Skip to content

Commit 1067229

Browse files
committed
Add metadata attributes
1 parent 19e2fac commit 1067229

File tree

4 files changed

+61
-27
lines changed

4 files changed

+61
-27
lines changed

investing_algorithm_framework/domain/backtesting/backtest.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,18 @@ def get_backtest_run(
9191
return run
9292
return None
9393

94+
def get_all_backtest_permutation_tests(
95+
self
96+
) -> List[BacktestPermutationTest]:
97+
"""
98+
Retrieve all BacktestPermutationTest instances from the backtest.
99+
100+
Returns:
101+
List[BacktestPermutationTest]: A list of all
102+
BacktestPermutationTest instances.
103+
"""
104+
return self.backtest_permutation_tests
105+
94106
def get_backtest_permutation_test(
95107
self, date_range: BacktestDateRange
96108
) -> Union[BacktestPermutationTest, None]:

investing_algorithm_framework/domain/backtesting/backtest_permutation_test.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def save(self, path: str) -> None:
137137
os.makedirs(path, exist_ok=True)
138138

139139
# Save the real metrics
140-
self.real_metrics.save(os.path.join(path, "original_metrics"))
140+
self.real_metrics.save(os.path.join(path, "original_metrics.json"))
141141

142142
permuted_dir = os.path.join(path, "permuted_metrics")
143143
os.makedirs(permuted_dir, exist_ok=True)
@@ -154,36 +154,33 @@ def open(path: str) -> "BacktestPermutationTest":
154154
"""
155155
Load the permutation test results from disk (JSON + Parquet).
156156
"""
157-
with open(os.path.join(path, "results.json"), "r") as f:
158-
results = json.load(f)
157+
original_metrics = os.path.join(path, "original_metrics.json")
159158

160159
# Rehydrate BacktestMetrics
161-
real_metrics = BacktestMetrics(**results["real_metrics"])
162-
permutated_metrics = [
163-
BacktestMetrics(**pm) for pm in results["permutated_metrics"]
164-
]
165-
166-
# Reload DataFrames
167-
ohlcv_original_datasets = {}
168-
ohlcv_permutated_datasets = {}
169-
for file in os.listdir(path):
170-
if file.startswith("original_") and file.endswith(".parquet"):
171-
key = file.replace("original_", "").replace(".parquet", "")
172-
ohlcv_original_datasets[key] = pd.read_parquet(
173-
os.path.join(path, file)
174-
)
175-
elif file.startswith("permuted_") and file.endswith(".parquet"):
176-
key = file.replace("permuted_", "").replace(".parquet", "")
177-
ohlcv_permutated_datasets[key] = pd.read_parquet(
178-
os.path.join(path, file)
179-
)
160+
real_metrics = BacktestMetrics.open(original_metrics)
161+
162+
permuted_dir = os.path.join(path, "permuted_metrics")
163+
164+
permutated_metrics = []
165+
if os.path.exists(permuted_dir):
166+
for fname in os.listdir(permuted_dir):
167+
if fname.startswith("permuted_"):
168+
pm = BacktestMetrics.open(
169+
os.path.join(permuted_dir, fname)
170+
)
171+
permutated_metrics.append(pm)
172+
173+
p_values_path = os.path.join(path, "p_values.json")
174+
p_values = {}
175+
176+
if os.path.exists(p_values_path):
177+
with open(p_values_path, "r") as f:
178+
p_values = json.load(f)
180179

181180
return BacktestPermutationTest(
182181
real_metrics=real_metrics,
183182
permutated_metrics=permutated_metrics,
184-
p_values=results["p_values"],
185-
ohlcv_original_datasets=ohlcv_original_datasets,
186-
ohlcv_permutated_datasets=ohlcv_permutated_datasets
183+
p_values=p_values,
187184
)
188185

189186
def create_directory_name(self) -> str:

investing_algorithm_framework/domain/backtesting/backtest_run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class BacktestRun:
8282
number_of_positions: int = 0
8383
backtest_metrics: BacktestMetrics = None
8484
backtest_date_range_name: str = None
85+
data_sources: List[dict] = field(default_factory=list)
8586

8687
def to_dict(self) -> dict:
8788
"""

investing_algorithm_framework/domain/models/portfolio/portfolio_snapshot.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from dateutil import parser
22
from investing_algorithm_framework.domain.models.base_model import BaseModel
3+
from investing_algorithm_framework.domain.constants import \
4+
DEFAULT_DATETIME_FORMAT
35

46

57
class PortfolioSnapshot(BaseModel):
@@ -17,7 +19,8 @@ def __init__(
1719
total_value=None,
1820
cash_flow=None,
1921
created_at=None,
20-
position_snapshots=None
22+
position_snapshots=None,
23+
metadata=None,
2124
):
2225
self.portfolio_id = portfolio_id
2326
self.trading_symbol = trading_symbol
@@ -29,6 +32,7 @@ def __init__(
2932
self.net_size = net_size
3033
self.total_cost = total_cost
3134
self.cash_flow = cash_flow
35+
self.metadata = metadata if metadata is not None else {}
3236

3337
if created_at is not None and isinstance(created_at, str):
3438
self.created_at = parser.parse(created_at)
@@ -144,9 +148,18 @@ def to_dict(self, datetime_format=None):
144148
if self.created_at else None
145149

146150
else:
147-
created_at = self.created_at
151+
created_at = self.created_at.strftime(DEFAULT_DATETIME_FORMAT)
148152

149153
return {
154+
"metadata": self.metadata,
155+
"portfolio_id": self.portfolio_id,
156+
"trading_symbol": self.trading_symbol,
157+
"pending_value": self.pending_value,
158+
"unallocated": self.unallocated,
159+
"total_net_gain": self.total_net_gain,
160+
"total_revenue": self.total_revenue,
161+
"total_cost": self.total_cost,
162+
"cash_flow": self.cash_flow,
150163
"net_size": self.net_size,
151164
"created_at": created_at,
152165
"total_value": self.total_value,
@@ -169,4 +182,15 @@ def from_dict(data):
169182
net_size=data.get("net_size", 0.0),
170183
created_at=created_at,
171184
total_value=data.get("total_value", 0.0),
185+
trading_symbol=data.get(
186+
"trading_symbol", None
187+
),
188+
portfolio_id=data.get("portfolio_id", None),
189+
pending_value=data.get("pending_value", 0.0),
190+
unallocated=data.get("unallocated", 0.0),
191+
total_net_gain=data.get("total_net_gain", 0.0),
192+
total_revenue=data.get("total_revenue", 0.0),
193+
total_cost=data.get("total_cost", 0.0),
194+
cash_flow=data.get("cash_flow", 0.0),
195+
metadata=data.get("metadata", {})
172196
)

0 commit comments

Comments
 (0)