title: Pushing Cython to its Limits in Scikit-learn use_katex: False class: title-slide # Pushing Cython to its Limits in Scikit-learn data:image/s3,"s3://crabby-images/98bc1/98bc100bb1d952a84caf0de3067bff9b8251950d" alt=":scale 30%" .larger[Thomas J. Fan]
@thomasjpfan
github.com/thomasjpfan/pydata-nyc-2024-cython-in-scikit-learn
--- class: top
# Me - Senior Machine Engineer @ Union.ai .g.g-middle[ .g-6.g-end[ data:image/s3,"s3://crabby-images/08627/086270846ea74ddea6b4bb51cd858d0f7baaa3b0" alt=":scale 50%" ] .g-6.g-start[ data:image/s3,"s3://crabby-images/d9f20/d9f20e4499633d159a37c4fcc1768067d2bca57c" alt=":scale 50%" ] ] -- - Maintainer for scikit-learn .center[ data:image/s3,"s3://crabby-images/61236/61236c366ef808f76b6becbd8a8bd391e1ef4077" alt=":scale 30%" ] --- .g.g-middle[ .g-6[ # Agenda 📓 ## - Why Cython? 🚀 ## - Cython 101 🍀 ## - Scikit-learn Use Cases 🛠️ ] .g-6.g-center[ data:image/s3,"s3://crabby-images/8bed9/8bed9b0763e7a4e2bc2441ce96fe5d0191f67995" alt=":scale 90%" ] ] --- # Why Cython? 🚀 .g.g-middle[ .g-6.larger[ ## 1. Python-Like 🐍 ## 2. Performance ### Improve Runtime 🏎️ ### Reduce Memory Usage 🧠 ] .g-6.g-center[ data:image/s3,"s3://crabby-images/b9fae/b9fae43bb6eb43f54cef6ae9c2a203c03c7bf766" alt=":scale 80%" ] ] --- class: top # Performance Uplift .g.g-center[ .g-1[] .g-5[ data:image/s3,"s3://crabby-images/61236/61236c366ef808f76b6becbd8a8bd391e1ef4077" alt=":scale 70%" ] .g-5[ data:image/s3,"s3://crabby-images/8bed9/8bed9b0763e7a4e2bc2441ce96fe5d0191f67995" alt=":scale 70%" ] .g-1[] ] - `HistGradientBoosting*`: **LightGBM**-like performance - **2x improvement**: `LogisticRegression`, `linear_model` module, and `GradientBoosting*` --- class: top # Performance Uplift .g.g-center[ .g-1[] .g-5[ data:image/s3,"s3://crabby-images/61236/61236c366ef808f76b6becbd8a8bd391e1ef4077" alt=":scale 70%" ] .g-5[ data:image/s3,"s3://crabby-images/8bed9/8bed9b0763e7a4e2bc2441ce96fe5d0191f67995" alt=":scale 70%" ] .g-1[] ] - `HistGradientBoosting*`: **LightGBM**-like performance - **2x improvement**: `LogisticRegression`, `linear_model` module, and `GradientBoosting*` - **20x improvement** in `cluster`, `manifold`, `neighbors`, `semi_supervised` modules - `TargetEncoder` - **4-5x runtime** and less memory usage - **Reduce memory usage** for validation checks --- .g.g-middle[ .g-6[ # Profiling 🔬 - `cProfile` + snakeviz - `viztracer` - `memray` - `Scalene` ] .g-6.g-center[ data:image/s3,"s3://crabby-images/357b0/357b061c579357d8097a0c4cf35197dcbe6cc18f" alt=":scale 70%" ] ] --- # Finding Hot-spots 🔎 ## `cProfile` + snakeviz ```bash python -m cProfile -o hist.prof hist.py snakeviz hist.prof ``` data:image/s3,"s3://crabby-images/28c7d/28c7dcea823b76470dd84a4e234efafab156a304" alt="" --- # Finding Hot-spots 🔎 ## `viztracer` ```bash viztracer hist.py vizviewer result.json ``` data:image/s3,"s3://crabby-images/e0ad3/e0ad309c86ef28e222dee429385eeaa7c660fbe1" alt="" --- # Memory Profiling 🧠 ## `memray` ```bash memray run np-copy.py memray flamegraph memray-np-copy.py.88600.bin ``` data:image/s3,"s3://crabby-images/07706/0770622d6159bbe231e295d18cbd59c1181868b8" alt="" --- # Memory Profiling 🧠 ## `memray` data:image/s3,"s3://crabby-images/a2255/a22550c7fdf057a34034ad1e97bc7470914697ee" alt="" --- # Memory Profiling 🧠 ## Scalene ```bash scalene np-copy.py ``` data:image/s3,"s3://crabby-images/90f4d/90f4d8c4cab5d8c04520c16eae55171ca63e4ef2" alt="" --- class: chapter-slide # Cython 101 🍀 --- .g.g-middle[ .g-6[ # Cython 101 🍀 ## - Compiling ## - Types ## - Developer Tips ] .g-6[ data:image/s3,"s3://crabby-images/487f6/487f69aeae2b23127e9f3a38769a112a245b1c90" alt=":scale 80%" ] ] --- class: top
# Compiling ```python # simple.pyx def add(x, y): return x + y ``` -- ## `setup.py` ```python from setuptools import setup from Cython.Build import cythonize setup( ext_modules=cythonize("simple.pyx"), ) ``` -- ## Build Command ```bash python setup.py build_ext --inplace ``` --- class: top
# Importing from Python code ```python import simple result = simple.add(10, 12) print(result) ``` -- ## Benefits - Does not go through the Python Interpreter ```python # simple.pyx def add(x, y): return x + y ``` --- class: top
# Adding Types ```python # simple.pyx def add(x: int, y: int): return x + y ``` -- ## Benefits - Removes the Python interpreter - Compiler can optimize with types --- # Cython Overview
.center[ data:image/s3,"s3://crabby-images/6b463/6b463720c0548b1df9113b5b97978ac42da9cc06" alt=":scale 100%" ] --- class: top
# Defining Functions - `def` : Call from Python - `cdef` : Call from Cython -- ```python *cdef float linear(slope: float, x: float, b: float): return slope * x + b ``` -- ```python def two_linear(slope: float, x: float, b: float): cdef: * float r1 = linear(slope, x, b) float r2 = linear(-slope, x, b) float result = r1 + 2 * r2 return result ``` -- - `cpdef` : Call from Python & Cython (Scikit-learn uses it like a `def`) --- # Developing ## Annotation ```bash cython --annotate simple.pyx ``` .center[ data:image/s3,"s3://crabby-images/3ea46/3ea46c9524d4f9a5779accaa471889aa29123d4a" alt=":scale 70%" ] --- # Working in Jupyter .center[ data:image/s3,"s3://crabby-images/bd7e5/bd7e5fda626c9891ed782d6c94625a395576a2b5" alt=":scale 100%" ] --- # Working in Jupyter (Annotation) data:image/s3,"s3://crabby-images/a7ff2/a7ff2b674522d34a8537fc92b017357182d96d1e" alt=":scale 55%" --- class: chapter-slide # Scikit-learn Use Cases 🛠️ --- .g.g-middle[ .g-6[ # Scikit-learn Use Cases 🛠️ ## Python <-> Cython interface ⚙️ ## Performance - Improve Runtime 🏎️ - Reduce Memory Usage 🧠 ] .g-6.g-center[ data:image/s3,"s3://crabby-images/61236/61236c366ef808f76b6becbd8a8bd391e1ef4077" alt=":scale 70%" ] ] --- class: top # Python <-> Cython interface - NumPy Arrays .center[ data:image/s3,"s3://crabby-images/287aa/287aa0d2d806c259feea5657c0499097789c2cb5" alt=":scale 30%" ] -- ## Memoryview ```python %% cython *def add_value(float[:, :] X, float value): ... ``` -- ## Call from Python ```python import numpy as np y = np.ones(shape=(3, 2), dtype=np.float32) result = add_value(y, 1.4) ``` ## Python [Buffer Protocol](https://docs.python.org/3/c-api/buffer.html) 🔌 --- class: top
# Python <> Cython interface - NumPy Arrays ## Write loops! ```python %% cython def add_value(double[:, :] X, double value): cdef: size_t i, j size_t N = X.shape[0] size_t M = X.shape[1] for i in range(N): for j in range(M): X[i, j] += value ``` -- ## It's okay! 😆 --- # Scikit-learn Optimizations for memoryviews ## Directives! ```python scikit_learn_cython_args = [ '-X language_level=3', '-X boundscheck=' + boundscheck, '-X wraparound=False', ... ] ``` .footnote-back[ [meson.build](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/meson.build#L183-L190) ] --- # Memoryview directives (`boundscheck=True`) data:image/s3,"s3://crabby-images/fc042/fc042b6b7c51a2869cb85753c4588701d3189827" alt=":scale 80%" --- # Memoryview directives (`boundscheck=False`) data:image/s3,"s3://crabby-images/a889f/a889fb27cd695ad6e4b2d32e2654aa644c50aaee" alt="" --- # Memoryview directives (`wraparound=True`) data:image/s3,"s3://crabby-images/23bc2/23bc2bf3b4320700653401cec58c7c107f790c16" alt="" --- # Memoryview directives (`wraparound=False`) data:image/s3,"s3://crabby-images/4af7b/4af7b6b77d5c48741ffa0335e3c72c3a9fd8251f" alt="" --- class: top
# Cython directives ## Define for file 🗃️ ```python # cython: language_level=3 # cython: boundscheck=False # cython: wraparound=False cimport cython ... ``` -- ## Globally in build backend 🌎 Scikit-learn's [sklearn/meson.build](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/meson.build#L183-L190) --- # scikit-learn Global configuration ## Dynamic configure `boundscheck` for testing ```python scikit_learn_cython_args = [ '-X language_level=3', * '-X boundscheck=' + boundscheck, '-X wraparound=False', ... ] ``` --- # Returning memoryviews ```python def _make_unique(...): cdef floating[::1] y_out = np.empty(unique_values, dtype=dtype) # Computation return( * np.asarray(x_out[:i+1]), ... ) ``` - `IsotonicRegression` .footnote-back[ [_isotonic.pyx](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/_isotonic.pyx#L111-L115) ] --- class: top # Strides 2D data:image/s3,"s3://crabby-images/c6778/c67780f0534db00b216a8a9589a65a4aeadf3345" alt=":scale 70%" -- .g[ .g-6[ #### `float[:, ::1]` - C contiguous data:image/s3,"s3://crabby-images/1985e/1985e0c7a1ee3b784ec5b7886c34c73c49b37807" alt=":scale 50%" ] .g-6[ ] ] --- class: top # Strides 2D data:image/s3,"s3://crabby-images/c6778/c67780f0534db00b216a8a9589a65a4aeadf3345" alt=":scale 70%" .g[ .g-6[ #### `float[:, ::1]` - C contiguous data:image/s3,"s3://crabby-images/1985e/1985e0c7a1ee3b784ec5b7886c34c73c49b37807" alt=":scale 50%" ] .g-6[ #### `float[::1, :]` - F contiguous data:image/s3,"s3://crabby-images/34579/345793f7ec4ff6acfb9195366aabee2fd10beea7" alt=":scale 40%" ] ] --- class: top
# NumPy API ```python import numpy as np a = np.ones((2, 3)) print(a.flags) ``` -- ``` C_CONTIGUOUS : True F_CONTIGUOUS : False OWNDATA : True WRITEABLE : True ALIGNED : True WRITEBACKIFCOPY : False ``` --- # Const memoryviews ## Support readonly data ```python cpdef floating _inertia_dense( * const floating[:, ::1] X, # IN const floating[::1] sample_weight, # IN const floating[:, ::1] centers, # IN const int[::1] labels, # IN int n_threads, int single_label=-1, ): ``` - `KMeans`, `BisectingKMeans` .footnote-back[ [cluster/_k_means_common.pyx](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/cluster/_k_means_common.pyx#L94-L101) ] --- # Const memoryviews ## Support readonly data - Use case ```python from sklearn.experimental import enable_halving_search_cv from sklearn.model_selection import HalvingRandomSearchCV search_cv = HalvingRandomSearchCV(estimator, ..., n_jobs=8) search_cv.fit(X, y) ``` data:image/s3,"s3://crabby-images/6bace/6bace6b4800de0025cbe9ddfd14fa60a07dca89e" alt="" --- class: top
# Structs ```python cdef struct SplitRecord: intp_t feature intp_t pos ``` --
## Building trees ```python cdef SplitRecord current_split while ...: current_split.pos = ... current_split.feature = features[f_j] ``` - `tree` module, `RandomForest*` .footnote-back[ [sklearn/tree/_splitter.pxd](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/tree/_splitter.pxd#L16) ] --- class: top # Packed Structs for memoryviews ### NumPy [Structured Dtype](https://numpy.org/doc/stable/user/basics.rec.html) ```python HISTOGRAM_DTYPE = np.dtype([ ('sum_gradients', Y_DTYPE), # sum of sample gradients in bin ('sum_hessians', Y_DTYPE), # sum of sample hessians in bin ('count', np.uint32), # number of samples in bin ]) ``` -- ### Cython ```python cdef packed struct hist_struct: Y_DTYPE_C sum_gradients Y_DTYPE_C sum_hessians unsigned int count ``` -- ### Memoryview ```python hist_struct [:, ::1] histograms = np.empty( shape=(self.n_features, self.n_bins), dtype=HISTOGRAM_DTYPE ) ``` .footnote-back[ [ensemble/_hist_gradient_boosting/histogram.pyx](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/ensemble/_hist_gradient_boosting/histogram.pyx#L141-L144) ] --- class: top
# "Cython classes": Extension Types ```python cdef class Tree: cdef public intp_t n_features cdef intp_t* n_classes cdef public intp_t n_outputs ... cdef Node* nodes cdef float64_t* value ``` -- ## Initialize from Python ```python self.tree_ = Tree(self.n_features_in_, self.n_classes_, self.n_outputs_) ``` - `DecisionTree*`, `RandomForest*`, `GradientBoosting*` .footnote-back[ [tree/_tree.pxd](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/tree/_tree.pxd#L36-L54) ] --- class: top
# "Cython classes": Extension Types ## Constructor ```python cdef class Tree: def __cinit__(self, intp_t n_features, cnp.ndarray n_classes, intp_t n_outputs): safe_realloc(&self.n_classes, n_outputs) ... ``` -- ## Destructor ```python def __dealloc__(self): free(self.n_classes) free(self.nodes) ... ``` .footnote-back[ [tree/_tree.pyx](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/tree/_tree.pyx#L783) ] --- class: chapter-slide # Performance 🏎️ --- # Python's Global Interpreter Lock (GIL) 🔐 .g.g-middle[ .g-8[ ## Prevents Python objects from being accessed at the same time ] .g-4[ data:image/s3,"s3://crabby-images/e6a83/e6a83d8c5db90c7172e2993273214a124a9067a6" alt=":scale 100%" ] ] --- class: top
# GIL - Solution ## Release the GIL! ⛓️💥 -- ```python trees = Parallel( n_jobs=self.n_jobs, ... prefer="threads", )( * delayed(_parallel_build_trees)(...) for i, t in enumerate(trees) ) ``` `ensemble.RandomForest*` .footnote-back[ [ensemble/_forest.py](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/ensemble/_forest.py#L492) ] --- # Releasing the Gil in Cython ## Context manager! ```python *with nogil: builder_stack.push(...) ... node_id = tree._add_node(...) splitter.node_value(...) ``` Everything in block must **not interact** with Python objects .footnote-back[ [tree/_tree.pyx](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/tree/_tree.pyx#L213) ] --- class: top
# nogil in function definition ## Tree builder ```python with nogil: builder_stack.push(...) ... node_id = tree._add_node(...) * splitter.node_value(...) ``` -- ### `node_value` definition ```python cdef class Splitter: cdef void node_value(self, float64_t* dest) noexcept nogil ``` **Must** have `nogil` .footnote-back[ [tree/_splitter.pxd](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/tree/_splitter.pxd#L102) ] --- class: top # Checking for nans or infs ### NumPy ```python has_inf = np.any(np.isinf(X)) has_nan = np.any(np.isnan(X)) ``` -- ### Cython ```python cdef inline FiniteStatus _isfinite_disable_nan(floating* a_ptr, Py_ssize_t length) noexcept nogil: for i in range(length): v = a_ptr[i] if isnan(v): return FiniteStatus.has_nan elif isinf(v): return FiniteStatus.has_infinite return FiniteStatus.all_finite ``` Used **almost everywhere** with `check_array` .footnote-back[ [utils/_isfinite.py](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/utils/_isfinite.pyx#L40-L41) ] --- # OpenMP ## Native Parallelism ```python from cython.parallel cimport prange *for i in prange(data.shape[0], schedule='static', nogil=True, num_threads=n_threads): left, right = 0, binning_thresholds.shape[0] while left < right: middle = left + (right - left - 1) // 2 if data[i] <= binning_thresholds[middle]: right = middle else: left = middle + 1 binned[i] = left ``` - `HistGradientBoosting{Classifier, Regressor}` .footnote-back[ [_hist_gradient_boosting/_binning.pyx](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/ensemble/_hist_gradient_boosting/_binning.pyx#L49-L65) ] --- class: top
# Calling SciPy BLAS with Cython ```python from scipy.linalg.cython_blas cimport sgemm, dgemm ``` ### `gemm`: General Matrix Multiply --
## OpenMP + Cython Blas ```python with nogil, parallel(num_threads=n_threads): for chunk_idx in prange(n_chunks, schedule='static'): _update_chunk_dense(...) # function calls gemm ``` - `KMeans` .footnote-back[ [cluster/_k_means_lloyd.pyx](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/cluster/_k_means_lloyd.pyx#L118) ] --- class: top
# C++ (Map) ### Header ```python cdef class IntFloatDict: cdef cpp_map[intp_t, float64_t] my_map ``` --
### Use in Python ```python d = IntFloatDict(keys, values) for key, value in zip(keys, values): assert d[key] == value ``` - `AgglomerativeClustering` .footnote-back[ [utils/_fast_dict.pxd](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/utils/_fast_dict.pxd#L17-L20), [utils/_fast_dict.pyx](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/utils/_fast_dict.pyx#L116) ] --- # C++ (Vector) ```python from libcpp.vector cimport vector def dbscan_inner(...): cdef vector[intp_t] stack while True: stack.push_back(v) if stack.size() == 0: break ``` - `DBSCAN` .footnote-back[ [cluster/_dbscan_inner.pyx](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/cluster/_dbscan_inner.pyx#L32) ] --- # C++ (Vector) ```python def _fit_encoding_fast(...): cdef: # Gives access to encodings without gil * vector[double*] encoding_vec encoding_vec.resize(n_features) for feat_idx in range(n_features): current_encoding = np.empty(shape=n_categories[feat_idx], dtype=np.float64) * encoding_vec[feat_idx] = & current_encoding[0] ``` - `TargetEncoder` .footnote-back[ [preprocessing/_target_encoder_fast.pyx](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/preprocessing/_target_encoder_fast.pyx#L20) ] --- # C++ Algorithm ```python from libcpp.algorithm cimport pop_heap from libcpp.algorithm cimport push_heap cpdef build(...): cdef vector[FrontierRecord] frontier while not frontier.empty(): * pop_heap(frontier.begin(), frontier.end(), &_compare_records) record = frontier.back() frontier.pop_back() ``` - `tree` module, `GradientBoosting*` & `RandomForest*` .footnote-back[ [tree/_tree.pyx](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/tree/_tree.pyx#L468-L469) ] --- class: top
# Fused Types (Intro) ```python ctypedef fused floating: float double ``` -- ## Function definition ```python from cython cimport floating cdef floating abs_max(int n, const floating* a) noexcept nogil: """np.max(np.abs(a))""" cdef int i cdef floating m = fabs(a[0]) cdef floating d for i in range(1, n): d = fabs(a[i]) if d > m: m = d return m ``` .footnote-back[ [linear_model/_cd_fast.pyx](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/linear_model/_cd_fast.pyx#L50-L59) ] --- class: top # Fused Types (Memoryview) ```python ctypedef fused INT_DTYPE: int64_t int32_t ctypedef fused Y_DTYPE: int64_t int32_t float64_t float32_t ``` -- ## Function Definition ```python def _fit_encoding_fast( INT_DTYPE[:, ::1] X_int, const Y_DTYPE[:] y, int64_t[::1] n_categories, ... ) ``` - `TargetEncoder` .footnote-back[ [preprocessing/_target_encoder_fast.pyx](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/preprocessing/_target_encoder_fast.pyx#L17) ] --- # C++ Vector & Fused types into NumPy Array ## Vectors point to data on the heap ```python from libcpp.vector cimport vector vector[int64_t] vec ``` data:image/s3,"s3://crabby-images/c6778/c67780f0534db00b216a8a9589a65a4aeadf3345" alt=":scale 100%" --- class: top # C++ Vector & Fused types into NumPy Array ```python ctypedef fused vector_typed: vector[float64_t] vector[intp_t] vector[int32_t] vector[int64_t] cdef class StdVectorSentinelInt64: * cdef vector[int64_t] vec ``` -- ## Conversion to NumPy Array ```python cdef cnp.ndarray vector_to_nd_array(vector_typed * vect_ptr): cdef: StdVectorSentinel sentinel = _create_sentinel(vect_ptr) * cnp.ndarray arr = cnp.PyArray_SimpleNewFromData(...) Py_INCREF(sentinel) * cnp.PyArray_SetBaseObject(arr, sentinel) return arr ``` .footnote-back[ [utils/_vector_sentinel.pxd](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/utils/_vector_sentinel.pxd#L6-L12) ] --- class: top
# Fused Types on classes ```python ctypedef fused Partitioner: DensePartitioner SparsePartitioner ``` -- ## Function definition ```python cdef inline int node_split_best( Splitter splitter, Partitioner partitioner, ... ): partitioner.init_node_split(...) while ...: partitioner.find_min_max(...) ``` - `tree` module, `RandomForest*`, `GradientBoosting*` .footnote-back[ [tree/_splitter.pyx](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/tree/_splitter.pyx#L40-L42) ] --- class: top # Tempita ## Code Generation! ```python cdef class WeightVector{{name_suffix}}(object): cdef readonly {{c_type}}[::1] w cdef readonly {{c_type}}[::1] aw ... ``` -- ## Generated Code ```python cdef class WeightVector64(object): cdef readonly double[::1] w cdef readonly double[::1] aw cdef class WeightVector32(object): cdef readonly float[::1] w cdef readonly float[::1] aw ``` - `Perceptron`, `SGDClassifier`, `SGDRegressor`, `PassiveAggressive*` .footnote-back[ [utils/_weight_vector.pxd.tp](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/utils/_weight_vector.pxd.tp#L21-L27) ] --- class: top # Optimizing Performance (Virtual Table) ### The Problem ```python cdef class CyLossFunction: def loss(self, ...) for i in prange( n_samples, schedule='static', nogil=True, num_threads=n_threads ): * loss_out[i] = self.point_loss(y_true[i], raw_prediction[i]) ``` -- ### Subclass ```python cdef class CyHalfSquaredError(CyLossFunction): cdef inline double point_loss( double y_true, double raw_prediction ) noexcept nogil: return 0.5 * (raw_prediction - y_true) * (raw_prediction - y_true) ``` -- ## Performance regression: Can not be dynamic! ❌ --- class: top # Optimizing Performance (Virtual Table) ### Tempita ```python cdef class {{name}}(CyLossFunction): def loss(...): for i in prange( n_samples, schedule='static', nogil=True, num_threads=n_threads ): * loss_out[i] = {{closs}}(y_true[i], raw_prediction[i]{{with_param}}) ``` -- ### Generated Code ```python cdef class CyHalfSquaredError(CyLossFunction): def loss(...): for i in prange( n_samples, schedule='static', nogil=True, num_threads=n_threads ): * loss_out[i] = closs_half_squared_error(y_true[i], raw_prediction[i]) ``` - `linear_model` module, `GradientBoosting*`, `HistGradientBoosting*` .footnote-back[ [_loss/_loss.pyx.tp](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/_loss/_loss.pyx.tp#L1025) ] --- class: top
.g.g-middle[ .g-6[ # Cython Features Covered ] .g-6[ data:image/s3,"s3://crabby-images/8bed9/8bed9b0763e7a4e2bc2441ce96fe5d0191f67995" alt=":scale 50%" ] ] .g[ .g-6[ ## Python <-> Cython interface ⚙️ - Compiling - Types - Memoryviews (NumPy interaction) - "Cython classes" ] .g-6[ ## Performance 🏎️ - Using SciPy BLAS - C++ Objects (Vector, Map, Algorithm) - Fused Types - Tempita ] ] **Learn more** @ https://cython.readthedocs.io/en/latest/ --- class: top # Performance Uplift .g.g-center[ .g-1[] .g-5[ data:image/s3,"s3://crabby-images/61236/61236c366ef808f76b6becbd8a8bd391e1ef4077" alt=":scale 70%" ] .g-5[ data:image/s3,"s3://crabby-images/8bed9/8bed9b0763e7a4e2bc2441ce96fe5d0191f67995" alt=":scale 70%" ] .g-1[] ] - `HistGradientBoosting*`: **LightGBM**-like performance - **2x improvement**: `LogisticRegression`, `linear_model` module, and `GradientBoosting*` - **20x improvement** in `cluster`, `manifold`, `neighbors`, `semi_supervised` modules - `TargetEncoder` - **4-5x runtime** and less memory usage - **Reduce memory usage** for validation checks --- class: top
.center[ # Pushing Cython to its Limits in Scikit-learn ] .g.g-middle[ .g-6[ ## Why Cython? 🚀 ## Cython 101 🍀 ## Scikit-learn Use Cases 🛠️ ] .g-6.g-center[ data:image/s3,"s3://crabby-images/98bc1/98bc100bb1d952a84caf0de3067bff9b8251950d" alt=":scale 60%" ] ] -- - **Material**: [github.com/thomasjpfan/pydata-nyc-2024-cython-in-scikit-learn](https://github.com/thomasjpfan/pydata-nyc-2024-cython-in-scikit-learn) - **Linkedin**: [linkedin.com/in/thomasjpfan/](https://www.linkedin.com/in/thomasjpfan/) - **GitHub**: [github.com/thomasjpfan](https://www.github.com/thomasjpfan) --- class: chapter-slide # Appendix --- # Other Languages .g[ .g-6[ .center[ ## AOT data:image/s3,"s3://crabby-images/475c1/475c13d3729e0b7aec5581c732ac00cc5a8ce405" alt="" ] - **Ahead of time** compiled - Harder to build - Less requirements during runtime ] .g-6[ .center[ ## Numba data:image/s3,"s3://crabby-images/1dc90/1dc9081eaa84a43bf3674b79eb8cf0ac657a0e5a" alt=":scale 38%" ] - **Just in time** compiled - Source code is Python - Requires compiler at runtime ] ] --- class: top
# Header files ```python cdef packed struct node_struct: Y_DTYPE_C value unsigned int count intp_t feature_idx X_DTYPE_C num_threshold ... ``` - **`HistGradientBoosting*`**: [ensemble/_hist_gradient_boosting/common.pxd](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/ensemble/_hist_gradient_boosting/common.pxd#L20-L27) -- ## Imported from another file ```python from .common cimport node_struct ``` - [ensemble/_hist_gradient_boosting/_predictor.pyx](https://github.com/scikit-learn/scikit-learn/blob/e9c394232e826e211d3c67a1f1677d47656114cc/sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx#L13) --- # Free-threading | Method | Data sharing | GIL | | ------ | ------------ | --- | | Multi-threading | Shared | Release GIL or native code | | Multi-processing | Pickle & memmap | No issue | | Sub-Interpreters | Pickle & memmap | No issue | | Free-threading | Shared | No Issue | - Requires native library support: https://py-free-threading.github.io/ - Utilities for data sharing: https://github.com/facebookincubator/ft_utils --- class: top # Strides 1D ```python import numpy as np X = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]) ``` -- ## `float[::1]` - Contiguous .g[ .g-6[ data:image/s3,"s3://crabby-images/d6f12/d6f12dd6f09ee357080191be6f2a1dd257e00658" alt=":scale 80%" ] .g-6[ ] ] --- class: top # Strides 1D ```python import numpy as np X = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]) ``` ## `float[::1]` - Contiguous .g[ .g-6[ data:image/s3,"s3://crabby-images/d6f12/d6f12dd6f09ee357080191be6f2a1dd257e00658" alt=":scale 80%" ] .g-6[ ```python X_row = X[0, :] print(X_row.flags) ``` ``` C_CONTIGUOUS : True F_CONTIGUOUS : True ... ``` ] ] --- class: top # Strides 1D ```python import numpy as np X = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]) ``` ## `float[:]` - Non-Contiguous .g[ .g-6[ data:image/s3,"s3://crabby-images/04629/04629abc1c4937822f2b38cd49acb7e026bc13de" alt=":scale 80%" ] .g-6[ ] ] --- class: top # Strides 1D ```python import numpy as np X = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]]) ``` ## `float[:]` - Non-Contiguous .g[ .g-6[ data:image/s3,"s3://crabby-images/04629/04629abc1c4937822f2b38cd49acb7e026bc13de" alt=":scale 80%" ] .g-6[ ```python X_col = X[:, 1] print(X_col.flags) ``` ``` C_CONTIGUOUS : False F_CONTIGUOUS : False ... ``` ] ]