Thomas J. Fan
@thomasjpfan
github.com/thomasjpfan/pydata-nyc-2024-cython-in-scikit-learn
HistGradientBoosting*
: LightGBM-like performanceLogisticRegression
, linear_model
module, and GradientBoosting*
HistGradientBoosting*
: LightGBM-like performanceLogisticRegression
, linear_model
module, and GradientBoosting*
cluster
, manifold
, neighbors
, semi_supervised
modulesTargetEncoder
- 4-5x runtime and less memory usagecProfile
+ snakevizviztracer
memray
Scalene
cProfile
+ snakevizpython -m cProfile -o hist.prof hist.pysnakeviz hist.prof
viztracer
viztracer hist.pyvizviewer result.json
memray
memray run np-copy.pymemray flamegraph memray-np-copy.py.88600.bin
memray
scalene np-copy.py
# simple.pyxdef add(x, y): return x + y
# simple.pyxdef add(x, y): return x + y
setup.py
from setuptools import setupfrom Cython.Build import cythonizesetup( ext_modules=cythonize("simple.pyx"),)
# simple.pyxdef add(x, y): return x + y
setup.py
from setuptools import setupfrom Cython.Build import cythonizesetup( ext_modules=cythonize("simple.pyx"),)
python setup.py build_ext --inplace
import simpleresult = simple.add(10, 12)print(result)
import simpleresult = simple.add(10, 12)print(result)
# simple.pyxdef add(x, y): return x + y
# simple.pyxdef add(x: int, y: int): return x + y
# simple.pyxdef add(x: int, y: int): return x + y
def
: Call from Pythoncdef
: Call from Cythondef
: Call from Pythoncdef
: Call from Cythoncdef float linear(slope: float, x: float, b: float): return slope * x + b
def
: Call from Pythoncdef
: Call from Cythoncdef float linear(slope: float, x: float, b: float): return slope * x + b
def two_linear(slope: float, x: float, b: float): cdef: float r1 = linear(slope, x, b) float r2 = linear(-slope, x, b) float result = r1 + 2 * r2 return result
def
: Call from Pythoncdef
: Call from Cythoncdef float linear(slope: float, x: float, b: float): return slope * x + b
def two_linear(slope: float, x: float, b: float): cdef: float r1 = linear(slope, x, b) float r2 = linear(-slope, x, b) float result = r1 + 2 * r2 return result
cpdef
: Call from Python & Cython (Scikit-learn uses it like a def
)cython --annotate simple.pyx
%% cythondef add_value(float[:, :] X, float value): ...
%% cythondef add_value(float[:, :] X, float value): ...
import numpy as npy = np.ones(shape=(3, 2), dtype=np.float32)result = add_value(y, 1.4)
%% cythondef add_value(double[:, :] X, double value): cdef: size_t i, j size_t N = X.shape[0] size_t M = X.shape[1] for i in range(N): for j in range(M): X[i, j] += value
%% cythondef add_value(double[:, :] X, double value): cdef: size_t i, j size_t N = X.shape[0] size_t M = X.shape[1] for i in range(N): for j in range(M): X[i, j] += value
scikit_learn_cython_args = [ '-X language_level=3', '-X boundscheck=' + boundscheck, '-X wraparound=False', ...]
boundscheck=True
)boundscheck=False
)wraparound=True
)wraparound=False
)# cython: language_level=3# cython: boundscheck=False# cython: wraparound=Falsecimport cython...
# cython: language_level=3# cython: boundscheck=False# cython: wraparound=Falsecimport cython...
Scikit-learn's sklearn/meson.build
boundscheck
for testingscikit_learn_cython_args = [ '-X language_level=3', '-X boundscheck=' + boundscheck, '-X wraparound=False', ...]
def _make_unique(...): cdef floating[::1] y_out = np.empty(unique_values, dtype=dtype) # Computation return( np.asarray(x_out[:i+1]), ... )
IsotonicRegression
float[:, ::1]
- C contiguousfloat[:, ::1]
- C contiguousfloat[::1, :]
- F contiguousimport numpy as npa = np.ones((2, 3))print(a.flags)
import numpy as npa = np.ones((2, 3))print(a.flags)
C_CONTIGUOUS : True F_CONTIGUOUS : False OWNDATA : True WRITEABLE : True ALIGNED : True WRITEBACKIFCOPY : False
cpdef floating _inertia_dense( const floating[:, ::1] X, # IN const floating[::1] sample_weight, # IN const floating[:, ::1] centers, # IN const int[::1] labels, # IN int n_threads, int single_label=-1,):
KMeans
, BisectingKMeans
from sklearn.experimental import enable_halving_search_cvfrom sklearn.model_selection import HalvingRandomSearchCVsearch_cv = HalvingRandomSearchCV(estimator, ..., n_jobs=8)search_cv.fit(X, y)
cdef struct SplitRecord: intp_t feature intp_t pos
cdef struct SplitRecord: intp_t feature intp_t pos
cdef SplitRecord current_splitwhile ...: current_split.pos = ... current_split.feature = features[f_j]
tree
module, RandomForest*
HISTOGRAM_DTYPE = np.dtype([ ('sum_gradients', Y_DTYPE), # sum of sample gradients in bin ('sum_hessians', Y_DTYPE), # sum of sample hessians in bin ('count', np.uint32), # number of samples in bin])
HISTOGRAM_DTYPE = np.dtype([ ('sum_gradients', Y_DTYPE), # sum of sample gradients in bin ('sum_hessians', Y_DTYPE), # sum of sample hessians in bin ('count', np.uint32), # number of samples in bin])
cdef packed struct hist_struct: Y_DTYPE_C sum_gradients Y_DTYPE_C sum_hessians unsigned int count
HISTOGRAM_DTYPE = np.dtype([ ('sum_gradients', Y_DTYPE), # sum of sample gradients in bin ('sum_hessians', Y_DTYPE), # sum of sample hessians in bin ('count', np.uint32), # number of samples in bin])
cdef packed struct hist_struct: Y_DTYPE_C sum_gradients Y_DTYPE_C sum_hessians unsigned int count
hist_struct [:, ::1] histograms = np.empty( shape=(self.n_features, self.n_bins), dtype=HISTOGRAM_DTYPE)
cdef class Tree: cdef public intp_t n_features cdef intp_t* n_classes cdef public intp_t n_outputs ... cdef Node* nodes cdef float64_t* value
cdef class Tree: cdef public intp_t n_features cdef intp_t* n_classes cdef public intp_t n_outputs ... cdef Node* nodes cdef float64_t* value
self.tree_ = Tree(self.n_features_in_, self.n_classes_, self.n_outputs_)
DecisionTree*
, RandomForest*
, GradientBoosting*
cdef class Tree: def __cinit__(self, intp_t n_features, cnp.ndarray n_classes, intp_t n_outputs): safe_realloc(&self.n_classes, n_outputs) ...
cdef class Tree: def __cinit__(self, intp_t n_features, cnp.ndarray n_classes, intp_t n_outputs): safe_realloc(&self.n_classes, n_outputs) ...
def __dealloc__(self): free(self.n_classes) free(self.nodes) ...
trees = Parallel( n_jobs=self.n_jobs, ... prefer="threads",)( delayed(_parallel_build_trees)(...) for i, t in enumerate(trees))
ensemble.RandomForest*
with nogil: builder_stack.push(...) ... node_id = tree._add_node(...) splitter.node_value(...)
Everything in block must not interact with Python objects
with nogil: builder_stack.push(...) ... node_id = tree._add_node(...) splitter.node_value(...)
with nogil: builder_stack.push(...) ... node_id = tree._add_node(...) splitter.node_value(...)
node_value
definitioncdef class Splitter: cdef void node_value(self, float64_t* dest) noexcept nogil
Must have nogil
has_inf = np.any(np.isinf(X))has_nan = np.any(np.isnan(X))
has_inf = np.any(np.isinf(X))has_nan = np.any(np.isnan(X))
cdef inline FiniteStatus _isfinite_disable_nan(floating* a_ptr, Py_ssize_t length) noexcept nogil: for i in range(length): v = a_ptr[i] if isnan(v): return FiniteStatus.has_nan elif isinf(v): return FiniteStatus.has_infinite return FiniteStatus.all_finite
Used almost everywhere with check_array
from cython.parallel cimport prangefor i in prange(data.shape[0], schedule='static', nogil=True, num_threads=n_threads): left, right = 0, binning_thresholds.shape[0] while left < right: middle = left + (right - left - 1) // 2 if data[i] <= binning_thresholds[middle]: right = middle else: left = middle + 1 binned[i] = left
HistGradientBoosting{Classifier, Regressor}
from scipy.linalg.cython_blas cimport sgemm, dgemm
gemm
: General Matrix Multiplyfrom scipy.linalg.cython_blas cimport sgemm, dgemm
gemm
: General Matrix Multiplywith nogil, parallel(num_threads=n_threads): for chunk_idx in prange(n_chunks, schedule='static'): _update_chunk_dense(...) # function calls gemm
KMeans
cdef class IntFloatDict: cdef cpp_map[intp_t, float64_t] my_map
cdef class IntFloatDict: cdef cpp_map[intp_t, float64_t] my_map
d = IntFloatDict(keys, values)for key, value in zip(keys, values): assert d[key] == value
AgglomerativeClustering
from libcpp.vector cimport vectordef dbscan_inner(...): cdef vector[intp_t] stack while True: stack.push_back(v) if stack.size() == 0: break
DBSCAN
def _fit_encoding_fast(...): cdef: # Gives access to encodings without gil vector[double*] encoding_vec encoding_vec.resize(n_features) for feat_idx in range(n_features): current_encoding = np.empty(shape=n_categories[feat_idx], dtype=np.float64) encoding_vec[feat_idx] = & current_encoding[0]
TargetEncoder
from libcpp.algorithm cimport pop_heapfrom libcpp.algorithm cimport push_heapcpdef build(...): cdef vector[FrontierRecord] frontier while not frontier.empty(): pop_heap(frontier.begin(), frontier.end(), &_compare_records) record = frontier.back() frontier.pop_back()
tree
module, GradientBoosting*
& RandomForest*
ctypedef fused floating: float double
ctypedef fused floating: float double
from cython cimport floatingcdef floating abs_max(int n, const floating* a) noexcept nogil: """np.max(np.abs(a))""" cdef int i cdef floating m = fabs(a[0]) cdef floating d for i in range(1, n): d = fabs(a[i]) if d > m: m = d return m
ctypedef fused INT_DTYPE: int64_t int32_tctypedef fused Y_DTYPE: int64_t int32_t float64_t float32_t
ctypedef fused INT_DTYPE: int64_t int32_tctypedef fused Y_DTYPE: int64_t int32_t float64_t float32_t
def _fit_encoding_fast( INT_DTYPE[:, ::1] X_int, const Y_DTYPE[:] y, int64_t[::1] n_categories, ...)
TargetEncoder
from libcpp.vector cimport vectorvector[int64_t] vec
ctypedef fused vector_typed: vector[float64_t] vector[intp_t] vector[int32_t] vector[int64_t]cdef class StdVectorSentinelInt64: cdef vector[int64_t] vec
ctypedef fused vector_typed: vector[float64_t] vector[intp_t] vector[int32_t] vector[int64_t]cdef class StdVectorSentinelInt64: cdef vector[int64_t] vec
cdef cnp.ndarray vector_to_nd_array(vector_typed * vect_ptr): cdef: StdVectorSentinel sentinel = _create_sentinel(vect_ptr) cnp.ndarray arr = cnp.PyArray_SimpleNewFromData(...) Py_INCREF(sentinel) cnp.PyArray_SetBaseObject(arr, sentinel) return arr
ctypedef fused Partitioner: DensePartitioner SparsePartitioner
ctypedef fused Partitioner: DensePartitioner SparsePartitioner
cdef inline int node_split_best( Splitter splitter, Partitioner partitioner, ...): partitioner.init_node_split(...) while ...: partitioner.find_min_max(...)
tree
module, RandomForest*
, GradientBoosting*
cdef class WeightVector{{name_suffix}}(object): cdef readonly {{c_type}}[::1] w cdef readonly {{c_type}}[::1] aw ...
cdef class WeightVector{{name_suffix}}(object): cdef readonly {{c_type}}[::1] w cdef readonly {{c_type}}[::1] aw ...
cdef class WeightVector64(object): cdef readonly double[::1] w cdef readonly double[::1] awcdef class WeightVector32(object): cdef readonly float[::1] w cdef readonly float[::1] aw
Perceptron
, SGDClassifier
, SGDRegressor
, PassiveAggressive*
cdef class CyLossFunction: def loss(self, ...) for i in prange( n_samples, schedule='static', nogil=True, num_threads=n_threads ): loss_out[i] = self.point_loss(y_true[i], raw_prediction[i])
cdef class CyLossFunction: def loss(self, ...) for i in prange( n_samples, schedule='static', nogil=True, num_threads=n_threads ): loss_out[i] = self.point_loss(y_true[i], raw_prediction[i])
cdef class CyHalfSquaredError(CyLossFunction): cdef inline double point_loss( double y_true, double raw_prediction ) noexcept nogil: return 0.5 * (raw_prediction - y_true) * (raw_prediction - y_true)
cdef class CyLossFunction: def loss(self, ...) for i in prange( n_samples, schedule='static', nogil=True, num_threads=n_threads ): loss_out[i] = self.point_loss(y_true[i], raw_prediction[i])
cdef class CyHalfSquaredError(CyLossFunction): cdef inline double point_loss( double y_true, double raw_prediction ) noexcept nogil: return 0.5 * (raw_prediction - y_true) * (raw_prediction - y_true)
cdef class {{name}}(CyLossFunction): def loss(...): for i in prange( n_samples, schedule='static', nogil=True, num_threads=n_threads ): loss_out[i] = {{closs}}(y_true[i], raw_prediction[i]{{with_param}})
cdef class {{name}}(CyLossFunction): def loss(...): for i in prange( n_samples, schedule='static', nogil=True, num_threads=n_threads ): loss_out[i] = {{closs}}(y_true[i], raw_prediction[i]{{with_param}})
cdef class CyHalfSquaredError(CyLossFunction): def loss(...): for i in prange( n_samples, schedule='static', nogil=True, num_threads=n_threads ): loss_out[i] = closs_half_squared_error(y_true[i], raw_prediction[i])
linear_model
module, GradientBoosting*
, HistGradientBoosting*
Learn more @ https://cython.readthedocs.io/en/latest/
HistGradientBoosting*
: LightGBM-like performance2x improvement: LogisticRegression
, linear_model
module, and GradientBoosting*
20x improvement in cluster
, manifold
, neighbors
, semi_supervised
modules
TargetEncoder
- 4-5x runtime and less memory usagecdef packed struct node_struct: Y_DTYPE_C value unsigned int count intp_t feature_idx X_DTYPE_C num_threshold ...
HistGradientBoosting*
: ensemble/_hist_gradient_boosting/common.pxdcdef packed struct node_struct: Y_DTYPE_C value unsigned int count intp_t feature_idx X_DTYPE_C num_threshold ...
HistGradientBoosting*
: ensemble/_hist_gradient_boosting/common.pxdfrom .common cimport node_struct
Method | Data sharing | GIL |
---|---|---|
Multi-threading | Shared | Release GIL or native code |
Multi-processing | Pickle & memmap | No issue |
Sub-Interpreters | Pickle & memmap | No issue |
Free-threading | Shared | No Issue |
import numpy as npX = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])
import numpy as npX = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])
float[::1]
- Contiguousimport numpy as npX = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])
float[::1]
- ContiguousX_row = X[0, :]print(X_row.flags)
C_CONTIGUOUS : True F_CONTIGUOUS : True ...
import numpy as npX = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])
float[:]
- Non-Contiguousimport numpy as npX = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])
float[:]
- Non-ContiguousX_col = X[:, 1]print(X_col.flags)
C_CONTIGUOUS : False F_CONTIGUOUS : False ...
Keyboard shortcuts
↑, ←, Pg Up, k | Go to previous slide |
↓, →, Pg Dn, Space, j | Go to next slide |
Home | Go to first slide |
End | Go to last slide |
Number + Return | Go to specific slide |
b / m / f | Toggle blackout / mirrored / fullscreen mode |
c | Clone slideshow |
p | Toggle presenter mode |
t | Restart the presentation timer |
?, h | Toggle this help |
Esc | Back to slideshow |