## ----include = FALSE---------------------------------------------------------- knitr::opts_chunk$set( collapse = TRUE, comment = "#>", eval = requireNamespace("tidymodels", quietly = TRUE) ) ## ----setup, message=FALSE, warning=FALSE-------------------------------------- library(tidymodels) library(bnns) ## ----reg-spec----------------------------------------------------------------- bnn_reg_spec <- mlp( mode = "regression", hidden_units = 5, epochs = 500, activation = "relu" ) %>% set_engine( engine = "bnns", chains = 2, warmup = 250, refresh = 0, seed = 123 ) bnn_reg_spec ## ----reg-fit, eval=FALSE------------------------------------------------------ # bnn_reg_wf <- workflow() %>% # add_model(bnn_reg_spec) %>% # add_formula(mpg ~ hp + wt + cyl + disp) # # # Fit the model # bnn_reg_fit <- fit(bnn_reg_wf, data = mtcars) # # bnn_reg_fit ## ----reg-pred, eval=FALSE----------------------------------------------------- # predictions <- predict(bnn_reg_fit, new_data = mtcars) # head(predictions) ## ----class-spec--------------------------------------------------------------- bnn_class_spec <- mlp( mode = "classification", hidden_units = 4, epochs = 500, activation = "tanh" ) %>% set_engine( engine = "bnns", chains = 1, warmup = 200, refresh = 0, seed = 456 ) ## ----class-fit, eval=FALSE---------------------------------------------------- # iris_rec <- recipe(Species ~ ., data = iris) %>% # step_normalize(all_numeric_predictors()) # # bnn_class_wf <- workflow() %>% # add_model(bnn_class_spec) %>% # add_recipe(iris_rec) # # bnn_class_fit <- fit(bnn_class_wf, data = iris) ## ----class-pred, eval=FALSE--------------------------------------------------- # # 1. Predict hard classes (returns a .pred_class factor column) # class_preds <- predict(bnn_class_fit, new_data = iris, type = "class") # head(class_preds) # # # 2. Predict class probabilities (returns .pred_{Level} columns) # prob_preds <- predict(bnn_class_fit, new_data = iris, type = "prob") # head(prob_preds) ## ----eval-metrics, eval=FALSE------------------------------------------------- # eval_data <- bind_cols(iris, class_preds, prob_preds) # # accuracy(eval_data, truth = Species, estimate = .pred_class) # roc_auc(eval_data, truth = Species, .pred_setosa:.pred_virginica)