11"""Tests for moment algorithms."""
22
3- import numpy as np
3+ import jax . numpy as jnp
44import pytest
55
66from causalprog import algorithms
@@ -26,14 +26,14 @@ def test_expectation_stdev_single_normal_node(
2626 graph = normal_graph (mean , stdev )
2727
2828 # Check within hand-computation
29- assert np .isclose (
29+ assert jnp .isclose (
3030 algorithms .expectation (
3131 graph , outcome_node_label = "X" , samples = samples , rng_key = rng_key
3232 ),
3333 mean ,
3434 rtol = rtol ,
3535 )
36- assert np .isclose (
36+ assert jnp .isclose (
3737 algorithms .standard_deviation (
3838 graph , outcome_node_label = "X" , samples = samples , rng_key = rng_key
3939 ),
@@ -79,18 +79,18 @@ def test_mean_stdev_two_node_graph(
7979
8080 graph = two_normal_graph (mean = mean , cov = stdev , cov2 = stdev2 )
8181
82- assert np .isclose (
82+ assert jnp .isclose (
8383 algorithms .expectation (
8484 graph , outcome_node_label = "X" , samples = samples , rng_key = rng_key
8585 ),
8686 mean ,
8787 rtol = rtol ,
8888 )
89- assert np .isclose (
89+ assert jnp .isclose (
9090 algorithms .standard_deviation (
9191 graph , outcome_node_label = "X" , samples = samples , rng_key = rng_key
9292 ),
93- np .sqrt (stdev ** 2 + stdev2 ** 2 ),
93+ jnp .sqrt (stdev ** 2 + stdev2 ** 2 ),
9494 rtol = rtol ,
9595 )
9696
@@ -108,7 +108,7 @@ def test_expectation(two_normal_graph, rng_key, samples, rtol):
108108 pytest .xfail ("Test currently too slow" )
109109 graph = two_normal_graph (1.0 , 1.2 , 0.8 )
110110
111- assert np .isclose (
111+ assert jnp .isclose (
112112 algorithms .expectation (
113113 graph , outcome_node_label = "X" , samples = samples , rng_key = rng_key
114114 ),
@@ -132,7 +132,7 @@ def test_stdev(two_normal_graph, rng_key, samples, rtol):
132132 pytest .xfail ("Test currently too slow" )
133133 graph = two_normal_graph (1.0 , 1.2 , 0.8 )
134134
135- assert np .isclose (
135+ assert jnp .isclose (
136136 algorithms .standard_deviation (
137137 graph , outcome_node_label = "X" , samples = samples , rng_key = rng_key
138138 ),
0 commit comments