knitr::opts_chunk$set(warning = FALSE, message = FALSE, eval=F)
library(tidyverse)
SuperLearner
is an algorithm that uses cross-validation to estimate the performance of multiple machine learning models, or the same model with different settings. It then creates an optimal weighted average of those models, which is also called an “ensemble”, using the test data performance.
library(SuperLearner)
Pima Indian Women data set
The type
column is the column that indicates the presence of diabetes. It is a binary Yes
or No
column, which means that it follows a binomial distribution.
# Get the `MASS` library
library(MASS)
# Train and test sets
train <- Pima.tr
test <- Pima.te
# Print out the first lines of `train`
head(train)
y <- as.numeric(train[,8])-1
ytest <- as.numeric(test[,8])-1
x <- data.frame(train[,1:7])
xtest <- data.frame(test[,1:7])
可以stacking的模型
listWrappers()
ranger
是随机森林包kernlab
是SVM包arm
是贝叶斯包,Bayes Generalized Linear Models (GLM)ipred
是bagging包
这些包使用前一定要先安装。
set.seed(150)
single.model <- SuperLearner(y,
x,
family=binomial(),
SL.library=list("SL.ranger"))
single.model
# Set the seed
set.seed(150)
# Fit the ensemble model
model <- SuperLearner(y,
x,
family=binomial(),
SL.library=list("SL.ranger",
"SL.ksvm",
"SL.ipredbagg",
"SL.bayesglm"))
# Return the model
model
什么是risk factor?variation
Ranger and KVSM have a coefficient of zero, which means that it is not weighted as part of the ensemble anymore.
Coef
是标准化的。
Hello Jiaxiang, The risk factor is simply the error of the model being fit. SuperLearner calculates the error of each model put into the ensemble stack. This is a gauge that it then uses to determine the coefficients or weights of each model. Behind the scenes, SuperLearner is using cross validation to determine the risk of each model. If you have a risk of 0.16, then this can be interpreted as having 0.16 error or 0.84 accuracy. [@Gremmell2018]
# Set the seed
set.seed(150)
# Get V-fold cross-validated risk estimate
cv.model <- CV.SuperLearner(y,
x,
V=5,
SL.library=list("SL.ranger",
"SL.ksvm",
"SL.ipredbagg",
"SL.bayesglm"))
# Print out the summary statistics
summary(cv.model)
plot(cv.model)
It’s easy to see that Bayes GLM performs the best on average while KSVM performs the worst and contains a lot of variation compared to the other models.
predictions <- predict.SuperLearner(model, newdata=xtest)
predictions$pred %>% head()
predictions$library.predict %>% head()
predict
是不行的。
# Load the package
library(dplyr)
# Recode probabilities
conv.preds <- ifelse(predictions$pred>=0.5,1,0)
# Load in `caret`
library(caret)
# Create the confusion matrix
cm <- confusionMatrix(conv.preds, ytest)
# Return the confusion matrix
cm
调整超参数
SL.ranger.tune <- function(...){
SL.ranger(..., num.trees=1000, mtry=2)
}
SL.ipredbagg.tune <- function(...){
SL.ipredbagg(..., nbagg=250)
}
# Set the seed
set.seed(150)
# Tune the model
cv.model.tune <-
CV.SuperLearner(y,
x,
V=5,
SL.library=list("SL.ranger",
"SL.ksvm",
"SL.ipredbagg",
"SL.bayesglm",
"SL.ranger.tune",
"SL.ipredbagg.tune"))
# Get summary statistics
summary(cv.model.tune)
plot(cv.model.tune)
前面CV完,后面继续。
# Set the seed
set.seed(150)
# Create the tuned model
model.tune <- SuperLearner(y,
x,
SL.library=
list("SL.ranger",
"SL.ksvm",
"SL.ipredbagg",
"SL.bayesglm",
"SL.ranger.tune",
"SL.ipredbagg.tune"))
# Return the tuned model
model.tune
SL.bayesglm
and SL.ipredbagg.tune
are now the only algorithms weighted in the ensemble.
# Gather predictions for the tuned model
predictions.tune <- predict.SuperLearner(model.tune, newdata=xtest)
# Recode predictions
conv.preds.tune <- ifelse(predictions.tune$pred>=0.5,1,0)
# Return the confusion matrix
confusionMatrix(conv.preds.tune,ytest)
create.Learner()
learner <- create.Learner("SL.ranger", params=list(num.trees=1000, mtry=2))
learner2 <- create.Learner("SL.ipredbagg", params=list(nbagg=250))
learner
# Set the seed
set.seed(150)
# Create a second tuned model
cv.model.tune2 <- CV.SuperLearner(y,
x,
V=5,
SL.library=list("SL.ranger",
"SL.ksvm",
"SL.ipredbagg",
"SL.bayesglm",
learner$names,
learner2$names))
# Get summary statistics
summary(cv.model.tune2)
plot(cv.model.tune2)
但是实际上,我并不知道bending等等。这个只是展示了一个工具包而已。