--- title: "Introduction to parttree" output: rmarkdown::html_vignette vignette: > %\VignetteIndexEntry{Introduction to parttree} %\VignetteEngine{knitr::rmarkdown} %\VignetteEncoding{UTF-8} --- ```{r, include = FALSE} knitr::opts_chunk$set( collapse = TRUE, comment = "#>", warning = FALSE, out.width = "90%", # fig.width = 8, # dpi = 300, asp = 0.625 ) ``` ## Motivating example: Classifying penguin species Start by loading the **parttree** package alongside **rpart**, which comes bundled with the base R installation. For the basic examples that follow, I'll use the well-known [Palmer Penguins](https://allisonhorst.github.io/palmerpenguins/) dataset to demonstrate functionality. You can load this dataset via the parent package (as I have here), or import it directly as a CSV [here](https://vincentarelbundock.github.io/Rdatasets/csv/palmerpenguins/penguins.csv). ```{r setup} library(parttree) # This package library(rpart) # For fitting decisions trees # install.packages("palmerpenguins") data("penguins", package = "palmerpenguins") head(penguins) ``` ```{r thread_control, echo=-1} data.table::setDTthreads(2) ``` Dataset in hand, let's say that we are interested in predicting penguin _species_ as a function of 1) flipper length and 2) bill length. We could model this as a simple decision tree: ```{r tree} tree = rpart(species ~ flipper_length_mm + bill_length_mm, data = penguins) tree ``` Like most tree-based frameworks, **rpart** comes with a default `plot` method for visualizing the resulting node splits. ```{r rpart_plot} plot(tree, compress = TRUE) text(tree, use.n = TRUE) ``` While this is okay, I don't feel that it provides much intuition about the model's prediction on the _scale of the actual data_. In other words, what I'd prefer to see is: How has our tree partitioned the original penguin data? This is where **parttree** enters the fray. The package is named for its primary workhorse function `parttree()`, which extracts all of the information needed to produce a nice plot of our tree partitions alongside the original data. ```{r penguin_cl_plot} ptree = parttree(tree) plot(ptree) ``` _Et voila!_ Now we can clearly see how our model has divided up the Cartesian space of the data. Gentoo penguins typically have longer flippers than Chinstrap or Adelie penguins, while the latter have the shortest bills. From the perspective of the end-user, the `ptree` parttree object is not all that interesting in of itself. It is simply a data frame that contains the basic information needed for our plot (partition coordinates, etc.). You can think of it as a helpful intermediate object on our way to the visualization of interest. ```{r ptree} # See also `attr(ptree, "parttree")` ptree ``` Speaking of visualization, underneath the hood `plot.parttree` calls the powerful [**tinyplot**](https://grantmcdermott.com/tinyplot/) package. All of the latter's various customization arguments can be passed on to our `parttree` plot to make it look a bit nicer. For example: ```{r penguin_cl_plot_custom} plot(ptree, pch = 16, palette = "classic", alpha = 0.75, grid = TRUE) ``` ### Continuous predictions In addition to discrete classification problems, **parttree** also supports regression trees with continuous independent variables. ```{r penguin_reg_plot} tree_cont = rpart(body_mass_g ~ flipper_length_mm + bill_length_mm, data = penguins) tree_cont |> parttree() |> plot(pch = 16, palette = "viridis") ``` ## Supported model classes Alongside the [**rpart**](https://CRAN.R-project.org/package=rpart) model objects that we have been working with thus far, **parttree** also supports decision trees created by the [**partykit**](https://CRAN.R-project.org/package=partykit) package. Here we see how the latter's `ctree` (conditional inference tree) algorithm yields a slightly more sophisticated partitioning that the former's default. ```{r penguin_ctree_plot} library(partykit) ctree(species ~ flipper_length_mm + bill_length_mm, data = penguins) |> parttree() |> plot(pch = 16, palette = "classic", alpha = 0.5) ``` **parttree** also supports a variety of "frontend" modes that call `rpart::rpart()` as the underlying engine. This includes packages from both the [**mlr3**](https://mlr3.mlr-org.com/) and [**tidymodels**](https://www.tidymodels.org/) ([parsnip](https://parsnip.tidymodels.org/) or [workflows](https://workflows.tidymodels.org/)) ecosystems. Here is a quick demonstration using **parsnip**, where we'll also pull in a different dataset just to change things up a little. ```{r titanic_plot} set.seed(123) ## For consistent jitter library(parsnip) library(titanic) ## Just for a different data set titanic_train$Survived = as.factor(titanic_train$Survived) ## Build our tree using parsnip (but with rpart as the model engine) ti_tree = decision_tree() |> set_engine("rpart") |> set_mode("classification") |> fit(Survived ~ Pclass + Age, data = titanic_train) ## Now pass to parttree and plot ti_tree |> parttree() |> plot(pch = 16, jitter = TRUE, palette = "dark", alpha = 0.7) ``` ## ggplot2 The default `plot.parttree` method produces a base graphics plot. But we also support **ggplot2** via with a dedicated `geom_parttree()` function. Here we demonstrate with our initial classification tree from earlier. ```{r penguin_cl_ggplot2} library(ggplot2) theme_set(theme_linedraw()) ## re-using the tree model object from above... ggplot(data = penguins, aes(x = flipper_length_mm, y = bill_length_mm)) + geom_point(aes(col = species)) + geom_parttree(data = tree, aes(fill=species), alpha = 0.1) ``` Compared to the "native" `plot.parttree` method, note that the **ggplot2** workflow requires a few tweaks: - We need to need to plot the original dataset as a separate layer (i.e., `geom_point()`). - `geom_parttree()` accepts the tree object _itself_, not the result of `parttree()`.^[This is because `geom_parttree(data = )` calls `parttree()` internally.] Continuous regression trees can also be drawn with `geom_parttree`. However, I recommend adjusting the plot fill aesthetic since your model will likely partition the data into intervals that don't match up exactly with the raw data. The easiest way to do this is by setting your colour and fill aesthetic together as part of the same `scale_colour_*` call. ```{r penguin_reg_ggplot2} ## re-using the tree_cont model object from above... ggplot(data = penguins, aes(x = flipper_length_mm, y = bill_length_mm)) + geom_parttree(data = tree_cont, aes(fill=body_mass_g), alpha = 0.3) + geom_point(aes(col = body_mass_g)) + scale_colour_viridis_c(aesthetics = c('colour', 'fill')) # NB: Set colour + fill together ``` ### Gotcha: (gg)plot orientation As we have already said, `geom_parttree()` calls the companion `parttree()` function internally, which coerces the **rpart** tree object into a data frame that is easily understood by **ggplot2**. For example, consider our initial "ptree" object from earlier. ```{r ptree_redux} # ptree = parttree(tree) ptree ``` Again, the resulting data frame is designed to be amenable to a **ggplot2** geom layer, with columns like `xmin`, `xmax`, etc. specifying aesthetics that **ggplot2** recognizes. (Fun fact: `geom_parttree()` is really just a thin wrapper around `geom_rect()`.) The goal of **parttree** is to abstract away these kinds of details from the user, so that they can just specify `geom_parttree()`—with a valid tree object as the data input—and be done with it. However, while this generally works well, it can sometimes lead to unexpected behaviour in terms of plot orientation. That's because it's hard to guess ahead of time what the user will specify as the x and y variables (i.e. axes) in their other **ggplot2** layers.^[The default `plot.partree` method doesn't have this problem because it assigns the x and y variables for both the partitions and raw data points as part of the same function call.] To see what I mean, let's redo our penguin plot from earlier, but this time switch the axes in the main `ggplot()` call. ```{r penguin_cl_ggplot_mismatch} ## First, redo our first plot but this time switch the x and y variables p3 = ggplot( data = penguins, aes(x = bill_length_mm, y = flipper_length_mm) ## Switched! ) + geom_point(aes(col = species)) ## Add on our tree (and some preemptive titling..) p3 + geom_parttree(data = tree, aes(fill = species), alpha = 0.1) + labs( title = "Oops!", subtitle = "Looks like a mismatch between our x and y axes..." ) ``` As was the case here, this kind of orientation mismatch is normally (hopefully) pretty easy to recognize. To fix, we can use the `flip = TRUE` argument to flip the orientation of the `geom_parttree` layer. ```{r penguin_cl_ggplot_mismatch_flip} p3 + geom_parttree( data = tree, aes(fill = species), alpha = 0.1, flip = TRUE ## Flip the orientation ) + labs(title = "That's better") ```