+ - 0:00:00
Notes for current slide
Notes for next slide

Pushing Cython to its Limits in Scikit-learn

Thomas J. Fan
@thomasjpfan github.com/thomasjpfan/pydata-nyc-2024-cython-in-scikit-learn


Me

  • Senior Machine Engineer @ Union.ai


Me

  • Senior Machine Engineer @ Union.ai

  • Maintainer for scikit-learn

Agenda 📓

- Why Cython? 🚀

- Cython 101 🍀

- Scikit-learn Use Cases 🛠️

Why Cython? 🚀

1. Python-Like 🐍

2. Performance

Improve Runtime 🏎️

Reduce Memory Usage 🧠

Performance Uplift

  • HistGradientBoosting*: LightGBM-like performance
  • 2x improvement: LogisticRegression, linear_model module, and GradientBoosting*

Performance Uplift

  • HistGradientBoosting*: LightGBM-like performance
  • 2x improvement: LogisticRegression, linear_model module, and GradientBoosting*
  • 20x improvement in cluster, manifold, neighbors, semi_supervised modules
  • TargetEncoder - 4-5x runtime and less memory usage
  • Reduce memory usage for validation checks

Profiling 🔬

  • cProfile + snakeviz
  • viztracer
  • memray
  • Scalene

Finding Hot-spots 🔎

cProfile + snakeviz

python -m cProfile -o hist.prof hist.py
snakeviz hist.prof

Finding Hot-spots 🔎

viztracer

viztracer hist.py
vizviewer result.json

Memory Profiling 🧠

memray

memray run np-copy.py
memray flamegraph memray-np-copy.py.88600.bin

Memory Profiling 🧠

memray

Memory Profiling 🧠

Scalene

scalene np-copy.py

Cython 101 🍀

Cython 101 🍀

- Compiling

- Types

- Developer Tips


Compiling

# simple.pyx
def add(x, y):
return x + y


Compiling

# simple.pyx
def add(x, y):
return x + y

setup.py

from setuptools import setup
from Cython.Build import cythonize
setup(
ext_modules=cythonize("simple.pyx"),
)


Compiling

# simple.pyx
def add(x, y):
return x + y

setup.py

from setuptools import setup
from Cython.Build import cythonize
setup(
ext_modules=cythonize("simple.pyx"),
)

Build Command

python setup.py build_ext --inplace



Importing from Python code

import simple
result = simple.add(10, 12)
print(result)



Importing from Python code

import simple
result = simple.add(10, 12)
print(result)

Benefits

  • Does not go through the Python Interpreter
# simple.pyx
def add(x, y):
return x + y




Adding Types

# simple.pyx
def add(x: int, y: int):
return x + y




Adding Types

# simple.pyx
def add(x: int, y: int):
return x + y

Benefits

  • Removes the Python interpreter
  • Compiler can optimize with types

Cython Overview




Defining Functions

  • def : Call from Python
  • cdef : Call from Cython



Defining Functions

  • def : Call from Python
  • cdef : Call from Cython
cdef float linear(slope: float, x: float, b: float):
return slope * x + b



Defining Functions

  • def : Call from Python
  • cdef : Call from Cython
cdef float linear(slope: float, x: float, b: float):
return slope * x + b
def two_linear(slope: float, x: float, b: float):
cdef:
float r1 = linear(slope, x, b)
float r2 = linear(-slope, x, b)
float result = r1 + 2 * r2
return result



Defining Functions

  • def : Call from Python
  • cdef : Call from Cython
cdef float linear(slope: float, x: float, b: float):
return slope * x + b
def two_linear(slope: float, x: float, b: float):
cdef:
float r1 = linear(slope, x, b)
float r2 = linear(-slope, x, b)
float result = r1 + 2 * r2
return result
  • cpdef : Call from Python & Cython (Scikit-learn uses it like a def)

Developing

Annotation

cython --annotate simple.pyx

Working in Jupyter

Working in Jupyter (Annotation)

Scikit-learn Use Cases 🛠️

Scikit-learn Use Cases 🛠️

Python <-> Cython interface ⚙️

Performance

  • Improve Runtime 🏎️
  • Reduce Memory Usage 🧠

Python <-> Cython interface - NumPy Arrays

Python <-> Cython interface - NumPy Arrays

Memoryview

%% cython
def add_value(float[:, :] X, float value):
...

Python <-> Cython interface - NumPy Arrays

Memoryview

%% cython
def add_value(float[:, :] X, float value):
...

Call from Python

import numpy as np
y = np.ones(shape=(3, 2), dtype=np.float32)
result = add_value(y, 1.4)

