{ "cells": [ { "cell_type": "markdown", "id": "c7930d23-0fb5-49e4-b25e-cc2e11195d03", "metadata": { "editable": true, "slideshow": { "slide_type": "slide" }, "tags": [] }, "source": [ "# Distributed ARIMA Forecasting with Spark\n", "\n", "## Feng Li\n", "\n", "### Guanghua School of Management\n", "### Peking University\n", "\n", "\n", "### [feng.li@gsm.pku.edu.cn](feng.li@gsm.pku.edu.cn)\n", "### Course home page: [https://feng.li/bdcf](https://feng.li/bdcf)" ] }, { "cell_type": "markdown", "id": "97384b2a-089f-470f-8885-58a8c63ab745", "metadata": { "editable": true, "slideshow": { "slide_type": "slide" }, "tags": [] }, "source": [ "## A split-and-merge example using pandas and statsmodels\n", "\n", "- Split the full time series into n blocks (equal-length subseries).\n", "\n", "- Fit an ARIMA model to each block.\n", "\n", "- Collect ARIMA parameters.\n", "\n", "- Manual Forecast with ARIMA Global Estimator" ] }, { "cell_type": "code", "execution_count": 2, "id": "2fa2c56a-571d-425b-95dd-aa7329c73d0d", "metadata": { "editable": true, "slideshow": { "slide_type": "slide" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
demandtime
012864.0002003-03-01 00:00:00
112389.0002003-03-01 01:00:00
212155.0002003-03-01 02:00:00
312072.0002003-03-01 03:00:00
412162.0002003-03-01 04:00:00
.........
12128715199.8572016-12-31 19:00:00
12128814503.9942016-12-31 20:00:00
12128913829.0162016-12-31 21:00:00
12129013093.2052016-12-31 22:00:00
12129112370.6392016-12-31 23:00:00
\n", "

121292 rows × 2 columns

\n", "
" ], "text/plain": [ " demand time\n", "0 12864.000 2003-03-01 00:00:00\n", "1 12389.000 2003-03-01 01:00:00\n", "2 12155.000 2003-03-01 02:00:00\n", "3 12072.000 2003-03-01 03:00:00\n", "4 12162.000 2003-03-01 04:00:00\n", "... ... ...\n", "121287 15199.857 2016-12-31 19:00:00\n", "121288 14503.994 2016-12-31 20:00:00\n", "121289 13829.016 2016-12-31 21:00:00\n", "121290 13093.205 2016-12-31 22:00:00\n", "121291 12370.639 2016-12-31 23:00:00\n", "\n", "[121292 rows x 2 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "import numpy as np\n", "from statsmodels.tsa.arima.model import ARIMA\n", "\n", "# Load the data\n", "df = pd.read_csv(\"../data/electricity/TOTAL_train.csv\", parse_dates=[\"time\"])\n", "series = df[\"demand\"].values\n", "\n", "# Split into n blocks\n", "n_blocks = 5\n", "block_size = len(series) // n_blocks\n", "blocks = [series[i*block_size:(i+1)*block_size] for i in range(n_blocks)]\n", "\n", "df" ] }, { "cell_type": "code", "execution_count": 4, "id": "ce26d036-87d4-4207-97c1-0f972ef92bfb", "metadata": { "editable": true, "slideshow": { "slide_type": "slide" }, "tags": [] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ar.L1ma.L1constblock_id
00.6968680.678450139543.3491210
10.6962600.665971137549.8453471
20.6912610.661040127060.0161452
30.7110460.660685115836.9000343
40.7286990.67107194070.9486864
\n", "
" ], "text/plain": [ " ar.L1 ma.L1 const block_id\n", "0 0.696868 0.678450 139543.349121 0\n", "1 0.696260 0.665971 137549.845347 1\n", "2 0.691261 0.661040 127060.016145 2\n", "3 0.711046 0.660685 115836.900034 3\n", "4 0.728699 0.671071 94070.948686 4" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Fit ARIMA on each block\n", "arima_order = (1, 1, 1)\n", "params_list = []\n", "\n", "for i, block in enumerate(blocks):\n", " try:\n", " model = ARIMA(block, order=arima_order)\n", " result = model.fit()\n", " params_list.append(result.params)\n", " except Exception as e:\n", " print(f\"Block {i} failed: {e}\")\n", " params_list.append(None)\n", "\n", "# Collect parameters into DataFrame (corrected version)\n", "param_df = pd.DataFrame([p if p is not None else [np.nan]*3 for p in params_list],\n", " columns=[\"ar.L1\", \"ma.L1\", \"const\"])\n", "param_df[\"block_id\"] = range(n_blocks)\n", "\n", "param_df" ] }, { "cell_type": "code", "execution_count": 6, "id": "2efe6821-106f-46f9-97c8-857074d4d241", "metadata": { "editable": true, "slideshow": { "slide_type": "slide" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Global ARIMA Parameters (via DLSA-style averaging):\n", "ar.L1 0.704827\n", "ma.L1 0.667443\n", "const 122812.211867\n", "dtype: float64\n" ] } ], "source": [ "# Drop rows with NaNs (failed models)\n", "valid_params = param_df.dropna()\n", "\n", "# Compute the average of each parameter across blocks (DLSA style)\n", "global_estimator = valid_params[[\"ar.L1\", \"ma.L1\", \"const\"]].mean()\n", "\n", "print(\"Global ARIMA Parameters (via DLSA-style averaging):\")\n", "print(global_estimator)" ] }, { "cell_type": "code", "execution_count": 7, "id": "d630a07d-813f-4a7a-a330-108d89551bd6", "metadata": { "editable": true, "slideshow": { "slide_type": "slide" }, "tags": [] }, "outputs": [], "source": [ "# Use the last block as test set for forecasting\n", "test_block = blocks[-1]\n", "test_block_diff = np.diff(test_block) # since d=1\n", "\n", "# Extract global parameters\n", "phi = global_estimator[\"ar.L1\"]\n", "theta = global_estimator[\"ma.L1\"]\n", "const = global_estimator[\"const\"]\n", "\n", "# Initialize for recursive forecast\n", "forecast_horizon = 10\n", "y_last = test_block[-1]\n", "y_history = list(test_block)\n", "eps_history = [0] # start with zero error\n", "\n", "# Generate forecasts\n", "forecasts = []\n", "for h in range(forecast_horizon):\n", " # Simulate ARIMA(1,1,1) forecast\n", " y_prev = y_history[-1]\n", " eps_prev = eps_history[-1]\n", " eps_t = 0 # assume mean error\n", " y_diff_forecast = const + phi * (y_history[-1] - y_history[-2]) + theta * eps_prev + eps_t\n", " y_forecast = y_prev + y_diff_forecast\n", " forecasts.append(y_forecast)\n", " y_history.append(y_forecast)\n", " eps_history.append(eps_t)" ] }, { "cell_type": "code", "execution_count": 8, "id": "b9fbe2b0-cf19-48eb-90b4-bc7177bab90a", "metadata": { "editable": true, "slideshow": { "slide_type": "slide" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "[136165.48530924582,\n", " 345203.71456575696,\n", " 615351.6644408321,\n", " 928571.3807523709,\n", " 1272149.230682896,\n", " 1637124.3055957702,\n", " 2017180.7172828834,\n", " 2407866.858716991,\n", " 2806045.1181367626,\n", " 3209504.022807839]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "forecasts" ] }, { "cell_type": "markdown", "id": "9e313b0f-e64b-4eb1-8d3d-bda5b97aedf4", "metadata": { "editable": true, "slideshow": { "slide_type": "slide" }, "tags": [] }, "source": [ " ## A full PySpark script \n", " \n", " - implementing the DARIMA-style blockwise ARIMA training, parameter merging, and forecasting pipeline.\n", "\n", " - This script includes:\n", "\n", " - Loading your TOTAL_train.csv file.\n", " - Splitting the data into blocks using Spark.\n", " - Fitting ARIMA(1,1,1) models blockwise with a UDF.\n", " - Merging parameters via simple averaging (DLSA-style).\n", " - Forecasting future values using the merged global model." ] }, { "cell_type": "code", "execution_count": 9, "id": "c7ae37ba-526d-4047-b162-261c714cdc01", "metadata": { "editable": true, "slideshow": { "slide_type": "slide" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "25/03/31 20:47:10 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" ] }, { "data": { "text/html": [ "\n", "
\n", "

SparkSession - in-memory

\n", " \n", "
\n", "

SparkContext

\n", "\n", "

Spark UI

\n", "\n", "
\n", "
Version
\n", "
v3.5.4
\n", "
Master
\n", "
local[*]
\n", "
AppName
\n", "
Spark Forecasting
\n", "
\n", "
\n", " \n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import os, sys # Ensure All environment variables are properly set \n", "# os.environ[\"JAVA_HOME\"] = os.path.dirname(sys.executable)\n", "os.environ[\"PYSPARK_PYTHON\"] = sys.executable\n", "os.environ[\"PYSPARK_DRIVER_PYTHON\"] = sys.executable\n", "\n", "from pyspark.sql import SparkSession # build Spark Session\n", "spark = SparkSession.builder \\\n", " .config(\"spark.ui.enabled\", \"false\") \\\n", " .config(\"spark.executor.memory\", \"16g\") \\\n", " .config(\"spark.executor.cores\", \"4\") \\\n", " .config(\"spark.cores.max\", \"32\") \\\n", " .config(\"spark.driver.memory\", \"30g\") \\\n", " .config(\"spark.sql.shuffle.partitions\", \"96\") \\\n", " .config(\"spark.memory.fraction\", \"0.8\") \\\n", " .config(\"spark.memory.storageFraction\", \"0.5\") \\\n", " .config(\"spark.dynamicAllocation.enabled\", \"true\") \\\n", " .config(\"spark.dynamicAllocation.minExecutors\", \"4\") \\\n", " .config(\"spark.dynamicAllocation.maxExecutors\", \"8\") \\\n", " .appName(\"Spark Forecasting\").getOrCreate()\n", "spark" ] }, { "cell_type": "code", "execution_count": 12, "id": "7c1edeef-af6b-4aa8-9980-83b2fd971957", "metadata": { "editable": true, "slideshow": { "slide_type": "slide" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/fli/.local/miniforge3/lib/python3.12/site-packages/statsmodels/tsa/statespace/sarimax.py:966: UserWarning: Non-stationary starting autoregressive parameters found. Using zeros as starting parameters.\n", " warn('Non-stationary starting autoregressive parameters'\n", "/home/fli/.local/miniforge3/lib/python3.12/site-packages/statsmodels/tsa/statespace/sarimax.py:978: UserWarning: Non-invertible starting MA parameters found. Using zeros as starting parameters.\n", " warn('Non-invertible starting MA parameters found.'\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+--------+--------------------+--------------------+\n", "|block_id| series| params|\n", "+--------+--------------------+--------------------+\n", "| 3|[12072.0, 15213.0...|[0.27383518831865...|\n", "| 4|[12162.0, 15646.0...|[0.27356099915817...|\n", "| 1|[12389.0, 13238.0...|[0.27318032124865...|\n", "| 2|[12155.0, 14191.0...|[0.27372551206138...|\n", "| 0|[12864.0, 12569.0...|[0.27322740803644...|\n", "+--------+--------------------+--------------------+\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] } ], "source": [ "from pyspark.sql.functions import col, monotonically_increasing_id, floor, collect_list, udf\n", "from pyspark.sql.types import ArrayType, DoubleType\n", "import pandas as pd\n", "import numpy as np\n", "from statsmodels.tsa.arima.model import ARIMA\n", "\n", "# Step 1: Load data\n", "df = spark.read.csv(\"../data/electricity/TOTAL_train.csv\", header=True, inferSchema=True)\n", "\n", "# Step 2: Assign block IDs\n", "n_blocks = 5\n", "df = df.withColumn(\"row_id\", monotonically_increasing_id())\n", "df = df.withColumn(\"block_id\", floor(col(\"row_id\") % n_blocks))\n", "\n", "# Step 3: Group and collect series per block\n", "grouped_df = df.groupBy(\"block_id\").agg(collect_list(\"demand\").alias(\"series\"))\n", "\n", "# Step 4: Define ARIMA UDF\n", "def fit_arima(series):\n", " try:\n", " model = ARIMA(series, order=(1,1,1))\n", " result = model.fit()\n", " return result.params.tolist()\n", " except:\n", " return [float(\"nan\")] * 3\n", "\n", "fit_arima_udf = udf(fit_arima, ArrayType(DoubleType()))\n", "\n", "# Step 5: Apply UDF to each block\n", "arima_params_df = grouped_df.withColumn(\"params\", fit_arima_udf(col(\"series\")))\n", "arima_params_df.show()" ] }, { "cell_type": "code", "execution_count": 13, "id": "9c641212-b370-4aaa-8b22-b8dae1bf0261", "metadata": { "editable": true, "slideshow": { "slide_type": "slide" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 11:> (0 + 1) / 1]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Global ARIMA Parameters (DLSA-style):\n", " ar.L1 2.735059e-01\n", "ma.L1 -9.661642e-01\n", "const 6.174971e+06\n", "dtype: float64\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " " ] } ], "source": [ "# Step 6: Convert to Pandas for merging\n", "params_pd = arima_params_df.select(\"block_id\", \"params\").toPandas()\n", "params_df = pd.DataFrame(params_pd[\"params\"].tolist(), columns=[\"ar.L1\", \"ma.L1\", \"const\"])\n", "params_df[\"block_id\"] = params_pd[\"block_id\"]\n", "\n", "# Step 7: Merge step (average parameters)\n", "valid_params = params_df.dropna()\n", "global_estimator = valid_params[[\"ar.L1\", \"ma.L1\", \"const\"]].mean()\n", "print(\"Global ARIMA Parameters (DLSA-style):\\n\", global_estimator)" ] }, { "cell_type": "code", "execution_count": 14, "id": "256037e2-ba0a-42cb-b99f-6da94ba0c576", "metadata": { "editable": true, "slideshow": { "slide_type": "slide" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Global ARIMA Parameters (DLSA-style):\n", " ar.L1 2.735059e-01\n", "ma.L1 -9.661642e-01\n", "const 6.174971e+06\n", "dtype: float64\n", "Forecasts: [6187143.597192929, 14050950.928765967, 22376719.1027173, 30828836.2860519, 39315510.66711336, 47811636.64514864, 56310347.69058615, 64809765.76717326, 73309377.2209412, 81809041.56450628]\n" ] } ], "source": [ "# Step 8: Forecast using the global parameters\n", "df_pd = df.orderBy(\"row_id\").toPandas()\n", "series = df_pd[\"demand\"].values\n", "test_block = series[-(len(series)//n_blocks):]\n", "phi, theta, const = global_estimator\n", "\n", "# Manual forecast\n", "forecast_horizon = 10\n", "y_history = list(test_block)\n", "eps_history = [0]\n", "forecasts_darima = []\n", "\n", "for _ in range(forecast_horizon):\n", " y_prev = y_history[-1]\n", " y_diff = y_history[-1] - y_history[-2]\n", " eps_prev = eps_history[-1]\n", "\n", " y_diff_forecast = const + phi * y_diff + theta * eps_prev\n", " y_forecast = y_prev + y_diff_forecast\n", "\n", " forecasts_darima.append(y_forecast)\n", " y_history.append(y_forecast)\n", " eps_history.append(0)\n", "\n", "print(\"Forecasts:\", forecasts_darima)" ] }, { "cell_type": "markdown", "id": "d12189a5-5fd7-4811-b49b-95c782d6e5e4", "metadata": { "editable": true, "slideshow": { "slide_type": "slide" }, "tags": [] }, "source": [ "## Discussions\n", "\n", "- The above version is the poor man's DARIMA\n", "\n", "- Full DARIMA implementation https://github.com/xqnwang/darima" ] } ], "metadata": { "kernelspec": { "display_name": "Python3.12 (PySpark3.5.4)", "language": "python", "name": "pyspark" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.8" } }, "nbformat": 4, "nbformat_minor": 5 }