Skip to content
31 changes: 27 additions & 4 deletions ot/lp/_network_simplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np
import warnings
from scipy.sparse import issparse as scipy_issparse

from ..utils import list_to_array, check_number_threads
from ..backend import get_backend
Expand Down Expand Up @@ -298,10 +299,20 @@ def emd(
a, b = list_to_array(a, b)
nx = get_backend(a, b)

# Check if M is sparse using backend's issparse method
is_sparse = nx.issparse(M)
# Check if M is sparse (either backend sparse or scipy.sparse)
is_sparse = nx.issparse(M) or scipy_issparse(M)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better : implement is_sparse o backends and raise error saying it is not implemented, no need to test after


if is_sparse:
# Check if backend supports sparse matrices
backend_name = nx.__class__.__name__
if backend_name in ["JaxBackend", "TensorflowBackend"]:
raise NotImplementedError(
f"Sparse optimal transport is not supported for {backend_name}. "
"JAX does not have native sparse matrix support, and TensorFlow's "
"sparse implementation is incomplete. Please convert your sparse "
"matrix to dense format using M.toarray() or equivalent before calling emd()."
)

# Extract COO data using backend method - returns numpy arrays
edge_sources, edge_targets, edge_costs, (n1, n2) = nx.sparse_coo_data(M)

Expand Down Expand Up @@ -572,10 +583,22 @@ def emd2(
a, b = list_to_array(a, b)
nx = get_backend(a, b)

# Check if M is sparse using backend's issparse method
is_sparse = nx.issparse(M)
# Check if M is sparse (either backend sparse or scipy.sparse)
from scipy.sparse import issparse as scipy_issparse

is_sparse = nx.issparse(M) or scipy_issparse(M)

if is_sparse:
# Check if backend supports sparse matrices
backend_name = nx.__class__.__name__
if backend_name in ["JaxBackend", "TensorflowBackend"]:
raise NotImplementedError(
f"Sparse optimal transport is not supported for {backend_name}. "
"JAX does not have native sparse matrix support, and TensorFlow's "
"sparse implementation is incomplete. Please convert your sparse "
"matrix to dense format using M.toarray() or equivalent before calling emd2()."
)

# Extract COO data using backend method - returns numpy arrays
edge_sources, edge_targets, edge_costs, (n1, n2) = nx.sparse_coo_data(M)

Expand Down