Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
376 changes: 376 additions & 0 deletions src/linalg/impl_linalg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1140,3 +1140,379 @@ where A: LinalgScalar
}
}
}

/// Specifies the axes along which to perform a tensor contraction in [`tensordot`].
///
/// This enum defines how the axes of two tensors should be paired and reduced
/// during a generalized dot product.
///
/// # Variants
///
/// * `Num(usize)` — Contract over the last *n* axes of the left-hand tensor
/// and the first *n* axes of the right-hand tensor.
///
/// * `Pair(Vec<isize>, Vec<isize>)` — Explicitly specify which axes from each
/// tensor to contract over. The first vector refers to axis indices in the
/// left-hand tensor, and the second to the right-hand tensor.
/// The two lists must be of equal length, and corresponding axes must have
/// matching dimension sizes.
///
/// # Examples
///
/// ```
/// use ndarray::linalg::AxisSpec;
/// // Contract over one axis (e.g., last of `a`, first of `b`)
/// let axes = AxisSpec::Num(1);
///
/// // Explicitly contract over multiple axes
/// let axes = AxisSpec::Pair(vec![1, 2], vec![0, 3]);
/// ```
///
/// # Notes
///
/// - Axis indices can be negative, in which case they are interpreted
/// relative to the end of the tensor (e.g., `-1` refers to the last axis).
/// - The number and dimensionality of contracted axes determine the rank of
/// the result of [tensordot].
/// - `AxisSpec` exists to disambiguate and formalise axis specifications,
/// avoiding confusion with [crate::Axis] and [crate::iter::Axes].
///
/// # See also
///
/// [tensordot] — Performs the generalized tensor contraction described by this specification.
/// [`Axis`] — Represents a single axis index within an array.
#[derive(Clone, Debug)]
pub enum AxisSpec
{
/// Contract over the last *n* axes of the left-hand tensor and the first *n* axes
/// of the right-hand tensor.

/// For example, `Num(1)` performs standard matrix multiplication,
/// `Num(0)` performs an outer product, and `Num(2)` contracts over two axes.
///
/// # Example
/// ```
/// # use ndarray::linalg::AxisSpec;
/// let axes = AxisSpec::Num(1); // last of `a`, first of `b`
/// ```
Num(usize),
/// Explicitly specify which axes of each tensor to contract over.
///
/// The first vector lists the axes of the left-hand tensor `a`,
/// and the second vector lists the corresponding axes of the right-hand tensor `b`.
/// Both vectors must be the same length, and each corresponding axis pair
/// must have matching dimension sizes.
///
/// Negative indices are supported and count from the end
/// (e.g. `-1` refers to the last axis).
///
/// # Example
/// ```
/// # use ndarray::linalg::AxisSpec;
/// let axes = AxisSpec::Pair(vec![1, -1], vec![0, 2]);
/// ```
Pair(Vec<isize>, Vec<isize>),
}

// Generalised tensor contraction.
///
/// This operation extends `dot` and matrix multiplication
/// to tensors of arbitrary rank. The contraction pattern is
/// defined by [`AxisSpec`].
pub trait Tensordot<Rhs: ?Sized>
{
/// The result of the contraction.
type Output;

/// Perform a tensor contraction along specified axes.
///
/// Given two tensors `self` and `rhs` and an `AxisSpec` specification,
/// containing either a specific number of axes to contract over or
/// explicit lists of axes for each tensor.
///
/// This function computes sum the products of the elements(components) of `self` and `Rhs` over the axes specified by the `axes` argument.
/// The AxisInfo argument can be a single non-negative integer scalar, N; if it is such, then the last N dimensions of `self` and the first N dimensions of `Rhs` are summed over.
/// If AxisInfo is a pair of lists of integers, then the first list contains the axes to be summed over in `self`, and the second list contains the axes to be summed over in `Rhs`.
///
/// # Safety and Panics
///
/// This function uses several internal `unwrap` and `expect` calls when
/// reshaping or permuting arrays. These operations are **guaranteed safe**
/// when the caller-provided `axes` specification is valid, because:
///
/// - Each axis index is bounds-checked before use.
/// - Contracted axes are validated for duplicate indices and matching
/// dimension sizes.
/// - The resulting permutation and reshape patterns are internally consistent.
///
/// The only circumstances under which an internal `unwrap`/`expect` may panic are:
///
/// - The `axes` specification refers to out-of-range or duplicate axes.
/// - The contraction dimensions differ in size.
/// - The computed product of reshaped dimensions does not equal the
/// array’s total element count (which would indicate internal logic error).
#[track_caller]
fn tensordot(&self, rhs: &Rhs, axes: AxisSpec) -> Self::Output;
}

