Forecasting with Gradient Boosted Tree¶

Feng Li¶

Guanghua School of Management¶

Peking University¶

feng.li@gsm.pku.edu.cn¶

Course home page: https://feng.li/bdcf¶

Gradient Boosted Tree Regressor¶

GBTRegressor (Gradient Boosted Tree Regressor) in Spark MLlib is a supervised learning algorithm that builds an ensemble of decision trees using gradient boosting to improve predictive accuracy.

How GBTRegressor Works¶

Gradient Boosted Trees (GBT) work by training decision trees sequentially, where:

  1. The first tree makes an initial prediction.
  2. Each subsequent tree learns from the errors (residuals) of the previous trees.
  3. The final prediction is the sum of all trees’ outputs.

This technique is effective for handling non-linear relationships in data and reducing bias and variance.

Code Example¶

from pyspark.ml.regression import GBTRegressor
from pyspark.ml.evaluation import RegressionEvaluator

# Initialize GBT Regressor
gbt = GBTRegressor(featuresCol="features", labelCol="sales", maxIter=50, maxDepth=5, stepSize=0.1)

# Train the model
model = gbt.fit(train_df)

# Make predictions
predictions = model.transform(test_df)

# Evaluate using RMSE
evaluator = RegressionEvaluator(labelCol="sales", predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(predictions)

print(f"Root Mean Squared Error (RMSE): {rmse}")

Key Hyperparameters¶

Parameter Description
maxIter Number of trees in the ensemble (higher = more complex model)
maxDepth Maximum depth of each tree (higher = risk of overfitting)
stepSize Learning rate (default 0.1 for stability)
subsamplingRate Fraction of data used for each tree (default 1.0, full dataset)
maxBins Number of bins for feature discretization (default 32)
minInstancesPerNode Minimum instances required per node (default 1)

Recommended Settings¶

  • For small datasets → maxIter=20, maxDepth=3
  • For large datasets → maxIter=50, maxDepth=5
  • For fine-tuning → Adjust stepSize (0.05 - 0.2)

Advantages and Limitations¶

✅ Handles complex non-linear relationships
✅ More accurate than a single Decision Tree
✅ Built-in feature selection (important features contribute more)
✅ Works well with missing values

🚨 Slower training compared to Random Forest (sequential training of trees)
🚨 Prone to overfitting with large maxDepth
🚨 Not suited for real-time applications (expensive to update)


In [1]:
import os, sys # Ensure All environment variables are properly set 
# os.environ["JAVA_HOME"] = os.path.dirname(sys.executable)
os.environ["PYSPARK_PYTHON"] = sys.executable
os.environ["PYSPARK_DRIVER_PYTHON"] = sys.executable

from pyspark.sql import SparkSession # build Spark Session
spark = SparkSession.builder \
    .config("spark.ui.enabled", "false") \
    .config("spark.executor.memory", "16g") \
    .config("spark.executor.cores", "4") \
    .config("spark.cores.max", "32") \
    .config("spark.driver.memory", "30g") \
    .config("spark.sql.shuffle.partitions", "96") \
    .config("spark.memory.fraction", "0.8") \
    .config("spark.memory.storageFraction", "0.5") \
    .config("spark.dynamicAllocation.enabled", "true") \
    .config("spark.dynamicAllocation.minExecutors", "4") \
    .config("spark.dynamicAllocation.maxExecutors", "8") \
    .appName("Optimized Spark").getOrCreate()
spark
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/03/11 20:02:23 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Out[1]:

SparkSession - in-memory

SparkContext

Spark UI

Version
v3.5.3
Master
local[*]
AppName
Optimized Spark

Rule of Thumb for Optimized Spark Configuration¶

  • .config("spark.sql.shuffle.partitions", "32 * 3"): 3x number of cores
  • .config("spark.dynamicAllocation.enabled", "true"): Dynamic Allocation
  • .config("spark.memory.fraction", "0.8"): 80% memory for execution
  • .config("spark.memory.storageFraction", "0.5"): # 50% for storage (cached data)
In [2]:
# Load Data
train_df = spark.read.csv("../data/m5-forecasting-accuracy/sales_train_evaluation.csv", header=True, inferSchema=True)
prices_df = spark.read.csv("../data//m5-forecasting-accuracy/sell_prices.csv", header=True, inferSchema=True)
calendar_df = spark.read.csv("../data//m5-forecasting-accuracy/calendar.csv", header=True, inferSchema=True)

TARGET = 'sales'          # Our main target
END_TRAIN = 1913+28       # Last day in train set
MAIN_INDEX = ['id','d']   # We can identify item by these columns
                                                                                
In [3]:
from pyspark.sql.functions import col, lit, expr
from pyspark.sql.types import StringType
import numpy as np
# calendar_df.printSchema()

# Define index columns
index_columns = ['id', 'item_id', 'dept_id', 'cat_id', 'store_id', 'state_id']

# Melt train_df using explode
train_df_long = train_df.selectExpr(
    *index_columns, 
    "stack(" + str(len(train_df.columns) - len(index_columns)) + 
    "".join([f", '{col}', {col}" for col in train_df.columns if col not in index_columns]) + ") as (d, sales)"
)

# Convert "d" column format to match Pandas melt
train_df_long = train_df_long.withColumn("d", expr("substring(d, 3, length(d)-2)"))

# Count rows
print("Train rows:", train_df.count(), train_df_long.count())

# Create "test set" grid for future dates
from pyspark.sql.functions import monotonically_increasing_id

add_grid = train_df.select(*index_columns).dropDuplicates()
add_grid = add_grid.crossJoin(
    spark.createDataFrame([(f"d_{END_TRAIN+i}", np.nan) for i in range(1, 29)], ["d", TARGET])
)

# Combine train and test sets
grid_df = train_df_long.union(add_grid)

# Show memory usage estimate (PySpark does not have direct memory usage functions)
print(f"Total rows in grid_df: {grid_df.count()}")
25/03/11 20:02:32 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                
Train rows: 30490 59181090
[Stage 15:==========>                                            (12 + 38) / 66]
Total rows in grid_df: 60034810
[Stage 15:========================>                              (29 + 37) / 66]

                                                                                
In [4]:
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType

# Group by store_id and item_id to get the earliest (min) wm_yr_wk (release week)
release_df = prices_df.groupBy("store_id", "item_id").agg(F.min("wm_yr_wk").alias("release"))

# Merge release_df with grid_df
grid_df = grid_df.join(release_df, on=["store_id", "item_id"], how="left")

# Remove release_df to free memory
del release_df

# Merge with calendar_df to get wm_yr_wk column
calendar_df = calendar_df.withColumn("d", expr("substring(d, 3, length(d)-2)"))
grid_df = grid_df.join(calendar_df.select("wm_yr_wk", "d"), on="d", how="left")

# Remove rows where wm_yr_wk is earlier than release
grid_df = grid_df.filter(F.col("wm_yr_wk") >= F.col("release"))

# Reset index equivalent (not needed in PySpark, but ensuring ordering)
grid_df = grid_df.withColumn("id", F.monotonically_increasing_id())

# Minify the release values
min_release = grid_df.agg(F.min("release")).collect()[0][0]  # Get minimum release week
grid_df = grid_df.withColumn("release", (F.col("release") - min_release).cast(IntegerType()))

# Show the transformed grid_df schema and a few rows
grid_df.printSchema()
# grid_df.show(5)
[Stage 27:====================================================>   (45 + 3) / 48]
root
 |-- d: string (nullable = true)
 |-- store_id: string (nullable = true)
 |-- item_id: string (nullable = true)
 |-- id: long (nullable = false)
 |-- dept_id: string (nullable = true)
 |-- cat_id: string (nullable = true)
 |-- state_id: string (nullable = true)
 |-- sales: double (nullable = true)
 |-- release: integer (nullable = true)
 |-- wm_yr_wk: integer (nullable = true)

                                                                                
In [5]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window

# Define window partitioning with ORDER BY for sequential computations
store_item_window = Window.partitionBy("store_id", "item_id").orderBy("wm_yr_wk")
store_item_month_window = Window.partitionBy("store_id", "item_id", "month").orderBy("wm_yr_wk")
store_item_year_window = Window.partitionBy("store_id", "item_id", "year").orderBy("wm_yr_wk")

# Compute basic aggregations
prices_df = prices_df.withColumn("price_max", F.max("sell_price").over(Window.partitionBy("store_id", "item_id")))
prices_df = prices_df.withColumn("price_min", F.min("sell_price").over(Window.partitionBy("store_id", "item_id")))
prices_df = prices_df.withColumn("price_std", F.stddev("sell_price").over(Window.partitionBy("store_id", "item_id")))
prices_df = prices_df.withColumn("price_mean", F.mean("sell_price").over(Window.partitionBy("store_id", "item_id")))

# Normalize prices (min-max scaling)
prices_df = prices_df.withColumn("price_norm", F.col("sell_price") / F.col("price_max"))

# Compute distinct counts separately (fix for DISTINCT not allowed in window functions)
price_nunique_df = prices_df.groupBy("store_id", "item_id").agg(F.countDistinct("sell_price").alias("price_nunique"))
item_nunique_df = prices_df.groupBy("store_id", "sell_price").agg(F.countDistinct("item_id").alias("item_nunique"))

# Join distinct count results back to prices_df
prices_df = prices_df.join(price_nunique_df, on=["store_id", "item_id"], how="left")
prices_df = prices_df.join(item_nunique_df, on=["store_id", "sell_price"], how="left")

# Fix: Select only necessary columns from calendar_df to avoid ambiguity
calendar_prices = calendar_df.select(
    F.col("wm_yr_wk"),
    F.col("month").alias("calendar_month"),  # Renaming to avoid ambiguity
    F.col("year").alias("calendar_year")
).dropDuplicates(["wm_yr_wk"])
In [6]:
# Merge calendar information into prices_df
prices_df = prices_df.join(calendar_prices, on=["wm_yr_wk"], how="left")

# Compute price momentum
prices_df = prices_df.withColumn(
    "price_momentum",
    F.col("sell_price") / F.lag("sell_price", 1).over(store_item_window)
)
prices_df = prices_df.withColumn(
    "price_momentum_m",
    F.col("sell_price") / F.mean("sell_price").over(
        Window.partitionBy("store_id", "item_id", "calendar_month").orderBy("wm_yr_wk")
    )
)
prices_df = prices_df.withColumn(
    "price_momentum_y",
    F.col("sell_price") / F.mean("sell_price").over(
        Window.partitionBy("store_id", "item_id", "calendar_year").orderBy("wm_yr_wk")
    )
)

# Drop temporary columns
prices_df = prices_df.drop("calendar_month", "calendar_year")

# Show schema and verify results
prices_df.printSchema()
# prices_df.show(5)
root
 |-- wm_yr_wk: integer (nullable = true)
 |-- store_id: string (nullable = true)
 |-- sell_price: double (nullable = true)
 |-- item_id: string (nullable = true)
 |-- price_max: double (nullable = true)
 |-- price_min: double (nullable = true)
 |-- price_std: double (nullable = true)
 |-- price_mean: double (nullable = true)
 |-- price_norm: double (nullable = true)
 |-- price_nunique: long (nullable = true)
 |-- item_nunique: long (nullable = true)
 |-- price_momentum: double (nullable = true)
 |-- price_momentum_m: double (nullable = true)
 |-- price_momentum_y: double (nullable = true)

In [7]:
# Perform Left Join with prices_df
grid_df = grid_df.join(prices_df, on=['store_id', 'item_id', 'wm_yr_wk'], how="left")

# We don't need prices_df anymore
del prices_df

# Show Schema and Sample Data
grid_df.printSchema()
# grid_df.show(10)
root
 |-- store_id: string (nullable = true)
 |-- item_id: string (nullable = true)
 |-- wm_yr_wk: integer (nullable = true)
 |-- d: string (nullable = true)
 |-- id: long (nullable = false)
 |-- dept_id: string (nullable = true)
 |-- cat_id: string (nullable = true)
 |-- state_id: string (nullable = true)
 |-- sales: double (nullable = true)
 |-- release: integer (nullable = true)
 |-- sell_price: double (nullable = true)
 |-- price_max: double (nullable = true)
 |-- price_min: double (nullable = true)
 |-- price_std: double (nullable = true)
 |-- price_mean: double (nullable = true)
 |-- price_norm: double (nullable = true)
 |-- price_nunique: long (nullable = true)
 |-- item_nunique: long (nullable = true)
 |-- price_momentum: double (nullable = true)
 |-- price_momentum_m: double (nullable = true)
 |-- price_momentum_y: double (nullable = true)

In [8]:
from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType, BooleanType
from math import ceil

icols = ['date', 'd', 'event_name_1', 'event_type_1', 'event_name_2', 'event_type_2', 
         'snap_CA', 'snap_TX', 'snap_WI']

grid_df = grid_df.join(calendar_df.select(*icols), on=['d'], how="left")

grid_df.printSchema()
# grid_df.show()
root
 |-- d: string (nullable = true)
 |-- store_id: string (nullable = true)
 |-- item_id: string (nullable = true)
 |-- wm_yr_wk: integer (nullable = true)
 |-- id: long (nullable = false)
 |-- dept_id: string (nullable = true)
 |-- cat_id: string (nullable = true)
 |-- state_id: string (nullable = true)
 |-- sales: double (nullable = true)
 |-- release: integer (nullable = true)
 |-- sell_price: double (nullable = true)
 |-- price_max: double (nullable = true)
 |-- price_min: double (nullable = true)
 |-- price_std: double (nullable = true)
 |-- price_mean: double (nullable = true)
 |-- price_norm: double (nullable = true)
 |-- price_nunique: long (nullable = true)
 |-- item_nunique: long (nullable = true)
 |-- price_momentum: double (nullable = true)
 |-- price_momentum_m: double (nullable = true)
 |-- price_momentum_y: double (nullable = true)
 |-- date: date (nullable = true)
 |-- event_name_1: string (nullable = true)
 |-- event_type_1: string (nullable = true)
 |-- event_name_2: string (nullable = true)
 |-- event_type_2: string (nullable = true)
 |-- snap_CA: integer (nullable = true)
 |-- snap_TX: integer (nullable = true)
 |-- snap_WI: integer (nullable = true)

In [9]:
# Extract Date Features
grid_df = grid_df.withColumn("tm_d", F.dayofmonth("date"))
grid_df = grid_df.withColumn("tm_w", F.weekofyear("date"))
grid_df = grid_df.withColumn("tm_m", F.month("date"))
grid_df = grid_df.withColumn("tm_y", F.year("date"))

# Normalize `tm_y` (Subtract Minimum Year)
min_year = grid_df.agg(F.min("tm_y")).collect()[0][0]
grid_df = grid_df.withColumn("tm_y", (F.col("tm_y") - min_year))

# Compute `tm_wm` (Week of Month)
grid_df = grid_df.withColumn("tm_wm", (F.col("tm_d") / 7 + 0.99))  # ceil(x/7)

# Compute `tm_dw` (Day of Week) and `tm_w_end` (Weekend Indicator)
grid_df = grid_df.withColumn("tm_dw", F.dayofweek("date") - 1)  # Adjust to start from Monday=0
grid_df = grid_df.withColumn("tm_w_end", (F.col("tm_dw") >= 5).cast(IntegerType()))

# Drop `wm_yr_wk` Column
grid_df = grid_df.drop("wm_yr_wk")

# how Schema & Sample Data
grid_df.printSchema()
# grid_df.show(10)
[Stage 43:==============================================>         (40 + 8) / 48]
root
 |-- d: string (nullable = true)
 |-- store_id: string (nullable = true)
 |-- item_id: string (nullable = true)
 |-- id: long (nullable = false)
 |-- dept_id: string (nullable = true)
 |-- cat_id: string (nullable = true)
 |-- state_id: string (nullable = true)
 |-- sales: double (nullable = true)
 |-- release: integer (nullable = true)
 |-- sell_price: double (nullable = true)
 |-- price_max: double (nullable = true)
 |-- price_min: double (nullable = true)
 |-- price_std: double (nullable = true)
 |-- price_mean: double (nullable = true)
 |-- price_norm: double (nullable = true)
 |-- price_nunique: long (nullable = true)
 |-- item_nunique: long (nullable = true)
 |-- price_momentum: double (nullable = true)
 |-- price_momentum_m: double (nullable = true)
 |-- price_momentum_y: double (nullable = true)
 |-- date: date (nullable = true)
 |-- event_name_1: string (nullable = true)
 |-- event_type_1: string (nullable = true)
 |-- event_name_2: string (nullable = true)
 |-- event_type_2: string (nullable = true)
 |-- snap_CA: integer (nullable = true)
 |-- snap_TX: integer (nullable = true)
 |-- snap_WI: integer (nullable = true)
 |-- tm_d: integer (nullable = true)
 |-- tm_w: integer (nullable = true)
 |-- tm_m: integer (nullable = true)
 |-- tm_y: integer (nullable = true)
 |-- tm_wm: double (nullable = true)
 |-- tm_dw: integer (nullable = true)
 |-- tm_w_end: integer (nullable = true)

                                                                                

Save your prepared features¶

In [10]:
import os
grid_df.write.mode("overwrite").csv(os.path.expanduser("~/m5_features"))
25/03/11 20:03:42 WARN DAGScheduler: Broadcasting large task binary with size 1088.4 KiB
                                                                                
In [11]:
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml.regression import GBTRegressor
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.sql import functions as F
from pyspark.sql.types import DoubleType, IntegerType

# Feature Selection (Exclude categorical & string columns)
# FEATURES = [
#    "release", "sell_price", "price_max", "price_min", "price_std", "price_mean",
#    "price_norm", "price_nunique", "item_nunique", "price_momentum", "price_momentum_m",
#    "price_momentum_y", "tm_d", "tm_w", "tm_m", "tm_y", "tm_wm", "tm_dw", "tm_w_end",
#    "snap_CA", "snap_TX", "snap_WI"
#]

FEATURES = [
    "release", "sell_price", "price_max", "price_min", "price_std", "price_mean"
]

TARGET = "sales"

# Convert sales to DoubleType (required for GBTRegressor)
grid_df = grid_df.withColumn(TARGET, F.col(TARGET).cast(DoubleType()))

# Replace NULL in values
grid_df = grid_df.na.fill(0)

# Assemble feature columns into a single 'features' vector
vector_assembler = VectorAssembler(inputCols=FEATURES, outputCol="features")
grid_df = vector_assembler.transform(grid_df)

# Train-Test Split
# Use 'd' to split data (adjust threshold as needed)
train_df = grid_df.filter(F.col("d") < 1914)  # Training Data
test_df = grid_df.filter(F.col("d") >= 1914)  # Test/Validation Data

Checking Features¶

In [12]:
from pyspark.sql.functions import col, count, when, lit
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import GBTRegressor

# Step 1: Check which feature columns contain NULL values
print("Checking for NULL values in feature columns:")
train_df.select([count(when(col(c).isNull(), 1)).alias(c) for c in FEATURES]).show()

# Step 2: Replace NULL values in numeric feature columns with 0
for col_name in FEATURES:
    train_df = train_df.withColumn(col_name, when(col(col_name).isNull(), lit(0)).otherwise(col(col_name)))

# Step 3: Drop NULL values in target column `sales`
train_df = train_df.dropna(subset=["sales"])

# Step 4: Drop existing `features` column if it exists
if "features" in train_df.columns:
    train_df = train_df.drop("features")

# Step 5: Recreate `VectorAssembler` with `handleInvalid="skip"`
vector_assembler = VectorAssembler(inputCols=FEATURES, outputCol="features", handleInvalid="skip")
train_df = vector_assembler.transform(train_df).select("features", "sales")

print(f"Training Data Count: {train_df.count()}")
print(f"Test Data Count: {test_df.count()}")
Checking for NULL values in feature columns:
                                                                                
+-------+----------+---------+---------+---------+----------+
|release|sell_price|price_max|price_min|price_std|price_mean|
+-------+----------+---------+---------+---------+----------+
|      0|         0|        0|        0|        0|         0|
+-------+----------+---------+---------+---------+----------+

                                                                                
Training Data Count: 46027957
[Stage 129:==========>                                            (9 + 36) / 48]
Test Data Count: 853720
[Stage 129:==============================================>        (41 + 7) / 48]

                                                                                
In [13]:
# Train GBT Model
gbt = GBTRegressor(featuresCol="features", labelCol="sales", maxIter=50, maxDepth=5, stepSize=0.1)
model = gbt.fit(train_df)

# Make Predictions
predictions = model.transform(test_df)

# Evaluate Model Performance
evaluator = RegressionEvaluator(labelCol=TARGET, predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(predictions)

print(f"Root Mean Squared Error (RMSE): {rmse}")
25/03/11 20:08:32 WARN DAGScheduler: Broadcasting large task binary with size 1000.7 KiB
25/03/11 20:08:33 WARN DAGScheduler: Broadcasting large task binary with size 1003.3 KiB
25/03/11 20:08:37 WARN DAGScheduler: Broadcasting large task binary with size 1003.8 KiB
25/03/11 20:08:38 WARN DAGScheduler: Broadcasting large task binary with size 1004.4 KiB
25/03/11 20:08:39 WARN DAGScheduler: Broadcasting large task binary with size 1005.5 KiB
25/03/11 20:08:39 WARN DAGScheduler: Broadcasting large task binary with size 1007.8 KiB
25/03/11 20:08:40 WARN DAGScheduler: Broadcasting large task binary with size 1010.5 KiB
25/03/11 20:08:45 WARN DAGScheduler: Broadcasting large task binary with size 1010.9 KiB
25/03/11 20:08:46 WARN DAGScheduler: Broadcasting large task binary with size 1011.5 KiB
25/03/11 20:08:46 WARN DAGScheduler: Broadcasting large task binary with size 1012.6 KiB
25/03/11 20:08:47 WARN DAGScheduler: Broadcasting large task binary with size 1014.7 KiB
25/03/11 20:08:48 WARN DAGScheduler: Broadcasting large task binary with size 1017.1 KiB
25/03/11 20:08:52 WARN DAGScheduler: Broadcasting large task binary with size 1017.6 KiB
25/03/11 20:08:53 WARN DAGScheduler: Broadcasting large task binary with size 1018.2 KiB
25/03/11 20:08:54 WARN DAGScheduler: Broadcasting large task binary with size 1019.3 KiB
25/03/11 20:08:54 WARN DAGScheduler: Broadcasting large task binary with size 1021.6 KiB
25/03/11 20:08:55 WARN DAGScheduler: Broadcasting large task binary with size 1024.2 KiB
25/03/11 20:09:00 WARN DAGScheduler: Broadcasting large task binary with size 1024.7 KiB
25/03/11 20:09:01 WARN DAGScheduler: Broadcasting large task binary with size 1025.3 KiB
25/03/11 20:09:01 WARN DAGScheduler: Broadcasting large task binary with size 1026.4 KiB
25/03/11 20:09:02 WARN DAGScheduler: Broadcasting large task binary with size 1028.7 KiB
25/03/11 20:09:03 WARN DAGScheduler: Broadcasting large task binary with size 1031.4 KiB
25/03/11 20:09:08 WARN DAGScheduler: Broadcasting large task binary with size 1031.8 KiB
25/03/11 20:09:08 WARN DAGScheduler: Broadcasting large task binary with size 1032.4 KiB
25/03/11 20:09:09 WARN DAGScheduler: Broadcasting large task binary with size 1033.6 KiB
25/03/11 20:09:09 WARN DAGScheduler: Broadcasting large task binary with size 1035.8 KiB
25/03/11 20:09:10 WARN DAGScheduler: Broadcasting large task binary with size 1038.5 KiB
25/03/11 20:09:15 WARN DAGScheduler: Broadcasting large task binary with size 1039.0 KiB
25/03/11 20:09:16 WARN DAGScheduler: Broadcasting large task binary with size 1039.5 KiB
25/03/11 20:09:16 WARN DAGScheduler: Broadcasting large task binary with size 1040.7 KiB
25/03/11 20:09:17 WARN DAGScheduler: Broadcasting large task binary with size 1043.0 KiB
25/03/11 20:09:18 WARN DAGScheduler: Broadcasting large task binary with size 1045.6 KiB
25/03/11 20:09:22 WARN DAGScheduler: Broadcasting large task binary with size 1046.1 KiB
25/03/11 20:09:23 WARN DAGScheduler: Broadcasting large task binary with size 1046.7 KiB
25/03/11 20:09:24 WARN DAGScheduler: Broadcasting large task binary with size 1047.8 KiB
25/03/11 20:09:24 WARN DAGScheduler: Broadcasting large task binary with size 1050.1 KiB
25/03/11 20:09:25 WARN DAGScheduler: Broadcasting large task binary with size 1052.8 KiB
25/03/11 20:09:30 WARN DAGScheduler: Broadcasting large task binary with size 1053.2 KiB
25/03/11 20:09:31 WARN DAGScheduler: Broadcasting large task binary with size 1053.8 KiB
25/03/11 20:09:31 WARN DAGScheduler: Broadcasting large task binary with size 1054.9 KiB
25/03/11 20:09:32 WARN DAGScheduler: Broadcasting large task binary with size 1057.2 KiB
25/03/11 20:09:32 WARN DAGScheduler: Broadcasting large task binary with size 1059.9 KiB
25/03/11 20:09:37 WARN DAGScheduler: Broadcasting large task binary with size 1060.4 KiB
25/03/11 20:09:38 WARN DAGScheduler: Broadcasting large task binary with size 1060.9 KiB
25/03/11 20:09:38 WARN DAGScheduler: Broadcasting large task binary with size 1062.1 KiB
25/03/11 20:09:39 WARN DAGScheduler: Broadcasting large task binary with size 1064.3 KiB
25/03/11 20:09:40 WARN DAGScheduler: Broadcasting large task binary with size 1067.0 KiB
25/03/11 20:09:45 WARN DAGScheduler: Broadcasting large task binary with size 1067.5 KiB
25/03/11 20:09:45 WARN DAGScheduler: Broadcasting large task binary with size 1068.1 KiB
25/03/11 20:09:46 WARN DAGScheduler: Broadcasting large task binary with size 1069.2 KiB
25/03/11 20:09:47 WARN DAGScheduler: Broadcasting large task binary with size 1071.5 KiB
25/03/11 20:09:47 WARN DAGScheduler: Broadcasting large task binary with size 1074.2 KiB
25/03/11 20:09:52 WARN DAGScheduler: Broadcasting large task binary with size 1074.6 KiB
25/03/11 20:09:53 WARN DAGScheduler: Broadcasting large task binary with size 1075.2 KiB
25/03/11 20:09:53 WARN DAGScheduler: Broadcasting large task binary with size 1076.3 KiB
25/03/11 20:09:54 WARN DAGScheduler: Broadcasting large task binary with size 1078.6 KiB
25/03/11 20:09:55 WARN DAGScheduler: Broadcasting large task binary with size 1081.3 KiB
25/03/11 20:10:00 WARN DAGScheduler: Broadcasting large task binary with size 1081.8 KiB
25/03/11 20:10:00 WARN DAGScheduler: Broadcasting large task binary with size 1082.3 KiB
25/03/11 20:10:01 WARN DAGScheduler: Broadcasting large task binary with size 1083.5 KiB
25/03/11 20:10:02 WARN DAGScheduler: Broadcasting large task binary with size 1085.7 KiB
25/03/11 20:10:02 WARN DAGScheduler: Broadcasting large task binary with size 1088.4 KiB
25/03/11 20:10:07 WARN DAGScheduler: Broadcasting large task binary with size 1088.9 KiB
25/03/11 20:10:08 WARN DAGScheduler: Broadcasting large task binary with size 1089.5 KiB
25/03/11 20:10:08 WARN DAGScheduler: Broadcasting large task binary with size 1090.6 KiB
25/03/11 20:10:09 WARN DAGScheduler: Broadcasting large task binary with size 1092.9 KiB
25/03/11 20:10:10 WARN DAGScheduler: Broadcasting large task binary with size 1095.5 KiB
25/03/11 20:10:15 WARN DAGScheduler: Broadcasting large task binary with size 1096.0 KiB
25/03/11 20:10:15 WARN DAGScheduler: Broadcasting large task binary with size 1096.6 KiB
25/03/11 20:10:16 WARN DAGScheduler: Broadcasting large task binary with size 1097.7 KiB
25/03/11 20:10:16 WARN DAGScheduler: Broadcasting large task binary with size 1100.0 KiB
25/03/11 20:10:17 WARN DAGScheduler: Broadcasting large task binary with size 1102.4 KiB
25/03/11 20:10:22 WARN DAGScheduler: Broadcasting large task binary with size 1102.8 KiB
25/03/11 20:10:23 WARN DAGScheduler: Broadcasting large task binary with size 1103.4 KiB
25/03/11 20:10:23 WARN DAGScheduler: Broadcasting large task binary with size 1104.5 KiB
25/03/11 20:10:24 WARN DAGScheduler: Broadcasting large task binary with size 1106.5 KiB
25/03/11 20:10:25 WARN DAGScheduler: Broadcasting large task binary with size 1108.8 KiB
25/03/11 20:10:30 WARN DAGScheduler: Broadcasting large task binary with size 1109.3 KiB
25/03/11 20:10:30 WARN DAGScheduler: Broadcasting large task binary with size 1109.9 KiB
25/03/11 20:10:31 WARN DAGScheduler: Broadcasting large task binary with size 1111.0 KiB
25/03/11 20:10:31 WARN DAGScheduler: Broadcasting large task binary with size 1113.3 KiB
25/03/11 20:10:32 WARN DAGScheduler: Broadcasting large task binary with size 1116.0 KiB
25/03/11 20:10:37 WARN DAGScheduler: Broadcasting large task binary with size 1116.4 KiB
25/03/11 20:10:38 WARN DAGScheduler: Broadcasting large task binary with size 1117.0 KiB
25/03/11 20:10:38 WARN DAGScheduler: Broadcasting large task binary with size 1118.2 KiB
25/03/11 20:10:39 WARN DAGScheduler: Broadcasting large task binary with size 1120.4 KiB
25/03/11 20:10:40 WARN DAGScheduler: Broadcasting large task binary with size 1122.8 KiB
25/03/11 20:10:45 WARN DAGScheduler: Broadcasting large task binary with size 1123.3 KiB
25/03/11 20:10:45 WARN DAGScheduler: Broadcasting large task binary with size 1123.8 KiB
25/03/11 20:10:46 WARN DAGScheduler: Broadcasting large task binary with size 1125.0 KiB
25/03/11 20:10:46 WARN DAGScheduler: Broadcasting large task binary with size 1127.0 KiB
25/03/11 20:10:47 WARN DAGScheduler: Broadcasting large task binary with size 1129.3 KiB
25/03/11 20:10:52 WARN DAGScheduler: Broadcasting large task binary with size 1129.7 KiB
25/03/11 20:10:52 WARN DAGScheduler: Broadcasting large task binary with size 1130.3 KiB
25/03/11 20:10:53 WARN DAGScheduler: Broadcasting large task binary with size 1131.4 KiB
25/03/11 20:10:53 WARN DAGScheduler: Broadcasting large task binary with size 1133.7 KiB
25/03/11 20:10:54 WARN DAGScheduler: Broadcasting large task binary with size 1136.4 KiB
25/03/11 20:10:59 WARN DAGScheduler: Broadcasting large task binary with size 1136.9 KiB
25/03/11 20:10:59 WARN DAGScheduler: Broadcasting large task binary with size 1137.4 KiB
25/03/11 20:11:00 WARN DAGScheduler: Broadcasting large task binary with size 1138.6 KiB
25/03/11 20:11:00 WARN DAGScheduler: Broadcasting large task binary with size 1140.8 KiB
25/03/11 20:11:01 WARN DAGScheduler: Broadcasting large task binary with size 1143.5 KiB
25/03/11 20:11:06 WARN DAGScheduler: Broadcasting large task binary with size 1144.0 KiB
25/03/11 20:11:06 WARN DAGScheduler: Broadcasting large task binary with size 1144.6 KiB
25/03/11 20:11:07 WARN DAGScheduler: Broadcasting large task binary with size 1145.7 KiB
25/03/11 20:11:07 WARN DAGScheduler: Broadcasting large task binary with size 1148.0 KiB
25/03/11 20:11:08 WARN DAGScheduler: Broadcasting large task binary with size 1150.6 KiB
25/03/11 20:11:12 WARN DAGScheduler: Broadcasting large task binary with size 1151.1 KiB
25/03/11 20:11:13 WARN DAGScheduler: Broadcasting large task binary with size 1151.7 KiB
25/03/11 20:11:14 WARN DAGScheduler: Broadcasting large task binary with size 1152.8 KiB
25/03/11 20:11:14 WARN DAGScheduler: Broadcasting large task binary with size 1155.1 KiB
25/03/11 20:11:15 WARN DAGScheduler: Broadcasting large task binary with size 1157.8 KiB
25/03/11 20:11:20 WARN DAGScheduler: Broadcasting large task binary with size 1158.3 KiB
25/03/11 20:11:20 WARN DAGScheduler: Broadcasting large task binary with size 1158.8 KiB
25/03/11 20:11:21 WARN DAGScheduler: Broadcasting large task binary with size 1160.0 KiB
25/03/11 20:11:21 WARN DAGScheduler: Broadcasting large task binary with size 1162.2 KiB
25/03/11 20:11:22 WARN DAGScheduler: Broadcasting large task binary with size 1164.9 KiB
25/03/11 20:11:27 WARN DAGScheduler: Broadcasting large task binary with size 1165.4 KiB
25/03/11 20:11:27 WARN DAGScheduler: Broadcasting large task binary with size 1166.0 KiB
25/03/11 20:11:28 WARN DAGScheduler: Broadcasting large task binary with size 1167.1 KiB
25/03/11 20:11:29 WARN DAGScheduler: Broadcasting large task binary with size 1169.4 KiB
25/03/11 20:11:29 WARN DAGScheduler: Broadcasting large task binary with size 1172.0 KiB
25/03/11 20:11:34 WARN DAGScheduler: Broadcasting large task binary with size 1172.5 KiB
25/03/11 20:11:35 WARN DAGScheduler: Broadcasting large task binary with size 1173.1 KiB
25/03/11 20:11:35 WARN DAGScheduler: Broadcasting large task binary with size 1174.2 KiB
25/03/11 20:11:36 WARN DAGScheduler: Broadcasting large task binary with size 1176.5 KiB
25/03/11 20:11:37 WARN DAGScheduler: Broadcasting large task binary with size 1179.2 KiB
25/03/11 20:11:41 WARN DAGScheduler: Broadcasting large task binary with size 1179.7 KiB
25/03/11 20:11:41 WARN DAGScheduler: Broadcasting large task binary with size 1180.2 KiB
25/03/11 20:11:42 WARN DAGScheduler: Broadcasting large task binary with size 1181.4 KiB
25/03/11 20:11:43 WARN DAGScheduler: Broadcasting large task binary with size 1183.6 KiB
25/03/11 20:11:44 WARN DAGScheduler: Broadcasting large task binary with size 1186.3 KiB
25/03/11 20:11:48 WARN DAGScheduler: Broadcasting large task binary with size 1186.8 KiB
25/03/11 20:11:48 WARN DAGScheduler: Broadcasting large task binary with size 1187.3 KiB
25/03/11 20:11:49 WARN DAGScheduler: Broadcasting large task binary with size 1188.5 KiB
25/03/11 20:11:50 WARN DAGScheduler: Broadcasting large task binary with size 1190.8 KiB
WARNING: An illegal reflective access operation has occurred
WARNING: Illegal reflective access by org.apache.spark.util.SizeEstimator$ (file:/nfs-share/software/anaconda/2020.02/envs/python3.12/lib/python3.12/site-packages/pyspark/jars/spark-core_2.12-3.5.3.jar) to field java.nio.charset.Charset.name
WARNING: Please consider reporting this to the maintainers of org.apache.spark.util.SizeEstimator$
WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations
WARNING: All illegal access operations will be denied in a future release
25/03/11 20:12:11 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
[Stage 1436:===========================================>          (39 + 9) / 48]
Root Mean Squared Error (RMSE): 3.285564654183651
[Stage 1436:===================================================>  (46 + 2) / 48]

                                                                                

Save Model and Load for Future Predictions¶

model.write().overwrite().save("m5_gbt_forecasting_model")
predictions.select("sales", "prediction").write.mode("overwrite").parquet("m5_gbt_predictions.parquet")

# Load Model for Future Predictions
from pyspark.ml.regression import GBTRegressionModel
loaded_model = GBTRegressionModel.load("m5_gbt_forecasting_model")

# Make new predictions with the loaded model
new_predictions = loaded_model.transform(test_df)
new_predictions.show(10)

Further Reading¶

  • Januschowski, T., Wang, Y., Torkkola, K., Erkkilä, T., Hasson, H., & Gasthaus, J. (2022). Forecasting with trees. International Journal of Forecasting, 38(4), 1473-1481. https://doi.org/10.1016/j.ijforecast.2021.10.004