Skip to content

Commit 16afa77

Browse files
authored
numpy -> jax.numpy everywhere (#102)
* Reaplce numpy imports in parameter node file * Replace numpy with jnp in MCMC helper functions * Replace np -> jnp in test_moments * np -> jnp in test_parameters * np -> jnp in test_norms * Fix test_parameter node since jnp.allclose can't handle lists
1 parent 9dd7a58 commit 16afa77

File tree

5 files changed

+25
-38
lines changed

5 files changed

+25
-38
lines changed

src/causalprog/graph/node/parameter.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
11
"""Graph nodes representing parameters."""
22

3-
from __future__ import annotations
4-
5-
import typing
6-
7-
import numpy as np
3+
import jax
4+
import jax.numpy as jnp
5+
import numpy.typing as npt
86
from typing_extensions import override
97

108
from .base import Node
119

12-
if typing.TYPE_CHECKING:
13-
import jax
14-
import numpy.typing as npt
15-
1610

1711
class ParameterNode(Node):
1812
"""
@@ -44,15 +38,15 @@ def __init__(self, *, label: str) -> None:
4438
def sample(
4539
self,
4640
parameter_values: dict[str, float],
47-
sampled_dependencies: dict[str, npt.NDArray[float]],
41+
sampled_dependencies: dict[str, npt.ArrayLike],
4842
samples: int,
4943
*,
5044
rng_key: jax.Array,
51-
) -> npt.NDArray[float]:
45+
) -> npt.ArrayLike:
5246
if self.label not in parameter_values:
5347
msg = f"Missing input for parameter node: {self.label}."
5448
raise ValueError(msg)
55-
return np.full(samples, parameter_values[self.label])
49+
return jnp.full(samples, parameter_values[self.label])
5650

5751
@override
5852
def copy(self) -> Node:

tests/fixtures/numpyro/mcmc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import Callable
44
from typing import Concatenate, TypeAlias
55

6-
import numpy as np
6+
import jax.numpy as jnp
77
import pytest
88
from jax import Array
99
from numpyro.infer import MCMC, NUTS
@@ -77,7 +77,7 @@ def _inner(left_mcmc: MCMC, right_mcmc: MCMC) -> None:
7777
f"Samples on left ({sample_name}) not present on right"
7878
)
7979
# Confirm samples match.
80-
assert np.allclose(sample_values, samples_r[sample_name]), (
80+
assert jnp.allclose(sample_values, samples_r[sample_name]), (
8181
f"Samples '{sample_name}' do not match"
8282
)
8383
for sample_name in samples_r:

tests/test_algorithms/test_moments.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Tests for moment algorithms."""
22

3-
import numpy as np
3+
import jax.numpy as jnp
44
import pytest
55

66
from causalprog import algorithms
@@ -26,14 +26,14 @@ def test_expectation_stdev_single_normal_node(
2626
graph = normal_graph(mean, stdev)
2727

2828
# Check within hand-computation
29-
assert np.isclose(
29+
assert jnp.isclose(
3030
algorithms.expectation(
3131
graph, outcome_node_label="X", samples=samples, rng_key=rng_key
3232
),
3333
mean,
3434
rtol=rtol,
3535
)
36-
assert np.isclose(
36+
assert jnp.isclose(
3737
algorithms.standard_deviation(
3838
graph, outcome_node_label="X", samples=samples, rng_key=rng_key
3939
),
@@ -79,18 +79,18 @@ def test_mean_stdev_two_node_graph(
7979

8080
graph = two_normal_graph(mean=mean, cov=stdev, cov2=stdev2)
8181

82-
assert np.isclose(
82+
assert jnp.isclose(
8383
algorithms.expectation(
8484
graph, outcome_node_label="X", samples=samples, rng_key=rng_key
8585
),
8686
mean,
8787
rtol=rtol,
8888
)
89-
assert np.isclose(
89+
assert jnp.isclose(
9090
algorithms.standard_deviation(
9191
graph, outcome_node_label="X", samples=samples, rng_key=rng_key
9292
),
93-
np.sqrt(stdev**2 + stdev2**2),
93+
jnp.sqrt(stdev**2 + stdev2**2),
9494
rtol=rtol,
9595
)
9696

@@ -108,7 +108,7 @@ def test_expectation(two_normal_graph, rng_key, samples, rtol):
108108
pytest.xfail("Test currently too slow")
109109
graph = two_normal_graph(1.0, 1.2, 0.8)
110110

111-
assert np.isclose(
111+
assert jnp.isclose(
112112
algorithms.expectation(
113113
graph, outcome_node_label="X", samples=samples, rng_key=rng_key
114114
),
@@ -132,7 +132,7 @@ def test_stdev(two_normal_graph, rng_key, samples, rtol):
132132
pytest.xfail("Test currently too slow")
133133
graph = two_normal_graph(1.0, 1.2, 0.8)
134134

135-
assert np.isclose(
135+
assert jnp.isclose(
136136
algorithms.standard_deviation(
137137
graph, outcome_node_label="X", samples=samples, rng_key=rng_key
138138
),
Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,8 @@
11
"""Tests for graph module."""
22

3-
from typing import Literal, TypeAlias
3+
import jax.numpy as jnp
44

5-
import numpy as np
6-
7-
from causalprog.graph import DistributionNode, ParameterNode
8-
9-
NormalGraphNodeNames: TypeAlias = Literal["mean", "cov", "X"]
10-
NormalGraphNodes: TypeAlias = dict[
11-
NormalGraphNodeNames, DistributionNode | ParameterNode
12-
]
5+
from causalprog.graph import ParameterNode
136

147

158
def test_parameter_node(rng_key, raises_context):
@@ -18,6 +11,6 @@ def test_parameter_node(rng_key, raises_context):
1811
with raises_context(ValueError("Missing input for parameter")):
1912
node.sample({}, {}, 1, rng_key=rng_key)
2013

21-
assert np.allclose(
22-
node.sample({node.label: 0.3}, {}, 10, rng_key=rng_key)[0], [0.3] * 10
14+
assert jnp.allclose(
15+
node.sample({node.label: 0.3}, {}, 10, rng_key=rng_key), jnp.full((10,), 0.3)
2316
)

tests/test_utils/test_norms.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections.abc import Callable
22

3-
import numpy as np
3+
import jax.numpy as jnp
44
import pytest
55

66
from causalprog.utils.norms import PyTree, l2_normsq
@@ -11,15 +11,15 @@
1111
[
1212
pytest.param(1.0, l2_normsq, 1.0, id="l2^2, scalar"),
1313
pytest.param(
14-
np.array([1.0, 2.0, 3.0]), l2_normsq, 14.0, id="l2^2, numpy array"
14+
jnp.array([1.0, 2.0, 3.0]), l2_normsq, 14.0, id="l2^2, numpy array"
1515
),
1616
pytest.param(
17-
{"a": 1.0, "b": (np.arange(3), [2.0, (-1.0, 0.0)])},
17+
{"a": 1.0, "b": (jnp.arange(3), [2.0, (-1.0, 0.0)])},
1818
l2_normsq,
19-
1.0 + (np.arange(3) ** 2).sum() + 4.0 + 1.0,
19+
1.0 + (jnp.arange(3) ** 2).sum() + 4.0 + 1.0,
2020
id="l2^2, PyTree",
2121
),
2222
],
2323
)
2424
def test_norm_value(pt: PyTree, norm: Callable[[PyTree], float], expected_value: float):
25-
assert np.allclose(norm(pt), expected_value)
25+
assert jnp.allclose(norm(pt), expected_value)

0 commit comments

Comments
 (0)