How Deep Are scikit-learn’s Histogram-based Gradient Boosted Trees?

🗻🚀🎄

Thomas J. Fan

This talk on Github: thomasjpfan/2020-richmond-ds-meetup-gradient-boosting

In Chat:

Beginner

Type 1

Intermediate

Type 2

Expert

Type 3

Supervised Learning 📖

$$ y = f(X) $$

  • $X$ of shape (n_samples, n_features)
  • $y$ of shape (n_samples,)

Scikit-learn API 🛠

from sklearn.experimental import enable_hist_gradient_boosting
from sklearn.ensemble import HistGradientBoostingClassifier

clf = HistGradientBoostingClassifier()

clf.fit(X, y)

clf.predict(X)

HistGradient-Boosting 🚀

Boosting 🚀

$$ f(X) = h_0(X) + h_1(X) + h_2(X) + … $$

$$ f(X) = \sum_i h_i(X) $$

Hist-Gradient-Boosting 🗻

Gradient 🗻

Regression

  • least_squares
  • least_absolute_deviation
  • poisson

Classificaiton

  • binary_crossentropy
  • categorical_crossentropy
  • auto

Loss Function - least_squares

$$ L(y, f(X)) = \frac{1}{2}||y - f(X)||^2 $$

Gradient

$$ \nabla L(y, f(X)) = -(y - f(X)) $$

Hessian

$$ \nabla^2 L(y, f(X)) = 1 $$

Gradient Boosting 🗻🚀

  • Initial Condition

$$ f_0(X) = C $$

  • Recursive Condition

$$ f_{m+1}(X) = f_{m}(X) - \eta \nabla L(y, f_{m}(X)) $$

where $\eta$ is the learning rate

Gradient Boosting 🏂 - least_squares

  • Plugging in gradient for least_square

$$ f_{m+1}(X) = f_{m}(X) + \eta(y - f_{m}(X)) $$

  • Letting $h_{m}(X)=(y - f_{m}(X))$

$$ f_{m+1}(X) = f_{m}(X) + \eta h_{m}(X) $$

  • We need to learn $h_{m}(X)$!
  • For the next example, let $\eta=1$

Gradient Boosting 🏂 - (Example, part 1)

$$ f_0(X) = C $$

Gradient Boosting 🏂 - (Example, part 2)

Gradient Boosting 🏂 - (Example, part 3)

Gradient Boosting 🏂 - (Example, part 4)

$$ f_{m+1}(X) = f_{m}(X) + h_{m}(X) $$

Gradient Boosting 🏂 - (Example, part 5)

Gradient Boosting 🏂 - (Example, part 6)

Gradient Boosting 🏂 - (Example, part 7)

Gradient Boosting 🏂

With two iterations of boosting:

$$ f(X) = C + h_0(X) + h_1(X) $$

Prediction

For example, with $X=40$

$$ f(40) = 78 + h_0(40) + h_1(40) $$

How to learn $h_m(X)$?

🎄🎄🎄🎄🎄🎄🎄

🎄🎄🎄🎄🎄🎄🎄

🎄🎄🎄🎄🎄🎄🎄

🎄🎄🎄🎄🎄🎄🎄

🎄🎄🎄🎄🎄🎄🎄

Tree Growing 🌲

  1. For every feature
    1. Sort feature
    2. For every split point
      1. Evaluate split
  2. Pick best split

How to evaluate split?

least_square

  • Recall Loss, Gradient, Hessian

$$ L(y, f(X)) = \frac{1}{2}||y - f(X)||^2 $$

$$ G = \nabla L(y, f(X)) = -(y - f(X)) $$

$$ H = \nabla^2 L(y, f(X)) = 1 $$

How to evaluate split?

Maximize the Gain!

$$ Gain = \dfrac{1}{2}\left[\dfrac{G_L^2}{H_L+\lambda} + \dfrac{G_R^2}{H_R + \lambda} - \dfrac{(G_L+G_R)^2}{H_L+H_R+\lambda}\right] $$

default $\lambda$: l2_regularization=0

Tree Growing 🎄

Are we done?

  1. For every feature
    1. Sort feature
    2. For every split point
      1. Evaluate split
  2. Pick best split

Tree Growing 🎄

Are we done?

  1. For every feature
    1. Sort feature - O(nlog(n))
    2. For every split point - O(n)
      1. Evaluate split
  2. Pick best split

Hist-GradientBoosting

Binning! 🗑

Binning! 🗑

# Original data
[-0.752,  2.7042,  1.3919,  0.5091, -2.0636,
 -2.064, -2.6514,  2.1977,  0.6007,  1.2487, ...]

# Binned data
[4, 9, 7, 6, 2, 1, 0, 8, 6, 7, ...]

