1 min read

gcForest 使用技巧

本文于2020-10-10更新。 如发现问题或者有建议,欢迎提交 Issue

gcForest包是使用调用pythongcForest的资源。

knitr::opts_chunk$set(eval = FALSE)
# install.packages('gcForest')

1 这是一个新包

library(tidyverse)
library(packagefinder)
library(dlstats)
library(cranly)
sem_pkg <- 
    'Deep Forest' %>% 
        findPackage() %>% 
        as_tibble()
sem_pkg_download <- 
    sem_pkg %>% 
    rename_all(tolower) %>% 
    arrange(desc(score)) %>% 
    distinct(name) %>% 
    # head(100) %>% 
    .$name %>% 
    # 可以插入 vector,所以不需要map
    cran_stats()
sem_pkg_download

2 数据预处理

sk <- reticulate::import('sklearn')
train_test_split <- sk$model_selection$train_test_split
data <- sk$datasets$load_iris
iris <- data()
x <- iris$data # matrix
y <- iris$target
data_split <- train_test_split(x, y, test_size=0.33)
x_tr <- data_split[[1]]
x_te <- data_split[[2]]
y_tr <- data_split[[3]]
y_te <- data_split[[4]]

3 训练模型

library(gcForest)
library(tidyverse)
library(lubridate)
gcforest_m <- gcforest(shape_1X=4L, window=2L, tolerance=0.0)
gcforest_m$fit(x_tr,y_tr)
gcf_model <- 
    model_save(
        gcforest_m
        ,file.path(
            'files'
            ,paste(today() %>% str_remove_all('-') %>% str_sub(3,-1),'gcforest_model.model',sep='_')
        )
    )

gcf <- 
    model_load(
        file.path(
            'files'
            ,list.files('files') %>% str_subset('gcforest_model.model') %>% max
        )
    )
gcf$fit(x_tr, y_tr)
gcf$fit(x_tr, y_tr)

4 预测结果

gcforest_m$predict(x_te)
y_te
gcforest_m$predict_proba(x_te)
  1. 可以看概率