Commit df002a9
authored
Fix: Return Attention Scores when
* Fix: Ensure Attention Layer Returns Attention Scores when `return_attention_scores=True`
This pull request addresses an issue in the Attention layer where the return_attention_scores parameter wasn't correctly handled in the compute_output_shape method. This fix ensures that attention scores are returned when return_attention_scores=True.
## Changes Made
Modified compute_output_shape method to return the shape of both the attention output and the attention scores when return_attention_scores=True.
* Formatting
* Fixed score return and added unit tests for return_attention_scores=True
* Removed debug print statementreturn_attention_scores=True (#20684)1 parent c1316e5 commit df002a9
2 files changed
+111
-6
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | 3 | | |
| 4 | + | |
4 | 5 | | |
5 | 6 | | |
6 | 7 | | |
| |||
84 | 85 | | |
85 | 86 | | |
86 | 87 | | |
| 88 | + | |
| 89 | + | |
87 | 90 | | |
88 | 91 | | |
89 | 92 | | |
| |||
217 | 220 | | |
218 | 221 | | |
219 | 222 | | |
| 223 | + | |
220 | 224 | | |
221 | 225 | | |
222 | 226 | | |
| |||
226 | 230 | | |
227 | 231 | | |
228 | 232 | | |
229 | | - | |
| 233 | + | |
230 | 234 | | |
231 | 235 | | |
232 | 236 | | |
233 | 237 | | |
234 | 238 | | |
235 | | - | |
| 239 | + | |
236 | 240 | | |
237 | | - | |
238 | | - | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
239 | 244 | | |
240 | 245 | | |
241 | 246 | | |
| |||
244 | 249 | | |
245 | 250 | | |
246 | 251 | | |
247 | | - | |
248 | | - | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
249 | 295 | | |
250 | 296 | | |
251 | 297 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
358 | 358 | | |
359 | 359 | | |
360 | 360 | | |
| 361 | + | |
| 362 | + | |
| 363 | + | |
| 364 | + | |
| 365 | + | |
| 366 | + | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
0 commit comments