diff --git a/MANIFEST.in b/MANIFEST.in index 7c96ba026..d93298de4 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -10,4 +10,5 @@ include ot/lp/full_bipartitegraph.h include ot/lp/full_bipartitegraph_omp.h include ot/lp/network_simplex_simple.h include ot/lp/network_simplex_simple_omp.h +include ot/lp/sparse_bipartitegraph.h include ot/partial/partial_cython.pyx diff --git a/docs/source/_static/images/comparison.png b/docs/source/_static/images/comparison.png new file mode 100644 index 000000000..587a4fb95 Binary files /dev/null and b/docs/source/_static/images/comparison.png differ diff --git a/examples/plot_sparse_emd.py b/examples/plot_sparse_emd.py new file mode 100644 index 000000000..a57ad8e54 --- /dev/null +++ b/examples/plot_sparse_emd.py @@ -0,0 +1,225 @@ +# -*- coding: utf-8 -*- +""" +============================================ +Sparse Optimal Transport +============================================ + +In many real-world optimal transport (OT) problems, the transport plan is +naturally sparse: only a small fraction of all possible source-target pairs +actually exchange mass. Using sparse OT solvers can provide significant +computational speedups and memory savings compared to dense solvers. + +This example demonstrates how to use sparse cost matrices with POT's EMD solver, +comparing sparse and dense formulations on both a minimal example and a larger +concentric circles dataset. +""" + +# Author: Nathan Neike +# License: MIT License +# sphinx_gallery_thumbnail_number = 2 + +import numpy as np +import matplotlib.pyplot as plt +from scipy.sparse import coo_matrix +import ot + + +############################################################################## +# Minimal example with 4 points +# ------------------------------ + +# %% + +X = np.array([[0, 0], [1, 0], [0.5, 0], [1.5, 0]]) +Y = np.array([[0, 1], [1, 1], [0.5, 1], [1.5, 1]]) +a = np.array([0.25, 0.25, 0.25, 0.25]) +b = np.array([0.25, 0.25, 0.25, 0.25]) + +# Build sparse cost matrix allowing only selected edges +rows = [0, 1, 2, 3] +cols = [0, 1, 2, 3] +vals = [np.linalg.norm(X[i] - Y[j]) for i, j in zip(rows, cols)] +M_sparse = coo_matrix((vals, (rows, cols)), shape=(4, 4)) + + +############################################################################## +# Solve and display sparse OT solution +# ------------------------------------- + +# %% + +G, log = ot.emd(a, b, M_sparse, log=True) + +print("Sparse OT cost:", log["cost"]) +print("Solution format:", type(G)) +print("Non-zero edges:", G.nnz) +print("\nEdges:") +G_coo = G if isinstance(G, coo_matrix) else G.tocoo() +for i, j, v in zip(G_coo.row, G_coo.col, G_coo.data): + if v > 1e-10: + print(f" source {i} -> target {j}, flow={v:.3f}") + + +############################################################################## +# Visualize sparse vs dense edge structure +# ----------------------------------------- + +# %% + +plt.figure(figsize=(8, 4)) + +plt.subplot(1, 2, 1) +plt.scatter(X[:, 0], X[:, 1], c="r", marker="o", s=100, zorder=3) +plt.scatter(Y[:, 0], Y[:, 1], c="b", marker="x", s=100, zorder=3) +for i, j in zip(rows, cols): + plt.plot([X[i, 0], Y[j, 0]], [X[i, 1], Y[j, 1]], "b-", linewidth=1, alpha=0.6) +plt.title("Sparse OT: Allowed Edges Only") +plt.xlim(-0.5, 2.0) +plt.ylim(-0.5, 1.5) + +plt.subplot(1, 2, 2) +plt.scatter(X[:, 0], X[:, 1], c="r", marker="o", s=100, zorder=3) +plt.scatter(Y[:, 0], Y[:, 1], c="b", marker="x", s=100, zorder=3) +for i in range(len(X)): + for j in range(len(Y)): + plt.plot([X[i, 0], Y[j, 0]], [X[i, 1], Y[j, 1]], "b-", linewidth=1, alpha=0.3) +plt.title("Dense OT: All Possible Edges") +plt.xlim(-0.5, 2.0) +plt.ylim(-0.5, 1.5) + +plt.tight_layout() + + +############################################################################## +# Larger example: concentric circles +# ----------------------------------- + +# %% + +n_clusters = 8 +points_per_cluster = 25 +n = n_clusters * points_per_cluster +k_neighbors = 8 +rng = np.random.default_rng(0) + +r_source = 1.0 +r_target = 2.0 +noise_scale = 0.06 + +theta = np.linspace(0.0, 2.0 * np.pi, n, endpoint=False) +cluster_labels = np.repeat(np.arange(n_clusters), points_per_cluster) + +X_large = np.column_stack( + [r_source * np.cos(theta), r_source * np.sin(theta)] +) + rng.normal(scale=noise_scale, size=(n, 2)) +Y_large = np.column_stack( + [r_target * np.cos(theta), r_target * np.sin(theta)] +) + rng.normal(scale=noise_scale, size=(n, 2)) + +a_large = np.zeros(n) +b_large = np.zeros(n) +for k in range(n_clusters): + idx = np.where(cluster_labels == k)[0] + a_large[idx] = 1.0 / n_clusters / points_per_cluster + b_large[idx] = 1.0 / n_clusters / points_per_cluster + +M_full = ot.dist(X_large, Y_large, metric="euclidean") + +# Build sparse cost matrix: intra-cluster k-nearest neighbors +angles_X = np.arctan2(X_large[:, 1], X_large[:, 0]) +angles_Y = np.arctan2(Y_large[:, 1], Y_large[:, 0]) + +rows = [] +cols = [] +vals = [] +for k in range(n_clusters): + src_idx = np.where(cluster_labels == k)[0] + tgt_idx = np.where(cluster_labels == k)[0] + for i in src_idx: + diff = np.angle(np.exp(1j * (angles_Y[tgt_idx] - angles_X[i]))) + idx = np.argsort(np.abs(diff))[:k_neighbors] + for j_local in idx: + j = tgt_idx[j_local] + rows.append(i) + cols.append(j) + vals.append(M_full[i, j]) + +M_sparse_large = coo_matrix((vals, (rows, cols)), shape=(n, n)) +allowed_sparse = set(zip(rows, cols)) + +############################################################################## +# Visualize edge structures +# -------------------------- + +# %% + +plt.figure(figsize=(16, 6)) + +plt.subplot(1, 2, 1) +for i in range(n): + for j in range(n): + plt.plot( + [X_large[i, 0], Y_large[j, 0]], + [X_large[i, 1], Y_large[j, 1]], + color="blue", + alpha=0.2, + linewidth=0.05, + ) +plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20) +plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20) +plt.axis("equal") +plt.title("Dense OT: All Possible Edges") + +plt.subplot(1, 2, 2) +for i, j in allowed_sparse: + plt.plot( + [X_large[i, 0], Y_large[j, 0]], + [X_large[i, 1], Y_large[j, 1]], + color="blue", + alpha=1, + linewidth=0.05, + ) +plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20) +plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20) +plt.axis("equal") +plt.title("Sparse OT: Intra-Cluster k-NN Edges") + +plt.tight_layout() +plt.show() + +############################################################################## +# Solve and visualize transport plans +# ------------------------------------ + +# %% + +G_dense = ot.emd(a_large, b_large, M_full) +cost_dense = np.sum(G_dense * M_full) +print(f"Dense OT cost: {cost_dense:.6f}") + +G_sparse, log_sparse = ot.emd(a_large, b_large, M_sparse_large, log=True) +cost_sparse = log_sparse["cost"] +print(f"Sparse OT cost: {cost_sparse:.6f}") + +plt.figure(figsize=(16, 6)) + +plt.subplot(1, 2, 1) +ot.plot.plot2D_samples_mat( + X_large, Y_large, G_dense, thr=1e-10, c=[0.5, 0.5, 1], alpha=0.5 +) +plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20, zorder=3) +plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20, zorder=3) +plt.axis("equal") +plt.title("Dense OT: Optimal Transport Plan") + +plt.subplot(1, 2, 2) +ot.plot.plot2D_samples_mat( + X_large, Y_large, G_sparse, thr=1e-10, c=[0.5, 0.5, 1], alpha=0.5 +) +plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20, zorder=3) +plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20, zorder=3) +plt.axis("equal") +plt.title("Sparse OT: Optimal Transport Plan") + +plt.tight_layout() +plt.show() diff --git a/ot/backend.py b/ot/backend.py index a11c78209..7ca505c0f 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -178,7 +178,16 @@ def _get_backend_instance(backend_impl): def _check_args_backend(backend_impl, args): - is_instance = set(isinstance(arg, backend_impl.__type__) for arg in args) + # Get backend instance to use issparse method + backend = _get_backend_instance(backend_impl) + + # Check if each arg is either: + # 1. An instance of backend.__type__ (e.g., np.ndarray for NumPy) + # 2. A sparse matrix recognized by backend.issparse() (e.g., scipy.sparse for NumPy) + is_instance = set( + isinstance(arg, backend_impl.__type__) or backend.issparse(arg) for arg in args + ) + # check that all arguments matched or not the type if len(is_instance) == 1: return is_instance.pop() @@ -839,6 +848,31 @@ def todense(self, a): """ raise NotImplementedError() + def sparse_coo_data(self, a): + r""" + Extracts COO format data (row, col, data, shape) from a sparse matrix. + + Returns row indices, column indices, data values, and shape as numpy arrays/tuple. + This is used to interface with C++ solvers that require explicit edge lists. + + Parameters + ---------- + a : sparse matrix + Sparse matrix in backend's COO format + + Returns + ------- + row : numpy.ndarray + Row indices (1D array) + col : numpy.ndarray + Column indices (1D array) + data : numpy.ndarray + Data values (1D array) + shape : tuple + Shape of the matrix (n_rows, n_cols) + """ + raise NotImplementedError() + def where(self, condition, x, y): r""" Returns elements chosen from x or y depending on condition. @@ -1349,6 +1383,15 @@ def todense(self, a): else: return a + def sparse_coo_data(self, a): + # Convert to COO format if needed + if not isinstance(a, coo_matrix): + a_coo = coo_matrix(a) + else: + a_coo = a + + return a_coo.row, a_coo.col, a_coo.data, a_coo.shape + def where(self, condition, x=None, y=None): if x is None and y is None: return np.where(condition) @@ -1768,6 +1811,15 @@ def todense(self, a): # Currently, JAX does not support sparse matrices return a + def sparse_coo_data(self, a): + # JAX doesn't support sparse matrices, so this shouldn't be called + # But if it is, convert the dense array to sparse using scipy + a_np = self.to_numpy(a) + from scipy.sparse import coo_matrix + + a_coo = coo_matrix(a_np) + return a_coo.row, a_coo.col, a_coo.data, a_coo.shape + def where(self, condition, x=None, y=None): if x is None and y is None: return jnp.where(condition) @@ -2340,6 +2392,20 @@ def todense(self, a): else: return a + def sparse_coo_data(self, a): + # For torch sparse tensors, coalesce first to ensure unique indices + a_coalesced = a.coalesce() + indices = a_coalesced._indices() + values = a_coalesced._values() + + # Convert to numpy + row = self.to_numpy(indices[0]) + col = self.to_numpy(indices[1]) + data = self.to_numpy(values) + shape = tuple(a_coalesced.shape) + + return row, col, data, shape + def where(self, condition, x=None, y=None): if x is None and y is None: return torch.where(condition) diff --git a/ot/lp/EMD.h b/ot/lp/EMD.h index b56f0601b..e3564a2d2 100644 --- a/ot/lp/EMD.h +++ b/ot/lp/EMD.h @@ -32,6 +32,24 @@ enum ProblemType { int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter); int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads); +int EMD_wrap_sparse( + int n1, + int n2, + double *X, + double *Y, + uint64_t n_edges, // Number of edges in sparse graph + uint64_t *edge_sources, // Source indices for each edge (n_edges) + uint64_t *edge_targets, // Target indices for each edge (n_edges) + double *edge_costs, // Cost for each edge (n_edges) + uint64_t *flow_sources_out, // Output: source indices of non-zero flows + uint64_t *flow_targets_out, // Output: target indices of non-zero flows + double *flow_values_out, // Output: flow values + uint64_t *n_flows_out, + double *alpha, // Output: dual variables for sources (n1) + double *beta, // Output: dual variables for targets (n2) + double *cost, // Output: total transportation cost + uint64_t maxIter // Maximum iterations for solver +); #endif diff --git a/ot/lp/EMD_wrapper.cpp b/ot/lp/EMD_wrapper.cpp index 4aa5a6e72..bd3672535 100644 --- a/ot/lp/EMD_wrapper.cpp +++ b/ot/lp/EMD_wrapper.cpp @@ -15,8 +15,10 @@ #include "network_simplex_simple.h" #include "network_simplex_simple_omp.h" +#include "sparse_bipartitegraph.h" #include "EMD.h" #include +#include int EMD_wrap(int n1, int n2, double *X, double *Y, double *D, double *G, @@ -216,3 +218,156 @@ int EMD_wrap_omp(int n1, int n2, double *X, double *Y, double *D, double *G, return ret; } + +// ============================================================================ +// SPARSE VERSION: Accepts edge list instead of dense cost matrix +// ============================================================================ +int EMD_wrap_sparse( + int n1, + int n2, + double *X, + double *Y, + uint64_t n_edges, + uint64_t *edge_sources, + uint64_t *edge_targets, + double *edge_costs, + uint64_t *flow_sources_out, + uint64_t *flow_targets_out, + double *flow_values_out, + uint64_t *n_flows_out, + double *alpha, + double *beta, + double *cost, + uint64_t maxIter +) { + using namespace lemon; + + uint64_t n = 0; + for (int i = 0; i < n1; i++) { + double val = *(X + i); + if (val > 0) { + n++; + } else if (val < 0) { + return INFEASIBLE; + } + } + + uint64_t m = 0; + for (int i = 0; i < n2; i++) { + double val = *(Y + i); + if (val > 0) { + m++; + } else if (val < 0) { + return INFEASIBLE; + } + } + + std::vector indI(n); // indI[graph_idx] = original_source_idx + std::vector indJ(m); // indJ[graph_idx] = original_target_idx + std::vector weights1(n); // Source masses (positive only) + std::vector weights2(m); // Target masses (negative for demand) + + // Create reverse mapping: original_idx → graph_idx + std::vector source_to_graph(n1, -1); + std::vector target_to_graph(n2, -1); + + uint64_t cur = 0; + for (int i = 0; i < n1; i++) { + double val = *(X + i); + if (val > 0) { + weights1[cur] = val; // Store the mass + indI[cur] = i; // Forward map: graph → original + source_to_graph[i] = cur; // Reverse map: original → graph + cur++; + } + } + + cur = 0; + for (int i = 0; i < n2; i++) { + double val = *(Y + i); + if (val > 0) { + weights2[cur] = -val; + indJ[cur] = i; // Forward map: graph → original + target_to_graph[i] = cur; // Reverse map: original → graph + cur++; + } + } + + typedef SparseBipartiteDigraph Digraph; + DIGRAPH_TYPEDEFS(Digraph); + + Digraph di(n, m); + + std::vector> edges; // (source, target) pairs + std::vector edge_to_arc; // edge_to_arc[k] = arc ID for edge k + std::vector arc_costs; // arc_costs[arc_id] = cost (for O(1) lookup) + edges.reserve(n_edges); + edge_to_arc.reserve(n_edges); + + uint64_t valid_edge_count = 0; + for (uint64_t k = 0; k < n_edges; k++) { + int64_t src_orig = edge_sources[k]; + int64_t tgt_orig = edge_targets[k]; + int64_t src = source_to_graph[src_orig]; + int64_t tgt = target_to_graph[tgt_orig]; + + if (src >= 0 && tgt >= 0) { + edges.emplace_back(src, tgt + n); + edge_to_arc.push_back(valid_edge_count); + arc_costs.push_back(edge_costs[k]); // Store cost indexed by arc ID + valid_edge_count++; + } else { + edge_to_arc.push_back(UINT64_MAX); + } + } + + + di.buildFromEdges(edges); + + NetworkSimplexSimple net( + di, true, (int)(n + m), di.arcNum(), maxIter + ); + + net.supplyMap(&weights1[0], (int)n, &weights2[0], (int)m); + + for (uint64_t k = 0; k < n_edges; k++) { + if (edge_to_arc[k] != UINT64_MAX) { + net.setCost(edge_to_arc[k], edge_costs[k]); + } + } + + int ret = net.run(); + + if (ret == (int)net.OPTIMAL || ret == (int)net.MAX_ITER_REACHED) { + *cost = 0; + *n_flows_out = 0; + + Arc a; + di.first(a); + for (; a != INVALID; di.next(a)) { + uint64_t i = di.source(a); + uint64_t j = di.target(a); + double flow = net.flow(a); + + uint64_t orig_i = indI[i]; + uint64_t orig_j = indJ[j - n]; + + + double arc_cost = arc_costs[a]; + + *cost += flow * arc_cost; + + + *(alpha + orig_i) = -net.potential(i); + *(beta + orig_j) = net.potential(j); + + if (flow > 1e-15) { + flow_sources_out[*n_flows_out] = orig_i; + flow_targets_out[*n_flows_out] = orig_j; + flow_values_out[*n_flows_out] = flow; + (*n_flows_out)++; + } + } + } + return ret; +} \ No newline at end of file diff --git a/ot/lp/_network_simplex.py b/ot/lp/_network_simplex.py index 492e4c7ac..200001378 100644 --- a/ot/lp/_network_simplex.py +++ b/ot/lp/_network_simplex.py @@ -13,7 +13,7 @@ from ..utils import list_to_array, check_number_threads from ..backend import get_backend -from .emd_wrap import emd_c, check_result +from .emd_wrap import emd_c, emd_c_sparse, check_result def center_ot_dual(alpha0, beta0, a=None, b=None): @@ -215,8 +215,13 @@ def emd( Source histogram (uniform weight if empty list) b : (nt,) array-like, float Target histogram (uniform weight if empty list) - M : (ns,nt) array-like, float - Loss matrix (c-order array in numpy with type float64) + M : (ns,nt) array-like or sparse matrix, float + Loss matrix. Can be: + + - Dense array (c-order array in numpy with type float64) + - Sparse matrix in backend's format (scipy.sparse.coo_matrix for NumPy backend, + torch.sparse_coo_tensor for PyTorch backend, etc.) + numItermax : int, optional (default=100000) The maximum number of iterations before stopping the optimization algorithm if it has not converged. @@ -233,15 +238,27 @@ def emd( If True, checks that the marginals mass are equal. If False, skips the check. + .. note:: The solver automatically detects sparse format using the backend's + :py:meth:`issparse` method. For sparse inputs: + + - Uses a memory-efficient sparse EMD algorithm + - Returns the transport plan as a sparse matrix in the backend's format + - Supports scipy.sparse (NumPy), torch.sparse (PyTorch), etc. + - JAX and TensorFlow backends don't support sparse matrices + Returns ------- - gamma: array-like, shape (ns, nt) - Optimal transportation matrix for the given - parameters + gamma: array-like or sparse matrix, shape (ns, nt) + Optimal transportation matrix for the given parameters. + + - For dense inputs: returns a dense array + - For sparse inputs: returns a sparse matrix in the backend's format + (e.g., scipy.sparse.coo_matrix for NumPy, torch.sparse_coo_tensor for PyTorch) + log: dict, optional - If input log is true, a dictionary containing the - cost and dual variables and exit status + If input log is True, a dictionary containing the cost, dual variables, + and exit status. Examples @@ -272,38 +289,78 @@ def emd( ot.optim.cg : General regularized OT """ - a, b, M = list_to_array(a, b, M) - nx = get_backend(M, a, b) + edge_sources = None + edge_targets = None + edge_costs = None + n1, n2 = None, None + + # Get backend from M first, then use it for list_to_array + # This ensures empty lists [] are converted to arrays in the correct backend + nx_M = get_backend(M) + a, b = list_to_array(a, b, nx=nx_M) + nx = get_backend(a, b, M) + + # Check if M is sparse using backend's issparse method + is_sparse = nx.issparse(M) + + if is_sparse: + # Check if backend supports sparse matrices + backend_name = nx.__class__.__name__ + if backend_name in ["JaxBackend", "TensorflowBackend"]: + raise NotImplementedError( + f"Sparse optimal transport is not supported for {backend_name}. " + "JAX does not have native sparse matrix support, and TensorFlow's " + "sparse implementation is incomplete. Please convert your sparse " + "matrix to dense format using M.toarray() or equivalent before calling emd()." + ) + + # Extract COO data using backend method - returns numpy arrays + edge_sources, edge_targets, edge_costs, (n1, n2) = nx.sparse_coo_data(M) + + # Ensure correct dtypes for C++ solver + if edge_sources.dtype != np.uint64: + edge_sources = edge_sources.astype(np.uint64) + if edge_targets.dtype != np.uint64: + edge_targets = edge_targets.astype(np.uint64) + if edge_costs.dtype != np.float64: + edge_costs = edge_costs.astype(np.float64) + + elif isinstance(M, tuple): + raise ValueError( + "Tuple format for sparse cost matrix is not supported. " + "Please use backend-appropriate sparse COO format (e.g., scipy.sparse.coo_matrix, torch.sparse_coo_tensor, etc.)." + ) + else: + is_sparse = False + a, b, M = list_to_array(a, b, M) if len(a) != 0: type_as = a elif len(b) != 0: type_as = b else: - type_as = M + type_as = a + + # Set n1, n2 if not already set (dense case) + if n1 is None: + n1, n2 = M.shape - # if empty array given then use uniform distributions if len(a) == 0: - a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] + a = nx.ones((n1,), type_as=type_as) / n1 if len(b) == 0: - b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] + b = nx.ones((n2,), type_as=type_as) / n2 - # convert to numpy - M, a, b = nx.to_numpy(M, a, b) + if is_sparse: + a, b = nx.to_numpy(a, b) + else: + M, a, b = nx.to_numpy(M, a, b) + M = np.asarray(M, dtype=np.float64, order="C") - # ensure float64 a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64, order="C") - - # if empty array given then use uniform distributions - if len(a) == 0: - a = np.ones((M.shape[0],), dtype=np.float64) / M.shape[0] - if len(b) == 0: - b = np.ones((M.shape[1],), dtype=np.float64) / M.shape[1] assert ( - a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1] + a.shape[0] == n1 and b.shape[0] == n2 ), "Dimension mismatch, check dimensions of M with a and b" # ensure that same mass @@ -321,32 +378,77 @@ def emd( numThreads = check_number_threads(numThreads) - G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + # ============================================================================ + # CALL SOLVER (sparse or dense) + # ============================================================================ + if is_sparse: + # Sparse solver - never build full matrix + flow_sources, flow_targets, flow_values, cost, u, v, result_code = emd_c_sparse( + a, b, edge_sources, edge_targets, edge_costs, numItermax + ) + else: + # Dense solver + G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) + # ============================================================================ + # POST-PROCESS DUAL VARIABLES AND CREATE TRANSPORT PLAN + # ============================================================================ + + # Center dual potentials if center_dual: u, v = center_ot_dual(u, v, a, b) + # Handle null weights if np.any(~asel) or np.any(~bsel): - u, v = estimate_dual_null_weights(u, v, a, b, M) + if is_sparse: + u, v = center_ot_dual(u, v, a, b) + else: + u, v = estimate_dual_null_weights(u, v, a, b, M) result_code_string = check_result(result_code) - if not nx.is_floating_point(type_as): - warnings.warn( - "Input histogram consists of integer. The transport plan will be " - "casted accordingly, possibly resulting in a loss of precision. " - "If this behaviour is unwanted, please make sure your input " - "histogram consists of floating point elements.", - stacklevel=2, + + # Create transport plan in backend format + if is_sparse: + # Convert flow to sparse matrix using backend's coo_matrix method + flow_values_backend = nx.from_numpy(flow_values, type_as=type_as) + flow_sources_backend = nx.from_numpy( + flow_sources.astype(np.int64), type_as=type_as + ) + flow_targets_backend = nx.from_numpy( + flow_targets.astype(np.int64), type_as=type_as + ) + + G = nx.coo_matrix( + flow_values_backend, + flow_sources_backend, + flow_targets_backend, + shape=(n1, n2), + type_as=type_as, ) + else: + # Warn about integer casting for dense case + if not nx.is_floating_point(type_as): + warnings.warn( + "Input histogram consists of integer. The transport plan will be " + "casted accordingly, possibly resulting in a loss of precision. " + "If this behaviour is unwanted, please make sure your input " + "histogram consists of floating point elements.", + stacklevel=2, + ) + G = nx.from_numpy(G, type_as=type_as) + + # Return results if log: - log = {} - log["cost"] = cost - log["u"] = nx.from_numpy(u, type_as=type_as) - log["v"] = nx.from_numpy(v, type_as=type_as) - log["warning"] = result_code_string - log["result_code"] = result_code - return nx.from_numpy(G, type_as=type_as), log - return nx.from_numpy(G, type_as=type_as) + log_dict = { + "cost": cost, + "u": nx.from_numpy(u, type_as=type_as), + "v": nx.from_numpy(v, type_as=type_as), + "warning": result_code_string, + "result_code": result_code, + } + return G, log_dict + else: + return G def emd2( @@ -356,10 +458,10 @@ def emd2( processes=1, numItermax=100000, log=False, - return_matrix=False, center_dual=True, numThreads=1, check_marginals=True, + return_matrix=False, ): r"""Solves the Earth Movers distance problem and returns the loss @@ -399,8 +501,13 @@ def emd2( Source histogram (uniform weight if empty list) b : (nt,) array-like, float64 Target histogram (uniform weight if empty list) - M : (ns,nt) array-like, float64 - Loss matrix (for numpy c-order array with type float64) + M : (ns,nt) array-like or sparse matrix, float64 + Loss matrix. Can be: + + - Dense array (c-order array in numpy with type float64) + - Sparse matrix in backend's format (scipy.sparse.coo_matrix for NumPy backend, + torch.sparse_coo_tensor for PyTorch backend, etc.) + processes : int, optional (default=1) Nb of processes used for multiple emd computation (deprecated) numItermax : int, optional (default=100000) @@ -421,6 +528,14 @@ def emd2( If True, checks that the marginals mass are equal. If False, skips the check. + .. note:: The solver automatically detects sparse format using the backend's + :py:meth:`issparse` method. For sparse inputs: + + - Uses a memory-efficient sparse EMD algorithm + - Edges not included are treated as having infinite cost (forbidden) + - Supports scipy.sparse (NumPy), torch.sparse (PyTorch), etc. + - JAX and TensorFlow backends don't support sparse matrices + Returns ------- @@ -460,34 +575,96 @@ def emd2( ot.optim.cg : General regularized OT """ - a, b, M = list_to_array(a, b, M) - nx = get_backend(M, a, b) + edge_sources = None + edge_targets = None + edge_costs = None + n1, n2 = None, None + + # Get backend from M first, then use it for list_to_array + # This ensures empty lists [] are converted to arrays in the correct backend + nx_M = get_backend(M) + a, b = list_to_array(a, b, nx=nx_M) + nx = get_backend(a, b, M) + + # Check if M is sparse using backend's issparse method + is_sparse = nx.issparse(M) + + # Save original sparse tensor for gradient tracking (before conversion to numpy) + M_original_sparse = None + + if is_sparse: + # Check if backend supports sparse matrices + backend_name = nx.__class__.__name__ + if backend_name in ["JaxBackend", "TensorflowBackend"]: + raise NotImplementedError( + f"Sparse optimal transport is not supported for {backend_name}. " + "JAX does not have native sparse matrix support, and TensorFlow's " + "sparse implementation is incomplete. Please convert your sparse " + "matrix to dense format using M.toarray() or equivalent before calling emd2()." + ) + + # Save original M for gradient tracking (before numpy conversion) + M_original_sparse = M + + # Extract COO data using backend method - returns numpy arrays + edge_sources, edge_targets, edge_costs, (n1, n2) = nx.sparse_coo_data(M) + + # Ensure correct dtypes for C++ solver + if edge_sources.dtype != np.uint64: + edge_sources = edge_sources.astype(np.uint64) + if edge_targets.dtype != np.uint64: + edge_targets = edge_targets.astype(np.uint64) + if edge_costs.dtype != np.float64: + edge_costs = edge_costs.astype(np.float64) + + elif isinstance(M, tuple): + raise ValueError( + "Tuple format for sparse cost matrix is not supported. " + "Please use backend-appropriate sparse COO format (e.g., scipy.sparse.coo_matrix, torch.sparse_coo_tensor, etc.)." + ) + else: + # Dense matrix + is_sparse = False + a, b, M = list_to_array(a, b, M) if len(a) != 0: type_as = a elif len(b) != 0: type_as = b else: - type_as = M + type_as = a # Can't use M for sparse case + + # Set n1, n2 if not already set (dense case) + if n1 is None: + n1, n2 = M.shape # if empty array given then use uniform distributions if len(a) == 0: - a = nx.ones((M.shape[0],), type_as=type_as) / M.shape[0] + a = nx.ones((n1,), type_as=type_as) / n1 if len(b) == 0: - b = nx.ones((M.shape[1],), type_as=type_as) / M.shape[1] + b = nx.ones((n2,), type_as=type_as) / n2 - # store original tensors - a0, b0, M0 = a, b, M + a0, b0 = a, b + M0 = None if is_sparse else M - # convert to numpy - M, a, b = nx.to_numpy(M, a, b) + if is_sparse: + # Use the original sparse tensor (preserves gradients for PyTorch) + # instead of converting from numpy + edge_costs_original = M_original_sparse + else: + edge_costs_original = None + + if is_sparse: + a, b = nx.to_numpy(a, b) + else: + M, a, b = nx.to_numpy(M, a, b) + M = np.asarray(M, dtype=np.float64, order="C") a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) - M = np.asarray(M, dtype=np.float64, order="C") assert ( - a.shape[0] == M.shape[0] and b.shape[0] == M.shape[1] + a.shape[0] == n1 and b.shape[0] == n2 ), "Dimension mismatch, check dimensions of M with a and b" # ensure that same mass @@ -504,54 +681,72 @@ def emd2( numThreads = check_number_threads(numThreads) - if log or return_matrix: - - def f(b): - bsel = b != 0 - + # ============================================================================ + # DEFINE SOLVER FUNCTION (works for both sparse and dense) + # ============================================================================ + def f(b): + bsel = b != 0 + + # Call appropriate solver + if is_sparse: + # Solve sparse EMD + flow_sources, flow_targets, flow_values, cost, u, v, result_code = ( + emd_c_sparse(a, b, edge_sources, edge_targets, edge_costs, numItermax) + ) + else: + # Solve dense EMD G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) - if center_dual: - u, v = center_ot_dual(u, v, a, b) + # Center dual potentials + if center_dual: + u, v = center_ot_dual(u, v, a, b) - if np.any(~asel) or np.any(~bsel): + # Handle null weights + if np.any(~asel) or np.any(~bsel): + if is_sparse: + u, v = center_ot_dual(u, v, a, b) + else: u, v = estimate_dual_null_weights(u, v, a, b, M) - result_code_string = check_result(result_code) - log = {} - if not nx.is_floating_point(type_as): - warnings.warn( - "Input histogram consists of integer. The transport plan will be " - "casted accordingly, possibly resulting in a loss of precision. " - "If this behaviour is unwanted, please make sure your input " - "histogram consists of floating point elements.", - stacklevel=2, + # Prepare cost with gradients + if is_sparse: + # Build gradient mapping for sparse case + edge_to_idx = { + (edge_sources[k], edge_targets[k]): k for k in range(len(edge_sources)) + } + + grad_edge_costs = np.zeros(len(edge_costs), dtype=np.float64) + for idx in range(len(flow_sources)): + src, tgt, flow = flow_sources[idx], flow_targets[idx], flow_values[idx] + edge_idx = edge_to_idx.get((src, tgt), -1) + if edge_idx >= 0: + grad_edge_costs[edge_idx] = flow + + # Convert gradient to sparse format matching edge_costs_original + grad_edge_costs_backend = nx.from_numpy(grad_edge_costs, type_as=type_as) + if nx.issparse(edge_costs_original): + # Reconstruct sparse gradient tensor with same structure as original + grad_M_sparse = nx.coo_matrix( + grad_edge_costs_backend, + nx.from_numpy(edge_sources.astype(np.int64), type_as=type_as), + nx.from_numpy(edge_targets.astype(np.int64), type_as=type_as), + shape=(n1, n2), + type_as=type_as, ) - G = nx.from_numpy(G, type_as=type_as) - if return_matrix: - log["G"] = G - log["u"] = nx.from_numpy(u, type_as=type_as) - log["v"] = nx.from_numpy(v, type_as=type_as) - log["warning"] = result_code_string - log["result_code"] = result_code + else: + grad_M_sparse = grad_edge_costs_backend + cost = nx.set_gradients( nx.from_numpy(cost, type_as=type_as), - (a0, b0, M0), - (log["u"] - nx.mean(log["u"]), log["v"] - nx.mean(log["v"]), G), + (a0, b0, edge_costs_original), + ( + nx.from_numpy(u - np.mean(u), type_as=type_as), + nx.from_numpy(v - np.mean(v), type_as=type_as), + grad_M_sparse, + ), ) - return [cost, log] - else: - - def f(b): - bsel = b != 0 - G, cost, u, v, result_code = emd_c(a, b, M, numItermax, numThreads) - - if center_dual: - u, v = center_ot_dual(u, v, a, b) - - if np.any(~asel) or np.any(~bsel): - u, v = estimate_dual_null_weights(u, v, a, b, M) - + else: + # Dense case: warn about integer casting if not nx.is_floating_point(type_as): warnings.warn( "Input histogram consists of integer. The transport plan will be " @@ -560,18 +755,39 @@ def f(b): "histogram consists of floating point elements.", stacklevel=2, ) - G = nx.from_numpy(G, type_as=type_as) + + G_backend = nx.from_numpy(G, type_as=type_as) cost = nx.set_gradients( nx.from_numpy(cost, type_as=type_as), (a0, b0, M0), ( nx.from_numpy(u - np.mean(u), type_as=type_as), nx.from_numpy(v - np.mean(v), type_as=type_as), - G, + G_backend, ), ) - check_result(result_code) + check_result(result_code) + + # Return results + if log or return_matrix: + log_dict = { + "u": nx.from_numpy(u, type_as=type_as), + "v": nx.from_numpy(v, type_as=type_as), + "warning": check_result(result_code), + "result_code": result_code, + } + + if return_matrix: + if is_sparse: + G = np.zeros((len(a), len(b)), dtype=np.float64) + G[flow_sources, flow_targets] = flow_values + log_dict["G"] = nx.from_numpy(G, type_as=type_as) + else: + log_dict["G"] = G_backend + + return [cost, log_dict] + else: return cost if len(b.shape) == 1: diff --git a/ot/lp/emd_wrap.pyx b/ot/lp/emd_wrap.pyx index 53df54fc3..3b19d3fdd 100644 --- a/ot/lp/emd_wrap.pyx +++ b/ot/lp/emd_wrap.pyx @@ -14,7 +14,7 @@ from ..utils import dist cimport cython cimport libc.math as math -from libc.stdint cimport uint64_t +from libc.stdint cimport uint64_t, int64_t import warnings @@ -22,6 +22,7 @@ import warnings cdef extern from "EMD.h": int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter) nogil int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads) nogil + int EMD_wrap_sparse(int n1, int n2, double *X, double *Y, uint64_t n_edges, uint64_t *edge_sources, uint64_t *edge_targets, double *edge_costs, uint64_t *flow_sources_out, uint64_t *flow_targets_out, double *flow_values_out, uint64_t *n_flows_out, double *alpha, double *beta, double *cost, uint64_t maxIter) nogil cdef enum ProblemType: INFEASIBLE, OPTIMAL, UNBOUNDED, MAX_ITER_REACHED @@ -206,3 +207,78 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, cur_idx += 1 cur_idx += 1 return G[:cur_idx], indices[:cur_idx], cost + +@cython.boundscheck(False) +@cython.wraparound(False) +def emd_c_sparse(np.ndarray[double, ndim=1, mode="c"] a, + np.ndarray[double, ndim=1, mode="c"] b, + np.ndarray[uint64_t, ndim=1, mode="c"] edge_sources, + np.ndarray[uint64_t, ndim=1, mode="c"] edge_targets, + np.ndarray[double, ndim=1, mode="c"] edge_costs, + uint64_t max_iter): + """ + Sparse EMD solver - only considers edges in edge_sources/edge_targets + + Parameters + ---------- + a : (n1,) array + Source histogram + b : (n2,) array + Target histogram + edge_sources : (k,) array, uint64 + Source indices for each edge + edge_targets : (k,) array, uint64 + Target indices for each edge + edge_costs : (k,) array, float64 + Cost for each edge + max_iter : uint64_t + Maximum iterations + + Returns + ------- + flow_sources : (n_flows,) array, uint64 + Source indices of non-zero flows + flow_targets : (n_flows,) array, uint64 + Target indices of non-zero flows + flow_values : (n_flows,) array, float64 + Flow values + cost : float + Total cost + alpha : (n1,) array + Dual variables for sources + beta : (n2,) array + Dual variables for targets + result_code : int + Result status + """ + cdef int n1 = a.shape[0] + cdef int n2 = b.shape[0] + cdef uint64_t n_edges = edge_sources.shape[0] + cdef uint64_t n_flows_out = 0 + cdef int result_code = 0 + cdef double cost = 0 + + # Allocate output arrays (max size = n_edges) + cdef np.ndarray[uint64_t, ndim=1, mode="c"] flow_sources = np.zeros(n_edges, dtype=np.uint64) + cdef np.ndarray[uint64_t, ndim=1, mode="c"] flow_targets = np.zeros(n_edges, dtype=np.uint64) + cdef np.ndarray[double, ndim=1, mode="c"] flow_values = np.zeros(n_edges, dtype=np.float64) + cdef np.ndarray[double, ndim=1, mode="c"] alpha = np.zeros(n1) + cdef np.ndarray[double, ndim=1, mode="c"] beta = np.zeros(n2) + + with nogil: + result_code = EMD_wrap_sparse( + n1, n2, + a.data, b.data, + n_edges, + edge_sources.data, edge_targets.data, edge_costs.data, + flow_sources.data, flow_targets.data, flow_values.data, + &n_flows_out, + alpha.data, beta.data, &cost, max_iter + ) + + # Trim to actual number of flows + flow_sources = flow_sources[:n_flows_out] + flow_targets = flow_targets[:n_flows_out] + flow_values = flow_values[:n_flows_out] + + return flow_sources, flow_targets, flow_values, cost, alpha, beta, result_code \ No newline at end of file diff --git a/ot/lp/sparse_bipartitegraph.h b/ot/lp/sparse_bipartitegraph.h new file mode 100644 index 000000000..7ba13b41a --- /dev/null +++ b/ot/lp/sparse_bipartitegraph.h @@ -0,0 +1,281 @@ +/* -*- mode: C++; indent-tabs-mode: nil; -*- + * + * Sparse bipartite graph for optimal transport + * Only stores edges that are explicitly added (not all n1×n2 edges) + * + * Uses CSR (Compressed Sparse Row) format for better cache locality and performance + * - Binary search for arc lookup: O(log k) where k = avg edges per node + * - Compact memory layout for better cache performance + * - Requires edges to be provided in sorted order during construction + */ + +#pragma once + +#include "core.h" +#include +#include +#include + +namespace lemon { + + class SparseBipartiteDigraphBase { + public: + + typedef SparseBipartiteDigraphBase Digraph; + typedef int Node; + typedef int64_t Arc; + + protected: + + int _node_num; + int64_t _arc_num; + int _n1, _n2; + + std::vector _arc_sources; // _arc_sources[arc_id] = source node + std::vector _arc_targets; // _arc_targets[arc_id] = target node + + // CSR format + // _row_ptr[i] = start index in _col_indices for source node i + // _row_ptr[i+1] - _row_ptr[i] = number of outgoing edges from node i + std::vector _row_ptr; + std::vector _col_indices; + std::vector _arc_ids; + + mutable std::vector> _in_arcs; // _in_arcs[node] = incoming arc IDs + mutable bool _in_arcs_built; + + SparseBipartiteDigraphBase() : _node_num(0), _arc_num(0), _n1(0), _n2(0), _in_arcs_built(false) {} + + void construct(int n1, int n2) { + _node_num = n1 + n2; + _n1 = n1; + _n2 = n2; + _arc_num = 0; + _arc_sources.clear(); + _arc_targets.clear(); + _row_ptr.clear(); + _col_indices.clear(); + _arc_ids.clear(); + _in_arcs.clear(); + _in_arcs_built = false; + } + + void build_in_arcs() const { + if (_in_arcs_built) return; + + _in_arcs.resize(_node_num); + + for (Arc a = 0; a < _arc_num; ++a) { + Node tgt = _arc_targets[a]; + _in_arcs[tgt].push_back(a); + } + + _in_arcs_built = true; + } + + public: + + Node operator()(int ix) const { return Node(ix); } + static int index(const Node& node) { return node; } + + void buildFromEdges(const std::vector>& edges) { + _arc_num = edges.size(); + + if (_arc_num == 0) { + _row_ptr.assign(_n1 + 1, 0); + return; + } + + // Create indexed edges: (source, target, original_arc_id) + std::vector> indexed_edges; + indexed_edges.reserve(_arc_num); + for (Arc i = 0; i < _arc_num; ++i) { + indexed_edges.emplace_back(edges[i].first, edges[i].second, i); + } + + // Sort by source node, then by target node CSR requirement + std::sort(indexed_edges.begin(), indexed_edges.end(), + [](const auto& a, const auto& b) { + if (std::get<0>(a) != std::get<0>(b)) + return std::get<0>(a) < std::get<0>(b); + return std::get<1>(a) < std::get<1>(b); + }); + + _arc_sources.resize(_arc_num); + _arc_targets.resize(_arc_num); + _col_indices.resize(_arc_num); + _arc_ids.resize(_arc_num); + _row_ptr.resize(_n1 + 1); + + _row_ptr[0] = 0; + int current_row = 0; + + for (int64_t i = 0; i < _arc_num; ++i) { + Node src = std::get<0>(indexed_edges[i]); + Node tgt = std::get<1>(indexed_edges[i]); + Arc orig_arc_id = std::get<2>(indexed_edges[i]); + + // Fill out row_ptr for rows with no outgoing edges + while (current_row < src) { + _row_ptr[++current_row] = i; + } + + _arc_sources[orig_arc_id] = src; + _arc_targets[orig_arc_id] = tgt; + _col_indices[i] = tgt; + _arc_ids[i] = orig_arc_id; + } + + // Fill remaining row_ptr entries + while (current_row < _n1) { + _row_ptr[++current_row] = _arc_num; + } + + _in_arcs_built = false; + } + + // Find arc from s to t using binary search (returns -1 if not found) + Arc arc(const Node& s, const Node& t) const { + if (s < 0 || s >= _n1 || t < _n1 || t >= _node_num) { + return Arc(-1); + } + + int64_t start = _row_ptr[s]; + int64_t end = _row_ptr[s + 1]; + + // Binary search for target t in col_indices[start:end] + auto it = std::lower_bound( + _col_indices.begin() + start, + _col_indices.begin() + end, + t + ); + + if (it != _col_indices.begin() + end && *it == t) { + int64_t pos = it - _col_indices.begin(); + return _arc_ids[pos]; + } + + return Arc(-1); + } + + int nodeNum() const { return _node_num; } + int64_t arcNum() const { return _arc_num; } + + int maxNodeId() const { return _node_num - 1; } + int64_t maxArcId() const { return _arc_num - 1; } + + Node source(Arc arc) const { + return (arc >= 0 && arc < _arc_num) ? _arc_sources[arc] : Node(-1); + } + + Node target(Arc arc) const { + return (arc >= 0 && arc < _arc_num) ? _arc_targets[arc] : Node(-1); + } + + static int id(Node node) { return node; } + static int64_t id(Arc arc) { return arc; } + + static Node nodeFromId(int id) { return Node(id); } + static Arc arcFromId(int64_t id) { return Arc(id); } + + Arc findArc(Node s, Node t, Arc prev = -1) const { + return prev == -1 ? arc(s, t) : Arc(-1); + } + + void first(Node& node) const { + node = _node_num - 1; + } + + static void next(Node& node) { + --node; + } + + void first(Arc& arc) const { + arc = _arc_num - 1; + } + + static void next(Arc& arc) { + --arc; + } + + void firstOut(Arc& arc, const Node& node) const { + if (node < 0 || node >= _n1) { + arc = -1; + return; + } + + int64_t start = _row_ptr[node]; + int64_t end = _row_ptr[node + 1]; + + arc = (start < end) ? _arc_ids[start] : Arc(-1); + } + + void nextOut(Arc& arc) const { + if (arc < 0) return; + + Node src = _arc_sources[arc]; + int64_t start = _row_ptr[src]; + int64_t end = _row_ptr[src + 1]; + + for (int64_t i = start; i < end; ++i) { + if (_arc_ids[i] == arc) { + arc = (i + 1 < end) ? _arc_ids[i + 1] : Arc(-1); + return; + } + } + arc = -1; + } + + void firstIn(Arc& arc, const Node& node) const { + build_in_arcs(); // Lazy build on first call + + if (node < 0 || node >= _node_num || node < _n1) { + arc = -1; // Invalid node or source nodes have no incoming arcs + return; + } + + const std::vector& in = _in_arcs[node]; + arc = in.empty() ? Arc(-1) : in[0]; + } + + void nextIn(Arc& arc) const { + if (arc < 0) return; + + Node tgt = _arc_targets[arc]; + const std::vector& in = _in_arcs[tgt]; + + // Find current arc in the list and return next one + for (size_t i = 0; i < in.size(); ++i) { + if (in[i] == arc) { + arc = (i + 1 < in.size()) ? in[i + 1] : Arc(-1); + return; + } + } + arc = -1; + } + }; + + /// Sparse bipartite digraph - only stores edges that are explicitly added + class SparseBipartiteDigraph : public SparseBipartiteDigraphBase { + typedef SparseBipartiteDigraphBase Parent; + + public: + + SparseBipartiteDigraph() { construct(0, 0); } + + SparseBipartiteDigraph(int n1, int n2) { construct(n1, n2); } + + Node operator()(int ix) const { return Parent::operator()(ix); } + static int index(const Node& node) { return Parent::index(node); } + + void buildFromEdges(const std::vector>& edges) { + Parent::buildFromEdges(edges); + } + + Arc arc(Node s, Node t) const { return Parent::arc(s, t); } + + int nodeNum() const { return Parent::nodeNum(); } + int64_t arcNum() const { return Parent::arcNum(); } + }; + +} //namespace lemon diff --git a/ot/plot.py b/ot/plot.py index 1505235c8..e3091ac8a 100644 --- a/ot/plot.py +++ b/ot/plot.py @@ -215,36 +215,72 @@ def plot2D_samples_mat(xs, xt, G, thr=1e-8, **kwargs): Plot lines between source and target 2D samples with a color proportional to the value of the matrix :math:`\mathbf{G}` between samples. + Supports both dense and sparse matrices. For sparse matrices, automatically + detects the format and efficiently iterates only over non-zero entries. Parameters ---------- xs : ndarray, shape (ns,2) Source samples positions - b : ndarray, shape (nt,2) + xt : ndarray, shape (nt,2) Target samples positions - G : ndarray, shape (na,nb) - OT matrix + G : ndarray or sparse matrix, shape (ns,nt) + OT matrix (dense array, scipy.sparse matrix, or backend array) thr : float, optional threshold above which the line is drawn **kwargs : dict parameters given to the plot functions (default color is black if nothing given) """ + from . import backend + from scipy.sparse import issparse, coo_matrix if ("color" not in kwargs) and ("c" not in kwargs): kwargs["color"] = "k" - mx = G.max() - if "alpha" in kwargs: - scale = kwargs["alpha"] - del kwargs["alpha"] - else: - scale = 1 - for i in range(xs.shape[0]): - for j in range(xt.shape[0]): - if G[i, j] / mx > thr: + + scale = kwargs.pop("alpha", 1) + + # Convert to numpy/scipy format for plotting + try: + nx = backend.get_backend(G) + if nx.issparse(G): + # Backend sparse -> extract as numpy arrays for COO format + rows, cols, data = nx.sparse_coo_data(G) + rows = nx.to_numpy(rows).astype(int) + cols = nx.to_numpy(cols).astype(int) + data = nx.to_numpy(data) + is_sparse = True + else: + # Backend dense -> convert to numpy + G = nx.to_numpy(G) + is_sparse = False + except (ValueError, AttributeError): + # Not a backend array, check if scipy.sparse + is_sparse = issparse(G) + if is_sparse: + G_coo = G if isinstance(G, coo_matrix) else G.tocoo() + rows, cols, data = G_coo.row, G_coo.col, G_coo.data + + if is_sparse: + # Sparse: iterate over non-zero entries only + mx = data.max() if len(data) > 0 else 1.0 + for i, j, val in zip(rows, cols, data): + if val / mx > thr: pl.plot( [xs[i, 0], xt[j, 0]], [xs[i, 1], xt[j, 1]], - alpha=G[i, j] / mx * scale, + alpha=val / mx * scale, **kwargs, ) + else: + # Dense: iterate over all entries + mx = np.max(G) + for i in range(xs.shape[0]): + for j in range(xt.shape[0]): + if G[i, j] / mx > thr: + pl.plot( + [xs[i, 0], xt[j, 0]], + [xs[i, 1], xt[j, 1]], + alpha=G[i, j] / mx * scale, + **kwargs, + ) diff --git a/test/test_backend.py b/test/test_backend.py index 2a0fc9a48..efd696ef0 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -75,6 +75,48 @@ class nx_subclass(nx.__type__): assert effective_nx.__name__ == nx.__name__ +def test_get_backend_sparse_matrix(): + """Test that get_backend correctly handles sparse matrices and rejects mixed backends.""" + from scipy.sparse import coo_matrix + + a_np = np.array([0.5, 0.5]) + b_np = np.array([0.5, 0.5]) + M_scipy = coo_matrix(([1.0, 2.0], ([0, 1], [0, 1])), shape=(2, 2)) + + nx = get_backend(a_np, b_np, M_scipy) + assert nx.__name__ == "numpy", "NumPy backend should accept scipy.sparse matrices" + + nx = get_backend(M_scipy) + assert nx.__name__ == "numpy", "scipy.sparse should use NumPy backend" + + if torch: + a_torch = torch.tensor([0.5, 0.5]) + b_torch = torch.tensor([0.5, 0.5]) + M_torch_sparse = torch.sparse_coo_tensor( + torch.tensor([[0, 1], [0, 1]]), torch.tensor([1.0, 2.0]), (2, 2) + ) + + nx = get_backend(a_torch, b_torch, M_torch_sparse) + assert ( + nx.__name__ == "torch" + ), "PyTorch backend should accept torch.sparse tensors" + + nx = get_backend(M_torch_sparse) + assert nx.__name__ == "torch", "torch.sparse should use PyTorch backend" + + # Case 1: PyTorch dense + scipy.sparse (incompatible) + with pytest.raises(ValueError): + get_backend(a_torch, b_torch, M_scipy) + + # Case 2: NumPy dense + torch.sparse (incompatible) + with pytest.raises(ValueError): + get_backend(a_np, b_np, M_torch_sparse) + + # Case 3: scipy.sparse + torch.sparse (incompatible) + with pytest.raises(ValueError): + get_backend(M_scipy, M_torch_sparse) + + def test_convert_between_backends(nx): A = np.zeros((3, 2)) B = np.zeros((3, 1)) diff --git a/test/test_ot.py b/test/test_ot.py index e8217d54d..e0a438a80 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -12,6 +12,7 @@ import ot from ot.datasets import make_1D_gauss as gauss from ot.backend import torch, tf, get_backend +from scipy.sparse import coo_matrix def test_emd_dimension_and_mass_mismatch(): @@ -914,6 +915,474 @@ def test_dual_variables(): assert constraint_violation.max() < 1e-8 +def test_emd_sparse_vs_dense(): + """Test that sparse and dense EMD solvers produce identical results. + + Uses augmented k-NN graph approach: first solves with dense solver to + identify needed edges, then compares both solvers on the same graph. + """ + n_source = 100 + n_target = 100 + k = 10 + + rng = np.random.RandomState(42) + + x_source = rng.randn(n_source, 2) + x_target = rng.randn(n_target, 2) + 0.5 + + a = ot.utils.unif(n_source) + b = ot.utils.unif(n_target) + + C = ot.dist(x_source, x_target) + + rows = [] + cols = [] + data = [] + + for i in range(n_source): + distances = C[i, :] + nearest_k = np.argpartition(distances, k)[:k] + for j in nearest_k: + rows.append(i) + cols.append(j) + data.append(C[i, j]) + + C_knn = coo_matrix((data, (rows, cols)), shape=(n_source, n_target)) + + large_cost = 1e8 + C_dense_infty = np.full((n_source, n_target), large_cost) + C_knn_array = C_knn.toarray() + C_dense_infty[C_knn_array > 0] = C_knn_array[C_knn_array > 0] + + G_dense_initial = ot.emd(a, b, C_dense_infty) + eps = 1e-9 + active_mask = G_dense_initial > eps + knn_mask = C_knn_array > 0 + extra_edges_mask = active_mask & ~knn_mask + + rows_aug = [] + cols_aug = [] + data_aug = [] + + knn_rows, knn_cols = np.where(knn_mask) + for i, j in zip(knn_rows, knn_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + extra_rows, extra_cols = np.where(extra_edges_mask) + for i, j in zip(extra_rows, extra_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + C_augmented = coo_matrix( + (data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target) + ) + + C_augmented_dense = np.full((n_source, n_target), large_cost) + C_augmented_array = C_augmented.toarray() + C_augmented_dense[C_augmented_array > 0] = C_augmented_array[C_augmented_array > 0] + + G_dense, log_dense = ot.emd(a, b, C_augmented_dense, log=True) + G_sparse, log_sparse = ot.emd(a, b, C_augmented, log=True) + + cost_dense = log_dense["cost"] + cost_sparse = log_sparse["cost"] + + np.testing.assert_allclose(cost_dense, cost_sparse, rtol=1e-5, atol=1e-7) + + # For dense, G_dense is returned; for sparse, reconstruct from flow edges + np.testing.assert_allclose(a, G_dense.sum(1), rtol=1e-5, atol=1e-7) + np.testing.assert_allclose(b, G_dense.sum(0), rtol=1e-5, atol=1e-7) + + # G_sparse is now returned as a sparse matrix + from scipy.sparse import issparse + + assert issparse(G_sparse), "Sparse solver should return a sparse matrix" + + # Convert to dense for marginal checks + G_sparse_dense = G_sparse.toarray() + np.testing.assert_allclose(a, G_sparse_dense.sum(1), rtol=1e-5, atol=1e-7) + np.testing.assert_allclose(b, G_sparse_dense.sum(0), rtol=1e-5, atol=1e-7) + + +def test_emd2_sparse_vs_dense(): + """Test that sparse and dense emd2 solvers produce identical results. + + Uses augmented k-NN graph approach: first solves with dense solver to + identify needed edges, then compares both solvers on the same graph. + """ + n_source = 100 + n_target = 100 + k = 10 + + rng = np.random.RandomState(42) + + x_source = rng.randn(n_source, 2) + x_target = rng.randn(n_target, 2) + 0.5 + + a = ot.utils.unif(n_source) + b = ot.utils.unif(n_target) + + C = ot.dist(x_source, x_target) + + rows = [] + cols = [] + data = [] + + for i in range(n_source): + distances = C[i, :] + nearest_k = np.argpartition(distances, k)[:k] + for j in nearest_k: + rows.append(i) + cols.append(j) + data.append(C[i, j]) + + C_knn = coo_matrix((data, (rows, cols)), shape=(n_source, n_target)) + + large_cost = 1e8 + C_dense_infty = np.full((n_source, n_target), large_cost) + C_knn_array = C_knn.toarray() + C_dense_infty[C_knn_array > 0] = C_knn_array[C_knn_array > 0] + + G_dense_initial = ot.emd(a, b, C_dense_infty) + + eps = 1e-9 + active_mask = G_dense_initial > eps + knn_mask = C_knn_array > 0 + extra_edges_mask = active_mask & ~knn_mask + + rows_aug = [] + cols_aug = [] + data_aug = [] + + knn_rows, knn_cols = np.where(knn_mask) + for i, j in zip(knn_rows, knn_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + extra_rows, extra_cols = np.where(extra_edges_mask) + for i, j in zip(extra_rows, extra_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + C_augmented = coo_matrix( + (data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target) + ) + + C_augmented_dense = np.full((n_source, n_target), large_cost) + C_augmented_array = C_augmented.toarray() + C_augmented_dense[C_augmented_array > 0] = C_augmented_array[C_augmented_array > 0] + + cost_dense = ot.emd2(a, b, C_augmented_dense) + cost_sparse = ot.emd2(a, b, C_augmented) + + np.testing.assert_allclose(cost_dense, cost_sparse, rtol=1e-5, atol=1e-7) + + +def test_emd2_sparse_gradients(): + """Test that PyTorch sparse tensors support gradient computation.""" + if not torch: + pytest.skip("PyTorch not available") + + n = 10 + a = torch.tensor(ot.utils.unif(n), requires_grad=True, dtype=torch.float64) + b = torch.tensor(ot.utils.unif(n), requires_grad=True, dtype=torch.float64) + + rows, cols, costs = [], [], [] + for i in range(n): + rows.append(i) + cols.append(i) + costs.append(0.1) + for offset in [1, 2]: + j = (i + offset) % n + rows.append(i) + cols.append(j) + costs.append(float(offset)) + + indices = torch.tensor( + np.vstack([np.array(rows), np.array(cols)]), dtype=torch.int64 + ) + values = torch.tensor(costs, dtype=torch.float64) + M_sparse = torch.sparse_coo_tensor(indices, values, (n, n), dtype=torch.float64) + + cost = ot.emd2(a, b, M_sparse) + cost.backward() + + assert a.grad is not None + assert b.grad is not None + np.testing.assert_allclose( + a.grad.sum().item(), -b.grad.sum().item(), rtol=1e-5, atol=1e-7 + ) + + +def test_emd2_sparse_vs_dense_gradients(): + """Verify gradient w.r.t. cost matrix M equals transport plan G.""" + if not torch: + pytest.skip("PyTorch not available") + + n = 4 + a = torch.tensor([0.25, 0.25, 0.25, 0.25], requires_grad=True, dtype=torch.float64) + b = torch.tensor([0.25, 0.25, 0.25, 0.25], requires_grad=True, dtype=torch.float64) + + M_full = torch.tensor( + [ + [0.1, 1.0, 2.0, 3.0], + [1.0, 0.1, 1.0, 2.0], + [2.0, 1.0, 0.1, 1.0], + [3.0, 2.0, 1.0, 0.1], + ], + dtype=torch.float64, + requires_grad=True, + ) + + cost_dense = ot.emd2(a, b, M_full) + cost_dense.backward() + G_dense = ot.emd(a.detach(), b.detach(), M_full.detach()) + + np.testing.assert_allclose( + M_full.grad.numpy(), G_dense.numpy(), rtol=1e-7, atol=1e-10 + ) + + a.grad = None + b.grad = None + + rows, cols, costs = [], [], [] + for i in range(n): + for j in range(max(0, i - 1), min(n, i + 2)): + rows.append(i) + cols.append(j) + costs.append(M_full[i, j].item()) + + rows_t = torch.tensor(rows, dtype=torch.int64) + cols_t = torch.tensor(cols, dtype=torch.int64) + M_sparse = torch.sparse_coo_tensor( + torch.stack([rows_t, cols_t]), + torch.tensor(costs, dtype=torch.float64), + (n, n), + dtype=torch.float64, + requires_grad=True, + ) + + cost_sparse = ot.emd2(a, b, M_sparse) + cost_sparse.backward() + G_sparse = ot.emd(a.detach(), b.detach(), M_sparse.detach()).to_dense() + + grad_values = M_sparse.grad.coalesce().values().numpy() + G_values = G_sparse[rows_t, cols_t].numpy() + + np.testing.assert_allclose(grad_values, G_values, rtol=1e-7, atol=1e-10) + assert grad_values.sum() > 0 + assert np.abs(grad_values.sum() - 1.0) < 1e-7 + + +def test_emd_sparse_backends(nx): + """Test that sparse EMD works with different backends for weights a and b. + + Uses augmented k-NN graph approach to ensure feasibility. + """ + # Skip backends that don't support sparse matrices + # JAX: no sparse support + # TensorFlow: coo_matrix() returns dense tensors + backend_name = nx.__class__.__name__.lower() + if "jax" in backend_name or "tensorflow" in backend_name: + pytest.skip("Backend does not support sparse matrices") + + n_source = 50 + n_target = 50 + k = 10 + + rng = np.random.RandomState(42) + + a = ot.utils.unif(n_source) + b = ot.utils.unif(n_target) + + x_source = rng.randn(n_source, 2) + x_target = rng.randn(n_target, 2) + 0.5 + C = ot.dist(x_source, x_target) + + rows = [] + cols = [] + data = [] + + for i in range(n_source): + distances = C[i, :] + nearest_k = np.argpartition(distances, k)[:k] + for j in nearest_k: + rows.append(i) + cols.append(j) + data.append(C[i, j]) + + C_knn = coo_matrix((data, (rows, cols)), shape=(n_source, n_target)) + + # Augment with necessary edges (same approach as test_emd_sparse_vs_dense) + large_cost = 1e8 + C_dense_infty = np.full((n_source, n_target), large_cost) + C_knn_array = C_knn.toarray() + C_dense_infty[C_knn_array > 0] = C_knn_array[C_knn_array > 0] + + G_dense_initial = ot.emd(a, b, C_dense_infty) + eps = 1e-9 + active_mask = G_dense_initial > eps + knn_mask = C_knn_array > 0 + extra_edges_mask = active_mask & ~knn_mask + + rows_aug = [] + cols_aug = [] + data_aug = [] + + knn_rows, knn_cols = np.where(knn_mask) + for i, j in zip(knn_rows, knn_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + extra_rows, extra_cols = np.where(extra_edges_mask) + for i, j in zip(extra_rows, extra_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + C_augmented = coo_matrix( + (data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target) + ) + + G_np, log_np = ot.emd(a, b, C_augmented, log=True) + + ab, bb = nx.from_numpy(a, b) + # Convert sparse matrix to backend format using backend's coo_matrix method + rows_array = np.array(rows_aug, dtype=np.int64) + cols_array = np.array(cols_aug, dtype=np.int64) + data_array = np.array(data_aug) + # For backends that need specific dtypes for indices (e.g., TensorFlow), cast them + rows_backend = nx.from_numpy(rows_array) + cols_backend = nx.from_numpy(cols_array) + data_backend = nx.from_numpy(data_array, type_as=ab) + C_augmented_backend = nx.coo_matrix( + data_backend, rows_backend, cols_backend, shape=(n_source, n_target) + ) + G_backend, log_backend = ot.emd(ab, bb, C_augmented_backend, log=True) + + cost_np = log_np["cost"] + cost_backend = nx.to_numpy(log_backend["cost"]) + + np.testing.assert_allclose(cost_np, cost_backend, rtol=1e-5, atol=1e-7) + + from scipy.sparse import issparse + + assert issparse(G_np), "NumPy backend should return scipy.sparse matrix" + + # Convert both to dense numpy arrays for comparison + if issparse(G_np): + G_np_dense = G_np.toarray() + else: + G_np_dense = np.asarray(G_np) + + # Convert backend result to dense first, then to numpy + if nx.issparse(G_backend): + G_backend_dense = nx.to_numpy(nx.todense(G_backend)) + else: + G_backend_dense = nx.to_numpy(G_backend) + + if issparse(G_backend_dense): + G_backend_dense = G_backend_dense.toarray() + + np.testing.assert_allclose(G_np_dense, G_backend_dense, rtol=1e-5, atol=1e-7) + + +def test_emd2_sparse_backends(nx): + """Test that sparse emd2 works with different backends for weights a and b. + + Uses augmented k-NN graph approach to ensure feasibility. + """ + # Skip backends that don't support sparse matrices + backend_name = nx.__class__.__name__.lower() + if "jax" in backend_name or "tensorflow" in backend_name: + pytest.skip("Backend does not support sparse matrices") + + n_source = 50 + n_target = 50 + k = 10 + + rng = np.random.RandomState(42) + + a = ot.utils.unif(n_source) + b = ot.utils.unif(n_target) + + x_source = rng.randn(n_source, 2) + x_target = rng.randn(n_target, 2) + 0.5 + C = ot.dist(x_source, x_target) + + rows = [] + cols = [] + data = [] + + for i in range(n_source): + distances = C[i, :] + nearest_k = np.argpartition(distances, k)[:k] + for j in nearest_k: + rows.append(i) + cols.append(j) + data.append(C[i, j]) + + C_knn = coo_matrix((data, (rows, cols)), shape=(n_source, n_target)) + + # Augment with necessary edges (same approach as test_emd2_sparse_vs_dense) + large_cost = 1e8 + C_dense_infty = np.full((n_source, n_target), large_cost) + C_knn_array = C_knn.toarray() + C_dense_infty[C_knn_array > 0] = C_knn_array[C_knn_array > 0] + + G_dense_initial = ot.emd(a, b, C_dense_infty) + eps = 1e-9 + active_mask = G_dense_initial > eps + knn_mask = C_knn_array > 0 + extra_edges_mask = active_mask & ~knn_mask + + rows_aug = [] + cols_aug = [] + data_aug = [] + + knn_rows, knn_cols = np.where(knn_mask) + for i, j in zip(knn_rows, knn_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + extra_rows, extra_cols = np.where(extra_edges_mask) + for i, j in zip(extra_rows, extra_cols): + rows_aug.append(i) + cols_aug.append(j) + data_aug.append(C[i, j]) + + C_augmented = coo_matrix( + (data_aug, (rows_aug, cols_aug)), shape=(n_source, n_target) + ) + + cost_np = ot.emd2(a, b, C_augmented) + + ab, bb = nx.from_numpy(a, b) + # Convert sparse matrix to backend format + rows_array = np.array(rows_aug, dtype=np.int64) + cols_array = np.array(cols_aug, dtype=np.int64) + data_array = np.array(data_aug) + rows_backend = nx.from_numpy(rows_array) + cols_backend = nx.from_numpy(cols_array) + data_backend = nx.from_numpy(data_array, type_as=ab) + C_augmented_backend = nx.coo_matrix( + data_backend, rows_backend, cols_backend, shape=(n_source, n_target) + ) + + cost_backend = ot.emd2(ab, bb, C_augmented_backend) + + cost_backend_np = nx.to_numpy(cost_backend) + + np.testing.assert_allclose(cost_np, cost_backend_np, rtol=1e-5, atol=1e-7) + + def check_duality_gap(a, b, M, G, u, v, cost): cost_dual = np.vdot(a, u) + np.vdot(b, v) # Check that dual and primal cost are equal