{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" }, "tags": [] }, "source": [ "# Tourism Demand Forecasting with Spark\n", "\n", "\n", "## Feng Li\n", "\n", "### Guanghua School of Management\n", "### Peking University\n", "\n", "\n", "### [feng.li@gsm.pku.edu.cn](feng.li@gsm.pku.edu.cn)\n", "### Course home page: [https://feng.li/bdcf](https://feng.li/bdcf)" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "25/03/26 19:36:24 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" ] }, { "data": { "text/html": [ "\n", "
\n", "

SparkSession - in-memory

\n", " \n", "
\n", "

SparkContext

\n", "\n", "

Spark UI

\n", "\n", "
\n", "
Version
\n", "
v3.5.3
\n", "
Master
\n", "
local[*]
\n", "
AppName
\n", "
Spark Forecasting
\n", "
\n", "
\n", " \n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import os, sys # Ensure All environment variables are properly set \n", "# os.environ[\"JAVA_HOME\"] = os.path.dirname(sys.executable)\n", "os.environ[\"PYSPARK_PYTHON\"] = sys.executable\n", "os.environ[\"PYSPARK_DRIVER_PYTHON\"] = sys.executable\n", "\n", "from pyspark.sql import SparkSession # build Spark Session\n", "spark = SparkSession.builder \\\n", " .config(\"spark.ui.enabled\", \"false\") \\\n", " .config(\"spark.executor.memory\", \"16g\") \\\n", " .config(\"spark.executor.cores\", \"4\") \\\n", " .config(\"spark.cores.max\", \"32\") \\\n", " .config(\"spark.driver.memory\", \"30g\") \\\n", " .config(\"spark.sql.shuffle.partitions\", \"96\") \\\n", " .config(\"spark.memory.fraction\", \"0.8\") \\\n", " .config(\"spark.memory.storageFraction\", \"0.5\") \\\n", " .config(\"spark.dynamicAllocation.enabled\", \"true\") \\\n", " .config(\"spark.dynamicAllocation.minExecutors\", \"4\") \\\n", " .config(\"spark.dynamicAllocation.maxExecutors\", \"8\") \\\n", " .appName(\"Spark Forecasting\").getOrCreate()\n", "spark" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Summary of the Data Structure\n", "\n", "- Tourism Data:\n", "\n", " - 228 rows (time series data).\n", " - 556 columns (date + 555 individual series representing regions/categories).\n", "\n", "- Summing Matrix:\n", "\n", " - 555 rows (each corresponding to a region/category).\n", " - 305 columns (Parent_Group + 304 child region/category columns).\n", " - The first column (Parent_Group) correctly defines hierarchical aggregation." ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "scrolled": false, "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/03/26 19:36:30 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\n", " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+----------+---------------+------------------+\n", "| date|Region_Category| Visitors|\n", "+----------+---------------+------------------+\n", "|1998-01-01| TotalAll|45151.071280099975|\n", "|1998-01-01| AAll|17515.502379600006|\n", "|1998-01-01| BAll|10393.618015699998|\n", "|1998-01-01| CAll| 8633.359046599999|\n", "|1998-01-01| DAll|3504.3133462000005|\n", "+----------+---------------+------------------+\n", "only showing top 5 rows\n", "\n", "+----------+---------------+------------------+\n", "| date|Region_Category| Visitors|\n", "+----------+---------------+------------------+\n", "|2016-01-01| TotalAll|45625.487797300026|\n", "|2016-01-01| AAll|14631.321547500002|\n", "|2016-01-01| BAll|11201.033523800006|\n", "|2016-01-01| CAll| 8495.718019999998|\n", "|2016-01-01| DAll| 3050.230027900001|\n", "+----------+---------------+------------------+\n", "only showing top 5 rows\n", "\n" ] } ], "source": [ "from pyspark.sql.functions import col, max as spark_max\n", "from pyspark.sql.types import DateType\n", "from pyspark.sql.functions import trunc\n", "\n", "\n", "# Load the tourism dataset (update the path accordingly)\n", "file_path = \"../data/tourism/tourism.csv\" # Replace with actual file path\n", "sdf = spark.read.csv(file_path, header=True, inferSchema=True)\n", "\n", "# Ensure 'date' column is in proper DateType format\n", "sdf = sdf.withColumn(\"date\", col(\"date\").cast(DateType()))\n", "\n", "# Convert wide format to long format\n", "columns_to_melt = [c for c in sdf.columns if c != \"date\"]\n", "sdf_long = sdf.selectExpr(\n", " \"date\",\n", " f\"stack({len(columns_to_melt)}, \" + \", \".join([f\"'{c}', {c}\" for c in columns_to_melt]) + \") as (Region_Category, Visitors)\"\n", ")\n", "\n", "# Force all dates in your long-format DataFrame to the start of month, right after loading\n", "sdf_long = sdf_long.withColumn(\"date\", trunc(\"date\", \"MM\"))\n", "\n", "# Find the maximum date in the dataset\n", "max_date = sdf_long.select(spark_max(\"date\")).collect()[0][0]\n", "\n", "# Define the threshold date for splitting (last 12 months for testing)\n", "split_date = max_date.replace(year=max_date.year - 1)\n", "\n", "# Split into training and testing datasets\n", "train_sdf = sdf_long.filter(col(\"date\") <= split_date)\n", "test_sdf = sdf_long.filter(col(\"date\") > split_date)\n", "\n", "# Show results\n", "train_sdf.show(5)\n", "test_sdf.show(5)\n", "\n", "# Save if needed\n", "train_sdf.write.csv(os.path.expanduser(\"~/train_data.csv\"), header=True, mode=\"overwrite\")\n", "test_sdf.write.csv(os.path.expanduser(\"~/test_data.csv\"), header=True, mode=\"overwrite\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+----------+\n", "| date|\n", "+----------+\n", "|2016-01-01|\n", "|2016-02-01|\n", "|2016-03-01|\n", "|2016-04-01|\n", "|2016-05-01|\n", "|2016-06-01|\n", "|2016-07-01|\n", "|2016-08-01|\n", "|2016-09-01|\n", "|2016-10-01|\n", "|2016-11-01|\n", "|2016-12-01|\n", "+----------+\n", "\n" ] } ], "source": [ "# train_sdf.select(\"date\").distinct().orderBy(\"date\").show(20)\n", "test_sdf.select(\"date\").distinct().orderBy(\"date\").show(20)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Convert wide format to long format\n", "\n", "- The orignial data are in long format\n", "\n", "- We convert it in long format" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "scrolled": false, "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
dateTotalAllAAllBAllCAllDAllEAllFAllGAllAAAll...GBBBusGBBOthGBCHolGBCVisGBCBusGBCOthGBDHolGBDVisGBDBusGBDOth
01998-01-0145151.07128017515.50238010393.6180168633.3590473504.3133463121.6191891850.735773131.9235294977.209611...0.0000000.0000007.5362230.0000001.6289480.0000000.8118560.0000009.4780510.0
11998-02-0117294.6995515880.3679183855.6478393580.0510651321.2579921826.610676757.07974473.6843161937.229611...1.0457970.0000000.0000000.0000005.2964590.0000000.5228990.0000000.0000000.0
21998-03-0120725.1141847086.4443924353.3792824717.6766631521.9500071868.381530900.796622276.4856882117.671851...0.0000000.0000002.9450061.4253249.9247443.1001210.0000000.0000000.0000000.0
31998-04-0125388.61235310530.6393485115.8655304924.5752041813.4391771952.612465801.444140250.0364882615.957465...11.4618240.00000026.41917613.6906032.3120880.0000000.00000010.9580052.3120880.0
41998-05-0120330.0352117430.3735593820.6664264219.2836471375.0820952616.965317551.377058316.2871092393.145511...0.0000000.00000023.78928267.8462071.2827670.0000000.0000000.0000000.0000000.0
..................................................................
2232016-08-0124100.4466327227.6125814000.4224057536.9393461230.3217002868.505157355.515713881.1297302305.649511...7.8744480.00000049.8999478.30022517.6334788.2185624.1216310.4088941.6258200.0
2242016-09-0124800.0337596778.3632264132.1689657123.1128021632.7321533327.064770525.3029741281.2888692061.246613...0.4219960.00000080.05888737.306013109.02416446.44715311.6611403.5577354.4485340.0
2252016-10-0130039.1069858592.9982505719.2979138759.1917811900.4764873704.651986895.180382467.3101862267.174784...35.3750930.27324752.1561312.09390250.2835383.3193660.7549410.0000000.0000000.0
2262016-11-0127320.9189088663.2409605165.4031726804.3593281543.2994353698.431886852.313563593.8705652786.280116...10.72309039.2924968.3449987.69799543.2703190.0000000.0000000.0000005.1806480.0
2272016-12-0124604.3107747953.6598995000.4835376049.1885831378.4180482844.820274876.351928501.3885062676.459548...0.0000000.0000002.4184460.0000000.7621401.0556850.0000000.0000009.9665140.0
\n", "

228 rows × 556 columns

\n", "
" ], "text/plain": [ " date TotalAll AAll BAll CAll \\\n", "0 1998-01-01 45151.071280 17515.502380 10393.618016 8633.359047 \n", "1 1998-02-01 17294.699551 5880.367918 3855.647839 3580.051065 \n", "2 1998-03-01 20725.114184 7086.444392 4353.379282 4717.676663 \n", "3 1998-04-01 25388.612353 10530.639348 5115.865530 4924.575204 \n", "4 1998-05-01 20330.035211 7430.373559 3820.666426 4219.283647 \n", ".. ... ... ... ... ... \n", "223 2016-08-01 24100.446632 7227.612581 4000.422405 7536.939346 \n", "224 2016-09-01 24800.033759 6778.363226 4132.168965 7123.112802 \n", "225 2016-10-01 30039.106985 8592.998250 5719.297913 8759.191781 \n", "226 2016-11-01 27320.918908 8663.240960 5165.403172 6804.359328 \n", "227 2016-12-01 24604.310774 7953.659899 5000.483537 6049.188583 \n", "\n", " DAll EAll FAll GAll AAAll ... \\\n", "0 3504.313346 3121.619189 1850.735773 131.923529 4977.209611 ... \n", "1 1321.257992 1826.610676 757.079744 73.684316 1937.229611 ... \n", "2 1521.950007 1868.381530 900.796622 276.485688 2117.671851 ... \n", "3 1813.439177 1952.612465 801.444140 250.036488 2615.957465 ... \n", "4 1375.082095 2616.965317 551.377058 316.287109 2393.145511 ... \n", ".. ... ... ... ... ... ... \n", "223 1230.321700 2868.505157 355.515713 881.129730 2305.649511 ... \n", "224 1632.732153 3327.064770 525.302974 1281.288869 2061.246613 ... \n", "225 1900.476487 3704.651986 895.180382 467.310186 2267.174784 ... \n", "226 1543.299435 3698.431886 852.313563 593.870565 2786.280116 ... \n", "227 1378.418048 2844.820274 876.351928 501.388506 2676.459548 ... \n", "\n", " GBBBus GBBOth GBCHol GBCVis GBCBus GBCOth \\\n", "0 0.000000 0.000000 7.536223 0.000000 1.628948 0.000000 \n", "1 1.045797 0.000000 0.000000 0.000000 5.296459 0.000000 \n", "2 0.000000 0.000000 2.945006 1.425324 9.924744 3.100121 \n", "3 11.461824 0.000000 26.419176 13.690603 2.312088 0.000000 \n", "4 0.000000 0.000000 23.789282 67.846207 1.282767 0.000000 \n", ".. ... ... ... ... ... ... \n", "223 7.874448 0.000000 49.899947 8.300225 17.633478 8.218562 \n", "224 0.421996 0.000000 80.058887 37.306013 109.024164 46.447153 \n", "225 35.375093 0.273247 52.156131 2.093902 50.283538 3.319366 \n", "226 10.723090 39.292496 8.344998 7.697995 43.270319 0.000000 \n", "227 0.000000 0.000000 2.418446 0.000000 0.762140 1.055685 \n", "\n", " GBDHol GBDVis GBDBus GBDOth \n", "0 0.811856 0.000000 9.478051 0.0 \n", "1 0.522899 0.000000 0.000000 0.0 \n", "2 0.000000 0.000000 0.000000 0.0 \n", "3 0.000000 10.958005 2.312088 0.0 \n", "4 0.000000 0.000000 0.000000 0.0 \n", ".. ... ... ... ... \n", "223 4.121631 0.408894 1.625820 0.0 \n", "224 11.661140 3.557735 4.448534 0.0 \n", "225 0.754941 0.000000 0.000000 0.0 \n", "226 0.000000 0.000000 5.180648 0.0 \n", "227 0.000000 0.000000 9.966514 0.0 \n", "\n", "[228 rows x 556 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# the original wide format\n", "sdf.toPandas()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": false, "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+----------+---------------+------------------+\n", "| date|Region_Category| Visitors|\n", "+----------+---------------+------------------+\n", "|1998-01-01| TotalAll|45151.071280099975|\n", "|1998-01-01| AAll|17515.502379600006|\n", "|1998-01-01| BAll|10393.618015699998|\n", "|1998-01-01| CAll| 8633.359046599999|\n", "|1998-01-01| DAll|3504.3133462000005|\n", "|1998-01-01| EAll| 3121.6191894|\n", "|1998-01-01| FAll|1850.7357734999998|\n", "|1998-01-01| GAll|131.92352909999997|\n", "|1998-01-01| AAAll| 4977.2096105|\n", "|1998-01-01| ABAll| 5322.738721600001|\n", "|1998-01-01| ACAll|3569.6213724000004|\n", "|1998-01-01| ADAll| 1472.9706096|\n", "|1998-01-01| AEAll| 1560.5142545|\n", "|1998-01-01| AFAll| 612.447811|\n", "|1998-01-01| BAAll| 3854.672582|\n", "|1998-01-01| BBAll|1653.9957826000002|\n", "|1998-01-01| BCAll| 2138.7473162|\n", "|1998-01-01| BDAll| 1395.3775834|\n", "|1998-01-01| BEAll|1350.8247515000003|\n", "|1998-01-01| CAAll| 6421.236419000001|\n", "+----------+---------------+------------------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "# the long format\n", "sdf_long.show()" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## The long format in Spark processing \n", "\n", "### Efficient Parallelization in Spark\n", "\n", "- Spark natively processes data in a row-wise manner, so applying a Pandas UDF on each column separately is inefficient.\n", "- By reshaping into long format, each (Region_Category, Visitors) pair becomes a separate row, allowing Spark to parallelize operations better.\n", "\n", "### Scalability for Large Datasets\n", "\n", "- If you have hundreds of columns (as in our dataset), keeping it in wide format means:\n", " - The UDF would need to handle many columns at once.\n", " - Each worker in Spark must load all columns, increasing memory pressure.\n", "- In long format, each time series is processed independently, allowing Spark to distribute tasks across multiple nodes efficiently.\n", "\n", "### Simpler UDF Application\n", "\n", "- Spark natively groups and applies UDFs on rows rather than columns.\n", "- Instead of manually iterating over columns in Python (which is slow and not parallelized), Spark can efficiently apply the UDF per region." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": false, "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+----------+---------------+------------------+\n", "| date|Region_Category| Forecast|\n", "+----------+---------------+------------------+\n", "|2016-01-01| AAAAll| 3102.656423754467|\n", "|2016-02-01| AAAAll|1680.2657901376683|\n", "|2016-03-01| AAAAll|1990.2751366602693|\n", "|2016-04-01| AAAAll|1998.0264213694707|\n", "|2016-05-01| AAAAll|1901.6160001070255|\n", "|2016-06-01| AAAAll| 1774.600837587313|\n", "|2016-07-01| AAAAll| 2080.923016745227|\n", "|2016-08-01| AAAAll|1801.0122793015398|\n", "|2016-09-01| AAAAll|1943.9770740225126|\n", "|2016-10-01| AAAAll|2227.8982222051372|\n", "|2016-11-01| AAAAll|2002.3244224259395|\n", "|2016-12-01| AAAAll| 2034.390487099784|\n", "|2016-01-01| AAABus| 297.5805840966566|\n", "|2016-02-01| AAABus|458.72259824713456|\n", "|2016-03-01| AAABus| 528.3710688119537|\n", "|2016-04-01| AAABus| 463.137319957957|\n", "|2016-05-01| AAABus| 558.7822468756337|\n", "|2016-06-01| AAABus|499.18248190461196|\n", "|2016-07-01| AAABus| 608.6718553424128|\n", "|2016-08-01| AAABus| 582.5088141373288|\n", "+----------+---------------+------------------+\n", "only showing top 20 rows\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", " \r" ] } ], "source": [ "from pyspark.sql.functions import explode, col\n", "from pyspark.sql.types import StructType, StructField, StringType, ArrayType, DoubleType, DateType\n", "import pandas as pd\n", "from pandas.tseries.offsets import MonthBegin\n", "from statsmodels.tsa.holtwinters import ExponentialSmoothing\n", "\n", "# Define schema for the forecast output\n", "forecast_schema = StructType([\n", " StructField(\"date\", DateType(), False),\n", " StructField(\"Region_Category\", StringType(), False),\n", " StructField(\"Forecast\", DoubleType(), False)\n", "])\n", "\n", "\n", "def ets_forecast(pdf):\n", " \"\"\"Fits an ETS model for a single region and forecasts 12 months ahead.\"\"\"\n", " region = pdf[\"Region_Category\"].iloc[0] # Extract region name\n", " pdf = pdf.sort_values(\"date\") # Ensure time series is sorted\n", "\n", " try:\n", " # Drop missing values\n", " ts = pdf[\"Visitors\"].dropna()\n", "\n", " if len(ts) >= 24: # Ensure at least 24 observations\n", " model = ExponentialSmoothing(ts, trend=\"add\", seasonal=\"add\", seasonal_periods=12)\n", " fitted_model = model.fit()\n", " forecast = fitted_model.forecast(steps=12)\n", " else:\n", " forecast = [None] * 12 # Not enough data\n", " except:\n", " forecast = [None] * 12 # Handle errors\n", "\n", " # Adjust forecast dates to start of the month\n", " last_date = pdf[\"date\"].max()\n", " future_dates = pd.date_range(start=last_date, periods=12, freq=\"ME\") + MonthBegin(1)\n", "\n", " # Return results as a DataFrame\n", " return pd.DataFrame({\"date\": future_dates, \"Region_Category\": region, \"Forecast\": forecast})\n", "\n", "# Apply the ETS model in parallel using applyInPandas\n", "forecast_sdf = train_sdf.groupBy(\"Region_Category\").applyInPandas(ets_forecast, schema=forecast_schema)\n", "\n", "# Show forecasted results\n", "forecast_sdf.show()\n", "\n", "# Save forecasts if needed\n", "forecast_sdf.write.csv(os.path.expanduser(\"~/ets_forecasts.csv\"), header=True, mode=\"overwrite\")\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "scrolled": false, "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+---------------+-------------------+\n", "|Region_Category| MAPE|\n", "+---------------+-------------------+\n", "| BCBOth| 1.738077138878954|\n", "| BCHol|0.22474998575413727|\n", "| BDEAll| 0.4502227513056252|\n", "| CBDAll|0.25029198603546415|\n", "| CCBAll|0.18626793448841952|\n", "| CCOth| 1.1703051474729502|\n", "| DCCAll| 0.8223672548904787|\n", "| DDBHol| 0.5296737451503019|\n", "| EABVis| 0.1905100109796238|\n", "| FBAVis| 1.7974654807514898|\n", "| ADBAll| 0.2552742929017676|\n", "| BDFAll|0.38694702800250763|\n", "| CBCHol|0.49020465544431563|\n", "| FAAHol| 0.2335170689589696|\n", "| GABVis| 0.5897688001440905|\n", "| GBCAll| 1.162144118937472|\n", "| AEHol| 0.2092658276093707|\n", "| BDBAll| 0.6689890326410363|\n", "| BDCBus| 2.3347911934651786|\n", "| BEGAll| 0.4220375937751745|\n", "+---------------+-------------------+\n", "only showing top 20 rows\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+------------------+\n", "| Overall_MAPE|\n", "+------------------+\n", "|0.9550717947568259|\n", "+------------------+\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " \r" ] } ], "source": [ "from pyspark.sql.functions import abs, mean, col\n", "\n", "# Join forecasted results with actual test data on date and Region_Category\n", "evaluation_sdf = forecast_sdf.join(\n", " test_sdf, on=[\"date\", \"Region_Category\"], how=\"inner\"\n", ").withColumn(\"APE\", abs((col(\"Forecast\") - col(\"Visitors\")) / col(\"Visitors\")))\n", "\n", "# Compute Mean Absolute Percentage Error (MAPE) for each Region_Category\n", "mape_sdf = evaluation_sdf.groupBy(\"Region_Category\").agg(mean(\"APE\").alias(\"MAPE\"))\n", "\n", "# Show MAPE results\n", "mape_sdf.show()\n", "# Save MAPE results if needed\n", "# mape_sdf.write.csv(os.path.expanduser(\"~/mape_results.csv\"), header=True, mode=\"overwrite\")\n", "\n", "# Compute overall mean MAPE across all Region_Category\n", "overall_mape = mape_sdf.agg(mean(\"MAPE\").alias(\"Overall_MAPE\"))\n", "\n", "# Show the result\n", "overall_mape.show()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": false, "slideshow": { "slide_type": "slide" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Parent_GroupAAAHolAAAVisAAABusAAAOthAABHolAABVisAABBusAABOthABAHol...GBBBusGBBOthGBCHolGBCVisGBCBusGBCOthGBDHolGBDVisGBDBusGBDOth
0TotalAll1.01.01.01.01.01.01.01.01.0...1.01.01.01.01.01.01.01.01.01.0
1AAll1.01.01.01.01.01.01.01.01.0...0.00.00.00.00.00.00.00.00.00.0
2BAll0.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
3CAll0.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
4DAll0.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.00.0
..................................................................
550GBCOth0.00.00.00.00.00.00.00.00.0...0.00.00.00.00.01.00.00.00.00.0
551GBDHol0.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.01.00.00.00.0
552GBDVis0.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.01.00.00.0
553GBDBus0.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.01.00.0
554GBDOth0.00.00.00.00.00.00.00.00.0...0.00.00.00.00.00.00.00.00.01.0
\n", "

555 rows × 305 columns

\n", "
" ], "text/plain": [ " Parent_Group AAAHol AAAVis AAABus AAAOth AABHol AABVis AABBus \\\n", "0 TotalAll 1.0 1.0 1.0 1.0 1.0 1.0 1.0 \n", "1 AAll 1.0 1.0 1.0 1.0 1.0 1.0 1.0 \n", "2 BAll 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", "3 CAll 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", "4 DAll 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", ".. ... ... ... ... ... ... ... ... \n", "550 GBCOth 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", "551 GBDHol 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", "552 GBDVis 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", "553 GBDBus 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", "554 GBDOth 0.0 0.0 0.0 0.0 0.0 0.0 0.0 \n", "\n", " AABOth ABAHol ... GBBBus GBBOth GBCHol GBCVis GBCBus GBCOth \\\n", "0 1.0 1.0 ... 1.0 1.0 1.0 1.0 1.0 1.0 \n", "1 1.0 1.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 \n", "2 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 \n", "3 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 \n", "4 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 \n", ".. ... ... ... ... ... ... ... ... ... \n", "550 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 1.0 \n", "551 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 \n", "552 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 \n", "553 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 \n", "554 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0 0.0 \n", "\n", " GBDHol GBDVis GBDBus GBDOth \n", "0 1.0 1.0 1.0 1.0 \n", "1 0.0 0.0 0.0 0.0 \n", "2 0.0 0.0 0.0 0.0 \n", "3 0.0 0.0 0.0 0.0 \n", "4 0.0 0.0 0.0 0.0 \n", ".. ... ... ... ... \n", "550 0.0 0.0 0.0 0.0 \n", "551 1.0 0.0 0.0 0.0 \n", "552 0.0 1.0 0.0 0.0 \n", "553 0.0 0.0 1.0 0.0 \n", "554 0.0 0.0 0.0 1.0 \n", "\n", "[555 rows x 305 columns]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from pyspark.sql.functions import col\n", "\n", "# Load the summing matrix file\n", "summing_matrix_path = \"../data/tourism/agg_mat.csv\" # Update with actual path\n", "\n", "# Load the summing matrix file (skip the first column)\n", "summing_sdf = spark.read.csv(summing_matrix_path, header=True, inferSchema=True)\n", "\n", "# Show the cleaned summing matrix\n", "summing_sdf.toPandas()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "+------------+---------------+------+\n", "|Parent_Group|Region_Category|Weight|\n", "+------------+---------------+------+\n", "| TotalAll| AAAHol| 1.0|\n", "| TotalAll| AAAVis| 1.0|\n", "| TotalAll| AAABus| 1.0|\n", "| TotalAll| AAAOth| 1.0|\n", "| TotalAll| AABHol| 1.0|\n", "| TotalAll| AABVis| 1.0|\n", "| TotalAll| AABBus| 1.0|\n", "| TotalAll| AABOth| 1.0|\n", "| TotalAll| ABAHol| 1.0|\n", "| TotalAll| ABAVis| 1.0|\n", "| TotalAll| ABABus| 1.0|\n", "| TotalAll| ABAOth| 1.0|\n", "| TotalAll| ABBHol| 1.0|\n", "| TotalAll| ABBVis| 1.0|\n", "| TotalAll| ABBBus| 1.0|\n", "| TotalAll| ABBOth| 1.0|\n", "| TotalAll| ACAHol| 1.0|\n", "| TotalAll| ACAVis| 1.0|\n", "| TotalAll| ACABus| 1.0|\n", "| TotalAll| ACAOth| 1.0|\n", "+------------+---------------+------+\n", "only showing top 20 rows\n", "\n" ] } ], "source": [ "from pyspark.sql.functions import expr\n", "\n", "# Convert from wide format to long format (Region_Category, Parent_Group, Weight)\n", "summing_sdf_long = summing_sdf.selectExpr(\n", " \"Parent_Group\",\n", " \"stack(\" + str(len(summing_sdf.columns) - 1) + \", \" +\n", " \", \".join([f\"'{col}', {col}\" for col in summing_sdf.columns if col != \"Parent_Group\"]) +\n", " \") as (Region_Category, Weight)\"\n", ")\n", "\n", "# Show the reshaped summing matrix\n", "summing_sdf_long.show()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+----------+-------------------+\n", "| date|Reconciled_Forecast|\n", "+----------+-------------------+\n", "|2016-07-01| 218192.70227874012|\n", "|2016-08-01| 197075.67915486306|\n", "|2016-09-01| 204392.50642857671|\n", "|2016-12-01| 198470.65308329317|\n", "|2016-10-01| 230859.00134236045|\n", "|2016-04-01| 227945.82289731244|\n", "|2016-01-01| 369232.9770653847|\n", "|2016-02-01| 167778.9789411411|\n", "|2016-06-01| 181376.75686524587|\n", "|2016-11-01| 199288.84515902033|\n", "|2016-05-01| 182015.74284797342|\n", "|2016-03-01| 192429.9184313547|\n", "+----------+-------------------+\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " \r" ] } ], "source": [ "# Bottom up\n", "from pyspark.sql.functions import sum as spark_sum, col\n", "# Join forecast data with transformed summing matrix\n", "forecast_with_hierarchy_sdf = forecast_sdf.join(\n", " summing_sdf_long, on=\"Region_Category\", how=\"inner\"\n", ")\n", "\n", "# Aggregate forecasts according to the summing matrix\n", "reconciled_forecast_sdf = forecast_with_hierarchy_sdf.groupBy(\"date\").agg(\n", " spark_sum(col(\"Forecast\") * col(\"Weight\")).alias(\"Reconciled_Forecast\")\n", ")\n", "\n", "# Show reconciled forecasts\n", "reconciled_forecast_sdf.show()\n", "\n", "# Save reconciled forecasts if needed\n", "# reconciled_forecast_sdf.write.csv(os.path.expanduser(\"~/reconciled_forecast.csv\"), header=True, mode=\"overwrite\")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "scrolled": false, "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+----------+---------------+------------------+----------+------------+-------------------+\n", "| date|Region_Category| Visitors| date|Parent_Group|Reconciled_Forecast|\n", "+----------+---------------+------------------+----------+------------+-------------------+\n", "|2016-01-01| DCBHol| 102.8895483|2016-01-01| DCBHol| 121.67840269674329|\n", "|2016-01-01| BBAHol| 1022.153257|2016-01-01| BBAHol| 1212.5607176306241|\n", "|2016-01-01| AFAHol| 207.9422174|2016-01-01| AFAHol| 239.56228231588986|\n", "|2016-01-01| ABAVis| 354.3176582|2016-01-01| ABAVis| 446.078158109554|\n", "|2016-01-01| ABAHol| 657.0045098|2016-01-01| ABAHol| 753.4011722631797|\n", "|2016-01-01| GAVis|161.90854199999998|2016-01-01| GAVis| 90.84358101715947|\n", "|2016-01-01| BEHol| 561.7473459|2016-01-01| BEHol| 509.9969906760645|\n", "|2016-01-01| AABus| 248.5930062|2016-01-01| AABus| 327.61548927681565|\n", "|2016-01-01| DABAll|26.582756099999997|2016-01-01| DABAll| 53.421767314549896|\n", "|2016-01-01| GAll| 608.4583232|2016-01-01| GAll| 369.71107428838764|\n", "|2016-02-01| GAABus| 204.0029381|2016-02-01| GAABus| 86.35571813879201|\n", "|2016-02-01| GAAVis| 20.835924|2016-02-01| GAAVis| 39.30977548594424|\n", "|2016-02-01| EBABus| 442.0968368|2016-02-01| EBABus| 968.5113794609761|\n", "|2016-02-01| DBAOth| 0.8202871|2016-02-01| DBAOth| 4.793413619092407|\n", "|2016-02-01| CDBHol| 21.2006668|2016-02-01| CDBHol| 10.445053695416235|\n", "|2016-02-01| CCABus| 0.0|2016-02-01| CCABus| 5.003942284998754|\n", "|2016-02-01| AECVis| 18.6097499|2016-02-01| AECVis| 13.334188823600508|\n", "|2016-02-01| ADBBus| 21.5952651|2016-02-01| ADBBus| 61.34592282892598|\n", "|2016-02-01| CCVis|263.77205649999996|2016-02-01| CCVis| 178.80184295335545|\n", "|2016-02-01| BCAAll|126.06354850000001|2016-02-01| BCAAll| 149.55966641842832|\n", "+----------+---------------+------------------+----------+------------+-------------------+\n", "only showing top 20 rows\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " \r" ] } ], "source": [ "# Map test data to Parent_Group using summing matrix\n", "test_with_hierarchy_sdf = test_sdf.join(\n", " summing_sdf_long, on=\"Region_Category\", how=\"inner\"\n", ")\n", "\n", "# Aggregate forecasts at Parent_Group level\n", "reconciled_forecast_sdf = forecast_with_hierarchy_sdf.groupBy(\"date\", \"Parent_Group\").agg(\n", " spark_sum(col(\"Forecast\") * col(\"Weight\")).alias(\"Reconciled_Forecast\")\n", ")\n", "\n", "# Merge test data with reconciled forecasts \n", "evaluation_sdf = test_sdf.join(\n", " reconciled_forecast_sdf,(test_sdf[\"date\"] == reconciled_forecast_sdf[\"date\"]) &\n", " (test_sdf[\"Region_Category\"] == reconciled_forecast_sdf[\"Parent_Group\"]), how=\"inner\")\n", "\n", "evaluation_sdf.show()\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+------------+-------------------+\n", "|Parent_Group| MAPE|\n", "+------------+-------------------+\n", "| CBDAll| 0.270830763473807|\n", "| BCHol|0.21598580491185304|\n", "| BCBOth| 1.7380771388789542|\n", "| CCBAll| 0.1908010944069485|\n", "| DDBHol| 0.5296737451503021|\n", "| CCOth| 1.4105528463717814|\n", "| FBAVis| 1.7974654807514896|\n", "| EABVis|0.19051001097962375|\n", "| DCCAll| 0.7543353523263773|\n", "| BDEAll| 0.4395809834344526|\n", "| GABVis| 0.5897688001440906|\n", "| CBCHol| 0.4902046554443155|\n", "| ADBAll|0.22560078150472027|\n", "| FAAHol|0.23351706895896954|\n", "| BDFAll| 0.2650950034661002|\n", "| GBCAll| 0.8098570176724825|\n", "| CDBHol| 0.4344523647200556|\n", "| BEGAll| 0.4589742227798342|\n", "| DABus| 0.3129358963433057|\n", "| BDCBus| 2.3347911934651786|\n", "+------------+-------------------+\n", "only showing top 20 rows\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+------------------+\n", "| Overall_MAPE|\n", "+------------------+\n", "|0.9492457033222294|\n", "+------------------+\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " \r" ] } ], "source": [ "from pyspark.sql.functions import abs, mean\n", "\n", "# Compute Absolute Percentage Error (APE)\n", "evaluation_sdf = evaluation_sdf.withColumn(\n", " \"APE\", abs((col(\"Reconciled_Forecast\") - col(\"Visitors\")) / col(\"Visitors\"))\n", ")\n", "\n", "# Compute MAPE for each Parent_Group\n", "mape_bu_sdf = evaluation_sdf.groupBy(\"Parent_Group\").agg(mean(\"APE\").alias(\"MAPE\"))\n", "\n", "# Show MAPE results\n", "mape_bu_sdf.show()\n", "\n", "# Compute overall mean MAPE across all Region_Category\n", "overall_mape_bu = mape_bu_sdf.agg(mean(\"MAPE\").alias(\"Overall_MAPE\"))\n", "\n", "# Show the result\n", "overall_mape_bu.show()\n", "\n", "# Save results if needed\n", "# mape_sdf.write.csv(os.path.expanduser(\"~/mape_bottom_up.csv\"), header=True, mode=\"overwrite\")\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n", "/nfs-share/home/2406184221/.local/lib/python3.12/site-packages/statsmodels/tsa/holtwinters/model.py:918: ConvergenceWarning: Optimization failed to converge. Check mle_retvals.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+------------+--------------------+\n", "|Parent_Group| MAPE|\n", "+------------+--------------------+\n", "| GBDOth| 0.02415337882585393|\n", "| TotalAll|0.043976383262471615|\n", "| BAll| 0.05390238179742062|\n", "| TotalVis| 0.05723817716280238|\n", "| AAll| 0.06447498916483488|\n", "| BAAll| 0.06607394487054989|\n", "| AVis| 0.06702550342938024|\n", "| BAHol| 0.07941507636947337|\n", "| DAAVis| 0.08147907508340964|\n", "| BAAAll| 0.0816780773066375|\n", "| CAll| 0.08317278641917357|\n", "| BVis| 0.09021519344686489|\n", "| DAVis| 0.09089992803688406|\n", "| EABAll| 0.09109946225903111|\n", "| BHol| 0.09915622264317571|\n", "| CAAll| 0.09932009843257024|\n", "| BAAHol| 0.10309288791756062|\n", "| AAAll| 0.10404957643124245|\n", "| TotalHol| 0.10639324287391223|\n", "| ABVis| 0.1081168121711044|\n", "+------------+--------------------+\n", "only showing top 20 rows\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " \r" ] } ], "source": [ "mape_bu_sdf.orderBy(\"MAPE\").show()" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Further reading\n", "\n", "- Forecasting hierarchical and grouped time series https://otexts.com/fpp3/hierarchical.html\n", "- Athanasopoulos, G., Ahmed, R. A., & Hyndman, R. J. (2009). Hierarchical forecasts for Australian domestic tourism. International Journal of Forecasting, 25, 146–166. https://doi.org/10.1016/j.ijforecast.2008.07.004" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Lab\n", "\n", "- Why bottom-up method is not as good as raw forecast?\n", "- Use some other forecasting method other than ETS for the reconciliation" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "python3.12", "language": "python", "name": "python3.12" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 4 }