<!DOCTYPE html>

How to extend the package

How to extend the package

In the following, we illustrate how further methods can be added to the counterfactuals package by integrating the featureTweakR package, which implements Feature Tweaking of Tolomei et al. (2017).

(Note that the featureTweakR package has a couple of limitations, e.g., that factors in the training data cause problems or that the algorithm is only applicable to randomForests trained on standardized features. Therefore, featureTweakR is not part of the counterfactuals package, but serves as a suitable example for our purpose here.)

Structure of the counterfactuals package

Before we dive into the implementation details, we briefly explain the structure of the counterfactuals package.

Class diagram

Each counterfactual method is represented by its own R6 class. Depending on whether a counterfactual method supports classification or regression tasks, it inherits from the (abstract) CounterfactualMethodClassif or CounterfactualMethodRegr classes, respectively. Counterfactual methods that support both tasks are split into two separate classes; for instance, as MOC is applicable to classification and regression tasks, we implement two classes: MOCClassif and MOCRegr.

Leaf classes (like MOCClassif and MOCRegr) inherit the find_counterfactuals() method from CounterfactualMethodClassif or CounterfactualMethodRegr, respectively. The key advantage of this approach is that we are able to provide a tailored find_counterfactuals() interface to the task at hand: for classification tasks find_counterfactuals() has two arguments desired_class and desired_prob and for regression tasks it has one argument desired_outcome.

Call graph

The find_counterfactuals() method calls a private run() method (1)—implemented by the leaf classes—which performs the search and returns the counterfactuals as a data.table (2). The find_counterfactuals() method then creates a Counterfactuals object, which contains the counterfactuals and provides several methods for their evaluation and visualization (3).

Integrating a new counterfactuals explanation method

To integrate Feature Tweaking, we first need to install featureTweakR and pforeach and load the required libraries.

devtools::install_github("katokohaku/featureTweakR")
# required for FeatureTweakeR
devtools::install_github("hoxo-m/pforeach")
library(counterfactuals)
library(randomForest)
library(featureTweakR)
library(R6)

Class structure

A new leaf class needs at least two methods: initialize() and run(). The method print_parameters() is not mandatory, but strongly recommended as it gives objects of that class an informative print() output.

As elaborated above, the new class inherits from either CounterfactualMethodClassif or CounterfactualMethodRegr, depending on which task it supports. Because Feature Tweaking supports classification tasks, the new FeatureTweakerClassif class inherits from the former.

FeatureTweakerClassif = R6::R6Class("FeatureTweakerClassif", inherit = CounterfactualMethodClassif,
  
  public = list(
    initialize = function() {}
  ),
  
  private = list(
    run = function() {},
    
    print_parameters = function() {}
  )
)

