You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
### Description
MultiheadAttention CUDA BF16 Support
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output to float tensors.")
1117
-
.TypeConstraint("QK", {"tensor(float)", "tensor(float16)"}, "Constrain QK output to float32 or float16 tensors, independent of input type or output type.")
1116
+
.TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.")
1117
+
.TypeConstraint("QK", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain QK output to float32 or float16 tensors, independent of input type or output type.")
1118
1118
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to integer types")
0 commit comments