Python Buffer Protocol 🔌



Python <> Cython interface - NumPy Arrays

Write loops!

%% cython
def add_value(double[:, :] X, double value):
cdef:
size_t i, j
size_t N = X.shape[0]
size_t M = X.shape[1]
for i in range(N):
for j in range(M):
X[i, j] += value



Python <> Cython interface - NumPy Arrays

Write loops!

%% cython
def add_value(double[:, :] X, double value):
cdef:
size_t i, j
size_t N = X.shape[0]
size_t M = X.shape[1]
for i in range(N):
for j in range(M):
X[i, j] += value

It's okay! 😆

Scikit-learn Optimizations for memoryviews

Directives!

scikit_learn_cython_args = [
'-X language_level=3',
'-X boundscheck=' + boundscheck,
'-X wraparound=False',
...
]

Memoryview directives (boundscheck=True)

Memoryview directives (boundscheck=False)

Memoryview directives (wraparound=True)

Memoryview directives (wraparound=False)



Cython directives

Define for file 🗃️

# cython: language_level=3
# cython: boundscheck=False
# cython: wraparound=False
cimport cython
...



Cython directives

Define for file 🗃️

# cython: language_level=3
# cython: boundscheck=False
# cython: wraparound=False
cimport cython
...

Globally in build backend 🌎

Scikit-learn's sklearn/meson.build

scikit-learn Global configuration

Dynamic configure boundscheck for testing

scikit_learn_cython_args = [
'-X language_level=3',
'-X boundscheck=' + boundscheck,
'-X wraparound=False',
...
]

Returning memoryviews

def _make_unique(...):
cdef floating[::1] y_out = np.empty(unique_values, dtype=dtype)
# Computation
return(
np.asarray(x_out[:i+1]),
...
)
  • IsotonicRegression

Strides 2D

Strides 2D

float[:, ::1] - C contiguous

Strides 2D

float[:, ::1] - C contiguous

float[::1, :] - F contiguous



NumPy API

import numpy as np
a = np.ones((2, 3))
print(a.flags)



NumPy API

import numpy as np
a = np.ones((2, 3))
print(a.flags)
C_CONTIGUOUS : True
F_CONTIGUOUS : False
OWNDATA : True
WRITEABLE : True
ALIGNED : True
WRITEBACKIFCOPY : False

Const memoryviews

Support readonly data

cpdef floating _inertia_dense(
const floating[:, ::1] X, # IN
const floating[::1] sample_weight, # IN
const floating[:, ::1] centers, # IN
const int[::1] labels, # IN
int n_threads,
int single_label=-1,
):
  • KMeans, BisectingKMeans

Const memoryviews

Support readonly data - Use case

from sklearn.experimental import enable_halving_search_cv
from sklearn.model_selection import HalvingRandomSearchCV
search_cv = HalvingRandomSearchCV(estimator, ..., n_jobs=8)
search_cv.fit(X, y)


Structs

cdef struct SplitRecord:
intp_t feature
intp_t pos


Structs

cdef struct SplitRecord:
intp_t feature
intp_t pos


Building trees

cdef SplitRecord current_split
while ...:
current_split.pos = ...
current_split.feature = features[f_j]
  • tree module, RandomForest*

Packed Structs for memoryviews

NumPy Structured Dtype

HISTOGRAM_DTYPE = np.dtype([
('sum_gradients', Y_DTYPE), # sum of sample gradients in bin
('sum_hessians', Y_DTYPE), # sum of sample hessians in bin
('count', np.uint32), # number of samples in bin
])

Packed Structs for memoryviews

NumPy Structured Dtype

HISTOGRAM_DTYPE = np.dtype([
('sum_gradients', Y_DTYPE), # sum of sample gradients in bin
('sum_hessians', Y_DTYPE), # sum of sample hessians in bin
('count', np.uint32), # number of samples in bin
])

Cython

cdef packed struct hist_struct:
Y_DTYPE_C sum_gradients
Y_DTYPE_C sum_hessians
unsigned int count

Packed Structs for memoryviews

NumPy Structured Dtype

HISTOGRAM_DTYPE = np.dtype([
('sum_gradients', Y_DTYPE), # sum of sample gradients in bin
('sum_hessians', Y_DTYPE), # sum of sample hessians in bin
('count', np.uint32), # number of samples in bin
])

Cython

cdef packed struct hist_struct:
Y_DTYPE_C sum_gradients
Y_DTYPE_C sum_hessians
unsigned int count

Memoryview

