Skip to content

Commit 4917c1f

Browse files
committed
Initial testing rewrite commit, let's just try it
1 parent d12644d commit 4917c1f

File tree

4 files changed

+307
-11
lines changed

4 files changed

+307
-11
lines changed

.github/workflows/analysis.yml

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,22 +158,25 @@ jobs:
158158
needs: merge
159159
steps:
160160
- uses: actions/checkout@v4
161-
- name: Set up R
162-
uses: r-lib/actions/setup-r@v2
163-
with:
164-
use-public-rspm: true
165-
- name: Install R dependencies
166-
uses: r-lib/actions/setup-r-dependencies@v2
161+
- name: Set up Python
162+
uses: actions/setup-python@v5
167163
with:
168-
packages: |
169-
any::tidyverse
170-
any::assertr
164+
python-version: "3.11"
165+
- name: Install dependencies
166+
run: |
167+
pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
168+
pip install -r requirements.txt
171169
- name: Download artifacts
172170
uses: actions/download-artifact@v4
173171
with:
174172
name: Data
175173
- name: Test against previous results
176-
run: Rscript --vanilla tests/IVIMmodels/unit_tests/compare.r test_output.csv test_reference.csv tests/IVIMmodels/unit_tests/reference_output.csv test_results.csv
174+
env:
175+
TEST_OUTPUT_CSV: test_output.csv
176+
TEST_REFERENCE_CSV: test_reference.csv
177+
REFERENCE_OUTPUT_CSV: tests/IVIMmodels/unit_tests/reference_output.csv
178+
TEST_RESULTS_CSV: test_results.csv
179+
run: python -m pytest tests/IVIMmodels/unit_tests/test_statistical_comparison.py
177180
- name: Upload data
178181
uses: actions/upload-artifact@v4
179182
if: always()

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ sphinx_rtd_theme
1818
pytest-json-report
1919
statsmodels
2020
ivimnet
21-
nlopt
21+
nlopt
22+
pyyaml
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Configuration for statistical tests of IVIM models.
2+
# This file defines the acceptance criteria for each test case, including
3+
# ground truth values, reference means, standard deviations, and tolerances.
4+
5+
ETP_SRI_LinearFitting:
6+
"Blood RV":
7+
100:
8+
ground_truth:
9+
f: 1.0
10+
Dp: 0.1
11+
D: 0.003
12+
acceptance_criteria:
13+
f:
14+
mean: 0.9038
15+
std_dev: 0.3742
16+
mean_tolerance: 0.2
17+
std_dev_tolerance_percent: 100.0
18+
Dp:
19+
mean: 0.0326
20+
std_dev: 0.0907
21+
mean_tolerance: 0.2
22+
std_dev_tolerance_percent: 100.0
23+
D:
24+
mean: 0.0012
25+
std_dev: 0.0020
26+
mean_tolerance: 0.002
27+
std_dev_tolerance_percent: 100.0
28+
"Myocardium LV":
29+
100:
30+
ground_truth:
31+
f: 0.15
32+
Dp: 0.08
33+
D: 0.0024
34+
acceptance_criteria:
35+
f:
36+
mean: 0.1504
37+
std_dev: 0.1653
38+
mean_tolerance: 0.1
39+
std_dev_tolerance_percent: 50.0
40+
Dp:
41+
mean: 0.0515
42+
std_dev: 0.1271
43+
mean_tolerance: 0.1
44+
std_dev_tolerance_percent: 50.0
45+
D:
46+
mean: 0.0024
47+
std_dev: 0.0003
48+
mean_tolerance: 0.001
49+
std_dev_tolerance_percent: 50.0
50+
"myocardium RV":
51+
100:
52+
ground_truth:
53+
f: 0.15
54+
Dp: 0.08
55+
D: 0.0024
56+
acceptance_criteria:
57+
f:
58+
mean: 0.1504
59+
std_dev: 0.1653
60+
mean_tolerance: 0.1
61+
std_dev_tolerance_percent: 50.0
62+
Dp:
63+
mean: 0.0515
64+
std_dev: 0.1271
65+
mean_tolerance: 0.1
66+
std_dev_tolerance_percent: 50.0
67+
D:
68+
mean: 0.0024
69+
std_dev: 0.0003
70+
mean_tolerance: 0.001
71+
std_dev_tolerance_percent: 50.0
72+
"myocardium ra":
73+
100:
74+
ground_truth:
75+
f: 0.07
76+
Dp: 0.07
77+
D: 0.0015
78+
acceptance_criteria:
79+
f:
80+
mean: 0.0749
81+
std_dev: 0.0736
82+
mean_tolerance: 0.2
83+
std_dev_tolerance_percent: 100.0
84+
Dp:
85+
mean: 0.0678
86+
std_dev: 0.2695
87+
mean_tolerance: 0.2
88+
std_dev_tolerance_percent: 100.0
89+
D:
90+
mean: 0.0015
91+
std_dev: 0.0001
92+
mean_tolerance: 0.002
93+
std_dev_tolerance_percent: 100.0
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import yaml
2+
import pytest
3+
import numpy as np
4+
from pathlib import Path
5+
import json
6+
from scipy import stats
7+
8+
# TODO: These should probably be fixtures
9+
from src.wrappers.OsipiBase import OsipiBase
10+
from utilities.data_simulation.GenerateData import GenerateData
11+
12+
13+
def load_config(path):
14+
"""Loads a YAML configuration file."""
15+
if not path.exists():
16+
return {}
17+
with open(path, "r") as f:
18+
return yaml.safe_load(f)
19+
20+
def save_config(config, path):
21+
"""Saves a YAML configuration file."""
22+
with open(path, "w") as f:
23+
yaml.dump(config, f, default_flow_style=False)
24+
25+
def get_algorithms():
26+
"""Loads the list of algorithms from the JSON file."""
27+
algorithms_path = Path(__file__).parent / "algorithms.json"
28+
with open(algorithms_path, "r") as f:
29+
return json.load(f)["algorithms"]
30+
31+
def generate_config_for_algorithm(algorithm):
32+
"""Generates reference data for a given algorithm."""
33+
fit_count = 300
34+
snr = 100
35+
rician_noise = True
36+
37+
regions = {
38+
"Blood RV": {"f": 1.0, "Dp": 0.1, "D": 0.003},
39+
"Myocardium LV": {"f": 0.15, "Dp": 0.08, "D": 0.0024},
40+
"myocardium RV": {"f": 0.15, "Dp": 0.08, "D": 0.0024},
41+
"myocardium ra": {"f": 0.07, "Dp": 0.07, "D": 0.0015},
42+
}
43+
bvals = np.array([0, 5, 10, 20, 30, 50, 75, 100, 150, 200, 300, 400, 500, 600, 700, 800])
44+
45+
new_config_entry = {}
46+
47+
print(f"Generating reference data for {algorithm}")
48+
49+
if "NET" in algorithm or "DC" in algorithm:
50+
print(" Skipping deep learning algorithm")
51+
return None
52+
53+
new_config_entry[algorithm] = {}
54+
for region_name, data in regions.items():
55+
print(f" Running {algorithm} for {region_name}")
56+
57+
rng = np.random.RandomState(42)
58+
S0 = 1
59+
gd = GenerateData(rng=rng)
60+
D = data["D"]
61+
f = data["f"]
62+
Dp = data["Dp"]
63+
64+
try:
65+
fit = OsipiBase(algorithm=algorithm)
66+
except Exception as e:
67+
print(f" Could not instantiate {algorithm}: {e}")
68+
continue
69+
70+
results = {"f": [], "Dp": [], "D": []}
71+
for idx in range(fit_count):
72+
signal = gd.ivim_signal(D, Dp, f, S0, bvals, snr, rician_noise)
73+
try:
74+
fit_result = fit.osipi_fit(signal, bvals)
75+
results["f"].append(fit_result["f"])
76+
results["Dp"].append(fit_result["Dp"])
77+
results["D"].append(fit_result["D"])
78+
except Exception as e:
79+
print(f" Fit failed for {algorithm} at index {idx}: {e}")
80+
81+
if results["f"]:
82+
f_mu = float(np.mean(results["f"]))
83+
Dp_mu = float(np.mean(results["Dp"]))
84+
D_mu = float(np.mean(results["D"]))
85+
f_std = float(np.std(results["f"]))
86+
Dp_std = float(np.std(results["Dp"]))
87+
D_std = float(np.std(results["D"]))
88+
89+
new_config_entry[algorithm][region_name] = {
90+
100: {
91+
"ground_truth": {"f": float(f), "Dp": float(Dp), "D": float(D)},
92+
"acceptance_criteria": {
93+
"f": {"mean": f_mu, "std_dev": f_std, "mean_tolerance": 0.2, "std_dev_tolerance_percent": 100.0},
94+
"Dp": {"mean": Dp_mu, "std_dev": Dp_std, "mean_tolerance": 0.2, "std_dev_tolerance_percent": 100.0},
95+
"D": {"mean": D_mu, "std_dev": D_std, "mean_tolerance": 0.002, "std_dev_tolerance_percent": 100.0},
96+
}
97+
}
98+
}
99+
return new_config_entry
100+
101+
def run_simulation_batch(algorithm, bvals, ground_truth, snr, batch_size, rician_noise):
102+
"""Runs a batch of simulations and returns the fitted parameters."""
103+
fit = OsipiBase(algorithm=algorithm)
104+
rng = np.random.RandomState()
105+
gd = GenerateData(rng=rng)
106+
S0 = 1
107+
108+
results = {"f": [], "Dp": [], "D": []}
109+
110+
for _ in range(batch_size):
111+
signal = gd.ivim_signal(ground_truth["D"], ground_truth["Dp"], ground_truth["f"], S0, bvals, snr, rician_noise)
112+
fit_result = fit.osipi_fit(signal, bvals)
113+
results["f"].append(fit_result["f"])
114+
results["Dp"].append(fit_result["Dp"])
115+
results["D"].append(fit_result["D"])
116+
return results
117+
118+
119+
@pytest.mark.parametrize("algorithm", get_algorithms())
120+
def test_statistical_equivalence(algorithm):
121+
"""
122+
Main test function to check statistical equivalence.
123+
"""
124+
config_path = Path(__file__).parent / "statistical_config.yml"
125+
config = load_config(config_path)
126+
127+
if algorithm not in config:
128+
proposed_config_path = Path(__file__).parent / "proposed_statistical_config.yml"
129+
proposed_config = load_config(proposed_config_path)
130+
131+
new_entry = generate_config_for_algorithm(algorithm)
132+
if new_entry:
133+
proposed_config.update(new_entry)
134+
save_config(proposed_config, proposed_config_path)
135+
136+
pytest.fail(f"Algorithm {algorithm} not in statistical_config.yml. A new entry has been proposed in proposed_statistical_config.yml.")
137+
138+
test_cases = config[algorithm]
139+
140+
# B-values for simulation - this might need to be adjusted based on real data
141+
bvals = np.array([0, 5, 10, 20, 30, 50, 75, 100, 150, 200, 300, 400, 500, 600, 700, 800])
142+
rician_noise = True
143+
batch_size = 25
144+
max_repetitions = 400 # 16 batches of 25
145+
alpha = 0.05 # For confidence intervals
146+
147+
print(f"Running tests for algorithm: {algorithm}")
148+
149+
all_tests_passed = True
150+
for region, snr_configs in test_cases.items():
151+
for snr, case_config in snr_configs.items():
152+
print(f" Testing Region: {region}, SNR: {snr}")
153+
154+
ground_truth = case_config["ground_truth"]
155+
acceptance_criteria = case_config["acceptance_criteria"]
156+
157+
all_results = {"f": [], "Dp": [], "D": []}
158+
159+
test_passed = False
160+
for i in range(max_repetitions // batch_size):
161+
print(f" Batch {i+1}")
162+
batch_results = run_simulation_batch(algorithm, bvals, ground_truth, snr, batch_size, rician_noise)
163+
164+
for param in all_results.keys():
165+
all_results[param].extend(batch_results[param])
166+
167+
# Update running statistics and check for early stopping
168+
passed_criteria = 0
169+
for param, criteria in acceptance_criteria.items():
170+
values = all_results[param]
171+
n = len(values)
172+
if n < 2:
173+
continue
174+
175+
mean = np.mean(values)
176+
std_dev = np.std(values, ddof=1)
177+
178+
# Confidence interval for the mean
179+
mean_ci = stats.t.interval(1 - alpha, n - 1, loc=mean, scale=stats.sem(values))
180+
181+
# Check if CI is within tolerance
182+
if (mean_ci[0] > criteria["mean"] - criteria["mean_tolerance"] and
183+
mean_ci[1] < criteria["mean"] + criteria["mean_tolerance"]):
184+
185+
# Check std dev tolerance
186+
std_dev_tolerance = criteria["std_dev"] * (criteria["std_dev_tolerance_percent"] / 100.0)
187+
if abs(std_dev - criteria["std_dev"]) < std_dev_tolerance:
188+
passed_criteria += 1
189+
190+
if passed_criteria == len(acceptance_criteria):
191+
print(" All criteria met, stopping early.")
192+
test_passed = True
193+
break
194+
195+
if not test_passed:
196+
print(f" Test failed for Region: {region}, SNR: {snr}")
197+
all_tests_passed = False
198+
199+
assert all_tests_passed, "One or more statistical equivalence tests failed."

0 commit comments

Comments
 (0)