Skip to contents

A wrapper class for generic ML algorithms (xgboost, RF, BART, rpart, etc.) in order to standardize the predictions given by different algorithms to be compatible with the interpretability functions.

The necessary variables are model, data, y. The other variables are optional, and depend on the use cases. Type should be used only when a prediction function is NOT specified.

The outputs of the algorithm must be the values if it is regression, or probabilities if classification. For classification problems with more than two categories, the output comes out as vectors of probabilities for the specified "class" category. Because this is for ML interpretability, other types of predictions (ex: predictions that spit out the factor) are not allowed.

Note

The class that wraps a machine learning model in order to provide a standardized method for predictions for different models. prediction method must be constructed, with optional argument of type

Public fields

data

The training data that was used during training for the model. This should be a data frame matching the data frame the model was given for training, which includes the label or outcome.

model

The object corresponding to the trained model that we want to make a Predictor object for. If this model doesn't have a generic predict method, the user has to provide a custom predict function that accepts a data frame.

task

The prediction task the model is trained to perform (`classification` or `regression`).

class

The class for which we get predictions. We specify this to get the predictions (such as probabilites) for an observation being in a specific class (e.g. Male or Female). This parameter is necessary for classification predictions with more than a single vector of predictions.

prediction.function

An optional parameter if the model doesn't have a generic prediction function. This should take a data frame and return a vector of predictions for each observation in the data frame.

y

The name of the outcome feature in the `data` data frame.

Methods


Method new()

Usage

Predictor$new(
  model = NULL,
  data = NULL,
  predict.func = NULL,
  y = NULL,
  task = NULL,
  class = NULL,
  type = NULL
)

Arguments

model

The object corresponding to the trained model that we want to make a Predictor object for. If this model doesn't have a generic predict method, the user has to provide a custom predict function that accepts a data frame.

data

The training data that was used during training for the model. This should be a data frame matching the data frame the model was given for training, including the label or outcome.

predict.func

An optional parameter if the model doesn't have a generic prediction function. This should take a data frame and return a vector of predictions for each observation in the data frame.

y

The name of the outcome feature in the `data` data frame.

task

The prediction task the model is trained to perform (`classification` or `regression`).

class

The class for which we get predictions. We specify this to get the predictions (such as probabilites) for an observation being in a specific class (e.g. Male or Female). This parameter is necessary for classification predictions with more than a single vector of predictions.

type

The type of predictions done (i.e. 'response' for predicted probabliities for classification). This feature should only be used if no predict.func is specified.

Returns

A `Predictor` object.


Method clone()

The objects of this class are cloneable with this method.

Usage

Predictor$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

library(distillML)
library(Rforestry)
set.seed(491)
data <- MASS::crabs

levels(data$sex) <- list(Male = "M", Female = "F")
levels(data$sp) <- list(Orange = "O", Blue = "B")
colnames(data) <- c("Species","Sex","Index","Frontal Lobe",
"Rear Width", "Carapace Length","Carapace Width","Body Depth")

test_ind <- sample(1:nrow(data), nrow(data)%/%5)
train_reg <- data[-test_ind,]
test_reg <- data[test_ind,]


forest <- forestry(x=train_reg[,-which(names(train_reg)=="Carapace Width")],
y=train_reg[,which(names(train_reg)=="Carapace Width")])

forest_predictor <- Predictor$new(model = forest, data=train_reg,
y="Carapace Width", task = "regression")