Skip to contents

The Core Idea: From Keras Layers to Tidymodels Specs

The keras3 package allows for building deep learning models layer-by-layer, which is a powerful and flexible approach. However, the tidymodels ecosystem is designed around declarative model specifications, where you define what model you want and which of its parameters you want to tune, rather than building it imperatively.

kerasnip bridges this gap with a simple but powerful concept: layer blocks. You define the components of your neural network (e.g., an input block, a dense block, a dropout block) as simple R functions. kerasnip then uses these blocks as building materials to create a brand new parsnip model specification function for you.

This new function behaves just like any other parsnip model (e.g., rand_forest() or linear_reg()), making it easy to integrate into workflows and tune with tune.

We’ll start by loading kerasnip, tidymodels and keras3:

## ── Attaching packages ────────────────────────────────────── tidymodels 1.3.0 ──
##  broom        1.0.9      recipes      1.3.1
##  dials        1.4.1      rsample      1.3.1
##  dplyr        1.1.4      tibble       3.3.0
##  ggplot2      3.5.2      tidyr        1.3.1
##  infer        1.0.9      tune         1.3.0
##  modeldata    1.4.0      workflows    1.2.0
##  parsnip      1.3.2      workflowsets 1.1.1
##  purrr        1.1.0      yardstick    1.3.2
## ── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
##  purrr::discard() masks scales::discard()
##  dplyr::filter()  masks stats::filter()
##  dplyr::lag()     masks stats::lag()
##  recipes::step()  masks stats::step()
## 
## Attaching package: 'keras3'
## The following object is masked from 'package:yardstick':
## 
##     get_weights

Example 1: Building and Fitting a Basic MLP

Let’s start by building a simple Multi-Layer Perceptron (MLP) for a regression task using the mtcars dataset.

Step 1: Define the Layer Blocks

We need three blocks:

  • 1. An input block to initialize the model and define the input shape. kerasnip will automatically pass the input_shape argument during fitting.

  • 2. A dense block for our hidden layers. We’ll give it a units argument so we can control the number of neurons.

  • 3. An output block for the final prediction. For regression, this is typically a single neuron with a linear activation.

    # 1. The input block must initialize the model. 
    # input_shape is passed automatically by the fit engine. 
    mlp_input_block <- function(model, input_shape) {
      keras_model_sequential(input_shape = input_shape) 
    }
    
    # 2. A block for hidden layers. units will become a tunable parameter. 
    mlp_dense_block <- function(model, units = 32) {
      model |>
        layer_dense(units = units, activation = "relu") 
    }
    
    # 3. The output block for a regression model. 
    mlp_output_block <- function(model) {
      model |>
        layer_dense(units = 1) 
    } 

Step 2: Create the Model Specification

Now, we use create_keras_sequential_spec() to generate a new model function, which we’ll call basic_mlp(). We provide our layer blocks in the order they should be assembled.

create_keras_sequential_spec(
  model_name = "basic_mlp",
  layer_blocks = list(
    input = mlp_input_block,
    dense = mlp_dense_block,
    output = mlp_output_block
  ),
  mode = "regression" 
) 

This function call has a side-effect: a new function basic_mlp() is now available in our environment! Notice its arguments: kerasnipautomatically created num_dense (to control the number of dense layers) and dense_units (from the units argument in our mlp_dense_block).

Step 3: Use the Spec in a Workflow

We can now use basic_mlp() like any other parsnip model. Let’s define a model with two hidden layers, each with 64 units, and train it for 50 epochs.

spec <- basic_mlp(
  num_dense = 2,
  dense_units = 64,
  fit_epochs = 50,
  learn_rate = 0.01 
) |>
  set_engine("keras")