hist_struct [:, ::1] histograms = np.empty(
shape=(self.n_features, self.n_bins),
dtype=HISTOGRAM_DTYPE
)




"Cython classes": Extension Types

cdef class Tree:
cdef public intp_t n_features
cdef intp_t* n_classes
cdef public intp_t n_outputs
...
cdef Node* nodes
cdef float64_t* value




"Cython classes": Extension Types

cdef class Tree:
cdef public intp_t n_features
cdef intp_t* n_classes
cdef public intp_t n_outputs
...
cdef Node* nodes
cdef float64_t* value

Initialize from Python

self.tree_ = Tree(self.n_features_in_, self.n_classes_, self.n_outputs_)
  • DecisionTree*, RandomForest*, GradientBoosting*



"Cython classes": Extension Types

Constructor

cdef class Tree:
def __cinit__(self, intp_t n_features, cnp.ndarray n_classes, intp_t n_outputs):
safe_realloc(&self.n_classes, n_outputs)
...



"Cython classes": Extension Types

Constructor

cdef class Tree:
def __cinit__(self, intp_t n_features, cnp.ndarray n_classes, intp_t n_outputs):
safe_realloc(&self.n_classes, n_outputs)
...

Destructor

def __dealloc__(self):
free(self.n_classes)
free(self.nodes)
...

Performance 🏎️

Python's Global Interpreter Lock (GIL) 🔐

Prevents Python objects from being accessed at the same time




GIL - Solution

Release the GIL! ⛓️‍💥




GIL - Solution

Release the GIL! ⛓️‍💥

trees = Parallel(
n_jobs=self.n_jobs, ... prefer="threads",
)(
delayed(_parallel_build_trees)(...)
for i, t in enumerate(trees)
)

ensemble.RandomForest*

Releasing the Gil in Cython

Context manager!

with nogil:
builder_stack.push(...)
...
node_id = tree._add_node(...)
splitter.node_value(...)

Everything in block must not interact with Python objects



nogil in function definition

Tree builder

with nogil:
builder_stack.push(...)
...
node_id = tree._add_node(...)
splitter.node_value(...)



nogil in function definition

Tree builder

with nogil:
builder_stack.push(...)
...
node_id = tree._add_node(...)
splitter.node_value(...)

node_value definition

cdef class Splitter:
cdef void node_value(self, float64_t* dest) noexcept nogil

Must have nogil

Checking for nans or infs

NumPy

has_inf = np.any(np.isinf(X))
has_nan = np.any(np.isnan(X))

Checking for nans or infs

NumPy

has_inf = np.any(np.isinf(X))
has_nan = np.any(np.isnan(X))

Cython

cdef inline FiniteStatus _isfinite_disable_nan(floating* a_ptr,
Py_ssize_t length) noexcept nogil:
for i in range(length):
v = a_ptr[i]
if isnan(v):
return FiniteStatus.has_nan
elif isinf(v):
return FiniteStatus.has_infinite
return FiniteStatus.all_finite

Used almost everywhere with check_array

OpenMP

Native Parallelism

from cython.parallel cimport prange
for i in prange(data.shape[0], schedule='static', nogil=True, num_threads=n_threads):
left, right = 0, binning_thresholds.shape[0]
while left < right:
middle = left + (right - left - 1) // 2
if data[i] <= binning_thresholds[middle]:
right = middle
else:
left = middle + 1
binned[i] = left
  • HistGradientBoosting{Classifier, Regressor}



Calling SciPy BLAS with Cython

from scipy.linalg.cython_blas cimport sgemm, dgemm

gemm: General Matrix Multiply



Calling SciPy BLAS with Cython

from scipy.linalg.cython_blas cimport sgemm, dgemm

gemm: General Matrix Multiply


OpenMP + Cython Blas

with nogil, parallel(num_threads=n_threads):
for chunk_idx in prange(n_chunks, schedule='static'):
_update_chunk_dense(...) # function calls gemm
  • KMeans


C++ (Map)

cdef class IntFloatDict:
cdef cpp_map[intp_t, float64_t] my_map


C++ (Map)

cdef class IntFloatDict:
cdef cpp_map[intp_t, float64_t] my_map


Use in Python

d = IntFloatDict(keys, values)
for key, value in zip(keys, values):
assert d[key] == value
  • AgglomerativeClustering

C++ (Vector)

from libcpp.vector cimport vector
def dbscan_inner(...):
cdef vector[intp_t] stack
while True:
stack.push_back(v)
if stack.size() == 0:
break
  • DBSCAN

C++ (Vector)

