diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 14c82ff4d..ce3cd0636 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -1140,3 +1140,379 @@ where A: LinalgScalar } } } + +/// Specifies the axes along which to perform a tensor contraction in [`tensordot`]. +/// +/// This enum defines how the axes of two tensors should be paired and reduced +/// during a generalized dot product. +/// +/// # Variants +/// +/// * `Num(usize)` — Contract over the last *n* axes of the left-hand tensor +/// and the first *n* axes of the right-hand tensor. +/// +/// * `Pair(Vec, Vec)` — Explicitly specify which axes from each +/// tensor to contract over. The first vector refers to axis indices in the +/// left-hand tensor, and the second to the right-hand tensor. +/// The two lists must be of equal length, and corresponding axes must have +/// matching dimension sizes. +/// +/// # Examples +/// +/// ``` +/// use ndarray::linalg::AxisSpec; +/// // Contract over one axis (e.g., last of `a`, first of `b`) +/// let axes = AxisSpec::Num(1); +/// +/// // Explicitly contract over multiple axes +/// let axes = AxisSpec::Pair(vec![1, 2], vec![0, 3]); +/// ``` +/// +/// # Notes +/// +/// - Axis indices can be negative, in which case they are interpreted +/// relative to the end of the tensor (e.g., `-1` refers to the last axis). +/// - The number and dimensionality of contracted axes determine the rank of +/// the result of [tensordot]. +/// - `AxisSpec` exists to disambiguate and formalise axis specifications, +/// avoiding confusion with [crate::Axis] and [crate::iter::Axes]. +/// +/// # See also +/// +/// [tensordot] — Performs the generalized tensor contraction described by this specification. +/// [`Axis`] — Represents a single axis index within an array. +#[derive(Clone, Debug)] +pub enum AxisSpec +{ + /// Contract over the last *n* axes of the left-hand tensor and the first *n* axes + /// of the right-hand tensor. + + /// For example, `Num(1)` performs standard matrix multiplication, + /// `Num(0)` performs an outer product, and `Num(2)` contracts over two axes. + /// + /// # Example + /// ``` + /// # use ndarray::linalg::AxisSpec; + /// let axes = AxisSpec::Num(1); // last of `a`, first of `b` + /// ``` + Num(usize), + /// Explicitly specify which axes of each tensor to contract over. + /// + /// The first vector lists the axes of the left-hand tensor `a`, + /// and the second vector lists the corresponding axes of the right-hand tensor `b`. + /// Both vectors must be the same length, and each corresponding axis pair + /// must have matching dimension sizes. + /// + /// Negative indices are supported and count from the end + /// (e.g. `-1` refers to the last axis). + /// + /// # Example + /// ``` + /// # use ndarray::linalg::AxisSpec; + /// let axes = AxisSpec::Pair(vec![1, -1], vec![0, 2]); + /// ``` + Pair(Vec, Vec), +} + +// Generalised tensor contraction. +/// +/// This operation extends `dot` and matrix multiplication +/// to tensors of arbitrary rank. The contraction pattern is +/// defined by [`AxisSpec`]. +pub trait Tensordot +{ + /// The result of the contraction. + type Output; + + /// Perform a tensor contraction along specified axes. + /// + /// Given two tensors `self` and `rhs` and an `AxisSpec` specification, + /// containing either a specific number of axes to contract over or + /// explicit lists of axes for each tensor. + /// + /// This function computes sum the products of the elements(components) of `self` and `Rhs` over the axes specified by the `axes` argument. + /// The AxisInfo argument can be a single non-negative integer scalar, N; if it is such, then the last N dimensions of `self` and the first N dimensions of `Rhs` are summed over. + /// If AxisInfo is a pair of lists of integers, then the first list contains the axes to be summed over in `self`, and the second list contains the axes to be summed over in `Rhs`. + /// + /// # Safety and Panics + /// + /// This function uses several internal `unwrap` and `expect` calls when + /// reshaping or permuting arrays. These operations are **guaranteed safe** + /// when the caller-provided `axes` specification is valid, because: + /// + /// - Each axis index is bounds-checked before use. + /// - Contracted axes are validated for duplicate indices and matching + /// dimension sizes. + /// - The resulting permutation and reshape patterns are internally consistent. + /// + /// The only circumstances under which an internal `unwrap`/`expect` may panic are: + /// + /// - The `axes` specification refers to out-of-range or duplicate axes. + /// - The contraction dimensions differ in size. + /// - The computed product of reshaped dimensions does not equal the + /// array’s total element count (which would indicate internal logic error). + #[track_caller] + fn tensordot(&self, rhs: &Rhs, axes: AxisSpec) -> Self::Output; +} + +/// Perform a tensor contraction along specified axes. +/// +/// See [`Tensordot::tensordot`] for more details. +#[track_caller] +pub fn tensordot(a: &ArrayBase, b: &ArrayBase, axes: AxisSpec) -> ArrayD +where + T: LinalgScalar, + Sa: Data, + Sb: Data, + Da: Dimension, + Db: Dimension, +{ + tensordot_impl::(a, b, axes) +} + +/// Performs the full contraction given resolved axis specification. +#[track_caller] +fn tensordot_impl(a: &ArrayBase, b: &ArrayBase, axes: AxisSpec) -> ArrayD +where + T: LinalgScalar, + Sa: Data, + Sb: Data, + Da: Dimension, + Db: Dimension, +{ + let nda = a.ndim() as isize; + let ndb = b.ndim() as isize; + + // Resolve axes + let (mut axes_a, mut axes_b): (Vec, Vec) = match axes { + AxisSpec::Num(n) => { + let n = n as isize; + assert!( + n <= nda && n <= ndb, + "tensordot: cannot contract over {} axes; a.ndim()={}, b.ndim()={}", + n, + nda, + ndb + ); + ((nda - n)..nda).zip(0..n).map(|(ia, ib)| (ia, ib)).unzip() + } + AxisSpec::Pair(aa, bb) => { + assert_eq!( + aa.len(), + bb.len(), + "tensordot: axes length mismatch (a has {}, b has {})", + aa.len(), + bb.len() + ); + (aa, bb) + } + }; + + // Normalise negative indices + for ax in &mut axes_a { + if *ax < 0 { + *ax += nda; + } + } + for ax in &mut axes_b { + if *ax < 0 { + *ax += ndb; + } + } + + // Validate + for &ax in &axes_a { + assert!( + (0..nda).contains(&ax), + "tensordot: axis {} out of bounds for a (ndim={})", + ax, + nda + ); + } + for &ax in &axes_b { + assert!( + (0..ndb).contains(&ax), + "tensordot: axis {} out of bounds for b (ndim={})", + ax, + ndb + ); + } + + // Shape checks + for (ia, ib) in axes_a.iter().zip(&axes_b) { + let da = a.shape()[*ia as usize]; + let db = b.shape()[*ib as usize]; + assert_eq!( + da, db, + "tensordot: shape mismatch along contraction axis: a[{}]={} vs b[{}]={}", + ia, da, ib, db + ); + } + + // Determine non-contracted axes + let notin_a: Vec = (0..nda as usize) + .filter(|k| !axes_a.iter().any(|&ax| ax as usize == *k)) + .collect(); + let notin_b: Vec = (0..ndb as usize) + .filter(|k| !axes_b.iter().any(|&ax| ax as usize == *k)) + .collect(); + + // Reorder axes + let mut newaxes_a = notin_a.clone(); + newaxes_a.extend(axes_a.iter().map(|&x| x as usize)); + let mut newaxes_b = axes_b.iter().map(|&x| x as usize).collect::>(); + newaxes_b.extend(notin_b.iter().copied()); + + // Matrix shapes + let m = notin_a.iter().fold(1, |p, &ax| p * a.shape()[ax]); + let k = axes_a.iter().fold(1, |p, &ax| p * a.shape()[ax as usize]); + let n = notin_b.iter().fold(1, |p, &ax| p * b.shape()[ax]); + + let a_dyn = a.view().into_dimensionality::().unwrap(); + let b_dyn = b.view().into_dimensionality::().unwrap(); + + let a_perm = a_dyn.permuted_axes(IxDyn(&newaxes_a)); + let a_std = a_perm.as_standard_layout(); + + let b_perm = b_dyn.permuted_axes(IxDyn(&newaxes_b)); + let b_std = b_perm.as_standard_layout(); + + let a2 = a_std + .into_shape_with_order(Ix2(m, k)) + .expect("reshaping a to 2D"); + let b2 = b_std + .into_shape_with_order(Ix2(k, n)) + .expect("reshaping b to 2D"); + + let c2 = a2.dot(&b2); + + let mut out_shape: Vec = notin_a.iter().map(|&ax| a.shape()[ax]).collect(); + out_shape.extend(notin_b.iter().map(|&ax| b.shape()[ax])); + + c2.into_shape_with_order(IxDyn(&out_shape)).unwrap() +} + +// ArrayBase × ArrayBase +impl Tensordot> for ArrayBase +where + A: LinalgScalar, + S: Data, + S2: Data, + D1: Dimension, + D2: Dimension, +{ + type Output = ArrayD; + + #[track_caller] + fn tensordot(&self, rhs: &ArrayBase, axes: AxisSpec) -> Self::Output + { + tensordot_impl::(self, rhs, axes) + } +} + +// ArrayBase × ArrayRef (rhs is ArrayRef) → pass a view to backend +impl Tensordot> for ArrayBase +where + A: LinalgScalar, + S: Data, + D1: Dimension, + D2: Dimension, +{ + type Output = ArrayD; + + #[track_caller] + fn tensordot(&self, rhs: &ArrayRef, axes: AxisSpec) -> Self::Output + { + let rhs_view: ArrayBase, D2> = rhs.view(); + tensordot_impl::, D1, D2>(self, &rhs_view, axes) + } +} + +// ArrayRef × ArrayBase (self is ArrayRef) → pass a view to backend +impl Tensordot> for ArrayRef +where + A: LinalgScalar, + S: Data, + D1: Dimension, + D2: Dimension, +{ + type Output = ArrayD; + + #[track_caller] + fn tensordot(&self, rhs: &ArrayBase, axes: AxisSpec) -> Self::Output + { + let self_view: ArrayBase, D1> = self.view(); + tensordot_impl::, S, D1, D2>(&self_view, rhs, axes) + } +} + +#[cfg(test)] +mod tensordot_tests +{ + use super::*; + use crate::{ArrayD, IxDyn}; + + #[test] + fn basic_pair_axes() + { + // a.shape = [3, 4, 5], b.shape = [4, 3, 2] + let a = ArrayD::from_shape_vec(IxDyn(&[3, 4, 5]), (0..60).collect::>()).unwrap(); + let b = ArrayD::from_shape_vec(IxDyn(&[4, 3, 2]), (0..24).collect::>()).unwrap(); + + let c: ArrayD = tensordot(&a, &b, AxisSpec::Pair(vec![1, 0], vec![0, 1])); + + // Expected shape: [5, 2] + assert_eq!( + c.shape(), + &[5, 2], + "unexpected output shape: got {:?}, expected [5, 2]", + c.shape() + ); + + // Spot check one known value + assert_eq!( + c[[0, 0]], 4400, + "unexpected value at [0,0]: got {}, expected 4400", + c[[0, 0]] + ); + + // Check consistency of the entire first row (informative failure message) + let first_row = c.slice(s![0, ..]).to_vec(); + assert_eq!( + first_row.len(), + 2, + "first row length mismatch: got {}, expected 2", + first_row.len() + ); + } + + #[test] + fn integer_axes() + { + // a.shape = [2, 2, 2], b.shape = [2, 2] + let a = ArrayD::from_shape_vec(IxDyn(&[2, 2, 2]), (1..=8).collect::>()).unwrap(); + let b = ArrayD::from_shape_vec(IxDyn(&[2, 2]), vec![10, 20, 30, 40]).unwrap(); + + // Contract over 2 axes + let c: ArrayD = tensordot(&a, &b, AxisSpec::Num(2)); + + assert_eq!( + c.shape(), + &[2], + "unexpected output shape: got {:?}, expected [2]", + c.shape() + ); + + // Extract result as slice for easy comparison + let got = c.as_slice().expect("array not contiguous"); + let expected = [300, 700]; // verified numeric result + + assert_eq!( + got, + expected, + "tensor contraction result mismatch:\n got {:?}\n expected {:?}", + got, + expected + ); + } +} diff --git a/src/linalg/mod.rs b/src/linalg/mod.rs index dc6964f9b..e31294b42 100644 --- a/src/linalg/mod.rs +++ b/src/linalg/mod.rs @@ -12,5 +12,6 @@ pub use self::impl_linalg::general_mat_mul; pub use self::impl_linalg::general_mat_vec_mul; pub use self::impl_linalg::kron; pub use self::impl_linalg::Dot; +pub use self::impl_linalg::{tensordot, AxisSpec, Tensordot}; mod impl_linalg;