print(spec) 
## basic mlp Model Specification (regression)
## 
## Main Arguments:
##   num_input = structure(list(), class = "rlang_zap")
##   num_dense = 2
##   num_output = structure(list(), class = "rlang_zap")
##   dense_units = 64
##   learn_rate = 0.01
##   fit_batch_size = structure(list(), class = "rlang_zap")
##   fit_epochs = 50
##   fit_callbacks = structure(list(), class = "rlang_zap")
##   fit_validation_split = structure(list(), class = "rlang_zap")
##   fit_validation_data = structure(list(), class = "rlang_zap")
##   fit_shuffle = structure(list(), class = "rlang_zap")
##   fit_class_weight = structure(list(), class = "rlang_zap")
##   fit_sample_weight = structure(list(), class = "rlang_zap")
##   fit_initial_epoch = structure(list(), class = "rlang_zap")
##   fit_steps_per_epoch = structure(list(), class = "rlang_zap")
##   fit_validation_steps = structure(list(), class = "rlang_zap")
##   fit_validation_batch_size = structure(list(), class = "rlang_zap")
##   fit_validation_freq = structure(list(), class = "rlang_zap")
##   fit_verbose = structure(list(), class = "rlang_zap")
##   fit_view_metrics = structure(list(), class = "rlang_zap")
##   compile_optimizer = structure(list(), class = "rlang_zap")
##   compile_loss = structure(list(), class = "rlang_zap")
##   compile_metrics = structure(list(), class = "rlang_zap")
##   compile_loss_weights = structure(list(), class = "rlang_zap")
##   compile_weighted_metrics = structure(list(), class = "rlang_zap")
##   compile_run_eagerly = structure(list(), class = "rlang_zap")
##   compile_steps_per_execution = structure(list(), class = "rlang_zap")
##   compile_jit_compile = structure(list(), class = "rlang_zap")
##   compile_auto_scale_loss = structure(list(), class = "rlang_zap")
## 
## Computational engine: keras

We’ll use a simple recipe to normalize the predictors and combine it with our model spec in a workflow.

# Suppress verbose Keras output for the vignette 
options(keras.fit_verbose = 0) 
 
rec <- recipe(mpg ~ ., data = mtcars) |>
  step_normalize(all_numeric_predictors())

wf <- workflow() |>
  add_recipe(rec) |>
  add_model(spec)

set.seed(123) 
fit_obj <- fit(wf, data = mtcars) 

Step 4: Make Predictions

Predictions work just as you’d expect in tidymodels.

predictions <- predict(fit_obj, new_data = mtcars[1:5, ]) 
## 1/1 - 0s - 38ms/step
print(predictions)
## # A tibble: 5 × 1
##   .pred
##   <dbl>
## 1  18.9
## 2  17.8
## 3  23.7
## 4  17.1
## 5  16.6

Example 2: Tuning the Model Architecture

The real power of kerasnip comes from its ability to tune not just hyperparameters (like learning rate or dropout), but the architecture of the network itself.

Let’s create a more complex tunable specification where we let tune find the optimal number of dense layers, the number of units in those layers, and the rate for a final dropout layer.

Step 1: Define Blocks and Create a New Spec

First, we’ll define an additional block for dropout and then create a new model specification, tunable_mlp, that includes it.

tunable_dropout_block <- function(model, rate = 0.2) {
  model |>
    layer_dropout(rate = rate)
}

create_keras_sequential_spec(
  model_name = "tunable_mlp",
  layer_blocks = list(
    input = mlp_input_block,
    dense = mlp_dense_block,
    dropout = tunable_dropout_block,
    output = mlp_output_block
  ),
  mode = "regression"
)

Step 2: Define a Tunable Specification

We use our new tunable_mlp() function, passing tune() to the arguments we want to optimize. We will have one dropout layer before the output.

tune_spec <- tunable_mlp(
  num_dense = tune(),
  dense_units = tune(),
  num_dropout = 1,
  dropout_rate = tune(),
  fit_epochs = 20
) |>
  set_engine("keras")

