diff --git a/cldm/cldm_unicontrol.py b/cldm/cldm_unicontrol.py index b776cc7..b0bf99f 100644 --- a/cldm/cldm_unicontrol.py +++ b/cldm/cldm_unicontrol.py @@ -436,7 +436,7 @@ def get_input(self, batch, k, bs=None, *args, **kwargs): batch['txt'] = batch['txt'] + [self.mapping_task[task_name]] x, c_all = super().get_input(batch, self.first_stage_key, *args, **kwargs) - c, c_task = c_all[:BS,:,:], c_all[BS:,:1,:] + c, c_task = c_all[:BS,:,:], c_all[BS:BS+1,:1,:] control = batch[self.control_key] if bs is not None: control = control[:bs] diff --git a/cldm/cldm_unicontrol_v11.py b/cldm/cldm_unicontrol_v11.py index 3a049f9..a2093c9 100644 --- a/cldm/cldm_unicontrol_v11.py +++ b/cldm/cldm_unicontrol_v11.py @@ -436,7 +436,7 @@ def get_input(self, batch, k, bs=None, *args, **kwargs): batch['txt'] = batch['txt'] + [self.mapping_task[task_name]] x, c_all = super().get_input(batch, self.first_stage_key, *args, **kwargs) - c, c_task = c_all[:BS,:,:], c_all[BS:,:1,:] + c, c_task = c_all[:BS,:,:], c_all[BS:BS+1,:1,:] control = batch[self.control_key] if bs is not None: control = control[:bs]