Skip to content

Commit 40b3a70

Browse files
committed
Fix data provider retrieval for permutation tests
1 parent 514590a commit 40b3a70

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

investing_algorithm_framework/app/app.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1325,9 +1325,10 @@ def run_permutation_test(
13251325

13261326
for data_source in data_sources:
13271327
if DataType.OHLCV.equals(data_source.data_type):
1328+
data_provider = data_provider_service.get(data_source)
13281329
data = data_provider_service.get_data(
13291330
data_source=data_source,
1330-
start_date=backtest_date_range.start_date,
1331+
start_date=data_provider._start_date_data_source,
13311332
end_date=backtest_date_range.end_date
13321333
)
13331334
original_data_combinations.append((data_source, data))

investing_algorithm_framework/services/data_providers/data_provider_service.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,19 @@ def __init__(
322322
self.configuration_service = configuration_service
323323
self.market_credential_service = market_credential_service
324324

325+
def get(self, data_source: DataSource) -> Optional[DataProvider]:
326+
"""
327+
Get a registered data provider by its data source.
328+
329+
Args:
330+
data_source (DataSource): The data source to get the data provider for.
331+
332+
Returns:
333+
Optional[DataProvider]: The registered data provider for
334+
the data source, or None if not found.
335+
"""
336+
return self.data_provider_index.get(data_source)
337+
325338
def get_data(
326339
self,
327340
data_source: DataSource,
@@ -742,6 +755,15 @@ def get_data_files(self):
742755

743756
return data_files
744757

758+
def get_all_registered_data_providers(self) -> List[DataProvider]:
759+
"""
760+
Function to get all registered data providers.
761+
762+
Returns:
763+
List[DataProvider]: A list of all registered data providers.
764+
"""
765+
return self.data_provider_index.get_all()
766+
745767
def reset(self):
746768
"""
747769
Function to reset all the data providers and the data provider

0 commit comments

Comments
 (0)