Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ include ot/lp/full_bipartitegraph.h
include ot/lp/full_bipartitegraph_omp.h
include ot/lp/network_simplex_simple.h
include ot/lp/network_simplex_simple_omp.h
include ot/lp/sparse_bipartitegraph.h
include ot/partial/partial_cython.pyx
Binary file added docs/source/_static/images/comparison.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
225 changes: 225 additions & 0 deletions examples/plot_sparse_emd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# -*- coding: utf-8 -*-
"""
============================================
Sparse Optimal Transport
============================================

In many real-world optimal transport (OT) problems, the transport plan is
naturally sparse: only a small fraction of all possible source-target pairs
actually exchange mass. Using sparse OT solvers can provide significant
computational speedups and memory savings compared to dense solvers.

This example demonstrates how to use sparse cost matrices with POT's EMD solver,
comparing sparse and dense formulations on both a minimal example and a larger
concentric circles dataset.
"""

# Author: Nathan Neike <nathan.neike@example.com>
# License: MIT License
# sphinx_gallery_thumbnail_number = 2

import numpy as np
import matplotlib.pyplot as plt
from scipy.sparse import coo_matrix
import ot


##############################################################################
# Minimal example with 4 points
# ------------------------------

# %%

X = np.array([[0, 0], [1, 0], [0.5, 0], [1.5, 0]])
Y = np.array([[0, 1], [1, 1], [0.5, 1], [1.5, 1]])
a = np.array([0.25, 0.25, 0.25, 0.25])
b = np.array([0.25, 0.25, 0.25, 0.25])

# Build sparse cost matrix allowing only selected edges
rows = [0, 1, 2, 3]
cols = [0, 1, 2, 3]
vals = [np.linalg.norm(X[i] - Y[j]) for i, j in zip(rows, cols)]
M_sparse = coo_matrix((vals, (rows, cols)), shape=(4, 4))


##############################################################################
# Solve and display sparse OT solution
# -------------------------------------

# %%

G, log = ot.emd(a, b, M_sparse, log=True)

print("Sparse OT cost:", log["cost"])
print("Solution format:", type(G))
print("Non-zero edges:", G.nnz)
print("\nEdges:")
G_coo = G if isinstance(G, coo_matrix) else G.tocoo()
for i, j, v in zip(G_coo.row, G_coo.col, G_coo.data):
if v > 1e-10:
print(f" source {i} -> target {j}, flow={v:.3f}")


##############################################################################
# Visualize sparse vs dense edge structure
# -----------------------------------------

# %%

plt.figure(figsize=(8, 4))

plt.subplot(1, 2, 1)
plt.scatter(X[:, 0], X[:, 1], c="r", marker="o", s=100, zorder=3)
plt.scatter(Y[:, 0], Y[:, 1], c="b", marker="x", s=100, zorder=3)
for i, j in zip(rows, cols):
plt.plot([X[i, 0], Y[j, 0]], [X[i, 1], Y[j, 1]], "b-", linewidth=1, alpha=0.6)
plt.title("Sparse OT: Allowed Edges Only")
plt.xlim(-0.5, 2.0)
plt.ylim(-0.5, 1.5)

plt.subplot(1, 2, 2)
plt.scatter(X[:, 0], X[:, 1], c="r", marker="o", s=100, zorder=3)
plt.scatter(Y[:, 0], Y[:, 1], c="b", marker="x", s=100, zorder=3)
for i in range(len(X)):
for j in range(len(Y)):
plt.plot([X[i, 0], Y[j, 0]], [X[i, 1], Y[j, 1]], "b-", linewidth=1, alpha=0.3)
plt.title("Dense OT: All Possible Edges")
plt.xlim(-0.5, 2.0)
plt.ylim(-0.5, 1.5)

plt.tight_layout()


##############################################################################
# Larger example: concentric circles
# -----------------------------------

# %%

n_clusters = 8
points_per_cluster = 25
n = n_clusters * points_per_cluster
k_neighbors = 8
rng = np.random.default_rng(0)

r_source = 1.0
r_target = 2.0
noise_scale = 0.06

theta = np.linspace(0.0, 2.0 * np.pi, n, endpoint=False)
cluster_labels = np.repeat(np.arange(n_clusters), points_per_cluster)