Histograms! 📊

Histograms! 📊

Overview

  1. For every feature
    1. Build histogram O(n)
    2. For every split point - O(n_bins)
      1. Evaluate split
  2. Pick best split

One More Trick 🎩

Trees = $h_m(X)$ 🎄

$$ f(X) = C + \sum h_{m}(X) $$

Overview of Algorithm 👀

  1. Bin data
  2. Make initial predictions (constant)
  3. Calculate gradients and hessians
  4. Grow Trees For Boosting
    1. Find best splits
    2. Add tree to predictors
    3. Update gradients and hessians

Implementation? 🤔

  • Pure Python?
  • Numpy?
  • Cython?
  • Cython + OpenMP!

OpenMP! Bin data 🗑

  1. Bin data
  2. Make initial predictions (constant)
  3. Calculate gradients and hessians
  4. Grow Trees For Boosting
    1. Find best splits by building histograms
    2. Add tree to predictors
    3. Update gradients and hessians

OpenMP! Bin data 🗑

for i in range(n_samples):
    left, right = 0, binning_thresholds.shape[0]
    while left < right:
        middle = left + (right - left - 1) // 2
        if data[i] <= binning_thresholds[middle]:
            right = middle
        else:
            left = middle + 1
    binned[i] = left

OpenMP! Bin data 🗑

# sklearn/ensemble/_hist_gradient_boosting/_binning.pyx
for i in prange(n_samples, schedule='static', nogil=True):
    left, right = 0, binning_thresholds.shape[0]
    while left < right:
        middle = left + (right - left - 1) // 2
        if data[i] <= binning_thresholds[middle]:
            right = middle
        else:
            left = middle + 1
    binned[i] = left

OpenMP! Building histograms 🌋

  1. Bin data
  2. Make initial predictions (constant)
  3. Calculate gradients and hessians
  4. Grow Trees For Boosting
    1. Find best splits by building histograms
    2. Add tree to predictors
    3. Update gradients and hessians

OpenMP! Building histograms 🌋

# sklearn/ensemble/_hist_gradient_boosting/histogram.pyx
with nogil:
    for feature_idx in prange(n_features, schedule='static'):
        self._compute_histogram_brute_single_feature(...)

for feature_idx in prange(n_features, schedule='static',
                          nogil=True):
    _subtract_histograms(feature_idx, ...)

OpenMP! Find best splits ✂️

  1. Bin data
  2. Make initial predictions (constant)
  3. Calculate gradients and hessians
  4. Grow Trees For Boosting
    1. Find best splits by building histograms
    2. Add tree to predictors
    3. Update gradients and hessians

OpenMP! Find best splits ✂️

# sklearn/ensemble/_hist_gradient_boosting/splitting.pyx
for feature_idx in prange(n_features, schedule='static'):
    # For each feature, find best bin to split on

OpenMP! Splitting ✂️

# sklearn/ensemble/_hist_gradient_boosting/splitting.pyx
for thread_idx in prange(n_threads, schedule='static',
                         chunksize=1):
    # splits a partition of node

OpenMP! Update gradients and hessians 🏔

  1. Bin data
  2. Make initial predictions (constant)
  3. Calculate gradients and hessians
  4. Grow Trees For Boosting
    1. Find best splits by building histograms
    2. Add tree to predictors
    3. Update gradients and hessians

OpenMP! Update gradients and hessians 🏔

least_squares

# sklearn/ensemble/_hist_gradient_boosting/_loss.pyx
for i in prange(n_samples, schedule='static', nogil=True):
    gradients[i] = raw_predictions[i] - y_true[i]

Hyper-parameters 📓

Hyper-parameters: Bin Data 🗑

  1. Bin data
  2. Make initial predictions (constant)
  3. Calculate gradients and hessians
  4. Grow Trees For Boosting
    1. Find best splits by building histograms
    2. Add tree to predictors
    3. Update gradients and hessians

Hyper-parameters: Bin Data 🗑

max_bins=255

Hyper-parameters: Loss 📉

  1. Bin data
  2. Make initial predictions (constant)
  3. Calculate gradients and hessians
  4. Grow Trees For Boosting
    1. Find best splits by building histograms
    2. Add tree to predictors
    3. Update gradients and hessians

Hyper-parameters: Loss 📉

  • HistGradientBoostingRegressor

    • loss=least_squares (default)
    • least_absolute_deviation
    • poisson
  • HistGradientBoostingClassifier

    • loss=auto (default)
    • binary_crossentropy
    • categorical_crossentropy
  • l2_regularization=0

Hyper-parameters: Boosting 🏂

  1. Bin data
  2. Make initial predictions (constant)
  3. Calculate gradients and hessians
  4. Grow Trees For Boosting
    1. Find best splits by building histograms
    2. Add tree to predictors
    3. Update gradients and hessians

