Skip to content

Commit 6e785cf

Browse files
authored
Create demand_prediction_model.py
1 parent 44ff63b commit 6e785cf

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

models/demand_prediction_model.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import pandas as pd
2+
from sklearn.model_selection import train_test_split, GridSearchCV
3+
from sklearn.linear_model import LinearRegression
4+
from sklearn.ensemble import RandomForestRegressor
5+
from sklearn.metrics import mean_squared_error, r2_score
6+
import joblib
7+
import numpy as np
8+
9+
# Load your dataset
10+
# For demonstration, we will create a synthetic dataset
11+
# In a real scenario, you would load your dataset from a CSV or database
12+
data = {
13+
'historical_demand': [100, 150, 200, 250, 300, 350, 400, 450, 500, 550],
14+
'seasonality': [1, 1, 1, 2, 2, 2, 3, 3, 3, 4], # Example feature
15+
'promotion': [0, 1, 0, 1, 0, 1, 0, 1, 0, 1] # Example feature: promotional activity
16+
}
17+
18+
df = pd.DataFrame(data)
19+
20+
# Features and target variable
21+
X = df[['historical_demand', 'seasonality', 'promotion']]
22+
y = df['historical_demand'].shift(-1).dropna() # Predict next period's demand
23+
X = X[:-1] # Align X with y
24+
25+
# Split the data into training and testing sets
26+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
27+
28+
# Model selection and hyperparameter tuning using GridSearchCV
29+
param_grid = {
30+
'n_estimators': [50, 100, 200],
31+
'max_depth': [None, 10, 20, 30],
32+
'min_samples_split': [2, 5, 10]
33+
}
34+
35+
# Using Random Forest Regressor for better performance
36+
model = RandomForestRegressor(random_state=42)
37+
38+
# Grid search for hyperparameter tuning
39+
grid_search = GridSearchCV(estimator=model, param_grid=param_grid, cv=5, scoring='neg_mean_squared_error', verbose=2, n_jobs=-1)
40+
grid_search.fit(X_train, y_train)
41+
42+
# Best model from grid search
43+
best_model = grid_search.best_estimator_
44+
45+
# Evaluate the model
46+
y_pred = best_model.predict(X_test)
47+
mse = mean_squared_error(y_test, y_pred)
48+
r2 = r2_score(y_test, y_pred)
49+
50+
print(f'Mean Squared Error: {mse}')
51+
print(f'R^2 Score: {r2}')
52+
53+
# Save the model
54+
joblib.dump(best_model, 'models/demand_prediction_model.pkl')
55+
print('Demand prediction model saved as demand_prediction_model.pkl')

0 commit comments

Comments
 (0)