Implement the \(initialize</code> method</h3> <p>The <code>initialize()</code> method must have a predictor argument that takes an <a href="https://giuseppec.github.io/iml/reference/Predictor.html"><code>iml::Predictor</code></a> object. In addition, it may have further arguments that are specific to the counterfactual method such as <code>ktree</code>, <code>epsiron</code>, and <code>resample</code> in this case. For argument checks, we recommend the <a href="https://mllg.github.io/checkmate/"><code>checkmate</code></a> package. We also fill the <code>print_parameters()</code> method with the parameters of Feature Tweaking.</p> <div class="sourceCode" id="cb4"><pre class="sourceCode r"><code class="sourceCode r"><span id="cb4-1"><a href="#cb4-1" tabindex="-1"></a>FeatureTweakerClassif <span class="ot">=</span> <span class="fu">R6Class</span>(<span class="st">&quot;FeatureTweakerClassif&quot;</span>, <span class="at">inherit =</span> CounterfactualMethodClassif,</span> <span id="cb4-2"><a href="#cb4-2" tabindex="-1"></a> </span> <span id="cb4-3"><a href="#cb4-3" tabindex="-1"></a> <span class="at">public =</span> <span class="fu">list</span>(</span> <span id="cb4-4"><a href="#cb4-4" tabindex="-1"></a> <span class="at">initialize =</span> <span class="cf">function</span>(predictor, <span class="at">ktree =</span> <span class="cn">NULL</span>, <span class="at">epsiron =</span> <span class="fl">0.1</span>, </span> <span id="cb4-5"><a href="#cb4-5" tabindex="-1"></a> <span class="at">resample =</span> <span class="cn">FALSE</span>) {</span> <span id="cb4-6"><a href="#cb4-6" tabindex="-1"></a> <span class="co"># adds predictor to private\)predictor field super\(</span><span class="fu">initialize</span>(predictor) </span> <span id="cb4-8"><a href="#cb4-8" tabindex="-1"></a> private<span class="sc">\)ktree = ktree private\(</span>epsiron <span class="ot">=</span> epsiron</span> <span id="cb4-10"><a href="#cb4-10" tabindex="-1"></a> private<span class="sc">\)resample = resample } ), private = list( ktree = NULL, epsiron = NULL, resample = NULL, run = function() {}, print_parameters = function() { cat(" - epsiron: ", private\(</span>epsiron, <span class="st">&quot;</span><span class="sc">\n</span><span class="st">&quot;</span>)</span> <span id="cb4-23"><a href="#cb4-23" tabindex="-1"></a> <span class="fu">cat</span>(<span class="st">&quot; - ktree: &quot;</span>, private<span class="sc">\)ktree, "") cat(" - resample: ", private\(</span>resample)</span> <span id="cb4-25"><a href="#cb4-25" tabindex="-1"></a> }</span> <span id="cb4-26"><a href="#cb4-26" tabindex="-1"></a> )</span> <span id="cb4-27"><a href="#cb4-27" tabindex="-1"></a>)</span></code></pre></div> </div> <div id="implement-the-run-method" class="section level3"> <h3>Implement the <code>\)run method

The run() method performs the search for counterfactuals. Its structure is completely free, which makes it flexible to add new counterfactual methods to the counterfactuals package.

The workflow of finding counterfactuals with the featureTweakR package is explained here and essentially consists of these steps:

# Rule extraction
rules = getRules(rf, ktree = 5L)
# Get e-satisfactory instance 
es = set.eSatisfactory(rules, epsiron = 0.3)
# Compute counterfactuals
tweaked = tweak(
  es, rf, x_interest, label.from = ..., label.to = ..., .dopar = FALSE
)
tweaked$suggest

As long as ktree—the number of trees to parse—is smaller than the total number of trees in the randomForest, the rule extraction is a random process. Hence, these steps can be repeated resample times to obtain multiple counterfactuals.

FeatureTweakerClassif = R6Class("FeatureTweakerClassif", 
  inherit = CounterfactualMethodClassif,
  
  public = list(
    initialize = function(predictor, ktree = NULL, epsiron = 0.1, 
      resample = FALSE) {
      # adds predictor to private$predictor field
      super$initialize(predictor) 
      private$ktree = ktree
      private$epsiron = epsiron
      private$resample = resample
    }
  ),
  
  private = list(
    ktree = NULL,
    epsiron = NULL,
    resample = NULL,
    
    run = function() {
      # Extract info from private fields
      predictor = private$predictor
      y_hat_interest = predictor$predict(private$x_interest)
      class_x_interest = names(y_hat_interest)[which.max(y_hat_interest)]
      rf = predictor$model
      
      # Search counterfactuals by calling functions in featureTweakR 
      rules = getRules(rf, ktree = private$ktree, resample = private$resample)
      es = set.eSatisfactory(rules, epsiron = private$epsiron)
      tweaks = featureTweakR::tweak(
        es, rf, private$x_interest, label.from = class_x_interest, 
        label.to = private$desired_class, .dopar = FALSE
      )
      res <- tweaks$suggest
    },
    
    print_parameters = function() {
      cat(" - epsiron: ", private$epsiron, "\n")
      cat(" - ktree: ", private$ktree, "\n")
      cat(" - resample: ", private$resample)
    }
  )
)

Use Case

Now, that we have implemented FeatureTweakerClassif, let’s look at a short application to the iris dataset.

First, we train a randomForest model on the iris data set and set up the iml::Predictor object, omitting x_interest from the training data.

set.seed(78546)
X = subset(iris, select = -Species)[- 130L,]
y = iris$Species[-130L]
rf = randomForest(X, y, ntree = 20L)
predictor = iml::Predictor$new(rf, data = iris[-130L, ], y = "Species", type = "prob")

For x_interest, the model predicts a probability of 25% for versicolor.

x_interest = iris[130L, ]
predictor$predict(x_interest)
#>   setosa versicolor virginica
#> 1      0        0.3       0.7
#> extracting sampled 10 of 20 trees
#> Time difference of 0.3100975 secs
#> set e-satisfactory instance (10 trees)
#> Time difference of 0.06299901 secs
#> 1 instances were predicted by 10 trees: 
#> instance[1]: predicted "virginica" agreed by 7 tree (wants "virginica"->"versicolor")
#> evaluate 0 rules in 7 trees
#> - evalutate 10 candidate of rules (delta.min=999999999999999967336168804116691273849533185806555472917961779471295845921727862608739868455469056.000)
#> Time difference of 0.005567312 secs

Now, we use Feature Tweaking to address the question: “What would need to change in x_interest for the model to predict a probability of at least 60% forversicolor.

# Set up FeatureTweakerClassif
ft_classif = FeatureTweakerClassif$new(predictor, ktree = 10L, resample = TRUE)

# Find counterfactuals and create Counterfactuals Object
cfactuals = ft_classif$find_counterfactuals(
  x_interest = x_interest, desired_class = "versicolor", desired_prob = c(0.6, 1)
)

Just as for the existing methods, the result is a Counterfactuals object.

Comments

A minor limitation of this basic implementation is that we would not be able to find counterfactuals for a setting with max(desired_prob) < 0.5, since featureTweakR::tweak only searches for instances that would be predicted as desired_class by majority vote. To enable this setting, we would need to change some featureTweakR internal code, but for the sake of clarity we will not do this here.

References

Tolomei, G., Silvestri, F., Haines, A., Lalmas, M.: Interpretable Predictions of Tree-based Ensembles via Actionable Feature Tweaking. In: Proceedings of the 23rd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining. pp. 465–474. KDD ’17, ACM, New York, NY, USA (2017). .