/// Perform a tensor contraction along specified axes.
///
/// See [`Tensordot::tensordot`] for more details.
#[track_caller]
pub fn tensordot<T, Sa, Sb, Da, Db>(a: &ArrayBase<Sa, Da>, b: &ArrayBase<Sb, Db>, axes: AxisSpec) -> ArrayD<T>
where
T: LinalgScalar,
Sa: Data<Elem = T>,
Sb: Data<Elem = T>,
Da: Dimension,
Db: Dimension,
{
tensordot_impl::<T, Sa, Sb, Da, Db>(a, b, axes)
}

/// Performs the full contraction given resolved axis specification.
#[track_caller]
fn tensordot_impl<T, Sa, Sb, Da, Db>(a: &ArrayBase<Sa, Da>, b: &ArrayBase<Sb, Db>, axes: AxisSpec) -> ArrayD<T>
where
T: LinalgScalar,
Sa: Data<Elem = T>,
Sb: Data<Elem = T>,
Da: Dimension,
Db: Dimension,
{
let nda = a.ndim() as isize;
let ndb = b.ndim() as isize;

// Resolve axes
let (mut axes_a, mut axes_b): (Vec<isize>, Vec<isize>) = match axes {
AxisSpec::Num(n) => {
let n = n as isize;
assert!(
n <= nda && n <= ndb,
"tensordot: cannot contract over {} axes; a.ndim()={}, b.ndim()={}",
n,
nda,
ndb
);
((nda - n)..nda).zip(0..n).map(|(ia, ib)| (ia, ib)).unzip()
}
AxisSpec::Pair(aa, bb) => {
assert_eq!(
aa.len(),
bb.len(),
"tensordot: axes length mismatch (a has {}, b has {})",
aa.len(),
bb.len()
);
(aa, bb)
}
};

// Normalise negative indices
for ax in &mut axes_a {
if *ax < 0 {
*ax += nda;
}
}
for ax in &mut axes_b {
if *ax < 0 {
*ax += ndb;
}
}

// Validate
for &ax in &axes_a {
assert!(
(0..nda).contains(&ax),
"tensordot: axis {} out of bounds for a (ndim={})",
ax,
nda
);
}
for &ax in &axes_b {
assert!(
(0..ndb).contains(&ax),
"tensordot: axis {} out of bounds for b (ndim={})",
ax,
ndb
);
}

// Shape checks
for (ia, ib) in axes_a.iter().zip(&axes_b) {
let da = a.shape()[*ia as usize];
let db = b.shape()[*ib as usize];
assert_eq!(
da, db,
"tensordot: shape mismatch along contraction axis: a[{}]={} vs b[{}]={}",
ia, da, ib, db
);
}

// Determine non-contracted axes
let notin_a: Vec<usize> = (0..nda as usize)
.filter(|k| !axes_a.iter().any(|&ax| ax as usize == *k))
.collect();
let notin_b: Vec<usize> = (0..ndb as usize)
.filter(|k| !axes_b.iter().any(|&ax| ax as usize == *k))
.collect();

// Reorder axes
let mut newaxes_a = notin_a.clone();
newaxes_a.extend(axes_a.iter().map(|&x| x as usize));
let mut newaxes_b = axes_b.iter().map(|&x| x as usize).collect::<Vec<_>>();
newaxes_b.extend(notin_b.iter().copied());

// Matrix shapes
let m = notin_a.iter().fold(1, |p, &ax| p * a.shape()[ax]);
let k = axes_a.iter().fold(1, |p, &ax| p * a.shape()[ax as usize]);
let n = notin_b.iter().fold(1, |p, &ax| p * b.shape()[ax]);

let a_dyn = a.view().into_dimensionality::<IxDyn>().unwrap();
let b_dyn = b.view().into_dimensionality::<IxDyn>().unwrap();

let a_perm = a_dyn.permuted_axes(IxDyn(&newaxes_a));
let a_std = a_perm.as_standard_layout();

let b_perm = b_dyn.permuted_axes(IxDyn(&newaxes_b));
let b_std = b_perm.as_standard_layout();

let a2 = a_std
.into_shape_with_order(Ix2(m, k))
.expect("reshaping a to 2D");
let b2 = b_std
.into_shape_with_order(Ix2(k, n))
.expect("reshaping b to 2D");

let c2 = a2.dot(&b2);

let mut out_shape: Vec<usize> = notin_a.iter().map(|&ax| a.shape()[ax]).collect();
out_shape.extend(notin_b.iter().map(|&ax| b.shape()[ax]));

c2.into_shape_with_order(IxDyn(&out_shape)).unwrap()
}

