XGBoost models

Function Works
tidypredict_fit(), tidypredict_sql(), parse_model()
tidypredict_to_column()
tidypredict_test()
tidypredict_interval(), tidypredict_sql_interval()
parsnip

tidypredict_ functions

library(xgboost)

logregobj <- function(preds, dtrain) {
  labels <- xgboost::getinfo(dtrain, "label")
  preds <- 1 / (1 + exp(-preds))
  grad <- preds - labels
  hess <- preds * (1 - preds)
  return(list(grad = grad, hess = hess))
}

xgb_bin_data <- xgboost::xgb.DMatrix(
  as.matrix(mtcars[, -9]), 
  label = mtcars$am
  )

model <- xgboost::xgb.train(
  params = list(max_depth = 2, objective = "binary:logistic", base_score = 0.5),
  data = xgb_bin_data, nrounds = 50
)

parsnip

parsnip fitted models are also supported by tidypredict:

library(parsnip)

p_model <- boost_tree(mode = "regression") %>%
  set_engine("xgboost") %>%
  fit(am ~ ., data = mtcars)
tidypredict_test(p_model, mtcars, xg_df = xgb_bin_data)
#> tidypredict test results
#> Difference threshold: 1e-12
#> 
#> Fitted records above the threshold: 15
#> 
#> Fit max  difference:
#> Lower max difference:
#> Upper max difference:8.06462707725331e-08

Parse model spec

Here is an example of the model spec:

pm <- parse_model(model)
str(pm, 2)
#> List of 2
#>  $ general:List of 7
#>   ..$ model        : chr "xgb.Booster"
#>   ..$ type         : chr "xgb"
#>   ..$ niter        : num 50
#>   ..$ params       :List of 4
#>   ..$ feature_names: chr [1:10] "mpg" "cyl" "disp" "hp" ...
#>   ..$ nfeatures    : int 10
#>   ..$ version      : num 1
#>  $ trees  :List of 42
#>   ..$ 0 :List of 3
#>   ..$ 1 :List of 3
#>   ..$ 2 :List of 3
#>   ..$ 3 :List of 3
#>   ..$ 4 :List of 3
#>   ..$ 5 :List of 3
#>   ..$ 6 :List of 3
#>   ..$ 7 :List of 3
#>   ..$ 8 :List of 3
#>   ..$ 9 :List of 3
#>   ..$ 10:List of 3
#>   ..$ 11:List of 2
#>   ..$ 12:List of 2
#>   ..$ 13:List of 2
#>   ..$ 14:List of 2
#>   ..$ 15:List of 2
#>   ..$ 16:List of 2
#>   ..$ 17:List of 2
#>   ..$ 18:List of 2
#>   ..$ 19:List of 2
#>   ..$ 20:List of 2
#>   ..$ 21:List of 2
#>   ..$ 22:List of 2
#>   ..$ 23:List of 2
#>   ..$ 24:List of 2
#>   ..$ 25:List of 2
#>   ..$ 26:List of 2
#>   ..$ 27:List of 2
#>   ..$ 28:List of 2
#>   ..$ 29:List of 2
#>   ..$ 30:List of 2
#>   ..$ 31:List of 2
#>   ..$ 32:List of 2
#>   ..$ 33:List of 2
#>   ..$ 34:List of 2
#>   ..$ 35:List of 2
#>   ..$ 36:List of 2
#>   ..$ 37:List of 2
#>   ..$ 38:List of 2
#>   ..$ 39:List of 2
#>   ..$ 40:List of 2
#>   ..$ 41:List of 2
#>  - attr(*, "class")= chr [1:3] "parsed_model" "pm_xgb" "list"
str(pm$trees[1])
#> List of 1
#>  $ 0:List of 3
#>   ..$ :List of 2
#>   .. ..$ prediction: num -0.436
#>   .. ..$ path      :List of 1
#>   .. .. ..$ :List of 5
#>   .. .. .. ..$ type   : chr "conditional"
#>   .. .. .. ..$ col    : chr "wt"
#>   .. .. .. ..$ val    : num 3.18
#>   .. .. .. ..$ op     : chr "less"
#>   .. .. .. ..$ missing: logi FALSE
#>   ..$ :List of 2
#>   .. ..$ prediction: num 0.429
#>   .. ..$ path      :List of 2
#>   .. .. ..$ :List of 5
#>   .. .. .. ..$ type   : chr "conditional"
#>   .. .. .. ..$ col    : chr "qsec"
#>   .. .. .. ..$ val    : num 19.2
#>   .. .. .. ..$ op     : chr "more-equal"
#>   .. .. .. ..$ missing: logi TRUE
#>   .. .. ..$ :List of 5
#>   .. .. .. ..$ type   : chr "conditional"
#>   .. .. .. ..$ col    : chr "wt"
#>   .. .. .. ..$ val    : num 3.18
#>   .. .. .. ..$ op     : chr "more-equal"
#>   .. .. .. ..$ missing: logi TRUE
#>   ..$ :List of 2
#>   .. ..$ prediction: num 0
#>   .. ..$ path      :List of 2
#>   .. .. ..$ :List of 5
#>   .. .. .. ..$ type   : chr "conditional"
#>   .. .. .. ..$ col    : chr "qsec"
#>   .. .. .. ..$ val    : num 19.2
#>   .. .. .. ..$ op     : chr "less"
#>   .. .. .. ..$ missing: logi FALSE
#>   .. .. ..$ :List of 5
#>   .. .. .. ..$ type   : chr "conditional"
#>   .. .. .. ..$ col    : chr "wt"
#>   .. .. .. ..$ val    : num 3.18
#>   .. .. .. ..$ op     : chr "more-equal"
#>   .. .. .. ..$ missing: logi TRUE