Skip to content

Commit 2b46bcb

Browse files
authored
Update ddpm.py
clean up no.1
1 parent 417d5f3 commit 2b46bcb

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

ldm/models/diffusion/ddpm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ def __init__(self,
461461
self.instantiate_cond_stage(cond_stage_config)
462462
self.cond_stage_forward = cond_stage_forward
463463
self.clip_denoised = False
464-
self.bbox_tokenizer = None # # TODO: special class?
464+
self.bbox_tokenizer = None
465465

466466
self.restarted_from_ckpt = False
467467
if ckpt_path is not None:
@@ -598,7 +598,7 @@ def get_weighting(self, h, w, Ly, Lx, device):
598598
weighting = weighting * L_weighting
599599
return weighting
600600

601-
def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code !
601+
def get_fold_unfold(self, x, kernel_size, stride, uf=1, df=1): # todo load once not every time, shorten code
602602
"""
603603
:param x: img of size (bs, c, h, w)
604604
:return: n img crops of size (n, bs, c, kernel_size[0], kernel_size[1])
@@ -793,7 +793,7 @@ def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_qua
793793
z = z.view((z.shape[0], -1, ks[0], ks[1], z.shape[-1])) # (bn, nc, ks[0], ks[1], L )
794794

795795
# 2. apply model loop over last dim
796-
if isinstance(self.first_stage_model, VQModelInterface): # todo ask what this is
796+
if isinstance(self.first_stage_model, VQModelInterface):
797797
output_list = [self.first_stage_model.decode(z[:, :, :, :, i],
798798
force_not_quantize=predict_cids or force_not_quantize)
799799
for i in range(z.shape[-1])]
@@ -901,7 +901,7 @@ def apply_model(self, x_noisy, t, cond, return_ids=False):
901901

902902
if hasattr(self, "split_input_params"):
903903
assert len(cond) == 1 # todo can only deal with one conditioning atm
904-
assert not return_ids # todo dont know what this is -> I exclude --> Good
904+
assert not return_ids
905905
ks = self.split_input_params["ks"] # eg. (128, 128)
906906
stride = self.split_input_params["stride"] # eg. (64, 64)
907907

0 commit comments

Comments
 (0)