// ArrayBase × ArrayBase
impl<A, S, S2, D1, D2> Tensordot<ArrayBase<S2, D2>> for ArrayBase<S, D1>
where
A: LinalgScalar,
S: Data<Elem = A>,
S2: Data<Elem = A>,
D1: Dimension,
D2: Dimension,
{
type Output = ArrayD<A>;

#[track_caller]
fn tensordot(&self, rhs: &ArrayBase<S2, D2>, axes: AxisSpec) -> Self::Output
{
tensordot_impl::<A, S, S2, D1, D2>(self, rhs, axes)
}
}

// ArrayBase × ArrayRef (rhs is ArrayRef) → pass a view to backend
impl<A, S, D1, D2> Tensordot<ArrayRef<A, D2>> for ArrayBase<S, D1>
where
A: LinalgScalar,
S: Data<Elem = A>,
D1: Dimension,
D2: Dimension,
{
type Output = ArrayD<A>;

#[track_caller]
fn tensordot(&self, rhs: &ArrayRef<A, D2>, axes: AxisSpec) -> Self::Output
{
let rhs_view: ArrayBase<ViewRepr<&A>, D2> = rhs.view();
tensordot_impl::<A, S, ViewRepr<&A>, D1, D2>(self, &rhs_view, axes)
}
}

// ArrayRef × ArrayBase (self is ArrayRef) → pass a view to backend
impl<A, S, D1, D2> Tensordot<ArrayBase<S, D2>> for ArrayRef<A, D1>
where
A: LinalgScalar,
S: Data<Elem = A>,
D1: Dimension,
D2: Dimension,
{
type Output = ArrayD<A>;

#[track_caller]
fn tensordot(&self, rhs: &ArrayBase<S, D2>, axes: AxisSpec) -> Self::Output
{
let self_view: ArrayBase<ViewRepr<&A>, D1> = self.view();
tensordot_impl::<A, ViewRepr<&A>, S, D1, D2>(&self_view, rhs, axes)
}
}

#[cfg(test)]
mod tensordot_tests
{
use super::*;
use crate::{ArrayD, IxDyn};

#[test]
fn basic_pair_axes()
{
// a.shape = [3, 4, 5], b.shape = [4, 3, 2]
let a = ArrayD::from_shape_vec(IxDyn(&[3, 4, 5]), (0..60).collect::<Vec<_>>()).unwrap();
let b = ArrayD::from_shape_vec(IxDyn(&[4, 3, 2]), (0..24).collect::<Vec<_>>()).unwrap();

let c: ArrayD<i32> = tensordot(&a, &b, AxisSpec::Pair(vec![1, 0], vec![0, 1]));

// Expected shape: [5, 2]
assert_eq!(
c.shape(),
&[5, 2],
"unexpected output shape: got {:?}, expected [5, 2]",
c.shape()
);

// Spot check one known value
assert_eq!(
c[[0, 0]], 4400,
"unexpected value at [0,0]: got {}, expected 4400",
c[[0, 0]]
);

// Check consistency of the entire first row (informative failure message)
let first_row = c.slice(s![0, ..]).to_vec();
assert_eq!(
first_row.len(),
2,
"first row length mismatch: got {}, expected 2",
first_row.len()
);
}

#[test]
fn integer_axes()
{
// a.shape = [2, 2, 2], b.shape = [2, 2]
let a = ArrayD::from_shape_vec(IxDyn(&[2, 2, 2]), (1..=8).collect::<Vec<_>>()).unwrap();
let b = ArrayD::from_shape_vec(IxDyn(&[2, 2]), vec![10, 20, 30, 40]).unwrap();

// Contract over 2 axes
let c: ArrayD<i32> = tensordot(&a, &b, AxisSpec::Num(2));

assert_eq!(
c.shape(),
&[2],
"unexpected output shape: got {:?}, expected [2]",
c.shape()
);

// Extract result as slice for easy comparison
let got = c.as_slice().expect("array not contiguous");
let expected = [300, 700]; // verified numeric result

assert_eq!(
got,
expected,
"tensor contraction result mismatch:\n got {:?}\n expected {:?}",
got,
expected
);
}
}
1 change: 1 addition & 0 deletions src/linalg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ pub use self::impl_linalg::general_mat_mul;
pub use self::impl_linalg::general_mat_vec_mul;
pub use self::impl_linalg::kron;
pub use self::impl_linalg::Dot;
pub use self::impl_linalg::{tensordot, AxisSpec, Tensordot};

mod impl_linalg;