Skip to content

Commit b2d30cc

Browse files
committed
Fix trade status in vector backtests
1 parent 487b03d commit b2d30cc

File tree

7 files changed

+68
-22
lines changed

7 files changed

+68
-22
lines changed

investing_algorithm_framework/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
Position, TimeFrame, INDEX_DATETIME, MarketCredential, \
1515
PortfolioConfiguration, RESOURCE_DIRECTORY, AWS_LAMBDA_LOGGING_CONFIG, \
1616
Trade, SYMBOLS, RESERVED_BALANCES, APP_MODE, AppMode, DATETIME_FORMAT, \
17-
BacktestDateRange, convert_polars_to_pandas, \
17+
BacktestDateRange, convert_polars_to_pandas, BacktestRun, \
1818
DEFAULT_LOGGING_CONFIG, DataType, DataProvider, \
1919
TradeStatus, TradeRiskType, generate_backtest_summary_metrics, \
2020
APPLICATION_DIRECTORY, DataSource, OrderExecutor, PortfolioProvider, \
@@ -189,4 +189,5 @@
189189
"get_negative_trades",
190190
"get_positive_trades",
191191
"get_number_of_trades",
192+
"BacktestRun"
192193
]

investing_algorithm_framework/domain/models/portfolio/portfolio_snapshot.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -150,23 +150,24 @@ def to_dict(self, datetime_format=None):
150150
if datetime_format is not None:
151151
created_at = self.created_at.strftime(datetime_format) \
152152
if self.created_at else None
153-
154153
else:
155154
created_at = self.created_at.strftime(DEFAULT_DATETIME_FORMAT)
156155

157156
return {
158-
"metadata": self.metadata,
159-
"portfolio_id": self.portfolio_id,
160-
"trading_symbol": self.trading_symbol,
161-
"pending_value": self.pending_value,
162-
"unallocated": self.unallocated,
163-
"total_net_gain": self.total_net_gain,
164-
"total_revenue": self.total_revenue,
165-
"total_cost": self.total_cost,
166-
"cash_flow": self.cash_flow,
167-
"net_size": self.net_size,
168-
"created_at": created_at,
169-
"total_value": self.total_value,
157+
"metadata": self.metadata if self.metadata else {},
158+
"portfolio_id": self.portfolio_id if self.portfolio_id else "",
159+
"trading_symbol": self.trading_symbol
160+
if self.trading_symbol else "",
161+
"pending_value": self.pending_value if self.pending_value else 0.0,
162+
"unallocated": self.unallocated if self.unallocated else 0.0,
163+
"total_net_gain": self.total_net_gain
164+
if self.total_net_gain else 0.0,
165+
"total_revenue": self.total_revenue if self.total_revenue else 0.0,
166+
"total_cost": self.total_cost if self.total_cost else 0.0,
167+
"cash_flow": self.cash_flow if self.cash_flow else 0.0,
168+
"net_size": self.net_size if self.net_size else 0.0,
169+
"created_at": created_at if created_at else "",
170+
"total_value": self.total_value if self.total_value else 0.0,
170171
}
171172

172173
@staticmethod

