我在工作流中使用 tidymodels 和 xgboost 时遇到问题。在应用包括step_dummy()
将分类变量转换为虚拟变量的配方后,我在尝试进行预测时收到以下错误:
Error in `validate_column_names()`:
! The following required columns are missing: 'A', 'B', 'C', 'D'.
这是我的代码的简化版本:
library(tidymodels)
library(xgboost)
library(dplyr)
set.seed(123)
datensatz <- tibble(
outcome = rnorm(100, mean = 60, sd = 10),
A = factor(sample(c("h", "i", "j"), 100, replace = TRUE)),
B = factor(sample(c("e", "f", "g"), 100, replace = TRUE)),
C = factor(sample(1:3, 100, replace = TRUE)),
D = factor(sample(c("a", "b"), 100, replace = TRUE))
)
# splitting
data_split <- initial_split(datensatz, prop = 0.75)
train_data <- training(data_split)
test_data <- testing(data_split)
# Rezept
recipe_obj <- recipe(outcome ~ ., data = train_data) %>%
step_dummy(all_nominal(), -all_outcomes()) %>%
step_zv(all_predictors()) %>%
step_normalize(all_numeric_predictors())
prepared_recipe <- prep(recipe_obj)
test_data_prepared <- bake(prepared_recipe, new_data = test_data)
# XGBoost Modell Spezifikation
xgboost_spec <- boost_tree(
trees = 1000,
tree_depth = 6,
min_n = 10,
loss_reduction = 0.01,
sample_size = 0.8,
mtry = 0.8,
learn_rate = 0.01
) %>%
set_mode("regression") %>%
set_engine("xgboost", count = FALSE, colsample_bytree = 0.8)
# Workflow
workflow_obj <- workflow() %>%
add_recipe(recipe_obj) %>%
add_model(xgboost_spec)
# Modell trainieren
xgboost_fit <- fit(workflow_obj, data = train_data)
# Modellvorhersage auf den vorbereiteten Testdaten
predictions <- predict(xgboost_fit, new_data = test_data_prepared)
# Ergebnisse
predictions
# Error occurs here
我怀疑问题与step_dummy()
删除原始分类列(A, B, C, D)
并用虚拟变量替换有关。但是,工作流程似乎在进行预测时需要原始列。
我该如何解决这个问题并确保预测步骤正确使用创建的虚拟变量step_dummy()
?
附加信息:
I'm using the `xgboost engine` within the `tidymodels` framework.
The error message suggests that the workflow expects the original categorical variables, but these are no longer present after applying `step_dummy()`.
如果您在工作流程中使用配方,则无需手动
prep()
设置bake()
测试数据集。因此,您可以删除以下几行并预测
predict(xgboost_fit, new_data = test_data)
而不是predict(xgboost_fit, new_data = test_data_prepared)
创建于 2024-08-30,使用reprex v2.1.1