X_large = np.column_stack(
[r_source * np.cos(theta), r_source * np.sin(theta)]
) + rng.normal(scale=noise_scale, size=(n, 2))
Y_large = np.column_stack(
[r_target * np.cos(theta), r_target * np.sin(theta)]
) + rng.normal(scale=noise_scale, size=(n, 2))

a_large = np.zeros(n)
b_large = np.zeros(n)
for k in range(n_clusters):
idx = np.where(cluster_labels == k)[0]
a_large[idx] = 1.0 / n_clusters / points_per_cluster
b_large[idx] = 1.0 / n_clusters / points_per_cluster

M_full = ot.dist(X_large, Y_large, metric="euclidean")

# Build sparse cost matrix: intra-cluster k-nearest neighbors
angles_X = np.arctan2(X_large[:, 1], X_large[:, 0])
angles_Y = np.arctan2(Y_large[:, 1], Y_large[:, 0])

rows = []
cols = []
vals = []
for k in range(n_clusters):
src_idx = np.where(cluster_labels == k)[0]
tgt_idx = np.where(cluster_labels == k)[0]
for i in src_idx:
diff = np.angle(np.exp(1j * (angles_Y[tgt_idx] - angles_X[i])))
idx = np.argsort(np.abs(diff))[:k_neighbors]
for j_local in idx:
j = tgt_idx[j_local]
rows.append(i)
cols.append(j)
vals.append(M_full[i, j])

M_sparse_large = coo_matrix((vals, (rows, cols)), shape=(n, n))
allowed_sparse = set(zip(rows, cols))

##############################################################################
# Visualize edge structures
# --------------------------

# %%

plt.figure(figsize=(16, 6))

plt.subplot(1, 2, 1)
for i in range(n):
for j in range(n):
plt.plot(
[X_large[i, 0], Y_large[j, 0]],
[X_large[i, 1], Y_large[j, 1]],
color="blue",
alpha=0.2,
linewidth=0.05,
)
plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20)
plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20)
plt.axis("equal")
plt.title("Dense OT: All Possible Edges")

plt.subplot(1, 2, 2)
for i, j in allowed_sparse:
plt.plot(
[X_large[i, 0], Y_large[j, 0]],
[X_large[i, 1], Y_large[j, 1]],
color="blue",
alpha=1,
linewidth=0.05,
)
plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20)
plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20)
plt.axis("equal")
plt.title("Sparse OT: Intra-Cluster k-NN Edges")

plt.tight_layout()
plt.show()

##############################################################################
# Solve and visualize transport plans
# ------------------------------------

# %%

G_dense = ot.emd(a_large, b_large, M_full)
cost_dense = np.sum(G_dense * M_full)
print(f"Dense OT cost: {cost_dense:.6f}")

G_sparse, log_sparse = ot.emd(a_large, b_large, M_sparse_large, log=True)
cost_sparse = log_sparse["cost"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need for log np.sum(G_sparse*M_sparse_large) shoudl work

print(f"Sparse OT cost: {cost_sparse:.6f}")

plt.figure(figsize=(16, 6))

plt.subplot(1, 2, 1)
ot.plot.plot2D_samples_mat(
X_large, Y_large, G_dense, thr=1e-10, c=[0.5, 0.5, 1], alpha=0.5
)
plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20, zorder=3)
plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20, zorder=3)
plt.axis("equal")
plt.title("Dense OT: Optimal Transport Plan")

plt.subplot(1, 2, 2)
ot.plot.plot2D_samples_mat(
X_large, Y_large, G_sparse, thr=1e-10, c=[0.5, 0.5, 1], alpha=0.5
)
plt.scatter(X_large[:, 0], X_large[:, 1], c="r", marker="o", s=20, zorder=3)
plt.scatter(Y_large[:, 0], Y_large[:, 1], c="b", marker="x", s=20, zorder=3)
plt.axis("equal")
plt.title("Sparse OT: Optimal Transport Plan")

plt.tight_layout()
plt.show()
68 changes: 67 additions & 1 deletion ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,16 @@ def _get_backend_instance(backend_impl):


