Skip to content

Commit a0a66ea

Browse files
committed
Fix download function
1 parent 1a01897 commit a0a66ea

File tree

3 files changed

+96
-17
lines changed

3 files changed

+96
-17
lines changed

examples/download_data.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,5 @@
99
end_date="2023-12-31",
1010
pandas=False,
1111
save=True,
12+
storage_path="./data"
1213
)
13-
14-
print(btceur_ohlcv.head())
15-
print(type(btceur_ohlcv))

investing_algorithm_framework/infrastructure/data_providers/ccxt.py

Lines changed: 95 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os.path
12
from time import sleep
23
from typing import Union
34
import logging
@@ -534,7 +535,13 @@ def get_data(
534535
# to the specified storage path.
535536
# This is a placeholder for that logic.
536537
self.save_data_to_storage(
537-
data=data, storage_path=storage_path
538+
symbol=symbol,
539+
market=market,
540+
start_date=start_date,
541+
end_date=end_date,
542+
time_frame=time_frame,
543+
data=data,
544+
storage_path=storage_path
538545
)
539546

540547
if pandas:
@@ -745,7 +752,7 @@ def retrieve_data_from_storage(
745752
time_frame: str = None,
746753
start_date: datetime = None,
747754
end_date: datetime = None
748-
) -> pl.DataFrame:
755+
) -> pl.DataFrame | None:
749756
"""
750757
Function to retrieve data from the storage path.
751758
@@ -760,18 +767,39 @@ def retrieve_data_from_storage(
760767
Returns:
761768
pl.DataFrame: The retrieved data in Polars DataFrame format.
762769
"""
763-
self.data = pl.read_csv(
764-
storage_path,
765-
schema_overrides={"Datetime": pl.Datetime},
766-
low_memory=True
770+
771+
if not os.path.isdir(storage_path):
772+
return None
773+
774+
file_name = self._create_filename(
775+
symbol=symbol,
776+
market=market,
777+
time_frame=time_frame,
778+
start_date=start_date,
779+
end_date=end_date
767780
)
768-
first_row = self.data.head(1)
769-
last_row = self.data.tail(1)
770-
self._start_date_data_source = first_row["Datetime"][0]
771-
self._end_date_data_source = last_row["Datetime"][0]
781+
782+
file_path = os.path.join(storage_path, file_name)
783+
784+
if os.path.exists(file_path):
785+
try:
786+
data = pl.read_csv(file_path, has_header=True)
787+
return data
788+
except Exception as e:
789+
logger.error(
790+
f"Error reading data from {file_path}: {e}"
791+
)
792+
return None
793+
794+
return None
772795

773796
def save_data_to_storage(
774797
self,
798+
symbol,
799+
market,
800+
start_date: datetime,
801+
end_date: datetime,
802+
time_frame: str,
775803
data: pl.DataFrame,
776804
storage_path: str,
777805
):
@@ -785,7 +813,32 @@ def save_data_to_storage(
785813
Returns:
786814
None
787815
"""
788-
# Placeholder for actual implementation
816+
if storage_path is None:
817+
raise OperationalException(
818+
"Storage path is not set. Please set the storage path "
819+
"before saving data."
820+
)
821+
822+
if not os.path.isdir(storage_path):
823+
os.makedirs(storage_path)
824+
825+
symbol = symbol.upper().replace('/', '_')
826+
filename = self._create_filename(
827+
symbol=symbol,
828+
market=market,
829+
time_frame=time_frame,
830+
start_date=start_date,
831+
end_date=end_date
832+
)
833+
storage_path = os.path.join(storage_path, filename)
834+
if os.path.exists(storage_path):
835+
os.remove(storage_path)
836+
837+
# Create the file
838+
if not os.path.exists(storage_path):
839+
with open(storage_path, 'w'):
840+
pass
841+
789842
data.write_csv(storage_path)
790843

791844
def __repr__(self):
@@ -794,3 +847,34 @@ def __repr__(self):
794847
f"symbol={self.symbol}, time_frame={self.time_frame}, "
795848
f"window_size={self.window_size})"
796849
)
850+
851+
@staticmethod
852+
def _create_filename(
853+
symbol: str,
854+
market: str,
855+
time_frame: str,
856+
start_date: datetime,
857+
end_date: datetime
858+
) -> str:
859+
"""
860+
Creates a filename for the data file based on the parameters.
861+
The date format is YYYYMMDDHH for both start and end dates.
862+
863+
Args:
864+
symbol (str): The symbol of the data.
865+
market (str): The market of the data.
866+
time_frame (str): The time frame of the data.
867+
start_date (datetime): The start date of the data.
868+
end_date (datetime): The end date of the data.
869+
870+
Returns:
871+
str: The generated filename.
872+
"""
873+
symbol = symbol.upper().replace('/', '_')
874+
start_date_str = start_date.strftime('%Y%m%d%H')
875+
end_date_str = end_date.strftime('%Y%m%d%H')
876+
filename = (
877+
f"{symbol}_{market}_{time_frame}_{start_date_str}_"
878+
f"{end_date_str}.csv"
879+
)
880+
return filename

investing_algorithm_framework/services/market_data_source_service/data_provider_service.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,10 @@ def register(self, symbol, market) -> DataProvider:
7272
Returns:
7373
None
7474
"""
75-
print("Registering data provider for ohlcv data")
7675
matches = []
7776

7877
for data_provider in self.data_providers:
7978

80-
print("checking data provider", data_provider)
8179
if data_provider.supports(market, symbol):
8280
matches.append(data_provider)
8381

@@ -336,7 +334,6 @@ def find_data_provider(
336334
"""
337335
data_provider = None
338336

339-
print(data_type)
340337
if TradingDataType.OHLCV.equals(data_type):
341338
# Check if there is already a registered data provider
342339
data_provider = self.ohlcv_data_provider_index.get(

0 commit comments

Comments
 (0)