Skip to content

Commit 400ae3c

Browse files
committed
Fix permutation testing saving and opening
1 parent 663ae50 commit 400ae3c

File tree

1 file changed

+53
-2
lines changed

1 file changed

+53
-2
lines changed

investing_algorithm_framework/domain/backtesting/backtest_permutation_test.py

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import List, Dict
66
import numpy as np
77
import pandas as pd
8+
from datetime import timezone
89

910
from .backtest_metrics import BacktestMetrics
1011

@@ -91,6 +92,15 @@ def summary(
9192
) -> Dict[str, Dict[str, float]]:
9293
"""
9394
Return a summary of real values, mean permuted values, and p-values.
95+
96+
Args:
97+
metrics (List[str]): List of metric names to include
98+
in the summary. If None, uses DEFAULT_METRICS.
99+
100+
Returns:
101+
Dict[str, Dict[str, float]]: A dictionary where each key
102+
is a metric name and the value is another dictionary
103+
with keys 'real', 'permuted_mean', and 'p_value'.
94104
"""
95105

96106
if metrics is None:
@@ -123,7 +133,7 @@ def summary(
123133
continue
124134

125135
summary_dict[metric] = {
126-
"real": real_value,
136+
"real": float(real_value),
127137
"permuted_mean": float(np.mean(dist)),
128138
"p_value": self.p_values.get(metric, None),
129139
}
@@ -140,6 +150,13 @@ def save(self, path: str) -> None:
140150
Returns:
141151
None
142152
"""
153+
def ensure_iso(value):
154+
if hasattr(value, "isoformat"):
155+
if value.tzinfo is None:
156+
value = value.replace(tzinfo=timezone.utc)
157+
return value.isoformat()
158+
return value
159+
143160
os.makedirs(path, exist_ok=True)
144161

145162
# Save the real metrics
@@ -149,12 +166,23 @@ def save(self, path: str) -> None:
149166
os.makedirs(permuted_dir, exist_ok=True)
150167

151168
for i, pm in enumerate(self.permutated_metrics):
152-
pm.save(os.path.join(permuted_dir, f"permuted_{i}"))
169+
pm.save(os.path.join(permuted_dir, f"permuted_{i}.json"))
153170

154171
# Save the P-values
155172
with open(os.path.join(path, "p_values.json"), "w") as f:
156173
json.dump(self.p_values, f)
157174

175+
# Create a metadata file to store additional info such as
176+
# date range name, start and end dates
177+
metadata = {
178+
"backtest_start_date": ensure_iso(self.backtest_start_date),
179+
"backtest_date_range_name": self.backtest_date_range_name,
180+
"backtest_end_date": ensure_iso(self.backtest_end_date),
181+
}
182+
183+
with open(os.path.join(path, "metadata.json"), "w") as f:
184+
json.dump(metadata, f)
185+
158186
@staticmethod
159187
def open(path: str) -> "BacktestPermutationTest":
160188
"""
@@ -189,10 +217,33 @@ def open(path: str) -> "BacktestPermutationTest":
189217
with open(p_values_path, "r") as f:
190218
p_values = json.load(f)
191219

220+
# Load metadata
221+
metadata_path = os.path.join(path, "metadata.json")
222+
backtest_start_date = None
223+
backtest_end_date = None
224+
backtest_date_range_name = None
225+
226+
if os.path.exists(metadata_path):
227+
with open(metadata_path, "r") as f:
228+
metadata = json.load(f)
229+
230+
backtest_start_date = pd.to_datetime(
231+
metadata.get("backtest_start_date"), utc=True
232+
)
233+
backtest_end_date = pd.to_datetime(
234+
metadata.get("backtest_end_date"), utc=True
235+
)
236+
backtest_date_range_name = metadata.get(
237+
"backtest_date_range_name"
238+
)
239+
192240
return BacktestPermutationTest(
193241
real_metrics=real_metrics,
194242
permutated_metrics=permutated_metrics,
195243
p_values=p_values,
244+
backtest_start_date=backtest_start_date,
245+
backtest_end_date=backtest_end_date,
246+
backtest_date_range_name=backtest_date_range_name
196247
)
197248

198249
def create_directory_name(self) -> str:

0 commit comments

Comments
 (0)