Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 71 additions & 17 deletions ella.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,37 @@
current_paths, _ = folder_paths.folder_names_and_paths["ella_encoder"]
folder_paths.folder_names_and_paths["ella_encoder"] = (current_paths, folder_paths.supported_pt_extensions)


# === device/dtype alignment helpers ===
def _infer_float_dtype_from_embeds(d: dict):
import torch
for v in d.values():
if torch.is_tensor(v) and v.is_floating_point():
return v.dtype
if isinstance(v, (list, tuple)):
for t in v:
if torch.is_tensor(t) and t.is_floating_point():
return t.dtype
if isinstance(v, dict):
dt = _infer_float_dtype_from_embeds(v)
if dt is not None:
return dt
return None

def _align_to_model_device_dtype(x, device, dtype):
import torch
if x is None:
return None
if torch.is_tensor(x):
if x.is_floating_point():
return x.to(device=device, dtype=dtype, non_blocking=True)
return x.to(device=device, non_blocking=True)
if isinstance(x, (list, tuple)):
return type(x)(_align_to_model_device_dtype(xx, device, dtype) for xx in x)
if isinstance(x, dict):
return {k: _align_to_model_device_dtype(v, device, dtype) for k, v in x.items()}
return x

# === /helpers ===
def ella_encode(ella: ELLA, timesteps: torch.Tensor, embeds: dict):
num_steps = len(timesteps) - 1
# print(f"creating ELLA conds for {num_steps} timesteps")
Expand All @@ -39,7 +69,14 @@ def ella_encode(ella: ELLA, timesteps: torch.Tensor, embeds: dict):
start = i / num_steps # Start percentage is calculated based on the index
end = (i + 1) / num_steps # End percentage is calculated based on the next index

cond_ella = ella(timestep, **embeds)
# cond_ella = ella(timestep, **embeds)
# align dtype/device to ELLA model
device = getattr(ella, "output_device", timesteps.device)
want_dtype = _infer_float_dtype_from_embeds(embeds) or torch.float16
_t = timestep.to(device=device, dtype=want_dtype)
_embeds = _align_to_model_device_dtype(embeds, device, want_dtype)

cond_ella = ella(_t, **_embeds)