def _fit_encoding_fast(...):
cdef:
# Gives access to encodings without gil
vector[double*] encoding_vec
encoding_vec.resize(n_features)
for feat_idx in range(n_features):
current_encoding = np.empty(shape=n_categories[feat_idx], dtype=np.float64)
encoding_vec[feat_idx] = & current_encoding[0]
  • TargetEncoder

C++ Algorithm

from libcpp.algorithm cimport pop_heap
from libcpp.algorithm cimport push_heap
cpdef build(...):
cdef vector[FrontierRecord] frontier
while not frontier.empty():
pop_heap(frontier.begin(), frontier.end(), &_compare_records)
record = frontier.back()
frontier.pop_back()
  • tree module, GradientBoosting* & RandomForest*


Fused Types (Intro)

ctypedef fused floating:
float
double


Fused Types (Intro)

ctypedef fused floating:
float
double

Function definition

from cython cimport floating
cdef floating abs_max(int n, const floating* a) noexcept nogil:
"""np.max(np.abs(a))"""
cdef int i
cdef floating m = fabs(a[0])
cdef floating d
for i in range(1, n):
d = fabs(a[i])
if d > m:
m = d
return m

Fused Types (Memoryview)

ctypedef fused INT_DTYPE:
int64_t
int32_t
ctypedef fused Y_DTYPE:
int64_t
int32_t
float64_t
float32_t

Fused Types (Memoryview)

ctypedef fused INT_DTYPE:
int64_t
int32_t
ctypedef fused Y_DTYPE:
int64_t
int32_t
float64_t
float32_t

Function Definition

def _fit_encoding_fast(
INT_DTYPE[:, ::1] X_int,
const Y_DTYPE[:] y,
int64_t[::1] n_categories,
...
)
  • TargetEncoder

C++ Vector & Fused types into NumPy Array

Vectors point to data on the heap

from libcpp.vector cimport vector
vector[int64_t] vec

C++ Vector & Fused types into NumPy Array

ctypedef fused vector_typed:
vector[float64_t]
vector[intp_t]
vector[int32_t]
vector[int64_t]
cdef class StdVectorSentinelInt64:
cdef vector[int64_t] vec

C++ Vector & Fused types into NumPy Array

ctypedef fused vector_typed:
vector[float64_t]
vector[intp_t]
vector[int32_t]
vector[int64_t]
cdef class StdVectorSentinelInt64:
cdef vector[int64_t] vec

Conversion to NumPy Array

cdef cnp.ndarray vector_to_nd_array(vector_typed * vect_ptr):
cdef:
StdVectorSentinel sentinel = _create_sentinel(vect_ptr)
cnp.ndarray arr = cnp.PyArray_SimpleNewFromData(...)
Py_INCREF(sentinel)
cnp.PyArray_SetBaseObject(arr, sentinel)
return arr


Fused Types on classes

ctypedef fused Partitioner:
DensePartitioner
SparsePartitioner


Fused Types on classes

ctypedef fused Partitioner:
DensePartitioner
SparsePartitioner

Function definition

cdef inline int node_split_best(
Splitter splitter,
Partitioner partitioner,
...
):
partitioner.init_node_split(...)
while ...:
partitioner.find_min_max(...)
  • tree module, RandomForest*, GradientBoosting*

Tempita

Code Generation!

cdef class WeightVector{{name_suffix}}(object):
cdef readonly {{c_type}}[::1] w
cdef readonly {{c_type}}[::1] aw
...

Tempita

Code Generation!

cdef class WeightVector{{name_suffix}}(object):
cdef readonly {{c_type}}[::1] w
cdef readonly {{c_type}}[::1] aw
...

Generated Code

cdef class WeightVector64(object):
cdef readonly double[::1] w
cdef readonly double[::1] aw
cdef class WeightVector32(object):
cdef readonly float[::1] w
cdef readonly float[::1] aw
  • Perceptron, SGDClassifier, SGDRegressor, PassiveAggressive*

Optimizing Performance (Virtual Table)

The Problem

cdef class CyLossFunction:
def loss(self, ...)
for i in prange(
n_samples, schedule='static', nogil=True, num_threads=n_threads
):
loss_out[i] = self.point_loss(y_true[i], raw_prediction[i])

Optimizing Performance (Virtual Table)

The Problem

cdef class CyLossFunction:
def loss(self, ...)
for i in prange(
n_samples, schedule='static', nogil=True, num_threads=n_threads
):
loss_out[i] = self.point_loss(y_true[i], raw_prediction[i])

Subclass

