Skip to content

Commit 5cf2b6d

Browse files
committed
allow users to specify engine type for automlx
1 parent a9a9b05 commit 5cf2b6d

File tree

2 files changed

+25
-22
lines changed

2 files changed

+25
-22
lines changed

ads/opctl/operator/lowcode/forecast/model/automlx.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -81,22 +81,6 @@ def _build_model(self) -> pd.DataFrame:
8181

8282
from automlx import Pipeline, init
8383

84-
cpu_count = os.cpu_count()
85-
try:
86-
if cpu_count < 4:
87-
engine = "local"
88-
engine_opts = None
89-
else:
90-
engine = "ray"
91-
engine_opts = ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
92-
init(
93-
engine=engine,
94-
engine_opts=engine_opts,
95-
loglevel=logging.CRITICAL,
96-
)
97-
except Exception as e:
98-
logger.info(f"Error. Has Ray already been initialized? Skipping. {e}")
99-
10084
full_data_dict = self.datasets.get_data_by_series()
10185

10286
self.models = {}
@@ -112,6 +96,26 @@ def _build_model(self) -> pd.DataFrame:
11296
# Clean up kwargs for pass through
11397
model_kwargs_cleaned, time_budget = self.set_kwargs()
11498

99+
cpu_count = os.cpu_count()
100+
try:
101+
engine_type = model_kwargs_cleaned.pop(
102+
"engine", "local" if cpu_count <= 4 else "ray"
103+
)
104+
engine_opts = (
105+
None
106+
if engine_type == "local"
107+
else ({"ray_setup": {"_temp_dir": "/tmp/ray-temp"}},)
108+
)
109+
init(
110+
engine=engine_type,
111+
engine_opts=engine_opts,
112+
loglevel=logging.CRITICAL,
113+
)
114+
except Exception as e:
115+
logger.info(
116+
f"Error initializing automlx. Has Ray already been initialized? Skipping. {e}"
117+
)
118+
115119
for s_id, df in full_data_dict.items():
116120
try:
117121
logger.debug(f"Running automlx on series {s_id}")

pyproject.toml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,27 +157,26 @@ forecast = [
157157
"oci-cli",
158158
"py-cpuinfo",
159159
"rich",
160-
"autots[additional]",
160+
"autots",
161161
"mlforecast",
162162
"neuralprophet>=0.7.0",
163163
"numpy<2.0.0",
164164
"oci-cli",
165165
"optuna",
166-
"oracle-ads",
167166
"pmdarima",
168167
"prophet",
169168
"shap",
170169
"sktime",
171170
"statsmodels",
172171
"plotly",
173172
"oracledb",
174-
"report-creator==1.0.28",
173+
"report-creator==1.0.32",
175174
]
176175
anomaly = [
177176
"oracle_ads[opctl]",
178177
"autots",
179178
"oracledb",
180-
"report-creator==1.0.28",
179+
"report-creator==1.0.32",
181180
"rrcf==0.4.4",
182181
"scikit-learn",
183182
"salesforce-merlion[all]==2.0.4"
@@ -186,7 +185,7 @@ recommender = [
186185
"oracle_ads[opctl]",
187186
"scikit-surprise",
188187
"plotly",
189-
"report-creator==1.0.28",
188+
"report-creator==1.0.32",
190189
]
191190
feature-store-marketplace = [
192191
"oracle-ads[opctl]",
@@ -202,7 +201,7 @@ pii = [
202201
"scrubadub_spacy",
203202
"spacy-transformers==1.2.5",
204203
"spacy==3.6.1",
205-
"report-creator==1.0.28",
204+
"report-creator==1.0.32",
206205
]
207206
llm = ["langchain>=0.2", "langchain-community", "langchain_openai", "pydantic>=2,<3", "evaluate>=0.4.0"]
208207
aqua = ["jupyter_server"]

0 commit comments

Comments
 (0)