library(tidymodels)
library(kerasnip)
data(diamonds, package = "ggplot2")
diamonds_split <- initial_split(diamonds, strata = price)
diamonds_train <- training(diamonds_split)
diamonds_rec <- recipe(price ~ ., data = diamonds_train) |>
step_dummy(all_nominal_predictors()) |>
step_zv(all_predictors()) |>
step_normalize(all_numeric_predictors())
wf <- workflow() |> add_recipe(diamonds_rec) |> add_model(spec)
params <- extract_parameter_set_dials(wf) |>
update(
num_body = dials::num_terms(c(1, 4)),
body_units = dials::hidden_units(c(32, 256)),
body_dropout = dials::dropout(c(0, 0.4))
)
grid <- grid_latin_hypercube(params, size = 20)
folds <- vfold_cv(diamonds_train, v = 5, strata = price)
res <- tune_grid(wf, resamples = folds, grid = grid)