|
| 1 | +import yaml |
| 2 | +import pytest |
| 3 | +import numpy as np |
| 4 | +from pathlib import Path |
| 5 | +import json |
| 6 | +from scipy import stats |
| 7 | + |
| 8 | +# TODO: These should probably be fixtures |
| 9 | +from src.wrappers.OsipiBase import OsipiBase |
| 10 | +from utilities.data_simulation.GenerateData import GenerateData |
| 11 | + |
| 12 | + |
| 13 | +def load_config(path): |
| 14 | + """Loads a YAML configuration file.""" |
| 15 | + if not path.exists(): |
| 16 | + return {} |
| 17 | + with open(path, "r") as f: |
| 18 | + return yaml.safe_load(f) |
| 19 | + |
| 20 | +def save_config(config, path): |
| 21 | + """Saves a YAML configuration file.""" |
| 22 | + with open(path, "w") as f: |
| 23 | + yaml.dump(config, f, default_flow_style=False) |
| 24 | + |
| 25 | +def get_algorithms(): |
| 26 | + """Loads the list of algorithms from the JSON file.""" |
| 27 | + algorithms_path = Path(__file__).parent / "algorithms.json" |
| 28 | + with open(algorithms_path, "r") as f: |
| 29 | + return json.load(f)["algorithms"] |
| 30 | + |
| 31 | +def generate_config_for_algorithm(algorithm): |
| 32 | + """Generates reference data for a given algorithm.""" |
| 33 | + fit_count = 300 |
| 34 | + snr = 100 |
| 35 | + rician_noise = True |
| 36 | + |
| 37 | + regions = { |
| 38 | + "Blood RV": {"f": 1.0, "Dp": 0.1, "D": 0.003}, |
| 39 | + "Myocardium LV": {"f": 0.15, "Dp": 0.08, "D": 0.0024}, |
| 40 | + "myocardium RV": {"f": 0.15, "Dp": 0.08, "D": 0.0024}, |
| 41 | + "myocardium ra": {"f": 0.07, "Dp": 0.07, "D": 0.0015}, |
| 42 | + } |
| 43 | + bvals = np.array([0, 5, 10, 20, 30, 50, 75, 100, 150, 200, 300, 400, 500, 600, 700, 800]) |
| 44 | + |
| 45 | + new_config_entry = {} |
| 46 | + |
| 47 | + print(f"Generating reference data for {algorithm}") |
| 48 | + |
| 49 | + if "NET" in algorithm or "DC" in algorithm: |
| 50 | + print(" Skipping deep learning algorithm") |
| 51 | + return None |
| 52 | + |
| 53 | + new_config_entry[algorithm] = {} |
| 54 | + for region_name, data in regions.items(): |
| 55 | + print(f" Running {algorithm} for {region_name}") |
| 56 | + |
| 57 | + rng = np.random.RandomState(42) |
| 58 | + S0 = 1 |
| 59 | + gd = GenerateData(rng=rng) |
| 60 | + D = data["D"] |
| 61 | + f = data["f"] |
| 62 | + Dp = data["Dp"] |
| 63 | + |
| 64 | + try: |
| 65 | + fit = OsipiBase(algorithm=algorithm) |
| 66 | + except Exception as e: |
| 67 | + print(f" Could not instantiate {algorithm}: {e}") |
| 68 | + continue |
| 69 | + |
| 70 | + results = {"f": [], "Dp": [], "D": []} |
| 71 | + for idx in range(fit_count): |
| 72 | + signal = gd.ivim_signal(D, Dp, f, S0, bvals, snr, rician_noise) |
| 73 | + try: |
| 74 | + fit_result = fit.osipi_fit(signal, bvals) |
| 75 | + results["f"].append(fit_result["f"]) |
| 76 | + results["Dp"].append(fit_result["Dp"]) |
| 77 | + results["D"].append(fit_result["D"]) |
| 78 | + except Exception as e: |
| 79 | + print(f" Fit failed for {algorithm} at index {idx}: {e}") |
| 80 | + |
| 81 | + if results["f"]: |
| 82 | + f_mu = float(np.mean(results["f"])) |
| 83 | + Dp_mu = float(np.mean(results["Dp"])) |
| 84 | + D_mu = float(np.mean(results["D"])) |
| 85 | + f_std = float(np.std(results["f"])) |
| 86 | + Dp_std = float(np.std(results["Dp"])) |
| 87 | + D_std = float(np.std(results["D"])) |
| 88 | + |
| 89 | + new_config_entry[algorithm][region_name] = { |
| 90 | + 100: { |
| 91 | + "ground_truth": {"f": float(f), "Dp": float(Dp), "D": float(D)}, |
| 92 | + "acceptance_criteria": { |
| 93 | + "f": {"mean": f_mu, "std_dev": f_std, "mean_tolerance": 0.2, "std_dev_tolerance_percent": 100.0}, |
| 94 | + "Dp": {"mean": Dp_mu, "std_dev": Dp_std, "mean_tolerance": 0.2, "std_dev_tolerance_percent": 100.0}, |
| 95 | + "D": {"mean": D_mu, "std_dev": D_std, "mean_tolerance": 0.002, "std_dev_tolerance_percent": 100.0}, |
| 96 | + } |
| 97 | + } |
| 98 | + } |
| 99 | + return new_config_entry |
| 100 | + |
| 101 | +def run_simulation_batch(algorithm, bvals, ground_truth, snr, batch_size, rician_noise): |
| 102 | + """Runs a batch of simulations and returns the fitted parameters.""" |
| 103 | + fit = OsipiBase(algorithm=algorithm) |
| 104 | + rng = np.random.RandomState() |
| 105 | + gd = GenerateData(rng=rng) |
| 106 | + S0 = 1 |
| 107 | + |
| 108 | + results = {"f": [], "Dp": [], "D": []} |
| 109 | + |
| 110 | + for _ in range(batch_size): |
| 111 | + signal = gd.ivim_signal(ground_truth["D"], ground_truth["Dp"], ground_truth["f"], S0, bvals, snr, rician_noise) |
| 112 | + fit_result = fit.osipi_fit(signal, bvals) |
| 113 | + results["f"].append(fit_result["f"]) |
| 114 | + results["Dp"].append(fit_result["Dp"]) |
| 115 | + results["D"].append(fit_result["D"]) |
| 116 | + return results |
| 117 | + |
| 118 | + |
| 119 | +@pytest.mark.parametrize("algorithm", get_algorithms()) |
| 120 | +def test_statistical_equivalence(algorithm): |
| 121 | + """ |
| 122 | + Main test function to check statistical equivalence. |
| 123 | + """ |
| 124 | + config_path = Path(__file__).parent / "statistical_config.yml" |
| 125 | + config = load_config(config_path) |
| 126 | + |
| 127 | + if algorithm not in config: |
| 128 | + proposed_config_path = Path(__file__).parent / "proposed_statistical_config.yml" |
| 129 | + proposed_config = load_config(proposed_config_path) |
| 130 | + |
| 131 | + new_entry = generate_config_for_algorithm(algorithm) |
| 132 | + if new_entry: |
| 133 | + proposed_config.update(new_entry) |
| 134 | + save_config(proposed_config, proposed_config_path) |
| 135 | + |
| 136 | + pytest.fail(f"Algorithm {algorithm} not in statistical_config.yml. A new entry has been proposed in proposed_statistical_config.yml.") |
| 137 | + |
| 138 | + test_cases = config[algorithm] |
| 139 | + |
| 140 | + # B-values for simulation - this might need to be adjusted based on real data |
| 141 | + bvals = np.array([0, 5, 10, 20, 30, 50, 75, 100, 150, 200, 300, 400, 500, 600, 700, 800]) |
| 142 | + rician_noise = True |
| 143 | + batch_size = 25 |
| 144 | + max_repetitions = 400 # 16 batches of 25 |
| 145 | + alpha = 0.05 # For confidence intervals |
| 146 | + |
| 147 | + print(f"Running tests for algorithm: {algorithm}") |
| 148 | + |
| 149 | + all_tests_passed = True |
| 150 | + for region, snr_configs in test_cases.items(): |
| 151 | + for snr, case_config in snr_configs.items(): |
| 152 | + print(f" Testing Region: {region}, SNR: {snr}") |
| 153 | + |
| 154 | + ground_truth = case_config["ground_truth"] |
| 155 | + acceptance_criteria = case_config["acceptance_criteria"] |
| 156 | + |
| 157 | + all_results = {"f": [], "Dp": [], "D": []} |
| 158 | + |
| 159 | + test_passed = False |
| 160 | + for i in range(max_repetitions // batch_size): |
| 161 | + print(f" Batch {i+1}") |
| 162 | + batch_results = run_simulation_batch(algorithm, bvals, ground_truth, snr, batch_size, rician_noise) |
| 163 | + |
| 164 | + for param in all_results.keys(): |
| 165 | + all_results[param].extend(batch_results[param]) |
| 166 | + |
| 167 | + # Update running statistics and check for early stopping |
| 168 | + passed_criteria = 0 |
| 169 | + for param, criteria in acceptance_criteria.items(): |
| 170 | + values = all_results[param] |
| 171 | + n = len(values) |
| 172 | + if n < 2: |
| 173 | + continue |
| 174 | + |
| 175 | + mean = np.mean(values) |
| 176 | + std_dev = np.std(values, ddof=1) |
| 177 | + |
| 178 | + # Confidence interval for the mean |
| 179 | + mean_ci = stats.t.interval(1 - alpha, n - 1, loc=mean, scale=stats.sem(values)) |
| 180 | + |
| 181 | + # Check if CI is within tolerance |
| 182 | + if (mean_ci[0] > criteria["mean"] - criteria["mean_tolerance"] and |
| 183 | + mean_ci[1] < criteria["mean"] + criteria["mean_tolerance"]): |
| 184 | + |
| 185 | + # Check std dev tolerance |
| 186 | + std_dev_tolerance = criteria["std_dev"] * (criteria["std_dev_tolerance_percent"] / 100.0) |
| 187 | + if abs(std_dev - criteria["std_dev"]) < std_dev_tolerance: |
| 188 | + passed_criteria += 1 |
| 189 | + |
| 190 | + if passed_criteria == len(acceptance_criteria): |
| 191 | + print(" All criteria met, stopping early.") |
| 192 | + test_passed = True |
| 193 | + break |
| 194 | + |
| 195 | + if not test_passed: |
| 196 | + print(f" Test failed for Region: {region}, SNR: {snr}") |
| 197 | + all_tests_passed = False |
| 198 | + |
| 199 | + assert all_tests_passed, "One or more statistical equivalence tests failed." |
0 commit comments