This vignette shows usage of EIX
package for titanic data. This dataset was copied from stablelearner
package. With EIX
package we explain XGBoost classification model concerning the survival problem. More details about EIX
package here.
## Warning: package 'data.table' was built under R version 4.0.4
gender | age | class | embarked | country | fare | sibsp | parch | survived |
---|---|---|---|---|---|---|---|---|
male | 42 | 3rd | Southampton | United States | 7.11 | 0 | 0 | no |
male | 13 | 3rd | Southampton | United States | 20.05 | 0 | 2 | no |
male | 16 | 3rd | Southampton | United States | 20.05 | 1 | 1 | no |
female | 39 | 3rd | Southampton | England | 20.05 | 1 | 1 | yes |
female | 16 | 3rd | Southampton | Norway | 7.13 | 0 | 0 | yes |
male | 25 | 3rd | Southampton | United States | 7.13 | 0 | 0 | yes |
library("xgboost")
param <- list(objective = "binary:logistic", max_depth = 2)
xgb_model <- xgboost(sparse_matrix, params = param, label = titanic_data[, "survived"] == "yes", nrounds = 50, verbose = FALSE)
## [00:13:12] WARNING: amalgamation/../src/learner.cc:1061: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.
First let’s plot the model.
## Warning: ggrepel: 5 unlabeled data points (too many overlaps). Consider
## increasing max.overlaps
Next we explore interactions using interactions()
functions and its plot.
interactions<-interactions(xgb_model, sparse_matrix, option = "interactions")
head(interactions, 15)
## Parent Child sumGain frequency
## 1: class3rd fare 45.577697 5
## 2: sibsp age 23.668270 3
## 3: genderfemale age 23.227908 4
## 4: class3rd age 22.997468 5
## 5: genderfemale fare 13.975253 2
## 6: countrySwitzerland fare 11.700150 2
## 7: fare classrestaurant staff 7.023322 1
## 8: age countryUnited States 4.857722 1
## 9: countrySwitzerland age 4.250423 2
## 10: classrestaurant staff class3rd 4.190052 1
## 11: classdeck crew age 4.165662 2
## 12: fare genderfemale 4.100826 1
## 13: age sibsp 4.055897 1
## 14: age fare 4.045337 1
## 15: genderfemale class2nd 3.013065 1
## Feature sumGain sumCover meanGain meanCover frequency
## 1: genderfemale 840.50 5042.000 60.030 360.100 14
## 2: class3rd 205.10 2172.000 20.510 217.200 10
## 3: classdeck crew 115.50 2102.000 19.250 350.300 6
## 4: age 90.00 3870.000 5.294 227.600 17
## 5: fare 58.59 2996.000 3.255 166.400 18
## 6: embarkedCherbourg 48.64 787.500 24.320 393.800 2
## 7: class3rd:fare 45.58 682.100 9.116 136.400 5
## 8: countryUnited States 27.46 519.500 13.730 259.800 2
## 9: classrestaurant staff 27.41 1970.000 4.569 328.300 6
## 10: sibsp:age 23.67 684.900 7.889 228.300 3
## 11: genderfemale:age 23.23 849.400 5.807 212.400 4
## 12: class3rd:age 23.00 971.200 4.599 194.200 5
## 13: sibsp 19.56 921.200 6.519 307.100 3
## 14: genderfemale:fare 13.98 348.900 6.988 174.500 2
## 15: countrySwitzerland:fare 11.70 6.044 5.850 3.022 2
## mean5Gain
## 1: 159.700
## 2: 37.050
## 3: 22.900
## 4: 12.700
## 5: 6.433
## 6: 24.320
## 7: 9.116
## 8: 13.730
## 9: 5.247
## 10: 7.889
## 11: 5.807
## 12: 4.599
## 13: 6.519
## 14: 6.988
## 15: 5.850
Let’s see an explanation of the prediction for an 18-year-old from England who has traveled 3rd class.
data <- titanic_data[27,]
new_observation <- sparse_matrix[27,]
wf<-waterfall(xgb_model, new_observation, data, option = "interactions")
wf
## contribution
## xgboost: intercept -0.819
## xgboost: gender = 2 1.714
## xgboost: class = 3 -1.437
## xgboost: country:fare = 15:8.06 -1.011
## xgboost: country = 15 -0.635
## xgboost: embarked = 4 0.310
## xgboost: age:class = 18:3 0.119
## xgboost: class:age = 3:18 -0.116
## xgboost: fare = 8.06 -0.103
## xgboost: age:country = 18:15 0.102
## xgboost: class:fare = 3:8.06 0.096
## xgboost: age = 18 -0.094
## xgboost: age:embarked = 18:4 0.089
## xgboost: gender:class = 2:3 -0.043
## xgboost: sibsp:age = 0:18 -0.022
## xgboost: sibsp = 0 0.021
## xgboost: gender:age = 2:18 0.020
## xgboost: fare:age = 8.06:18 0.019
## xgboost: gender:fare = 2:8.06 -0.018
## xgboost: age:sibsp = 18:0 0.010
## xgboost: parch:fare = 0:8.06 0.002
## xgboost: parch = 0 0.001
## xgboost: prediction -1.794