cdef class CyHalfSquaredError(CyLossFunction):
cdef inline double point_loss(
double y_true,
double raw_prediction
) noexcept nogil:
return 0.5 * (raw_prediction - y_true) * (raw_prediction - y_true)

Optimizing Performance (Virtual Table)

The Problem

cdef class CyLossFunction:
def loss(self, ...)
for i in prange(
n_samples, schedule='static', nogil=True, num_threads=n_threads
):
loss_out[i] = self.point_loss(y_true[i], raw_prediction[i])

Subclass

cdef class CyHalfSquaredError(CyLossFunction):
cdef inline double point_loss(
double y_true,
double raw_prediction
) noexcept nogil:
return 0.5 * (raw_prediction - y_true) * (raw_prediction - y_true)

Performance regression: Can not be dynamic! ❌

Optimizing Performance (Virtual Table)

Tempita

cdef class {{name}}(CyLossFunction):
def loss(...):
for i in prange(
n_samples, schedule='static', nogil=True, num_threads=n_threads
):
loss_out[i] = {{closs}}(y_true[i], raw_prediction[i]{{with_param}})

Optimizing Performance (Virtual Table)

Tempita

cdef class {{name}}(CyLossFunction):
def loss(...):
for i in prange(
n_samples, schedule='static', nogil=True, num_threads=n_threads
):
loss_out[i] = {{closs}}(y_true[i], raw_prediction[i]{{with_param}})

Generated Code

cdef class CyHalfSquaredError(CyLossFunction):
def loss(...):
for i in prange(
n_samples, schedule='static', nogil=True, num_threads=n_threads
):
loss_out[i] = closs_half_squared_error(y_true[i], raw_prediction[i])
  • linear_model module, GradientBoosting*, HistGradientBoosting*


Cython Features Covered

Python <-> Cython interface ⚙️

  • Compiling
  • Types
  • Memoryviews (NumPy interaction)
  • "Cython classes"

Performance 🏎️

  • Using SciPy BLAS
  • C++ Objects (Vector, Map, Algorithm)
  • Fused Types
  • Tempita

Learn more @ https://cython.readthedocs.io/en/latest/

Performance Uplift

  • HistGradientBoosting*: LightGBM-like performance
  • 2x improvement: LogisticRegression, linear_model module, and GradientBoosting*

  • 20x improvement in cluster, manifold, neighbors, semi_supervised modules

  • TargetEncoder - 4-5x runtime and less memory usage
  • Reduce memory usage for validation checks


Pushing Cython to its Limits in Scikit-learn

Why Cython? 🚀

Cython 101 🍀

Scikit-learn Use Cases 🛠️


Pushing Cython to its Limits in Scikit-learn

Why Cython? 🚀

Cython 101 🍀

Scikit-learn Use Cases 🛠️

Appendix

Other Languages

AOT

  • Ahead of time compiled
  • Harder to build
  • Less requirements during runtime

Numba

  • Just in time compiled
  • Source code is Python
  • Requires compiler at runtime


Header files

cdef packed struct node_struct:
Y_DTYPE_C value
unsigned int count
intp_t feature_idx
X_DTYPE_C num_threshold
...


Header files

cdef packed struct node_struct:
Y_DTYPE_C value
unsigned int count
intp_t feature_idx
X_DTYPE_C num_threshold
...

Imported from another file

from .common cimport node_struct

Free-threading

Method Data sharing GIL
Multi-threading Shared Release GIL or native code
Multi-processing Pickle & memmap No issue
Sub-Interpreters Pickle & memmap No issue
Free-threading Shared No Issue

Strides 1D

import numpy as np
X = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])

Strides 1D

import numpy as np
X = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])

float[::1] - Contiguous

Strides 1D

import numpy as np
X = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])

float[::1] - Contiguous

X_row = X[0, :]
print(X_row.flags)
C_CONTIGUOUS : True
F_CONTIGUOUS : True
...

Strides 1D

import numpy as np
X = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])

float[:] - Non-Contiguous

Strides 1D

import numpy as np
X = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])

float[:] - Non-Contiguous

X_col = X[:, 1]
print(X_col.flags)
C_CONTIGUOUS : False
F_CONTIGUOUS : False
...


Me

  • Senior Machine Engineer @ Union.ai

Paused

Help

Keyboard shortcuts

, , Pg Up, k Go to previous slide
, , Pg Dn, Space, j Go to next slide
Home Go to first slide
End Go to last slide
Number + Return Go to specific slide
b / m / f Toggle blackout / mirrored / fullscreen mode
c Clone slideshow
p Toggle presenter mode
t Restart the presentation timer
?, h Toggle this help
Esc Back to slideshow