+ - 0:00:00
Notes for current slide
Notes for next slide

Scikit-learn on GPUs with Array API

Thomas J. Fan
@thomasjpfan github.com/thomasjpfan/pydata-nyc-2023-scikit-learn-array-api

GPU support in scikit-learn โ‰๏ธ

scikit-learn v1.2 Array API support (2022)

scikit-learn v1.3 Array API support (2023)

Contents

1. scikit-learn API ๐Ÿ–ฅ๏ธ

2. Array API Standard ๐Ÿ”ฌ

3. Challenges ๐Ÿšง

scikit-learn API ๐Ÿ–ฅ๏ธ

scikit-learn API ๐Ÿ–ฅ๏ธ

from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
lda_np = LinearDiscriminantAnalysis()
lda_np.fit(X_np, y_np)
y_pred_np = lda_np.predict(X_np)
type(y_pred_np)
# <class 'numpy.ndarray'>

Enabling Array API support in scikit-learn

Global configuration ๐ŸŒŽ

import sklearn
import torch
sklearn.set_config(array_api_dispatch=True)
X_torch_cpu, y_torch_cpu = torch.asarray(X_np), torch.asarray(y_np)
lda = LinearDiscriminantAnalysis()
lda.fit(X_torch_cpu, y_torch_cpu)
type(lda.predict(X_torch_cpu))
# <class 'torch.Tensor'>

Enabling Array API support in scikit-learn

Context Manager ๐ŸŽฌ

import sklearn
with sklearn.config_context(array_api_dispatch=True):
X_torch_cuda = torch.asarray(X_np, device="cuda")
y_torch_cuda = torch.asarray(y_np, device="cuda")
lda = LinearDiscriminantAnalysis()
lda.fit(X_torch_cuda, y_torch_cuda)
type(lda.predict(X_torch_cuda))
# <class 'torch.Tensor'>

Performance ๐Ÿš€

16-core AMD 5950x CPU and Nvidia RTX 3090 GPU

scikit-learn Nightly Build ๐ŸŒ•

https://scikit-learn.org/dev/modules/array_api.html

Array API Standard ๐Ÿ”ฌ

Array Libraries

Consortium for Python Data API Standards

https://data-apis.org

Extensions ๐Ÿ”Œ




Vision ๐Ÿ”ฎ

NumPy Code

def func(x, y):
out = np.mean(x, axis=0) - 2 * np.std(y, axis=0)
return out




Vision ๐Ÿ”ฎ

NumPy Code

def func(x, y):
out = np.mean(x, axis=0) - 2 * np.std(y, axis=0)
return out

Array API Code

def func(x, y):
xp = array_namespace(x, y)
out = xp.mean(x, axis=0) - 2 * xp.std(y, axis=0)
return out

Array API support (2022)

โœ…

import numpy.array_api as xp
import cupy.array_api as xp

๐Ÿ›‘

import numpy as np
import cupy as cp

scikit-learn v1.2 Array API support (2022)

import cupy
import cupy.array_api as xp
sklearn.set_config(array_api_dispatch=True)
X_cp, y_cp = cupy.asarray(...), cupy.asarray(...)
X_xp, y_xp = xp.asarray(X_cp), xp.asarray(y_cp)
lda = LinearDiscriminantAnalysis()
lda.fit(X_xp, y_xp)

Meta + Quansight Collaboration

array_api_compat ๐Ÿš€

Extend Array API standard to the main namespace!

https://github.com/data-apis/array-api-compat


Using array_api_compat ๐Ÿš€

from array_api_compat import array_namespace
def func(x, y):
xp = array_namespace(x, y)
out = xp.mean(x, axis=0) - 2 * xp.std(y, axis=0)
return out


Using array_api_compat ๐Ÿš€

from array_api_compat import array_namespace
def func(x, y):
xp = array_namespace(x, y)
out = xp.mean(x, axis=0) - 2 * xp.std(y, axis=0)
return out


Works with ๐ŸŽฏ

array_api_compat Extend:

  • NumPy's ndarray
  • CuPy's ndarray
  • PyTorch's Tensors

Array API implementations

  • Numpy Arrays from numpy.array_api
  • CuPy Arrays from cupy.array_api

scikit-learn v1.3 Array API support (2023)

import torch
sklearn.set_config(array_api_dispatch=True)
X_torch_cpu, y_torch_cpu = torch.asarray(...), torch.asarray(...)
lda = LinearDiscriminantAnalysis()
lda.fit(X_torch_cpu, y_torch_cpu)

Challenges ๐Ÿšง

Challenges ๐Ÿšง

  • API Differences ๐Ÿ”Œ
  • Semantic Differences ๐Ÿช„
  • Compiled Code ๐Ÿ—๏ธ

API Differences ๐Ÿ”Œ




Most methods are in the module ๐Ÿ“ฆ

NumPy

import numpy as np
y_sum = y.sum(axis=0)




Most methods are in the module ๐Ÿ“ฆ

