{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" }, "tags": [] }, "source": [ "# OLS Reconciliation with Spark\n", "\n", "## Feng Li\n", "\n", "### Guanghua School of Management\n", "### Peking University\n", "\n", "\n", "### [feng.li@gsm.pku.edu.cn](feng.li@gsm.pku.edu.cn)\n", "### Course home page: [https://feng.li/bdcf](https://feng.li/bdcf)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "25/03/26 19:44:15 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" ] }, { "data": { "text/html": [ "\n", "
\n", "

SparkSession - in-memory

\n", " \n", "
\n", "

SparkContext

\n", "\n", "

Spark UI

\n", "\n", "
\n", "
Version
\n", "
v3.5.3
\n", "
Master
\n", "
local[*]
\n", "
AppName
\n", "
Spark Forecasting
\n", "
\n", "
\n", " \n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import os, sys # Ensure All environment variables are properly set \n", "# os.environ[\"JAVA_HOME\"] = os.path.dirname(sys.executable)\n", "os.environ[\"PYSPARK_PYTHON\"] = sys.executable\n", "os.environ[\"PYSPARK_DRIVER_PYTHON\"] = sys.executable\n", "\n", "from pyspark.sql import SparkSession # build Spark Session\n", "spark = SparkSession.builder \\\n", " .config(\"spark.ui.enabled\", \"false\") \\\n", " .config(\"spark.executor.memory\", \"16g\") \\\n", " .config(\"spark.executor.cores\", \"4\") \\\n", " .config(\"spark.cores.max\", \"32\") \\\n", " .config(\"spark.driver.memory\", \"30g\") \\\n", " .config(\"spark.sql.shuffle.partitions\", \"96\") \\\n", " .config(\"spark.memory.fraction\", \"0.8\") \\\n", " .config(\"spark.memory.storageFraction\", \"0.5\") \\\n", " .config(\"spark.dynamicAllocation.enabled\", \"true\") \\\n", " .config(\"spark.dynamicAllocation.minExecutors\", \"4\") \\\n", " .config(\"spark.dynamicAllocation.maxExecutors\", \"8\") \\\n", " .appName(\"Spark Forecasting\").getOrCreate()\n", "spark" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "scrolled": false, "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "train_sdf = spark.read.csv(\"../data/tourism/tourism_train.csv\", header=True, inferSchema=True)\n", "test_sdf = spark.read.csv(\"../data/tourism/tourism_test.csv\", header=True, inferSchema=True)\n", "forecast_sdf = spark.read.csv(\"../data/tourism/ets_forecasts.csv\", header=True, inferSchema=True)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+----------+---------------+------------------+\n", "| date|Region_Category| Visitors|\n", "+----------+---------------+------------------+\n", "|1998-01-01| TotalAll|45151.071280099975|\n", "|1998-01-01| AAll|17515.502379600006|\n", "|1998-01-01| BAll|10393.618015699998|\n", "|1998-01-01| CAll| 8633.359046599999|\n", "|1998-01-01| DAll|3504.3133462000005|\n", "|1998-01-01| EAll| 3121.6191894|\n", "|1998-01-01| FAll|1850.7357734999998|\n", "|1998-01-01| GAll|131.92352909999997|\n", "|1998-01-01| AAAll| 4977.2096105|\n", "|1998-01-01| ABAll| 5322.738721600001|\n", "|1998-01-01| ACAll|3569.6213724000004|\n", "|1998-01-01| ADAll| 1472.9706096|\n", "|1998-01-01| AEAll| 1560.5142545|\n", "|1998-01-01| AFAll| 612.447811|\n", "|1998-01-01| BAAll| 3854.672582|\n", "|1998-01-01| BBAll|1653.9957826000002|\n", "|1998-01-01| BCAll| 2138.7473162|\n", "|1998-01-01| BDAll| 1395.3775834|\n", "|1998-01-01| BEAll|1350.8247515000003|\n", "|1998-01-01| CAAll| 6421.236419000001|\n", "+----------+---------------+------------------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "train_sdf.show()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+----------+---------------+------------------+\n", "| date|Region_Category| Visitors|\n", "+----------+---------------+------------------+\n", "|2016-01-01| TotalAll|45625.487797300026|\n", "|2016-01-01| AAll|14631.321547500002|\n", "|2016-01-01| BAll|11201.033523800006|\n", "|2016-01-01| CAll| 8495.718019999998|\n", "|2016-01-01| DAll| 3050.230027900001|\n", "|2016-01-01| EAll| 6153.800932700001|\n", "|2016-01-01| FAll| 1484.9254222|\n", "|2016-01-01| GAll| 608.4583232|\n", "|2016-01-01| AAAll| 3507.913989999999|\n", "|2016-01-01| ABAll| 5358.049206900001|\n", "|2016-01-01| ACAll|2816.0075377999997|\n", "|2016-01-01| ADAll|1210.4949652999999|\n", "|2016-01-01| AEAll| 1100.6643246|\n", "|2016-01-01| AFAll| 638.1915229|\n", "|2016-01-01| BAAll| 5354.437284900001|\n", "|2016-01-01| BBAll|1343.8020357999999|\n", "|2016-01-01| BCAll|1888.8476593999999|\n", "|2016-01-01| BDAll|1631.0206729999998|\n", "|2016-01-01| BEAll| 982.9258707|\n", "|2016-01-01| CAAll| 5494.1272128|\n", "+----------+---------------+------------------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "test_sdf.show()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+----------+---------------+------------------+\n", "| date|Region_Category| Forecast|\n", "+----------+---------------+------------------+\n", "|2015-12-01| AAAAll| 2058.838101212888|\n", "|2016-01-01| AAAAll|3162.9085270260384|\n", "|2016-02-01| AAAAll|1744.1909768938476|\n", "|2016-03-01| AAAAll|2059.3010229302345|\n", "|2016-04-01| AAAAll| 2060.36170915585|\n", "|2016-05-01| AAAAll| 1972.482680954841|\n", "|2016-06-01| AAAAll| 1846.108381522378|\n", "|2016-07-01| AAAAll| 2151.959971338432|\n", "|2016-08-01| AAAAll| 1872.768734964083|\n", "|2016-09-01| AAAAll|2014.0112033543805|\n", "|2016-10-01| AAAAll|2296.4055573391643|\n", "|2016-11-01| AAAAll| 2043.693618904705|\n", "|2015-12-01| AAABus|455.55263433165527|\n", "|2016-01-01| AAABus| 296.1898590727801|\n", "|2016-02-01| AAABus| 453.3714726155002|\n", "|2016-03-01| AAABus| 525.339287022726|\n", "|2016-04-01| AAABus| 459.2040763240983|\n", "|2016-05-01| AAABus| 554.7602408402748|\n", "|2016-06-01| AAABus| 498.8249232185846|\n", "|2016-07-01| AAABus| 604.1654402560752|\n", "+----------+---------------+------------------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "forecast_sdf.show()" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "skipping\n" ] } ], "source": [ "%%script echo skipping\n", "\n", "# If you want to have alternative forecasts, change the following code to obtain `forecast_sdf`\n", "\n", "from pyspark.sql.functions import explode, col\n", "from pyspark.sql.types import StructType, StructField, StringType, ArrayType, DoubleType, DateType\n", "import pandas as pd\n", "from pandas.tseries.offsets import MonthBegin\n", "from statsmodels.tsa.holtwinters import ExponentialSmoothing\n", "\n", "# Define schema for the forecast output\n", "forecast_schema = StructType([\n", " StructField(\"date\", DateType(), False),\n", " StructField(\"Region_Category\", StringType(), False),\n", " StructField(\"Forecast\", DoubleType(), False)\n", "])\n", "\n", "\n", "def ets_forecast(pdf):\n", " \"\"\"Fits an ETS model for a single region and forecasts 12 months ahead.\"\"\"\n", " region = pdf[\"Region_Category\"].iloc[0] # Extract region name\n", " pdf = pdf.sort_values(\"date\") # Ensure time series is sorted\n", "\n", " try:\n", " # Drop missing values\n", " ts = pdf[\"Visitors\"].dropna()\n", "\n", " if len(ts) >= 24: # Ensure at least 24 observations\n", " model = ExponentialSmoothing(ts, trend=\"add\", seasonal=\"add\", seasonal_periods=12)\n", " fitted_model = model.fit()\n", " forecast = fitted_model.forecast(steps=12)\n", " else:\n", " forecast = [None] * 12 # Not enough data\n", " except:\n", " forecast = [None] * 12 # Handle errors\n", "\n", " # Adjust forecast dates to start of the month\n", " last_date = pdf[\"date\"].max()\n", " future_dates = pd.date_range(start=last_date, periods=12, freq=\"ME\") + MonthBegin(1)\n", "\n", " # Return results as a DataFrame\n", " return pd.DataFrame({\"date\": future_dates, \"Region_Category\": region, \"Forecast\": forecast})\n", "\n", "# Apply the ETS model in parallel using applyInPandas\n", "forecast_sdf = train_sdf.groupBy(\"Region_Category\").applyInPandas(ets_forecast, schema=forecast_schema)\n", "\n", "# Show forecasted results\n", "forecast_sdf.show()\n", "\n", "# Save forecasts if needed\n", "forecast_sdf.write.csv(os.path.expanduser(\"~/ets_forecasts.csv\"), header=True, mode=\"overwrite\")\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## MinT-OLS approximation\n", "\n", "Since PySpark doesn't support matrix operations like NumPy, we'll use the **MinT-OLS approximation**, which assumes the forecast error covariance matrix \\( W = I \\) (identity matrix). This simplifies the formula:\n", "\n", "$$\n", "\\tilde{y} = S (S^\\top S)^{-1} S^\\top \\hat{y}\n", "$$\n", "\n", "\n", "\n", "- `forecast_sdf` contains **base forecasts** for each `Region_Category` and `date`.\n", "- `test_sdf` is your **test set** with actual `Visitors` by `Region_Category` and `date`.\n", "- You have `summing_sdf_long` with: \n", " - `Parent_Group` \n", " - `Region_Category` \n", " - `Weight` (usually 0 or 1)\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Compute MinT-OLS reconciliation in PySpark\n", "\n", "- MinT-OLS is equivalent to solving the linear regression problem \n", "$$\n", "\\hat{y} = S\\beta + \\varepsilon\n", "$$\n", "\n", "Then: $\\tilde{y} = S\\hat{\\beta}$, which means: \n", "\n", "- You can use **`LinearRegression` from `pyspark.ml.regression`** to compute \n", "$$\n", "\\hat{\\beta} = (S^\\top S)^{-1} S^\\top \\hat{y}\n", "$$\n", "\n", "on the **bottom-level base forecasts**\n", "\n", "- Then multiply $ S \\hat{\\beta} $ to get the reconciled forecasts.\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## MinT-OLS via `LinearRegression` in PySpark\n", "\n", "\n", "- You have `forecast_sdf`: bottom-level base forecasts (`date`, `Region_Category`, `Forecast`)\n", "- You have `summing_sdf_long`: mapping from `Region_Category → Parent_Group`\n", "\n", "- For each `date`, collect base forecasts into a vector $ \\hat{y} $, fit\n", " ```python\n", " LinearRegression(labelCol=\"Forecast\", featuresCol=\"features\")\n", " ```\n", " where `features = S` (bottom-level structure) and label is each forecast\n", " \n", "- Predict: $ S \\hat{\\beta} $ to get reconciled total forecasts" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/03/26 19:44:22 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+------------+---------------+------+\n", "|Parent_Group|Region_Category|Weight|\n", "+------------+---------------+------+\n", "| TotalAll| AAAHol| 1.0|\n", "| TotalAll| AAAVis| 1.0|\n", "| TotalAll| AAABus| 1.0|\n", "| TotalAll| AAAOth| 1.0|\n", "| TotalAll| AABHol| 1.0|\n", "| TotalAll| AABVis| 1.0|\n", "| TotalAll| AABBus| 1.0|\n", "| TotalAll| AABOth| 1.0|\n", "| TotalAll| ABAHol| 1.0|\n", "| TotalAll| ABAVis| 1.0|\n", "| TotalAll| ABABus| 1.0|\n", "| TotalAll| ABAOth| 1.0|\n", "| TotalAll| ABBHol| 1.0|\n", "| TotalAll| ABBVis| 1.0|\n", "| TotalAll| ABBBus| 1.0|\n", "| TotalAll| ABBOth| 1.0|\n", "| TotalAll| ACAHol| 1.0|\n", "| TotalAll| ACAVis| 1.0|\n", "| TotalAll| ACABus| 1.0|\n", "| TotalAll| ACAOth| 1.0|\n", "+------------+---------------+------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "from pyspark.sql.functions import col, sum as spark_sum\n", "\n", "# Load the summing matrix file\n", "summing_matrix_path = \"../data/tourism/agg_mat.csv\" # Update with actual path\n", "\n", "# Load the summing matrix file (skip the first column)\n", "summing_sdf = spark.read.csv(summing_matrix_path, header=True, inferSchema=True)\n", "\n", "\n", "# Convert from wide format to long format (Region_Category, Parent_Group, Weight)\n", "summing_sdf_long = summing_sdf.selectExpr(\n", " \"Parent_Group\",\n", " \"stack(\" + str(len(summing_sdf.columns) - 1) + \", \" +\n", " \", \".join([f\"'{col}', {col}\" for col in summing_sdf.columns if col != \"Parent_Group\"]) +\n", " \") as (Region_Category, Weight)\"\n", ")\n", "\n", "# Show the reshaped summing matrix\n", "summing_sdf_long.show()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": false, "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "from pyspark.sql.functions import row_number, first\n", "from pyspark.sql.window import Window\n", "\n", "# Add forecast step index (1-12) per Region_Category\n", "window = Window.partitionBy(\"Region_Category\").orderBy(\"date\")\n", "forecast_with_step_sdf = forecast_sdf.withColumn(\"step\", row_number().over(window))\n", "\n", "# Pivot into wide format: Region_Category × [1,...,12]\n", "forecast_wide_sdf = forecast_with_step_sdf.groupBy(\"Region_Category\").pivot(\"step\").agg(first(\"Forecast\"))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "# From summing_sdf_long: Parent_Group, Region_Category, Weight\n", "# Create design matrix where Region_Category is row, Parent_Groups are columns\n", "s_matrix_T = summing_sdf_long.groupBy(\"Region_Category\") \\\n", " .pivot(\"Parent_Group\").agg(first(\"Weight\")).fillna(0)\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "scrolled": false, "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/03/26 19:44:30 WARN Instrumentation: [55c02424] regParam is zero, which might cause numerical instability and overfitting.\n", "25/03/26 19:44:31 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS\n", "25/03/26 19:44:32 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.lapack.JNILAPACK\n", "25/03/26 19:44:32 WARN Instrumentation: [55c02424] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver.\n", "25/03/26 19:44:33 ERROR LBFGS: Failure! Resetting history: breeze.optimize.FirstOrderException: Line search zoom failed\n", "25/03/26 19:44:33 ERROR LBFGS: Failure again! Giving up and returning. Maybe the objective is just poorly behaved?\n", "25/03/26 19:44:40 WARN Instrumentation: [551fd9d3] regParam is zero, which might cause numerical instability and overfitting.\n", "25/03/26 19:44:41 WARN Instrumentation: [551fd9d3] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver.\n", "25/03/26 19:44:47 WARN Instrumentation: [77ef9881] regParam is zero, which might cause numerical instability and overfitting.\n", "25/03/26 19:44:47 WARN Instrumentation: [77ef9881] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver.\n", "25/03/26 19:44:47 ERROR LBFGS: Failure! Resetting history: breeze.optimize.FirstOrderException: Line search zoom failed\n", "25/03/26 19:44:47 ERROR LBFGS: Failure again! Giving up and returning. Maybe the objective is just poorly behaved?\n", "25/03/26 19:44:53 WARN Instrumentation: [40a25d43] regParam is zero, which might cause numerical instability and overfitting.\n", "25/03/26 19:44:54 WARN Instrumentation: [40a25d43] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver.\n", "25/03/26 19:44:54 ERROR LBFGS: Failure! Resetting history: breeze.optimize.FirstOrderException: Line search zoom failed\n", "25/03/26 19:44:54 ERROR LBFGS: Failure! Resetting history: breeze.optimize.FirstOrderException: Line search zoom failed\n", "25/03/26 19:44:54 ERROR LBFGS: Failure! Resetting history: breeze.optimize.FirstOrderException: Line search zoom failed\n", "25/03/26 19:44:54 ERROR LBFGS: Failure again! Giving up and returning. Maybe the objective is just poorly behaved?\n", "25/03/26 19:44:59 WARN Instrumentation: [527231ad] regParam is zero, which might cause numerical instability and overfitting.\n", "25/03/26 19:45:00 WARN Instrumentation: [527231ad] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver.\n", "25/03/26 19:45:06 WARN Instrumentation: [c4d24db2] regParam is zero, which might cause numerical instability and overfitting.\n", "25/03/26 19:45:07 WARN Instrumentation: [c4d24db2] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver.\n", "25/03/26 19:45:07 ERROR LBFGS: Failure! Resetting history: breeze.optimize.FirstOrderException: Line search zoom failed\n", "25/03/26 19:45:12 WARN Instrumentation: [65615c88] regParam is zero, which might cause numerical instability and overfitting.\n", "25/03/26 19:45:13 WARN Instrumentation: [65615c88] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver.\n", "25/03/26 19:45:13 ERROR LBFGS: Failure! Resetting history: breeze.optimize.FirstOrderException: Line search zoom failed\n", "25/03/26 19:45:13 ERROR LBFGS: Failure again! Giving up and returning. Maybe the objective is just poorly behaved?\n", "25/03/26 19:45:19 WARN Instrumentation: [bb3a2485] regParam is zero, which might cause numerical instability and overfitting.\n", "25/03/26 19:45:19 WARN Instrumentation: [bb3a2485] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver.\n", "25/03/26 19:45:25 WARN Instrumentation: [99ce8314] regParam is zero, which might cause numerical instability and overfitting.\n", "25/03/26 19:45:25 WARN Instrumentation: [99ce8314] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver.\n", "25/03/26 19:45:31 WARN Instrumentation: [b9e78f4b] regParam is zero, which might cause numerical instability and overfitting.\n", "25/03/26 19:45:32 WARN Instrumentation: [b9e78f4b] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver.\n", "25/03/26 19:45:37 WARN Instrumentation: [7af5a814] regParam is zero, which might cause numerical instability and overfitting.\n", "25/03/26 19:45:38 WARN Instrumentation: [7af5a814] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver.\n", "25/03/26 19:45:44 WARN Instrumentation: [0515ac6c] regParam is zero, which might cause numerical instability and overfitting.\n", "25/03/26 19:45:44 WARN Instrumentation: [0515ac6c] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver.\n" ] } ], "source": [ "from pyspark.ml.feature import VectorAssembler\n", "from pyspark.ml.regression import LinearRegression\n", "from pyspark.sql.functions import col, lit\n", "\n", "reconciled_forecasts = []\n", "\n", "for step in range(1, 13):\n", " step_col = str(step)\n", "\n", " training_df = forecast_wide_sdf.select(\"Region_Category\", step_col) \\\n", " .join(s_matrix_T, on=\"Region_Category\", how=\"inner\")\n", "\n", " feature_cols = [c for c in training_df.columns if c not in [\"Region_Category\", step_col]]\n", "\n", " assembler = VectorAssembler(inputCols=feature_cols, outputCol=\"features\")\n", " assembled_df = assembler.transform(training_df).select(\"Region_Category\", \"features\", col(step_col).alias(\"label\"))\n", "\n", " lr = LinearRegression(featuresCol=\"features\", labelCol=\"label\")\n", " model = lr.fit(assembled_df)\n", "\n", " reconciled = model.transform(assembled_df) \\\n", " .select(\"Region_Category\", col(\"prediction\").alias(\"Forecast\")) \\\n", " .withColumn(\"step\", lit(step)) # ← Tag with forecast step number\n", "\n", " reconciled_forecasts.append(reconciled)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'/nfs-share/software/anaconda/2020.02/envs/python3.12/bin/python'" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import sys\n", "sys.executable" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "data": { "text/plain": [ "[DataFrame[Region_Category: string, Forecast: double, step: int],\n", " DataFrame[Region_Category: string, Forecast: double, step: int],\n", " DataFrame[Region_Category: string, Forecast: double, step: int],\n", " DataFrame[Region_Category: string, Forecast: double, step: int],\n", " DataFrame[Region_Category: string, Forecast: double, step: int],\n", " DataFrame[Region_Category: string, Forecast: double, step: int],\n", " DataFrame[Region_Category: string, Forecast: double, step: int],\n", " DataFrame[Region_Category: string, Forecast: double, step: int],\n", " DataFrame[Region_Category: string, Forecast: double, step: int],\n", " DataFrame[Region_Category: string, Forecast: double, step: int],\n", " DataFrame[Region_Category: string, Forecast: double, step: int],\n", " DataFrame[Region_Category: string, Forecast: double, step: int]]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "reconciled_forecasts" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/03/26 19:55:40 WARN DAGScheduler: Broadcasting large task binary with size 4.9 MiB\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+---------------+------------------+----+----------+\n", "|Region_Category| Forecast|step| date|\n", "+---------------+------------------+----+----------+\n", "| DDBHol| 82.07811032786951| 1|2016-01-01|\n", "| FBAVis| 6.656805280557123| 1|2016-01-01|\n", "| EABVis|414.37511192628403| 1|2016-01-01|\n", "| BCBOth| 6.196280508794359| 1|2016-01-01|\n", "| FAAHol|150.30406627752933| 1|2016-01-01|\n", "| CBCHol| 67.29236126836824| 1|2016-01-01|\n", "| GABVis| 2.693192208874464| 1|2016-01-01|\n", "| CDBHol|0.5679754672918733| 1|2016-01-01|\n", "| DBABus| 5.316852903629325| 1|2016-01-01|\n", "| BDCBus|20.895111901507477| 1|2016-01-01|\n", "| DAAVis| 274.7963041881968| 1|2016-01-01|\n", "| CAAHol| 585.8461906935038| 1|2016-01-01|\n", "| AEDVis| 50.20844923996796| 1|2016-01-01|\n", "| BBAHol| 292.4243221681629| 1|2016-01-01|\n", "| ABABus|112.69985746271419| 1|2016-01-01|\n", "| CBBVis| 90.58366645794179| 1|2016-01-01|\n", "| BDBVis| 58.23528346223553| 1|2016-01-01|\n", "| ADDVis|111.02943173753079| 1|2016-01-01|\n", "| CCAVis|22.293612674097204| 1|2016-01-01|\n", "| BEHVis|19.081634731302486| 1|2016-01-01|\n", "+---------------+------------------+----+----------+\n", "only showing top 20 rows\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", "[Stage 380:> (0 + 1) / 1]\r", "\r", " \r" ] } ], "source": [ "# Union all 12 steps into one DataFrame\n", "from functools import reduce\n", "from pyspark.sql.functions import expr\n", "from pyspark.sql.functions import min as spark_min\n", "\n", "reconciled_long_df = reduce(lambda df1, df2: df1.unionByName(df2), reconciled_forecasts)\n", "\n", "# Get the first forecast date from the test set\n", "min_date_row = test_sdf.select(spark_min(\"date\").alias(\"start_date\")).collect()[0]\n", "start_date = min_date_row[\"start_date\"].strftime(\"%Y-%m-%d\")\n", "\n", "# Add forecast dates based on step number\n", "reconciled_long_df = reconciled_long_df.withColumn(\n", " \"date\", expr(f\"add_months(to_date('{start_date}'), step -1 )\")\n", ")\n", "\n", "reconciled_long_df.show()" ] }, { "cell_type": "code", "execution_count": 34, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "from pyspark.sql.functions import sum as spark_sum\n", "\n", "reconciled_with_hierarchy_df = reconciled_long_df.join(\n", " summing_sdf_long, on=\"Region_Category\", how=\"inner\"\n", ")\n", "\n", "reconciled_agg_df = reconciled_with_hierarchy_df.withColumn(\n", " \"Weighted_Forecast\", col(\"Forecast\") * col(\"Weight\")\n", ").groupBy(\"Parent_Group\", \"date\").agg(\n", " spark_sum(\"Weighted_Forecast\").alias(\"Reconciled_Forecast\")\n", ")\n", "\n", "test_with_hierarchy_df = test_sdf.join(summing_sdf_long, on=\"Region_Category\", how=\"inner\")\n", "\n", "test_agg_df = test_with_hierarchy_df.withColumn(\n", " \"Weighted_Actual\", col(\"Visitors\") * col(\"Weight\")\n", ").groupBy(\"Parent_Group\", \"date\").agg(\n", " spark_sum(\"Weighted_Actual\").alias(\"Actual_Visitors\")\n", ")" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- Parent_Group: string (nullable = true)\n", " |-- date: date (nullable = true)\n", " |-- Reconciled_Forecast: double (nullable = true)\n", "\n", "root\n", " |-- Parent_Group: string (nullable = true)\n", " |-- date: date (nullable = true)\n", " |-- Actual_Visitors: double (nullable = true)\n", "\n" ] } ], "source": [ "# Check date column types\n", "reconciled_agg_df.printSchema()\n", "test_agg_df.printSchema()\n" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+----------+\n", "| date|\n", "+----------+\n", "|2016-01-01|\n", "|2016-02-01|\n", "|2016-03-01|\n", "|2016-04-01|\n", "|2016-05-01|\n", "|2016-06-01|\n", "|2016-07-01|\n", "|2016-08-01|\n", "|2016-09-01|\n", "|2016-10-01|\n", "|2016-11-01|\n", "|2016-12-01|\n", "+----------+\n", "\n", "+----------+\n", "| date|\n", "+----------+\n", "|2016-01-01|\n", "|2016-02-01|\n", "|2016-03-01|\n", "|2016-04-01|\n", "|2016-05-01|\n", "|2016-06-01|\n", "|2016-07-01|\n", "|2016-08-01|\n", "|2016-09-01|\n", "|2016-10-01|\n", "|2016-11-01|\n", "|2016-12-01|\n", "+----------+\n", "\n" ] } ], "source": [ "reconciled_agg_df.select(\"date\").distinct().orderBy(\"date\").show(20)\n", "test_agg_df.select(\"date\").distinct().orderBy(\"date\").show(20)\n" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/03/26 20:06:24 WARN DAGScheduler: Broadcasting large task binary with size 6.0 MiB\n", "[Stage 495:> (0 + 12) / 12]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+------------+----------+-------------------+------------------+\n", "|Parent_Group| date|Reconciled_Forecast| Actual_Visitors|\n", "+------------+----------+-------------------+------------------+\n", "| GBDBus|2016-01-01| 10.705897775124726| 28.5621226|\n", "| CCBVis|2016-01-01| 127.8105261721546| 136.5528404|\n", "| BECVis|2016-01-01| 62.46237010369159| 15.9456186|\n", "| FAAll|2016-01-01| 349.9323421850451| 763.3077524999999|\n", "| GACOth|2016-01-01|0.46048233024117735| 9.8127999|\n", "| EAAOth|2016-01-01| 2.9080098567940524| 5.7992456|\n", "| DCDBus|2016-01-01| 59.450079656240774| 46.0093885|\n", "| BDBOth|2016-01-01| 6.4037361653122815| 0.0|\n", "| AECOth|2016-01-01| 41.90321160149263| 10.9008486|\n", "| DDHol|2016-01-01| 102.51594592069989|359.95267900000005|\n", "| EHol|2016-01-01| 757.846086951482| 2812.781809|\n", "| EAAVis|2016-01-01| 52.72819361817362| 30.1005963|\n", "| BCBOth|2016-01-01| 6.196280508794359| 6.2829104|\n", "| BACBus|2016-01-01| 19.415417443459482| 0.0|\n", "| DAOth|2016-01-01| 212.9483219880252| 28.8082796|\n", "| CAVis|2016-01-01| 1280.131137158587| 1554.5346126|\n", "| ABAll|2016-01-01| 2088.705531118374| 5358.049206900001|\n", "| FBBHol|2016-01-01| 65.01918861007131| 169.5561995|\n", "| CBDOth|2016-01-01|-0.9901700392533428| 0.0|\n", "| BDDHol|2016-01-01| 21.003014900333262| 29.0369307|\n", "+------------+----------+-------------------+------------------+\n", "only showing top 20 rows\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "25/03/26 20:06:28 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB\n", "\r", " \r" ] } ], "source": [ "evaluation_df = reconciled_agg_df.join(test_agg_df, on=[\"Parent_Group\", \"date\"], how=\"inner\")\n", "\n", "evaluation_df.show()" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/03/26 20:06:51 WARN DAGScheduler: Broadcasting large task binary with size 6.0 MiB\n", "25/03/26 20:06:54 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB\n", "25/03/26 20:06:55 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB\n", " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+------------+-------------------+\n", "|Parent_Group| MAPE|\n", "+------------+-------------------+\n", "| BCBOth| 1.8814916288630412|\n", "| BDEAll| 0.6319932719308724|\n", "| CCOth| 1.1371127959837342|\n", "| CCBAll|0.23007099726028646|\n", "| FBAVis| 4.193833533298544|\n", "| DCCAll| 1.0170059992164842|\n", "| DDBHol| 0.6989044320419716|\n", "| BCHol| 0.6848869709479803|\n", "| EABVis| 0.2597990646591048|\n", "| CBDAll|0.28554250724533875|\n", "| ADBAll| 0.3511071573477474|\n", "| GBCAll| 1.822907197252322|\n", "| FAAHol|0.26537276454971037|\n", "| BDFAll| 0.3958432127605984|\n", "| CBCHol| 1.09175933047297|\n", "| GABVis| 0.5505873369776451|\n", "| CAVis| 0.3025095665083698|\n", "| BDBAll| 0.6698985120653421|\n", "| BEGAll| 0.5083338077197483|\n", "| DABus|0.40018259421342434|\n", "+------------+-------------------+\n", "only showing top 20 rows\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "25/03/26 20:07:11 WARN DAGScheduler: Broadcasting large task binary with size 6.0 MiB\n", "25/03/26 20:07:14 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB\n", "25/03/26 20:07:15 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+------------------+\n", "| Overall_MAPE|\n", "+------------------+\n", "|1.1076102791937525|\n", "+------------------+\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " \r" ] } ], "source": [ "from pyspark.sql.functions import abs, avg\n", "\n", "evaluation_df = evaluation_df.withColumn(\n", " \"APE\", abs((col(\"Reconciled_Forecast\") - col(\"Actual_Visitors\")) / col(\"Actual_Visitors\"))\n", ")\n", "\n", "# MAPE per group\n", "mape_df = evaluation_df.groupBy(\"Parent_Group\").agg(avg(\"APE\").alias(\"MAPE\"))\n", "mape_df.show()\n", "\n", "# Overall MAPE\n", "overall_mape_df = mape_df.agg(avg(\"MAPE\").alias(\"Overall_MAPE\"))\n", "overall_mape_df.show()" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/03/26 20:10:28 WARN DAGScheduler: Broadcasting large task binary with size 6.0 MiB\n", "25/03/26 20:10:30 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB\n", "25/03/26 20:10:31 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB\n", "25/03/26 20:10:32 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+-------+------------+------------------+\n", "|summary|Parent_Group| MAPE|\n", "+-------+------------+------------------+\n", "| count| 555| 555|\n", "| mean| NULL|1.1076102791937525|\n", "| stddev| NULL|1.3666645729588924|\n", "| min| AAAAll|0.1381354920151718|\n", "| max| TotalVis|14.210951585958131|\n", "+-------+------------+------------------+\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " \r" ] } ], "source": [ "mape_df.describe().show()" ] }, { "cell_type": "code", "execution_count": 42, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/03/26 20:16:10 WARN DAGScheduler: Broadcasting large task binary with size 6.0 MiB\n", "25/03/26 20:16:13 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB\n", "25/03/26 20:16:14 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB\n", " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+------------+------------------+\n", "|Parent_Group| MAPE|\n", "+------------+------------------+\n", "| DCDOth|14.210951585958131|\n", "| GBBBus| 12.0202762100509|\n", "| GAAOth| 8.483838150123491|\n", "| GBDVis| 7.982046555639922|\n", "| BEGOth| 7.614402179032675|\n", "| EACOth| 7.401889674658579|\n", "| DACBus| 6.429812651519794|\n", "| DCDBus| 6.02555705573783|\n", "| BEGHol| 5.946961838257663|\n", "| CCBOth| 5.468544199643766|\n", "+------------+------------------+\n", "only showing top 10 rows\n", "\n" ] } ], "source": [ "mape_df.orderBy(col(\"MAPE\").desc()).show(10)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/03/26 20:17:42 WARN DAGScheduler: Broadcasting large task binary with size 6.0 MiB\n", "25/03/26 20:17:44 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+------------+-------------------+\n", "|Parent_Group| MAPE|\n", "+------------+-------------------+\n", "| TotalOth| 0.1381354920151718|\n", "| AAAAll|0.15061493949148597|\n", "| BAAAll|0.15118372389621185|\n", "| CABAll| 0.1558753344427696|\n", "| AAAll|0.15773976873635803|\n", "| ABus| 0.1640455398594557|\n", "| CAll|0.16662182767530095|\n", "| DAAAll|0.17184814368002907|\n", "| TotalBus|0.17234491596252519|\n", "| AFAAll|0.17411943563895874|\n", "| AFAll|0.17411943563895874|\n", "| CDAll| 0.1816424382790586|\n", "| DAAll|0.19273337510343946|\n", "| EABAll| 0.195400784344689|\n", "| CAAll|0.20112517708058755|\n", "| BAAll|0.20431117238316943|\n", "| EAll|0.20736614485709745|\n", "| TotalAll|0.20956702567115104|\n", "| CBAll| 0.2104987151201196|\n", "| BBus|0.21364902571913846|\n", "+------------+-------------------+\n", "only showing top 20 rows\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "25/03/26 20:17:46 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB\n", "\r", " \r" ] } ], "source": [ "mape_df.orderBy(\"MAPE\").show()" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Multistep reconciliation with linear regression\n", "\n", "\n", "\n", "> If you **stack forecasts from multiple time steps** (e.g., 12 months) as **separate columns**, you can treat it as **a multivariate linear regression problem** with **shared design matrix $ S $**.\n", "\n", "This means:\n", "- **Each row** of $ Y $ is a **Region_Category**\n", "- **Each column** of $ Y $ is a forecast time step (e.g., 1st month, 2nd month...)\n", "- You can run **one Spark MLlib linear regression per time step**, or fit all at once if you flatten it.\n", "\n", "\n", "### What you’re doing conceptually\n", "You solve $Y = S\\beta + E$ where:\n", "- $ Y $: matrix of stacked forecasts (`Region_Category` × `horizon`)\n", "- $ S $: summing matrix (`Parent_Group` × `Region_Category`)\n", "- $ \\beta $: regression coefficients (`Region_Category` × `horizon`)\n", "- Reconciled forecasts: $\\tilde{Y} = S\\hat{\\beta}$\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Summary of MinT-OLS\n", "\n", "- **MinT-OLS (simple projection)** \n", "- **No covariance matrix needed** \n", "- **Coherent forecasts at Parent_Group level** \n", "- **Evaluated using MAPE**\n" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "python3.12", "language": "python", "name": "python3.12" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 4 }