The package contains two workhorses to calculate SHAP values for any model:
permshap()
: Exact permutation SHAP algorithm of [1].
Recommended for up to 8-10 features.kernelshap()
: Kernel SHAP algorithm of [2] and [3]. By
default, exact Kernel SHAP is used for up to \(p=8\) features, and an almost exact hybrid
algorithm otherwise.Furthermore, the function additive_shap()
produces SHAP
values for additive models fitted via lm()
,
glm()
, mgcv::gam()
, mgcv::bam()
,
gam::gam()
, survival::coxph()
, or
survival::survreg()
. It is exponentially faster than
permshap()
and kernelshap()
, with identical
results when the background dataset of the latter equals the full
training data.
Kernel SHAP has been introduced in [2] as an approximation of permutation SHAP [1]. For up to ten features, exact calculations are realistic for both algorithms. Since exact Kernel SHAP is still only an approximation of exact permutation SHAP, the latter should be preferred in this case, even if the results are often very similar.
A situation where the two approaches give different results: The model has interactions of order three or higher.
X
to be explained. If the training dataset is small, simply
use the full training data for this purpose. X
should only
contain feature columns.bg_X
to
calculate marginal means. For this purpose, set aside 50 to 500 rows
from the training data. If the training data is small, use the full
training data. In cases with a natural “off” value (like MNIST digits),
this can also be a single row with all values set to the off value. If
not specified, maximum bg_n = 200
rows are randomly sampled
from X
.kernelshap(object, X, bg_X = NULL, ...)
or
permshap(object, X, bg_X = NULL, ...)
to calculate SHAP
values. Runtime is proportional to nrow(X)
, while memory
consumption scales linearly in nrow(bg_X)
.Remarks
bg_w
.kernelshap()
, the iterative
pure sampling approach in [3] can be enforced.additive_shap()
explainer is easier to use: Only
the model and X
are required.# From CRAN
install.packages("kernelshap")
# Or the development version:
::install_github("ModelOriented/kernelshap") devtools
Let’s model diamond prices with a random forest. As an alternative, you could use the {treeshap} package in this situation.
library(kernelshap)
library(ggplot2)
library(ranger)
library(shapviz)
<- transform(
diamonds
diamonds,log_price = log(price),
log_carat = log(carat)
)
<- c("log_carat", "clarity", "color", "cut")
xvars
<- ranger(
fit ~ log_carat + clarity + color + cut,
log_price data = diamonds,
num.trees = 100,
seed = 20
)# OOB R-squared 0.989
fit
# 1) Sample rows to be explained
set.seed(10)
<- diamonds[sample(nrow(diamonds), 1000), xvars]
X
# 2) Optional: Select background data. If not specified, a random sample of 200 rows
# from X is used
<- diamonds[sample(nrow(diamonds), 200), ]
bg_X
# 3) Crunch SHAP values for all 1000 rows of X (54 seconds)
# Note: Since the number of features is small, we use permshap()
system.time(
<- permshap(fit, X, bg_X = bg_X)
ps
)
ps
# SHAP values of first observations:
log_carat clarity color cut1,] 1.1913247 0.09005467 -0.13430720 0.000682593
[2,] -0.4931989 -0.11724773 0.09868921 0.028563613
[
# Kernel SHAP gives almost the same:
system.time( # 49 s
<- kernelshap(fit, X, bg_X = bg_X)
ks
)
ks# log_carat clarity color cut
# [1,] 1.1911791 0.0900462 -0.13531648 0.001845958
# [2,] -0.4927482 -0.1168517 0.09815062 0.028255442
# 4) Analyze with our sister package {shapviz}
<- shapviz(ps)
ps sv_importance(ps)
sv_dependence(ps, xvars)
{kernelshap} can deal with almost any situation. We will show some of the flexibility here. The first two examples require you to run at least up to Step 2 of the “Basic Usage” code.
Parallel computing is supported via {foreach}. Note that this does not work with all models, and that there is no progress bar.
On Windows, sometimes not all packages or global objects are passed
to the parallel sessions. Often, this can be fixed via
parallel_args
, see the generalized additive model
below.
library(doFuture)
library(mgcv)
registerDoFuture()
plan(multisession, workers = 4) # Windows
# plan(multicore, workers = 4) # Linux, macOS, Solaris
<- gam(log_price ~ s(log_carat) + clarity * color + cut, data = diamonds)
fit
system.time( # 9 seconds in parallel
<- permshap(fit, X, parallel = TRUE, parallel_args = list(.packages = "mgcv"))
ps
)
ps
# SHAP values of first observations:
# log_carat clarity color cut
# [1,] 1.26801 0.1023518 -0.09223291 0.004512402
# [2,] -0.51546 -0.1174766 0.11122775 0.030243973
# Because there are no interactions of order above 2, Kernel SHAP gives the same:
system.time( # 27 s non-parallel
<- kernelshap(fit, X, bg_X = bg_X)
ks
)all.equal(ps$S, ks$S)
# [1] TRUE
# Now the usual plots:
<- shapviz(ps)
sv sv_importance(sv, kind = "bee")
sv_dependence(sv, xvars)
In this {keras} example, we show how to use a tailored
predict()
function that complies with
The results are not fully reproducible though.
library(keras)
<- keras_model_sequential()
nn |>
nn layer_dense(units = 30, activation = "relu", input_shape = 4) |>
layer_dense(units = 15, activation = "relu") |>
layer_dense(units = 1)
|>
nn compile(optimizer = optimizer_adam(0.001), loss = "mse")
<- list(
cb callback_early_stopping(patience = 20),
callback_reduce_lr_on_plateau(patience = 5)
)
|>
nn fit(
x = data.matrix(diamonds[xvars]),
y = diamonds$log_price,
epochs = 100,
batch_size = 400,
validation_split = 0.2,
callbacks = cb
)
<- function(mod, X)
pred_fun predict(mod, data.matrix(X), batch_size = 1e4, verbose = FALSE)
system.time( # 60 s
<- permshap(nn, X, bg_X = bg_X, pred_fun = pred_fun)
ps
)
<- shapviz(ps)
ps sv_importance(ps, show_numbers = TRUE)
sv_dependence(ps, xvars)
The additive explainer extracts the additive contribution of each feature from a model of suitable class.
<- lm(log(price) ~ log(carat) + color + clarity + cut, data = diamonds)
fit <- additive_shap(fit, diamonds) |>
shap_values shapviz()
sv_importance(shap_values)
sv_dependence(shap_values, v = "carat", color_var = NULL)
{kernelshap} supports multivariate predictions like:
Here, we use the iris
data (no need to run code from
above).
library(kernelshap)
library(ranger)
library(shapviz)
set.seed(1)
# Probabilistic classification
<- ranger(Species ~ ., data = iris, probability = TRUE)
fit_prob <- permshap(fit_prob, X = iris[-5]) |>
ps_prob shapviz()
sv_importance(ps_prob)
sv_dependence(ps_prob, "Petal.Length")
Meta-learning packages like {tidymodels}, {caret} or {mlr3} are
straightforward to use. The following examples additionally shows that
the ...
arguments of permshap()
and
kernelshap()
are passed to predict()
.
library(kernelshap)
library(tidymodels)
set.seed(1)
<- iris |>
iris_recipe recipe(Species ~ .)
<- rand_forest(trees = 100) |>
mod set_engine("ranger") |>
set_mode("classification")
<- workflow() |>
iris_wf add_recipe(iris_recipe) |>
add_model(mod)
<- iris_wf |>
fit fit(iris)
system.time( # 4s
<- permshap(fit, iris[-5], type = "prob")
ps
)
ps
# Some values
$.pred_setosa
Sepal.Length Sepal.Width Petal.Length Petal.Width1,] 0.02186111 0.012137778 0.3658278 0.2667667
[2,] 0.02628333 0.001315556 0.3683833 0.2706111 [
library(kernelshap)
library(caret)
<- train(
fit ~ .,
Sepal.Length data = iris,
method = "lm",
tuneGrid = data.frame(intercept = TRUE),
trControl = trainControl(method = "none")
)
<- permshap(fit, iris[-1]) ps
library(kernelshap)
library(mlr3)
library(mlr3learners)
set.seed(1)
<- TaskClassif$new(id = "1", backend = iris, target = "Species")
task_classif <- lrn("classif.rpart", predict_type = "prob")
learner_classif $train(task_classif)
learner_classif
<- learner_classif$selected_features()
x
# Don't forget to pass predict_type = "prob" to mlr3's predict()
<- permshap(
ps X = iris, feature_names = x, predict_type = "prob"
learner_classif,
)
ps# $setosa
# Petal.Length Petal.Width
# [1,] 0.6666667 0
# [2,] 0.6666667 0
[1] Erik Štrumbelj and Igor Kononenko. Explaining prediction models and individual predictions with feature contributions. Knowledge and Information Systems 41, 2014.
[2] Scott M. Lundberg and Su-In Lee. A Unified Approach to Interpreting Model Predictions. Advances in Neural Information Processing Systems 30, 2017.
[3] Ian Covert and Su-In Lee. Improving KernelSHAP: Practical Shapley Value Estimation Using Linear Regression. Proceedings of The 24th International Conference on Artificial Intelligence and Statistics, PMLR 130:3457-3465, 2021.