Developing with the Plotting API¶
Scikit-learn defines a simple API for creating visualizations for machine learning. The key features of this API is to run calculations once and to have the flexibility to adjust the visualizations after the fact. This section is intended for developers who wish to develop or maintain plotting tools. For usage, users should refer to the :ref`User Guide <visualizations>`.
Plotting API Overview¶
This logic is encapsulated into a display object where the computed data is
stored and the plotting is done in a plot
method. The display object’s
__init__
method contains only the data needed to create the visualization.
The plot
method takes in parameters that only have to do with visualization,
such as a matplotlib axes. The plot
method will store the matplotlib artists
as attributes allowing for style adjustments through the display object. A
plot_*
helper function accepts parameters to do the computation and the
parameters used for plotting. After the helper function creates the display
object with the computed values, it calls the display’s plot method. Note that
the plot
method defines attributes related to matplotlib, such as the line
artist. This allows for customizations after calling the plot
method.
For example, the RocCurveDisplay
defines the following methods and
attributes:
class RocCurveDisplay:
def __init__(self, fpr, tpr, roc_auc, estimator_name):
...
self.fpr = fpr
self.tpr = tpr
self.roc_auc = roc_auc
self.estimator_name = estimator_name
def plot(self, ax=None, name=None, **kwargs):
...
self.line_ = ...
self.ax_ = ax
self.figure_ = ax.figure_
def plot_roc_curve(estimator, X, y, pos_label=None, sample_weight=None,
drop_intermediate=True, response_method="auto",
name=None, ax=None, **kwargs):
# do computation
viz = RocCurveDisplay(fpr, tpr, roc_auc,
estimator.__class__.__name__)
return viz.plot(ax=ax, name=name, **kwargs)
Read more in ROC Curve with Visualization API and the User Guide.
Plotting with Multiple Axes¶
Some of the plotting tools like
plot_partial_dependence
and
PartialDependenceDisplay
support plottong on
multiple axes. Two different scenarios are supported:
1. If a list of axes is passed in, plot
will check if the number of axes is
consistent with the number of axes it expects and then draws on those axes. 2.
If a single axes is passed in, that axes defines a space for multiple axes to
be placed. In this case, we suggest using matplotlib’s
~matplotlib.gridspec.GridSpecFromSubplotSpec
to split up the space:
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpecFromSubplotSpec
fig, ax = plt.subplots()
gs = GridSpecFromSubplotSpec(2, 2, subplot_spec=ax.get_subplotspec())
ax_top_left = fig.add_subplot(gs[0, 0])
ax_top_right = fig.add_subplot(gs[0, 1])
ax_bottom = fig.add_subplot(gs[1, :])
By default, the ax
keyword in plot
is None
. In this case, the single
axes is created and the gridspec api is used to create the regions to plot in.
See for example, plot_partial_dependence
which
plots multiple lines and contours using this API. The axes defining the
bounding box is saved in a bounding_ax_
attribute. The individual axes
created are stored in an axes_
ndarray, corresponding to the axes position on
the grid. Positions that are not used are set to None
. Furthermore, the
matplotlib Artists are stored in lines_
and contours_
where the key is the
position on the grid. When a list of axes is passed in, the axes_
, lines_
,
and contours_
is a 1d ndarray corresponding to the list of axes passed in.