-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathprophet_model.py
More file actions
99 lines (84 loc) · 3.89 KB
/
prophet_model.py
File metadata and controls
99 lines (84 loc) · 3.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import pandas as pd
import numpy as np
import datetime
import matplotlib.pyplot as plt
from scalecast.Forecaster import Forecaster
from prophet import Prophet
from sklearn.base import BaseEstimator, RegressorMixin
import evaluation
import utils
import constants
from preprocess import Preprocess
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)
class ProphetForecaster(BaseEstimator, RegressorMixin):
def __init__(self,
seasonality_mode='additive',
daily_seasonality=False,
weekly_seasonality=False,
seasonality_prior_scale=10.0,
holidays_prior_scale=10.0,
changepoint_prior_scale=0.05):
self.seasonality_mode = seasonality_mode
self.daily_seasonality = daily_seasonality
self.weekly_seasonality = weekly_seasonality
self.seasonality_prior_scale = seasonality_prior_scale
self.holidays_prior_scale = holidays_prior_scale
self.changepoint_prior_scale = changepoint_prior_scale
self.model = None
def fit(self, x, y):
data = pd.merge(x, y, left_index=True, right_index=True)
idx_name = data.index.name
data.index = data.index.tz_localize(None)
data = data.reset_index()
data = data.rename(columns={idx_name: 'ds', y.name: "y"})
holidays = data.loc[data['is_special'] == 1, ['ds', 'is_special']]
holidays.columns = ["ds", "holiday"]
holidays['holiday'] = 'holiday'
self.model = Prophet(holidays=holidays,
seasonality_mode=self.seasonality_mode,
daily_seasonality=self.daily_seasonality,
weekly_seasonality=self.weekly_seasonality,
seasonality_prior_scale=self.seasonality_prior_scale,
holidays_prior_scale=self.holidays_prior_scale,
changepoint_prior_scale=self.changepoint_prior_scale
)
for col in x.columns:
self.model.add_regressor(col)
self.model.fit(data)
return self
def predict(self, x):
idx_name = x.index.name
x.index = x.index.tz_localize(None)
x = x.reset_index()
x = x.rename(columns={idx_name: 'ds'})
forecast = self.model.predict(x)
return forecast['yhat'].values
if __name__ == "__main__":
# usage example
data = utils.import_preprocessed("resources/preprocessed_data.csv")
start_train = constants.W1_LATEST_WEEK['start_train']
start_test = constants.W1_LATEST_WEEK['start_test']
end_short_pred = constants.W1_LATEST_WEEK['start_test'] + datetime.timedelta(days=1)
end_long_pred = constants.W1_LATEST_WEEK['end_test']
p = ProphetForecaster()
params = {'seasonality_mode': 'additive', 'daily_seasonality': True, 'weekly_seasonality': True,
'seasonality_prior_scale': 1, 'holidays_prior_scale': 2, 'changepoint_prior_scale': 0.01,
}
preprocessed = Preprocess.run(data=data.copy(deep=True),
y_label=constants.DMA_NAMES[1],
start_train=start_train,
start_test=start_test,
end_test=end_long_pred,
cols_to_lag={'Rainfall depth (mm)': 12, constants.DMA_NAMES[1]: 24},
cols_to_move_stat=[],
window_size=24,
cols_to_decompose=[],
norm_method='fixed_window',
labels_cluster=[])
x_train, y_train, x_test, y_test, scalers, norm_cols, y_labels = preprocessed
p.fit(x_train, y_train)
y = p.predict(x_test)
plt.plot(y_test.index, y_test.values)
plt.plot(y_test.index, y)
plt.show()