Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def load_model(self):

def __call__(self, caption, text_input_ids=None, attention_mask=None, max_length=None, **kwargs):
self.load_model()
model_device = self.model.device

# remove a1111/comfyui prompt weight, t5 embedder currently does not accept weight
caption = remove_weights(caption)
if max_length is None:
Expand All @@ -152,14 +154,23 @@ def __call__(self, caption, text_input_ids=None, attention_mask=None, max_length
)
else:
text_inputs = self.tokenizer(caption, return_tensors="pt", add_special_tokens=True)
text_input_ids = text_inputs.input_ids
attention_mask = text_inputs.attention_mask
text_input_ids = text_input_ids.to(self.model.device) # type: ignore
attention_mask = attention_mask.to(self.model.device) # type: ignore
outputs = self.model(text_input_ids, attention_mask=attention_mask) # type: ignore


# Ensure tensors are on the correct device
text_input_ids = text_inputs.input_ids.to(model_device)
attention_mask = text_inputs.attention_mask.to(model_device)
else:
# Ensure provided tensors are on the correct device
text_input_ids = text_input_ids.to(model_device)
attention_mask = attention_mask.to(model_device)

# Ensure model is on the correct device
self.model.to(model_device)

outputs = self.model(text_input_ids, attention_mask=attention_mask)

# Move output to the specified output device
return outputs.last_hidden_state.to(self.output_device)


class TimestepEmbedding(nn.Module):
def __init__(
Expand Down