|
18 | 18 |
|
19 | 19 | import numpy as np |
20 | 20 | import torch |
| 21 | +import torch.distributed as dist |
21 | 22 | import torch.nn as nn |
22 | 23 | import torch.nn.functional as F |
23 | 24 |
|
@@ -330,6 +331,19 @@ def __call__( |
330 | 331 | joint_key = torch.cat([txt_key, img_key], dim=1) |
331 | 332 | joint_value = torch.cat([txt_value, img_value], dim=1) |
332 | 333 |
|
| 334 | + if attention_mask is None and encoder_hidden_states_mask is not None: |
| 335 | + # The joint sequence is [text, image]. |
| 336 | + seq_len_img = hidden_states.shape[1] |
| 337 | + img_mask = torch.ones( |
| 338 | + encoder_hidden_states_mask.shape[0], seq_len_img, device=encoder_hidden_states_mask.device |
| 339 | + ) |
| 340 | + attention_mask = torch.cat([encoder_hidden_states_mask, img_mask], dim=1) |
| 341 | + |
| 342 | + # Convert the mask to the format expected by SDPA |
| 343 | + attention_mask = attention_mask[:, None, None, :] |
| 344 | + attention_mask = attention_mask.to(dtype=joint_query.dtype) |
| 345 | + attention_mask = (1.0 - attention_mask) * torch.finfo(joint_query.dtype).min |
| 346 | + |
333 | 347 | # Compute joint attention |
334 | 348 | joint_hidden_states = dispatch_attention_fn( |
335 | 349 | joint_query, |
@@ -600,6 +614,16 @@ def forward( |
600 | 614 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a |
601 | 615 | `tuple` where the first element is the sample tensor. |
602 | 616 | """ |
| 617 | + if dist.is_initialized(): |
| 618 | + world_size = dist.get_world_size() |
| 619 | + if world_size > 1 and encoder_hidden_states is not None: |
| 620 | + seq_len = encoder_hidden_states.shape[1] |
| 621 | + pad_len = (world_size - seq_len % world_size) % world_size |
| 622 | + if pad_len > 0: |
| 623 | + encoder_hidden_states = F.pad(encoder_hidden_states, (0, 0, 0, pad_len)) |
| 624 | + if encoder_hidden_states_mask is not None: |
| 625 | + encoder_hidden_states_mask = F.pad(encoder_hidden_states_mask, (0, pad_len), value=0) |
| 626 | + |
603 | 627 | if attention_kwargs is not None: |
604 | 628 | attention_kwargs = attention_kwargs.copy() |
605 | 629 | lora_scale = attention_kwargs.pop("scale", 1.0) |
@@ -630,7 +654,13 @@ def forward( |
630 | 654 | else self.time_text_embed(timestep, guidance, hidden_states) |
631 | 655 | ) |
632 | 656 |
|
633 | | - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) |
| 657 | + # use the shape of the padded hidden states to generate the rotary embeddings |
| 658 | + if encoder_hidden_states is not None: |
| 659 | + recalculated_txt_seq_lens = [encoder_hidden_states.shape[1]] * encoder_hidden_states.shape[0] |
| 660 | + else: |
| 661 | + recalculated_txt_seq_lens = txt_seq_lens |
| 662 | + |
| 663 | + image_rotary_emb = self.pos_embed(img_shapes, recalculated_txt_seq_lens, device=hidden_states.device) |
634 | 664 |
|
635 | 665 | for index_block, block in enumerate(self.transformer_blocks): |
636 | 666 | if torch.is_grad_enabled() and self.gradient_checkpointing: |
|
0 commit comments