cond_ella_dict = {"start_percent": start, "end_percent": end}
conds.append([cond_ella, cond_ella_dict])
Expand Down Expand Up @@ -69,13 +106,24 @@ def __init__(
self.embeds[i][k] = CONDCrossAttn(self.embeds[i][k])

def process_cond(self, embeds: Dict[str, CONDCrossAttn], batch_size, **kwargs):
return {k: v.process_cond(batch_size, self.ella.output_device, **kwargs).cond for k, v in embeds.items()}
# return {k: v.process_cond(batch_size, self.ella.output_device, **kwargs).cond for k, v in embeds.items()}
out = {k: v.process_cond(batch_size, self.ella.output_device, **kwargs).cond for k, v in embeds.items()}
# align floats to a common dtype inferred from outputs (or fallback fp16)
want_dtype = _infer_float_dtype_from_embeds(out) or torch.float16
return _align_to_model_device_dtype(out, self.ella.output_device, want_dtype)

def prepare_conds(self):

cond_embeds = self.process_cond(self.embeds[0], 1)
cond = self.ella(torch.Tensor([999]), **cond_embeds)
want_dtype = _infer_float_dtype_from_embeds(cond_embeds) or torch.float16
t999 = torch.tensor([999.0], device=self.ella.output_device, dtype=want_dtype)
cond = self.ella(t999, **cond_embeds)

uncond_embeds = self.process_cond(self.embeds[1], 1)
uncond = self.ella(torch.Tensor([999]), **uncond_embeds)
# same dtype for consistency
t999u = t999
uncond = self.ella(t999u, **uncond_embeds)

if self.mode == APPLY_MODE_ELLA_ONLY:
return cond, uncond
if "clip_embeds" not in cond_embeds or "clip_embeds" not in uncond_embeds:
Expand All @@ -94,22 +142,28 @@ def __call__(self, apply_model, kwargs: dict):
_device = c["c_crossattn"].device

time_aware_encoder_hidden_states = []
for i in cond_or_uncond:
# get the dtype of the target model from the cond-data of the first group
# (process_cond has already aligned device to self.ella.output_device)
for idx, i in enumerate(cond_or_uncond):
cond_embeds = self.process_cond(self.embeds[i], input_x.size(0) // len(cond_or_uncond))
h = self.ella(
self.model_sampling.timestep(timestep_[0]),
**cond_embeds,
)
if self.mode == APPLY_MODE_ELLA_ONLY:
want_dtype = _infer_float_dtype_from_embeds(cond_embeds) or torch.float16

# timestep from sampler can be on CPU and in fp32 - we will align it
t_model = self.model_sampling.timestep(timestep_[0])
t_model = t_model.to(device=self.ella.output_device, dtype=want_dtype)

h = self.ella(t_model, **cond_embeds)

if self.mode == APPLY_MODE_ELLA_ONLY or "clip_embeds" not in cond_embeds:
time_aware_encoder_hidden_states.append(h)
continue
if "clip_embeds" not in cond_embeds:
else:
h = torch.concat([h, cond_embeds["clip_embeds"]], dim=1)
time_aware_encoder_hidden_states.append(h)
continue
h = torch.concat([h, cond_embeds["clip_embeds"]], dim=1)
time_aware_encoder_hidden_states.append(h)

c["c_crossattn"] = torch.cat(time_aware_encoder_hidden_states, dim=0).to(_device)
# build a batch and move it under the downstream-UNet device
hidden = torch.cat(time_aware_encoder_hidden_states, dim=0)
c["c_crossattn"] = hidden.to(_device)


return apply_model(input_x, timestep_, **c)

Expand Down
84 changes: 75 additions & 9 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import Optional

import torch
import os
from safetensors.torch import load_file
from comfy import model_management
from comfy.model_patcher import ModelPatcher
from safetensors.torch import load_model
Expand All @@ -12,6 +14,59 @@
from .utils import patch_device_empty_setter, remove_weights


ELLA_DEBUG = os.getenv("ELLA_DEBUG", "0") in ("1", "true", "True")

def _count_params(m: torch.nn.Module):
total = sum(p.numel() for p in m.parameters())
trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
return total, trainable

def _size_mb(m: torch.nn.Module):
# estimate the size of parameters in MB
bytes_total = sum(p.numel() * p.element_size() for p in m.parameters())
return bytes_total / (1024 ** 2)



def load_model_lenient(model: torch.nn.Module, path: str):
sd_file = load_file(path) # dict[name -> Tensor]
model_sd = model.state_dict()

new_sd = {}
skipped_shape = []
extra = []
casted = []

for k, v in sd_file.items():
if k in model_sd:
if model_sd[k].shape == v.shape:
if model_sd[k].dtype != v.dtype:
v = v.to(model_sd[k].dtype)
casted.append(k)
# transfer to the parameter device
if v.device != model_sd[k].device:
v = v.to(model_sd[k].device)
new_sd[k] = v
else:
skipped_shape.append((k, tuple(v.shape), tuple(model_sd[k].shape)))
else:
extra.append(k)

if skipped_shape:
print(f"[ELLA/load] skipped by shape: {len(skipped_shape)} (e.g. {skipped_shape[:3]})")
if extra:
print(f"[ELLA/load] extra keys in ckpt: {len(extra)} (e.g. {extra[:5]})")
if casted:
print(f"[ELLA/load] dtype casted: {len(casted)} (e.g. {casted[:5]})")

missing = [k for k in model_sd.keys() if k not in new_sd]
if missing:
print(f"[ELLA/load] missing in ckpt: {len(missing)} (e.g. {missing[:5]})")

model.load_state_dict(new_sd, strict=False)
return model


class AdaLayerNorm(nn.Module):
def __init__(self, embedding_dim: int, time_embedding_dim: Optional[int] = None):
super().__init__()
Expand Down Expand Up @@ -135,6 +190,8 @@ def load_model(self):

def __call__(self, caption, text_input_ids=None, attention_mask=None, max_length=None, **kwargs):
self.load_model()
model_device = self.model.device

# remove a1111/comfyui prompt weight, t5 embedder currently does not accept weight
caption = remove_weights(caption)
if max_length is None:
Expand All @@ -152,14 +209,23 @@ def __call__(self, caption, text_input_ids=None, attention_mask=None, max_length
)
else:
text_inputs = self.tokenizer(caption, return_tensors="pt", add_special_tokens=True)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
text_input_ids = text_input_ids.to(self.model.device) # type: ignore
attention_mask = attention_mask.to(self.model.device) # type: ignore
outputs = self.model(text_input_ids, attention_mask=attention_mask) # type: ignore


# Ensure tensors are on the correct device
text_input_ids = text_inputs.input_ids.to(model_device)
attention_mask = text_inputs.attention_mask.to(model_device)
else:
# Ensure provided tensors are on the correct device
text_input_ids = text_input_ids.to(model_device)
attention_mask = attention_mask.to(model_device)

# Ensure model is on the correct device
self.model.to(model_device)

outputs = self.model(text_input_ids, attention_mask=attention_mask)

# Move output to the specified output device
return outputs.last_hidden_state.to(self.output_device)


class TimestepEmbedding(nn.Module):
def __init__(
Expand Down Expand Up @@ -316,8 +382,8 @@ def __init__(self, path: str, **kwargs) -> None:
self.dtype = model_management.text_encoder_dtype(self.load_device)
self.output_device = model_management.intermediate_device()
self.model = ELLAModel()
load_model(self.model, path, strict=True)
self.model.to(self.dtype) # type: ignore
load_model_lenient(self.model, path)
self.model.to(dtype=torch.float16)
self.patcher = ModelPatcher(self.model, load_device=self.load_device, offload_device=self.offload_device)

def load_model(self):
Expand Down