The class for distilled surrogate models.
Note
Do not initalize this class on its own. It is automatically created by the distill function for the interpreter class.
Public fields
interpreter
The interpreter object to use as a standardized wrapper for the model
features
The indices of the features in the data used in the surrogate model
weights
The weights used to recombine the PDPs into a surrogate for the original model
intercept
The intercept term we use for our predictions
feature.centers
The center value for the features determined in the model
center.mean
Boolean value that determines whether we use the mean-centered data for our predictions
grid
A list of PDPS that determine our prediction.
snap.grid
Boolean that determines whether we use grid.points
Methods
Method new()
Usage
Surrogate$new(
interpreter,
features,
weights,
intercept,
feature.centers,
center.mean,
grid,
snap.grid
)
Arguments
interpreter
The interpreter object we want to build a surrogate model for.
features
The indices of features in the training data used for the surrogate model
weights
The weights for each given feature after the surrogate model is fit.
intercept
The baseline value. If uncentered, this is 0, and if centered, this will be the mean of the predictions of the original model on the training data.
feature.centers
The baseline value for the effect of each feature. If uncentered, this is 0.
center.mean
A boolean value that shows whether this model is a centered or uncentered model
grid
A list of dataframes containing the pre-calculated values used to generate predictions if snap.grid is TRUE
snap.grid
Boolean that determines if we use previously calculated values or re-predict using the functions.
Examples
library(distillML)
library(Rforestry)
set.seed(491)
data <- MASS::crabs
levels(data$sex) <- list(Male = "M", Female = "F")
levels(data$sp) <- list(Orange = "O", Blue = "B")
colnames(data) <- c("Species","Sex","Index","Frontal Lobe",
"Rear Width", "Carapace Length","Carapace Width","Body Depth")
test_ind <- sample(1:nrow(data), 180)
train_reg <- data[-test_ind,]
test_reg <- data[test_ind,]
forest <- forestry(x=train_reg[,-which(names(train_reg)=="Carapace Width")],
y=train_reg[,which(names(train_reg)=="Carapace Width")])
forest_predictor <- Predictor$new(model = forest, data=train_reg,
y="Carapace Width", task = "regression")
forest_interpret <- Interpreter$new(predictor = forest_predictor)
# Both initializations of a surrogate class result in the same surrogate model
surrogate.model <- distill(forest_interpret)
surrogate.model <- distill(forest_interpret,
center.mean = TRUE,
features = 1:length(forest_interpret$features),
cv = FALSE,
snap.grid = TRUE,
snap.train = TRUE)