Skip to contents

The class for distilled surrogate models.

Note

Do not initalize this class on its own. It is automatically created by the distill function for the interpreter class.

Public fields

interpreter

The interpreter object to use as a standardized wrapper for the model

features

The indices of the features in the data used in the surrogate model

weights

The weights used to recombine the PDPs into a surrogate for the original model

intercept

The intercept term we use for our predictions

feature.centers

The center value for the features determined in the model

center.mean

Boolean value that determines whether we use the mean-centered data for our predictions

grid

A list of PDPS that determine our prediction.

snap.grid

Boolean that determines whether we use grid.points

Methods


Method new()

Usage

Surrogate$new(
  interpreter,
  features,
  weights,
  intercept,
  feature.centers,
  center.mean,
  grid,
  snap.grid
)

Arguments

interpreter

The interpreter object we want to build a surrogate model for.

features

The indices of features in the training data used for the surrogate model

weights

The weights for each given feature after the surrogate model is fit.

intercept

The baseline value. If uncentered, this is 0, and if centered, this will be the mean of the predictions of the original model on the training data.

feature.centers

The baseline value for the effect of each feature. If uncentered, this is 0.

center.mean

A boolean value that shows whether this model is a centered or uncentered model

grid

A list of dataframes containing the pre-calculated values used to generate predictions if snap.grid is TRUE

snap.grid

Boolean that determines if we use previously calculated values or re-predict using the functions.

Returns

A surrogate model object that we can use for predictions


Method clone()

The objects of this class are cloneable with this method.

Usage

Surrogate$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

library(distillML)
library(Rforestry)
set.seed(491)
data <- MASS::crabs

levels(data$sex) <- list(Male = "M", Female = "F")
levels(data$sp) <- list(Orange = "O", Blue = "B")
colnames(data) <- c("Species","Sex","Index","Frontal Lobe",
"Rear Width", "Carapace Length","Carapace Width","Body Depth")

test_ind <- sample(1:nrow(data), 180)
train_reg <- data[-test_ind,]
test_reg <- data[test_ind,]


forest <- forestry(x=train_reg[,-which(names(train_reg)=="Carapace Width")],
y=train_reg[,which(names(train_reg)=="Carapace Width")])

forest_predictor <- Predictor$new(model = forest, data=train_reg,
y="Carapace Width", task = "regression")

forest_interpret <- Interpreter$new(predictor = forest_predictor)

# Both initializations of a surrogate class result in the same surrogate model
surrogate.model <- distill(forest_interpret)
surrogate.model <- distill(forest_interpret,
                           center.mean = TRUE,
                           features = 1:length(forest_interpret$features),
                           cv = FALSE,
                           snap.grid = TRUE,
                           snap.train = TRUE)