|
1 | 1 | #!/usr/bin/env python |
2 | | -# -*- coding: utf-8 -*-- |
3 | 2 |
|
4 | 3 | # Copyright (c) 2024 Oracle and/or its affiliates. |
5 | 4 | # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ |
6 | | -import pandas as pd |
7 | 5 | import numpy as np |
| 6 | +import pandas as pd |
8 | 7 |
|
9 | | -from ads.opctl import logger |
10 | 8 | from ads.common.decorator import runtime_dependency |
| 9 | +from ads.opctl import logger |
11 | 10 | from ads.opctl.operator.lowcode.forecast.utils import _select_plot_list |
| 11 | + |
| 12 | +from ..const import ForecastOutputColumns, SupportedModels |
| 13 | +from ..operator_config import ForecastOperatorConfig |
12 | 14 | from .base_model import ForecastOperatorBaseModel |
13 | 15 | from .forecast_datasets import ForecastDatasets, ForecastOutput |
14 | | -from ..operator_config import ForecastOperatorConfig |
15 | | -from ..const import ForecastOutputColumns, SupportedModels |
16 | 16 |
|
17 | 17 |
|
18 | 18 | class MLForecastOperatorModel(ForecastOperatorBaseModel): |
@@ -58,18 +58,25 @@ def _train_model(self, data_train, data_test, model_kwargs): |
58 | 58 | from mlforecast.target_transforms import Differences |
59 | 59 |
|
60 | 60 | lgb_params = { |
61 | | - "verbosity": -1, |
62 | | - "num_leaves": 512, |
| 61 | + "verbosity": model_kwargs.get("verbosity", -1), |
| 62 | + "num_leaves": model_kwargs.get("num_leaves", 512), |
63 | 63 | } |
64 | 64 | additional_data_params = {} |
65 | 65 | if len(self.datasets.get_additional_data_column_names()) > 0: |
66 | 66 | additional_data_params = { |
67 | | - "target_transforms": [Differences([12])], |
| 67 | + "target_transforms": [ |
| 68 | + Differences([model_kwargs.get("Differences", 12)]) |
| 69 | + ], |
68 | 70 | "lags": model_kwargs.get("lags", [1, 6, 12]), |
69 | 71 | "lag_transforms": ( |
70 | 72 | { |
71 | 73 | 1: [ExpandingMean()], |
72 | | - 12: [RollingMean(window_size=24)], |
| 74 | + 12: [ |
| 75 | + RollingMean( |
| 76 | + window_size=model_kwargs.get("RollingMean", 24), |
| 77 | + min_samples=1, |
| 78 | + ) |
| 79 | + ], |
73 | 80 | } |
74 | 81 | ), |
75 | 82 | } |
|
0 commit comments