@@ -322,6 +322,8 @@ def bundle_resources_for_swift_cli(args):
322322from transformers .models .clip import modeling_clip
323323
324324# Copied from https://github.com/huggingface/transformers/blob/v4.30.0/src/transformers/models/clip/modeling_clip.py#L677C1-L692C1
325+ # Starting from transformers >= 4.35.0, the _make_causal_mask function is replaced by _create_4d_causal_attention_mask in modeling_clip.
326+ # For backward compatibility with versions < 4.35.0, both functions are patched here.
325327def patched_make_causal_mask (input_ids_shape , dtype , device , past_key_values_length : int = 0 ):
326328 """ Patch to replace torch.finfo(dtype).min with -1e4
327329 """
@@ -334,9 +336,7 @@ def patched_make_causal_mask(input_ids_shape, dtype, device, past_key_values_len
334336 if past_key_values_length > 0 :
335337 mask = torch .cat ([torch .zeros (tgt_len , past_key_values_length , dtype = dtype , device = device ), mask ], dim = - 1 )
336338 return mask [None , None , :, :].expand (bsz , 1 , tgt_len , tgt_len + past_key_values_length )
337-
338- # Starting from transformers >= 4.35.0, the _make_causal_mask function is replaced by _create_4d_causal_attention_mask in modeling_clip.
339- # For backward compatibility with versions < 4.35.0, both functions are patched here.
339+
340340modeling_clip ._make_causal_mask = patched_make_causal_mask # For transformers < 4.35.0
341341modeling_clip ._create_4d_causal_attention_mask = patched_make_causal_mask # For transformers >= 4.35.0
342342
0 commit comments