NumPy

import numpy as np
y_sum = y.sum(axis=0)

Array API

from array_api_compat import array_namespace
xp = array_namespace(y)
y_sum = xp.sum(y, axis=0)




Most methods are in the module ๐Ÿ“ฆ

NumPy

import numpy as np
y = (X.mean(axis=1) > 1.0).any()




Most methods are in the module ๐Ÿ“ฆ

NumPy

import numpy as np
y = (X.mean(axis=1) > 1.0).any()

Array API

xp = array_namespace(x)
y = xp.any(xp.mean(X, axis=1) > 1.0)



Matrix Multiplication ๐Ÿงฎ

NumPy

import numpy as np
C = np.dot(A, B)



Matrix Multiplication ๐Ÿงฎ

NumPy

import numpy as np
C = np.dot(A, B)

Array API

  • @ is more restrictive compared to np.dot
C = A @ B



Differences between NumPy and Array API ๐ŸŽ›๏ธ

NumPy

import numpy as np
uniques = np.unique(x)
uniques, counts = np.unique(x, return_counts=True)



Differences between NumPy and Array API ๐ŸŽ›๏ธ

NumPy

import numpy as np
uniques = np.unique(x)
uniques, counts = np.unique(x, return_counts=True)

Array API

xp = array_namespace(x)
uniques = xp.unique_values(x)
counts = xp.unique_counts(x)

Some NumPy API does not exist in Array API ๐ŸŽš๏ธ

NumPy

import numpy as np
x_mean = np.nanmax(x, axis=1)

Some NumPy API does not exist in Array API ๐ŸŽš๏ธ

Array API

def xp_nanmax(X, axis=None):
xp = array_namespace(X)
if is_numpy_namespace(xp):
return xp.asarray(numpy.nanmax(X, axis=axis))
# Implement using Array API standard (simplified)
mask = xp.isnan(X)
inf_ = xp.asarray(-xp.inf, device=device(X))
X_nanmax = xp.max(xp.where(mask, inf_, X), axis=axis)
return X_nanmax

Integer Indexing ๐Ÿ”Ž

NumPy

import numpy as np
x = np.asarray([[1, 2], [4, 5], [4, 1]])
x[[0, 2]]
# array([[1, 2],
# [4, 1]])

Integer Indexing ๐Ÿ”Ž

NumPy

import numpy as np
x = np.asarray([[1, 2], [4, 5], [4, 1]])
x[[0, 2]]
# array([[1, 2],
# [4, 1]])

Array API

  • Added in the 2022.12 standard
import numpy.array_api as xp
x = xp.asarray([[1, 2], [4, 5], [4, 1]])
xp.take(x, xp.asarray([0, 2]), axis=0)
# Array([[1, 2],
# [4, 1]], dtype=int64)

Indexing Multiple Dimensions ๐Ÿ”Ž

NumPy

import numpy as np
x = np.asarray([[1, 2, 3], [4, 5, 6]])
x[1]
# array([4, 5, 6])

Indexing Multiple Dimensions ๐Ÿ”Ž

NumPy

import numpy as np
x = np.asarray([[1, 2, 3], [4, 5, 6]])
x[1]
# array([4, 5, 6])

Array API

import numpy.array_api as xp
x = xp.asarray([[1, 2, 3], [4, 5, 6]])
x[1]
# IndexError

Indexing Multiple Dimensions ๐Ÿ”Ž

NumPy

import numpy as np
x = np.asarray([[1, 2, 3], [4, 5, 6]])
x[1]
# array([4, 5, 6])

Array API

import numpy.array_api as xp
x = xp.asarray([[1, 2, 3], [4, 5, 6]])
x[1]
# IndexError
x[1, :]
# array([4, 5, 6])


Random Number Generators ๐ŸŽฎ

NumPy

import numpy as np
rng = np.random.default_rng()
x = rng.standard_normal(size=10)


Random Number Generators ๐ŸŽฎ

NumPy

import numpy as np
rng = np.random.default_rng()
x = rng.standard_normal(size=10)

Array API

import numpy as np
rng = np.random.default_rng()
x_np = rng.standard_normal(size=10)
xp = array_namespace(x)
x_xp = xp.asarray(x_np, device=device(x))

Order โ™Ÿ๏ธ

rng = np.random.default_rng()
x = rng.standard_normal(size=(10_000, 10_000))
x_c = np.asarray(x, order="C")
x_f = np.asarray(x, order="F")
%%timeit
_ = x_c.sum(axis=0)
# 36.3 ms ยฑ 1.44 ms per loop
%%timeit
_ = x_f.sum(axis=0)
# 18.8 ms ยฑ 131 ยตs per loop

Semantic Differences ๐Ÿช„


Type Promotion โ™›

NumPy

import numpy as np
x1 = np.asarray([[1, 2], [4, 5]])
x2 = np.asarray([[1, 2]], dtype=np.float32)
x1 + x2
# array([[2., 4.],
# [5., 7.]])


