Skip to content

Commit 686273a

Browse files
committed
Check if dataframe has datetime data type
1 parent 6269027 commit 686273a

File tree

5 files changed

+80
-108
lines changed

5 files changed

+80
-108
lines changed

investing_algorithm_framework/domain/models/trade/trade.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,21 @@ def is_manual_stop_loss_trigger(
209209
# If dataframes are provided, we use the dataframe to calculate
210210
# the stop loss price
211211
if ohlcv_df is not None:
212-
filtered_df = ohlcv_df.filter(
213-
pl.col('Datetime') >= self.opened_at.strftime(
214-
DATETIME_FORMAT
212+
column_type = ohlcv_df['Datetime'].dtype
213+
214+
if isinstance(column_type, pl.Datetime):
215+
filtered_df = ohlcv_df.filter(
216+
pl.col('Datetime') >= self.opened_at
217+
)
218+
else:
219+
filtered_df = ohlcv_df.filter(
220+
pl.col('Datetime') >= self.opened_at.strftime(
221+
DATETIME_FORMAT
222+
)
215223
)
216-
)
224+
217225
prices = filtered_df['Close'].to_numpy()
226+
218227
highest_price = max(prices)
219228
stop_loss_price = highest_price * (1 - stop_loss_percentage / 100)
220229
return current_price <= stop_loss_price

investing_algorithm_framework/services/order_service/order_backtest_service.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import datetime
12
import logging
23

34
import pandas as pd
5+
import polars as pl
46

57
from investing_algorithm_framework.domain import BACKTESTING_INDEX_DATETIME, \
68
OrderStatus, BACKTESTING_PENDING_ORDER_CHECK_INTERVAL, \
@@ -160,9 +162,18 @@ def has_executed(self, order, ohlcv_data_frame):
160162
created_at = order.get_created_at()
161163
order_side = order.get_order_side()
162164
order_price = order.get_price()
165+
column_type = ohlcv_data_frame['Datetime'].dtype
163166

164-
# Filter OHLCV data after the order creation time
165-
ohlcv_data_after_order = ohlcv_data_frame.loc[created_at:]
167+
if isinstance(column_type, pl.Datetime):
168+
ohlcv_data_after_order = ohlcv_data_frame.filter(
169+
pl.col('Datetime') >= created_at
170+
)
171+
else:
172+
ohlcv_data_after_order = ohlcv_data_frame.filter(
173+
pl.col('Datetime') >= created_at.strftime(
174+
self.configuration_service.config["DATETIME_FORMAT"]
175+
)
176+
)
166177

167178
# Check if the order execution conditions are met
168179
if OrderSide.BUY.equals(order_side):

tests/infrastructure/market_data_sources/test_ccxt_ohlcv_backtest_market_data_source.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from unittest import TestCase
44

55
from investing_algorithm_framework.domain import RESOURCE_DIRECTORY, \
6-
BACKTEST_DATA_DIRECTORY_NAME
6+
BACKTEST_DATA_DIRECTORY_NAME, DATETIME_FORMAT
77
from investing_algorithm_framework.infrastructure import \
88
CCXTOHLCVBacktestMarketDataSource
99

@@ -97,6 +97,9 @@ def test_right_columns(self):
9797
timeframe="15m",
9898
window_size=200
9999
)
100+
data_source.config = {
101+
"DATETIME_FORMAT": DATETIME_FORMAT
102+
}
100103
data_source.prepare_data(
101104
config={
102105
RESOURCE_DIRECTORY: self.resource_dir,
@@ -107,7 +110,11 @@ def test_right_columns(self):
107110
)
108111
self.assertEqual(200, data_source.window_size)
109112
self.assertEqual(csv_file_path, data_source._create_file_path())
110-
113+
df = data_source\
114+
.get_data(datetime(year=2023, month=12, day=17, hour=0, minute=0))
115+
self.assertEqual(
116+
["Datetime", "Open", "High", "Low", "Close", "Volume"], df.columns
117+
)
111118

112119
# def test_start_date(self):
113120
# start_date = datetime(2023, 12, 1)
@@ -151,26 +158,7 @@ def test_right_columns(self):
151158
# data_source.start_date = datetime(2023, 12, 25)
152159
# data_source.end_date = datetime(2023, 12, 16)
153160
# self.assertTrue(data_source.empty())
154-
#
155-
# def test_get_data(self):
156-
# file_name = "OHLCV_BTC-EUR_BINANCE_15m_2023-12-" \
157-
# "01:00:00_2023-12-25:00:00.csv"
158-
# datasource = CSVOHLCVMarketDataSource(
159-
# csv_file_path=f"{self.resource_dir}/"
160-
# "market_data_sources/"
161-
# f"{file_name}"
162-
# )
163-
# number_of_runs = 0
164-
#
165-
# while not datasource.empty():
166-
# data = datasource.get_data()
167-
# datasource.start_date = datasource.start_date + timedelta(days=1)
168-
# datasource.end_date = datasource.end_date + timedelta(days=1)
169-
# self.assertTrue(len(data) > 0)
170-
# number_of_runs += 1
171-
#
172-
# self.assertTrue(number_of_runs > 0)
173-
#
161+
174162
# def test_get_identifier(self):
175163
# file_name = "OHLCV_BTC-EUR_BINANCE_15m_2023-12-" \
176164
# "01:00:00_2023-12-25:00:00.csv"
@@ -214,3 +202,4 @@ def test_right_columns(self):
214202
# timeframe="15m"
215203
# )
216204
# self.assertEqual("15m", datasource.get_timeframe())
205+

tests/infrastructure/market_data_sources/test_csv_ohlcv_market_data_source.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,17 @@ def setUp(self) -> None:
3131
def test_right_columns(self):
3232
file_name = "OHLCV_BTC-EUR_BINANCE" \
3333
"_2h_2023-08-07:07:59_2023-12-02:00:00.csv"
34-
CSVOHLCVMarketDataSource(
34+
data_source = CSVOHLCVMarketDataSource(
3535
csv_file_path=f"{self.resource_dir}/"
3636
"market_data_sources/"
3737
f"{file_name}",
3838
window_size=10
3939
)
40+
df = data_source \
41+
.get_data(datetime(year=2023, month=12, day=17, hour=0, minute=0))
42+
self.assertEqual(
43+
["Datetime", "Open", "High", "Low", "Close", "Volume"], df.columns
44+
)
4045

4146
def test_throw_exception_when_missing_column_names_columns(self):
4247
file_name = "OHLCV_BTC-EUR_BINANCE_2h_NO_COLUMNS_2023-" \
@@ -114,6 +119,7 @@ def test_get_data(self):
114119
datasource.start_date = datasource.start_date + timedelta(days=1)
115120
datasource.end_date = datasource.end_date + timedelta(days=1)
116121
self.assertTrue(len(data) > 0)
122+
self.assertAlmostEqual(10, len(data), 2)
117123
self.assertTrue(isinstance(data, DataFrame))
118124
number_of_runs += 1
119125

@@ -170,3 +176,6 @@ def test_get_timeframe(self):
170176
window_size=10,
171177
)
172178
self.assertEqual("15m", datasource.get_timeframe())
179+
180+
def test_get_data(self):
181+
pass

0 commit comments

Comments
 (0)