|
25 | 25 | stack, |
26 | 26 | switch, |
27 | 27 | ) |
| 28 | +from pytensor.tensor.blockwise import Blockwise |
28 | 29 | from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise, scalar_elemwise |
29 | 30 | from pytensor.tensor.shape import shape, specify_broadcastable |
30 | 31 | from pytensor.tensor.type import ( |
31 | 32 | DenseTensorType, |
32 | | - TensorType, |
33 | 33 | complex_dtypes, |
34 | 34 | continuous_dtypes, |
35 | 35 | discrete_dtypes, |
@@ -2868,93 +2868,7 @@ def logsumexp(x, axis=None, keepdims=False): |
2868 | 2868 | return log(sum(exp(x), axis=axis, keepdims=keepdims)) |
2869 | 2869 |
|
2870 | 2870 |
|
2871 | | -class MatMul(Op): |
2872 | | - __props__ = ("dtype",) |
2873 | | - |
2874 | | - def __init__(self, dtype=None): |
2875 | | - self.dtype = dtype |
2876 | | - |
2877 | | - @classmethod |
2878 | | - def _get_output_shape(cls, x1, x2, shapes, validate=False): |
2879 | | - x1_shape, x2_shape = shapes |
2880 | | - |
2881 | | - if x1.ndim == 1 and x2.ndim == 1: |
2882 | | - if validate and x1_shape[0] != x2_shape[0]: |
2883 | | - raise ValueError("1d inputs must have the same length.") |
2884 | | - return () |
2885 | | - elif x1.ndim == 1 and x2.ndim > 1: |
2886 | | - if validate and x1_shape[0] != x2_shape[-2]: |
2887 | | - raise ValueError( |
2888 | | - "length of input 1 must be equal the length " |
2889 | | - "of the 2nd-last dimension of input 2" |
2890 | | - ) |
2891 | | - return x2_shape[:-2] + x2_shape[-1:] |
2892 | | - elif x1.ndim > 1 and x2.ndim == 1: |
2893 | | - if validate and x1_shape[-1] != x2_shape[0]: |
2894 | | - raise ValueError( |
2895 | | - "length of input 2 must be equal the length " |
2896 | | - "of the last dimension of input 1" |
2897 | | - ) |
2898 | | - return x1_shape[:-1] |
2899 | | - elif x1.ndim == 2 and x2.ndim == 2: |
2900 | | - if validate and x1_shape[-1] != x2_shape[0]: |
2901 | | - raise ValueError( |
2902 | | - "number of columns of input 1 must be equal to " |
2903 | | - "the number of rows of input 2" |
2904 | | - ) |
2905 | | - return x1_shape[:-1] + x2_shape[-1:] |
2906 | | - elif x1.ndim > 2 and x2.ndim == 2: |
2907 | | - if validate and x1_shape[-1] != x2_shape[0]: |
2908 | | - raise ValueError( |
2909 | | - "number of rows of input 2 must be equal to " |
2910 | | - "the length of the last dimension of input 1" |
2911 | | - ) |
2912 | | - return x1_shape[:-2] + x1_shape[-2:-1] + x2_shape[-1:] |
2913 | | - elif x1.ndim == 2 and x2.ndim > 2: |
2914 | | - if validate and x1_shape[-1] != x2_shape[-2]: |
2915 | | - raise ValueError( |
2916 | | - "number of columns of input 1 must be equal " |
2917 | | - "the length of the 2nd-last dimension of input 2" |
2918 | | - ) |
2919 | | - return x2_shape[:-2] + x1_shape[-2:-1] + x2_shape[-1:] |
2920 | | - else: |
2921 | | - if validate: |
2922 | | - from pytensor.tensor.random.basic import broadcast_shapes |
2923 | | - |
2924 | | - bshape = broadcast_shapes(x1_shape[:-2], x2_shape[:-2]) |
2925 | | - if x1_shape[-1] != x2_shape[-2]: |
2926 | | - raise ValueError( |
2927 | | - "length of the last dimension of input 1 must be equal " |
2928 | | - "to the length of the 2nd-last dimension of input 2" |
2929 | | - ) |
2930 | | - else: |
2931 | | - from pytensor.tensor.extra_ops import broadcast_shape |
2932 | | - |
2933 | | - bshape = broadcast_shape( |
2934 | | - x1_shape[:-2], x2_shape[:-2], arrays_are_shapes=True |
2935 | | - ) |
2936 | | - return bshape + x1_shape[-2:-1] + x2_shape[-1:] |
2937 | | - |
2938 | | - def make_node(self, a, b): |
2939 | | - a = as_tensor_variable(a) |
2940 | | - b = as_tensor_variable(b) |
2941 | | - |
2942 | | - if 0 in {a.ndim, b.ndim}: |
2943 | | - raise ValueError("inputs to `matmul` cannot be scalar.") |
2944 | | - |
2945 | | - out_shape = self._get_output_shape( |
2946 | | - a, b, (a.type.shape, b.type.shape), validate=True |
2947 | | - ) |
2948 | | - out = TensorType(dtype=self.dtype, shape=out_shape)() |
2949 | | - return Apply(self, [a, b], [out]) |
2950 | | - |
2951 | | - def perform(self, node, inputs, outputs): |
2952 | | - x1, x2 = inputs |
2953 | | - outputs[0][0] = np.matmul(x1, x2, dtype=self.dtype) |
2954 | | - |
2955 | | - def infer_shape(self, fgraph, node, shapes): |
2956 | | - x1, x2 = node.inputs |
2957 | | - return [self._get_output_shape(x1, x2, shapes)] |
| 2871 | +_matrix_matrix_matmul = Blockwise(_dot, signature="(n,k),(k,m)->(n,m)") |
2958 | 2872 |
|
2959 | 2873 |
|
2960 | 2874 | def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None): |
@@ -2999,7 +2913,23 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None |
2999 | 2913 | - Stacks of matrices are broadcast together as if the matrices were elements, |
3000 | 2914 | respecting the signature ``(n, k), (k, m) -> (n, m)``: |
3001 | 2915 | """ |
3002 | | - return MatMul(dtype=dtype)(x1, x2) |
| 2916 | + x1 = as_tensor_variable(x1) |
| 2917 | + x2 = as_tensor_variable(x2) |
| 2918 | + if x1.type.ndim == 0 or x2.type.ndim == 0: |
| 2919 | + raise ValueError("matmul operand cannot be scalar") |
| 2920 | + if x1.type.ndim == 1 and x2.type.ndim == 1: |
| 2921 | + out = _dot(x1, x2) |
| 2922 | + elif x1.type.ndim == 1: |
| 2923 | + out = _matrix_matrix_matmul(x1[None], x2).squeeze(-2) |
| 2924 | + elif x2.type.ndim == 1: |
| 2925 | + out = _matrix_matrix_matmul(x1, x2[:, None]).squeeze(-1) |
| 2926 | + else: |
| 2927 | + out = _matrix_matrix_matmul(x1, x2) |
| 2928 | + |
| 2929 | + if dtype is not None: |
| 2930 | + out = out.astype(dtype) |
| 2931 | + |
| 2932 | + return out |
3003 | 2933 |
|
3004 | 2934 |
|
3005 | 2935 | __all__ = [ |
|
0 commit comments