Type Promotion โ™›

NumPy

import numpy as np
x1 = np.asarray([[1, 2], [4, 5]])
x2 = np.asarray([[1, 2]], dtype=np.float32)
x1 + x2
# array([[2., 4.],
# [5., 7.]])

Array API

x1 = xp.asarray([[1, 2], [4, 5]])
x2 = xp.asarray([[1, 2]], dtype=xp.float32)
x1 + x2
# TypeError: int64 and float32 cannot be type promoted together

Type Promotion โ™›

Workaround

x1 = xp.asarray([[1, 2], [4, 5]], dtype=xp.float32)
x2 = xp.asarray([[1, 2]], dtype=xp.float32)
x1 + x2
# Array([[2., 4.],
# [5., 7.]], dtype=float32)

Type Promotion โ™›: Python Scalars

NumPy

import numpy as np
x1 = np.asarray([[1, 2, 3]])
x2 = 1.0
x1 + x2
# array([[2., 3., 4.]])

Type Promotion โ™›: Python Scalars

NumPy

import numpy as np
x1 = np.asarray([[1, 2, 3]])
x2 = 1.0
x1 + x2
# array([[2., 3., 4.]])

Array API

import numpy.array_api as xp
x1 = xp.asarray([[1, 2, 3]])
x2 = 1.0
x1 + x2
# TypeError: Python float scalars can only be promoted with floating-point arrays.

Type Promotion โ™›: Python Scalars

Workaround

import numpy.array_api as xp
x1 = xp.asarray([[1, 2, 3]], dtype=xp.float32)
x2 = 1.0
x1 + x2
# Array([[2., 3., 4.]], dtype=float32)



Device ๐Ÿ“ 

NumPy

import numpy as np
y = np.linspace(2.0, 3.0, num=10)



Device ๐Ÿ“ 

NumPy

import numpy as np
y = np.linspace(2.0, 3.0, num=10)

Array API

from array_api_compat import device
xp = array_namespace(x)
y = xp.linspace(2.0, 3.0, num=10, device=device(x))

Compiled Code ๐Ÿ—๏ธ

Complied Code in scikit-learn? ๐Ÿ—๏ธ

  • Random Forest ๐ŸŒฒ๐ŸŒฒ๐ŸŒฒ
    • RandomForestClassifier
    • RandomForestRegressor
  • Histogram Gradient Boosting ๐ŸŽ„ + ๐Ÿ›น
    • HistGradientBoostingClassifier
    • HistGradientBoostingRegressor
  • Linear Models ๐Ÿ“ˆ
    • LogisticRegression
    • PoissonRegressor

Possible Solutions

Works Now ๐Ÿช„

  • Convert to NumPy and back - SciPy

Convert to NumPy and back - SciPy

def func(a, b):
xp = array_namespace(a, b)
c = xp.sum(a, axis=1) + xp.sum(b, axis=1)
c = numpy.asarray(c)
d = compiled_code_that_only_works_with_numpy(c)
d = xp.asarray(d)
return d

Possible Solutions

Works Now ๐Ÿช„

  • Convert to NumPy and back - SciPy

Dispatching ๐Ÿ”€

  • uarray - SciPy
  • Plugins - Scikit-learn
  • Array library specific code

Dispatching ๐Ÿ”€

def func(a, b, plugin):
xp = array_namespace(a, b)
c = xp.sum(a, axis=1) + xp.sum(b, axis=1)
d = plugin.dispatch_to_library(c)
e = xp.mean(d, axis=0)
return e

Array library specific code ๐Ÿ“š

def erf(x):
if is_numpy(x):
import scipy.special
return scipy.special.erf(x)
elif is_cupy(x):
import cupyx.scipy.special.erf
import cupyx.scipy.special.erf(x)
elif is_pytorch(x):
import torch
return torch.special.erf(x)
else:
...

Challenges ๐Ÿšง

  • API Differences ๐Ÿ”Œ
  • Semantic Differences ๐Ÿช„
  • Compiled Code ๐Ÿ—๏ธ

Why Adopt the Array API Standard?

Smaller API

Why Adopt the Array API Standard?

Smaller API

Portable

Why Adopt the Array API Standard?

Smaller API

Portable

Performance

Performance ๐Ÿš€

16-core AMD 5950x CPU and Nvidia RTX 3090 GPU

Conclusion

User ๐Ÿงช

Library Author โœ๏ธ

Conclusion

User ๐Ÿงช

Library Author โœ๏ธ


GPU support in scikit-learn โ‰๏ธ

Paused

Help

Keyboard shortcuts

โ†‘, โ†, Pg Up, k Go to previous slide
โ†“, โ†’, Pg Dn, Space, j Go to next slide
Home Go to first slide
End Go to last slide
Number + Return Go to specific slide
b / m / f Toggle blackout / mirrored / fullscreen mode
c Clone slideshow
p Toggle presenter mode
t Restart the presentation timer
?, h Toggle this help
Esc Back to slideshow