@@ -358,6 +358,8 @@ def bundle_resources_for_swift_cli(args):
358358from transformers .models .clip import modeling_clip
359359
360360# Copied from https://github.com/huggingface/transformers/blob/v4.30.0/src/transformers/models/clip/modeling_clip.py#L677C1-L692C1
361+ # Starting from transformers >= 4.35.0, the _make_causal_mask function is replaced by _create_4d_causal_attention_mask in modeling_clip.
362+ # For backward compatibility with versions < 4.35.0, both functions are patched here.
361363def patched_make_causal_mask (input_ids_shape , dtype , device , past_key_values_length : int = 0 ):
362364 """ Patch to replace torch.finfo(dtype).min with -1e4
363365 """
@@ -370,8 +372,9 @@ def patched_make_causal_mask(input_ids_shape, dtype, device, past_key_values_len
370372 if past_key_values_length > 0 :
371373 mask = torch .cat ([torch .zeros (tgt_len , past_key_values_length , dtype = dtype , device = device ), mask ], dim = - 1 )
372374 return mask [None , None , :, :].expand (bsz , 1 , tgt_len , tgt_len + past_key_values_length )
373-
374- modeling_clip ._make_causal_mask = patched_make_causal_mask
375+
376+ modeling_clip ._make_causal_mask = patched_make_causal_mask # For transformers >= 4.30.0 and transformers < 4.35.0
377+ modeling_clip ._create_4d_causal_attention_mask = patched_make_causal_mask # For transformers >= 4.35.0
375378
376379def convert_text_encoder (text_encoder , tokenizer , submodule_name , args ):
377380 """ Converts the text encoder component of Stable Diffusion
0 commit comments