@@ -44,7 +44,10 @@ def _unmasked_pre_div_sdpa_script(query, key, value):
4444 scaled_key = op .Div (key_transposed , divisor )
4545 attn_score = op .MatMul (scaled_query , scaled_key )
4646 attn_weight = op .Softmax (attn_score , axis = - 1 )
47- attn_output = op .MatMul (attn_weight , value )
47+ is_nan = op .IsNaN (attn_weight )
48+ zero = op .Constant (value_float = 0.0 )
49+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
50+ attn_output = op .MatMul (adj_attn_weight , value )
4851 return attn_output
4952
5053
@@ -56,7 +59,10 @@ def _unmasked_pre_mul_sdpa_script(query, key, value):
5659 scaled_key = op .Mul (key_transposed , multiplier )
5760 attn_score = op .MatMul (scaled_query , scaled_key )
5861 attn_weight = op .Softmax (attn_score , axis = - 1 )
59- attn_output = op .MatMul (attn_weight , value )
62+ is_nan = op .IsNaN (attn_weight )
63+ zero = op .Constant (value_float = 0.0 )
64+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
65+ attn_output = op .MatMul (adj_attn_weight , value )
6066 return attn_output
6167
6268
@@ -67,7 +73,10 @@ def _unmasked_post_div_sdpa_script(query, key, value):
6773 attn_score = op .MatMul (query , key_transposed )
6874 scaled_attn_score = op .Div (attn_score , divisor )
6975 attn_weight = op .Softmax (scaled_attn_score , axis = - 1 )
70- attn_output = op .MatMul (attn_weight , value )
76+ is_nan = op .IsNaN (attn_weight )
77+ zero = op .Constant (value_float = 0.0 )
78+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
79+ attn_output = op .MatMul (adj_attn_weight , value )
7180 return attn_output
7281
7382
@@ -78,7 +87,10 @@ def _unmasked_post_mul_sdpa_script(query, key, value):
7887 attn_score = op .MatMul (query , key_transposed )
7988 scaled_attn_score = op .Mul (attn_score , multiplier )
8089 attn_weight = op .Softmax (scaled_attn_score , axis = - 1 )
81- attn_output = op .MatMul (attn_weight , value )
90+ is_nan = op .IsNaN (attn_weight )
91+ zero = op .Constant (value_float = 0.0 )
92+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
93+ attn_output = op .MatMul (adj_attn_weight , value )
8294 return attn_output
8395
8496
@@ -90,7 +102,10 @@ def _custom_scale_pre_div_sdpa_script(query, key, value):
90102 scaled_key = op .Div (key_transposed , divisor )
91103 attn_score = op .MatMul (scaled_query , scaled_key )
92104 attn_weight = op .Softmax (attn_score , axis = - 1 )
93- attn_output = op .MatMul (attn_weight , value )
105+ is_nan = op .IsNaN (attn_weight )
106+ zero = op .Constant (value_float = 0.0 )
107+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
108+ attn_output = op .MatMul (adj_attn_weight , value )
94109 return attn_output
95110
96111
@@ -102,7 +117,10 @@ def _custom_scale_pre_mul_sdpa_script(query, key, value):
102117 scaled_key = op .Mul (key_transposed , multiplier )
103118 attn_score = op .MatMul (scaled_query , scaled_key )
104119 attn_weight = op .Softmax (attn_score , axis = - 1 )
105- attn_output = op .MatMul (attn_weight , value )
120+ is_nan = op .IsNaN (attn_weight )
121+ zero = op .Constant (value_float = 0.0 )
122+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
123+ attn_output = op .MatMul (adj_attn_weight , value )
106124 return attn_output
107125
108126
@@ -115,7 +133,10 @@ def _custom_multi_scale_pre_mul_sdpa_script(query, key, value):
115133 scaled_key = op .Mul (key_transposed , multiplier_k )
116134 attn_score = op .MatMul (scaled_query , scaled_key )
117135 attn_weight = op .Softmax (attn_score , axis = - 1 )
118- attn_output = op .MatMul (attn_weight , value )
136+ is_nan = op .IsNaN (attn_weight )
137+ zero = op .Constant (value_float = 0.0 )
138+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
139+ attn_output = op .MatMul (adj_attn_weight , value )
119140 return attn_output
120141
121142
@@ -126,7 +147,10 @@ def _custom_scale_post_div_sdpa_script(query, key, value):
126147 attn_score = op .MatMul (query , key_transposed )
127148 scaled_attn_score = op .Div (attn_score , divisor )
128149 attn_weight = op .Softmax (scaled_attn_score , axis = - 1 )
129- attn_output = op .MatMul (attn_weight , value )
150+ is_nan = op .IsNaN (attn_weight )
151+ zero = op .Constant (value_float = 0.0 )
152+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
153+ attn_output = op .MatMul (adj_attn_weight , value )
130154 return attn_output
131155
132156
@@ -137,7 +161,10 @@ def _custom_scale_post_mul_sdpa_script(query, key, value):
137161 attn_score = op .MatMul (query , key_transposed )
138162 scaled_attn_score = op .Mul (attn_score , multiplier )
139163 attn_weight = op .Softmax (scaled_attn_score , axis = - 1 )
140- attn_output = op .MatMul (attn_weight , value )
164+ is_nan = op .IsNaN (attn_weight )
165+ zero = op .Constant (value_float = 0.0 )
166+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
167+ attn_output = op .MatMul (adj_attn_weight , value )
141168 return attn_output
142169
143170
@@ -150,7 +177,10 @@ def _masked_pre_div_sdpa_script(query, key, value, mask):
150177 attn_score = op .MatMul (scaled_query , scaled_key )
151178 masked_attn_score = op .Add (attn_score , mask )
152179 attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
153- attn_output = op .MatMul (attn_weight , value )
180+ is_nan = op .IsNaN (attn_weight )
181+ zero = op .Constant (value_float = 0.0 )
182+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
183+ attn_output = op .MatMul (adj_attn_weight , value )
154184 return attn_output
155185
156186
@@ -163,7 +193,10 @@ def _masked_pre_mul_sdpa_script(query, key, value, mask):
163193 attn_score = op .MatMul (scaled_query , scaled_key )
164194 masked_attn_score = op .Add (attn_score , mask )
165195 attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
166- attn_output = op .MatMul (attn_weight , value )
196+ is_nan = op .IsNaN (attn_weight )
197+ zero = op .Constant (value_float = 0.0 )
198+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
199+ attn_output = op .MatMul (adj_attn_weight , value )
167200 return attn_output
168201
169202
@@ -175,7 +208,10 @@ def _masked_post_div_sdpa_script(query, key, value, mask):
175208 scaled_attn_score = op .Div (attn_score , divisor )
176209 masked_attn_score = op .Add (scaled_attn_score , mask )
177210 attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
178- attn_output = op .MatMul (attn_weight , value )
211+ is_nan = op .IsNaN (attn_weight )
212+ zero = op .Constant (value_float = 0.0 )
213+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
214+ attn_output = op .MatMul (adj_attn_weight , value )
179215 return attn_output
180216
181217
@@ -187,7 +223,10 @@ def _masked_post_mul_sdpa_script(query, key, value, mask):
187223 scaled_attn_score = op .Mul (attn_score , multiplier )
188224 masked_attn_score = op .Add (scaled_attn_score , mask )
189225 attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
190- attn_output = op .MatMul (attn_weight , value )
226+ is_nan = op .IsNaN (attn_weight )
227+ zero = op .Constant (value_float = 0.0 )
228+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
229+ attn_output = op .MatMul (adj_attn_weight , value )
191230 return attn_output
192231
193232
@@ -200,7 +239,10 @@ def _masked_custom_scale_pre_div_sdpa_script(query, key, value, mask):
200239 attn_score = op .MatMul (scaled_query , scaled_key )
201240 masked_attn_score = op .Add (attn_score , mask )
202241 attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
203- attn_output = op .MatMul (attn_weight , value )
242+ is_nan = op .IsNaN (attn_weight )
243+ zero = op .Constant (value_float = 0.0 )
244+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
245+ attn_output = op .MatMul (adj_attn_weight , value )
204246 return attn_output
205247
206248
@@ -213,7 +255,10 @@ def _masked_custom_scale_pre_mul_sdpa_script(query, key, value, mask):
213255 attn_score = op .MatMul (scaled_query , scaled_key )
214256 masked_attn_score = op .Add (attn_score , mask )
215257 attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
216- attn_output = op .MatMul (attn_weight , value )
258+ is_nan = op .IsNaN (attn_weight )
259+ zero = op .Constant (value_float = 0.0 )
260+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
261+ attn_output = op .MatMul (adj_attn_weight , value )
217262 return attn_output
218263
219264
@@ -225,7 +270,10 @@ def _masked_custom_scale_post_div_sdpa_script(query, key, value, mask):
225270 scaled_attn_score = op .Div (attn_score , divisor )
226271 masked_attn_score = op .Add (scaled_attn_score , mask )
227272 attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
228- attn_output = op .MatMul (attn_weight , value )
273+ is_nan = op .IsNaN (attn_weight )
274+ zero = op .Constant (value_float = 0.0 )
275+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
276+ attn_output = op .MatMul (adj_attn_weight , value )
229277 return attn_output
230278
231279
@@ -237,7 +285,10 @@ def _masked_custom_scale_post_mul_sdpa_script(query, key, value, mask):
237285 scaled_attn_score = op .Mul (attn_score , multiplier )
238286 masked_attn_score = op .Add (scaled_attn_score , mask )
239287 attn_weight = op .Softmax (masked_attn_score , axis = - 1 )
240- attn_output = op .MatMul (attn_weight , value )
288+ is_nan = op .IsNaN (attn_weight )
289+ zero = op .Constant (value_float = 0.0 )
290+ adj_attn_weight = op .Where (is_nan , zero , attn_weight )
291+ attn_output = op .MatMul (adj_attn_weight , value )
241292 return attn_output
242293
243294
0 commit comments