Tourism Demand Forecasting with Spark¶
Feng Li¶
Guanghua School of Management¶
Peking University¶
feng.li@gsm.pku.edu.cn¶
Course home page: https://feng.li/bdcf¶
In [1]:
import os, sys # Ensure All environment variables are properly set
# os.environ["JAVA_HOME"] = os.path.dirname(sys.executable)
os.environ["PYSPARK_PYTHON"] = sys.executable
os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable
from pyspark.sql import SparkSession # build Spark Session
spark = SparkSession.builder \
.config("spark.ui.enabled", "false") \
.config("spark.executor.memory", "16g") \
.config("spark.executor.cores", "4") \
.config("spark.cores.max", "32") \
.config("spark.driver.memory", "30g") \
.config("spark.sql.shuffle.partitions", "96") \
.config("spark.memory.fraction", "0.8") \
.config("spark.memory.storageFraction", "0.5") \
.config("spark.dynamicAllocation.enabled", "true") \
.config("spark.dynamicAllocation.minExecutors", "4") \
.config("spark.dynamicAllocation.maxExecutors", "8") \
.appName("Spark Forecasting").getOrCreate()
spark
Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). 25/03/26 19:36:24 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Out[1]:
SparkSession - in-memory
Summary of the Data Structure¶
Tourism Data:
- 228 rows (time series data).
- 556 columns (date + 555 individual series representing regions/categories).
Summing Matrix:
- 555 rows (each corresponding to a region/category).
- 305 columns (Parent_Group + 304 child region/category columns).
- The first column (Parent_Group) correctly defines hierarchical aggregation.
In [2]:
from pyspark.sql.functions import col, max as spark_max
from pyspark.sql.types import DateType
from pyspark.sql.functions import trunc
# Load the tourism dataset (update the path accordingly)
file_path = "../data/tourism/tourism.csv" # Replace with actual file path
sdf = spark.read.csv(file_path, header=True, inferSchema=True)
# Ensure 'date' column is in proper DateType format
sdf = sdf.withColumn("date", col("date").cast(DateType()))
# Convert wide format to long format
columns_to_melt = [c for c in sdf.columns if c != "date"]
sdf_long = sdf.selectExpr(
"date",
f"stack({len(columns_to_melt)}, " + ", ".join([f"'{c}', {c}" for c in columns_to_melt]) + ") as (Region_Category, Visitors)"
)
# Force all dates in your long-format DataFrame to the start of month, right after loading
sdf_long = sdf_long.withColumn("date", trunc("date", "MM"))
# Find the maximum date in the dataset
max_date = sdf_long.select(spark_max("date")).collect()[0][0]
# Define the threshold date for splitting (last 12 months for testing)
split_date = max_date.replace(year=max_date.year - 1)
# Split into training and testing datasets
train_sdf = sdf_long.filter(col("date") <= split_date)
test_sdf = sdf_long.filter(col("date") > split_date)
# Show results
train_sdf.show(5)
test_sdf.show(5)
# Save if needed
train_sdf.write.csv(os.path.expanduser("~/train_data.csv"), header=True, mode="overwrite")
test_sdf.write.csv(os.path.expanduser("~/test_data.csv"), header=True, mode="overwrite")
25/03/26 19:36:30 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
+----------+---------------+------------------+ | date|Region_Category| Visitors| +----------+---------------+------------------+ |1998-01-01| TotalAll|45151.071280099975| |1998-01-01| AAll|17515.502379600006| |1998-01-01| BAll|10393.618015699998| |1998-01-01| CAll| 8633.359046599999| |1998-01-01| DAll|3504.3133462000005| +----------+---------------+------------------+ only showing top 5 rows +----------+---------------+------------------+ | date|Region_Category| Visitors| +----------+---------------+------------------+ |2016-01-01| TotalAll|45625.487797300026| |2016-01-01| AAll|14631.321547500002| |2016-01-01| BAll|11201.033523800006| |2016-01-01| CAll| 8495.718019999998| |2016-01-01| DAll| 3050.230027900001| +----------+---------------+------------------+ only showing top 5 rows
In [3]:
# train_sdf.select("date").distinct().orderBy("date").show(20)
test_sdf.select("date").distinct().orderBy("date").show(20)
+----------+ | date| +----------+ |2016-01-01| |2016-02-01| |2016-03-01| |2016-04-01| |2016-05-01| |2016-06-01| |2016-07-01| |2016-08-01| |2016-09-01| |2016-10-01| |2016-11-01| |2016-12-01| +----------+
Convert wide format to long format¶
The orignial data are in long format
We convert it in long format
In [4]:
# the original wide format
sdf.toPandas()
Out[4]:
date | TotalAll | AAll | BAll | CAll | DAll | EAll | FAll | GAll | AAAll | ... | GBBBus | GBBOth | GBCHol | GBCVis | GBCBus | GBCOth | GBDHol | GBDVis | GBDBus | GBDOth | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1998-01-01 | 45151.071280 | 17515.502380 | 10393.618016 | 8633.359047 | 3504.313346 | 3121.619189 | 1850.735773 | 131.923529 | 4977.209611 | ... | 0.000000 | 0.000000 | 7.536223 | 0.000000 | 1.628948 | 0.000000 | 0.811856 | 0.000000 | 9.478051 | 0.0 |
1 | 1998-02-01 | 17294.699551 | 5880.367918 | 3855.647839 | 3580.051065 | 1321.257992 | 1826.610676 | 757.079744 | 73.684316 | 1937.229611 | ... | 1.045797 | 0.000000 | 0.000000 | 0.000000 | 5.296459 | 0.000000 | 0.522899 | 0.000000 | 0.000000 | 0.0 |
2 | 1998-03-01 | 20725.114184 | 7086.444392 | 4353.379282 | 4717.676663 | 1521.950007 | 1868.381530 | 900.796622 | 276.485688 | 2117.671851 | ... | 0.000000 | 0.000000 | 2.945006 | 1.425324 | 9.924744 | 3.100121 | 0.000000 | 0.000000 | 0.000000 | 0.0 |
3 | 1998-04-01 | 25388.612353 | 10530.639348 | 5115.865530 | 4924.575204 | 1813.439177 | 1952.612465 | 801.444140 | 250.036488 | 2615.957465 | ... | 11.461824 | 0.000000 | 26.419176 | 13.690603 | 2.312088 | 0.000000 | 0.000000 | 10.958005 | 2.312088 | 0.0 |
4 | 1998-05-01 | 20330.035211 | 7430.373559 | 3820.666426 | 4219.283647 | 1375.082095 | 2616.965317 | 551.377058 | 316.287109 | 2393.145511 | ... | 0.000000 | 0.000000 | 23.789282 | 67.846207 | 1.282767 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
223 | 2016-08-01 | 24100.446632 | 7227.612581 | 4000.422405 | 7536.939346 | 1230.321700 | 2868.505157 | 355.515713 | 881.129730 | 2305.649511 | ... | 7.874448 | 0.000000 | 49.899947 | 8.300225 | 17.633478 | 8.218562 | 4.121631 | 0.408894 | 1.625820 | 0.0 |
224 | 2016-09-01 | 24800.033759 | 6778.363226 | 4132.168965 | 7123.112802 | 1632.732153 | 3327.064770 | 525.302974 | 1281.288869 | 2061.246613 | ... | 0.421996 | 0.000000 | 80.058887 | 37.306013 | 109.024164 | 46.447153 | 11.661140 | 3.557735 | 4.448534 | 0.0 |
225 | 2016-10-01 | 30039.106985 | 8592.998250 | 5719.297913 | 8759.191781 | 1900.476487 | 3704.651986 | 895.180382 | 467.310186 | 2267.174784 | ... | 35.375093 | 0.273247 | 52.156131 | 2.093902 | 50.283538 | 3.319366 | 0.754941 | 0.000000 | 0.000000 | 0.0 |
226 | 2016-11-01 | 27320.918908 | 8663.240960 | 5165.403172 | 6804.359328 | 1543.299435 | 3698.431886 | 852.313563 | 593.870565 | 2786.280116 | ... | 10.723090 | 39.292496 | 8.344998 | 7.697995 | 43.270319 | 0.000000 | 0.000000 | 0.000000 | 5.180648 | 0.0 |
227 | 2016-12-01 | 24604.310774 | 7953.659899 | 5000.483537 | 6049.188583 | 1378.418048 | 2844.820274 | 876.351928 | 501.388506 | 2676.459548 | ... | 0.000000 | 0.000000 | 2.418446 | 0.000000 | 0.762140 | 1.055685 | 0.000000 | 0.000000 | 9.966514 | 0.0 |
228 rows × 556 columns
In [5]:
# the long format
sdf_long.show()
+----------+---------------+------------------+ | date|Region_Category| Visitors| +----------+---------------+------------------+ |1998-01-01| TotalAll|45151.071280099975| |1998-01-01| AAll|17515.502379600006| |1998-01-01| BAll|10393.618015699998| |1998-01-01| CAll| 8633.359046599999| |1998-01-01| DAll|3504.3133462000005| |1998-01-01| EAll| 3121.6191894| |1998-01-01| FAll|1850.7357734999998| |1998-01-01| GAll|131.92352909999997| |1998-01-01| AAAll| 4977.2096105| |1998-01-01| ABAll| 5322.738721600001| |1998-01-01| ACAll|3569.6213724000004| |1998-01-01| ADAll| 1472.9706096| |1998-01-01| AEAll| 1560.5142545| |1998-01-01| AFAll| 612.447811| |1998-01-01| BAAll| 3854.672582| |1998-01-01| BBAll|1653.9957826000002| |1998-01-01| BCAll| 2138.7473162| |1998-01-01| BDAll| 1395.3775834| |1998-01-01| BEAll|1350.8247515000003| |1998-01-01| CAAll| 6421.236419000001| +----------+---------------+------------------+ only showing top 20 rows
The long format in Spark processing¶
Efficient Parallelization in Spark¶
- Spark natively processes data in a row-wise manner, so applying a Pandas UDF on each column separately is inefficient.
- By reshaping into long format, each (Region_Category, Visitors) pair becomes a separate row, allowing Spark to parallelize operations better.
Scalability for Large Datasets¶
- If you have hundreds of columns (as in our dataset), keeping it in wide format means:
- The UDF would need to handle many columns at once.
- Each worker in Spark must load all columns, increasing memory pressure.
- In long format, each time series is processed independently, allowing Spark to distribute tasks across multiple nodes efficiently.
Simpler UDF Application¶
- Spark natively groups and applies UDFs on rows rather than columns.
- Instead of manually iterating over columns in Python (which is slow and not parallelized), Spark can efficiently apply the UDF per region.
In [6]:
from pyspark.sql.functions import explode, col
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, DoubleType, DateType
import pandas as pd
from pandas.tseries.offsets import MonthBegin
from statsmodels.tsa.holtwinters import ExponentialSmoothing
# Define schema for the forecast output
forecast_schema = StructType([
StructField("date", DateType(), False),
StructField("Region_Category", StringType(), False),
StructField("Forecast", DoubleType(), False)
])
def ets_forecast(pdf):
"""Fits an ETS model for a single region and forecasts 12 months ahead."""
region = pdf["Region_Category"].iloc[0] # Extract region name
pdf = pdf.sort_values("date") # Ensure time series is sorted
try:
# Drop missing values
ts = pdf["Visitors"].dropna()
if len(ts) >= 24: # Ensure at least 24 observations
model = ExponentialSmoothing(ts, trend="add", seasonal="add", seasonal_periods=12)
fitted_model = model.fit()
forecast = fitted_model.forecast(steps=12)
else:
forecast = [None] * 12 # Not enough data
except:
forecast = [None] * 12 # Handle errors
# Adjust forecast dates to start of the month
last_date = pdf["date"].max()
future_dates = pd.date_range(start=last_date, periods=12, freq="ME") + MonthBegin(1)
# Return results as a DataFrame
return pd.DataFrame({"date": future_dates, "Region_Category": region, "Forecast": forecast})
# Apply the ETS model in parallel using applyInPandas
forecast_sdf = train_sdf.groupBy("Region_Category").applyInPandas(ets_forecast, schema=forecast_schema)
# Show forecasted results
forecast_sdf.show()
# Save forecasts if needed
forecast_sdf.write.csv(os.path.expanduser("~/ets_forecasts.csv"), header=True, mode="overwrite")
/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn(
+----------+---------------+------------------+ | date|Region_Category| Forecast| +----------+---------------+------------------+ |2016-01-01| AAAAll| 3102.656423754467| |2016-02-01| AAAAll|1680.2657901376683| |2016-03-01| AAAAll|1990.2751366602693| |2016-04-01| AAAAll|1998.0264213694707| |2016-05-01| AAAAll|1901.6160001070255| |2016-06-01| AAAAll| 1774.600837587313| |2016-07-01| AAAAll| 2080.923016745227| |2016-08-01| AAAAll|1801.0122793015398| |2016-09-01| AAAAll|1943.9770740225126| |2016-10-01| AAAAll|2227.8982222051372| |2016-11-01| AAAAll|2002.3244224259395| |2016-12-01| AAAAll| 2034.390487099784| |2016-01-01| AAABus| 297.5805840966566| |2016-02-01| AAABus|458.72259824713456| |2016-03-01| AAABus| 528.3710688119537| |2016-04-01| AAABus| 463.137319957957| |2016-05-01| AAABus| 558.7822468756337| |2016-06-01| AAABus|499.18248190461196| |2016-07-01| AAABus| 608.6718553424128| |2016-08-01| AAABus| 582.5088141373288| +----------+---------------+------------------+ only showing top 20 rows
/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn(
In [7]:
from pyspark.sql.functions import abs, mean, col
# Join forecasted results with actual test data on date and Region_Category
evaluation_sdf = forecast_sdf.join(
test_sdf, on=["date", "Region_Category"], how="inner"
).withColumn("APE", abs((col("Forecast") - col("Visitors")) / col("Visitors")))
# Compute Mean Absolute Percentage Error (MAPE) for each Region_Category
mape_sdf = evaluation_sdf.groupBy("Region_Category").agg(mean("APE").alias("MAPE"))
# Show MAPE results
mape_sdf.show()
# Save MAPE results if needed
# mape_sdf.write.csv(os.path.expanduser("~/mape_results.csv"), header=True, mode="overwrite")
# Compute overall mean MAPE across all Region_Category
overall_mape = mape_sdf.agg(mean("MAPE").alias("Overall_MAPE"))
# Show the result
overall_mape.show()
/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn(
+---------------+-------------------+ |Region_Category| MAPE| +---------------+-------------------+ | BCBOth| 1.738077138878954| | BCHol|0.22474998575413727| | BDEAll| 0.4502227513056252| | CBDAll|0.25029198603546415| | CCBAll|0.18626793448841952| | CCOth| 1.1703051474729502| | DCCAll| 0.8223672548904787| | DDBHol| 0.5296737451503019| | EABVis| 0.1905100109796238| | FBAVis| 1.7974654807514898| | ADBAll| 0.2552742929017676| | BDFAll|0.38694702800250763| | CBCHol|0.49020465544431563| | FAAHol| 0.2335170689589696| | GABVis| 0.5897688001440905| | GBCAll| 1.162144118937472| | AEHol| 0.2092658276093707| | BDBAll| 0.6689890326410363| | BDCBus| 2.3347911934651786| | BEGAll| 0.4220375937751745| +---------------+-------------------+ only showing top 20 rows
/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn(
+------------------+ | Overall_MAPE| +------------------+ |0.9550717947568259| +------------------+
In [8]:
from pyspark.sql.functions import col
# Load the summing matrix file
summing_matrix_path = "../data/tourism/agg_mat.csv" # Update with actual path
# Load the summing matrix file (skip the first column)
summing_sdf = spark.read.csv(summing_matrix_path, header=True, inferSchema=True)
# Show the cleaned summing matrix
summing_sdf.toPandas()
Out[8]:
Parent_Group | AAAHol | AAAVis | AAABus | AAAOth | AABHol | AABVis | AABBus | AABOth | ABAHol | ... | GBBBus | GBBOth | GBCHol | GBCVis | GBCBus | GBCOth | GBDHol | GBDVis | GBDBus | GBDOth | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | TotalAll | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | ... | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 |
1 | AAll | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
2 | BAll | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
3 | CAll | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
4 | DAll | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
550 | GBCOth | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 |
551 | GBDHol | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 |
552 | GBDVis | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 |
553 | GBDBus | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 |
554 | GBDOth | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 |
555 rows × 305 columns
In [9]:
from pyspark.sql.functions import expr
# Convert from wide format to long format (Region_Category, Parent_Group, Weight)
summing_sdf_long = summing_sdf.selectExpr(
"Parent_Group",
"stack(" + str(len(summing_sdf.columns) - 1) + ", " +
", ".join([f"'{col}', {col}" for col in summing_sdf.columns if col != "Parent_Group"]) +
") as (Region_Category, Weight)"
)
# Show the reshaped summing matrix
summing_sdf_long.show()
+------------+---------------+------+ |Parent_Group|Region_Category|Weight| +------------+---------------+------+ | TotalAll| AAAHol| 1.0| | TotalAll| AAAVis| 1.0| | TotalAll| AAABus| 1.0| | TotalAll| AAAOth| 1.0| | TotalAll| AABHol| 1.0| | TotalAll| AABVis| 1.0| | TotalAll| AABBus| 1.0| | TotalAll| AABOth| 1.0| | TotalAll| ABAHol| 1.0| | TotalAll| ABAVis| 1.0| | TotalAll| ABABus| 1.0| | TotalAll| ABAOth| 1.0| | TotalAll| ABBHol| 1.0| | TotalAll| ABBVis| 1.0| | TotalAll| ABBBus| 1.0| | TotalAll| ABBOth| 1.0| | TotalAll| ACAHol| 1.0| | TotalAll| ACAVis| 1.0| | TotalAll| ACABus| 1.0| | TotalAll| ACAOth| 1.0| +------------+---------------+------+ only showing top 20 rows
In [10]:
# Bottom up
from pyspark.sql.functions import sum as spark_sum, col
# Join forecast data with transformed summing matrix
forecast_with_hierarchy_sdf = forecast_sdf.join(
summing_sdf_long, on="Region_Category", how="inner"
)
# Aggregate forecasts according to the summing matrix
reconciled_forecast_sdf = forecast_with_hierarchy_sdf.groupBy("date").agg(
spark_sum(col("Forecast") * col("Weight")).alias("Reconciled_Forecast")
)
# Show reconciled forecasts
reconciled_forecast_sdf.show()
# Save reconciled forecasts if needed
# reconciled_forecast_sdf.write.csv(os.path.expanduser("~/reconciled_forecast.csv"), header=True, mode="overwrite")
/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn(
+----------+-------------------+ | date|Reconciled_Forecast| +----------+-------------------+ |2016-07-01| 218192.70227874012| |2016-08-01| 197075.67915486306| |2016-09-01| 204392.50642857671| |2016-12-01| 198470.65308329317| |2016-10-01| 230859.00134236045| |2016-04-01| 227945.82289731244| |2016-01-01| 369232.9770653847| |2016-02-01| 167778.9789411411| |2016-06-01| 181376.75686524587| |2016-11-01| 199288.84515902033| |2016-05-01| 182015.74284797342| |2016-03-01| 192429.9184313547| +----------+-------------------+
In [11]:
# Map test data to Parent_Group using summing matrix
test_with_hierarchy_sdf = test_sdf.join(
summing_sdf_long, on="Region_Category", how="inner"
)
# Aggregate forecasts at Parent_Group level
reconciled_forecast_sdf = forecast_with_hierarchy_sdf.groupBy("date", "Parent_Group").agg(
spark_sum(col("Forecast") * col("Weight")).alias("Reconciled_Forecast")
)
# Merge test data with reconciled forecasts
evaluation_sdf = test_sdf.join(
reconciled_forecast_sdf,(test_sdf["date"] == reconciled_forecast_sdf["date"]) &
(test_sdf["Region_Category"] == reconciled_forecast_sdf["Parent_Group"]), how="inner")
evaluation_sdf.show()
/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn(
+----------+---------------+------------------+----------+------------+-------------------+ | date|Region_Category| Visitors| date|Parent_Group|Reconciled_Forecast| +----------+---------------+------------------+----------+------------+-------------------+ |2016-01-01| DCBHol| 102.8895483|2016-01-01| DCBHol| 121.67840269674329| |2016-01-01| BBAHol| 1022.153257|2016-01-01| BBAHol| 1212.5607176306241| |2016-01-01| AFAHol| 207.9422174|2016-01-01| AFAHol| 239.56228231588986| |2016-01-01| ABAVis| 354.3176582|2016-01-01| ABAVis| 446.078158109554| |2016-01-01| ABAHol| 657.0045098|2016-01-01| ABAHol| 753.4011722631797| |2016-01-01| GAVis|161.90854199999998|2016-01-01| GAVis| 90.84358101715947| |2016-01-01| BEHol| 561.7473459|2016-01-01| BEHol| 509.9969906760645| |2016-01-01| AABus| 248.5930062|2016-01-01| AABus| 327.61548927681565| |2016-01-01| DABAll|26.582756099999997|2016-01-01| DABAll| 53.421767314549896| |2016-01-01| GAll| 608.4583232|2016-01-01| GAll| 369.71107428838764| |2016-02-01| GAABus| 204.0029381|2016-02-01| GAABus| 86.35571813879201| |2016-02-01| GAAVis| 20.835924|2016-02-01| GAAVis| 39.30977548594424| |2016-02-01| EBABus| 442.0968368|2016-02-01| EBABus| 968.5113794609761| |2016-02-01| DBAOth| 0.8202871|2016-02-01| DBAOth| 4.793413619092407| |2016-02-01| CDBHol| 21.2006668|2016-02-01| CDBHol| 10.445053695416235| |2016-02-01| CCABus| 0.0|2016-02-01| CCABus| 5.003942284998754| |2016-02-01| AECVis| 18.6097499|2016-02-01| AECVis| 13.334188823600508| |2016-02-01| ADBBus| 21.5952651|2016-02-01| ADBBus| 61.34592282892598| |2016-02-01| CCVis|263.77205649999996|2016-02-01| CCVis| 178.80184295335545| |2016-02-01| BCAAll|126.06354850000001|2016-02-01| BCAAll| 149.55966641842832| +----------+---------------+------------------+----------+------------+-------------------+ only showing top 20 rows
In [12]:
from pyspark.sql.functions import abs, mean
# Compute Absolute Percentage Error (APE)
evaluation_sdf = evaluation_sdf.withColumn(
"APE", abs((col("Reconciled_Forecast") - col("Visitors")) / col("Visitors"))
)
# Compute MAPE for each Parent_Group
mape_bu_sdf = evaluation_sdf.groupBy("Parent_Group").agg(mean("APE").alias("MAPE"))
# Show MAPE results
mape_bu_sdf.show()
# Compute overall mean MAPE across all Region_Category
overall_mape_bu = mape_bu_sdf.agg(mean("MAPE").alias("Overall_MAPE"))
# Show the result
overall_mape_bu.show()
# Save results if needed
# mape_sdf.write.csv(os.path.expanduser("~/mape_bottom_up.csv"), header=True, mode="overwrite")
/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn(
+------------+-------------------+ |Parent_Group| MAPE| +------------+-------------------+ | CBDAll| 0.270830763473807| | BCHol|0.21598580491185304| | BCBOth| 1.7380771388789542| | CCBAll| 0.1908010944069485| | DDBHol| 0.5296737451503021| | CCOth| 1.4105528463717814| | FBAVis| 1.7974654807514896| | EABVis|0.19051001097962375| | DCCAll| 0.7543353523263773| | BDEAll| 0.4395809834344526| | GABVis| 0.5897688001440906| | CBCHol| 0.4902046554443155| | ADBAll|0.22560078150472027| | FAAHol|0.23351706895896954| | BDFAll| 0.2650950034661002| | GBCAll| 0.8098570176724825| | CDBHol| 0.4344523647200556| | BEGAll| 0.4589742227798342| | DABus| 0.3129358963433057| | BDCBus| 2.3347911934651786| +------------+-------------------+ only showing top 20 rows
/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn(
+------------------+ | Overall_MAPE| +------------------+ |0.9492457033222294| +------------------+
In [13]:
mape_bu_sdf.orderBy("MAPE").show()
/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn( /nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals. warnings.warn(
+------------+--------------------+ |Parent_Group| MAPE| +------------+--------------------+ | GBDOth| 0.02415337882585393| | TotalAll|0.043976383262471615| | BAll| 0.05390238179742062| | TotalVis| 0.05723817716280238| | AAll| 0.06447498916483488| | BAAll| 0.06607394487054989| | AVis| 0.06702550342938024| | BAHol| 0.07941507636947337| | DAAVis| 0.08147907508340964| | BAAAll| 0.0816780773066375| | CAll| 0.08317278641917357| | BVis| 0.09021519344686489| | DAVis| 0.09089992803688406| | EABAll| 0.09109946225903111| | BHol| 0.09915622264317571| | CAAll| 0.09932009843257024| | BAAHol| 0.10309288791756062| | AAAll| 0.10404957643124245| | TotalHol| 0.10639324287391223| | ABVis| 0.1081168121711044| +------------+--------------------+ only showing top 20 rows
Further reading¶
- Forecasting hierarchical and grouped time series https://otexts.com/fpp3/hierarchical.html
- Athanasopoulos, G., Ahmed, R. A., & Hyndman, R. J. (2009). Hierarchical forecasts for Australian domestic tourism. International Journal of Forecasting, 25, 146–166. https://doi.org/10.1016/j.ijforecast.2008.07.004
Lab¶
- Why bottom-up method is not as good as raw forecast?
- Use some other forecasting method other than ETS for the reconciliation