investing_algorithm_framework/domain/models/trade/trade.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,14 +283,14 @@ def to_dict(self, datetime_format=None):
283283
"trading_symbol": self.trading_symbol,
284284
"status": self.status,
285285
"amount": self.amount,
286-
"remaining": self.remaining,
286+
"remaining": self.remaining if self.remaining is not None else 0,
287287
"open_price": self.open_price,
288288
"last_reported_price": self.last_reported_price,
289289
"opened_at": opened_at,
290290
"closed_at": closed_at,
291291
"updated_at": updated_at,
292-
"net_gain": self.net_gain,
293-
"cost": self.cost,
292+
"net_gain": self.net_gain if self.net_gain is not None else 0,
293+
"cost": self.cost if self.cost is not None else 0,
294294
"stop_losses": [
295295
stop_loss.to_dict(datetime_format=datetime_format)
296296
for stop_loss in self.stop_losses

investing_algorithm_framework/services/backtesting/backtest_service.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,19 @@ def create_vector_backtest(
165165
portfolio_configurations = []
166166
portfolio_configurations.append(
167167
PortfolioConfiguration(
168+
identifier="vector_backtest",
168169
market=market,
169170
trading_symbol=trading_symbol,
170171
initial_balance=initial_amount
171172
)
172173
)
173174

175+
portfolio_configuration = portfolio_configurations[0]
176+
174177
trading_symbol = portfolio_configurations[0].trading_symbol
178+
portfolio = Portfolio.from_portfolio_configuration(
179+
portfolio_configuration
180+
)
175181

176182
# Load vectorized backtest data
177183
data = self._data_provider_service.get_vectorized_backtest_data(
@@ -188,7 +194,9 @@ def create_vector_backtest(
188194
index = pd.Index([])
189195

190196
most_granular_ohlcv_data_source = \
191-
self._get_most_granular_ohlcv_data_source(strategy.data_sources)
197+
BacktestService.get_most_granular_ohlcv_data_source(
198+
strategy.data_sources
199+
)
192200
most_granular_ohlcv_data = self._data_provider_service.get_ohlcv_data(
193201
symbol=most_granular_ohlcv_data_source.symbol,
194202
start_date=backtest_date_range.start_date,
@@ -212,6 +220,8 @@ def create_vector_backtest(
212220
granular_ohlcv_data_order_by_symbol = {}
213221
snapshots = [
214222
PortfolioSnapshot(
223+
trading_symbol=trading_symbol,
224+
portfolio_id=portfolio.identifier,
215225
created_at=backtest_date_range.start_date,
216226
unallocated=initial_amount,
217227
total_value=initial_amount,
@@ -346,7 +356,7 @@ def create_vector_backtest(
346356
{
347357
"orders": trade_orders,
348358
"closed_at": current_date,
349-
"trade_status": TradeStatus.CLOSED,
359+
"status": TradeStatus.CLOSED,
350360
"updated_at": current_date,
351361
"net_gain": net_gain_val
352362
}
@@ -392,6 +402,7 @@ def create_vector_backtest(
392402
# total_net_gain = total_value - initial_amount
393403
snapshots.append(
394404
PortfolioSnapshot(
405+
portfolio_id=portfolio.identifier,
395406
created_at=interval_datetime,
396407
unallocated=unallocated,
397408
total_value=unallocated + allocated,
@@ -580,7 +591,7 @@ def create_backtest(
580591
)
581592

582593
@staticmethod
583-
def _get_most_granular_ohlcv_data_source(data_sources):
594+
def get_most_granular_ohlcv_data_source(data_sources):
584595
"""
585596
Get the most granular data source from a list of data sources.
586597

investing_algorithm_framework/services/metrics/win_rate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def get_win_rate(trades: List[Trade]) -> float:
5555
trades = [
5656
trade for trade in trades if TradeStatus.CLOSED.equals(trade.status)
5757
]
58+
print(len(trades))
5859
positive_trades = sum(1 for trade in trades if trade.net_gain > 0)
5960
total_trades = len(trades)
6061

tests/scenarios/vectorized_backtests/test_multiple_vectorized_backtests.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
import pandas as pd
44
from datetime import datetime, timedelta, timezone
55
from unittest import TestCase
6-
from typing import Dict, Any, List
6+
from typing import Dict, Any
77

88
from pyindicators import ema, rsi, crossover, crossunder, macd
99

1010
from investing_algorithm_framework import TradingStrategy, DataSource, \
1111
TimeUnit, DataType, create_app, BacktestDateRange, PositionSize, \
12-
Algorithm, RESOURCE_DIRECTORY, SnapshotInterval
12+
TradeStatus, RESOURCE_DIRECTORY, SnapshotInterval
1313

1414

1515

@@ -290,6 +290,12 @@ def test_run(self):
290290
self.assertNotEqual(0, len(run.get_trades(target_symbol="ETH")))
291291
self.assertNotEqual(0, len(run.get_trades(target_symbol="BTC")))
292292

293+
# Get first trade
294+
trade = run.get_trades(target_symbol="BTC")[0]
295+
self.assertEqual("BTC", trade.target_symbol)
296+
self.assertEqual("EUR", trade.trading_symbol)
297+
self.assertTrue(TradeStatus.CLOSED.equals(trade.status))
298+
293299
def test_run_without_data_sources_initialization(self):
294300
start_time = time.time()
295301
# RESOURCE_DIRECTORY should always point to the parent directory/resources
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import os
2+
from unittest import TestCase
3+
4+
from investing_algorithm_framework import create_backtest_metrics, BacktestRun
5+
6+
class TestGenerateMetrics(TestCase):
7+
def setUp(self):
8+
# Must point to /tests/resources
9+
self.resource_directory = os.path.abspath(
10+
os.path.join(os.path.dirname(__file__), '..', '..', 'resources')
11+
)
12+
self.test_data_directory = os.path.join(
13+
self.resource_directory, 'test_data'
14+
)
15+
self.backtest_run_directory = os.path.join(
16+
self.test_data_directory, 'backtest_runs'
17+
)
18+
19+
def test_generate_metrics(self):
20+
# This is a placeholder for the actual test implementation
21+
backtest_run = BacktestRun.open(
22+
os.path.join(self.backtest_run_directory, 'backtest_run_one')
23+
)
24+
backtest_metrics = create_backtest_metrics(
25+
backtest_run, risk_free_rate=0.024
26+
)

0 commit comments

Comments
 (0)