本文于2020-10-10更新。 如发现问题或者有建议,欢迎提交 Issue
gcForest
包是使用调用python
的gcForest
的资源。
knitr::opts_chunk$set(eval = FALSE)
# install.packages('gcForest')
1 这是一个新包
library(tidyverse)
library(packagefinder)
library(dlstats)
library(cranly)
sem_pkg <-
'Deep Forest' %>%
findPackage() %>%
as_tibble()
sem_pkg_download <-
sem_pkg %>%
rename_all(tolower) %>%
arrange(desc(score)) %>%
distinct(name) %>%
# head(100) %>%
.$name %>%
# 可以插入 vector,所以不需要map
cran_stats()
sem_pkg_download
2 数据预处理
sk <- reticulate::import('sklearn')
train_test_split <- sk$model_selection$train_test_split
data <- sk$datasets$load_iris
iris <- data()
x <- iris$data # matrix
y <- iris$target
data_split <- train_test_split(x, y, test_size=0.33)
x_tr <- data_split[[1]]
x_te <- data_split[[2]]
y_tr <- data_split[[3]]
y_te <- data_split[[4]]
3 训练模型
library(gcForest)
library(tidyverse)
library(lubridate)
gcforest_m <- gcforest(shape_1X=4L, window=2L, tolerance=0.0)
gcforest_m$fit(x_tr,y_tr)
gcf_model <-
model_save(
gcforest_m
,file.path(
'files'
,paste(today() %>% str_remove_all('-') %>% str_sub(3,-1),'gcforest_model.model',sep='_')
)
)
gcf <-
model_load(
file.path(
'files'
,list.files('files') %>% str_subset('gcforest_model.model') %>% max
)
)
gcf$fit(x_tr, y_tr)
gcf$fit(x_tr, y_tr)
4 预测结果
gcforest_m$predict(x_te)
y_te
gcforest_m$predict_proba(x_te)
- 可以看概率