diff --git a/include/infinicore/ops/attention.hpp b/include/infinicore/ops/attention.hpp index 1bc447c77..79c34bbb9 100644 --- a/include/infinicore/ops/attention.hpp +++ b/include/infinicore/ops/attention.hpp @@ -2,6 +2,7 @@ #include "../device.hpp" #include "common/op.hpp" +#include namespace infinicore::op { class Attention { @@ -13,4 +14,15 @@ class Attention { Tensor attention(Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos); void attention_(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos); + +Tensor self_attention(Tensor query, + Tensor key, + Tensor value, + std::optional scale); + +void self_attention_(Tensor out, + Tensor query, + Tensor key, + Tensor value, + std::optional scale); } // namespace infinicore::op diff --git a/python/infinicore/nn/functional/__init__.py b/python/infinicore/nn/functional/__init__.py index 255079790..ec11d3d8b 100644 --- a/python/infinicore/nn/functional/__init__.py +++ b/python/infinicore/nn/functional/__init__.py @@ -4,6 +4,7 @@ from .random_sample import random_sample from .rms_norm import rms_norm from .rope import RopeAlgo, rope +from .self_attention import self_attention from .silu import silu from .swiglu import swiglu @@ -17,4 +18,5 @@ "embedding", "rope", "RopeAlgo", + "self_attention", ] diff --git a/python/infinicore/nn/functional/self_attention.py b/python/infinicore/nn/functional/self_attention.py new file mode 100644 index 000000000..5f3560c6e --- /dev/null +++ b/python/infinicore/nn/functional/self_attention.py @@ -0,0 +1,39 @@ +from typing import Optional + +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def self_attention( + query: Tensor, + key: Tensor, + value: Tensor, + scale: Optional[float] = None, + *, + out=None, +) -> Tensor: + r"""Computes scaled dot product attention on query, key and value tensors.""" + + seq_len = query.shape[-2] + total_seq_len = key.shape[-2] + + assert (1 == seq_len and total_seq_len > 1) or (seq_len == total_seq_len), ( + "Incorrect parameter value." + ) + + if out is None: + return Tensor( + _infinicore.self_attention( + query._underlying, key._underlying, value._underlying, scale + ) + ) + + _infinicore.self_attention_( + out._underlying, + query._underlying, + key._underlying, + value._underlying, + scale, + ) + + return out diff --git a/src/infinicore/ops/attention/attention.cc b/src/infinicore/ops/attention/attention.cc index bf4fd8203..076c00116 100644 --- a/src/infinicore/ops/attention/attention.cc +++ b/src/infinicore/ops/attention/attention.cc @@ -1,5 +1,7 @@ #include "infinicore/ops/attention.hpp" - +#include "infinicore/ops/causal_softmax.hpp" +#include "infinicore/ops/gemm.hpp" +#include namespace infinicore::op { common::OpDispatcher &Attention::dispatcher() { @@ -25,4 +27,88 @@ void attention_(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor Attention::execute(out, q, k, v, k_cache, v_cache, pos); } +Tensor self_attention(Tensor query_states, // [bs, num_attention_heads, ntoken, head_dim] + Tensor key_states, // [bs, num_key_value_heads, total_token, head_dim] + Tensor value_states, // [bs, num_key_value_heads, total_token, head_dim] + std::optional scale) { + + auto query_shape = query_states->shape(); + auto key_shape = key_states->shape(); + + Size batch_size = query_shape[0]; + Size num_attention_heads = query_shape[1]; + Size ntoken = query_shape[2]; + Size head_dim = key_shape[3]; + + Tensor output_values = Tensor::empty({batch_size, num_attention_heads, ntoken, head_dim}, query_states->dtype(), query_states->device()); + + self_attention_(output_values, query_states, key_states, value_states, scale); + + return output_values; +} + +void self_attention_(Tensor out, + Tensor query_states, + Tensor key_states, + Tensor value_states, + std::optional scale) { + + auto query_shape = query_states->shape(); + auto key_shape = key_states->shape(); + + Size batch_size = query_shape[0]; + Size num_attention_heads = query_shape[1]; + Size ntoken = query_shape[2]; + + Size num_key_value_heads = key_shape[1]; + Size total_token = key_shape[2]; + Size head_dim = key_shape[3]; + + assert(0 == (num_attention_heads % num_key_value_heads)); + Size ngroup = num_attention_heads / num_key_value_heads; + + float attention_scale{0.0f}; + if (scale.has_value()) { + attention_scale = scale.value(); + } else { + attention_scale = 1.f / float(sqrt(head_dim)); + } + + Tensor out_view = out->view({batch_size, num_key_value_heads, ngroup * ntoken, head_dim}); + for (Size ib = 0; ib < batch_size; ++ib) { + Tensor q = query_states->narrow({{0, ib, 1}})->view({num_attention_heads, ntoken, head_dim}); // [ num_attention_heads, ntoken, head_dim] + Tensor k = key_states->narrow({{0, ib, 1}})->view({num_key_value_heads, total_token, head_dim}); // [ num_key_value_heads, total_token, head_dim] + Tensor v = value_states->narrow({{0, ib, 1}})->view({num_key_value_heads, total_token, head_dim}); // [ num_key_value_heads, total_token, head_dim] + Tensor output_v = out_view->narrow({{0, ib, 1}})->view({num_key_value_heads, ngroup * ntoken, head_dim}); + { + /* + 输入: + q, [ num_attention_heads, ntoken, head_dim] + k, [ num_key_value_heads, total_token, head_dim] + v, [ num_key_value_heads, total_token, head_dim] + 输出: + att_val : {num_key_value_heads, ngroup * ntok, head_dim} + */ + + auto q_gemm = q->view({num_key_value_heads, ngroup * ntoken, head_dim}); // => {nkvh, ngroup * seq_len, dh} + auto k_gemm = k->permute({0, 2, 1}); // => { nkvh, dh, total_token} + auto v_gemm = v; // => { nkvh, total_token, dh} + + // qk_score : => {nkvh, ngroup * ntoken, total_token} + Tensor qk_score = gemm(q_gemm, // {nkvh, ngroup * ntoken, dh} + k_gemm, // {nkvh, dh, total_token} + attention_scale, 0.f); + + // softmax + auto qk_softmax = qk_score->view({num_attention_heads, ntoken, total_token}); + causal_softmax_(qk_softmax, qk_softmax); + + // values + gemm_(output_v, // {nkvh, ngroup * ntoken, dh} + qk_score, // {nkvh, ngroup * ntoken, total_token} + v_gemm, // { nkvh, total_token, dh} + 1.0f, 0.0f); + } + } +} } // namespace infinicore::op diff --git a/src/infinicore/pybind11/ops/attention.hpp b/src/infinicore/pybind11/ops/attention.hpp index 4af2d5f74..f4831b5c0 100644 --- a/src/infinicore/pybind11/ops/attention.hpp +++ b/src/infinicore/pybind11/ops/attention.hpp @@ -8,6 +8,29 @@ namespace py = pybind11; namespace infinicore::ops { +Tensor py_self_attention(Tensor query, + Tensor key, + Tensor value, + pybind11::object scale) { + std::optional scale_float = std::nullopt; + if (!scale.is_none()) { + scale_float = scale.cast(); + } + return op::self_attention(query, key, value, scale_float); +} + +void py_self_attention_(Tensor out, + Tensor query, + Tensor key, + Tensor value, + pybind11::object scale) { + std::optional scale_float = std::nullopt; + if (!scale.is_none()) { + scale_float = scale.cast(); + } + op::self_attention_(out, query, key, value, scale_float); +} + inline void bind_attention(py::module &m) { m.def("attention", &op::attention, @@ -21,7 +44,7 @@ inline void bind_attention(py::module &m) { Args: q: Query tensor - k: Key tensor + k: Key tensor v: Value tensor k_cache: Key cache tensor v_cache: Value cache tensor @@ -51,6 +74,23 @@ inline void bind_attention(py::module &m) { v_cache: Value cache tensor pos: Current position in the sequence )doc"); + + m.def("self_attention", + &ops::py_self_attention, + py::arg("query"), + py::arg("key"), + py::arg("value"), + py::arg("scale") = py::none(), + R"doc(Computes scaled dot product attention on query, key and value tensors)doc"); + + m.def("self_attention_", + &ops::py_self_attention_, + py::arg("out"), + py::arg("query"), + py::arg("key"), + py::arg("value"), + py::arg("scale") = py::none(), + R"doc(In-place, Computes scaled dot product attention on query, key and value tensors)doc"); } } // namespace infinicore::ops diff --git a/test/infinicore/ops/self_attention.py b/test/infinicore/ops/self_attention.py new file mode 100644 index 000000000..a52596d56 --- /dev/null +++ b/test/infinicore/ops/self_attention.py @@ -0,0 +1,136 @@ +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +import infinicore +from framework.base import BaseOperatorTest, TensorSpec, TestCase +from framework.runner import GenericTestRunner +from framework.utils import is_broadcast + + +# ============================================================================== +# Operator-specific configuration +# ============================================================================== +_TEST_CASES_DATA = [ + # bs, ntoken, total_token, num_attention_heads, num_key_value_heads, head_dim + (1, 4, 4, 8, 8, 64), + (1, 1, 4, 8, 8, 64), + (4, 16, 16, 32, 8, 64), + (4, 1, 128, 32, 8, 64), +] + + +# Tolerance configuration +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-2, "rtol": 1e-2}, + infinicore.float32: {"atol": 1e-2, "rtol": 1e-2}, + infinicore.bfloat16: {"atol": 5e-2, "rtol": 5e-2}, +} + + +# Data types to test +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] +# _TENSOR_DTYPES = [infinicore.bfloat16] + + +def parse_test_cases(): + """ + Parse test case data and return list of TestCase objects for sdpa operation. + Each test case contains all necessary information for execution and validation. + """ + test_cases = [] + + for data in _TEST_CASES_DATA: + bs = data[0] + ntoken, total_token = data[1], data[2] + num_attention_heads, num_key_value_heads = data[3], data[4] + head_dim = data[5] + + # Determine shapes based on batch dimension + query_shape = (bs, num_attention_heads, ntoken, head_dim) + key_shape = (bs, num_key_value_heads, total_token, head_dim) + value_shape = (bs, num_key_value_heads, total_token, head_dim) + out_shape = (bs, num_attention_heads, ntoken, head_dim) + + # Check if tensors support in-place operations + c_supports_inplace = not is_broadcast(out_shape) + + # Generate test cases for all data types + for dtype in _TENSOR_DTYPES: + tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3}) + + # Create typed tensor specs + query_spec = TensorSpec.from_tensor(query_shape, None, dtype) + key_spec = TensorSpec.from_tensor(key_shape, None, dtype) + value_spec = TensorSpec.from_tensor(value_shape, None, dtype) + out_spec = TensorSpec.from_tensor(out_shape, None, dtype) + + # Test Case 1: Out-of-place (return value) + test_cases.append( + TestCase( + inputs=[query_spec, key_spec, value_spec], + kwargs={}, + output_spec=None, + comparison_target=None, + tolerance=tolerance, + description=f"sdpa - OUT_OF_PLACE", + ) + ) + + # Test Case 2: In-place with explicit output tensor + if c_supports_inplace: + test_cases.append( + TestCase( + inputs=[query_spec, key_spec, value_spec], + kwargs=None, + output_spec=out_spec, # Specify the output tensor spec + comparison_target="out", + tolerance=tolerance, + description=f"sdpa - INPLACE(out)", + ) + ) + + return test_cases + + +class OpTest(BaseOperatorTest): + """sdpa operator test with simplified implementation""" + + def __init__(self): + super().__init__("sdpa") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, query, key, value, out=None, **kwargs): + """PyTorch sdpa implementation""" + ntoken = query.shape[-2] + total_token = key.shape[-2] + + is_causal = True + if 1 == ntoken and total_token > 1: + is_causal = False + + result = torch.nn.functional.scaled_dot_product_attention( + query, key, value, is_causal=is_causal, enable_gqa=True + ) + if out is not None: + out.copy_(result) + return out + return result + + def infinicore_operator(self, query, key, value, out=None, **kwargs): + """InfiniCore sdpa implementation""" + return infinicore.nn.functional.self_attention(query, key, value, out=out) + + +def main(): + """Main entry point""" + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main()