print(tune_spec)
## tunable mlp Model Specification (regression)
## 
## Main Arguments:
##   num_input = structure(list(), class = "rlang_zap")
##   num_dense = tune()
##   num_dropout = 1
##   num_output = structure(list(), class = "rlang_zap")
##   dense_units = tune()
##   dropout_rate = tune()
##   learn_rate = structure(list(), class = "rlang_zap")
##   fit_batch_size = structure(list(), class = "rlang_zap")
##   fit_epochs = 20
##   fit_callbacks = structure(list(), class = "rlang_zap")
##   fit_validation_split = structure(list(), class = "rlang_zap")
##   fit_validation_data = structure(list(), class = "rlang_zap")
##   fit_shuffle = structure(list(), class = "rlang_zap")
##   fit_class_weight = structure(list(), class = "rlang_zap")
##   fit_sample_weight = structure(list(), class = "rlang_zap")
##   fit_initial_epoch = structure(list(), class = "rlang_zap")
##   fit_steps_per_epoch = structure(list(), class = "rlang_zap")
##   fit_validation_steps = structure(list(), class = "rlang_zap")
##   fit_validation_batch_size = structure(list(), class = "rlang_zap")
##   fit_validation_freq = structure(list(), class = "rlang_zap")
##   fit_verbose = structure(list(), class = "rlang_zap")
##   fit_view_metrics = structure(list(), class = "rlang_zap")
##   compile_optimizer = structure(list(), class = "rlang_zap")
##   compile_loss = structure(list(), class = "rlang_zap")
##   compile_metrics = structure(list(), class = "rlang_zap")
##   compile_loss_weights = structure(list(), class = "rlang_zap")
##   compile_weighted_metrics = structure(list(), class = "rlang_zap")
##   compile_run_eagerly = structure(list(), class = "rlang_zap")
##   compile_steps_per_execution = structure(list(), class = "rlang_zap")
##   compile_jit_compile = structure(list(), class = "rlang_zap")
##   compile_auto_scale_loss = structure(list(), class = "rlang_zap")
## 
## Computational engine: keras

Step 3: Set up the Tuning Grid

We create a workflow as before. Then, we can use helper functions from dials to define the search space for our parameters.

tune_wf <- workflow() |>
  add_recipe(rec) |>
  add_model(tune_spec)

# Define the tuning grid. 
# `num_terms()` is the dials function for `num_*` parameters.
# `hidden_units()` is the dials function for `*_units` parameters.
params <- extract_parameter_set_dials(tune_wf) |>
  update(
    num_dense = dials::num_terms(c(1, 3)),
    dense_units = dials::hidden_units(c(8, 64)),
    dropout_rate = dials::dropout(c(0.1, 0.5))
  )
grid <- grid_regular(params, levels = 2) 
print(grid) 
## # A tibble: 8 × 3
##   num_dense dense_units dropout_rate
##       <int>       <int>        <dbl>
## 1         1           8          0.1
## 2         3           8          0.1
## 3         1          64          0.1
## 4         3          64          0.1
## 5         1           8          0.5
## 6         3           8          0.5
## 7         1          64          0.5
## 8         3          64          0.5

Step 4: Run the Tuning

We use tune_grid() with resamples to evaluate each combination of architectural parameters.

set.seed(456) 
folds <- vfold_cv(mtcars, v = 3) 
 
# The control argument is used to prevent saving predictions, which 
# can be large for Keras models. 
tune_res <- tune_grid(
  tune_wf,
  resamples = folds,
  grid = grid,
  control = control_grid(save_pred = FALSE) 
) 
## 1/1 - 0s - 33ms/step
## 1/1 - 0s - 40ms/step
## 1/1 - 0s - 33ms/step
## 1/1 - 0s - 42ms/step
## 1/1 - 0s - 34ms/step
## 1/1 - 0s - 43ms/step
## 1/1 - 0s - 33ms/step
## 1/1 - 0s - 41ms/step
## 1/1 - 0s - 34ms/step
## 1/1 - 0s - 41ms/step
## 1/1 - 0s - 33ms/step
## 1/1 - 0s - 41ms/step
## 1/1 - 0s - 33ms/step
## 1/1 - 0s - 41ms/step
## 1/1 - 0s - 33ms/step
## 1/1 - 0s - 41ms/step
## 1/1 - 0s - 33ms/step
## 1/1 - 0s - 40ms/step
## 1/1 - 0s - 34ms/step
## 1/1 - 0s - 41ms/step
## 1/1 - 0s - 33ms/step
## 1/1 - 0s - 41ms/step
## 1/1 - 0s - 34ms/step
## 1/1 - 0s - 42ms/step

