Distributed ARIMA Forecasting with Spark¶
Feng Li¶
Guanghua School of Management¶
Peking University¶
feng.li@gsm.pku.edu.cn¶
Course home page: https://feng.li/bdcf¶
A split-and-merge example using pandas and statsmodels¶
Split the full time series into n blocks (equal-length subseries).
Fit an ARIMA model to each block.
Collect ARIMA parameters.
Manual Forecast with ARIMA Global Estimator
In [2]:
import pandas as pd
import numpy as np
from statsmodels.tsa.arima.model import ARIMA
# Load the data
df = pd.read_csv("../data/electricity/TOTAL_train.csv", parse_dates=["time"])
series = df["demand"].values
# Split into n blocks
n_blocks = 5
block_size = len(series) // n_blocks
blocks = [series[i*block_size:(i+1)*block_size] for i in range(n_blocks)]
df
Out[2]:
demand | time | |
---|---|---|
0 | 12864.000 | 2003-03-01 00:00:00 |
1 | 12389.000 | 2003-03-01 01:00:00 |
2 | 12155.000 | 2003-03-01 02:00:00 |
3 | 12072.000 | 2003-03-01 03:00:00 |
4 | 12162.000 | 2003-03-01 04:00:00 |
... | ... | ... |
121287 | 15199.857 | 2016-12-31 19:00:00 |
121288 | 14503.994 | 2016-12-31 20:00:00 |
121289 | 13829.016 | 2016-12-31 21:00:00 |
121290 | 13093.205 | 2016-12-31 22:00:00 |
121291 | 12370.639 | 2016-12-31 23:00:00 |
121292 rows × 2 columns
In [4]:
# Fit ARIMA on each block
arima_order = (1, 1, 1)
params_list = []
for i, block in enumerate(blocks):
try:
model = ARIMA(block, order=arima_order)
result = model.fit()
params_list.append(result.params)
except Exception as e:
print(f"Block {i} failed: {e}")
params_list.append(None)
# Collect parameters into DataFrame (corrected version)
param_df = pd.DataFrame([p if p is not None else [np.nan]*3 for p in params_list],
columns=["ar.L1", "ma.L1", "const"])
param_df["block_id"] = range(n_blocks)
param_df
Out[4]:
ar.L1 | ma.L1 | const | block_id | |
---|---|---|---|---|
0 | 0.696868 | 0.678450 | 139543.349121 | 0 |
1 | 0.696260 | 0.665971 | 137549.845347 | 1 |
2 | 0.691261 | 0.661040 | 127060.016145 | 2 |
3 | 0.711046 | 0.660685 | 115836.900034 | 3 |
4 | 0.728699 | 0.671071 | 94070.948686 | 4 |
In [6]:
# Drop rows with NaNs (failed models)
valid_params = param_df.dropna()
# Compute the average of each parameter across blocks (DLSA style)
global_estimator = valid_params[["ar.L1", "ma.L1", "const"]].mean()
print("Global ARIMA Parameters (via DLSA-style averaging):")
print(global_estimator)
Global ARIMA Parameters (via DLSA-style averaging): ar.L1 0.704827 ma.L1 0.667443 const 122812.211867 dtype: float64
In [7]:
# Use the last block as test set for forecasting
test_block = blocks[-1]
test_block_diff = np.diff(test_block) # since d=1
# Extract global parameters
phi = global_estimator["ar.L1"]
theta = global_estimator["ma.L1"]
const = global_estimator["const"]
# Initialize for recursive forecast
forecast_horizon = 10
y_last = test_block[-1]
y_history = list(test_block)
eps_history = [0] # start with zero error
# Generate forecasts
forecasts = []
for h in range(forecast_horizon):
# Simulate ARIMA(1,1,1) forecast
y_prev = y_history[-1]
eps_prev = eps_history[-1]
eps_t = 0 # assume mean error
y_diff_forecast = const + phi * (y_history[-1] - y_history[-2]) + theta * eps_prev + eps_t
y_forecast = y_prev + y_diff_forecast
forecasts.append(y_forecast)
y_history.append(y_forecast)
eps_history.append(eps_t)
In [8]:
forecasts
Out[8]:
[136165.48530924582, 345203.71456575696, 615351.6644408321, 928571.3807523709, 1272149.230682896, 1637124.3055957702, 2017180.7172828834, 2407866.858716991, 2806045.1181367626, 3209504.022807839]
A full PySpark script¶
implementing the DARIMA-style blockwise ARIMA training, parameter merging, and forecasting pipeline.
This script includes:
- Loading your TOTAL_train.csv file.
- Splitting the data into blocks using Spark.
- Fitting ARIMA(1,1,1) models blockwise with a UDF.
- Merging parameters via simple averaging (DLSA-style).
- Forecasting future values using the merged global model.
In [9]:
import os, sys # Ensure All environment variables are properly set
# os.environ["JAVA_HOME"] = os.path.dirname(sys.executable)
os.environ["PYSPARK_PYTHON"] = sys.executable
os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable
from pyspark.sql import SparkSession # build Spark Session
spark = SparkSession.builder \
.config("spark.ui.enabled", "false") \
.config("spark.executor.memory", "16g") \
.config("spark.executor.cores", "4") \
.config("spark.cores.max", "32") \
.config("spark.driver.memory", "30g") \
.config("spark.sql.shuffle.partitions", "96") \
.config("spark.memory.fraction", "0.8") \
.config("spark.memory.storageFraction", "0.5") \
.config("spark.dynamicAllocation.enabled", "true") \
.config("spark.dynamicAllocation.minExecutors", "4") \
.config("spark.dynamicAllocation.maxExecutors", "8") \
.appName("Spark Forecasting").getOrCreate()
spark
Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). 25/03/31 20:47:10 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Out[9]:
SparkSession - in-memory
In [12]:
from pyspark.sql.functions import col, monotonically_increasing_id, floor, collect_list, udf
from pyspark.sql.types import ArrayType, DoubleType
import pandas as pd
import numpy as np
from statsmodels.tsa.arima.model import ARIMA
# Step 1: Load data
df = spark.read.csv("../data/electricity/TOTAL_train.csv", header=True, inferSchema=True)
# Step 2: Assign block IDs
n_blocks = 5
df = df.withColumn("row_id", monotonically_increasing_id())
df = df.withColumn("block_id", floor(col("row_id") % n_blocks))
# Step 3: Group and collect series per block
grouped_df = df.groupBy("block_id").agg(collect_list("demand").alias("series"))
# Step 4: Define ARIMA UDF
def fit_arima(series):
try:
model = ARIMA(series, order=(1,1,1))
result = model.fit()
return result.params.tolist()
except:
return [float("nan")] * 3
fit_arima_udf = udf(fit_arima, ArrayType(DoubleType()))
# Step 5: Apply UDF to each block
arima_params_df = grouped_df.withColumn("params", fit_arima_udf(col("series")))
arima_params_df.show()
/home/fli/.local/miniforge3/lib/python3.12/site-packages/statsmodels/tsa/statespace/sarimax.py:966: UserWarning: Non-stationary starting autoregressive parameters found. Using zeros as starting parameters. warn('Non-stationary starting autoregressive parameters' /home/fli/.local/miniforge3/lib/python3.12/site-packages/statsmodels/tsa/statespace/sarimax.py:978: UserWarning: Non-invertible starting MA parameters found. Using zeros as starting parameters. warn('Non-invertible starting MA parameters found.'
+--------+--------------------+--------------------+ |block_id| series| params| +--------+--------------------+--------------------+ | 3|[12072.0, 15213.0...|[0.27383518831865...| | 4|[12162.0, 15646.0...|[0.27356099915817...| | 1|[12389.0, 13238.0...|[0.27318032124865...| | 2|[12155.0, 14191.0...|[0.27372551206138...| | 0|[12864.0, 12569.0...|[0.27322740803644...| +--------+--------------------+--------------------+
In [13]:
# Step 6: Convert to Pandas for merging
params_pd = arima_params_df.select("block_id", "params").toPandas()
params_df = pd.DataFrame(params_pd["params"].tolist(), columns=["ar.L1", "ma.L1", "const"])
params_df["block_id"] = params_pd["block_id"]
# Step 7: Merge step (average parameters)
valid_params = params_df.dropna()
global_estimator = valid_params[["ar.L1", "ma.L1", "const"]].mean()
print("Global ARIMA Parameters (DLSA-style):\n", global_estimator)
[Stage 11:> (0 + 1) / 1]
Global ARIMA Parameters (DLSA-style): ar.L1 2.735059e-01 ma.L1 -9.661642e-01 const 6.174971e+06 dtype: float64
In [14]:
# Step 8: Forecast using the global parameters
df_pd = df.orderBy("row_id").toPandas()
series = df_pd["demand"].values
test_block = series[-(len(series)//n_blocks):]
phi, theta, const = global_estimator
# Manual forecast
forecast_horizon = 10
y_history = list(test_block)
eps_history = [0]
forecasts_darima = []
for _ in range(forecast_horizon):
y_prev = y_history[-1]
y_diff = y_history[-1] - y_history[-2]
eps_prev = eps_history[-1]
y_diff_forecast = const + phi * y_diff + theta * eps_prev
y_forecast = y_prev + y_diff_forecast
forecasts_darima.append(y_forecast)
y_history.append(y_forecast)
eps_history.append(0)
print("Forecasts:", forecasts_darima)
Global ARIMA Parameters (DLSA-style): ar.L1 2.735059e-01 ma.L1 -9.661642e-01 const 6.174971e+06 dtype: float64 Forecasts: [6187143.597192929, 14050950.928765967, 22376719.1027173, 30828836.2860519, 39315510.66711336, 47811636.64514864, 56310347.69058615, 64809765.76717326, 73309377.2209412, 81809041.56450628]
Discussions¶
The above version is the poor man's DARIMA
Full DARIMA implementation https://github.com/xqnwang/darima