{ "cells": [ { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Forecasting with Gradient Boosted Tree\n", "\n", "## Feng Li\n", "\n", "### Guanghua School of Management\n", "### Peking University\n", "\n", "\n", "### [feng.li@gsm.pku.edu.cn](feng.li@gsm.pku.edu.cn)\n", "### Course home page: [https://feng.li/bdcf](https://feng.li/bdcf)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Gradient Boosted Tree Regressor\n", "\n", "`GBTRegressor` (**Gradient Boosted Tree Regressor**) in **Spark MLlib** is a **supervised learning algorithm** that builds an ensemble of decision trees using **gradient boosting** to improve predictive accuracy.\n", "\n", "\n", "### How GBTRegressor Works\n", "Gradient Boosted Trees (GBT) work by **training decision trees sequentially**, where:\n", "1. The **first tree** makes an initial prediction.\n", "2. Each subsequent tree learns from the **errors (residuals)** of the previous trees.\n", "3. The final prediction is the sum of all trees’ outputs.\n", "\n", "This technique is effective for **handling non-linear relationships** in data and reducing **bias and variance**." ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Code Example\n", "```python\n", "from pyspark.ml.regression import GBTRegressor\n", "from pyspark.ml.evaluation import RegressionEvaluator\n", "\n", "# Initialize GBT Regressor\n", "gbt = GBTRegressor(featuresCol=\"features\", labelCol=\"sales\", maxIter=50, maxDepth=5, stepSize=0.1)\n", "\n", "# Train the model\n", "model = gbt.fit(train_df)\n", "\n", "# Make predictions\n", "predictions = model.transform(test_df)\n", "\n", "# Evaluate using RMSE\n", "evaluator = RegressionEvaluator(labelCol=\"sales\", predictionCol=\"prediction\", metricName=\"rmse\")\n", "rmse = evaluator.evaluate(predictions)\n", "\n", "print(f\"Root Mean Squared Error (RMSE): {rmse}\")\n", "```" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Key Hyperparameters\n", "| Parameter | Description |\n", "|-----------------|-------------|\n", "| `maxIter` | Number of trees in the ensemble (higher = more complex model) |\n", "| `maxDepth` | Maximum depth of each tree (higher = risk of overfitting) |\n", "| `stepSize` | Learning rate (default `0.1` for stability) |\n", "| `subsamplingRate` | Fraction of data used for each tree (default `1.0`, full dataset) |\n", "| `maxBins` | Number of bins for feature discretization (default `32`) |\n", "| `minInstancesPerNode` | Minimum instances required per node (default `1`) |\n", "\n", "### Recommended Settings\n", "- **For small datasets** → `maxIter=20, maxDepth=3`\n", "- **For large datasets** → `maxIter=50, maxDepth=5`\n", "- **For fine-tuning** → Adjust `stepSize` (`0.05 - 0.2`)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "\n", "## Advantages and Limitations\n", "\n", "✅ **Handles complex non-linear relationships** \n", "✅ **More accurate than a single Decision Tree** \n", "✅ **Built-in feature selection** (important features contribute more) \n", "✅ **Works well with missing values** \n", "\n", "\n", "🚨 **Slower training compared to Random Forest** (sequential training of trees) \n", "🚨 **Prone to overfitting with large `maxDepth`** \n", "🚨 **Not suited for real-time applications** (expensive to update) \n", "\n", "---" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Setting default log level to \"WARN\".\n", "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n", "25/03/11 20:02:23 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n" ] }, { "data": { "text/html": [ "\n", "
\n", "

SparkSession - in-memory

\n", " \n", "
\n", "

SparkContext

\n", "\n", "

Spark UI

\n", "\n", "
\n", "
Version
\n", "
v3.5.3
\n", "
Master
\n", "
local[*]
\n", "
AppName
\n", "
Optimized Spark
\n", "
\n", "
\n", " \n", "
\n", " " ], "text/plain": [ "" ] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import os, sys # Ensure All environment variables are properly set \n", "# os.environ[\"JAVA_HOME\"] = os.path.dirname(sys.executable)\n", "os.environ[\"PYSPARK_PYTHON\"] = sys.executable\n", "os.environ[\"PYSPARK_DRIVER_PYTHON\"] = sys.executable\n", "\n", "from pyspark.sql import SparkSession # build Spark Session\n", "spark = SparkSession.builder \\\n", " .config(\"spark.ui.enabled\", \"false\") \\\n", " .config(\"spark.executor.memory\", \"16g\") \\\n", " .config(\"spark.executor.cores\", \"4\") \\\n", " .config(\"spark.cores.max\", \"32\") \\\n", " .config(\"spark.driver.memory\", \"30g\") \\\n", " .config(\"spark.sql.shuffle.partitions\", \"96\") \\\n", " .config(\"spark.memory.fraction\", \"0.8\") \\\n", " .config(\"spark.memory.storageFraction\", \"0.5\") \\\n", " .config(\"spark.dynamicAllocation.enabled\", \"true\") \\\n", " .config(\"spark.dynamicAllocation.minExecutors\", \"4\") \\\n", " .config(\"spark.dynamicAllocation.maxExecutors\", \"8\") \\\n", " .appName(\"Optimized Spark\").getOrCreate()\n", "spark" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Rule of Thumb for Optimized Spark Configuration\n", "\n", "\n", "- `.config(\"spark.sql.shuffle.partitions\", \"32 * 3\")`: 3x number of cores\n", "- `.config(\"spark.dynamicAllocation.enabled\", \"true\")`: Dynamic Allocation\n", "- `.config(\"spark.memory.fraction\", \"0.8\")`: 80% memory for execution\n", "- `.config(\"spark.memory.storageFraction\", \"0.5\")`: # 50% for storage (cached data)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " \r" ] } ], "source": [ "# Load Data\n", "train_df = spark.read.csv(\"../data/m5-forecasting-accuracy/sales_train_evaluation.csv\", header=True, inferSchema=True)\n", "prices_df = spark.read.csv(\"../data//m5-forecasting-accuracy/sell_prices.csv\", header=True, inferSchema=True)\n", "calendar_df = spark.read.csv(\"../data//m5-forecasting-accuracy/calendar.csv\", header=True, inferSchema=True)\n", "\n", "TARGET = 'sales' # Our main target\n", "END_TRAIN = 1913+28 # Last day in train set\n", "MAIN_INDEX = ['id','d'] # We can identify item by these columns" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/03/11 20:02:32 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\n", " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Train rows: 30490 59181090\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[Stage 15:==========> (12 + 38) / 66]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Total rows in grid_df: 60034810\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", "[Stage 15:========================> (29 + 37) / 66]\r", "\r", " \r" ] } ], "source": [ "from pyspark.sql.functions import col, lit, expr\n", "from pyspark.sql.types import StringType\n", "import numpy as np\n", "# calendar_df.printSchema()\n", "\n", "# Define index columns\n", "index_columns = ['id', 'item_id', 'dept_id', 'cat_id', 'store_id', 'state_id']\n", "\n", "# Melt train_df using explode\n", "train_df_long = train_df.selectExpr(\n", " *index_columns, \n", " \"stack(\" + str(len(train_df.columns) - len(index_columns)) + \n", " \"\".join([f\", '{col}', {col}\" for col in train_df.columns if col not in index_columns]) + \") as (d, sales)\"\n", ")\n", "\n", "# Convert \"d\" column format to match Pandas melt\n", "train_df_long = train_df_long.withColumn(\"d\", expr(\"substring(d, 3, length(d)-2)\"))\n", "\n", "# Count rows\n", "print(\"Train rows:\", train_df.count(), train_df_long.count())\n", "\n", "# Create \"test set\" grid for future dates\n", "from pyspark.sql.functions import monotonically_increasing_id\n", "\n", "add_grid = train_df.select(*index_columns).dropDuplicates()\n", "add_grid = add_grid.crossJoin(\n", " spark.createDataFrame([(f\"d_{END_TRAIN+i}\", np.nan) for i in range(1, 29)], [\"d\", TARGET])\n", ")\n", "\n", "# Combine train and test sets\n", "grid_df = train_df_long.union(add_grid)\n", "\n", "# Show memory usage estimate (PySpark does not have direct memory usage functions)\n", "print(f\"Total rows in grid_df: {grid_df.count()}\")" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 27:====================================================> (45 + 3) / 48]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- d: string (nullable = true)\n", " |-- store_id: string (nullable = true)\n", " |-- item_id: string (nullable = true)\n", " |-- id: long (nullable = false)\n", " |-- dept_id: string (nullable = true)\n", " |-- cat_id: string (nullable = true)\n", " |-- state_id: string (nullable = true)\n", " |-- sales: double (nullable = true)\n", " |-- release: integer (nullable = true)\n", " |-- wm_yr_wk: integer (nullable = true)\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " \r" ] } ], "source": [ "from pyspark.sql import functions as F\n", "from pyspark.sql.types import IntegerType\n", "\n", "# Group by store_id and item_id to get the earliest (min) wm_yr_wk (release week)\n", "release_df = prices_df.groupBy(\"store_id\", \"item_id\").agg(F.min(\"wm_yr_wk\").alias(\"release\"))\n", "\n", "# Merge release_df with grid_df\n", "grid_df = grid_df.join(release_df, on=[\"store_id\", \"item_id\"], how=\"left\")\n", "\n", "# Remove release_df to free memory\n", "del release_df\n", "\n", "# Merge with calendar_df to get wm_yr_wk column\n", "calendar_df = calendar_df.withColumn(\"d\", expr(\"substring(d, 3, length(d)-2)\"))\n", "grid_df = grid_df.join(calendar_df.select(\"wm_yr_wk\", \"d\"), on=\"d\", how=\"left\")\n", "\n", "# Remove rows where wm_yr_wk is earlier than release\n", "grid_df = grid_df.filter(F.col(\"wm_yr_wk\") >= F.col(\"release\"))\n", "\n", "# Reset index equivalent (not needed in PySpark, but ensuring ordering)\n", "grid_df = grid_df.withColumn(\"id\", F.monotonically_increasing_id())\n", "\n", "# Minify the release values\n", "min_release = grid_df.agg(F.min(\"release\")).collect()[0][0] # Get minimum release week\n", "grid_df = grid_df.withColumn(\"release\", (F.col(\"release\") - min_release).cast(IntegerType()))\n", "\n", "# Show the transformed grid_df schema and a few rows\n", "grid_df.printSchema()\n", "# grid_df.show(5)\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "from pyspark.sql import functions as F\n", "from pyspark.sql.window import Window\n", "\n", "# Define window partitioning with ORDER BY for sequential computations\n", "store_item_window = Window.partitionBy(\"store_id\", \"item_id\").orderBy(\"wm_yr_wk\")\n", "store_item_month_window = Window.partitionBy(\"store_id\", \"item_id\", \"month\").orderBy(\"wm_yr_wk\")\n", "store_item_year_window = Window.partitionBy(\"store_id\", \"item_id\", \"year\").orderBy(\"wm_yr_wk\")\n", "\n", "# Compute basic aggregations\n", "prices_df = prices_df.withColumn(\"price_max\", F.max(\"sell_price\").over(Window.partitionBy(\"store_id\", \"item_id\")))\n", "prices_df = prices_df.withColumn(\"price_min\", F.min(\"sell_price\").over(Window.partitionBy(\"store_id\", \"item_id\")))\n", "prices_df = prices_df.withColumn(\"price_std\", F.stddev(\"sell_price\").over(Window.partitionBy(\"store_id\", \"item_id\")))\n", "prices_df = prices_df.withColumn(\"price_mean\", F.mean(\"sell_price\").over(Window.partitionBy(\"store_id\", \"item_id\")))\n", "\n", "# Normalize prices (min-max scaling)\n", "prices_df = prices_df.withColumn(\"price_norm\", F.col(\"sell_price\") / F.col(\"price_max\"))\n", "\n", "# Compute distinct counts separately (fix for DISTINCT not allowed in window functions)\n", "price_nunique_df = prices_df.groupBy(\"store_id\", \"item_id\").agg(F.countDistinct(\"sell_price\").alias(\"price_nunique\"))\n", "item_nunique_df = prices_df.groupBy(\"store_id\", \"sell_price\").agg(F.countDistinct(\"item_id\").alias(\"item_nunique\"))\n", "\n", "# Join distinct count results back to prices_df\n", "prices_df = prices_df.join(price_nunique_df, on=[\"store_id\", \"item_id\"], how=\"left\")\n", "prices_df = prices_df.join(item_nunique_df, on=[\"store_id\", \"sell_price\"], how=\"left\")\n", "\n", "# Fix: Select only necessary columns from calendar_df to avoid ambiguity\n", "calendar_prices = calendar_df.select(\n", " F.col(\"wm_yr_wk\"),\n", " F.col(\"month\").alias(\"calendar_month\"), # Renaming to avoid ambiguity\n", " F.col(\"year\").alias(\"calendar_year\")\n", ").dropDuplicates([\"wm_yr_wk\"])\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- wm_yr_wk: integer (nullable = true)\n", " |-- store_id: string (nullable = true)\n", " |-- sell_price: double (nullable = true)\n", " |-- item_id: string (nullable = true)\n", " |-- price_max: double (nullable = true)\n", " |-- price_min: double (nullable = true)\n", " |-- price_std: double (nullable = true)\n", " |-- price_mean: double (nullable = true)\n", " |-- price_norm: double (nullable = true)\n", " |-- price_nunique: long (nullable = true)\n", " |-- item_nunique: long (nullable = true)\n", " |-- price_momentum: double (nullable = true)\n", " |-- price_momentum_m: double (nullable = true)\n", " |-- price_momentum_y: double (nullable = true)\n", "\n" ] } ], "source": [ "# Merge calendar information into prices_df\n", "prices_df = prices_df.join(calendar_prices, on=[\"wm_yr_wk\"], how=\"left\")\n", "\n", "# Compute price momentum\n", "prices_df = prices_df.withColumn(\n", " \"price_momentum\",\n", " F.col(\"sell_price\") / F.lag(\"sell_price\", 1).over(store_item_window)\n", ")\n", "prices_df = prices_df.withColumn(\n", " \"price_momentum_m\",\n", " F.col(\"sell_price\") / F.mean(\"sell_price\").over(\n", " Window.partitionBy(\"store_id\", \"item_id\", \"calendar_month\").orderBy(\"wm_yr_wk\")\n", " )\n", ")\n", "prices_df = prices_df.withColumn(\n", " \"price_momentum_y\",\n", " F.col(\"sell_price\") / F.mean(\"sell_price\").over(\n", " Window.partitionBy(\"store_id\", \"item_id\", \"calendar_year\").orderBy(\"wm_yr_wk\")\n", " )\n", ")\n", "\n", "# Drop temporary columns\n", "prices_df = prices_df.drop(\"calendar_month\", \"calendar_year\")\n", "\n", "# Show schema and verify results\n", "prices_df.printSchema()\n", "# prices_df.show(5)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- store_id: string (nullable = true)\n", " |-- item_id: string (nullable = true)\n", " |-- wm_yr_wk: integer (nullable = true)\n", " |-- d: string (nullable = true)\n", " |-- id: long (nullable = false)\n", " |-- dept_id: string (nullable = true)\n", " |-- cat_id: string (nullable = true)\n", " |-- state_id: string (nullable = true)\n", " |-- sales: double (nullable = true)\n", " |-- release: integer (nullable = true)\n", " |-- sell_price: double (nullable = true)\n", " |-- price_max: double (nullable = true)\n", " |-- price_min: double (nullable = true)\n", " |-- price_std: double (nullable = true)\n", " |-- price_mean: double (nullable = true)\n", " |-- price_norm: double (nullable = true)\n", " |-- price_nunique: long (nullable = true)\n", " |-- item_nunique: long (nullable = true)\n", " |-- price_momentum: double (nullable = true)\n", " |-- price_momentum_m: double (nullable = true)\n", " |-- price_momentum_y: double (nullable = true)\n", "\n" ] } ], "source": [ "# Perform Left Join with prices_df\n", "grid_df = grid_df.join(prices_df, on=['store_id', 'item_id', 'wm_yr_wk'], how=\"left\")\n", "\n", "# We don't need prices_df anymore\n", "del prices_df\n", "\n", "# Show Schema and Sample Data\n", "grid_df.printSchema()\n", "# grid_df.show(10)\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "scrolled": false, "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- d: string (nullable = true)\n", " |-- store_id: string (nullable = true)\n", " |-- item_id: string (nullable = true)\n", " |-- wm_yr_wk: integer (nullable = true)\n", " |-- id: long (nullable = false)\n", " |-- dept_id: string (nullable = true)\n", " |-- cat_id: string (nullable = true)\n", " |-- state_id: string (nullable = true)\n", " |-- sales: double (nullable = true)\n", " |-- release: integer (nullable = true)\n", " |-- sell_price: double (nullable = true)\n", " |-- price_max: double (nullable = true)\n", " |-- price_min: double (nullable = true)\n", " |-- price_std: double (nullable = true)\n", " |-- price_mean: double (nullable = true)\n", " |-- price_norm: double (nullable = true)\n", " |-- price_nunique: long (nullable = true)\n", " |-- item_nunique: long (nullable = true)\n", " |-- price_momentum: double (nullable = true)\n", " |-- price_momentum_m: double (nullable = true)\n", " |-- price_momentum_y: double (nullable = true)\n", " |-- date: date (nullable = true)\n", " |-- event_name_1: string (nullable = true)\n", " |-- event_type_1: string (nullable = true)\n", " |-- event_name_2: string (nullable = true)\n", " |-- event_type_2: string (nullable = true)\n", " |-- snap_CA: integer (nullable = true)\n", " |-- snap_TX: integer (nullable = true)\n", " |-- snap_WI: integer (nullable = true)\n", "\n" ] } ], "source": [ "from pyspark.sql import functions as F\n", "from pyspark.sql.types import IntegerType, BooleanType\n", "from math import ceil\n", "\n", "icols = ['date', 'd', 'event_name_1', 'event_type_1', 'event_name_2', 'event_type_2', \n", " 'snap_CA', 'snap_TX', 'snap_WI']\n", "\n", "grid_df = grid_df.join(calendar_df.select(*icols), on=['d'], how=\"left\")\n", "\n", "grid_df.printSchema()\n", "# grid_df.show()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[Stage 43:==============================================> (40 + 8) / 48]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "root\n", " |-- d: string (nullable = true)\n", " |-- store_id: string (nullable = true)\n", " |-- item_id: string (nullable = true)\n", " |-- id: long (nullable = false)\n", " |-- dept_id: string (nullable = true)\n", " |-- cat_id: string (nullable = true)\n", " |-- state_id: string (nullable = true)\n", " |-- sales: double (nullable = true)\n", " |-- release: integer (nullable = true)\n", " |-- sell_price: double (nullable = true)\n", " |-- price_max: double (nullable = true)\n", " |-- price_min: double (nullable = true)\n", " |-- price_std: double (nullable = true)\n", " |-- price_mean: double (nullable = true)\n", " |-- price_norm: double (nullable = true)\n", " |-- price_nunique: long (nullable = true)\n", " |-- item_nunique: long (nullable = true)\n", " |-- price_momentum: double (nullable = true)\n", " |-- price_momentum_m: double (nullable = true)\n", " |-- price_momentum_y: double (nullable = true)\n", " |-- date: date (nullable = true)\n", " |-- event_name_1: string (nullable = true)\n", " |-- event_type_1: string (nullable = true)\n", " |-- event_name_2: string (nullable = true)\n", " |-- event_type_2: string (nullable = true)\n", " |-- snap_CA: integer (nullable = true)\n", " |-- snap_TX: integer (nullable = true)\n", " |-- snap_WI: integer (nullable = true)\n", " |-- tm_d: integer (nullable = true)\n", " |-- tm_w: integer (nullable = true)\n", " |-- tm_m: integer (nullable = true)\n", " |-- tm_y: integer (nullable = true)\n", " |-- tm_wm: double (nullable = true)\n", " |-- tm_dw: integer (nullable = true)\n", " |-- tm_w_end: integer (nullable = true)\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", " \r" ] } ], "source": [ "# Extract Date Features\n", "grid_df = grid_df.withColumn(\"tm_d\", F.dayofmonth(\"date\"))\n", "grid_df = grid_df.withColumn(\"tm_w\", F.weekofyear(\"date\"))\n", "grid_df = grid_df.withColumn(\"tm_m\", F.month(\"date\"))\n", "grid_df = grid_df.withColumn(\"tm_y\", F.year(\"date\"))\n", "\n", "# Normalize `tm_y` (Subtract Minimum Year)\n", "min_year = grid_df.agg(F.min(\"tm_y\")).collect()[0][0]\n", "grid_df = grid_df.withColumn(\"tm_y\", (F.col(\"tm_y\") - min_year))\n", "\n", "# Compute `tm_wm` (Week of Month)\n", "grid_df = grid_df.withColumn(\"tm_wm\", (F.col(\"tm_d\") / 7 + 0.99)) # ceil(x/7)\n", "\n", "# Compute `tm_dw` (Day of Week) and `tm_w_end` (Weekend Indicator)\n", "grid_df = grid_df.withColumn(\"tm_dw\", F.dayofweek(\"date\") - 1) # Adjust to start from Monday=0\n", "grid_df = grid_df.withColumn(\"tm_w_end\", (F.col(\"tm_dw\") >= 5).cast(IntegerType()))\n", "\n", "# Drop `wm_yr_wk` Column\n", "grid_df = grid_df.drop(\"wm_yr_wk\")\n", "\n", "# how Schema & Sample Data\n", "grid_df.printSchema()\n", "# grid_df.show(10)" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Save your prepared features " ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/03/11 20:03:42 WARN DAGScheduler: Broadcasting large task binary with size 1088.4 KiB\n", " \r" ] } ], "source": [ "import os\n", "grid_df.write.mode(\"overwrite\").csv(os.path.expanduser(\"~/m5_features\"))" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "slideshow": { "slide_type": "slide" } }, "outputs": [], "source": [ "from pyspark.ml.feature import VectorAssembler, StringIndexer\n", "from pyspark.ml.regression import GBTRegressor\n", "from pyspark.ml.evaluation import RegressionEvaluator\n", "from pyspark.sql import functions as F\n", "from pyspark.sql.types import DoubleType, IntegerType\n", "\n", "# Feature Selection (Exclude categorical & string columns)\n", "# FEATURES = [\n", "# \"release\", \"sell_price\", \"price_max\", \"price_min\", \"price_std\", \"price_mean\",\n", "# \"price_norm\", \"price_nunique\", \"item_nunique\", \"price_momentum\", \"price_momentum_m\",\n", "# \"price_momentum_y\", \"tm_d\", \"tm_w\", \"tm_m\", \"tm_y\", \"tm_wm\", \"tm_dw\", \"tm_w_end\",\n", "# \"snap_CA\", \"snap_TX\", \"snap_WI\"\n", "#]\n", "\n", "FEATURES = [\n", " \"release\", \"sell_price\", \"price_max\", \"price_min\", \"price_std\", \"price_mean\"\n", "]\n", "\n", "TARGET = \"sales\"\n", "\n", "# Convert sales to DoubleType (required for GBTRegressor)\n", "grid_df = grid_df.withColumn(TARGET, F.col(TARGET).cast(DoubleType()))\n", "\n", "# Replace NULL in values\n", "grid_df = grid_df.na.fill(0)\n", "\n", "# Assemble feature columns into a single 'features' vector\n", "vector_assembler = VectorAssembler(inputCols=FEATURES, outputCol=\"features\")\n", "grid_df = vector_assembler.transform(grid_df)\n", "\n", "# Train-Test Split\n", "# Use 'd' to split data (adjust threshold as needed)\n", "train_df = grid_df.filter(F.col(\"d\") < 1914) # Training Data\n", "test_df = grid_df.filter(F.col(\"d\") >= 1914) # Test/Validation Data" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Checking Features" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "scrolled": false, "slideshow": { "slide_type": "fragment" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Checking for NULL values in feature columns:\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "+-------+----------+---------+---------+---------+----------+\n", "|release|sell_price|price_max|price_min|price_std|price_mean|\n", "+-------+----------+---------+---------+---------+----------+\n", "| 0| 0| 0| 0| 0| 0|\n", "+-------+----------+---------+---------+---------+----------+\n", "\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ " \r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Training Data Count: 46027957\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "[Stage 129:==========> (9 + 36) / 48]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Test Data Count: 853720\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", "[Stage 129:==============================================> (41 + 7) / 48]\r", "\r", " \r" ] } ], "source": [ "from pyspark.sql.functions import col, count, when, lit\n", "from pyspark.ml.feature import VectorAssembler\n", "from pyspark.ml.regression import GBTRegressor\n", "\n", "# Step 1: Check which feature columns contain NULL values\n", "print(\"Checking for NULL values in feature columns:\")\n", "train_df.select([count(when(col(c).isNull(), 1)).alias(c) for c in FEATURES]).show()\n", "\n", "# Step 2: Replace NULL values in numeric feature columns with 0\n", "for col_name in FEATURES:\n", " train_df = train_df.withColumn(col_name, when(col(col_name).isNull(), lit(0)).otherwise(col(col_name)))\n", "\n", "# Step 3: Drop NULL values in target column `sales`\n", "train_df = train_df.dropna(subset=[\"sales\"])\n", "\n", "# Step 4: Drop existing `features` column if it exists\n", "if \"features\" in train_df.columns:\n", " train_df = train_df.drop(\"features\")\n", "\n", "# Step 5: Recreate `VectorAssembler` with `handleInvalid=\"skip\"`\n", "vector_assembler = VectorAssembler(inputCols=FEATURES, outputCol=\"features\", handleInvalid=\"skip\")\n", "train_df = vector_assembler.transform(train_df).select(\"features\", \"sales\")\n", "\n", "print(f\"Training Data Count: {train_df.count()}\")\n", "print(f\"Test Data Count: {test_df.count()}\")\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "scrolled": false, "slideshow": { "slide_type": "slide" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "25/03/11 20:08:32 WARN DAGScheduler: Broadcasting large task binary with size 1000.7 KiB\n", "25/03/11 20:08:33 WARN DAGScheduler: Broadcasting large task binary with size 1003.3 KiB\n", "25/03/11 20:08:37 WARN DAGScheduler: Broadcasting large task binary with size 1003.8 KiB\n", "25/03/11 20:08:38 WARN DAGScheduler: Broadcasting large task binary with size 1004.4 KiB\n", "25/03/11 20:08:39 WARN DAGScheduler: Broadcasting large task binary with size 1005.5 KiB\n", "25/03/11 20:08:39 WARN DAGScheduler: Broadcasting large task binary with size 1007.8 KiB\n", "25/03/11 20:08:40 WARN DAGScheduler: Broadcasting large task binary with size 1010.5 KiB\n", "25/03/11 20:08:45 WARN DAGScheduler: Broadcasting large task binary with size 1010.9 KiB\n", "25/03/11 20:08:46 WARN DAGScheduler: Broadcasting large task binary with size 1011.5 KiB\n", "25/03/11 20:08:46 WARN DAGScheduler: Broadcasting large task binary with size 1012.6 KiB\n", "25/03/11 20:08:47 WARN DAGScheduler: Broadcasting large task binary with size 1014.7 KiB\n", "25/03/11 20:08:48 WARN DAGScheduler: Broadcasting large task binary with size 1017.1 KiB\n", "25/03/11 20:08:52 WARN DAGScheduler: Broadcasting large task binary with size 1017.6 KiB\n", "25/03/11 20:08:53 WARN DAGScheduler: Broadcasting large task binary with size 1018.2 KiB\n", "25/03/11 20:08:54 WARN DAGScheduler: Broadcasting large task binary with size 1019.3 KiB\n", "25/03/11 20:08:54 WARN DAGScheduler: Broadcasting large task binary with size 1021.6 KiB\n", "25/03/11 20:08:55 WARN DAGScheduler: Broadcasting large task binary with size 1024.2 KiB\n", "25/03/11 20:09:00 WARN DAGScheduler: Broadcasting large task binary with size 1024.7 KiB\n", "25/03/11 20:09:01 WARN DAGScheduler: Broadcasting large task binary with size 1025.3 KiB\n", "25/03/11 20:09:01 WARN DAGScheduler: Broadcasting large task binary with size 1026.4 KiB\n", "25/03/11 20:09:02 WARN DAGScheduler: Broadcasting large task binary with size 1028.7 KiB\n", "25/03/11 20:09:03 WARN DAGScheduler: Broadcasting large task binary with size 1031.4 KiB\n", "25/03/11 20:09:08 WARN DAGScheduler: Broadcasting large task binary with size 1031.8 KiB\n", "25/03/11 20:09:08 WARN DAGScheduler: Broadcasting large task binary with size 1032.4 KiB\n", "25/03/11 20:09:09 WARN DAGScheduler: Broadcasting large task binary with size 1033.6 KiB\n", "25/03/11 20:09:09 WARN DAGScheduler: Broadcasting large task binary with size 1035.8 KiB\n", "25/03/11 20:09:10 WARN DAGScheduler: Broadcasting large task binary with size 1038.5 KiB\n", "25/03/11 20:09:15 WARN DAGScheduler: Broadcasting large task binary with size 1039.0 KiB\n", "25/03/11 20:09:16 WARN DAGScheduler: Broadcasting large task binary with size 1039.5 KiB\n", "25/03/11 20:09:16 WARN DAGScheduler: Broadcasting large task binary with size 1040.7 KiB\n", "25/03/11 20:09:17 WARN DAGScheduler: Broadcasting large task binary with size 1043.0 KiB\n", "25/03/11 20:09:18 WARN DAGScheduler: Broadcasting large task binary with size 1045.6 KiB\n", "25/03/11 20:09:22 WARN DAGScheduler: Broadcasting large task binary with size 1046.1 KiB\n", "25/03/11 20:09:23 WARN DAGScheduler: Broadcasting large task binary with size 1046.7 KiB\n", "25/03/11 20:09:24 WARN DAGScheduler: Broadcasting large task binary with size 1047.8 KiB\n", "25/03/11 20:09:24 WARN DAGScheduler: Broadcasting large task binary with size 1050.1 KiB\n", "25/03/11 20:09:25 WARN DAGScheduler: Broadcasting large task binary with size 1052.8 KiB\n", "25/03/11 20:09:30 WARN DAGScheduler: Broadcasting large task binary with size 1053.2 KiB\n", "25/03/11 20:09:31 WARN DAGScheduler: Broadcasting large task binary with size 1053.8 KiB\n", "25/03/11 20:09:31 WARN DAGScheduler: Broadcasting large task binary with size 1054.9 KiB\n", "25/03/11 20:09:32 WARN DAGScheduler: Broadcasting large task binary with size 1057.2 KiB\n", "25/03/11 20:09:32 WARN DAGScheduler: Broadcasting large task binary with size 1059.9 KiB\n", "25/03/11 20:09:37 WARN DAGScheduler: Broadcasting large task binary with size 1060.4 KiB\n", "25/03/11 20:09:38 WARN DAGScheduler: Broadcasting large task binary with size 1060.9 KiB\n", "25/03/11 20:09:38 WARN DAGScheduler: Broadcasting large task binary with size 1062.1 KiB\n", "25/03/11 20:09:39 WARN DAGScheduler: Broadcasting large task binary with size 1064.3 KiB\n", "25/03/11 20:09:40 WARN DAGScheduler: Broadcasting large task binary with size 1067.0 KiB\n", "25/03/11 20:09:45 WARN DAGScheduler: Broadcasting large task binary with size 1067.5 KiB\n", "25/03/11 20:09:45 WARN DAGScheduler: Broadcasting large task binary with size 1068.1 KiB\n", "25/03/11 20:09:46 WARN DAGScheduler: Broadcasting large task binary with size 1069.2 KiB\n", "25/03/11 20:09:47 WARN DAGScheduler: Broadcasting large task binary with size 1071.5 KiB\n", "25/03/11 20:09:47 WARN DAGScheduler: Broadcasting large task binary with size 1074.2 KiB\n", "25/03/11 20:09:52 WARN DAGScheduler: Broadcasting large task binary with size 1074.6 KiB\n", "25/03/11 20:09:53 WARN DAGScheduler: Broadcasting large task binary with size 1075.2 KiB\n", "25/03/11 20:09:53 WARN DAGScheduler: Broadcasting large task binary with size 1076.3 KiB\n", "25/03/11 20:09:54 WARN DAGScheduler: Broadcasting large task binary with size 1078.6 KiB\n", "25/03/11 20:09:55 WARN DAGScheduler: Broadcasting large task binary with size 1081.3 KiB\n", "25/03/11 20:10:00 WARN DAGScheduler: Broadcasting large task binary with size 1081.8 KiB\n", "25/03/11 20:10:00 WARN DAGScheduler: Broadcasting large task binary with size 1082.3 KiB\n", "25/03/11 20:10:01 WARN DAGScheduler: Broadcasting large task binary with size 1083.5 KiB\n", "25/03/11 20:10:02 WARN DAGScheduler: Broadcasting large task binary with size 1085.7 KiB\n", "25/03/11 20:10:02 WARN DAGScheduler: Broadcasting large task binary with size 1088.4 KiB\n", "25/03/11 20:10:07 WARN DAGScheduler: Broadcasting large task binary with size 1088.9 KiB\n", "25/03/11 20:10:08 WARN DAGScheduler: Broadcasting large task binary with size 1089.5 KiB\n", "25/03/11 20:10:08 WARN DAGScheduler: Broadcasting large task binary with size 1090.6 KiB\n", "25/03/11 20:10:09 WARN DAGScheduler: Broadcasting large task binary with size 1092.9 KiB\n", "25/03/11 20:10:10 WARN DAGScheduler: Broadcasting large task binary with size 1095.5 KiB\n", "25/03/11 20:10:15 WARN DAGScheduler: Broadcasting large task binary with size 1096.0 KiB\n", "25/03/11 20:10:15 WARN DAGScheduler: Broadcasting large task binary with size 1096.6 KiB\n", "25/03/11 20:10:16 WARN DAGScheduler: Broadcasting large task binary with size 1097.7 KiB\n", "25/03/11 20:10:16 WARN DAGScheduler: Broadcasting large task binary with size 1100.0 KiB\n", "25/03/11 20:10:17 WARN DAGScheduler: Broadcasting large task binary with size 1102.4 KiB\n", "25/03/11 20:10:22 WARN DAGScheduler: Broadcasting large task binary with size 1102.8 KiB\n", "25/03/11 20:10:23 WARN DAGScheduler: Broadcasting large task binary with size 1103.4 KiB\n", "25/03/11 20:10:23 WARN DAGScheduler: Broadcasting large task binary with size 1104.5 KiB\n", "25/03/11 20:10:24 WARN DAGScheduler: Broadcasting large task binary with size 1106.5 KiB\n", "25/03/11 20:10:25 WARN DAGScheduler: Broadcasting large task binary with size 1108.8 KiB\n", "25/03/11 20:10:30 WARN DAGScheduler: Broadcasting large task binary with size 1109.3 KiB\n", "25/03/11 20:10:30 WARN DAGScheduler: Broadcasting large task binary with size 1109.9 KiB\n", "25/03/11 20:10:31 WARN DAGScheduler: Broadcasting large task binary with size 1111.0 KiB\n", "25/03/11 20:10:31 WARN DAGScheduler: Broadcasting large task binary with size 1113.3 KiB\n", "25/03/11 20:10:32 WARN DAGScheduler: Broadcasting large task binary with size 1116.0 KiB\n", "25/03/11 20:10:37 WARN DAGScheduler: Broadcasting large task binary with size 1116.4 KiB\n", "25/03/11 20:10:38 WARN DAGScheduler: Broadcasting large task binary with size 1117.0 KiB\n", "25/03/11 20:10:38 WARN DAGScheduler: Broadcasting large task binary with size 1118.2 KiB\n", "25/03/11 20:10:39 WARN DAGScheduler: Broadcasting large task binary with size 1120.4 KiB\n", "25/03/11 20:10:40 WARN DAGScheduler: Broadcasting large task binary with size 1122.8 KiB\n", "25/03/11 20:10:45 WARN DAGScheduler: Broadcasting large task binary with size 1123.3 KiB\n", "25/03/11 20:10:45 WARN DAGScheduler: Broadcasting large task binary with size 1123.8 KiB\n", "25/03/11 20:10:46 WARN DAGScheduler: Broadcasting large task binary with size 1125.0 KiB\n", "25/03/11 20:10:46 WARN DAGScheduler: Broadcasting large task binary with size 1127.0 KiB\n", "25/03/11 20:10:47 WARN DAGScheduler: Broadcasting large task binary with size 1129.3 KiB\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "25/03/11 20:10:52 WARN DAGScheduler: Broadcasting large task binary with size 1129.7 KiB\n", "25/03/11 20:10:52 WARN DAGScheduler: Broadcasting large task binary with size 1130.3 KiB\n", "25/03/11 20:10:53 WARN DAGScheduler: Broadcasting large task binary with size 1131.4 KiB\n", "25/03/11 20:10:53 WARN DAGScheduler: Broadcasting large task binary with size 1133.7 KiB\n", "25/03/11 20:10:54 WARN DAGScheduler: Broadcasting large task binary with size 1136.4 KiB\n", "25/03/11 20:10:59 WARN DAGScheduler: Broadcasting large task binary with size 1136.9 KiB\n", "25/03/11 20:10:59 WARN DAGScheduler: Broadcasting large task binary with size 1137.4 KiB\n", "25/03/11 20:11:00 WARN DAGScheduler: Broadcasting large task binary with size 1138.6 KiB\n", "25/03/11 20:11:00 WARN DAGScheduler: Broadcasting large task binary with size 1140.8 KiB\n", "25/03/11 20:11:01 WARN DAGScheduler: Broadcasting large task binary with size 1143.5 KiB\n", "25/03/11 20:11:06 WARN DAGScheduler: Broadcasting large task binary with size 1144.0 KiB\n", "25/03/11 20:11:06 WARN DAGScheduler: Broadcasting large task binary with size 1144.6 KiB\n", "25/03/11 20:11:07 WARN DAGScheduler: Broadcasting large task binary with size 1145.7 KiB\n", "25/03/11 20:11:07 WARN DAGScheduler: Broadcasting large task binary with size 1148.0 KiB\n", "25/03/11 20:11:08 WARN DAGScheduler: Broadcasting large task binary with size 1150.6 KiB\n", "25/03/11 20:11:12 WARN DAGScheduler: Broadcasting large task binary with size 1151.1 KiB\n", "25/03/11 20:11:13 WARN DAGScheduler: Broadcasting large task binary with size 1151.7 KiB\n", "25/03/11 20:11:14 WARN DAGScheduler: Broadcasting large task binary with size 1152.8 KiB\n", "25/03/11 20:11:14 WARN DAGScheduler: Broadcasting large task binary with size 1155.1 KiB\n", "25/03/11 20:11:15 WARN DAGScheduler: Broadcasting large task binary with size 1157.8 KiB\n", "25/03/11 20:11:20 WARN DAGScheduler: Broadcasting large task binary with size 1158.3 KiB\n", "25/03/11 20:11:20 WARN DAGScheduler: Broadcasting large task binary with size 1158.8 KiB\n", "25/03/11 20:11:21 WARN DAGScheduler: Broadcasting large task binary with size 1160.0 KiB\n", "25/03/11 20:11:21 WARN DAGScheduler: Broadcasting large task binary with size 1162.2 KiB\n", "25/03/11 20:11:22 WARN DAGScheduler: Broadcasting large task binary with size 1164.9 KiB\n", "25/03/11 20:11:27 WARN DAGScheduler: Broadcasting large task binary with size 1165.4 KiB\n", "25/03/11 20:11:27 WARN DAGScheduler: Broadcasting large task binary with size 1166.0 KiB\n", "25/03/11 20:11:28 WARN DAGScheduler: Broadcasting large task binary with size 1167.1 KiB\n", "25/03/11 20:11:29 WARN DAGScheduler: Broadcasting large task binary with size 1169.4 KiB\n", "25/03/11 20:11:29 WARN DAGScheduler: Broadcasting large task binary with size 1172.0 KiB\n", "25/03/11 20:11:34 WARN DAGScheduler: Broadcasting large task binary with size 1172.5 KiB\n", "25/03/11 20:11:35 WARN DAGScheduler: Broadcasting large task binary with size 1173.1 KiB\n", "25/03/11 20:11:35 WARN DAGScheduler: Broadcasting large task binary with size 1174.2 KiB\n", "25/03/11 20:11:36 WARN DAGScheduler: Broadcasting large task binary with size 1176.5 KiB\n", "25/03/11 20:11:37 WARN DAGScheduler: Broadcasting large task binary with size 1179.2 KiB\n", "25/03/11 20:11:41 WARN DAGScheduler: Broadcasting large task binary with size 1179.7 KiB\n", "25/03/11 20:11:41 WARN DAGScheduler: Broadcasting large task binary with size 1180.2 KiB\n", "25/03/11 20:11:42 WARN DAGScheduler: Broadcasting large task binary with size 1181.4 KiB\n", "25/03/11 20:11:43 WARN DAGScheduler: Broadcasting large task binary with size 1183.6 KiB\n", "25/03/11 20:11:44 WARN DAGScheduler: Broadcasting large task binary with size 1186.3 KiB\n", "25/03/11 20:11:48 WARN DAGScheduler: Broadcasting large task binary with size 1186.8 KiB\n", "25/03/11 20:11:48 WARN DAGScheduler: Broadcasting large task binary with size 1187.3 KiB\n", "25/03/11 20:11:49 WARN DAGScheduler: Broadcasting large task binary with size 1188.5 KiB\n", "25/03/11 20:11:50 WARN DAGScheduler: Broadcasting large task binary with size 1190.8 KiB\n", "WARNING: An illegal reflective access operation has occurred\n", "WARNING: Illegal reflective access by org.apache.spark.util.SizeEstimator$ (file:/nfs-share/software/anaconda/2020.02/envs/python3.12/lib/python3.12/site-packages/pyspark/jars/spark-core_2.12-3.5.3.jar) to field java.nio.charset.Charset.name\n", "WARNING: Please consider reporting this to the maintainers of org.apache.spark.util.SizeEstimator$\n", "WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations\n", "WARNING: All illegal access operations will be denied in a future release\n", "25/03/11 20:12:11 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS\n", "[Stage 1436:===========================================> (39 + 9) / 48]\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Root Mean Squared Error (RMSE): 3.285564654183651\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\r", "[Stage 1436:===================================================> (46 + 2) / 48]\r", "\r", " \r" ] } ], "source": [ "# Train GBT Model\n", "gbt = GBTRegressor(featuresCol=\"features\", labelCol=\"sales\", maxIter=50, maxDepth=5, stepSize=0.1)\n", "model = gbt.fit(train_df)\n", "\n", "# Make Predictions\n", "predictions = model.transform(test_df)\n", "\n", "# Evaluate Model Performance\n", "evaluator = RegressionEvaluator(labelCol=TARGET, predictionCol=\"prediction\", metricName=\"rmse\")\n", "rmse = evaluator.evaluate(predictions)\n", "\n", "print(f\"Root Mean Squared Error (RMSE): {rmse}\")" ] }, { "cell_type": "markdown", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "## Save Model and Load for Future Predictions\n", "\n", "```python\n", "model.write().overwrite().save(\"m5_gbt_forecasting_model\")\n", "predictions.select(\"sales\", \"prediction\").write.mode(\"overwrite\").parquet(\"m5_gbt_predictions.parquet\")\n", "\n", "# Load Model for Future Predictions\n", "from pyspark.ml.regression import GBTRegressionModel\n", "loaded_model = GBTRegressionModel.load(\"m5_gbt_forecasting_model\")\n", "\n", "# Make new predictions with the loaded model\n", "new_predictions = loaded_model.transform(test_df)\n", "new_predictions.show(10)\n", "```\n" ] } ], "metadata": { "celltoolbar": "Slideshow", "kernelspec": { "display_name": "python3.12", "language": "python", "name": "python3.12" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 4 }