From 73a1aa1384bc1198329710b6c2e22c1b7ef7d182 Mon Sep 17 00:00:00 2001 From: mr-asa Date: Wed, 27 Aug 2025 19:14:24 +0300 Subject: [PATCH 1/2] fix: resolve device mismatch in T5TextEmbedder (#1) --- model.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/model.py b/model.py index 9fbd12e..1816c56 100644 --- a/model.py +++ b/model.py @@ -135,6 +135,8 @@ def load_model(self): def __call__(self, caption, text_input_ids=None, attention_mask=None, max_length=None, **kwargs): self.load_model() + model_device = self.model.device + # remove a1111/comfyui prompt weight, t5 embedder currently does not accept weight caption = remove_weights(caption) if max_length is None: @@ -152,14 +154,23 @@ def __call__(self, caption, text_input_ids=None, attention_mask=None, max_length ) else: text_inputs = self.tokenizer(caption, return_tensors="pt", add_special_tokens=True) - text_input_ids = text_inputs.input_ids - attention_mask = text_inputs.attention_mask - text_input_ids = text_input_ids.to(self.model.device) # type: ignore - attention_mask = attention_mask.to(self.model.device) # type: ignore - outputs = self.model(text_input_ids, attention_mask=attention_mask) # type: ignore - + + # Ensure tensors are on the correct device + text_input_ids = text_inputs.input_ids.to(model_device) + attention_mask = text_inputs.attention_mask.to(model_device) + else: + # Ensure provided tensors are on the correct device + text_input_ids = text_input_ids.to(model_device) + attention_mask = attention_mask.to(model_device) + + # Ensure model is on the correct device + self.model.to(model_device) + + outputs = self.model(text_input_ids, attention_mask=attention_mask) + + # Move output to the specified output device return outputs.last_hidden_state.to(self.output_device) - + class TimestepEmbedding(nn.Module): def __init__( From 3c70a83192c934fc9caf870e7f4c41c5309a3e5b Mon Sep 17 00:00:00 2001 From: mr-asa Date: Mon, 25 Aug 2025 13:03:31 +0300 Subject: [PATCH 2/2] feat: add device and dtype alignment helpers for ELLA model --- ella.py | 88 +++++++++++++++++++++++++++++++++++++++++++++----------- model.py | 59 +++++++++++++++++++++++++++++++++++-- 2 files changed, 128 insertions(+), 19 deletions(-) diff --git a/ella.py b/ella.py index ca3ee39..fbb79d8 100644 --- a/ella.py +++ b/ella.py @@ -29,7 +29,37 @@ current_paths, _ = folder_paths.folder_names_and_paths["ella_encoder"] folder_paths.folder_names_and_paths["ella_encoder"] = (current_paths, folder_paths.supported_pt_extensions) - +# === device/dtype alignment helpers === +def _infer_float_dtype_from_embeds(d: dict): + import torch + for v in d.values(): + if torch.is_tensor(v) and v.is_floating_point(): + return v.dtype + if isinstance(v, (list, tuple)): + for t in v: + if torch.is_tensor(t) and t.is_floating_point(): + return t.dtype + if isinstance(v, dict): + dt = _infer_float_dtype_from_embeds(v) + if dt is not None: + return dt + return None + +def _align_to_model_device_dtype(x, device, dtype): + import torch + if x is None: + return None + if torch.is_tensor(x): + if x.is_floating_point(): + return x.to(device=device, dtype=dtype, non_blocking=True) + return x.to(device=device, non_blocking=True) + if isinstance(x, (list, tuple)): + return type(x)(_align_to_model_device_dtype(xx, device, dtype) for xx in x) + if isinstance(x, dict): + return {k: _align_to_model_device_dtype(v, device, dtype) for k, v in x.items()} + return x + +# === /helpers === def ella_encode(ella: ELLA, timesteps: torch.Tensor, embeds: dict): num_steps = len(timesteps) - 1 # print(f"creating ELLA conds for {num_steps} timesteps") @@ -39,7 +69,14 @@ def ella_encode(ella: ELLA, timesteps: torch.Tensor, embeds: dict): start = i / num_steps # Start percentage is calculated based on the index end = (i + 1) / num_steps # End percentage is calculated based on the next index - cond_ella = ella(timestep, **embeds) + # cond_ella = ella(timestep, **embeds) + # align dtype/device to ELLA model + device = getattr(ella, "output_device", timesteps.device) + want_dtype = _infer_float_dtype_from_embeds(embeds) or torch.float16 + _t = timestep.to(device=device, dtype=want_dtype) + _embeds = _align_to_model_device_dtype(embeds, device, want_dtype) + + cond_ella = ella(_t, **_embeds) cond_ella_dict = {"start_percent": start, "end_percent": end} conds.append([cond_ella, cond_ella_dict]) @@ -69,13 +106,24 @@ def __init__( self.embeds[i][k] = CONDCrossAttn(self.embeds[i][k]) def process_cond(self, embeds: Dict[str, CONDCrossAttn], batch_size, **kwargs): - return {k: v.process_cond(batch_size, self.ella.output_device, **kwargs).cond for k, v in embeds.items()} + # return {k: v.process_cond(batch_size, self.ella.output_device, **kwargs).cond for k, v in embeds.items()} + out = {k: v.process_cond(batch_size, self.ella.output_device, **kwargs).cond for k, v in embeds.items()} + # align floats to a common dtype inferred from outputs (or fallback fp16) + want_dtype = _infer_float_dtype_from_embeds(out) or torch.float16 + return _align_to_model_device_dtype(out, self.ella.output_device, want_dtype) def prepare_conds(self): + cond_embeds = self.process_cond(self.embeds[0], 1) - cond = self.ella(torch.Tensor([999]), **cond_embeds) + want_dtype = _infer_float_dtype_from_embeds(cond_embeds) or torch.float16 + t999 = torch.tensor([999.0], device=self.ella.output_device, dtype=want_dtype) + cond = self.ella(t999, **cond_embeds) + uncond_embeds = self.process_cond(self.embeds[1], 1) - uncond = self.ella(torch.Tensor([999]), **uncond_embeds) + # same dtype for consistency + t999u = t999 + uncond = self.ella(t999u, **uncond_embeds) + if self.mode == APPLY_MODE_ELLA_ONLY: return cond, uncond if "clip_embeds" not in cond_embeds or "clip_embeds" not in uncond_embeds: @@ -94,22 +142,28 @@ def __call__(self, apply_model, kwargs: dict): _device = c["c_crossattn"].device time_aware_encoder_hidden_states = [] - for i in cond_or_uncond: + # get the dtype of the target model from the cond-data of the first group + # (process_cond has already aligned device to self.ella.output_device) + for idx, i in enumerate(cond_or_uncond): cond_embeds = self.process_cond(self.embeds[i], input_x.size(0) // len(cond_or_uncond)) - h = self.ella( - self.model_sampling.timestep(timestep_[0]), - **cond_embeds, - ) - if self.mode == APPLY_MODE_ELLA_ONLY: + want_dtype = _infer_float_dtype_from_embeds(cond_embeds) or torch.float16 + + # timestep from sampler can be on CPU and in fp32 - we will align it + t_model = self.model_sampling.timestep(timestep_[0]) + t_model = t_model.to(device=self.ella.output_device, dtype=want_dtype) + + h = self.ella(t_model, **cond_embeds) + + if self.mode == APPLY_MODE_ELLA_ONLY or "clip_embeds" not in cond_embeds: time_aware_encoder_hidden_states.append(h) - continue - if "clip_embeds" not in cond_embeds: + else: + h = torch.concat([h, cond_embeds["clip_embeds"]], dim=1) time_aware_encoder_hidden_states.append(h) - continue - h = torch.concat([h, cond_embeds["clip_embeds"]], dim=1) - time_aware_encoder_hidden_states.append(h) - c["c_crossattn"] = torch.cat(time_aware_encoder_hidden_states, dim=0).to(_device) + # build a batch and move it under the downstream-UNet device + hidden = torch.cat(time_aware_encoder_hidden_states, dim=0) + c["c_crossattn"] = hidden.to(_device) + return apply_model(input_x, timestep_, **c) diff --git a/model.py b/model.py index 1816c56..921f36a 100644 --- a/model.py +++ b/model.py @@ -3,6 +3,8 @@ from typing import Optional import torch +import os +from safetensors.torch import load_file from comfy import model_management from comfy.model_patcher import ModelPatcher from safetensors.torch import load_model @@ -12,6 +14,59 @@ from .utils import patch_device_empty_setter, remove_weights +ELLA_DEBUG = os.getenv("ELLA_DEBUG", "0") in ("1", "true", "True") + +def _count_params(m: torch.nn.Module): + total = sum(p.numel() for p in m.parameters()) + trainable = sum(p.numel() for p in m.parameters() if p.requires_grad) + return total, trainable + +def _size_mb(m: torch.nn.Module): + # estimate the size of parameters in MB + bytes_total = sum(p.numel() * p.element_size() for p in m.parameters()) + return bytes_total / (1024 ** 2) + + + +def load_model_lenient(model: torch.nn.Module, path: str): + sd_file = load_file(path) # dict[name -> Tensor] + model_sd = model.state_dict() + + new_sd = {} + skipped_shape = [] + extra = [] + casted = [] + + for k, v in sd_file.items(): + if k in model_sd: + if model_sd[k].shape == v.shape: + if model_sd[k].dtype != v.dtype: + v = v.to(model_sd[k].dtype) + casted.append(k) + # transfer to the parameter device + if v.device != model_sd[k].device: + v = v.to(model_sd[k].device) + new_sd[k] = v + else: + skipped_shape.append((k, tuple(v.shape), tuple(model_sd[k].shape))) + else: + extra.append(k) + + if skipped_shape: + print(f"[ELLA/load] skipped by shape: {len(skipped_shape)} (e.g. {skipped_shape[:3]})") + if extra: + print(f"[ELLA/load] extra keys in ckpt: {len(extra)} (e.g. {extra[:5]})") + if casted: + print(f"[ELLA/load] dtype casted: {len(casted)} (e.g. {casted[:5]})") + + missing = [k for k in model_sd.keys() if k not in new_sd] + if missing: + print(f"[ELLA/load] missing in ckpt: {len(missing)} (e.g. {missing[:5]})") + + model.load_state_dict(new_sd, strict=False) + return model + + class AdaLayerNorm(nn.Module): def __init__(self, embedding_dim: int, time_embedding_dim: Optional[int] = None): super().__init__() @@ -327,8 +382,8 @@ def __init__(self, path: str, **kwargs) -> None: self.dtype = model_management.text_encoder_dtype(self.load_device) self.output_device = model_management.intermediate_device() self.model = ELLAModel() - load_model(self.model, path, strict=True) - self.model.to(self.dtype) # type: ignore + load_model_lenient(self.model, path) + self.model.to(dtype=torch.float16) self.patcher = ModelPatcher(self.model, load_device=self.load_device, offload_device=self.offload_device) def load_model(self):