11#!/usr/bin/env python
2- # -*- coding: utf-8 -*--
32
43# Copyright (c) 2024 Oracle and/or its affiliates.
54# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6- import pandas as pd
75import numpy as np
6+ import pandas as pd
87
9- from ads .opctl import logger
108from ads .common .decorator import runtime_dependency
9+ from ads .opctl import logger
1110from ads .opctl .operator .lowcode .forecast .utils import _select_plot_list
11+
12+ from ..const import ForecastOutputColumns , SupportedModels
13+ from ..operator_config import ForecastOperatorConfig
1214from .base_model import ForecastOperatorBaseModel
1315from .forecast_datasets import ForecastDatasets , ForecastOutput
14- from ..operator_config import ForecastOperatorConfig
15- from ..const import ForecastOutputColumns , SupportedModels
1616
1717
1818class MLForecastOperatorModel (ForecastOperatorBaseModel ):
@@ -58,18 +58,25 @@ def _train_model(self, data_train, data_test, model_kwargs):
5858 from mlforecast .target_transforms import Differences
5959
6060 lgb_params = {
61- "verbosity" : - 1 ,
62- "num_leaves" : 512 ,
61+ "verbosity" : model_kwargs . get ( "verbosity" , - 1 ) ,
62+ "num_leaves" : model_kwargs . get ( "num_leaves" , 512 ) ,
6363 }
6464 additional_data_params = {}
6565 if len (self .datasets .get_additional_data_column_names ()) > 0 :
6666 additional_data_params = {
67- "target_transforms" : [Differences ([12 ])],
67+ "target_transforms" : [
68+ Differences ([model_kwargs .get ("Differences" , 12 )])
69+ ],
6870 "lags" : model_kwargs .get ("lags" , [1 , 6 , 12 ]),
6971 "lag_transforms" : (
7072 {
7173 1 : [ExpandingMean ()],
72- 12 : [RollingMean (window_size = 24 )],
74+ 12 : [
75+ RollingMean (
76+ window_size = model_kwargs .get ("RollingMean" , 24 ),
77+ min_samples = 1 ,
78+ )
79+ ],
7380 }
7481 ),
7582 }
@@ -147,7 +154,7 @@ def _train_model(self, data_train, data_test, model_kwargs):
147154 )
148155
149156 self .model_parameters [s_id ] = {
150- "framework" : SupportedModels .MLForecast ,
157+ "framework" : SupportedModels .LGBForecast ,
151158 ** lgb_params ,
152159 }
153160
@@ -204,10 +211,10 @@ def _generate_report(self):
204211 self .datasets .list_series_ids (),
205212 )
206213
207- # Section 2: MlForecast Model Parameters
214+ # Section 2: LGBForecast Model Parameters
208215 sec2_text = rc .Block (
209- rc .Heading ("MlForecast Model Parameters" , level = 2 ),
210- rc .Text ("These are the parameters used for the MlForecast model." ),
216+ rc .Heading ("LGBForecast Model Parameters" , level = 2 ),
217+ rc .Text ("These are the parameters used for the LGBForecast model." ),
211218 )
212219
213220 blocks = [
@@ -221,7 +228,7 @@ def _generate_report(self):
221228
222229 all_sections = [sec1_text , sec_1 , sec2_text , sec_2 ]
223230 model_description = rc .Text (
224- "mlforecast is a framework to perform time series forecasting using machine learning models"
231+ "LGBForecast uses mlforecast framework to perform time series forecasting using machine learning models"
225232 "with the option to scale to massive amounts of data using remote clusters."
226233 "Fastest implementations of feature engineering for time series forecasting in Python."
227234 "Support for exogenous variables and static covariates."
0 commit comments