Step 5: Analyze the Results

We can now see which architecture performed best.

show_best(tune_res, metric = "rmse")
## # A tibble: 5 × 9
##   num_dense dense_units dropout_rate .metric .estimator  mean     n std_err
##       <int>       <int>        <dbl> <chr>   <chr>      <dbl> <int>   <dbl>
## 1         3          64          0.1 rmse    standard    4.70     3   0.627
## 2         3          64          0.5 rmse    standard    5.34     3   0.626
## 3         1          64          0.1 rmse    standard   10.6      3   0.996
## 4         1          64          0.5 rmse    standard   12.7      3   0.824
## 5         3           8          0.5 rmse    standard   14.4      3   1.94 
## # ℹ 1 more variable: .config <chr>

The results show that tune has successfully tested different network depths (num_dense), widths (dense_units), and dropout rates to find the best-performing combination. This demonstrates how kerasnip seamlessly integrates complex architectural tuning into the standard tidymodels workflow.

Advanced Customization

kerasnip provides a clean API for passing arguments directly to Keras’s compile() and fit() methods.

  • Compile Arguments: Pass any argument to keras3::compile() by prefixing it with compile_. For example, to change the loss function you would use compile_loss = "mae".
  • Fit Arguments: Pass any argument to keras3::fit() by prefixing it with fit_. For example, to set a validation split and add a callback, you would use fit_validation_split = 0.2 and fit_callbacks = list(...).

Here is an example of using these arguments to specify a different loss function, a validation split, and an early stopping callback.

adv_spec <- basic_mlp(
  num_dense = 2,
  dense_units = 32,
  fit_epochs = 100,
  # Arguments for keras3::compile()
  compile_loss = "mae",
  # Arguments for keras3::fit()
  fit_validation_split = 0.2,
  fit_callbacks = list(
    keras3::callback_early_stopping(patience = 5)
  )
) |>
  set_engine("keras")

print(adv_spec)
## basic mlp Model Specification (regression)
## 
## Main Arguments:
##   num_input = structure(list(), class = "rlang_zap")
##   num_dense = 2
##   num_output = structure(list(), class = "rlang_zap")
##   dense_units = 32
##   learn_rate = structure(list(), class = "rlang_zap")
##   fit_batch_size = structure(list(), class = "rlang_zap")
##   fit_epochs = 100
##   fit_callbacks = list(keras3::callback_early_stopping(patience = 5))
##   fit_validation_split = 0.2
##   fit_validation_data = structure(list(), class = "rlang_zap")
##   fit_shuffle = structure(list(), class = "rlang_zap")
##   fit_class_weight = structure(list(), class = "rlang_zap")
##   fit_sample_weight = structure(list(), class = "rlang_zap")
##   fit_initial_epoch = structure(list(), class = "rlang_zap")
##   fit_steps_per_epoch = structure(list(), class = "rlang_zap")
##   fit_validation_steps = structure(list(), class = "rlang_zap")
##   fit_validation_batch_size = structure(list(), class = "rlang_zap")
##   fit_validation_freq = structure(list(), class = "rlang_zap")
##   fit_verbose = structure(list(), class = "rlang_zap")
##   fit_view_metrics = structure(list(), class = "rlang_zap")
##   compile_optimizer = structure(list(), class = "rlang_zap")
##   compile_loss = mae
##   compile_metrics = structure(list(), class = "rlang_zap")
##   compile_loss_weights = structure(list(), class = "rlang_zap")
##   compile_weighted_metrics = structure(list(), class = "rlang_zap")
##   compile_run_eagerly = structure(list(), class = "rlang_zap")
##   compile_steps_per_execution = structure(list(), class = "rlang_zap")
##   compile_jit_compile = structure(list(), class = "rlang_zap")
##   compile_auto_scale_loss = structure(list(), class = "rlang_zap")
## 
## Computational engine: keras

This system gives you full control over the Keras training process while keeping the model specification function signature clean and focused on the tunable parameters.