Hyper-parameters: Boosting 🏂

  • learning_rate=0.1 ($\eta$)
  • max_iter=100

Hyper-parameters: Boosting 🏂

Hyper-parameters: Boosting 🏂

Hyper-parameters: Grow Trees 🎄

  1. Bin data
  2. Make initial predictions (constant)
  3. Calculate gradients and hessians
  4. Grow Trees For Boosting
    1. Find best splits by building histograms
    2. Add tree to predictors
    3. Update gradients and hessians

Hyper-parameters: Grow Trees 🎄

  • max_leaf_nodes=31
  • max_depth=None
  • min_samples_leaf=20

Hyper-parameters: Grow Trees 🎄

Hyper-parameters: Grow Trees 🎄

Hyper-parameters: Early Stopping 🛑

  1. Bin data
  2. Split into a validation dataset
  3. Make initial predictions (constant)
  4. Calculate gradients and hessians
  5. Grow Trees For Boosting
    1. Stop if early stop condition is true

Hyper-parameters: Early Stopping 🛑

  • early_stopping='auto' (enabled if n_samples>10_000)
  • scoring='loss'
  • validation_fraction=0.1
  • n_iter_no_change=10
  • tol=1e-7

Hyper-parameters: Early Stopping 🛑

Hyper-parameters: Misc 🎁

  • verbose=0
  • random_state=None
  • export OMP_NUM_THREADS=8

Recently Added Features

  • Missing values (0.22)
  • Monotonic constraints (0.23)
  • Poisson loss (0.23)
  • Categorical features (0.24)

Missing Values (0.22)

from sklearn.experimental import enable_hist_gradient_boosting  # noqa
from sklearn.ensemble import HistGradientBoostingClassifier
import numpy as np

X = np.array([0, 1, 2, np.nan]).reshape(-1, 1)
y = [0, 0, 1, 1]

gbdt = HistGradientBoostingClassifier(min_samples_leaf=1).fit(X, y)
gbdt.predict(X)
# [0 0 1 1]

Monotonic Constraints (0.23)

from sklearn.experimental import enable_hist_gradient_boosting  # noqa
from sklearn.ensemble import HistGradientBoostingRegressor

X, y = ...

gbdt_no_cst = HistGradientBoostingRegressor().fit(X, y)
gbdt_cst = HistGradientBoostingRegressor(monotonic_cst=[1, 0]).fit(X, y)

Monotonic Constraints (0.23)

from sklearn.inspection import plot_partial_dependence

disp = plot_partial_dependence(
    gbdt_no_cst, X, features=[0], feature_names=['feature 0'], line_kw={...})
plot_partial_dependence(gbdt_cst, X, features=[0], line_kw={...}, ax=disp.axes_)

Poisson Loss (0.23)

hist_poisson = HistGradientBoostingRegressor(loss='poisson')

Categorical Features (0.24)

From categorical example

categorical_mask = ([True] * n_categorical_features +
                    [False] * n_numerical_features)
hist = HistGradientBoostingRegressor(categorical_features=categorical_mask)

Compared to Other Libraries

XGBoost

conda install -c conda-forge xgboost
from xgboost import XGBClassifier
xgb = XGBClassifier()
  • GPU training
  • Networked parallel training
  • Sparse data

LightGBM

conda install -c conda-forge lightgbm
from lightgbm.sklearn import LGBMClassifier
lgbm = LGBMClassifier()
  • GPU training
  • Networked parallel training
  • Sparse data

CatBoost

conda install -c conda-forge catboost
from catboost.sklearn import CatBoostClassifier
catb = CatBoostClassifier()
  • Focus on categorical features
  • Bagged and smoothed target encoding for categorical features
  • Symmetric trees
  • GPU training
  • Tooling

Benchmark 🚀

HIGGS Boson

  • 8800000 samples
  • 28 features
  • binary classification (1 for signal, 0 for background)

Current Benchmark Results

librarytimeroc aucaccuracy
sklearn66s0.81260.7325
lightgbm42s0.81250.7323
xgboost45s0.81240.7325
catboost90s0.80080.7223

Versions

  • xgboost=1.3.0.post0
  • lightgbm=3.1.1
  • catboost=0.24.3

Thank you Working on This 🎉

  • @hug_nicolas lead the development of this algorithm in sklearn
  • All the core developers for reviewing!

Conclusion

Future Work

  • Sparse Data
  • Improve performance when compared to other frameworks.
  • Better way to pass feature-aligned metadata to estimators in a pipeline.

Learn more about Histogram-Based Gradient Boosting

pip install scikit-learn==0.24.0rc1