def _check_args_backend(backend_impl, args):
is_instance = set(isinstance(arg, backend_impl.__type__) for arg in args)
# Get backend instance to use issparse method
backend = _get_backend_instance(backend_impl)

# Check if each arg is either:
# 1. An instance of backend.__type__ (e.g., np.ndarray for NumPy)
# 2. A sparse matrix recognized by backend.issparse() (e.g., scipy.sparse for NumPy)
is_instance = set(
isinstance(arg, backend_impl.__type__) or backend.issparse(arg) for arg in args
)

# check that all arguments matched or not the type
if len(is_instance) == 1:
return is_instance.pop()
Expand Down Expand Up @@ -839,6 +848,31 @@ def todense(self, a):
"""
raise NotImplementedError()

def sparse_coo_data(self, a):
r"""
Extracts COO format data (row, col, data, shape) from a sparse matrix.

Returns row indices, column indices, data values, and shape as numpy arrays/tuple.
This is used to interface with C++ solvers that require explicit edge lists.

Parameters
----------
a : sparse matrix
Sparse matrix in backend's COO format

Returns
-------
row : numpy.ndarray
Row indices (1D array)
col : numpy.ndarray
Column indices (1D array)
data : numpy.ndarray
Data values (1D array)
shape : tuple
Shape of the matrix (n_rows, n_cols)
"""
raise NotImplementedError()

def where(self, condition, x, y):
r"""
Returns elements chosen from x or y depending on condition.
Expand Down Expand Up @@ -1349,6 +1383,15 @@ def todense(self, a):
else:
return a

def sparse_coo_data(self, a):
# Convert to COO format if needed
if not isinstance(a, coo_matrix):
a_coo = coo_matrix(a)
else:
a_coo = a

return a_coo.row, a_coo.col, a_coo.data, a_coo.shape

def where(self, condition, x=None, y=None):
if x is None and y is None:
return np.where(condition)
Expand Down Expand Up @@ -1768,6 +1811,15 @@ def todense(self, a):
# Currently, JAX does not support sparse matrices
return a

def sparse_coo_data(self, a):
# JAX doesn't support sparse matrices, so this shouldn't be called
# But if it is, convert the dense array to sparse using scipy
a_np = self.to_numpy(a)
from scipy.sparse import coo_matrix

a_coo = coo_matrix(a_np)
return a_coo.row, a_coo.col, a_coo.data, a_coo.shape

def where(self, condition, x=None, y=None):
if x is None and y is None:
return jnp.where(condition)
Expand Down Expand Up @@ -2340,6 +2392,20 @@ def todense(self, a):
else:
return a

def sparse_coo_data(self, a):
# For torch sparse tensors, coalesce first to ensure unique indices
a_coalesced = a.coalesce()
indices = a_coalesced._indices()
values = a_coalesced._values()

# Convert to numpy
row = self.to_numpy(indices[0])
col = self.to_numpy(indices[1])
data = self.to_numpy(values)
shape = tuple(a_coalesced.shape)

return row, col, data, shape

def where(self, condition, x=None, y=None):
if x is None and y is None:
return torch.where(condition)
Expand Down
18 changes: 18 additions & 0 deletions ot/lp/EMD.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,24 @@ enum ProblemType {
int EMD_wrap(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter);
int EMD_wrap_omp(int n1,int n2, double *X, double *Y,double *D, double *G, double* alpha, double* beta, double *cost, uint64_t maxIter, int numThreads);

int EMD_wrap_sparse(
int n1,
int n2,
double *X,
double *Y,
uint64_t n_edges, // Number of edges in sparse graph
uint64_t *edge_sources, // Source indices for each edge (n_edges)
uint64_t *edge_targets, // Target indices for each edge (n_edges)
double *edge_costs, // Cost for each edge (n_edges)
uint64_t *flow_sources_out, // Output: source indices of non-zero flows
uint64_t *flow_targets_out, // Output: target indices of non-zero flows
double *flow_values_out, // Output: flow values
uint64_t *n_flows_out,
double *alpha, // Output: dual variables for sources (n1)
double *beta, // Output: dual variables for targets (n2)
double *cost, // Output: total transportation cost
uint64_t maxIter // Maximum iterations for solver
);


#endif
Loading
Loading