OLS Reconciliation with Spark¶
Feng Li¶
Guanghua School of Management¶
Peking University¶
feng.li@gsm.pku.edu.cn¶
Course home page: https://feng.li/bdcf¶
import os, sys # Ensure All environment variables are properly set
# os.environ["JAVA_HOME"] = os.path.dirname(sys.executable)
os.environ["PYSPARK_PYTHON"] = sys.executable
os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable
from pyspark.sql import SparkSession # build Spark Session
spark = SparkSession.builder \
.config("spark.ui.enabled", "false") \
.config("spark.executor.memory", "16g") \
.config("spark.executor.cores", "4") \
.config("spark.cores.max", "32") \
.config("spark.driver.memory", "30g") \
.config("spark.sql.shuffle.partitions", "96") \
.config("spark.memory.fraction", "0.8") \
.config("spark.memory.storageFraction", "0.5") \
.config("spark.dynamicAllocation.enabled", "true") \
.config("spark.dynamicAllocation.minExecutors", "4") \
.config("spark.dynamicAllocation.maxExecutors", "8") \
.appName("Spark Forecasting").getOrCreate()
spark
Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). 25/03/26 19:44:15 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
SparkSession - in-memory
train_sdf = spark.read.csv("../data/tourism/tourism_train.csv", header=True, inferSchema=True)
test_sdf = spark.read.csv("../data/tourism/tourism_test.csv", header=True, inferSchema=True)
forecast_sdf = spark.read.csv("../data/tourism/ets_forecasts.csv", header=True, inferSchema=True)
train_sdf.show()
+----------+---------------+------------------+ | date|Region_Category| Visitors| +----------+---------------+------------------+ |1998-01-01| TotalAll|45151.071280099975| |1998-01-01| AAll|17515.502379600006| |1998-01-01| BAll|10393.618015699998| |1998-01-01| CAll| 8633.359046599999| |1998-01-01| DAll|3504.3133462000005| |1998-01-01| EAll| 3121.6191894| |1998-01-01| FAll|1850.7357734999998| |1998-01-01| GAll|131.92352909999997| |1998-01-01| AAAll| 4977.2096105| |1998-01-01| ABAll| 5322.738721600001| |1998-01-01| ACAll|3569.6213724000004| |1998-01-01| ADAll| 1472.9706096| |1998-01-01| AEAll| 1560.5142545| |1998-01-01| AFAll| 612.447811| |1998-01-01| BAAll| 3854.672582| |1998-01-01| BBAll|1653.9957826000002| |1998-01-01| BCAll| 2138.7473162| |1998-01-01| BDAll| 1395.3775834| |1998-01-01| BEAll|1350.8247515000003| |1998-01-01| CAAll| 6421.236419000001| +----------+---------------+------------------+ only showing top 20 rows
test_sdf.show()
+----------+---------------+------------------+ | date|Region_Category| Visitors| +----------+---------------+------------------+ |2016-01-01| TotalAll|45625.487797300026| |2016-01-01| AAll|14631.321547500002| |2016-01-01| BAll|11201.033523800006| |2016-01-01| CAll| 8495.718019999998| |2016-01-01| DAll| 3050.230027900001| |2016-01-01| EAll| 6153.800932700001| |2016-01-01| FAll| 1484.9254222| |2016-01-01| GAll| 608.4583232| |2016-01-01| AAAll| 3507.913989999999| |2016-01-01| ABAll| 5358.049206900001| |2016-01-01| ACAll|2816.0075377999997| |2016-01-01| ADAll|1210.4949652999999| |2016-01-01| AEAll| 1100.6643246| |2016-01-01| AFAll| 638.1915229| |2016-01-01| BAAll| 5354.437284900001| |2016-01-01| BBAll|1343.8020357999999| |2016-01-01| BCAll|1888.8476593999999| |2016-01-01| BDAll|1631.0206729999998| |2016-01-01| BEAll| 982.9258707| |2016-01-01| CAAll| 5494.1272128| +----------+---------------+------------------+ only showing top 20 rows
forecast_sdf.show()
+----------+---------------+------------------+ | date|Region_Category| Forecast| +----------+---------------+------------------+ |2015-12-01| AAAAll| 2058.838101212888| |2016-01-01| AAAAll|3162.9085270260384| |2016-02-01| AAAAll|1744.1909768938476| |2016-03-01| AAAAll|2059.3010229302345| |2016-04-01| AAAAll| 2060.36170915585| |2016-05-01| AAAAll| 1972.482680954841| |2016-06-01| AAAAll| 1846.108381522378| |2016-07-01| AAAAll| 2151.959971338432| |2016-08-01| AAAAll| 1872.768734964083| |2016-09-01| AAAAll|2014.0112033543805| |2016-10-01| AAAAll|2296.4055573391643| |2016-11-01| AAAAll| 2043.693618904705| |2015-12-01| AAABus|455.55263433165527| |2016-01-01| AAABus| 296.1898590727801| |2016-02-01| AAABus| 453.3714726155002| |2016-03-01| AAABus| 525.339287022726| |2016-04-01| AAABus| 459.2040763240983| |2016-05-01| AAABus| 554.7602408402748| |2016-06-01| AAABus| 498.8249232185846| |2016-07-01| AAABus| 604.1654402560752| +----------+---------------+------------------+ only showing top 20 rows
%%script echo skipping
# If you want to have alternative forecasts, change the following code to obtain `forecast_sdf`
from pyspark.sql.functions import explode, col
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, DoubleType, DateType
import pandas as pd
from pandas.tseries.offsets import MonthBegin
from statsmodels.tsa.holtwinters import ExponentialSmoothing
# Define schema for the forecast output
forecast_schema = StructType([
StructField("date", DateType(), False),
StructField("Region_Category", StringType(), False),
StructField("Forecast", DoubleType(), False)
])
def ets_forecast(pdf):
"""Fits an ETS model for a single region and forecasts 12 months ahead."""
region = pdf["Region_Category"].iloc[0] # Extract region name
pdf = pdf.sort_values("date") # Ensure time series is sorted
try:
# Drop missing values
ts = pdf["Visitors"].dropna()
if len(ts) >= 24: # Ensure at least 24 observations
model = ExponentialSmoothing(ts, trend="add", seasonal="add", seasonal_periods=12)
fitted_model = model.fit()
forecast = fitted_model.forecast(steps=12)
else:
forecast = [None] * 12 # Not enough data
except:
forecast = [None] * 12 # Handle errors
# Adjust forecast dates to start of the month
last_date = pdf["date"].max()
future_dates = pd.date_range(start=last_date, periods=12, freq="ME") + MonthBegin(1)
# Return results as a DataFrame
return pd.DataFrame({"date": future_dates, "Region_Category": region, "Forecast": forecast})
# Apply the ETS model in parallel using applyInPandas
forecast_sdf = train_sdf.groupBy("Region_Category").applyInPandas(ets_forecast, schema=forecast_schema)
# Show forecasted results
forecast_sdf.show()
# Save forecasts if needed
forecast_sdf.write.csv(os.path.expanduser("~/ets_forecasts.csv"), header=True, mode="overwrite")
skipping
MinT-OLS approximation¶
Since PySpark doesn't support matrix operations like NumPy, we'll use the MinT-OLS approximation, which assumes the forecast error covariance matrix ( W = I ) (identity matrix). This simplifies the formula:
$$ \tilde{y} = S (S^\top S)^{-1} S^\top \hat{y} $$
forecast_sdf
contains base forecasts for eachRegion_Category
anddate
.test_sdf
is your test set with actualVisitors
byRegion_Category
anddate
.- You have
summing_sdf_long
with:Parent_Group
Region_Category
Weight
(usually 0 or 1)
Compute MinT-OLS reconciliation in PySpark¶
- MinT-OLS is equivalent to solving the linear regression problem
$$ \hat{y} = S\beta + \varepsilon $$
Then: $\tilde{y} = S\hat{\beta}$, which means:
- You can use
LinearRegression
frompyspark.ml.regression
to compute
$$ \hat{\beta} = (S^\top S)^{-1} S^\top \hat{y} $$
on the bottom-level base forecasts
- Then multiply $ S \hat{\beta} $ to get the reconciled forecasts.
MinT-OLS via LinearRegression
in PySpark¶
You have
forecast_sdf
: bottom-level base forecasts (date
,Region_Category
,Forecast
)You have
summing_sdf_long
: mapping fromRegion_Category → Parent_Group
For each
date
, collect base forecasts into a vector $ \hat{y} $, fitLinearRegression(labelCol="Forecast", featuresCol="features")
where
features = S
(bottom-level structure) and label is each forecastPredict: $ S \hat{\beta} $ to get reconciled total forecasts
from pyspark.sql.functions import col, sum as spark_sum
# Load the summing matrix file
summing_matrix_path = "../data/tourism/agg_mat.csv" # Update with actual path
# Load the summing matrix file (skip the first column)
summing_sdf = spark.read.csv(summing_matrix_path, header=True, inferSchema=True)
# Convert from wide format to long format (Region_Category, Parent_Group, Weight)
summing_sdf_long = summing_sdf.selectExpr(
"Parent_Group",
"stack(" + str(len(summing_sdf.columns) - 1) + ", " +
", ".join([f"'{col}', {col}" for col in summing_sdf.columns if col != "Parent_Group"]) +
") as (Region_Category, Weight)"
)
# Show the reshaped summing matrix
summing_sdf_long.show()
25/03/26 19:44:22 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
+------------+---------------+------+ |Parent_Group|Region_Category|Weight| +------------+---------------+------+ | TotalAll| AAAHol| 1.0| | TotalAll| AAAVis| 1.0| | TotalAll| AAABus| 1.0| | TotalAll| AAAOth| 1.0| | TotalAll| AABHol| 1.0| | TotalAll| AABVis| 1.0| | TotalAll| AABBus| 1.0| | TotalAll| AABOth| 1.0| | TotalAll| ABAHol| 1.0| | TotalAll| ABAVis| 1.0| | TotalAll| ABABus| 1.0| | TotalAll| ABAOth| 1.0| | TotalAll| ABBHol| 1.0| | TotalAll| ABBVis| 1.0| | TotalAll| ABBBus| 1.0| | TotalAll| ABBOth| 1.0| | TotalAll| ACAHol| 1.0| | TotalAll| ACAVis| 1.0| | TotalAll| ACABus| 1.0| | TotalAll| ACAOth| 1.0| +------------+---------------+------+ only showing top 20 rows
from pyspark.sql.functions import row_number, first
from pyspark.sql.window import Window
# Add forecast step index (1-12) per Region_Category
window = Window.partitionBy("Region_Category").orderBy("date")
forecast_with_step_sdf = forecast_sdf.withColumn("step", row_number().over(window))
# Pivot into wide format: Region_Category × [1,...,12]
forecast_wide_sdf = forecast_with_step_sdf.groupBy("Region_Category").pivot("step").agg(first("Forecast"))
# From summing_sdf_long: Parent_Group, Region_Category, Weight
# Create design matrix where Region_Category is row, Parent_Groups are columns
s_matrix_T = summing_sdf_long.groupBy("Region_Category") \
.pivot("Parent_Group").agg(first("Weight")).fillna(0)
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression
from pyspark.sql.functions import col, lit
reconciled_forecasts = []
for step in range(1, 13):
step_col = str(step)
training_df = forecast_wide_sdf.select("Region_Category", step_col) \
.join(s_matrix_T, on="Region_Category", how="inner")
feature_cols = [c for c in training_df.columns if c not in ["Region_Category", step_col]]
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")
assembled_df = assembler.transform(training_df).select("Region_Category", "features", col(step_col).alias("label"))
lr = LinearRegression(featuresCol="features", labelCol="label")
model = lr.fit(assembled_df)
reconciled = model.transform(assembled_df) \
.select("Region_Category", col("prediction").alias("Forecast")) \
.withColumn("step", lit(step)) # ← Tag with forecast step number
reconciled_forecasts.append(reconciled)
25/03/26 19:44:30 WARN Instrumentation: [55c02424] regParam is zero, which might cause numerical instability and overfitting. 25/03/26 19:44:31 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS 25/03/26 19:44:32 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.lapack.JNILAPACK 25/03/26 19:44:32 WARN Instrumentation: [55c02424] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver. 25/03/26 19:44:33 ERROR LBFGS: Failure! Resetting history: breeze.optimize.FirstOrderException: Line search zoom failed 25/03/26 19:44:33 ERROR LBFGS: Failure again! Giving up and returning. Maybe the objective is just poorly behaved? 25/03/26 19:44:40 WARN Instrumentation: [551fd9d3] regParam is zero, which might cause numerical instability and overfitting. 25/03/26 19:44:41 WARN Instrumentation: [551fd9d3] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver. 25/03/26 19:44:47 WARN Instrumentation: [77ef9881] regParam is zero, which might cause numerical instability and overfitting. 25/03/26 19:44:47 WARN Instrumentation: [77ef9881] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver. 25/03/26 19:44:47 ERROR LBFGS: Failure! Resetting history: breeze.optimize.FirstOrderException: Line search zoom failed 25/03/26 19:44:47 ERROR LBFGS: Failure again! Giving up and returning. Maybe the objective is just poorly behaved? 25/03/26 19:44:53 WARN Instrumentation: [40a25d43] regParam is zero, which might cause numerical instability and overfitting. 25/03/26 19:44:54 WARN Instrumentation: [40a25d43] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver. 25/03/26 19:44:54 ERROR LBFGS: Failure! Resetting history: breeze.optimize.FirstOrderException: Line search zoom failed 25/03/26 19:44:54 ERROR LBFGS: Failure! Resetting history: breeze.optimize.FirstOrderException: Line search zoom failed 25/03/26 19:44:54 ERROR LBFGS: Failure! Resetting history: breeze.optimize.FirstOrderException: Line search zoom failed 25/03/26 19:44:54 ERROR LBFGS: Failure again! Giving up and returning. Maybe the objective is just poorly behaved? 25/03/26 19:44:59 WARN Instrumentation: [527231ad] regParam is zero, which might cause numerical instability and overfitting. 25/03/26 19:45:00 WARN Instrumentation: [527231ad] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver. 25/03/26 19:45:06 WARN Instrumentation: [c4d24db2] regParam is zero, which might cause numerical instability and overfitting. 25/03/26 19:45:07 WARN Instrumentation: [c4d24db2] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver. 25/03/26 19:45:07 ERROR LBFGS: Failure! Resetting history: breeze.optimize.FirstOrderException: Line search zoom failed 25/03/26 19:45:12 WARN Instrumentation: [65615c88] regParam is zero, which might cause numerical instability and overfitting. 25/03/26 19:45:13 WARN Instrumentation: [65615c88] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver. 25/03/26 19:45:13 ERROR LBFGS: Failure! Resetting history: breeze.optimize.FirstOrderException: Line search zoom failed 25/03/26 19:45:13 ERROR LBFGS: Failure again! Giving up and returning. Maybe the objective is just poorly behaved? 25/03/26 19:45:19 WARN Instrumentation: [bb3a2485] regParam is zero, which might cause numerical instability and overfitting. 25/03/26 19:45:19 WARN Instrumentation: [bb3a2485] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver. 25/03/26 19:45:25 WARN Instrumentation: [99ce8314] regParam is zero, which might cause numerical instability and overfitting. 25/03/26 19:45:25 WARN Instrumentation: [99ce8314] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver. 25/03/26 19:45:31 WARN Instrumentation: [b9e78f4b] regParam is zero, which might cause numerical instability and overfitting. 25/03/26 19:45:32 WARN Instrumentation: [b9e78f4b] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver. 25/03/26 19:45:37 WARN Instrumentation: [7af5a814] regParam is zero, which might cause numerical instability and overfitting. 25/03/26 19:45:38 WARN Instrumentation: [7af5a814] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver. 25/03/26 19:45:44 WARN Instrumentation: [0515ac6c] regParam is zero, which might cause numerical instability and overfitting. 25/03/26 19:45:44 WARN Instrumentation: [0515ac6c] Cholesky solver failed due to singular covariance matrix. Retrying with Quasi-Newton solver.
import sys
sys.executable
'/nfs-share/software/anaconda/2020.02/envs/python3.12/bin/python'
reconciled_forecasts
[DataFrame[Region_Category: string, Forecast: double, step: int], DataFrame[Region_Category: string, Forecast: double, step: int], DataFrame[Region_Category: string, Forecast: double, step: int], DataFrame[Region_Category: string, Forecast: double, step: int], DataFrame[Region_Category: string, Forecast: double, step: int], DataFrame[Region_Category: string, Forecast: double, step: int], DataFrame[Region_Category: string, Forecast: double, step: int], DataFrame[Region_Category: string, Forecast: double, step: int], DataFrame[Region_Category: string, Forecast: double, step: int], DataFrame[Region_Category: string, Forecast: double, step: int], DataFrame[Region_Category: string, Forecast: double, step: int], DataFrame[Region_Category: string, Forecast: double, step: int]]
# Union all 12 steps into one DataFrame
from functools import reduce
from pyspark.sql.functions import expr
from pyspark.sql.functions import min as spark_min
reconciled_long_df = reduce(lambda df1, df2: df1.unionByName(df2), reconciled_forecasts)
# Get the first forecast date from the test set
min_date_row = test_sdf.select(spark_min("date").alias("start_date")).collect()[0]
start_date = min_date_row["start_date"].strftime("%Y-%m-%d")
# Add forecast dates based on step number
reconciled_long_df = reconciled_long_df.withColumn(
"date", expr(f"add_months(to_date('{start_date}'), step -1 )")
)
reconciled_long_df.show()
25/03/26 19:55:40 WARN DAGScheduler: Broadcasting large task binary with size 4.9 MiB
+---------------+------------------+----+----------+ |Region_Category| Forecast|step| date| +---------------+------------------+----+----------+ | DDBHol| 82.07811032786951| 1|2016-01-01| | FBAVis| 6.656805280557123| 1|2016-01-01| | EABVis|414.37511192628403| 1|2016-01-01| | BCBOth| 6.196280508794359| 1|2016-01-01| | FAAHol|150.30406627752933| 1|2016-01-01| | CBCHol| 67.29236126836824| 1|2016-01-01| | GABVis| 2.693192208874464| 1|2016-01-01| | CDBHol|0.5679754672918733| 1|2016-01-01| | DBABus| 5.316852903629325| 1|2016-01-01| | BDCBus|20.895111901507477| 1|2016-01-01| | DAAVis| 274.7963041881968| 1|2016-01-01| | CAAHol| 585.8461906935038| 1|2016-01-01| | AEDVis| 50.20844923996796| 1|2016-01-01| | BBAHol| 292.4243221681629| 1|2016-01-01| | ABABus|112.69985746271419| 1|2016-01-01| | CBBVis| 90.58366645794179| 1|2016-01-01| | BDBVis| 58.23528346223553| 1|2016-01-01| | ADDVis|111.02943173753079| 1|2016-01-01| | CCAVis|22.293612674097204| 1|2016-01-01| | BEHVis|19.081634731302486| 1|2016-01-01| +---------------+------------------+----+----------+ only showing top 20 rows
[Stage 380:> (0 + 1) / 1]
from pyspark.sql.functions import sum as spark_sum
reconciled_with_hierarchy_df = reconciled_long_df.join(
summing_sdf_long, on="Region_Category", how="inner"
)
reconciled_agg_df = reconciled_with_hierarchy_df.withColumn(
"Weighted_Forecast", col("Forecast") * col("Weight")
).groupBy("Parent_Group", "date").agg(
spark_sum("Weighted_Forecast").alias("Reconciled_Forecast")
)
test_with_hierarchy_df = test_sdf.join(summing_sdf_long, on="Region_Category", how="inner")
test_agg_df = test_with_hierarchy_df.withColumn(
"Weighted_Actual", col("Visitors") * col("Weight")
).groupBy("Parent_Group", "date").agg(
spark_sum("Weighted_Actual").alias("Actual_Visitors")
)
# Check date column types
reconciled_agg_df.printSchema()
test_agg_df.printSchema()
root |-- Parent_Group: string (nullable = true) |-- date: date (nullable = true) |-- Reconciled_Forecast: double (nullable = true) root |-- Parent_Group: string (nullable = true) |-- date: date (nullable = true) |-- Actual_Visitors: double (nullable = true)
reconciled_agg_df.select("date").distinct().orderBy("date").show(20)
test_agg_df.select("date").distinct().orderBy("date").show(20)
+----------+ | date| +----------+ |2016-01-01| |2016-02-01| |2016-03-01| |2016-04-01| |2016-05-01| |2016-06-01| |2016-07-01| |2016-08-01| |2016-09-01| |2016-10-01| |2016-11-01| |2016-12-01| +----------+ +----------+ | date| +----------+ |2016-01-01| |2016-02-01| |2016-03-01| |2016-04-01| |2016-05-01| |2016-06-01| |2016-07-01| |2016-08-01| |2016-09-01| |2016-10-01| |2016-11-01| |2016-12-01| +----------+
evaluation_df = reconciled_agg_df.join(test_agg_df, on=["Parent_Group", "date"], how="inner")
evaluation_df.show()
25/03/26 20:06:24 WARN DAGScheduler: Broadcasting large task binary with size 6.0 MiB [Stage 495:> (0 + 12) / 12]
+------------+----------+-------------------+------------------+ |Parent_Group| date|Reconciled_Forecast| Actual_Visitors| +------------+----------+-------------------+------------------+ | GBDBus|2016-01-01| 10.705897775124726| 28.5621226| | CCBVis|2016-01-01| 127.8105261721546| 136.5528404| | BECVis|2016-01-01| 62.46237010369159| 15.9456186| | FAAll|2016-01-01| 349.9323421850451| 763.3077524999999| | GACOth|2016-01-01|0.46048233024117735| 9.8127999| | EAAOth|2016-01-01| 2.9080098567940524| 5.7992456| | DCDBus|2016-01-01| 59.450079656240774| 46.0093885| | BDBOth|2016-01-01| 6.4037361653122815| 0.0| | AECOth|2016-01-01| 41.90321160149263| 10.9008486| | DDHol|2016-01-01| 102.51594592069989|359.95267900000005| | EHol|2016-01-01| 757.846086951482| 2812.781809| | EAAVis|2016-01-01| 52.72819361817362| 30.1005963| | BCBOth|2016-01-01| 6.196280508794359| 6.2829104| | BACBus|2016-01-01| 19.415417443459482| 0.0| | DAOth|2016-01-01| 212.9483219880252| 28.8082796| | CAVis|2016-01-01| 1280.131137158587| 1554.5346126| | ABAll|2016-01-01| 2088.705531118374| 5358.049206900001| | FBBHol|2016-01-01| 65.01918861007131| 169.5561995| | CBDOth|2016-01-01|-0.9901700392533428| 0.0| | BDDHol|2016-01-01| 21.003014900333262| 29.0369307| +------------+----------+-------------------+------------------+ only showing top 20 rows
25/03/26 20:06:28 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB
from pyspark.sql.functions import abs, avg
evaluation_df = evaluation_df.withColumn(
"APE", abs((col("Reconciled_Forecast") - col("Actual_Visitors")) / col("Actual_Visitors"))
)
# MAPE per group
mape_df = evaluation_df.groupBy("Parent_Group").agg(avg("APE").alias("MAPE"))
mape_df.show()
# Overall MAPE
overall_mape_df = mape_df.agg(avg("MAPE").alias("Overall_MAPE"))
overall_mape_df.show()
25/03/26 20:06:51 WARN DAGScheduler: Broadcasting large task binary with size 6.0 MiB 25/03/26 20:06:54 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB 25/03/26 20:06:55 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB
+------------+-------------------+ |Parent_Group| MAPE| +------------+-------------------+ | BCBOth| 1.8814916288630412| | BDEAll| 0.6319932719308724| | CCOth| 1.1371127959837342| | CCBAll|0.23007099726028646| | FBAVis| 4.193833533298544| | DCCAll| 1.0170059992164842| | DDBHol| 0.6989044320419716| | BCHol| 0.6848869709479803| | EABVis| 0.2597990646591048| | CBDAll|0.28554250724533875| | ADBAll| 0.3511071573477474| | GBCAll| 1.822907197252322| | FAAHol|0.26537276454971037| | BDFAll| 0.3958432127605984| | CBCHol| 1.09175933047297| | GABVis| 0.5505873369776451| | CAVis| 0.3025095665083698| | BDBAll| 0.6698985120653421| | BEGAll| 0.5083338077197483| | DABus|0.40018259421342434| +------------+-------------------+ only showing top 20 rows
25/03/26 20:07:11 WARN DAGScheduler: Broadcasting large task binary with size 6.0 MiB 25/03/26 20:07:14 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB 25/03/26 20:07:15 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB
+------------------+ | Overall_MAPE| +------------------+ |1.1076102791937525| +------------------+
mape_df.describe().show()
25/03/26 20:10:28 WARN DAGScheduler: Broadcasting large task binary with size 6.0 MiB 25/03/26 20:10:30 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB 25/03/26 20:10:31 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB 25/03/26 20:10:32 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB
+-------+------------+------------------+ |summary|Parent_Group| MAPE| +-------+------------+------------------+ | count| 555| 555| | mean| NULL|1.1076102791937525| | stddev| NULL|1.3666645729588924| | min| AAAAll|0.1381354920151718| | max| TotalVis|14.210951585958131| +-------+------------+------------------+
mape_df.orderBy(col("MAPE").desc()).show(10)
25/03/26 20:16:10 WARN DAGScheduler: Broadcasting large task binary with size 6.0 MiB 25/03/26 20:16:13 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB 25/03/26 20:16:14 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB
+------------+------------------+ |Parent_Group| MAPE| +------------+------------------+ | DCDOth|14.210951585958131| | GBBBus| 12.0202762100509| | GAAOth| 8.483838150123491| | GBDVis| 7.982046555639922| | BEGOth| 7.614402179032675| | EACOth| 7.401889674658579| | DACBus| 6.429812651519794| | DCDBus| 6.02555705573783| | BEGHol| 5.946961838257663| | CCBOth| 5.468544199643766| +------------+------------------+ only showing top 10 rows
mape_df.orderBy("MAPE").show()
25/03/26 20:17:42 WARN DAGScheduler: Broadcasting large task binary with size 6.0 MiB 25/03/26 20:17:44 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB
+------------+-------------------+ |Parent_Group| MAPE| +------------+-------------------+ | TotalOth| 0.1381354920151718| | AAAAll|0.15061493949148597| | BAAAll|0.15118372389621185| | CABAll| 0.1558753344427696| | AAAll|0.15773976873635803| | ABus| 0.1640455398594557| | CAll|0.16662182767530095| | DAAAll|0.17184814368002907| | TotalBus|0.17234491596252519| | AFAAll|0.17411943563895874| | AFAll|0.17411943563895874| | CDAll| 0.1816424382790586| | DAAll|0.19273337510343946| | EABAll| 0.195400784344689| | CAAll|0.20112517708058755| | BAAll|0.20431117238316943| | EAll|0.20736614485709745| | TotalAll|0.20956702567115104| | CBAll| 0.2104987151201196| | BBus|0.21364902571913846| +------------+-------------------+ only showing top 20 rows
25/03/26 20:17:46 WARN DAGScheduler: Broadcasting large task binary with size 7.8 MiB
Multistep reconciliation with linear regression¶
If you stack forecasts from multiple time steps (e.g., 12 months) as separate columns, you can treat it as a multivariate linear regression problem with shared design matrix $ S $.
This means:
- Each row of $ Y $ is a Region_Category
- Each column of $ Y $ is a forecast time step (e.g., 1st month, 2nd month...)
- You can run one Spark MLlib linear regression per time step, or fit all at once if you flatten it.
What you’re doing conceptually¶
You solve $Y = S\beta + E$ where:
- $ Y $: matrix of stacked forecasts (
Region_Category
×horizon
) - $ S $: summing matrix (
Parent_Group
×Region_Category
) - $ \beta $: regression coefficients (
Region_Category
×horizon
) - Reconciled forecasts: $\tilde{Y} = S\hat{\beta}$
Summary of MinT-OLS¶
- MinT-OLS (simple projection)
- No covariance matrix needed
- Coherent forecasts at Parent_Group level
- Evaluated using MAPE