(Chapman & Hall - CRC Data Science Series) Brandon M. Greenwell - Tree-Based Methods For Statistical Learning in R - A Practical Introduction With Applications in R-CRC Press (2022)
(Chapman & Hall - CRC Data Science Series) Brandon M. Greenwell - Tree-Based Methods For Statistical Learning in R - A Practical Introduction With Applications in R-CRC Press (2022)
for Statistical
Learning in R
The book follows up most ideas and mathematical concepts with code-based
examples in the R statistical language; with an emphasis on using as few external
packages as possible. For example, users will be exposed to writing their own
random forest and gradient tree boosting functions using simple for loops and
basic tree fitting software (like rpart and party/partykit), and more. The core
chapters also end with a detailed section on relevant software in both R and
other opensource alternatives (e.g., Python, Spark, and Julia), and example usage
on real data sets. While the book mostly uses R, it is meant to be equally acces-
sible and useful to non-R programmers.
Consumers of this book will have gained a solid foundation (and appreciation)
for tree-based methods and how they can be used to solve practical problems and
challenges data scientists often face in applied work.
Features:
• Thorough coverage, from the ground up, of tree-based methods (e.g.,
CART, conditional inference trees, bagging, boosting, and random
forests).
• A companion website containing additional supplementary material
and the code to reproduce every example and figure in the book.
• A companion R package, called treemisc, which contains several data
sets and functions used throughout the book (e.g., there’s an implemen-
tation of gradient tree boosting with LAD loss that shows how to per-
form the line search step by updating the terminal node estimates of a
fitted rpart tree).
• Interesting examples that are of practical use; for example, how to con-
struct partial dependence plots from a fitted model in Spark MLlib
(using only Spark operations), or post-processing tree ensembles via
the LASSO to reduce the number of trees while maintaining, or even
improving performance.
CHAPMAN & HALL/CRC DATA SCIENCE SERIES
Reflecting the interdisciplinary nature of the field, this book series brings together
researchers, practitioners, and instructors from statistics, computer science,
machine learning, and analytics. The series will publish cutting-edge research,
industry applications, and textbooks in data science.
Published Titles
Data Analytics
A Small Data Approach
Shuai Huang and Houtao Deng
Data Science
An Introduction
Tiffany-Anne Timbers, Trevor Campbell and Melissa Lee
Urban Informatics
Using Big Data to Understand and Serve Communities
Daniel T. O’Brien
Brandon M. Greenwell
First edition published 2022
by CRC Press
6000 Broken Sound Parkway NW, Suite 300, Boca Raton, FL 33487-2742
Reasonable efforts have been made to publish reliable data and information, but the author and pub-
lisher cannot assume responsibility for the validity of all materials or the consequences of their use.
The authors and publishers have attempted to trace the copyright holders of all material reproduced
in this publication and apologize to copyright holders if permission to publish in this form has not
been obtained. If any copyright material has not been acknowledged please write and let us know so
we may rectify in any future reprint.
Except as permitted under U.S. Copyright Law, no part of this book may be reprinted, reproduced,
transmitted, or utilized in any form by any electronic, mechanical, or other means, now known or
hereafter invented, including photocopying, microfilming, and recording, or in any information
storage or retrieval system, without written permission from the publishers.
For permission to photocopy or use material electronically from this work, access www.copyright.
com or contact the Copyright Clearance Center, Inc. (CCC), 222 Rosewood Drive, Danvers, MA
01923, 978-750-8400. For works that are not available on CCC please contact mpkbookspermis-
[email protected]
Trademark notice: Product or corporate names may be trademarks or registered trademarks and are
used only for identification and explanation without intent to infringe.
DOI: 10.1201/9781003089032
Typeset in LM Roman
by KnowledgeWorks Global Ltd.
Publisher’s note: This book has been prepared from camera-ready copy provided by the authors.
To my son, Oliver.
Contents
Preface xiii
1 Introduction 1
1.1 Select topics in statistical and machine learning . . . . . . . 2
1.1.1 Statistical jargon and conventions . . . . . . . . . . . 3
1.1.2 Supervised learning . . . . . . . . . . . . . . . . . . . 4
1.1.2.1 Description . . . . . . . . . . . . . . . . . . . 5
1.1.2.2 Prediction . . . . . . . . . . . . . . . . . . . 6
1.1.2.3 Classification vs. regression . . . . . . . . . 7
1.1.2.4 Discrimination vs. prediction . . . . . . . . . 7
1.1.2.5 The bias-variance tradeoff . . . . . . . . . . . 8
1.1.3 Unsupervised learning . . . . . . . . . . . . . . . . . . 10
1.2 Why trees? . . . . . . . . . . . . . . . . . . . . . . . . . . . . 10
1.2.1 A brief history of decision trees . . . . . . . . . . . . 12
1.2.2 The anatomy of a simple decision tree . . . . . . . . . 14
1.2.2.1 Example: survival on the Titanic . . . . . . . 15
1.3 Why R? . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . 17
1.3.1 No really, why R? . . . . . . . . . . . . . . . . . . . . 17
1.3.2 Software information and conventions . . . . . . . . . 19
1.4 Some example data sets . . . . . . . . . . . . . . . . . . . . . 20
1.4.1 Swiss banknotes . . . . . . . . . . . . . . . . . . . . . 21
1.4.2 New York air quality measurements . . . . . . . . . . 21
1.4.3 The Friedman 1 benchmark problem . . . . . . . . . 23
1.4.4 Mushroom edibility . . . . . . . . . . . . . . . . . . . 24
1.4.5 Spam or ham? . . . . . . . . . . . . . . . . . . . . . . 25
1.4.6 Employee attrition . . . . . . . . . . . . . . . . . . . . 28
1.4.7 Predicting home prices in Ames, Iowa . . . . . . . . . 29
1.4.8 Wine quality ratings . . . . . . . . . . . . . . . . . . 30
1.4.9 Mayo Clinic primary biliary cholangitis study . . . . 31
1.5 There ain’t no such thing as a free lunch . . . . . . . . . . . 35
1.6 Outline of this book . . . . . . . . . . . . . . . . . . . . . . . 35
I Decision trees 37
2 Binary recursive partitioning with CART 39
2.1 Introduction . . . . . . . . . . . . . . . . . . . . . . . . . . . 39
vii
viii Contents
Bibliography 359
Index 381
Preface
xiii
xiv Preface
Nonetheless, this book does assume some familiarity with the basics of sta-
tistical and machine learning, as well as the R programming language. Useful
references and resources are provided in the introductory material in Chap-
ter 1. While I try to provide sufficient detail and background where possible,
some topics could only be given cursory treatment, though, whenever possi-
ble, I try to point the more ambitious reader in the right direction in terms
of references.
Companion website
Along with the companion website, there’s also a companion R package, called
treemisc [Greenwell, 2021c], that houses a number of the data sets and func-
tions used throughout this book. Installation instructions and documentation
can be found in the package’s GitHub repository at
https://github.com/bgreenwell/treemisc.
Preface xv
Colorblindess
This book contains many visuals in color. I have tried as much as pos-
sible to keep every figure colorblind friendly. For the most part, I use
the Okabe-Ito color palette, designed by Masataka Okabe and Kei Ito
(https://jfly.uni-koeln.de/color/), which is available in R (>=4.0.0);
see ?grDevices::palette.colors for details. If you find any of the visuals
hard to read (whether due to color blindness or not) please consider reporting
it so that it can be corrected in the next available printing/version.
Acknowledgments
I’m extremely grateful to Bradley Boehmke, who back in 2016 asked me to help
him write “Hands-On Machine Learning with R” [Boehmke and Greenwell,
2020]. Without that experience, I would not have had the confidence (nor the
skill or patience) to prepare this book on my own. Thank you, Brad.
Also, a huge thanks to Alex Gutman and Jay Cunningham, who both agreed to
provide feedback on an earlier draft of this book. Their reviews and attention
to detail have ultimately led to a much improved presentation of the material.
Thank you both.
Lastly, I cannot express how much I owe to my wonderful wife Jennifer, and
our three kids: Julia, Lillian, and Oliver. You help inspire all I do and keep
me sane, and I truly appreciate you putting up with me while I worked on
this book.
Brandon M. Greenwell
1
Introduction
Charles Dickens
A Christmas Carol
Ever play a game called twenty questions? If you have, then you should have no
trouble understanding the basics of how decision trees work. A decision tree
is essentially a set of sequential yes or no questions regarding the available
1
2 Introduction
that will help introduce more advanced topics later in the book. This book
does assume, however, that readers have at least some general background or
exposure to common topics in statistics and machine learning (like hypothesis
testing, cross-validation, and hyperparameter tuning). If you’re looking for
a more thorough overview of statistical and machine learning, I’d suggest
starting with James et al. [2021]. For a deeper dive, go with Hastie et al.
[2009]. Both books are freely available for download, if you choose not to
purchase a hard copy. Harrell [2015] and Matloff [2017], while more statistical
in nature, offer valuable takes on several concepts fundamental to statistical
and machine learning, and I highly recommend each.
but whose effect is not of direct interest (e.g., we may not care to interpret its effect, but
we include it to improve the overall model; think analysis of covariance).
4 Introduction
Y = f (x) + , (1.1)
where is a random variable with mean zero (i.e., E () = 0) and is assumed
to be independent of x. Note that the response is also a random variable here
since it is a function of . The function f (x) is fixed and represents the system-
atic part of the relationship between Y and x. As is almost always the case,
the true relationship between Y and x is often statistical in nature (i.e., not
deterministic) and the additive error helps to capture the non-deterministic
aspect of this relationship (e.g., unobserved predictors, measurement error,
etc.).
Since we assume E () = 0, it turns out that f (x) can be viewed as a condi-
tional expectation:
E (Y |x) = f (x) ,
where we can interpret f (x) as the mean response for all observations
with predictor values equal to x. In the case of J-class classification (Sec-
tion 1.1.2.3), we can still view f (x) as a conditional mean; in this case it’s
the conditional proportion corresponding to a particular class: E (Y = j|x),
which can be interpreted as an estimate of Pr (Y = j|x)—the probability that
Y belongs to class j given a particular set of predictor values x. In this sense,
class probability estimation is really a regression problem.
The term “supervised” in supervised learning refers to the fact that we use
N
labeled training datac {yi , xi }i=1 and an algorithm to learn a reasonable map-
ping between the observed response values, yi , and a set of predictor values,
xi . Without a labeled response column, the task would be unsupervised (Sec-
tion 1.1.3), and the analytic goal would be different.
An estimate fˆ (x) of f (x) can be used for either description or prediction (or
both). I’ll briefly discuss the meaning of each in turn next.
1.1.2.1 Description
For example, in the Ames housing data (Section 1.4.7), we may be interested
in determining which predictors are most influential on the predicted sale price
in a fitted model. We may also be interested in how a particular feature (e.g.,
overall house size) functionally relates to the predicted sale price from a fitted
model.
Questions like these are relatively straightforward to glean from simpler mod-
els, like an additive linear model or a simple decision tree. However, this type of
information is often hidden in more complicated nonparametric models—like
neural networks (NNs) and support vector machines (SVMs)—which unfortu-
nately has given rise to the term “black box” models. In Chapter 6, I’ll look at
several model-agnostic techniques that can be helpful in extracting relevant
descriptive information from any supervised learning model.
1.1.2.2 Prediction
As the name implies, prediction tasks are concerned with predicting future or
unobserved outcomes. For example, we may be interested in predicting the sale
price for a new home given a set of relevant features. This could, in theory, be
a useful starting point in setting the listing price for a home, or trying to help
infer whether or not a particular house is under- or over-valued. Great care
must be taken in such problems, however, as the outcome variable (the sale
price of homes, in this case) can be complex in nature and the available data
may not be enough to adequately capture sudden changes in the distribution
of future response values; a bit more on this in Section 1.4.7.
It should be stressed that prediction and description often go hand in hand.
Description helps provide transparency in how a model’s predictions are gen-
erated. Transparency helps reveal potential issues and biases and therefore
can help increase trust or distrust in a model’s predictions. Would you feel
comfortable putting a model into production if you did not have some un-
derstanding as to how different subsets of features contribute to the model’s
predictions?
Single decision trees, while often great descriptors, seldom make for good pre-
dictors, at least when compared to more contemporary techniques. Nonethe-
less, as we’ll see in Part I of this book, single decision trees are sometimes the
right tool for the job, but it just so happens that more accurate decision trees
tend to be harder to interpret. This is especially true for the decision trees dis-
cussed in Chapter 4, which are flexible enough to achieve good performance,
but often pay a price in interpretability.
Select topics in statistical and machine learning 7
Supervised learning tasks generally fall into one of two categories: classification
or regression. Regression is used in a very general sense here, and often refers
to any supervised learning task with an ordered outcome. Examples of ordered
outcomes might be sale price or wine quality on a scale of 0–10 (essentially,
an ordered category).
In classification, the response is categorical and the objective is to “classify”
new observations into one of J possible categories. In the mushroom clas-
sification example (Section 1.4.4), for instance, the goal is to classify new
mushrooms as either edible or poisonous on the basis of simple observational
attributes about each (like the color and odor of each mushroom). In this ex-
ample, J = 2 (edible or poisonous) and the task is one of binary classification.
When J > 2, the task is referred to as multiclass classification.
Pure classification is almost never the goal, as we are usually not interested in
directly classifying observations into one of J categories. Instead, interest often
lies in estimating the conditional probability of class membership. That is to
J
say, it is often far more informative to estimate {Pr (Y = j|x0 )}j=1 as opposed
to predicting the class membership of some observation x0 . Even when the
term “classification” is used, the underlying goal is usually that of estimating
class membership probabilities conditional on the feature valuesd .
Frank Harrell, a prominent biostatistician, couldn’t have said it better:
as it is often used in situations where the true goal is class probability estimation.
8 Introduction
Frank Harrell
https://www.fharrell.com/post/classification/
The terms overfitting and underfitting are used throughout this book (more
so the former), but what do they mean? Overfitting occurs when your model
is too complex, and has gone past any signal in the data and is starting to
fit the noise. Underfitting, on the other hand, refers to when a model is too
simple and does not adequately capture any of the signal in the data. In both
cases, the model will not generalize well to new data.
A model that is overfitting the learning sample often exhibits lower bias but
has higher variance when compared to a model that is underfitting, which
often exhibits higher bias but lower variance. This tradeoff is more specifically
referred to as the bias-variance tradeoff. Excellent discussions of this topic can
be found in Matloff [2017, Sec. 1.11] and Hastie et al. [2009, Sec. 7.2–7.3]; the
latter provides more of a theoretical view.
For the additive error model (1.1) with constant variance σ 2 , Hastie et al.
[2009] show that the mean square prediction error for an arbitrary observation
x0 can be decomposed into
2
E Y − fˆ (x0 ) |x = x0 = σ 2 + Bias2 fˆ (x0 ) + V fˆ (x0 ) ,
h i h i
1.3 1.3
y
1.1 1.1
1.0
1.0 1.0
0.9 0.9
0.8
0.8 0.8
0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0
x x x
FIGURE 1.1: Fitted mean response from three linear models applied to the
quadratic example. Left: a simple linear model (i.e., degree one polynomial).
Middle: a quadratic model (i.e., the correct model). Right: a 20-th degree
polynomial model.
the predictor space into disjoint sets, one for each class. They can be useful for understanding
and comparing the flexibility and performance of different classifiers.
f A 1-NN model classifies a new observation according to the class of its nearest neighbor
in the learning sample; in this case, “nearest neighbor” is defined as the closest observation
in the training set as measured by the Euclidean distance.
10 Introduction
4 1−NN
Bayes
2
x2
−2
−4
−4 −2 0 2 4
x1
There are a number of great modeling tools available to data scientists. But
what makes a modeling tool good? Is being able to achieve competitive accu-
racy all that matters? Of course not. According to the late Leo Breiman, a
good modeling tool, at a minimum, should:
• be applicable to both classification (binary and multiclass) and regression;
• be competitive in terms of accuracy;
Why trees? 11
was acquired by Minitab in 2017), hence, open source implementations often go by different
names. For example, a popular open source implementation of MARSis the fantastic earth
package [Milborrow, 2021a] in R.
12 Introduction
or careful tuning. You can generally just hit “Run” and get something useful.
THAT IS NOT TO SAY THAT YOU SHOULD NOT PUT TIME AND EF-
FORT INTO CLEANING UP YOUR DATA AND CAREFULLY TUNING
THESE MODELS. Rather, trees can work seamlessly in rather messy data
situations (e.g., outliers, missing values, skewness, mixed data types, etc.)
without requiring the level of pre-processing necessary for other algorithms to
“just work” (e.g., neural networks). For example, even if I’m not using a deci-
sion tree for the final model, I will often use it as a first pass as it can give me
a quick and dirty picture of the data, and any serious issues (which can easily
be missed in the exploratory phase) will often be highlighted (e.g., extreme
target leakage or accidentally leaving in an ID column). In other words, trees
make great exploratory tools, especially when dealing with potentially messy
data.
Decision trees have a long and rich history in both statistics and computer
science, and have been around for many decades. However, decision trees
arguably got their true start in the social sciences. Motivated by the need
Why trees? 13
for finding interaction effects in complex survey data, Morgan and Sonquist
[1963] developed and published the first decision tree algorithm for regression
called automatic interaction and detection (AID). Starting at the root node,
AID recursively partitions the data into two homogeneous subgroups, called
child nodes, by maximizing the between-node sum of squares, similar to the
process described in Section 2.3. AID continues successively bisecting each
resulting child node until the reduction in the within-node sum of squares is
less than some prespecified threshold.
Messenger and Mandell [1972] extended AID to classification in their theta
automatic interaction detection (THAID) algorithm. The theta criterion used
in THAID to choose splits maximizes the sum of the number of observations
in each modal category.
The chi-square automatic interaction detection (CHAID) algorithm, intro-
duced in Kass [1980], improved upon AID by countering some of its initial
criticisms; CHAID was original developed for classification and later extended
to also handle regression problems. Similar to the decision tree algorithms dis-
cussed in Chapters 3–4, CHAID employs statistical tests and stopping rules
to select the splitting variables and split points. In particular, CHAID relies
on chi-squared tests, which require discretizing ordered variables into bins.
Compared to AID and THAID, CHAID was unique in that it allowed mul-
tiway splits (which typically require larger sample sizes, otherwise the child
nodes can become too small rather quickly) and included a separate category
for missing values.
Despite the novelty of AID, THAID, and CHAID, it wasn’t until Breiman et al.
[1984] introduced the more general classification and regression tree (CART)
algorithmh , that tree-based algorithms started to catch on in the statistical
community. CART-like decision trees are the topic of Chapter 2. A similar
tree-based algorithm, called C4.5 [Quinlan, 1993], which evolved into the cur-
rent C5.0 algorithmi , has become very similar to CART in many regards;
hence, I focus on CART in this book and discuss the details of C4.5/C5.0 in
the online supplementary material.
CART helped generate renewed interest in partitioning methods, and we’ve
seen that evolution unfold over the last several decades. While the history is
rife with advancements, the first part of this book will focus on three of the
most important tree-based algorithms:
Chapter 2: Classification and regression trees (CART).
Chapter 3: Conditional inference trees (CTree).
h Like MARS, the term “CART” is also trademarked and licensed to Salford Systems,
In this section, I’ll look at the basic parts of a typical decision tree (perhaps
tree topology would’ve been a cooler section header).
A typical (binary) decision tree is displayed in Figure 1.3. The tree is made up
of nodes and branches; the path between two consecutive branches is called
an edge. The nodes are the points at which a branch occurs. Here, we have
three internal nodes, labeled I1 , I2 , and I3 , and five terminal (or leaf) nodes,
labeled L1 , L2 , L3 , L4 , and L5 . The tree is binary because it only uses two-
way splits; that is, at each node, a split results in only two branches, labeled
“Yes” and “No”k . The path taken at each internal node depends on whether or
not the corresponding split condition is satisfied. For example, an observation
with x1 = 0.33 and x5 = 1.19 would find itself in terminal node L2 , regardless
of the values of the other predictors.
The split conditions (or just splits) for an ordered predictor x have the form
x < c vs. x ≥ c, where c is in the domain of x (typically the midpoint between
two consecutive x values in the learning sample); note that the same type of
splits are used for ordered factors since we just need to preserve the natural
ordering of the categories (e.g., x < medium vs. x ≥ medium)l . Splits on
(unordered) categorical variables have the form x ∈ S1 vs. x ∈ S2 , where
S = S1 ∪ S2 is the full set of unique categories in x.
Each tree begins with a root node containing the entire learning sample. Start-
ing with the root node, the training data are split into two non-overlapping
groups, one going left, and the other right, depending on whether or not the
first split condition is satisfied by each observation. The process is repeated
on each subgroup (or child node) until each observation reaches a terminal
node.
j For those interested, QUEST stands for quick, unbiased, and efficient statistical tree and
CRUISE stands for classification rule with unbiased interaction selection and estimation.
k Some decision tree algorithms allow multiway (i.e., > 2) splits, but none are really
To further illustrate, let’s look at the tree diagram in Figure 1.4. This CART-
like decision tree was constructed using the well-known Titanic data set and
is trying to separate passengers who survived from those who didn’t using
readily available information about each. The data, which I’ll revisit in Sec-
tion 7.9.3, contain N = 1, 309 observations (i.e., passengers) on the following
six variables:
• survived: binary indicator of passenger survival (the response);
• pclass: integer specifying passenger class (i.e., 1–3);
• age: passenger age in years;
• sex: factor giving the sex of each passenger (i.e., male/female);
• sibsp: integer specifying the number of siblings/spouses aboard;
• parch: integer specifying the number of parents/children aboard.
The variable pclass is commonly treated as nominal categorical, but here the
natural ordering has been taken into account. The tree split the passengers into
six relatively homogeneous groups (terminal nodes) based on four of the above
m Fitted values are just the predicted values for each observation in the learning sample.
16 Introduction
Root node
x1 < 0.75?
No Yes
I1 I2
x3 ∈ {a, b, e}? x5 < 0.9?
No Yes No Yes
I3
L1 x1 < 0.25? L2 L3
No Yes
L4 L5
five available features; parch is the only variable not selected to partition the
data. The terminal nodes (nodes 6–11) each contain a node summary giving
the proportion of surviving passengers in that node. As we’ll see in later
chapters, these proportions can be used as class probability estimates.
For example, the tree diagram estimates that first and second class female
passengers had a 93% chance of survival; the percentage displayed in the
bottom of each node corresponds to the fraction of training observations used
to define that node. Given what you know about the ill-fated Titanic, does
the tree diagram make sense to you? Does it appear that women and children
were given priority and had a higher chance of survival (i.e., “women and
children first”)? Perhaps, unless you were a third-class passenger.
In Part I of this book, I’ll look at how several popular decision tree algorithms
choose which variables to split on (splitters) and how each split condition is
determined (e.g., age < 9.5). Part II of this book will then look at how
to improve the accuracy and generalization performance of a single tree by
combining several hundred or thousand individual trees together.
Why R? 17
1
0.38
100%
no sex = female yes
2 3
0.19 0.73
64% 36%
age < 9.5 pclass < 3
4 5
0.17 0.53
61% 4%
pclass < 2 sibsp < 3
8 9 10 11 6 7
0.13 0.33 0.05 0.89 0.49 0.93
47% 13% 2% 2% 17% 19%
1.3 Why R?
Why not?
I grew up on R (and SAS), but I’ve chosen it for this book primarily for one
reason: it currently provides the best support and access to a wide range of
tree-based methods (both classic and modern-day)n . For example, I’m not
aware of any non-R open source implementation of tree-based methods that
provides full support for nominal categorical variables without requiring them
to be encoded numericallyo . Another good example is conditional inference
trees [Hothorn et al., 2006c], the topic of Chapter 3, which are only imple-
mented in R (as far as I’m aware). Nonetheless, there is at least one Python
example included in this book!
n This book was also written in R (and LaTeX) using the wonderful knitr package Xie
[2021].
o There’s been some progress in areas of scikit-learn and other open source software that
While I have a strong appreciation for the power of base R (i.e., the core lan-
guage), I’m extremely appreciative of, and often rely on, the amazing ecosys-
tem of contributed packages. Nonetheless, I’ve decided to keep the use of
external packages to a minimum (aside from those packages related to the
core tree-based methods discussed in this book), and instead rely on vanilla
R programming as much as possible. This choice was made for several (highly
opinionated) reasons:
• it will make the book easier to maintain going forward, as the code ex-
amples will (hopefully) continue to work for many years to come without
much modification;
• using standard R programming constructs (e.g., for loops and their apply-
style functional replacements, like lapply()) will make the material easier
to comprehend for non-R programmers, and easier to translate to other
open source languages, like Julia and Python;
• it emphasizes the basic concepts of the methods being introduced, rather
than focusing on current best coding practices and cool packages, of which
there are many. (Please don’t send me hate mail for using sapply() instead
of vapply() or the family of map functions available in purrr [Henry and
Wickham, 2020].)
Note that I’ve tried to be as aggressive as possible in terms of commenting
upon the various code snippets scattered throughout this book and the online
supplementary material. You should pay careful attention to these comments
as they often link a particular line or section of code to a specific step in an
algorithm, try to explain a hacky approach I’m using, and so on.
Each chapter includes one or more software sections, which highlight both
R and non-R implementations of the relevant algorithms under discussion.
Additionally, each chapter also contains R-specific software example sections
(usually at the end of each chapter), which demonstrate use of relevant tree-
specific software on actual data (either simulated or real).
There are many great resources for learning R, but I would argue that the
online manual, “An Introduction to R”, which can be found at
https://cran.r-project.org/doc/manuals/r-release/R-intro.html,
is a great place to start. Book-wise, I think that Matloff [2011] and Wickham
[2019] are some of the best resources for learning R (note that the latter is
freely available to read online). If you’re interested in a more hands-on intro-
duction to statistical and machine learning with R, I’ll happily self-promote
Boehmke and Greenwell [2020].
Why R? 19
Each chapter contains at least one software section, which points to relevant
implementations of the ideas discussed in the corresponding chapter. And
while this book focuses on R, these sections will also allude to additional
implementations in other open source programming languages, like Python
and Julia. Furthermore, several code snippets are contained throughout the
book to help solidify certain concepts (mostly in R).
Be warned, I occasionally.use.dots in variable and function names—old
R programming habits die hard. Package names are in bold text (e.g., rpart),
inline code and function names are in typewriter font (e.g., sapply()), and file
names are in sans serif font (e.g., path/to/filename.txt). In situations where it
may not be obvious which package a function belongs to, I’ll use the notation
foo::bar(), where bar() is the name of a function in package foo.
I often allude to the documentation and help pages for specific R functions. For
example, you can view the documentation for function foo() in package bar
by typing ?foo::bar or help("foo", package = "bar") at the R console.
It’s a good idea to read these help pages as they will often provide more
useful details, further references, and example usage. For base R functions—
that is, functions available in R’s base package—I omit the package name
(e.g., ?kronecker). I also make heavy use of R’s apply()-family of functions
throughout the book, often for brevity and to avoid longer code snippets based
on for loops. If you’re unfamiliar with these, I encourage you to start with
the help pages for both apply() and lapply().
R package vignettes (when available) often provide more in-depth details on
specific functionality available in a particular package. You can browse any
available vignettes for a CRAN package, say foo, by visiting the package’s
homepage on CRAN at
https://cran.r-project.org/package=foo.
You can also use the utils package to view package vignettes during an
active R session. For example, the vignettes accompanying the R package
rpart [Therneau and Atkinson, 2019], which is heavily used in Chapter 2,
can be found at https://CRAN.R-project.org/package=rpart or by typing
utils::vignette("bar", package = "foo") at the R console.
There’s a myriad of R packages available for fitting tree-based models, and
this book only covers a handful. If you’re not familiar with CRAN’s task
views, you should be. They provide useful guidance on which packages on
CRAN are relevant to a certain topic (e.g., machine learning). The task view
on statistical and machine learning, for example, which can be found at
https://cran.r-project.org/web/views/MachineLearning.html,
lists several R packages useful for fitting tree-based models across a wide vari-
ety of situations. For instance, it lists RWeka [Hornik, 2021] as providing an
20 Introduction
open source interface to the J4.8-variant of C4.5 and M5 (see the online sup-
plementary material on the book website). A brief description of all available
task views can be found at https://cran.r-project.org/web/views/.
Keep in mind that the focus of this book is to help you build a deeper under-
standing of tree-based methods, it is not a programming book. Nonetheless,
writing, running, and experimenting with code is one of the best ways to learn
this subject, in my opinion.
This book uses a couple of graphical parameters and themes for plotting that
are set behind the scene. So don’t fret if your plots don’t look exactly the
same when running the code. This book uses a mix of base R and ggplot2
[Wickham et al., 2021a] graphics, though, I think there’s a lattice [Sarkar,
2021] graphic or two floating around somewhere. For ggplot2-based graphics,
I use the theme_bw() theme, which can be set at the top level (i.e., for all
plots) using theme_set(theme_bw()). Most of the base R graphics in this
book use the following par() settings (see ?graphics::par for details on
each argument):
par(
mar = c(4, 4, 0.1, 0.1), # may be different for a handful of figures
cex.lab = 0.95,
cex.axis = 0.8,
mgp = c(2, 0.7, 0),
tcl = -0.3,
las = 1
)
Some of the base R graphics in this book use a slightly different setting for
the mar argument (e.g., to make room for plots that also have a top axis, like
Figure 8.12 on page 349).
The examples in this book make use of several data sets, both real and sim-
ulated, and both small and large. Many of the data sets are available in the
treemisc package that accompanies this book (or another R package), but
many are also available for download from the book’s website:
https://bgreenwell.github.io/treebook/datasets.html.
In this section, I’ll introduce a handful of the data sets used in the examples
throughout this book. Some of these data sets are pretty common, and are
Some example data sets 21
The Swiss banknote data [Flury and Riedwyl, 1988] contain measurements
from 200 Swiss 1000-franc banknotes: 100 genuine and 100 counterfeit. There
are six available predictors, each giving the length (in mm) of a different
dimension for each bill (e.g., the length of the diagonal). The response variable
is a 0/1 indicator for whether or not the bill was genuine/counterfeit. This is
a small data set that will be useful when exploring how some classification
trees are constructed. The code snippet below generates a simple scatterplot
matrix of the data, which is displayed in Figure 1.5:
bn <- treemisc::banknote
cols <- palette.colors(3, palette = "Okabe-Ito")
pairs(bn[, 1L:6L], col = adjustcolor(cols[bn$y + 2], alpha.f = 0.5),
pch = c(1, 2)[bn$y + 1], cex = 0.7)
Note how good some of the features are at discriminating between the two
classes (e.g., top and diagonal). This is a small data set that will be used to
illustrate fundamental concepts in decision tree building in Chapters 2–3.
The New York air quality data contain daily air quality measurements
in New York from May through September of 1973 (153 days). The
data are conveniently available in R’s built-in datasets package; see
?datasets::airquality for details and the original source. The main vari-
ables include:
• Ozone: the mean ozone (in parts per billion) from 1300 to 1500 hours at
Roosevelt Island;
• Solar.R: the solar radiation (in Langleys) in the frequency band 4000–
7700 Angstroms from 0800 to 1200 hours at Central Park;
• Wind: the average wind speed (in miles per hour) at 0700 and 1000 hours
at LaGuardia Airport;
• Temp: the maximum daily temperature (in degrees Fahrenheit) at La
Guardia Airport.
22 Introduction
216.0
215.5
le g h 215.0
21 .5
21 .0
131.0
130.5
130.0 le
129.5
129.0
131.0
130.5
r gh 130.0
129.5
129.0
12
11
10
9
12
11
p 10
9
1 2
1 1
1 0 ag al
139
13
FIGURE 1.5: Scatterplot matrix of the Swiss banknote data. The black cir-
cles and orange triangles correspond to genuine and counterfeit banknotes,
respectively.
The month (1–12) and day of the month (1–31) are also available in the
columns Month and Day, respectively. In these data, Ozone is treated as a
response variable.
This is another small data set that will be useful when exploring how some
regression trees are constructed. A simple scatterplot matrix of the data is
constructed below; see Figure 1.6. The upper diagonal scatterplots each con-
tain a LOWESS smooth p of the data (red curve). Note that there’s a relatively
strong nonlinear relationship between Ozone and both Temp and Wind, com-
pared to the others.
p A LOWESS smoother is a nonparametric smooth based on locally-weighted polynomial
aq <- datasets::airquality
color <- adjustcolor("forestgreen", alpha.f = 0.5)
ps <- function(x, y, ...) { # custom panel function
panel.smooth(x, y, col = color, col.smooth = "black",
cex = 0.7, lwd = 2)
}
pairs(aq, cex = 0.7, upper.panel = ps, col = color)
0 100 250 60 0 0 10 20 30
150
e
100
50
300
250
S lar.
200
150
100
50
0
20
15
10
90
70
0
e p
60
h 7
6
5
30
25
ay
20
15
10
5
0
0 50 100 5 10 15 20 5 6 7 9
FIGURE 1.6: Scatterplot matrix of the New York air quality data. Each black
curve in the upper panel represents a LOWESS smoother.
2
Y = 10 sin (πX1 X2 ) + 20 (X3 − 0.5) + 10X4 + 5X5 + , (1.2)
where ∼ N (0, σ) and the input features are all independent uniform random
1 iid
variables on the interval [0, 1]: {Xj }j=1 0 ∼ U (0, 1). Notice how X6 –X10 are
unrelated to the response Y .
These data can be generated in R using the mlbench.friedman1 func-
tion from package mlbench [Leisch and Dimitriadou., 2021]. Here, I’ll
use the gen_friedman1 function from package treemisc, which al-
lows you to generate any number of features ≥ 5; similar to the
make_friedman1 function in scikit-learn’s sklearn.datasets module for
Python. See ?treemisc::gen_friedman1 for details. Below, I generate a sam-
ple of N = 5 observations from (1.2) with only seven features (so it prints
nicely):
set.seed(943) # for reproducibility
treemisc::gen_friedman1(5, nx = 7, sigma = 0.1)
#> y x1 x2 x3 x4 x5 x6 x7
#> 1 18.5 0.346 0.853 0.655 0.839 0.293 0.3408 0.0573
#> 2 13.7 0.442 0.691 0.214 0.108 0.543 0.1616 0.5055
#> 3 10.8 0.223 0.789 0.807 0.252 0.257 0.8595 0.7248
#> 4 18.9 0.859 0.520 0.891 0.129 0.936 0.0348 0.7105
#> 5 14.5 0.181 0.590 0.893 0.611 0.415 0.4104 0.2636
From (1.2), it should be clear that features X1 –X5 are the most important!
(The others don’t influence Y at all.) Also, based on the form of the model,
we’d expect X4 to be the most important feature, probably followed by X1
and X2 (both comparably important), with X5 probably being less important.
The influence of X3 is harder to determine due to its quadratic nature, but it
seems likely that this nonlinearity will suppress the variable’s influence over its
observed range (i.e., [0, 1]). Since the true nature of E (Y |x) is known, albeit
somewhat complex (e.g., nonlinear relationships and an explicit interaction
effect), these data are useful in testing out different model interpretability
techniques (at least on numeric features), like those discussed in Section 6.
Since these data are convenient to generate, I’ll use them in a couple of small-
scale simulations throughout this book.
The mushroom edibility data is one of my favorite data sets. It contains 8124
mushrooms described in terms of 22 different physical characteristics, like odor
and spore print color. The response variable (Edibility) is a binary indicator
for whether or not each mushroom is Edible or Poisonous. The data are avail-
able from the UCI Machine Learning repository at https://archive.ics.
Some example data sets 25
These data refer to N = 4, 601 emails classified as either spam (i.e., junk
email) of non-spam (i.e. “ham”) that were collected at Hewlett-Packard (HP)
Labs. In addition to the class label, there are 57 predictors giving the relative
frequency of certain words and characters in each email. For example, the
column charDollar gives the relative frequency of dollar signs ($) appearing in
each email. The data are available from the UCI Machine Learning repository
at
https://archive.ics.uci.edu/ml/datasets/spambase.
In R, the data can be loaded from the kernlab package [Karatzoglou et al.,
2019]; see ?kernlab::spam for further details.
Below, I load the data into R, check the frequency of spam and non-spam
emails, then look at the average relative frequency of several different words
and characters between each:
data(spam, package = "kernlab")
almond
anise
creosote
fishy
foul
musty
odor
none
pungent
spicy
Edibility
FIGURE 1.7: Mosaic plot visualizing the relationship between mushroom ed-
ibility and odor. The area of each tile is proportional to the number of obser-
vations in the particular category.
Notice how the first three variables show a much larger difference between
spam and non-spam emails; we might expect these to be important predictors
(at least compared to the other two) in classifying new HP emails as spam
vs. non-spam. For example, given that these emails all came from Hewlett-
Packard Labs, the fact that the non-spam emails contain a much higher rela-
tive frequency of the word hp makes sense (email spam was not as clever back
in 1998).
As a preview of what’s to come, the code chunk below fits a basic decision
tree with three splits (i.e., it asks three yes or no questions) to a 70% random
Some example data sets 27
sample of the data. It also takes into account the specified assumption that
classifying a non-spam email as spam is five times more costly than classifying
a spam email as non-spam. We’ll learn all about rpart and the steps taken
below in Chapter 2.
library(rpart)
library(treemisc)
1
nonspam
.60 .40
100%
2 no remove >= 0.01 yes
nonspam
.72 .28
82%
charDollar >= 0.088 5
spam
.15 .85
12%
hp < 0.25
4 10 11 3
nonspam nonspam spam spam
.82 .18 .91 .09 .08 .92 .05 .95
70% 1% 11% 18%
FIGURE 1.8: Decision tree diagram for a simple classification tree applied to
the email spam learning sample.
28 Introduction
The associated tree diagram is displayed in Figure 1.8. This tree is too simple
and underfits the training data (I’ll re-analyze these data using an ensemble
in Chapter 5). Nonetheless, simple decision trees can often be displayed as a
small set of simple rules. As a set of mutually exclusive and exhaustive rules,
the tree diagram in Figure 1.8 translates to:
Rule 1 (path to terminal node 3)
IF remove >= 0.01
THEN classification = spam (probability = 0.95)
table(attrition$Attrition)
#>
#> No Yes
#> 1233 237
The Ames housing data [De Cock, 2011], which are available in the R pack-
age AmesHousing [Kuhn, 2020], contain information from the Ames As-
sessor’s Office used in computing assessed values for individual residential
properties sold in Ames, Iowa from 2006–2010; online documentation describ-
ing the data set can be found at http://jse.amstat.org/v19n3/decock/
DataDocumentation.txt. These data are often used as a more contemporary
replacement to the often cited—and ethically challenged [Carlisle, 2019]—
Boston housing data [Harrison and Rubinfeld, 1978].
The data set contains N = 2, 930 observations on 81 variables. The response
variable here is the final sale price of the home (Sale_Price). The remaining
80 variables, which I’ll treat as predictors, are a mix of both ordered and
categorical features.
In the code chunk below, I’ll load the data into R and split it into train/test
sets using a 70/30 split, which I’ll use in several examples throughout this
book (note that for plotting purposes, mostly to avoid large numbers on the
y-axis, I’ll rescale the response by dividing by 1,000):
ames <- as.data.frame(AmesHousing::make_ames())
ames$Sale_Price <- ames$Sale_Price / 1000 # rescale response
set.seed(2101) # for reproducibility
trn.id <- sample.int(nrow(ames), size = floor(0.7 * nrow(ames)))
ames.trn <- ames[trn.id, ] # training data/learning sample
ames.tst <- ames[-trn.id, ] # test data
Figure 1.9 shows a scatterplot of sale price vs. above grade (ground) living
area in square feet (Gr_Liv_Area) from the 70% learning sample. Above grade
living area, as we’ll see in later chapters, is arguably one of the more impor-
tant predictors in this data set (as you might expect). It is evident from this
plot that heteroscedasticity is present, with variation in sale price increasing
with home size. Linear models assume constant variance whenever relying on
the usual normal theory standard errors and confidence intervals for interpre-
tation. Outliers are another potential problem.
plot(Sale_Price ~ Gr_Liv_Area, data = ames.trn,
col = adjustcolor(1, alpha.f = 0.5),
xlab = "Above grade square footage",
ylab = "Sale price / 1000")
30 Introduction
600
Sale pr e 1000
00
200
FIGURE 1.9: Scatterplot of sale price vs. above grade (ground) living area in
square feet from the Ames housing training sample; here you can see five or
so potential outliers.
Note that predictions based solely on these data should not be used alone in
setting the sale price of a home. I mean, they could, but they would likely
not perform well over time. There are many complexities involved in valuing
a home, and housing markets change over time. With the data at hand, it
can be hard to predict such changes, especially during the initial Covid-19
outbreak during which the majority of this book was written (many things
became rather hard to predict and forecast). However, such a model could be
a useful place to start, especially for descriptive purposes.
These data are related to red and white variants of the Portuguese “Vinho
Verde” wine; for details, see Cortez et al. [2009]. Due to privacy and logistic
issues, only physicochemical and sensory variables are available (e.g., there is
no data about grape types, wine brand, wine selling price, etc.). The response
variable here is the wine quality score (quality), which is an ordered integer
in the range 0–10.
The data are available in the R package treemisc and can be used for classi-
fication or regression, but given the ordinal nature of the response, the latter
is more appropriate; see ?treemisc::wine. The data can also be downloaded
from the UCI Machine Learning repository at https://archive.ics.uci.
Some example data sets 31
Note that most wines (red or white) are mediocre and relatively few have very
high or low scores. The response here, while truly an integer in the range 0–10,
is often treated as binary by arbitrarily discretizing the ordered response into
“low quality” and “high quality” wines. A more appropriate analysis, which
utilizes the fact that the response is ordered, is given in Section 3.5.2.
This example concerns data from a study by the Mayo Clinic on primary
biliary cholangitis (PBC) of the liver conducted between January 1974 and
May 1984; follow-up continued through July 1986. PBC is an autoimmune
disease leading to destruction of the small bile ducts in the liver. There were
a total of N = 418 patients whose survival time and censoring indicator were
known (I’ll discuss what these mean briefly). The goal was to compare the
drug D-penicillamine with a placebo in terms of survival probability. The
drug was ultimately found to be ineffective; see, for example, Fleming and
Harrington [1991, p. 2] and Ahn and Loh [1994] (the latter employs a tree-
based analysis). An additional 16 potential covariates are included which I’ll
investigate further as predictors in Section 3.5.3.
Below, I load the data from the survival package [Therneau, 2021] and do
some prep work. For starters, I’ll only consider the subset of patients who were
randomized into the D-penicillamine and placebo groups; see ?survival::pbc
for details. Second, I’ll consider the small number of subjects who underwent
liver transplant to be censored at the day of transplantq :
library(survival)
q As mentioned in Harrell [2015, Sec. 8.9], liver transplantation was rather uncommon at
the time the data were collected, so it still constitutes a natural history study for PBC.
32 Introduction
In this sample, 125 subjects died (i.e., experienced the event of interest) and
the remaining 187 were considered censored (i.e., we only know they did not
die before dropping out, receiving a transplant, or reaching the end of the
study period).
In survival studies (like this one), the dependent variable of interest is often
time until some event occurs; in this example, the event of interest is death.
However, medical studies cannot go on forever, and sometimes subjects drop
out or are otherwise lost to follow-up. In these situations, we may not have
observed the event time, but we at least have some partial information. For
example, some of the subjects may have survived beyond the study period,
or perhaps some dropped out due to other circumstances. Regardless of the
specific reason, we at least have some partial information on these subjects,
which survival analysis (also referred to as time-to-event or reliability analysis)
takes into account.
The scatterplot in Figure 1.10 shows the survival times for the first ten subjects
in the PBC data, with an indicator for whether or not each observation was
censored. The first subject, for example, was recorded dead at t = 400 days,
while subject two was censored at t = 4, 500 days.
In survival analysis, the response variable typically has the form
Y = min (T, C) ,
where T is the survival time and C is the censoring time. In this book, I’ll only
consider right censoring (the most common form of censoring), where T ≥ Y .
In this case, all we know is that the true event time is at least as large as
the observed timer . For example, if we were studying the failure time of some
motor in a machine, we might have observed a failure at time t = 56 hours,
or perhaps the study ended at t = 100 hours, so all we know is that the true
failure time would have occurred some time after that.
To indicate that a particular observation is censored, we can use a censoring
indicator:
r Left censoring and interval censoring are other common forms of censoring.
Some example data sets 33
0 e s re
10 1 1 ea h
9 1
1
7 0
6 1
Su e
5 0
1
3 1
2 0
1 1
FIGURE 1.10: Survival times for the first ten (randomized) subjects in the
Mayo Clinic PBC data.
(
1 if T ≤ C
δ= ,
0 if T > C (i.e., censored)
which describes the probability of surviving longer than time t. The Kaplan-
Meier (or product limit) estimator is a nonparametric statistic used for esti-
mating the survival function in the presence of censoring (if there isn’t any
censoring, then we could just use the ordinary empirical distribution function).
34 Introduction
The details are beyond the scope of this book, but the survfit function from
package survival can do the heavy lifting for us.
In the code snippet below, I call survfit to estimate and plot the survival
curves for both the drug and placebo groups; see Figure 1.11. Here, you can see
that the estimated survival curves between the treatment and control group
are similar, indicating that D-penicillamine is rather ineffective. The log-rank
test can be used to test for differences between the survival distributions of
two groups. Some decision tree algorithms for the analysis of survival data use
the log-rank test to help partition the data; see, for example, Segal [1988] and
Leblanc and Crowley [1993].
palette("Okabe-Ito")
plot(survfit(Surv(time, status) ~ trt, data = pbc2), col = 2:3,
conf.int = FALSE, las = 1, xlab = "Days until death",
ylab = "Estimated survival probability")
legend("bottomleft", legend = c("Penicillmain", "Placebo"),
lty = 1, col = 2:3, text.col = 2:3, inset = 0.01, bty = "n")
palette("default")
1.0
Estimated survival probability
0.8
0.6
0.4
0.2
Penicillmain
Placebo
0.0
FIGURE 1.11: Kaplan-Meier estimate of the survival function for the ran-
domized subjects in the Mayo Clinic PBC data by treatment group (i.e., drug
vs. placebo). The median survival times are 3282 days (drug) and 3428 days
(placebo).
In Section 3.5.3, we’ll see how a simple tree-based analysis can estimate the
survival function conditional on a set of predictors, denoted Ŝ (t|x), by parti-
tioning the learning sample into non-overlapping groups with similar survival
rates; here, we’ll see further evidence that D-penicillamine was not effective
There ain’t no such thing as a free lunch 35
Too often, we see papers or hear arguments claiming that some cool new
algorithm A is better than some existing algorithms B and C at doing D. This
is mostly baloney, as any experienced statistician or modeler would tell you
that no one procedure or algorithm is uniformly superior across all situations.
That being said, you should not walk away from this book with the impression
that tree-based methods are superior to any other algorithm or modeling tool.
They are powerful and flexible tools for sure, but that doesn’t always mean
they’re the right tool for the job. Consider them as simply another tool to
include in your modeling and analysis toolbox.
This book is about decision trees, both individual trees (Part I) and ensembles
thereof (Part II). There are a large number of decision tree algorithms in
existence, and entire books have even been dedicated to some. Consequently, I
had to be quite selective in choosing the topics to present in detail in this book,
which has mostly been guided by my experiences with tree-based methods
over the years in both academics and industry. As mentioned in Loh [2014],
“There are so many recursive partitioning algorithms in the literature that it
is nowadays very hard to see the wood for the trees.”
I’ll discuss some of the major, and most important tree-based algorithms in
current use today. However, due to time and page constraints, several im-
portant algorithms and extensions didn’t make the final cut, and are instead
discussed in the (free) online supplementary material that can be found on
the book website. These methods include:
• C5.0 [Kuhn and Johnson, 2013, Sec. 14.6], the successor to C4.5 [Quinlan,
1993], which is similar enough to CART that including it in a separate
chapter would be largely redundant with Chapter 2;
• MARS, which was briefly mentioned in Section 1.2 (see Table 1.1), is es-
sentially an extension of linear models (and CART) that automatically
36 Introduction
Decision trees
2
Binary recursive partitioning with CART
I’m always thinking one step ahead, like a carpenter that makes
stairs.
Andy Bernard
The Office
This is arguably the most important chapter in the book. It is long, and
rather involved, but serves as the foundation to more contemporary partition-
ing algorithms, like conditional inference treesCTree(Chapter 3), generalized,
unbiased, interaction detection, and estimation (Chapter 4), and tree-based
ensembles, such as random forests (Chapter 7) and gradient boosting machines
(Chapter 8).
2.1 Introduction
In this chapter, I’ll discuss one of the most general (and powerful) tree-based
algorithms in current practice: binary recursive partitioning. This treatment of
the subject follows closely with the open source routines available in the rpart
package [Therneau and Atkinson, 2019], the details of which can be found in
the corresponding package vignettes which can be accessed directly from R
using browseVignettes("rpart") (they can also be found on the package’s
CRAN landing page at https://cran.r-project.org/package=rpart). The
rpart package, which is discussed in depth in Section 2.9, is a modern
39
40 Binary recursive partitioning with CART
source implementations go by other names. For brevity, I’ll use the acronym CART to refer
to the broad class of implementations that follow the original ideas in Breiman et al. [1984],
which includes rpart and scikit-learn’s sklearn.tree module scikit-learn.
Classification trees 41
lass a pr le egress pr le
20
)
12 eg 1
eg 1 eg 3
e ge (
spee ( ph)
11 15
10
10
9
eg 3
e gh
5
eg 2 eg 2
7
9 10 11 12 60 70 0 90
e gh p e ge ( ) e pera ure ( egrees )
FIGURE 2.1: Scatterplots of two data sets split into three non-overlapping
rectangular regions. The regions were selected so that the response values
within each were as homogenous as possible. Left: a binary classification prob-
lem concerning 200 Swiss banknotes that have been identified as either genuine
(purple circles) or counterfeit (yellow triangles). Right: a regression problem
(brighter spots indicate higher average response rates within each bin).
region in the left side of Figure 2.1 by making more partitions, but this would
eventually lead to overfitting.
The term “binary recursive partitioning” is quite descriptive of the general
CART procedure, which I’ll discuss in detail in the next section for the clas-
sification case. The word binary refers to the binary (or two-way) nature of
the splits used to construct the trees (i.e., each split partitions a set of ob-
servations into two non-overlapping subsets). The word recursive refers to the
greedy nature of the algorithm in choosing splits sequentially (i.e., the algo-
rithm does not look ahead to find splits that are globally optimal in any sense;
it only tries to find the next best split). And of course, partitioning refers to
the way splits attempt to partition a set of observations into non-overlapping
subgroups with homogeneous response values.
first. To begin, let’s go back to the Swiss banknote data from Figure 2.1. As
discussed in Section 1.4.1, these data contain six continuous measurements on
200 Swiss 1000-franc banknotes: 100 genuine and 100 counterfeit. The goal
is to use the six available features to classify new Swiss banknotes as either
genuine or counterfeit.
The code chunk below loads the data into R and prints the first few observa-
tions:
head(bn <- treemisc::banknote) # load and peek at data
#> length left right bottom top diagonal y
#> 1 215 131 131 9.0 9.7 141 0
#> 2 215 130 130 8.1 9.5 142 0
#> 3 215 130 130 8.7 9.6 142 0
#> 4 215 130 130 7.5 10.4 142 0
#> 5 215 130 130 10.4 7.7 142 0
#> 6 216 131 130 9.0 10.1 141 0
A tree diagram representation of the Swiss banknote regions from Figure 2.1
is displayed in Figure 2.2. The bottom number in each node gives the fraction
of observations that pass through that node (hence, the root node displays
100%). The values in the middle give the proportion of counterfeit and genuine
banknotes, respectively, and the class printed at the top corresponds to the
larger fraction (i.e., whichever class holds the majority in the node). The
number above each node gives the corresponding node number. This is an
example classification tree that can be used to classify new Swiss banknotes.
For example, any Swiss banknote with bottom >= 9.55 would be classified as
counterfeit (y = 1); note that the split points are rounded for display purposes
in Figure 2.2. The proportion of counterfeit bills in this node is 0.977 and can
be used as an estimate of Pr (Y = 1|x); but more on this later.
From this tree, we can construct three simple rules for classifying new Swiss
banknotes using just the bottom and top length of each bill:
Rule 1 (path to terminal node 2)
IF bottom >= 9.55 (mm)
THEN classification = Counterfeit (probability = 0.977)
which features to split on and which split point to use for each? Since this is
a binary classification problem, CART searched for the predictor/split com-
binations that “best” separated the genuine banknotes from the counterfeit
ones (I’ll discuss how “best” is determined in the next section).
1
Counterfeit
.50 .50
100%
no bottom < 9.6 yes
3
Genuine
.12 .87
56%
top < 11
2 6 7
Counterfeit Counterfeit Genuine
.98 .02 .76 .24 .01 .99
44% 8% 48%
FIGURE 2.2: Example decision tree diagram for classifying Swiss banknotes
as counterfeit or genuine.
Let’s first discuss in general how CART finds the “best” split for an ordered
variable. A hypothetical split S of an arbitrary node A into left and right
child nodes, denoted AL and AR , respectively, is shown in Figure 2.3. If A
contains N observations, then S partitions A into subsets AL and AR with
node sizes NL and NR , respectively; note that NL + NR = N . Since the
splitting process we’re about to describe applies to any node in a tree, we can
assume without loss of generality that A is the root node, which contains the
entire learning sample (that is, all of the training data that will be used in
constructing the tree). For now, I’ll assume that all of the features are ordered,
which includes both continuous and ordered categorical variables (I’ll discuss
splits for nominal categorical features in Section 2.4). The first step is to
partition the root node in a way that “best separates” the individual class
labels into two child nodes; I’ll discuss ways to measure how well a particular
split separates the class labels momentarily.
The split S depicted in Figure 2.3 can be summarized via a 2-by-2 contingency
table giving the number of observations from each class that go to the left or
44 Binary recursive partitioning with CART
x≥c x<c
AL AR
FIGURE 2.3: Hypothetical split for some parent node A into two child nodes
using a continuous feature x with split point c.
TABLE 2.1: Confusion table summarizing the split S depicted in Figure 2.3.
y=0 y=1
N0,A N1,A
right child node. Table 2.1 gives such a summary for a binary 0/1 outcome.
For example, N0,AL is the number of observations belonging to class y =
0 that went to the left child node. The row and column margins are also
displayed.
CART takes a greedy approach to tree construction. At each step in the
splitting process, CART uses an exhaustive search to look for the next best
(i.e., locally optimal) split, which does not necessarily lead to a globally opti-
mal tree structure. This offers a reasonable trade-off between simplicity and
complexity—otherwise the algorithm would have to consider all future poten-
tial splits at each step, which would lead to a combinatorial explosion. Let’s
turn our attention now to how CART chooses to split a node.
Let’s assume the outcome is binary with J = 2 classes that are arbitrarily
coded as 0/1 (e.g., for failure/success). For a continuous feature x with k
Classification trees 45
Cross entropy
0.6
Gini index
Node impurity
0.4
0.2
0.0
You may wonder why I’m not considering misclassification error as a measure
of impurity. As it turns out, misclassification error is not a useful impurity
measure for deciding splits; see, for example, Hastie et al. [2009, Section 9.2.3].
However, misclassification error can be a useful measure of the risk associated
with a tree and is used in decision tree pruning (Section 2.5).
c
Technically, we should use pj (A) ∝ πj Nj,A /Nj , where πj represents the true pro-
portion of class j in the population of interest (called the prior for class j), but I’ll come
back to this in Section 2.2.4. For now, let’s take πj = Nj /N , the observed proportion of ob-
servations in the learning sample that belong to class j—this assumption is not always valid
(e.g., when the data have been downsampled), but simplifies the formulas in this section,
so I’ll leave the complexities to Section 2.2.4.
Classification trees 47
Now that we have some notion of node impurity, we can define a measure
for the quality of a particular split. In essence, the quality of an ordered split
S = {x, c} (see Figure 2.3), often called the gain of S, denoted ∆I (S, A),
is defined as the degree to which the two resulting child nodes, AL and AR ,
reduce the impurity of the parent node A:
For binary trees, Breiman [1996c] noted that the Gini index tends to prefer
splits that put the most frequent class into one pure node, and the remaining
classes into the other. Both entropy and the twoing splitting rules, on the
other hand, put their emphasis on balancing the class sizes in the two child
nodes. In problems with a small number of classes (i.e., J = 2), the Gini and
entropy criteria tend to produce similar results.
Géron [2019, pp.183–184] echoes similar thoughts to Breiman’s: “So should
you use Gini impurity or entropy? The truth is, most of the time it does not
make a big difference: they lead to similar trees. Gini impurity is slightly faster
to compute, so it is a good default. However, when they differ, Gini impurity
tends to isolate the most frequent class in its own branch of the tree, while
entropy tends to produce slightly more balanced trees.”
48 Binary recursive partitioning with CART
Returning to the Swiss banknote example, our goal is to find the first split
condition that “best” separates the genuine banknotes from the counterfeit
ones. Here, we’ll restrict our attention to just two features: top and bottom,
which give the length (in mm) of the top and bottom edge, respectively. (We’re
restricting attention to these two features because, as we’ll see later, diagonal
is too good a predictor and leads to a less interesting illustration of finding
splits.) Since this is a classification problem, we can use cross-entropy or the
Gini index to measure the goodness of each split; here, we’ll use the Gini index
and leave implementing cross-entropy as an exercise for the reader.
A simple R function for computing the Gini index in the two-class case is
given below. This function takes the binary target values as input, which are
assumed to be coded as 0/1 (which corresponds to genuine/counterfeit, in this
example); compare the function below to (2.2).
gini <- function(y) { # y should be coded as 0/1
p <- mean(y) # proportion of successes (or 1s)
2 * p * (1 - p) # Gini index
}
To find the optimal split S = {x, c}, where x is an ordered (but numeric)
feature and c is in its domain, we need to search through every possible value
of c. This can be done, for example, by searching through the midpoints of
the sorted, unique values of x. For each split, we then need to compute the
weighted impurity of the current (or parent) node, as well as the weighted
impurities of the resulting left and right child nodes; then we find which split
point resulted in the largest gain (2.3).
A simple R function, called splits(), for carrying out these steps is given
below. Here, node is a data frame containing the observations in a particular
node (i.e., a subset of the learning sample), while x and y give the column
names in node corresponding to the (ordered numeric) feature of interest
and the (binary or 0/1) target, respectively. The argument n specifies the
number of observations in the learning sample; this is needed to compute
the probabilities p (A), p (AL ), and p (AR ) used in (2.3). The use of drop =
TRUE in the definitions of the variables left and right ensures the results are
coerced to the lowest possible dimension. The drop argument in subsetting
arrays and matrices is used a lot in this book; for details, see ¿[‘ and ?drop
for additional details.
splits <- function(node, x, y, n) { # y should be coded as 0/1
xvals <- sort(unique(node[[x]])) # sorted, unique values
xvals <- xvals[-length(xvals)] + diff(xvals) / 2 # midpoints
res <- matrix(nrow = length(xvals), ncol = 2) # to store results
colnames(res) <- c("cutpoint", "gain")
for (i in seq_along(xvals)) { # loop through each midpoint
Classification trees 49
Let’s test this function out on the full data set (i.e., A is the root node) and
find the optimal split point for bottom. To start, we’ll find the gain that is
associated with each possible split point and plot the results:
res <- splits(bn, x = "bottom", y = "y", n = nrow(bn))
head(res, n = 5) # peek at first five rows
#> cutpoint gain
#> [1,] 7.25 0.00761
#> [2,] 7.35 0.01020
#> [3,] 7.45 0.01948
#> [4,] 7.55 0.03045
#> [5,] 7.65 0.03616
plot(res, type = "b", col = 2, las = 1,
xlab = "Split value for bottom edge length (mm)",
ylab = "Gain") # Figure 2.5
0.35
0.30
0.25
0.20
Gain
0.15
0.10
0.05
0.00
8 9 10 11 12
Split value for bottom edge length (mm)
FIGURE 2.5: Reduction to the root node impurity as a function of the split
value c for the bottom edge length (mm).
50 Binary recursive partitioning with CART
Figure 2.5 shows the split value c as a function of gain (or goodness of split).
We can extract the exact cutpoint associated with the largest gain using
res[which.max(res[, "gain"]), ] # extract row with maximum gain
#> cutpoint gain
#> 9.550 0.358
Here, we can see that the optimal split point for bottom is 9.55 mm. A typical
tree algorithm based on an exhaustive search would do this for each feature and
pick one feature with the largest overall gain. Since all the features in banknote
are continuous, we can just apply splits() to each feature to see which
predictor would be used to first split the training data (i.e., the root node).
To make things easier, let’s write a wrapper function that calls splits() for
any number of features, finds the split point associated with the largest gain for
each, and then returns the best predictor/cutpoint pair. This is accomplished
by the find_best_split() function below:
find_best_split <- function(node, x, y, n) {
res <- matrix(nrow = length(x), ncol = 2) # to store output
rownames(res) <- x # set row names to feature names
colnames(res) <- c("cutpoint", "gain") # column names
for (xname in x) { # loop through each feature
# Compute optimal split
cutpoints <- splits(node, x = xname, y = y, n = n)
res[xname, ] <- cutpoints[which.max(cutpoints[, "gain"]), ]
}
res[which.max(res[, "gain"]), , drop = FALSE]
}
Now we’re ready to start recursively partitioning the banknote data set.
The code chunk below uses find_best_split() on the root node (i.e., the
full learning sample) to find the best split between the features top and
bottom:
features <- c("top", "bottom") # feature names
find_best_split(bn, x = features, y = "y", n = nrow(bn))
#> cutpoint gain
#> bottom 9.55 0.358
Using the Gini index, the best way to separate genuine bills from counterfeit
ones, using only the lengths of the top and bottom edges, is to separate the
banknotes according to whether or not bottom >= 9.55 (mm), which parti-
tions the root node (i.e., full learning sample) into two relatively homogeneous
subgroups (or child nodes):
left <- bn[bn$bottom >= 9.55, ] # left child node
right <- bn[bn$bottom < 9.55, ] # right child node
#>
#> 0 1
#> 2 86
table(right$y) # class distribution in right child node
#>
#> 0 1
#> 98 14
It makes no difference which node we consider the left or right child node; here
I chose them for consistency with the tree diagram from Figure 2.2. Notice how
the left child node is nearly pure, since 86 of the 88 observations (98%) in that
node are counterfeit. While we could try to further partition this node using
another split, it will likely lead to overfitting. The right node, on the other
hand, is less homogeneous, with 14 of the 112 observations being counterfeit,
and could potentially benefit from further splitting, as shown below:
find_best_split(right, x = features, y = "y", n = nrow(bn))
#> cutpoint gain
#> top 11.1 0.082
The next best split used top with a split value of c = 11.15 (mm) and a
corresponding gain of 0.082. The resulting child nodes from this split are
more homogenous but still not pure.
These two splits match the tree structure from Figure 2.2, which was obtained
using actual tree fitting software, but more on that later. Without any stop-
ping criteria defined, the partitioning algorithm could continue splitting until
all terminal nodes are pure (a saturated tree). In Section 2.5, we’ll discuss
how to select an optimal number of splits (e.g., based on cross-validation).
Saturated (or nearly full grown) trees are not generally useful on their own;
however, in Chapter 5, we’ll discuss a simple ensemble technique for improv-
ing the performance of individual trees by aggregating the results from several
hundred (or even thousand) saturated trees.
Fitted values and predictions for new observations are obtained by passing
records down the tree and seeing which terminal nodes they fall in. Recall
that every terminal node in a fitted tree comprises some subset of the original
training instances. If A is a terminal node, then any observation x (train-
ing or new) that lands in A would be assigned to the majority class in A:
arg maxj∈{1,2,...,J} Nj,A ; tie breaking can be handled in a number of ways
(e.g., drawing straws). The predicted probability of x belonging to class j,
which is often of more interest (and more useful) than the classification of x,
52 Binary recursive partitioning with CART
Pr
c (Y = j|x) = pj (A) = Nj,A /NA , j = 1, 2, . . . , J.
In the Swiss banknote tree (Figure 2.2; p. 43), any Swiss banknote with bottom
>= 9.55 (mm) would be classified as counterfeit (since the majority of obser-
vations in the corresponding terminal node are counterfeit) with a predicted
probability of 86/(86 + 2) = 0.977; note that the fitted probabilities in Fig-
ure 2.2 have been rounded to two decimal places, which is why they are not
identical to the results we computed by hand in the previous section.
In summary, terminal nodes in a CART-like tree are summarized by a sin-
gle statistic (or sometime multiple statistics, like the individual class propor-
tions for J-class classification), which is then used to obtain fitted value and
predictions—all observations that are predicted to be in the same terminal
node also receive the same prediction. In classification trees, terminal nodes
can be summarized by the majority class or the individual class proportions
which are then used to generate classifications or predicted class probabili-
ties for each of the J classes, respectively. Similarly, the terminal nodes in a
CART-like regression tree (Section 2.3) can be summarized by the mean or
median response, typically the former.
d For more than two classes (i.e., J > 2), a plurality vote is used.
Classification trees 53
Frank Harrell
https://www.fharrell.com/post/classification/
Fortunately, CART can flexibly handle imbalanced class labels without chang-
ing the learning sample. At a high level, we can assign specific unequal losses
or penalties on a one-by-one basis to each type of misclassification error; in
binary classification, there are two types of misclassification errors we can
make: misclassify a 0 as a 1 (a false positive) or misclassify a 1 as a 0 (a
false negative). The CART algorithm can account for these unequal losses or
misclassification costs when deciding on splits and making predictions. Unfor-
tunately, it seems that many practitioners are either unaware, or fail to take
advantage of this feature.
Our discussion of splitting nodes in Section 2.2.1 implicitly made several
assumptions about the available data. For instance, estimating pj (A) with
Nj,A /NA , the proportion of observations in node A that belong to class j, as-
sumes the training data are a random sample from some population
of interest. In particular, it assumes that the true prior probability of observ-
ing class j, denoted πj , can be estimated with the observed proportion of class
j observations in the training data; that is, πj ≈ Nj /N . If the observed class
proportions are off (e.g., the data have been downsampled or the minority
class has intentionally been over-sampled to over-represent rare cases), then
Nj,A /NA is no longer a reasonable estimate of pj (A). Instead, we should be
J
using pj (A) ∝ πj Nj,A /Nj , where we scale the {pj (A)}j=1 to sum to one. Note
that if we take πj to be the observed class proportions, then πj = Nj /N and
54 Binary recursive partitioning with CART
Let L be a J × J loss matrix with entries Li,j representing the loss (or cost)
associated with misclassifying an i as a j. We can define the risk of a node A
as
Classification trees 55
J
X
r (A) = pj (A) × Lj,τA , (2.4)
j=1
where τA is the class assigned to A, if A were a terminal node, such that this
risk is minimized. Since pj (A) depends on the prior class probabilities, risk is
a function of both misclassification costs and class priors.
As a consequence, we can take misclassifcation costs into account by absorbing
them into the priors for each class; this is referred to as the altered priors
method. In particular, if
(
Li i 6= j
Li,j =
0 i=j
then we can use the prior approach discussed above with the priors altered
according to
J
X
π̃i = πj Li / πj Lj , (2.5)
j=1
To illustrate, let’s walk through a detailed example using the employee at-
trition data set (Section 1.4.6). Figure 2.6 displays two classification trees
fit to the employee attrition data, each with a max depth of two.e The
only difference is that the tree on the left used the observed class priors
πno = 1233/1470 = 0.839 and πyes = 237/1470 = 0.161 (i.e., it treats both
types of misclassifications as equal). The tree on the right used altered priors
based on the following loss (or misclassification cost) matrix:
e The depth of a decision tree is the maximum of the number of edges from the root node
to each terminal node and is a common tuning parameter; see Section 8.3.2.
56 Binary recursive partitioning with CART
No Yes
No 0 1
L= ,
Yes 8 0
where the rows represent the true class and the columns represent the pre-
dicted class. For example, we’re saying that it is 8 times more costly to mis-
classify a Yes (employee will leave due to attrition) as a No (employee will not
leave due to attrition) than it is to misclassify a No as a Yes. Using this loss
matrix, we can compute the altered priors as follows:
Rescaling so that π̃no + π̃yes = 1 gives π̃no = 0.394 and π̃yes = 0.606. No-
tice how altering the priors resulted in a tree with different splits and node
summaries.
The confusion matrix from each tree applied to the learning sample is shown
in Table 2.2. Altering the priors by specifying a higher cost for misclassifying
the Yeses increased the number of true negatives (assuming No represents the
positive class) from 48 to 233, albeit at the expense of decreasing the number
of true negatives from 1212 to 163. Finding the right balance is application-
specific and requires a lot of thought and collaboration with subject matter
experts.
Observed class
Default priors Altered priors
No Yes No Yes
No 1212 189 163 4
Predicted class
Yes 21 48 1070 233
The tree structure on the left of Figure 2.6 uses the same calculations we
worked through for the Swiss banknote example, so let’s walk through some
of the calculations for the tree on the right.
In any particular node A, we estimate pno (A) ∝ π̃no × Nno,A /Nno and
pyes (A) ∝ π̃yes × Nyes,A /Nyes , which are rescaled to sum to one. For instance,
if A is the root node, we have pno (A) = π̃no = 0.394 since Nno,A /Nno =
1233/1233 = 1. Similarly, pyes (A) = 0.606. We can then calculate the impu-
rity of the root node using the Gini index:
Classification trees 57
1 1
No Yes
.84 .16 .39 .61
100% 100%
no OverTime = Yes yes no OverTime = Yes yes
3 2
No No
.69 .31 .52 .48
28% 72%
MonthlyIncome < 2475 MonthlyIncome < 11e+3
2 6 7 4 5 3
No No Yes No Yes Yes
.90 .10 .77 .23 .30 .70 .84 .16 .48 .52 .22 .78
72% 24% 5% 11% 60% 28%
FIGURE 2.6: Decision trees for the employee attrition example. Left: default
(i.e., observed) class priors. Right: altered class priors.
If we split the data according to Overtime = Yes (right branch) vs. Overtime
= No (left branch), we have the following:
left <- attrition[attrition$OverTime == "No", ] # left child
right <- attrition[attrition$OverTime == "Yes", ] # right child
table(attrition$Attrition) # class frequencies
#>
#> No Yes
#> 1233 237
table(left$Attrition) # class frequencies (left node)
#>
#> No Yes
#> 944 110
table(right$Attrition) # class frequencies (right node)
#>
#> No Yes
#> 289 127
In Section 2.9.4, we’ll verify these calculations using open source tree software
that follows the same CART-like procedure for altered priors.
This wraps our discussion of CART’s search for the best split for an ordered
variable in classification trees. Before discussing the search for splits on cat-
egorical features, I’ll introduce the concept of a regression tree; that is, a
decision tree with a continuous outcome.
NAL NAR 2
SSEA = SSEAL + SSEAR + (ȳL − ȳR ) , (2.7)
NA
where ȳL and ȳR give the sample mean for the left and right child nodes of A,
respectively. This implies that maximizing (2.6) is equivalent to maximizing
the last term in (2.7), which makes sense, since we want the child nodes to be
as different as possible (i.e., a greater difference in the mean responses).
In the regression case, we don’t have to worry about priors or node probabil-
ities. The terminal nodes are summarized by the mean response in each (the
sample median is another possibility), and these are used for producing fitted
values and predictions. For example, if a new
PNAobservation x were to occupy
some node terminal node A, then fˆ (x) = i=1 yi,A /NA , where yi,A denotes
the i-th response value from the learning sample that resides in terminal node
A.
Aside from being useful in their own right, regression trees, as presented here,
serve as the basic building blocks for gradient tree boosting (Chapter 8), one
of the most powerful tree-based ensemble algorithms available.
Consider, for example, the airquality data frame introduced in Section 1.4.2,
which contains daily air quality measurements in New York from May to
September of 1973. A regression tree with a single split was fit to the data
and the corresponding tree diagram is displayed in the left side of Figure 2.7.
Here, the chosen splitter was temperature (in degrees Fahrenheit). Each node
displays the predicted ozone concentration for all observations that fall in that
node (top number) as well as the proportion of training observations in each
(bottom number). According to this tree, the predicted ozone concentration
is given by the simple rule:
60 Binary recursive partitioning with CART
(
26.544 if Temp < 82.5
=
Ozone .
75.405 if Temp >= 82.5
The estimated regression surface is plotted in the right side of Figure 2.7.
Note that the estimated prediction surface from a regression tree is essentially
a step function, which makes it hard for decision trees to capture arbitrarily
smooth or linear response surfaces.
150
1
2
100 100
no Temp >= 83 yes e
50
2 3
27 75
6 32
0
60 70 0 90
e p
To manually find the first partition and reconstruct the tree in Figure 2.7,we’ll
start by creating a simple function to calculate the within-node SSE. Note that
these data contain a few missing valuesf (or NAs in R), so I set na.rm = TRUE
in order to remove them before computing the results.
sse <- function(y, na.rm = TRUE) {
sum((y - mean(y, na.rm = na.rm)) ^ 2, na.rm = na.rm)
}
Next, I’ll modify the splits() function from Section 2.2.2 to work for the
regression case:
f CARTis actually pretty clever in how it handles missing values in the predictors, but
more on this in Section 2.7.
Regression trees 61
Before applying this function to the air quality data, I’ll remove the 37 rows
that have a missing response value. The possible split points for Temp, along
with their associated gains, are displayed in Figure 2.8. (To make the y-axis
look nicer on the plot, the gain values were divided by 1,000.)
# Find optimal split for `Temp`
aq <- airquality[!is.na(airquality$Ozone), ]
res <- splits.sse(aq, x = "Temp", y = "Ozone")
res[which.max(res[, "gain"]), ]
#> cutpoint gain
#> 82.5 60158.5
# Plot results
res[, "gain"] <- res[, "gain"] / 1000 # rescale for plotting
plot(res, type = "b", col = 2, las = 1,
xlab = "Temperature split value (degrees Fahrenheit)",
ylab = "Gain/1000")
abline(v = 82.5, lty = 2, col = 2)
To show that temperature is the best primary splitter for the root node, we
can use sapply() to find the optimal cutpoint for all five features.:
features <- c("Solar.R", "Wind", "Temp", "Month", "Day")
sapply(features, FUN = function(xname) {
res <- splits.sse(aq, x = xname, y = "Ozone")
res[which.max(res[, "gain"]), ]
})
#> Solar.R Wind Temp Month Day
#> cutpoint 153 6.6 82.5 6.5 24.5
#> gain 29721 50591.2 60158.5 14511.3 10282.8
Clearly, the split associated with the largest gain is Temp, followed by Wind,
Solar.R, Month, and Day.
A regression tree in one predictor produces a step function, as was seen in the
right side of Figure 2.7. The same idea extends to higher dimensions as well.
62 Binary recursive partitioning with CART
60
50
40
Gain/1000
30
20
10
60 70 80 90
Temperature split value (degrees Fahrenheit)
FIGURE 2.8: Potential split points for temperature as a function of gain. The
maximum gain occurs at a temperature of 82.5 ◦ F (the dashed vertical line).
For example, suppose we considered splitting on Wind next. Using the same
procedures previously described, we would find that the next best partition
occurs in the left child node using Wind with a cutpoint of 7.15 (mph). The
corresponding tree diagram is displayed on the left side of Figure 2.9. If we stop
splitting here, the result is a regression tree in two features. The corresponding
prediction function, displayed on the right side of Figure 2.9, is a surface that’s
constant over each terminal node.
Up to this point, we’ve only considered splits for ordered predictors, which
have the form x < c vs. x ≥ c, where c is in the domain of x. But what about
splits involving nominal categorical features? If x is ordinal (i.e., an ordered
category, like low < medium < high), then we can map its ordered categories
to the integers 1, 2, . . . , J, where J is the number of unique categories, and
split as if x were originally numeric. If x is nominal (i.e., the order of the
categories has no meaning), then we have to consider all possible ways to split
Categorical splits 63
1
42
150
100%
no Temp >= 83 yes
Ozone
100
2
27
68%
50
Wind < 7.2
5
10
4 5 3 90
nd
22 56 75 80 15
Wi
59% 9% 32% Tem 70
p
60 20
FIGURE 2.9: Regression tree diagram (left) and corresponding regression sur-
face (right) for the air quality data. These are the same splits shown in Fig-
ure 2.1.
x into two mutually disjoint groups. For example, if x took on the categories
{a, b, c}, then we could form a total three splits:
• x ∈ {a} vs. x ∈ {b, c};
• x ∈ {b} vs. x ∈ {a, c};
• x ∈ {c} vs. x ∈ {a, b}.
For a nominal predictor with J categories, there are a total of 2J−1 − 1 po-
tential splits to search through, which can be computationally prohibitive for
large J; for J ≥ 21, we’d have to search more than a million splits! Fortu-
nately, for ordered or binary outcomes, there is a computational shortcut that
can be exploited for the splitting rules discussed in this chapter (i.e., Gini in-
dex, entropy, and SSE). This is discussed, for example, in Hastie et al. [2009,
Sec. 9.2.4] and the “User Written Split Functions” vignette in package rpart
(use vignette("usercode", package = "rpart") at the R console).
In short, the optimal split for a nominal predictor x at some node A can be
found by first ordering the individual categories of x by their average response
value—for example, the proportion of successes in the binary outcome case—
and then finding the best split using this new ordinal variable.g This reduces
g This is equivalent to performing mean/target encoding [Micci-Barreca, 2001] prior to
searching for the best split at each node; see Section 2.4.3.
64 Binary recursive partitioning with CART
1
Edible
.52 .48
100%
no odor = creosote,fishy,foul,musty,pungent,spicy yes
2
Edible
.97 .03
53%
spore.print.color = green
4 5 3
Edible Poison Poison
.99 .01 .00 1.00 .00 1.00
52% 1% 47%
Edible/Poison. I’ll also remove the veil.type feature because it only takes
on a single value (i.e., it has zero variance) and can contribute nothing to the
partitioning:
m <- treemisc::mushroom # load mushroom data
m$veil.type <- NULL # remove useless feature
m$Edibility <- ifelse(m$Edibility == "Poison", 1, 0)
m2 <- m # make a copy of the original data
To illustrate the main idea, let’s look at a frequency table for the veil.color
predictor, which has four unique categories:
table(m2$veil.color)
#>
#> brown orange white yellow
#> 96 96 7924 8
We need to find the mean response within each category—in this case, the
proportion of poisonous mushrooms—and then map those back to the origi-
nal feature values. For instance we would re-encode all the values of "white"
in veil.color as 0.493 because 3908/7924 ≈ 0.493 of the mushrooms with
veil.color = "white" are poisonous. This can be done in any number of
ways, and here I’ll write a simple function, called ordinalize(), that returns
a list with two components: map, which contains the numeric value each cat-
egory gets mapped to, and encoded, which contains the re-encoded feature
values.
ordinalize <- function(x, y) { # convert nominal to ordered
map <- tapply(y, INDEX = x, FUN = mean)
list("mapping" = map, "encoded" = map[x])
}
Next, I’ll write a simple for loop that uses ordinalize() to numerically
re-encode each feature column in the m2 data frame:
xnames <- setdiff(names(m2), "Edibility")
for (xname in xnames) { # mean/target encode each feature
m2[[xname]] <- ordinalize(m2[[xname]], y = m2[["Edibility"]])$encoded
}
Since all the categorical features have been re-encoded numerically, we can
use our previously defined find_best_split() function to partition the data.
Starting with the root node (i.e., the full learning sample), we obtain:
find_best_split(m2, x = xnames, y = "Edibility", n = nrow(m2))
#> cutpoint gain
#> odor 0.517 0.471
# Summarize split
left <- m[m2$odor >= 0.5170068, ]
right <- m[m2$odor < 0.5170068, ]
The first split uses odor, with a mean/target encoded split point of c = 0.517
and a corresponding gain of 0.471. Since the resulting right child node is pure
(in this case, all poisonous), let’s continue partitioning with the left one:
# Ordinalize left child node and find next best split
right.ord <- right
for (xname in xnames) { # mean/target encode each feature
right.ord[[xname]] <-
ordinalize(right.ord[[xname]],
y = right.ord[["Edibility"]])$encoded
}
For example, the split point for odor was 0.517 (the midpoint between 0.034
1.00), and every feature mapped to a re-encoded odor value ≥ 0.517 is used to
construct the first partition; see the first split in Figure 2.10. In Section 2.9.2,
we’ll verify these results (e.g., the computed gain for both splits) using CART-
like software in R.
One drawback of CART-like decision trees is that they tend to favor cate-
gorical features with high cardinality (i.e., large J), even if they are mostly
irrelevant.h For categorical features with large J, for example, there are so
many potential splits that the tree is more likely to find a good split just by
chance. Think about the extreme case where a nominal feature x is different
and unique in every row of the learning sample, like a row ID column. The
split variable selection bias in CART-like decision trees has been discussed
plenty in the literature; see, for example, Breiman et al. [1984, p. 42], Segal
[1988], and Hothorn et al. [2006c] (and the additional references therein)
To illustrate the issue, I added ten random categorical features (cat1–cat10)
to the airquality data set from Section 2.3.1, each with a cardinality of J =
26 (they’re just random letters from the alphabet). A default regression tree
was fit to the data using rpart, and the resulting tree diagram is displayed in
Figure 2.11. Notice that all of the splits, aside from the first, use the completely
irrelevant categorical features that were added! In Section 2.5 we’ll look at a
h This bias actually extends to any predictor with lots of potential split points, whether
ordered or nominal.
68 Binary recursive partitioning with CART
general pruning technique that can be helpful in screening out pure noise
variables.
1
42
100%
no Temp >= 83 yes
2 3
27 75
68% 32%
cat4 = b,c,k,t,v,z cat1 = f,g,l,n,p,q,r,s,v,w
4 5 7
20 46 90
51% 17% 19%
cat2 = a,c,d,f,h,j,l,n,o cat2 = c,g,h,q,w cat8 = a,d,f,l,t,x
8
16
34%
cat3 = a,e,m,v,w,z
16 17 9 10 11 6 14 15
12 25 28 29 79 54 79 108
24% 10% 16% 11% 6% 13% 12% 7%
FIGURE 2.11: A decision tree fit to a copy of the air quality data set that
includes ten completely random categorical features, each with cardinality 26.
When dealing with categorical data, we are often concerned with how to en-
code such features. In linear models, for example, we often employ dummy
encoding or effect encoding, depending on the task at hand. Similarly, one-hot-
encoding (OHE), closely related to dummy encoding, is often used in general
machine learning problems outside of (generalized) linear models. And there
are plenty of other ways to encode categorical variables, depending on the
algorithm and task at hand.
Building a decision tree 69
As you’ve already seen, decision trees can naturally handle variables of any
type without special encoding, although we did see that a local form of
mean/target encoding can be used to reduce the computational burden im-
posed by nominal categorical splits. Nonetheless, using an encoding strategy,
like OHE, can sometimes improve the predictive performance or interpretabil-
ity of a tree-based model; see Kuhn and Johnson [2013, Sec. 14.7] for a brief
discussion on the use of OHE in tree-based methods. Further, some tree-based
software, like Scikit-learn’s sklearn.tree module, require all features to be
numeric—forcing users to employ different encoding schemes for categorical
features. See Boehmke and Greenwell [2020, Chap. 3] for details on different
encoding strategies (with examples in R), and further references.
In the previous sections, we talked about the basics of splitting a node (i.e.,
partitioning some subset of the learning sample). Building a CART-like deci-
sion tree starts by splitting the root node, and then recursively applying the
same splitting procedure to every resulting child node until a saturated tree is
obtained (i.e., all terminal nodes are pure) or other stopping criteria are met.
In essence, the partitioning stops when at least one of the following conditions
are met:
• all the terminal nodes are pure;
• the specified maximum tree depth has been reached;
• the minimum number of observations that must exist in a node in order
for a split to be attempted has been reached;
• no further splits are able to decrease the overall lack of fit by a specified
factor;
• and so forth.
This often results in an overly complex tree structure that overfits the learning
sample; that is, it has low bias, but high variance.
To illustrate, consider a random sample of size N = 500, generated from the
following sine wave with Gaussian noise:
Y = sin (X) + ,
where X ∼ U (0, 2π) and ∼ N (0, σ = 0.3). A scatterplot of the data, along
with the true response function, is shown in Figure 2.12.
70 Binary recursive partitioning with CART
0
y
0 2 6
FIGURE 2.12: Data generated from a simple sine wave with Gaussian noise.
The black curve shows the true mean response E (Y |X = x) = sin (x).
Figure 2.13 shows the prediction function from two regression trees fit to the
same data.i The tree on the left is too complex and has too many splits, and
exhibits high variance, but low bias (i.e., it fits the current sample well, but
the tree structure will vary wildly from one sample to the next because it’s
mostly fitting the noise here); unstable models, like this one are often referred
to as unstable learners (more on this in Section 5.1). The tree on the right,
which is a simple decision stump (i.e., a tree with only a single split), is too
simple, and will also not be useful for prediction because it has extremely
high bias, but low variance (i.e., it doesn’t fit the data too well, but the tree
structure will be more stable from sample to sample); such a weak performing
model is often referred to as a weak learner (more on this in Section 5.2).
Neither tree is likely to be accurate when applied to a different sample from
the same model; the ensemble methods discussed in Part II of this book can
improve the performance of both weak and unstable learners. When using a
single decision tree, however, the question we need to answer is, How complex
should we make the tree? Ideally, we should have stopped splitting nodes at
some subtree along the way, but where?
A rather careless approach is to build a tree by only splitting nodes that meet
some threshold on prediction error. However, this is shortsighted because a
low-quality split early on may lead to a very good split later in the tree. The
standard approach to finding an optimal subtree—basically, determining when
i The associated tree diagrams are shown in the top left and bottom right of Figure 2.14
1 1
0 0
y
y
1 1
0 2 6 0 2 6
FIGURE 2.13: Regression trees applied to the sine wave example. Left: this
tree is too complex (i.e., low bias and high variance). Right: this tree is too
simple (i.e., high bias and low variance).
Rα (T ) = R (T ) + α|T |,
| | | |
| | | |
| | | |
FIGURE 2.14: Nested subtrees for the sine wave example. The optimal sub-
tree, chosen via 10-fold cross-validation, is highlighted in green.
R (A) − R (TA )
α= ,
|TA | − 1
R (A) − R (TA )
α1 ≥ .
|TA | − 1
This results in the optimal subtree, Tα1 , associated with α = α1 . Starting with
Tα1 , we then continue this process by finding α2 in the same way we found
α1 for the full tree T . The process is continued until reaching the root node.
It might sound confusing, but we’ll walk through the calculations using the
mushroom example in the next section.
The rpart package, which is used extensively throughout this chapter, em-
ploys a slightly friendlier, and rescaled, version of the cost-complexity param-
eter α, which they denote as cp. Specifically, rpart uses
74 Binary recursive partitioning with CART
1.0
Test error
10−fold CV
0.8
Relative error
0.6
0.4
0.2
0.0
0 50 100 150
Number of splits
FIGURE 2.15: Relative error based on the test set (black curve) and 10-
fold cross-validation (yellow curve) vs. the number of splits for the sine wave
example. The vertical yellow line shows the optimal number of splits based
on 10-fold cross-validation, while the vertical black line shows the optimal
number of splits based on the independent test set.
Rcp (T ) ≡ R (T ) + cp × |T | × R (T1 ) ,
where T1 is the tree with zero splits (i.e., the root node). Compared to α, cp
is unitless, and a value of cp = 1 will always result in a tree with zero splits.
The complexity parameter, cp, can also be used as a stopping rule during
tree construction. In many open source implementations of CART, whenever
cp > 0, any split that does not decrease the overall lack of fit by a factor of cp
is not attempted. In a regression tree, for instance, this means that the overall
R2 must increase by cp at each step for a split to occur. The main idea is to
reduce computation time by avoiding potentially unworthy splits. However,
this runs the risk of not finding potentially much better splits further down
the tree.
Let’s drive the main ideas home by calculating a few α values to prune a
simple tree for the mushroom edibility data. Consider again a simple deci-
sion tree for the mushroom edibility data which is displayed in Figure 2.16.
This is a simple tree with only three splits, but we’ll use it to illustrate how
Building a decision tree 75
pruning works and how the sequence of α values is computed. For clarity, the
number of observations in each class is displayed within each node, and the
node numbers appear at the top of each node. For example, node 8 contains
4208 edible mushrooms and 24 poisonous ones. The assigned classification,
or majority class, is printed above the class frequencies in each node. This
tree was also built using the observed class priors and equal misclassification
costs; hence, R (T ) is just the proportion of misclassifications in the learning
sample: 24/8124 ≈ 0.003.
Let Ai , i ∈ {1, 2, 3, 4, 5, 8, 9} denote the seven nodes of the tree in Figure 2.16;
in rpart, the left and right child nodes for any node numbered x are always
numbered 2x and 2x+1, respectively (the root node always corresponds to x =
1). We can compute the risk of any terminal node using R (Ai ) = Nj,A /NA . For
example, nodes A5 –A7 all have a risk of zero (since they are pure nodes).
1
Edible
4208 3916
100%
no odor = creosote,fishy,foul,musty,pungent,spicy yes
2
Edible
4208 120
53%
spore.print.color = green
4
Edible
4208 48
52%
stalk.color.below.ring = yellow
8 9 5 3
Edible Poison Poison Poison
4208 24 0 24 0 72 0 3796
52% 0% 1% 47%
FIGURE 2.16: Classification tree with three splits for the mushroom edibility
data. The overall risk of the tree is 24/8124 ≈ 0.003.
Since αA4 is the smallest, we collapse node A4 , resulting in the next optimal
subtree in the sequence, Tα1 , which is displayed in the left side of Figure 2.17.
The cost-complexity of this tree is Rα1 (Tα1 ) = 0.015. To find α2 , we start with
Tα1 and repeat the process by first finding the smallest α value associated with
the |Tα1 | − 1 = 2 internal nodes of Tα1 . These are given by
1
1
Edible
Edible
4208 3916
4208 3916
100%
100%
no odor = crs,fsh,fol,mst,png,spc yes
no odor = crs,fsh,fol,mst,png,spc yes
2
Edible
4208 120
53%
spore.print.color = grn
4 5 3 2 3
2.5.2 Cross-validation
Once the sequence α1 , α2 , ..., αk−1 has been found, we still need to estimate the
overall risk/quality of the corresponding sequence of nested subtrees, Rαi (T ),
for i = 1, 2, . . . , k − 1. Breiman et al. [1984, Chap. 11] suggested picking α
using a separate validation set or k-fold cross-validation. The latter is more
computational, but tends to be preferred since it makes use of all available
data, and both tend to lead to similar results. The procedure described in
Algorithm 2.1 below follows the implementation in the rpart package in R
(see the “Introduction to Rpart” vignette):
1) Fit the full model to the learning sample to obtain α1 , α2 , ..., αk−1 .
2) Define βi according to
0
i=1
√
βi = αi−1 αi i = 2, 3, . . . , m − 1 .
i=m
∞
Since any value of α in the interval (αi , αi+1 ] results in the same subtree,
we instead consider the sequence of βi ’s, which represent typical values
within each range using the geometric midpoint.
a) Fit the full model to the learning sample, but omit the subset Di ,
and find the sequence of optimal subtrees Tβ1 , Tβ2 , . . . , Tβk .
b) Compute the prediction error from each tree on the validation set Di .
5) Return Tβ from the initial sequence of trees based on the full learning sam-
ple, where β corresponds to the βi associated with the smallest prediction
error in step 4).
78 Binary recursive partitioning with CART
When choosing α with k-fold cross-validation, Breiman et al. [1984, Sec. 3.4.3]
recommend using the 1-SE rule, and argue that it is useful in screening out
irrelevant features. The 1-SE rule suggests using the most parsimonious tree
(i.e., the one with fewest splits) whose cross-validation error is no more than
one standard error above the cross-validation error of the best model. This
of course requires an estimate of the standard error during cross-validation.
A heuristic estimate of the standard error can be found in Breiman et al.
[1984, pp. 306–309] or Zhang and Singer [2010, pp. 42–43], but the formula
isn’t pretty! Applying cost-complexity pruning using cross-validation, with
or without the 1-SE rule, would almost surely remove all of the nonsensical
splits seen in Figure 2.11. (In fact, this was the case after applying 10-fold
cross-validation using the 1-SE rule.)
One of the best features of CART is the flexibility with which missing val-
ues can be handled. More traditional statistical models, like linear or logistic
regression, will often discard any observations with missing values. CART,
Missing data and surrogate splits 79
through the use of surrogate splits, can utilize all observations that have non-
missing response values and at least one non-missing value for the predictors.
Surrogate splits are essentially splits using other available features with non-
missing values. The basic idea, which is fully described in Breiman et al. [1984,
Sec. 5.3], is to estimate (or impute) the missing data point using the other
available features.
Consider the decision stump in Figure 2.18, which corresponds to the optimal
tree for the Swiss banknote data when using all available features.
What if we wanted to classify a new observation which had a missing value for
diagonal? The surrogate approach finds surrogate variables for the missing
splitter by building decision stumps, one for each of the other features (in this
case, length, left, right, bottom, and top), to predict the binary response,
denoted below by y ? , formed by the original split:
(
0 if diagonal ≥ 140.65
y =
?
.
1 if diagonal < 140.65
1
0
.50 .50
100%
no diagonal < 141 yes
2 3
0 1
1.00 .00 .02 .98
49% 51%
For each feature, the optimal split is chosen using the procedure described in
Section 2.2.1. (Note that when looking for surrogates, we do not bother to
incorporate priors or losses since none are defined for y ? .) In addition to the
optimal split for each feature, we also consider the majority rule, which just
uses the majority class. Once the surrogates have been determined, they’re
ranked in terms of misclassification error, and any surrogate that does worse
80 Binary recursive partitioning with CART
than the majority class is discarded. Some implementations, like R’s rpart
package, further require surrogate splits to send at least two observations to
each of the left and right child nodes.
Returning to the Swiss banknote example, let’s find the surrogate splits for the
primary split on diagonal depicted in Figure 2.18. We can find the surrogate
splits using the same splitting process as before, albeit with our new target
variable y ? :
bn2 <- treemisc::banknote # load Swiss banknote data
bn2$y <- ifelse(bn2$diagonal >= 140.65, 1, 0) # new target
bn2$diagonal <- NULL # remove column
features <- c("length", "left", "right", "bottom", "top")
res <- sapply(features, FUN = function(feature) {
find_best_split(bn2, x = feature, y = "y", n = nrow(bn2))
})
rownames(res) <- c("cutpoint", "gain")
res[, order(res["gain", ], decreasing = TRUE)]
#> bottom right top left length
#> cutpoint 9.550 129.850 10.950 130.050 215.1500
#> gain 0.343 0.169 0.157 0.137 0.0344
Aside from being able to handle missing predictor values directly, classification
trees can be extremely useful in examining patterns of missing data [Harrell,
2015, Sec. 3.2]. For example, CART can be used to describe observations
that tend to have missing values (a description problem). This can be done
by growing a classification tree using a target variable that’s just a binary
Missing data and surrogate splits 81
indicator for whether or not a variable of interest is missing; see Harrell [2015,
pp. 302–304] for an example using real data in R.
It can also be informative to construct missing value indicators for each pre-
dictor under consideration. Imagine, for example, that you work for a bank
and that part of your job is to help determine who should be denied for a
loan and who should not. A missing credit score on a particular loan appli-
cation might be an obvious red flag, and indicative of somebody with a bad
credit history, hence, an important indicator in determining whether or not
to approve them for a loan. A similar strategy for categorical variables is to
treat missing values as an actual category. As noted in van Buuren [2018,
Sec. 1.3.7], the missing value indicator method may have its uses in particular
situations but fails as a generic method to handle missing data (e.g., it does
not allow for missing data in the response and can lead to biased regression
estimates across a wide range of scenarios).
Imputation—filling in missing values with a reasonable guess—is another com-
mon strategy, and trees make great candidates for imputation models (e.g.,
they’re fully nonparametric and naturally support both classification and re-
gression).
Using CART for the purpose of missing value imputation has been suggested
by several authors; see van Buuren [2018, Sec. 3.5] for details and several
references. A generally useful approach is to use CART to generate multi-
ple imputations [van Buuren, 2018, Sec. 3.5] via the bootstrap methodj (see
Davison and Hinkley [1997] for an overview of different bootstrap methods);
multiple imputation is now widely accepted as one of the best general methods
for dealing with incomplete data [van Buuren, 2018, Sec. 2.1.2].
The basic steps are outlined in Algorithm 2.2; also see ?mice::cart for de-
tails on its implementation in the mice package [van Buuren and Groothuis-
Oudshoorn, 2021]. Here, it is assumed that the response y corresponds to the
predictor with incomplete observations (i.e., contains missing values) and that
the predictors correspond to the original predictors with complete information
(i.e., no missing values).
As described in Doove et al. [2014] and van Buuren [2018, Sec. 3.5], this process
can be repeated m times using the bootstrap to produce m imputed data sets.
As noted in van Buuren [2018, Sec. 3.5], Algorithm 2.2 is a form of predictive
mean matching [van Buuren, 2018, Sec. 3.4], where the “predictive mean” is
instead calculated by CART, as opposed to a regression model. An example
using CART for multiple imputation is provided in Section 7.9.3.
But what if you’re using a decision tree as the model, and not just as a means
for imputation: should you rely on surrogate splits or a different strategy,
j Unless stated otherwise, a bootstrap sample refers to a random sample of size N with
replacement from a set of N observations; hence, some of the original obervations will be
sampled more than once and some not at all.
82 Binary recursive partitioning with CART
3) for each missing y value, randomly draw an observed response value from
the terminal node to which it’s assigned (i.e., the complete response values
from the learning sample that summarize the node) to use for the imputed
value.
in Section 7.9.5.
n Note that the balanced nature of these data is not very realistic, unless roughly half the
Swiss banknotes truly are counterfeit. However, without any additional information about
the true class priors, there’s not much that can be done here.
Software and examples 85
However, for ease of interpretation, I’ll re-encode the outcome y from 0/1 to
Genuine/Counterfeito :
library(rpart)
# Fit a CART-like tree using top and bottom as the only features
(bn.tree <- rpart(y ~ top + bottom, data = bn))
#> n= 200
#>
#> node), split, n, loss, yval, (yprob)
#> * denotes terminal node
#>
#> 1) root 200 100 Counterfeit (0.5000 0.5000)
#> 2) bottom>=9.55 88 2 Counterfeit (0.9773 0.0227) *
#> 3) bottom< 9.55 112 14 Genuine (0.1250 0.8750)
#> 6) top>=11.1 17 4 Counterfeit (0.7647 0.2353) *
#> 7) top< 11.1 95 1 Genuine (0.0105 0.9895) *
Note that this is the same tree that was displayed in Figure 2.2 (p. 43).
The output from printing an "rpart" object can seem intimidating at first,
especially for large trees, so let’s take a closer look. The output is split into
three sections. The first section gives N , the number of rows in the learning
sample (or root node). The middle section, starting with node), indicates the
format of the tree structure that follows. The last section, starting at 1),
provides a a brief summary of the tree structure. All the nodes of the tree
are numbered, with 1) indicating the root node and lines ending with a *
indicating the terminal nodes. The topology of the tree is conveyed through
indented lines; for example, nodes 2) and 3) are nested within 1); the left
and right child nodes for any node numbered x are always numbered 2x and
2x + 1, respectively.
For each node we can also see the split that was used, the number of observa-
tions it captured, the deviance or loss (in this case, the number of observations
misclassified in that node), the fitted value (in this case, the classification given
to observations in that node), and the proportion of each class in the node.
Take node 2), for example. This is a terminal node, the left child of node 1),
and contains 88 of the N = 200 observations (two of which are genuine bank-
notes). Any observation landing in node 2) will be classified as counterfeit
with a predicted probability of 0.977.
o I could leave the response numerically encoded as 0/1, but then I would need to tell
rpart to treat this as a classification problem by setting method = "class" in the call to
rpart().
86 Binary recursive partitioning with CART
If you want even more verbose output, with details about each split, you can
use the summary() method:
summary(bn.tree) # print more verbose tree summary
#> Call:
#> rpart(formula = y ~ top + bottom, data = bn)
#> n= 200
#>
#> CP nsplit rel error xerror xstd
#> 1 0.84 0 1.00 1.14 0.0700
#> 2 0.09 1 0.16 0.19 0.0415
#> 3 0.01 2 0.07 0.12 0.0336
#>
#> Variable importance
#> bottom top
#> 66 34
#>
#> Node number 1: 200 observations, complexity param=0.84
#> predicted class=Counterfeit expected loss=0.5 P(node) =1
#> class counts: 100 100
#> probabilities: 0.500 0.500
#> left son=2 (88 obs) right son=3 (112 obs)
#> Primary splits:
#> bottom < 9.55 to the right, improve=71.6, (0 missing)
#> top < 11 to the right, improve=30.7, (0 missing)
#> Surrogate splits:
#> top < 11 to the right, agree=0.685, adj=0.284, (0 split)
#>
#> Node number 2: 88 observations
#> predicted class=Counterfeit expected loss=0.0227 P(node) =0.44
#> class counts: 86 2
#> probabilities: 0.977 0.023
#>
#> Node number 3: 112 observations, complexity param=0.09
#> predicted class=Genuine expected loss=0.125 P(node) =0.56
#> class counts: 14 98
#> probabilities: 0.125 0.875
#> left son=6 (17 obs) right son=7 (95 obs)
#> Primary splits:
#> top < 11.1 to the right, improve=16.40, (0 missing)
#> bottom < 9.25 to the right, improve= 2.42, (0 missing)
#>
#> Node number 6: 17 observations
#> predicted class=Counterfeit expected loss=0.235 P(node) =0.085
#> class counts: 13 4
#> probabilities: 0.765 0.235
#>
#> Node number 7: 95 observations
#> predicted class=Genuine expected loss=0.0105 P(node) =0.475
Software and examples 87
Here, we can see each primary splitter, along with its corresponding split point
and gain (i.e., a measure of the quality of the split). For example, using bottom
< 9.55 yielded the greatest improvement and was selected as the first primary
split. The reported improvement (improve=71.59091) is N ×∆I (s, A), hence
why it differs from the output of our previously defined splits() function,
which just uses ∆I (s, A); but you can check the math: 71.59091/200 = 0.358,
which is the same value we obtained by hand back in Section 2.2.2. Woot!
Before continuing, let’s refit the tree using all available features:
summary(rpart(y ~ ., data = bn, method = "class"))
#> Call:
#> rpart(formula = y ~ ., data = bn, method = "class")
#> n= 200
#>
#> CP nsplit rel error xerror xstd
#> 1 0.98 0 1.00 1.12 0.0702
#> 2 0.01 1 0.02 0.03 0.0172
#>
#> Variable importance
#> diagonal bottom right top left length
#> 28 22 15 14 14 6
#>
#> Node number 1: 200 observations, complexity param=0.98
#> predicted class=Counterfeit expected loss=0.5 P(node) =1
#> class counts: 100 100
#> probabilities: 0.500 0.500
#> left son=2 (102 obs) right son=3 (98 obs)
#> Primary splits:
#> diagonal < 141 to the left, improve=96.1, (0 missing)
#> bottom < 9.55 to the right, improve=71.6, (0 missing)
#> right < 130 to the right, improve=34.3, (0 missing)
#> top < 11 to the right, improve=30.7, (0 missing)
#> left < 130 to the right, improve=27.8, (0 missing)
#> Surrogate splits:
#> bottom < 9.25 to the right, agree=0.910, adj=0.816, (0 split)
#> right < 130 to the right, agree=0.785, adj=0.561, (0 split)
#> top < 11 to the right, agree=0.765, adj=0.520, (0 split)
#> left < 130 to the right, agree=0.760, adj=0.510, (0 split)
#> length < 215 to the left, agree=0.620, adj=0.224, (0 split)
#>
#> Node number 2: 102 observations
#> predicted class=Counterfeit expected loss=0.0196 P(node) =0.51
#> class counts: 100 2
88 Binary recursive partitioning with CART
Using all the predictors results in the same decision stump that was displayed
in Figure 2.18. As it turns out, the best tree uses a single split on the length
of the diagonal (in mm) and only misclassifies two of the genuine banknotes
in the learning sample. In addition to the chosen splitter, diagonal, we also
see a description of the competing splits (four by default) and surrogate splits
(five by default); note that these match the surrogate splits I found manually
back in Section 2.7. For example, if I didn’t include diagonal as a potential
feature, then bottom would’ve been selected as the primary splitter because it
gave the next best reduction to weighted impurity (improve=71.59091).
While the rpart package provides plot() and text() methods for plot-
ting and labeling tree diagrams, respectively, the resulting figures are not
as polished as those produced by other packages; for example, rpart.plot
[Milborrow, 2021b] and partykit [Hothorn and Zeileis, 2021]. All the tree
diagrams in this chapter were constructed using a simple wrapper function
around rpart.plot() called tree_diagram(), which is part of treemisc;
see ?rpart.plot::rpart.plot and ?treemisc::tree_diagram for details.
For example, the tree diagram from Figure 2.2 (p. 43) can be constructed
using:
treemisc::tree_diagram(bn.tree)
Figure 2.19 shows a tree diagram depicting the primary split (left) as well
as the second best surrogate split (right). In the printout from summary(),
we also see the computed agreement and adjusted agreement for each surro-
gate. From Figure 2.19, we can see that the surrogate sends (66 + 91)) /200 ≈
0.785 of the observations in the same direction as the primary split (agree-
ment). The majority rule gets 102 correct, giving an adjusted agreement of
(66 + 91 − 102)) / (200 − 102) ≈ 0.561.
In this section, we’ll use rpart to fit a classification tree to the mushroom
data, and explore a bit more of the output and fitting process. Recall from
Section 1.4.4, that the overall objective is to find a simple rule of thumb (if
possible) for avoiding potentially poisonous mushrooms. For now, I’ll stick
with rpart’s defaults (e.g., the splitting rule is the Gini index), but set com-
Software and examples 89
1 1
Counterfeit Genuine
100 100 98 102
100% 100%
no diagonal >= 141 yes no right >= 130 yes
2 3 2 3
Counterfeit Genuine Counterfeit Genuine
100 2 0 98 66 11 32 91
51% 49% 38% 62%
FIGURE 2.19: Decision stump for the Swiss banknote example (left) and one
of the surrogate splits (right).
plexity parameter, cp, to zero (cp = 0) for a more complex tree.p Although
the tree construction itself is not random, the internal cross-validation results
are, so I’ll also set the random number seed before calling rpart():
mushroom <- treemisc::mushroom
1
Edible
.52 .48
100%
no odor = creosote,fishy,foul,musty,pungent,spicy yes
2
Edible
.97 .03
53%
spore.print.color = green
4
Edible
.99 .01
52%
stalk.color.below.ring = yellow
8
Edible
.99 .01
52%
stalk.color.below.ring = brown
16 17
Edible Edible
1.00 .00 .80 .20
51% 1%
habitat = leaves stalk.root = missing
33
Edible
.96 .04
2%
cap.surface = grooves,scaly
32 66 67 34 35 9 5 3
Edible Edible Poison Edible Poison Poison Poison Poison
1.00 .00 1.00 .00 .00 1.00 1.00 .00 .00 1.00 .00 1.00 .00 1.00 .00 1.00
49% 2% 0% 1% 0% 0% 1% 47%
FIGURE 2.20: Decision tree diagram for classifying the edibility of mush-
rooms.
#> 0 30 10
You can change any of these parameters via rpart()’s control argument, or
by passing them directly to rpart() via the ... (pronounced dot-dot-dot)
argument.q For example, the two calls to rpart() below are equivalent. Each
one fits a classification tree but changes the default complexity parameter
from 0.01 to 0 (cp = 0) and the number of internal cross-validations from ten
to five (xval = 5); see ?rpart::rpart.control for further details about all
the configurable parameters.
ctrl <- rpart.control(cp = 0, xval = 5) # can also be a names list
tree <- rpart(Edibility ~ ., data = mushroom, control = ctrl)
tree <- rpart(Edibility ~ ., data = mushroom, cp = 0, xval = 5)
Another useful option in rpart() is the parms argument, which controls how
nodes are split in the treer ; it must be a named list whenever supplied. Below
we print the tree$parms component, which in this case is a list containing
the class priors, loss matrix, and impurity function used in constructing the
tree.
mushroom.tree$parms
#> $prior
#> 1 2
#> 0.518 0.482
#>
#> $loss
#> [,1] [,2]
#> [1,] 0 1
#> [2,] 1 0
#>
#> $split
#> [1] 1
The $prior component defaults to the class frequencies in the root node,
which can easily be verified:
proportions(table(mushroom$Edibility)) # observed class proportions
#>
#> Edible Poison
#> 0.518 0.482
The loss matrix, given by component $loss, defaults to equal losses for
false positives and false negatives (the off diagonals); there’s no loss asso-
ciated with a correct classification (i.e., the diagonal entries are always zero).
q In R, functions can have a special ... argument that allows them to take any number
where rows represent the observed classes and columns represent the assigned
classes. Here, TP, FP, FN, and TN stand for true positive, false positive, false
negative, and true negative, respectively; for example, a false negative is the
case in which the tree misclassifies a 1 as a 0. The order of the rows/columns
correspond to the same order as the categories when sorted alphabetically or
numerically.
Since there is no cost for correct classification, we take T P = T N = 0. Set-
ting F P = F N = c, for some constant c (i.e., treat FPs and FNs equally),
will always result in the same splits (although, the internal statistics used
in selecting the splits will be scaled differently). When misclassification costs
are not equal, specify the appropriate values in the loss matrix. For example,
the following tree would treat false negatives (i.e., misclassifying poisonous
mushrooms as edible) as five times more costly than false positives (i.e., mis-
classifying edible mushrooms as poisonous). We could also obtain the same
tree by computing the altered priors based on this loss matrix and supplying
them via the parms argument, but this is left as an exercise to the reader.
levels(mushroom$Edibility) # inspect order of levels
(loss <- matrix(c(0, 5, 1, 0), nrow = 2)) # loss matrix
rpart(Edibility ~ ., data = mushroom, parms = list("loss" = loss))
mushroom.tree$variable.importance
#> odor spore.print.color
#> 3823.407 2834.187
#> gill.color stalk.surface.above.ring
#> 2322.460 2035.816
#> stalk.surface.below.ring ring.type
#> 2030.555 2026.526
#> stalk.color.below.ring stalk.root
#> 53.933 25.600
#> stalk.color.above.ring veil.color
#> 17.546 16.315
#> cap.surface cap.color
#> 15.360 14.032
#> habitat cap.shape
#> 13.409 3.840
#> gill.attachment
#> 0.585
In many cases, predictors that weren’t used in the tree will have a non-zero
importance score. The reason is that surrogate splits are also incorporated
into the calculation. In particular, a variable may effectively appear in the
tree more than once, either as a primary or surrogate splitter. The variable
importance measure for a particular feature is the sum of the gains associated
with each split for which it was the primary variable, plus the gains (adjusted
for agreement) associated with each split for which it was a surrogate. You can
turn off surrogates by setting maxsurrogate = 0 in rpart.control().
How does k-fold cross-validation (Section 2.5.2) in rpart work? The rpart()
function does internal 10-fold cross-validation by default. According to rpart’s
documentation, 10-fold cross-validation is a reasonable default, and has been
shown to be very reliable for screening out “pure noise” features. The num-
ber of folds (k) can be changed, however, using the xval argument in
rpart.control().
You can visualize the cross-validation results of an "rpart" object using
plotcp(), as illustrated in Figure 2.21 for the mushroom.tree object. A good
rule of thumb in choosing cp for pruning is to use the leftmost value for which
the average cross-validation score lies below the horizontal line; this coincides
with the 1-SE rule discussed in Section 2.5.2.1. The columns labeled "xerror"
and "xstd" provide the cross-validated risk and its corresponding standard
error, respectively (Section 2.5).
plotcp(mushroom.tree, upper = "splits", las = 1) # Figure 2.21
mushroom.tree$cptable # print cross-validation results
#> CP nsplit rel error xerror xstd
#> 1 0.96936 0 1.00000 1.000000 0.011501
#> 2 0.01839 1 0.03064 0.030644 0.002777
94 Binary recursive partitioning with CART
number of splits
0 1 2 3 5 7
1.0
X−val Relative Error
0.8
0.6
0.4
0.2
0.0
Don’t be confused by the fact that the cp values between printcp() (and
hence the $cptable component of an "rpart" object) and plotcp() don’t
match. The latter just plots the geometric means of the CP column listed in
printcp() (these relate to the βi values used in the k-fold cross-validation
procedure described in Section 2.5). Any cp value between two consecutive
rows will produce the same tree. For instance, any cp value between 0.002 and
0.001 will produce a tree with five splits. Also, these correspond to a scaled
version of the complexity values αi from Section 2.5. Note that rpart scales the
CP column, as well as the error columns, by a factor inversely proportional to
the risk of the root node, so that the associated training error ("rel error")
for the root node is always one (i.e., the first row in the table); which in this
case is 1/ (3916/8124) ≈ 2.075. Dividing through by this scaling factor should
return the raw αi values; the first three correspond to the values I computed
by hand back in Section 2.5:
mushroom.tree$cptable[1L:3L, "CP"] / (8124 / 3916)
#> 1 2 3
Software and examples 95
1
Edible
.52 .48
100%
no odor = creosote,fishy,foul,musty,pungent,spicy yes
2 3
Edible Poison
.97 .03 .00 1.00
53% 47%
FIGURE 2.22: Decision tree diagram for classifying the edibility of mush-
rooms; in this case, the result is a decision stump.
The tree diagram displayed in Figure 2.22 provides us with a handy rule of
thumb for classifying mushrooms as either edible or poisonous. If the mush-
room smells fishy, foul, musty, pungent, spicy, or like creosote, it’s likely poi-
sonous. In other words, if it smells bad, don’t eat it!
96 Binary recursive partitioning with CART
In this section, I’ll use rpart to build a regression to the Ames housing data
(Section 1.4.7). I’ll also show how to easily prune an rpart tree using the 1-SE
rule via treemisc’s prune_se() function. The code chunk below loads in the
data before splitting it into train/test sets using a 70/30 split:
ames <- as.data.frame(AmesHousing::make_ames())
ames$Sale_Price <- ames$Sale_Price / 1000 # rescale response
set.seed(2101) # for reproducibility
trn.id <- sample.int(nrow(ames), size = floor(0.7 * nrow(ames)))
ames.trn <- ames[trn.id, ] # training data/learning sample
ames.tst <- ames[-trn.id, ] # test data
A smaller discrepancy, but the pruned tree is slightly less accurate than the
unpruned tree on the test set. So did pruning really help here? It depends
on how you look at it. Both trees are displayed in Figure 2.23 without text
or labels. The unpruned tree has 169 splits while pruning with the 1-SE rule
and cross-validation resulted in a subtree with only 33 splits—a much more
parsimonious tree
FIGURE 2.23: Regression trees for the Ames housing example. Left: unpruned
regression tree. Right: pruned regression tree using 10-fold cross-validation and
the 1-SE rule.
Even with pruning, we still ended up with a subtree that is too complex to
easily interpret. In situations like this, it can be helpful to use different post hoc
interpretability techniques to help the end user interpret the model in a way
more understandable for humans. For instance, it can be quite informative
to look at a plot of variable importance scores, like the Cleveland dot plot
displayed in Figure ??; here, the importance scores are scaled to sum to 1 (see
the code chunk below). From the results, we can see that the overall quality
98 Binary recursive partitioning with CART
Total_Bsmt_SF
Year_Built
Exter_Qual
Gr_Liv_Area
Garage_Area
Kitchen_Qual
Garage_Cars
Bsmt_Qual
Neighborhood
Overall_Qual
FIGURE 2.24: Variable importance plot for the top ten predictors in the
pruned decision tree for the Ames housing data.
While variable importance plots can be useful, they don’t tell us anything
about the nature of the relationship between each feature and the pre-
dicted outcome. For instance, how does the above grade square footage
(Gr_Liv_Area) impact the predicted sale price on average? This is precisely
what partial dependence plots (PDPs) can tell us; see Section 6.2.1. In a nut-
shell, PDPs are low-dimensional graphical renderings of the prediction func-
tion so that the relationship between the outcome and predictors of interest
can be more easily understood. PDPs, along with other interpretability tech-
niques, are discussed in more detail in Chapter 6. For now, I’ll just introduce
the pdp package [Greenwell, 2021b], and show how it can be used to help vi-
sualize the relationship between above grade square footage and the predicted
sale price:
library(ggplot2)
library(pdp)
200
190
Partial dependence
180
170
Note that the y-axis is on the same scale as the response, and in this case,
represents the averaged predicted sale price across the entire learning sample
for a range of values of Gr_Liv_Area. The rug display (or 1-dimensional plot)
on the x-axis shows the distribution of Gr_Liv_Area in the training data, with
a tick mark at the min/max and each decile. As you would expect, larger size
homes are associated with higher average predicted sales. Full details on the
pdp package are given in Greenwell [2017].
Decision trees, especially smaller ones, can be rather self-explanatory. How-
ever, it is often the case that a usefully discriminating tree is too large to
interpret by inspection. Variable importance scores, PDPs, and other inter-
pretibility techniques, can be used to help understand any tree, regardless of
size or complexity; these techniques are even more critical for understanding
the output from more complex models, like the tree-based ensembles discussed
in Chapters 5–8.
100 Binary recursive partitioning with CART
In this example, I’ll revisit the employee attrition data (Section 1.4.6) and
build a classification tree using rpart with altered priors to help understand
drivers of employee attrition and confirm my previous calculations from Sec-
tion 2.2.4.1.
Figure 2.6 showed two classification trees for the employee attrition data, one
using the default priors and one with altered priors based on a specific loss
matrix with unequal misclassification costs. In rpart, you can specify the loss
matrix, priors, or both—it’s quite flexible!
The next code chunk fits three depth-two classification trees to the employee
attrition data. The first tree (tree1) assumes equal misclassification costs
and uses the default (i.e., observed) class priors. The other two trees use
different, but equivalent, approaches: tree2 uses the previously defined loss
matrix from Section 2.2.4.1, while tree3 uses the associated altered priors
I computed by hand back in Section 2.2.4.1. Although the internal statistics
used in constructing each tree differ slightly, both trees are equivalent in terms
of splits and will make the same classifications. The resulting tree diagrams
are displayed in Figure 2.26.
data(attrition, package = "modeldata")
The subtle difference between tree2 and tree3 is that the within-node class
proportions for tree2 are not adjusted for cost/loss; hence, the predicted class
probabilities will not match between the two trees. In essence, the tree based
on the loss matrix (tree2) makes classifications using a predicted probability
Software and examples 101
1 1 1
No Yes Yes
.84 .16 .84 .16 .39 .61
100% 100% 100%
no OverTime = Yes yes no OverTime = Yes yes no OverTime = Yes yes
3 2 2
No No No
.69 .31 .90 .10 .52 .48
28% 72% 72%
MonthlyIncome < 2475 MonthlyIncome < 11e+3 MonthlyIncome < 11e+3
2 6 7 4 5 3 4 5 3
No No Yes No Yes Yes No Yes Yes
.90 .10 .77 .23 .30 .70 .98 .02 .88 .12 .69 .31 .84 .16 .48 .52 .22 .78
72% 24% 5% 11% 60% 28% 11% 60% 28%
FIGURE 2.26: Decision trees fit to the employee attrition data set. Left: de-
fault priors and equal costs. Middle: Unequal costs specified via a loss matrix.
Right: altered priors equivalent to the costs associated with the middle tree.
u In general, the default probability threshold for classification is 1/J, where J is the
number of classes.
102 Binary recursive partitioning with CART
Although I restricted each tree to a max depth of two, the tree on the far
left side of Figure 2.26 actually coincides with the optimal tree I would’ve
obtained using 10-fold cross-validation and the 1-SE rule (assuming unequal
misclassification costs, of course), and shows that having to work overtime,
as well as having a lower monthly income, was associated with the highest
predicted probability of attrition (p = 0.70)
To further illustrate pruning using the 1-SE rule in rpart, let’s look at a tree
based on altered priors (tree3) with the full set of features. Below, I refit the
same altered priors tree using all available features to maximum depth (i.e.,
intentionally overgrow the tree); the cross-validation results and pruned tree
diagram are displayed in Figure 2.27 (p. 109). The top plot in Figure 2.27
shows the cost-complexity pruning results as a function of the number of
splits (top axis). If I were to prune using the 1-SE rule, I would select the tree
corresponding to the point farthest to the left that’s below the horizontal line
(the horizontal line corresponds to 1-SE above the minimum error). In this
case, we would end up with a tree containing just four splits, as seen in the
bottom plot in Figure 2.27.
library(rpart)
The goal of this example is to build an image classifier using a simple deci-
sion tree that incorporates additional information about the true class priors
for a multiclass outcome with J = 26 classes. The letter image recognition
data, which are available in R via the mlbench package [Leisch and Dimitri-
adou., 2021]v , contains 20,000 observations on 17 variables. Each observation
corresponds to a distorted black-and-white rectangular pixel image of a capi-
tal letter from the English alphabet (i.e., A–Z). A total of 16 ordered features
was derived from each image (e.g., statistical moments and edge counts) which
were then scaled to to be integers in the range 0–15. The objective is to identify
each letter using the 16 ordered features. Frey and Slate [1991] first analyzed
the data using Holland-style adaptive classifiers, and reported an accuracy of
just over 80%.
To start, I’ll load the data and split it into train/test sets using a 70/30
split:
library(treemisc) # for prune_se()
Next, I’ll use rpart() to fit a classification tree that’s been pruned using the
1-SE rule with 10-fold cross-validation, and see how accurate it is on the test
sample:
set.seed(1703) # for reproducibility
lr.cart <- rpart(lettr ~ ., data = lr.trn, cp = 0, xval = 10)
lr.cart <- prune_se(lr.cart, se = 1) # prune using 1-SE rule
v The data are also available from the UCI Machine Learning Repository: https:
//archive.ics.uci.edu/ml/datasets/Letter+Recognition.
104 Binary recursive partitioning with CART
As noted in Matloff [2017, Sec. 5.8], this is unrealistic since some letters tend
to appear much more frequently than others in English text. For example,
assuming balanced classes, rpart() uses a prior probability for the letter “A”
of πA = 1/26 ≈ 0.0384, when in fact πA is closer to 0.0855.w
A proper analysis should take these frequencies into account, which is easy
enough to do with classification trees. Fortunately, the correct letter frequen-
cies are conveniently available in the regtools package [Matloff, 2019]. Below,
I’ll refit the model including the updated (and more realistic) class priors. I’ll
then sample the test data so that the resulting class frequencies more accu-
rately reflect the true prior probabilities of each letter:
data(ltrfreqs, package = "regtools")
Finally, let’s compare the two CART fits on the modified test set that reflects
the more correct class frequencies:
w Based on the English letter frequencies reported at http://practicalcryptography.
com.
Discussion 105
While the error of the original model based on equal priors decreased (albeit,
not by much), the tree incorporating true prior information did much better.
Woot!
2.10 Discussion
Small trees are easy to interpret. Decision trees are often hailed as being
simple and interpretable, relative to more complex algorithms. However, this
is really only true for small trees with relatively few splits, like the one from
Figure 2.22 (p. 95).
Trees scale well to large N . Individual decision trees scale incredibly well
to large data sets, especially if most of the features are ordered, or categorical
with relatively few categories. Even in the extreme cases, various shortcuts
and approximations can be used to reduce computational burden (see, for
example, Section 2.4).
The leaves form a natural clustering of the data. All observations that
coinhabit a terminal node necessarily satisfied all the same conditions when
traversing the tree; in this sense, the records within a terminal node should
be similar with respect to the feature values and can be considered nearest
neighbors. We’ll revisit this idea in Section 7.6.
Trees can handle data of all types. Trees can naturally handle data of
mixed types, and categorical features do not necessarily have to be numerically
106 Binary recursive partitioning with CART
re-encoded, like in linear regression or neural networks. Trees are also invariant
to monotone transformations of the predictors; that is, they only care about
the rank order of the values of each ordered feature. For example, there’s no
need to apply logarithmic or square root transformations to any of the features
like you might in a linear model.
Automatic variable selection. CART selects variables and splits one step
at a time (“...like a carpenter that makes stairs”); hence the quote at the
beginning of the chapter. If a variable cannot meaningfully partition the data
into homogeneous subgroups, it will not likely be selected as a primary splitter.
If it does, it’ll likely get snipped off during the pruning phase.
Trees can naturally handle missing data. As discussed in Section 2.7,
CART can naturally avoid many of the problems caused by missing data by
using surrogate splits (i.e., back up splitters that can be used whenever missing
values are encountered during prediction).
Trees are completely nonparametric. CART is fully nonparametric. It
does not require any distributional assumptions, and the user does not have
to specify any parametric form for the model, like in linear regression. It can
also automatically handle nonlinear relationships (although it tends be quite
biased since it uses step function to approximate potentially smooth surfaces)
and interactions.
Large trees are difficult to interpret. Large tree diagrams, like the ones
in Figure 2.23 (p. 97), can be difficult to interpret and are probably not very
useful to the end user. Fortunately, various interpretbility techniques, like
variable importance plots and PDPs, can help alleviate this problem. Such
techniques are the topic of Chapter 6.
CART’s splitting algorithm is quite greedy. CART makes splits that are
locally optimal. That is, the algorithm does not look through all possible tree
structures to globally optimize some performance metric; that would be unfea-
sible, even for a small number of features. Instead, the algorithm recursively
partitions the data by looking for the next best split at each stage. This is
analogous to the difference between forward-stepwise selection and best-subset
selection. Greedy algorithms use a more constrained search and tend to have
lower variance but often pay the price in bias. Chapters 5–8 discuss several
strategies for breaking the bias-variance-tradeoff by combining many decision
trees together.
Splits lower on the tree are less accurate. Data is essentially taken away
after each split, making splits further down the tree less accurate (and noisy)
compared to splits near the root node. This is part of the reason why binary
Discussion 107
splits are used in the first place. While some decision tree algorithms allow
multiway splits (e.g., CHAID and C4.5/C5.0), this is not a good strategy in
general as the data would be fragmented too quickly, and the search for locally
optimal splits becomes more challenging.
Trees contain complex interactions. CART finds splits in a sequential
manner, and all splits in the tree depend on any that came before it. Once
a final tree structure has been identified, the resulting prediction equation
can be written using a linear model in indicator functions. For example, the
prediction equation for the tree diagram in Figure 2.9 (p. 63) can be written
as follows:
\ =75.41 × I (Temp ≥ 82.50) + 55.60 × I (Temp < 82.50 & Wind < 7.15)
Ozone
,
+ 22.33 × I (Temp < 82.50 & Wind ≥ 7.15)
where I (·) is the indicator function that evaluates to one whenever its argu-
ment is true, and zero otherwise. The right-hand side can be re-written more
generally as f (Temp) + f (Temp, Wind), where the second term explicitly mod-
els an interaction effect between Temp and Wind. As you can imagine, a more
complex tree with a larger number of splits easily leads to a model with high-
order interaction effects. The presence of high-order interaction effects can
make interpreting the main effects (i.e., the effect of individual predictors)
more challenging.
Biased variable selection. As briefly discussed in Section 2.4.2, CART’s
split selection strategy is biased towards features with many potential split
points, such as categorical predictors with high cardinality. More contempo-
rary decision tree algorithms, like those discussed in Chapters 3–4, are unbi-
ased in this sense.
Trees are essentially step functions. Trees can have a hard time adapting
to smooth and/or linear response surfaces. Recall the twonorm problem from
Section 1.1.2.5, where the optimal decision boundary is linear. I fit an pruned
rpart tree to the same sample using 10-fold cross-validation and the 1-SE
rule; the resulting decision boundary (along with the optimal Bayes rule) is
displayed in Figure 2.28 (p. 110). Of course, I could increase the number of
splits resulting in smaller steps, but in practice this often leads to overfitting
and poor generalizability. This lack of smoothness causes more problems in
the regression setting.
Trees are noisy. A common criticism of decision trees is that they are consid-
ered unstable predictors; this was also noted in the original CART monograph;
see Breiman et al. [1984, Section 5.5.2]. By unstable, I mean high variance or,
in other words, the tree structure (and therefore predictions) can vary, often
wildly, from one sample to the next. For example, at any node in a particular
108 Binary recursive partitioning with CART
tree, there may be several competing splits that result in nearly the same
decrease in node impurity and different samples may lead to different choices
among these similar performing split contenders.
To illustrate, let’s look at six independent samples of size N = 3, 220 from
the email spam data described in Section 1.4.5 (≈ 70% training sample), and
fit a CART-like tree to each using a maximum of four splits. The results are
displayed in Figure 2.29. Note the difference in structure and split variables
across the six trees.
Fortunately, tree stability isn’t always as problematic as it sounds (or looks).
According to Zhang and Singer [2010, p. 57], “...the real cause for concern [in
practice] regarding tree stability is the psychological effect of the appearance
of a tree.” Even though the structure of the tree can vary from sample to
sample, Breiman et al. [1984, pp. 156–159] argued that competitive trees,
while differing in appearance, can give fairly stable and consistent predictions.
Strategies for improving the stability and performance of decision trees are
discussed in Chapters 5–8.
Instability is not a feature specific to trees, though. For example, traditional
model selection techniques in linear regression—like forward selection, back-
ward elimination, and hybrid variations thereof—all suffer from the same prob-
lem. However, averaging can improve the accuracy of unstable predictors, like
overgrown decision trees, through variance reduction [Breiman, 1996a]; more
on this in Chapter 5.
First and foremost, I highly recommend reading the original CART mono-
graph [Breiman et al., 1984]. For a more approachable and thorough discussion
of CART, I’d recommend Berk [2008, Chap. 3] (note that there’s now a second
edition of this book). I also recommend reading the vignettes accompanying
the rpart package; R users can launch these from an active R session, as men-
tioned throughout this chapter, but they’re also available from rpart’s CRAN
landing page: https://cran.r-project.org/package=rpart. Scikit-learn’s
sklearn.tree module documentation is also pretty solid: https://scikit-
learn.org/stable/modules/tree.html. There’s also a fantastic talk by Dan
Steinberg about CART, called “Data Science Tricks With the Single Decision
Tree,” that can be found on YouTube:
https://www.youtube.com/watch?v=JVbU_tS6zKo&feature=youtu.be.
Recommended reading 109
number of splits
1.3
1.2
X−val Relative Error
1.1
1.0
0.9
0.8
0.7
1
Yes
.39 .61
100%
no OverTime = Yes yes
2
No
.52 .48
72%
TotalWorkingYears < 3
4
No
.58 .42
66%
StockOptionLevel < 1
9
Yes
.47 .53
27%
JobRole = Laboratory_Technician,Research_Scientist,Sales_Executive,Sales_Representative
8 18 19 5 3
No No Yes Yes Yes
.68 .32 .71 .29 .38 .62 .20 .80 .22 .78
39% 10% 17% 6% 28%
FIGURE 2.27: Employee attrition decision tree with altered priors. Top: 10-
fold cross-validation results. Bottom: pruned tree using the 1-SE rule (which
corresponds to the left-most point in the 10-fold cross-validation results that
lies beneath the horizontal dotted line).
110 Binary recursive partitioning with CART
4 CART
Bayes
2
x2
−2
−4
−4 −2 0 2 4
x1
capitalL < 18
remove < 0.055 hp >= 0.4 remove < 0.045 capitalA < 2.3
remove < 0.045
spam spam
free < 0.2 spam nonspam spam nonspam spam your < 0.4 spam
nonspam remove < 0.025
spam
nonspam spam nonspam spam
nonspam
nonspam spam
nonspam spam nonspam spam
FIGURE 2.29: CART-like trees applied to six independent samples from the
email spam data; for plotting purposes, each tree is restricted to just four
splits.
3
Conditional inference trees
The trees that are slow to grow bear the best fruit.
Moliere
3.1 Introduction
111
112 Conditional inference trees
The basic idea behind unbiased recursive partitioning is to separate the split
search into two sequential steps: 1) selecting the primary split variable, then
2) selecting an optimal split point. Typically, the primary splitter is selected
first by comparing appropriate statistical tests (e.g., a chi-square test if both
X and Y are nominal, or a correlation-type test if X and Y are both uni-
variate continuous). Once a splitting variable has been identified, the optimal
split point can be found using any number of approaches, including further
statistical tests, or the impurity measures discussed in Section 2.2.1.
The idea of using statistical tests for split variable selection in recursive parti-
tioning is not new. In fact, the CTree algorithm discussed later in this chapter
Early attempts at unbiased recursive partitioning 113
CART CTree
80%
Selection percentage
60%
40%
20%
0%
ch2 m10 m2 m20 m4 nor uni ch2 m10 m2 m20 m4 nor uni
FIGURE 3.1: Split variable selection frequencies for two regression tree pro-
cedures (CART and CTree); the y-axis shows the proportion of times each
feature—all of which are unrelated to the response—was selected to split the
root node in 5,000 Monte Carlo simulations. A horizontal dashed red line is
given at 1/7, the frequency corresponding to unbiased split variable selection.
was inspired by a number of approaches that came before it, like CHAID
(Section 1.2.1). While these algorithms help to reduce variable selection bias
compared to exhaustive search techniques like CART, most of them only apply
to special circumstances. CHAID, for example, requires both the features and
the response to be categorical, although, CHAID was eventually extended to
handle ordered outcomes. CHAID can be used with ordered features if they’re
binned into categories; this induces a bias and defeats the purpose of unbiased
split variable selection in the first place. CTree, on the other hand, provides a
unified framework for unbiased recursive partitioning that’s applicable to data
measured on arbitrary scales (e.g., continuous, nominal categorical, censored,
ordinal, and multivariate data). Before we introduce the details of CTree,
it will be helpful to have a basic understanding of the conditional inference
framework it relies on.
The following section involves a bit of mathematical detail and notation,
mostly around linear algebra, and can be skipped by the uninterested
reader.
114 Conditional inference trees
This section offers a quick detour into the essentials of a general framework for
conditional inference procedures commonly known as permutation tests; our
discussion follows closely with Hothorn et al. [2006b]a , which is based on the
theoretical results derived by Strasser and Weber [1999] for a special class of
linear statistics. For a more traditional take on (nonparametric) permutation
tests, see, for example, Davison and Hinkley [1997, Sec. 4.3].
N
Suppose {(X i , Y i )}i=1 are a random sample from some population of interest,
where X and Y are from sample spaces X and Y, respectively, which may be
multivariate (hence the bold notation) and measured at arbitrary scales (e.g.,
ordinal, nominal, continuous, etc.). Our primary interest is in testing the null
hypothesis of independence between X and Y , namely,
H0 : D (Y |X) = D (Y ) ,
N
!
>
X
T = vec g (X i ) h (Y i ) ∈ Rpq×1 , (3.1)
i=1
where vec () is the matrix vectorization operator that I’ll explain in the ex-
amples that follow, g : X → Rp×1 is a transformation function specifying
a possible transformation of the X values, and h : Y → Rq×1 is called the
influence function and specifies a possible transformation to the Y values.
Appropriate choices for g () and h () allow us to perform general tests of in-
dependence, including:
a An updated version of this paper is freely available in the “A Lego System for Con-
ditional Inference” vignette accompanying the R package coin [Hothorn et al., 2021c]; to
view, use utils::vignette("LegoCondInf", package = "coin") at the R console.
A quick digression into conditional inference 115
Σh = V (h (Y i ) |S)
N
>
X
= [h (Y i ) − E (h (Y i ) |S)] [h (Y i ) − E (h (Y i ) |S)] .
i=1
Strasser and Weber [1999] derived the conditional mean (µ) and variance (σ 2 )
of T |S, which are:
N
! !
X
µ = E (T |S) = vec g (X i ) µ>
h ∈ Rpq×1 ,
i=1
N
!
N X >
σ = V (T |S) =
2
Vh ⊗ g (X i ) ⊗ g (X i ) (3.2)
N −1 i=1
N
! N
!>
1 X X
− Vh ⊗ g (X i ) ⊗ g (X i ) ∈ Rpq×pq ,
N −1 i=1 i=1
where ⊗ denotes the Kronecker product. While the equations listed in (3.1)–
(3.2) might seem very complex, they simplify quite a bit in many standard
situations—specific examples are given in Sections 3.3.1–3.3.2.
116 Conditional inference trees
The next step is to construct a test statistic for testing H0 . In order to do this,
we can standardize the (possibly multivariate) linear statistic using µ and Σ.
Let c () be a function that maps T ∈ Rpq×1 to the real line (i.e., a scalar or
single number). Hothorn et al. [2006b] suggest using a quadratic form or the
maximum of the absolute values of the standardized linear statistic:
>
cq = cq (T , µ, Σ) = (T − µ) Σ+ (T − µ) ,
(3.3)
T −µ
cm = cm (T , µ, Σ) = max ,
diag (Σ)1/2
The simplest case occurs when X and Y are both univariate continuous vari-
ables; in the univariate case, we dropPthe bold notation and just write X and
N
Y . In this case, p = q = 1 and T = i=1 g (Xi ) h (Yi ). If we take g () and h ()
to be the identity function (e.g., g (Xi ) = Xi ), then
N
X
T = Xi Yi ,
i=1
N
!
X
µ= Xi Ȳ ,
i=1
N N
!2
X X
σ =
2
SY2 Xi2 − SY2 Xi /N,
i=1 i=1
PN PN 2
where Ȳ = i=1 Yi /N and SY2 = i=1 Yi − Ȳ / (N − 1) are the sample
mean and variance of Y , respectively. Since T is univariate, the standardized
√
test statistics (3.3) are cm = (T − µ) / Σ and cq = c2m ; hence, it makes
no difference which test statistic we use in this case, as the results will be
identical.
Let’s revisit the New York air quality data set (Section 1.4.2) to demonstrate
the required computations in R. Let X = Temp and Y = Ozone be the variables
of interest. To test the null hypothesis of general independence between X and
Y at the α = 0.05 level, I’ll use the quadratic test statistic (cq ), and compute a
p-value for the test using an asymptotic χ2 (1) approximation. I’ll also choose
g () and h () to be the identity function. The first line below removes any rows
with a missing response value, which I’ll be using later.
aq <- airquality[!is.na(airquality$Ozone), ]
N <- nrow(aq) # sample size
gX <- aq$Temp # g(X)
gY <- aq$Ozone # h(Y)
Tstat <- sum(gX * gY) # linear statistic
mu <- sum(gX) * mean(gY)
Sigma <- var(gY) * sum(gX ^ 2) - var(gY) * sum(gX) ^ 2 / N
Here we would reject the null hypothesis at the α = 0.05 level (p < 0.001)
and conclude that there is some degree of association between Temp and
Ozone.
For comparison, we can use the independence_test() function from
package coin, which provides a flexible implementation of the con-
ditional inference procedures described in Hothorn et al. [2006b]; see
?coin::independence_test for details. This is demonstrated in the code
snippet below:
library(coin)
independence_test(Ozone ~ Temp, data = aq, teststat = "quadratic")
#>
#> Asymptotic General Independence Test
#>
#> data: Ozone by Temp
#> chi-squared = 56, df = 1, p-value = 7e-14
Happily, we obtain the exact same results. Note that cm and cq will only differ
when the linear statistic (3.1) is multivariate.
b Here we assume that the q categories of X define the rows of the contingency table, but
in general, it does not matter.
A quick digression into conditional inference 119
#> Edibility
#> odor Edible Poison
#> almond 400 0
#> anise 400 0
#> creosote 0 192
#> fishy 0 576
#> foul 0 2160
#> musty 0 36
#> none 3408 120
#> pungent 0 256
#> spicy 0 576
(Tstat <- as.vector(ctab)) # multivariate linear statistic
#> [1] 400 400 0 0 0 0 3408 0 0 0
#> [11] 0 192 576 2160 36 120 256 576
>
g (Xi ) = (0, 0, 0, 0, 1, 0, 0, 0, 0)
and
>
h (Yi ) = (0, 1) .
Again, we can compare the results with the output from coin’s
independence_test() function. Once again, the results are equivalent.
independence_test(Edibility ~ odor, data = mushroom,
teststat = "quadratic")
#>
#> Asymptotic General Independence Test
#>
#> data: Edibility by
#> odor (almond, anise, creosote, fishy, foul, musty, none, pungent...
#> chi-squared = 7659, df = 8, p-value <2e-16
In the previous examples, we used the quadratic form of the test statistic
in (3.3) and its asymptotic chi-square distribution, but what about the maxi-
c Here, I use model.matrix(˜ variable.name - 1) to suppress the intercept—a column
of all ones—and ensure that each category of variable.name gets dummy encoded.
Conditional inference trees 121
mally selected test statistic (cm ) in (3.3)? When X and Y are both univariate
continuous, or binary variables encoded as 0/1, then the choice between cq and
cm makes no difference, since cq = c2m in this case. However, if X and/or Y
are multivariate (e.g., when X and/or Y are multi-level categorical variables),
then the two statistics can lead to different, although usually similar results.
Some guidance on when one test statistic may be more useful than the other is
given in Hothorn et al. [2006b]. For example, if X and Y are both categorical,
then working with cm and the standardized contingency table can be useful
in gaining insight into the association structure between X and Y . For the
general test of independence, it often doesn’t matter which form of the test
statistic you use. As we’ll see in the next few sections, the CTree algorithm
often defaults to using the quadratic test statistic (i.e., cq ) from (3.3).
2) Use Xj , the partitioning variable selected in step 1), to partition the data
into two disjoint subsets (or child nodes), AL and AR . For each possible
split S, a standardized test statistic (3.3) is computed, and the partition
associated with the largest test statistic is used to partition the data into
two child nodes.
3) Repeat steps 1)–2) in a recursive fashion on the resulting child nodes until
the global hypothesis in step 1) cannot be rejected at a prespecified α
level.
[1995] and Wright [1992]—the latter discusses the use of adjusted p-values.
Conditional inference trees 123
To illustrate, let’s write a simple function, called gi.test(), that uses the
conditional inference procedure described in Section 3.3 to test the null hy-
pothesis of general independence between two variables X and Y .e To keep it
simple, this function applies only to univariate continuous variables, and com-
putes an approximate p-value assuming an asymptotic χ2 (1) distribution (see
Section 3.3.1 for details)—although, it would not be too difficult to modify
gi.test() to return approximate p-values using the permutation distribution
instead. The arguments g and h allow for suitable transformations of the vari-
ables x and y, respectively; for example, if the relationship between X and
Y is monotonic, but not necessarily linear, or if we suspect outliers, then we
might consider converting X and/or Y to ranks (e.g., g = rank)—converting
both X and Y to ranks is similar in spirit to conducting a correlation test
based on Spearman’s ρ. Both arguments default to R’s built-in identity()
function, which has no effect on the given values.
gi.test <- function(x, y, g = identity, h = identity) {
xy <- na.omit(cbind(x, y)) # only retain complete cases
gx <- g(xy[, 1L]) # transformation function applied to x
hy <- h(xy[, 2L]) # influence function applied to y
lin <- sum(gx * hy) # linear statistic
mu <- sum(gx) * mean(hy) # conditional expectation
sigma <- var(hy) * sum(gx ^ 2) - # conditional covariance
var(hy) * sum(gx) ^ 2 / length(hy)
c.quad <- ((lin - mu) / sqrt(sigma)) ^ 2 # quadratic test statistic
pval <- 1 - pchisq(c.quad, df = 1) # p-value
c("chisq" = c.quad, "pval" = pval) # return results
}
Continuing with the New York air quality example, let’s see which variable, if
any, is selected to split the root node. Following convention, I’ll use α = 0.05
as the set threshold for failing to reject the global null hypothesis in step 1)
of Algorithm 3.1.
The following code chunk applies the previously defined gi.test() function
to test the null hypothesis of general independence between each of the five
features and the response—if you skipped Section 3.3, then you can think of
this as a simple test of association that defaults to using a test statistic whose
asymptotic distribution (i.e., the approximate distribution for sufficiently large
N ) is χ2 (1). Note that the p.adjust() function mentioned earlier is used
to adjust the resulting p-values to account for multiple tests using a simple
Bonferroni adjustment:
e We could also use the much more flexible independence_test() function from package
coin, but writing your own function can help solidify your basic understanding of how the
procedure actually works.
124 Conditional inference trees
In this example, the predictor associated with the smallest adjusted p-value
is Temp, and since p?Temp ≈ 3.469 × 10−13 < α = 0.05, Temp is the first variable
that will be used to partition the data. The next step is to determine the
optimal split point of Temp to use when partitioning the data (step 2) of
Algorithm 3.1), which will be discussed in Section 3.4.2.
Let’s try a binary classification problem as well. If you followed Section 3.3
and paid close attention, then you might have figured out that our gi.test()
function should also work for 0/1 encoded binary variables.
Using the Swiss banknote data (Section 1.4.1), let’s see which, if any, of the
available features can be used to effectively partition the root node—recall
that all the features are numeric and that the binary response (y) is already
coded as 0 (for genuine banknotes) and 1 (for counterfeit banknotes):
bn <- treemisc::banknote # start with the root node
xnames <- setdiff(names(bn), "y") # feature names
res <- sapply(xnames, FUN = function(x) { # test each feature
gi.test(bn[[x]], y = bn[["y"]])
})
t(res) # print transpose of results (nicer printing)
#> chisq pval
#> length 7.52 6.11e-03
#> left 48.89 2.71e-12
#> right 68.51 1.11e-16
#> bottom 118.61 0.00e+00
#> top 72.22 0.00e+00
Conditional inference trees 125
Hopefully, by this point, you have a basic understanding of how CTree selects
the splitting variable in step 1) of Algorithm 3.1. Let’s now turn our attention
to finding the optimal split condition for the selected splitter.
Once a splitting variable has been selected, the next step is to find the optimal
split point. CTree uses binary splits like those discussed for CART in Chap-
ter 2; in particular, continuous and ordinal variables produce binary splits of
the form x ≤ c vs. x > c, where c is in the domain of x, and categorical
126 Conditional inference trees
Continuing with the New York air quality example, let’s find the optimal split
point for Temp, the feature selected previously in step 1) of Algorithm 3.1
(p. 122), to partition the root node. The code chunk below computes the test
statistics for testing H0 : D (Ozone|Temp ≤ c) = D (Ozone), for each unique
value c of Temp; the results are plotted in Figure 3.2.
set.seed(912) # for reproducibility
xvals <- sort(unique(aq$Temp)) # potential cut points
splits <- matrix(0, nrow = length(xvals), ncol = 2)
colnames(splits) <- c("cutoff", "chisq")
for (i in seq_along(xvals)) {
x <- ifelse(aq$Temp <= xvals[i], 0, 1) # binary indicator
y <- aq$Ozone
# Ignore pathological splits or splits that are too small
if (length(table(x)) < 2 || any(table(x) < 7)) {
res <- NA
} else {
res <- gi.test(x, y)["chisq"]
Conditional inference trees 127
}
splits[i, ] <- c(xvals[i], res) # store cutpoint and test statistic
}
splits <- na.omit(splits)
splits[which.max(splits[, "chisq"]), ]
#> cutoff chisq
#> 82.0 55.3
# Plot the test statistic for each cutoff (Figure 3.2)
plot(splits, type = "b", pch = 19, col = 2, las = 1,
xlab = "Temperature split value (degrees Fahrenheit)",
ylab = "Test statistic")
abline(v = 82, lty = "dashed")
50
40
Test statistic
30
20
10
60 65 70 75 80 85 90
Temperature split value (degrees Fahrenheit)
FIGURE 3.2: Test statistics from gi.test() comparing the two groups of
Ozone values for every binary partition using Temp. A dashed line shows the
optimal split point c = 82.
From the results, we see that the maximum of all the test statistics is cq =
53.282 and is associated with the split point c = 82, giving us our first split
in the tree (i.e., Temp <= 82).
Following Algorithm 3.1, we would continue splitting each of the resulting
child nodes until the global null hypothesis in step 1) of Algorithm 3.1 cannot
be rejected at the specified α level. For example, applying the previous code
to the resulting left child node (Temp ≤ 82) would result in a further partition
using Wind ≤ 6.9. I’ll confirm these calculations using specific CTree software
in Section 3.5.1.
128 Conditional inference trees
3.4.3 Pruning
Unlike CART and other exhaustive search procedures, CTree uses statistical
stopping criteria (e.g., Bonferroni-adjusted p-values) to determine tree size and
does not require pruning—although, pruning can still be beneficial in certain
circumstances (see Section 3.4.5). That’s not to say that CTree doesn’t overfit.
As we’ll see in Sections 3.4.5 and 3.5, the threshold α has a direct impact on
the size and therefore complexity of the tree and is often treated as a tuning
parameter.
Similar to CART, CTree can use the idea of surrogates to handle missing val-
ues, but it is not the default in current implementations of CTree. By default,
observations which can’t be classified to a child node because of missing val-
ues are either 1) randomly assigned according to the node distribution (as in
the newer partykit package [Hothorn and Zeileis, 2021]), or 2) go with the
majority (as in the older party package [Hothorn et al., 2021b].
Observations with missing values in predictor X are simply ignored when com-
puting the associated test statistics during step 1) of Algorithm 3.1 (p. 122).
Similarly, missing values associated with the splitting variable are also ignored
when computing the test statistics in step 2). Once a split has been found,
surrogates can be constructed using an approach similar to the one described
in Section 2.7 for CART, in particular, creating a binary decision stump using
the binary split in question as the response and trying to find the best splits
associated with it using Algorithm 3.1 (p. 122).
used in CTree will only have high power for certain directions of deviation
from independence and depends on the choice of g () and h (). A useful guide
for selecting g () and h () can be found in Table 4 of coin’s “Implement-
ing a Class of Permutation Tests: The coin Package” vignette; to view, use
vignette("Implementation", package = "coin") at the R console.
In the presence of outliers, the general test of independence discussed in Sec-
tion 3.3 would be more powerful at a given sample size and α if g () and
h () converted X and Y to ranks (because ranks are more robust to outlying
observations).
To illustrate, let’s run a quick Monte Carlo experiment. Suppose X = X has
a standard normal distribution and that Y = Y is equal to X with a tad bit of
noise: Y = X + , where ∼ N (0, σ = 0.1). Figure 3.3 shows a scatterplot for
two random samples of size N = 100 generated from X and Y . The left panel
in Figure 3.3 shows a clear association between X and Y . The right panel
shows a scatterplot of the same sample, but with three of the observations
replaced by outliers. Even with the outliers, there is still a clear relationship
between X and Y .
10 10
8 8
6 6
4 4
Y
2 2
0 0
−2 −2
−4 −4
−3 −2 −1 0 1 2 −3 −2 −1 0 1 2
X X
FIGURE 3.3: Scatterplots of two linearly related variables. Left: original sam-
ple. Right: original sample with three observations replaced by outliers.
The code chunk below applies the gi.test() function, using both the identity
and rank transformations, to 100 random samples from X and Y for sample
sizes ranging from 10–100; note that each sample includes three outlying Y
values. For each sample size, the approximate power of each test is computed
as the proportion of times out of 100 that the null hypothesis was rejected
130 Conditional inference trees
at the α = 0.05 level. The results are plotted in Figure 3.4. Clearly, using
ranks provides a more powerful test across the range of sample sizes in this
example.
set.seed(2142) # for reproducibility
N <- seq(from = 5, to = 100, by = 5) # range of sample sizes
res <- sapply(N, FUN = function(n) {
pvals <- replicate(1000, expr = { # simulate 1,000 p-values
x <- rnorm(n, mean = 0) # from each test
y <- x + rnorm(length(x), mean = 0, sd = 0.1)
y[1:3] <- 10 # insert outliers
test1 <- gi.test(x, y) # no transformations
test2 <- gi.test(x, y, g = rank, h = rank) # convert to ranks
c(test1["pval"], test2["pval"]) # extract p-values
})
apply(pvals, MARGIN = 1, FUN = function(x) mean(x < 0.05))
})
1.0
0.8
0.6
Power
0.4
0.2
Rank transformation
No transformation
0.0
20 40 60 80 100
Sample size
FIGURE 3.4: Power vs. sample size for the general test of independence be-
tween two univariate continuous variables, X and Y , using the conditional
inference procedure outlined in Section 3.3. The solid black curve corresponds
to using ranks, whereas the dashed red curve corresponds to the identity (i.e.,
no transformation).
For univariate regression and classification trees, fitted values and predictions
are obtained in the same manner as they are in CART. For example, the
within-terminal node class proportions can be used for class probability esti-
mates. For regression, the within-terminal node sample mean of the response
values can be used as the fitted values or for predicting new observations.
However, CTree is quite flexible and can handle many other situations be-
yond simple classification and regression. With censored outcomes, for exam-
ple, each terminal node can be summarized using the Kaplan-Meier estimator;
see Section 3.5.3. The median survival time can be used for making predic-
tions.
Unlike CART, CTree cannot explicitly take into account prior class probabil-
ities, nor can it account for unequal misclassification costs. However, it can
assign increased weight to specific observations (for brevity, I omitted the case
132 Conditional inference trees
weights from the formulas in Section 3.3, so see the references listed there for
full details). For instance, we can assign higher weights to observations associ-
ated with higher misclassification costs. There is a drawback to this approach,
however. Increasing the case weights essentially corresponds to increasing the
sample size N in the statistical tests used in Algorithm 3.1, which will re-
sult in smaller p-values. Consequently, the resulting tree can be much larger
since more splits will be significant. Decreasing α and/or employing the tuning
strategy discussed in Section 3.4.5 can help, but it can be a difficult balancing
act.
Conditional inference trees are available in the party package via the
ctree() function.f However, the partykit package contains an improved re-
implementation of ctree() and is recommended for fitting individual condi-
tional inference trees in general. The ctree() function in partykit is much
more modular and written almost entirely in R; this flexibility seems to
come at a price, however, as partykit::ctree() can be much slower than
party::ctree() (the latter of which is implemented in C). It’s also worth
noting that partykit is quite extensible and allows you to coerce tree models
from different sources into a standard format that partykit can work with
(e.g., importing trees saved in the predictive model markup language (PMML)
format [Data Mining Group, 2014]). A good example of this is the R package
C5.0 [Kuhn and Quinlan, 2021], which provides an interface to C5.0 (C5.0,
which evolved out of C4.5, is discussed in the online complements).
The R package boot [Canty and Ripley, 2021] can be used to carry out general
permutation tests, as well as more general bootstrap procedures; for permu-
tation tests, see the example code in Davison and Hinkley [1997, Sec. 4.3].
f The package name is apparently a play on the words “partition y”.
Software and examples 133
In earlier sections, we used our own gi.test() function to split the root node
of the airquality data set. Using conditional inference, we found that the
first best split occurred with Temp <= 82. Now, let’s use partykit to apply
Algorithm 3.1, and recursively split the airquality data set until we can no
longer reject the null hypothesis of general independence between Ozone and
any of the five numeric features at the α = 0.05 level.
You can use ctree_control() to specify a number of parameters governing
the CTree algorithm; in the code chunk below, we stick with the defaults; see
?partykit::ctree_control for a description of all the parameters that can
be set. In this chapter, we used the quadratic form of the test statistic cq for
steps 1)–2) of Algorithm 3.1, which is the default in both party and partykit
(teststat = "quad"). We also stick with the default Bonferroni adjusted p-
values. To specify α, set either the alpha or mincriterion arguments in
ctree_control(), where the value of mincriterion corresponds to 1 − α
(only the mincriterion argument is available in party); both packages use a
default significance level of α = 0.05. In party’s implementation of ctree(),
the transformation functions g () and h () can be specified via the xtrafo and
ytrafo arguments, respectively; in partykit’s implementation, only ytrafo
is available.
134 Conditional inference trees
Next, I call ctree() to recursively partition the data and plot the resulting
tree diagram using partykit’s built-in plot() method (see Figure 3.5):
library(partykit)
Note that I again removed the rows with missing response values. The fitted
tree contains four splits (i.e., five terminal nodes) on only two predictors:
Temp and Wind. The plot() method for ctree() objects is quite flexible, and
I encourage you to read the documentation in ?partykit::plot. By default,
the terminal nodes are summarized using an appropriate plot that depends on
the scale of the response variable—in this case, boxplots. The p-values from
step 1) of Algorithm 3.1 are printed in each node, along with the selected
splitting variable and the node number:
In partykit, we can print the test statistics and adjusted p-values associ-
ated with any node using the sctest() function from package strucchange
[Zeileis et al., 2019], which is illustrated below; the 1 specifies the node of in-
terest, which, according to the printed output and tree diagram, corresponds
to the root node.. These correspond to the tests carried out in step 1) of
Algorithm 3.1. The results are a match to our earlier computations using
gi.test() and p.adjust(), woot! As far as I’m aware, you cannot currently
obtain the test statistics from step 2) in partykit, although this is possible
in party’s implementation of ctree(), which I’ll demonstrate next.
Software and examples 135
FIGURE 3.5: A default CTree fit to the New York air quality measurements
data set.
strucchange::sctest(aq.cit, 1)
#> Solar.R Wind Temp Month Day
#> statistic 13.34761 4.16e+01 5.61e+01 3.113 0.0201
#> p.value 0.00129 5.56e-10 3.47e-13 0.333 1.0000
When fitting conditional inference trees with party, the nodes() function can
be used to extract a list of nodes of a tree; the where argument specifies the
node ID (i.e., the node numbers used to label the nodes in the associated tree
diagram). Below, I’ll refit the same tree using party::ctree(), extract the
split associated with the root node, and plot the corresponding test statistics
comparing the different cut points of the split variable (in this case, Temp).
Note that party::ctree() only uses the maximally selected statistic (cm )
for step 2) of Algorithm 3.1g , but recall that in the univariate case, c2m =
cq , so I’ll square them and compare them to the results I plotted earlier in
Figure 3.2 (p. 127). As they should, the results from party::ctree(), which
are displayed in Figure 3.6, match with what I obtained earlier using my
gi.test() function.
aq.cit2 <- party::ctree(Ozone ~ ., data = aq) # refit the same tree
root <- party::nodes(aq.cit2, where = 1)[[1L]] # extract root node
split.stats <- root$psplit$splitstatistic # split statistics
cutpoints <- aq[[root$psplit$variableName]][split.stats > 0]
cq <- split.stats[split.stats > 0] ^ 2
g In contrast, partykit lets you choose which test statistic to use in step 2) of Algo-
rithm 3.1, and defaults to the quadratic form cq we used earlier in gi.test().
136 Conditional inference trees
50
40
Test statistic
30
20
10
60 65 70 75 80 85 90
Temperature split value (degrees Fahrenheit)
FIGURE 3.6: Test statistics from party::ctree() comparing the two groups
of Ozone values for every binary partition using Temp. A dashed line shows the
optimal split point c = 82. Compare these results to those from Figure 3.2.
You can coerce decision trees produced by various other implementations into
"party" objects using the partykit::as.party() function. This means, for
example, we can fit a decision tree using rpart, and visualize it using par-
tykit’s plot() method.
To illustrate, the next code chunk fits an rpart tree to the same aq data,
coerces it to a "party" object, and plots the associated tree diagram. Here,
I’ll set the complexity parameter cp to zero (i.e., no penalty on the size of the
tree) and use the default 10-fold cross-validation along with the 1-SE rule to
prune the tree (Section 2.5.2.1). In this example, CART produced a decision
stump (i.e., a tree with only a single split).
set.seed(1525) # for reproducibility
aq.cart <- rpart::rpart(Ozone ~ ., data = aq, cp = 0)
aq.cart.pruned <- treemisc::prune_se(aq.cart, se = 1) # 1-SE rule
plot(partykit::as.party(aq.cart.pruned))
Software and examples 137
FIGURE 3.7: CART-like decision tree fit to the New York air quality mea-
surements data set. The tree was pruned according to the 1-SE rule discussed
in Section 2.5.2.1.
Next, I’ll fit a (default) conditional inference tree via partykit. Unfortunately,
the tree diagram is too large to print neatly on this page, so I’ll show a printout
of the fitted tree instead:
(reds.cit <- ctree(quality ~ ., data = reds))
#>
#> Model formula:
#> quality ~ fixed.acidity + volatile.acidity + citric.acid + residua...
#> chlorides + free.sulfur.dioxide + total.sulfur.dioxide +
#> density + pH + sulphates + alcohol
#>
#> Fitted party:
#> [1] root
#> | [2] alcohol <= 10.5
#> | | [3] volatile.acidity <= 0.3
#> | | | [4] sulphates <= 0.7: 5 (n = 27, err = 48%)
#> | | | [5] sulphates > 0.7: 6 (n = 58, err = 41%)
#> | | [6] volatile.acidity > 0.3
#> | | | [7] volatile.acidity <= 0.7
#> | | | | [8] alcohol <= 9.8
#> | | | | | [9] total.sulfur.dioxide <= 39: 5 (n = 171, er...
#> | | | | | [10] total.sulfur.dioxide > 39
#> | | | | | | [11] pH <= 3.4: 5 (n = 205, err = 22%)
#> | | | | | | [12] pH > 3.4: 5 (n = 53, err = 42%)
#> | | | | [13] alcohol > 9.8: 6 (n = 228, err = 54%)
#> | | | [14] volatile.acidity > 0.7
#> | | | | [15] fixed.acidity <= 8.5: 5 (n = 172, err = 26%)
#> | | | | [16] fixed.acidity > 8.5: 5 (n = 69, err = 35%)
#> | [17] alcohol > 10.5
#> | | [18] volatile.acidity <= 0.9
#> | | | [19] sulphates <= 0.6
#> | | | | [20] volatile.acidity <= 0.3: 6 (n = 33, err = 45%)
#> | | | | [21] volatile.acidity > 0.3: 6 (n = 207, err = 45%)
#> | | | [22] sulphates > 0.6
#> | | | | [23] alcohol <= 11.5
#> | | | | | [24] total.sulfur.dioxide <= 49
#> | | | | | | [25] volatile.acidity <= 0.4: 7 (n = 72, e...
Software and examples 139
To see how well the model performs (on the learning sample), we can cross-
classify the observed quality ratings with the fitted values (i.e., the prediction
from the learning sample):
p <- predict(reds.cit, newdata = reds) # fitted values
table(predicted = p, observed = reds$quality) # contingency table
#> observed
#> predicted 3 4 5 6 7 8
#> 3 0 0 0 0 0 0
#> 4 0 0 0 0 0 0
#> 5 9 33 483 194 5 0
#> 6 1 20 191 361 85 3
#> 7 0 0 7 83 109 15
#> 8 0 0 0 0 0 0
For example, of all the red wines with a rating quality score of 7, 5 were
predicted to have a quality rating of 5, 85 were predicted to have a quality
rating of 6, and the rest (109) were predicted to have a quality rating of 7.
So which variables seem to be the most predictive of the wine quality
rating? At first glance, alcohol by volume (alcohol) and volatile acidity
(volatile.acidity) seem to be important predictors, as they appear at the
top of the tree and are used multiple times to partition the data. We can
quantify this in CTree using partykit’s varimp() function. This function
computes importance using a permutation-based approach akin to the proce-
dure discussed in Section 6.1.1. For now, just think of the returned importance
scores as an estimate of the decrease in performance as a result of removing
the effect of the predictor in question. By default, performance is measured
by the negative log-likelihood h .
set.seed(2023) # for reproducibility
(vi <- varimp(reds.cit, nperm = 100)) # variable importance scores
dotchart(vi, pch = 19, xlab = "Variable importance") # Figure 3.8
h For
PN
ordinal outcomes in CTree, the log-likelihood is defined as i=1
log (pi ) /N , where
pi is the proportion of observations in the same node as case i sharing the same class.
140 Conditional inference trees
fixed.acidity
pH
total.sulfur.dioxide
sulphates
volatile.acidity
alcohol
FIGURE 3.8: Variable importance plot for the red wine quality conditional
inference tree.
In this example, I’ll revisit the PBC data described in Section 1.4.9. A tree-
based analysis of the data was briefly discussed in Ahn and Loh [1994]. Below
we load the survival package and prepare the data:
i As always, the code to reproduce this plot is available on the book website.
Software and examples 141
3.8
Partial dependence
Partial dependence
Partial dependence
3.6
4.0 3.7
3.4
3.5 3.6
3.2
3.0 3.5
10 12 14 0.4 0.8 1.2 1.6 0.5 1.0 1.5 2.0
alcohol volatile.acidity sulphates
library(survival)
Using α = 0.05, we would reject the null hypothesis (p < 0.001) and conclude
that the level of serum bilirubin is associated with survival rate. But this
doesn’t tell us much beyond that. Do subjects with higher levels of serum
bilirubin tend to survive longer? To answer questions like this, we can use
CTree to recursively partition the data using conditional inference-based tests
of independence between each feature and the log-rank scores:
(pbc2.cit <- partykit::ctree(Surv(time, status) ~ ., data = pbc2))
#>
#> Model formula:
#> Surv(time, status) ~ trt + age + sex + ascites + hepato + spiders +
#> edema + bili + chol + albumin + copper + alk.phos + ast +
#> trig + platelet + protime + stage
#>
#> Fitted party:
#> [1] root
#> | [2] bili <= 1.9
#> | | [3] edema in 0
#> | | | [4] stage <= 2: Inf (n = 61)
#> | | | [5] stage > 2: 4191 (n = 104)
#> | | [6] edema in 0.5, 1: Inf (n = 16)
#> | [7] bili > 1.9
#> | | [8] protime <= 11.2
#> | | | [9] age <= 44.5
#> | | | | [10] bili <= 5.6: 3839 (n = 29)
#> | | | | [11] bili > 5.6: 1080 (n = 7)
Final thoughts 143
Notice how treatment group (drug) was not selected as a splitting variable
at any node. This is not surprising since Fleming and Harrington [1991, p.
2] concluded that there was no practically significant difference between the
survival times of those taking the placebo and those taking the drug.
We can also display the tree diagram using the plot() method; the results are
displayed in Figure 3.10. For censored outcomes, the Kaplan-Meier estimate of
the survival curve is displayed in each node. The tree diagram in Figure 3.10
makes it clear that subjects with higher serum bilirubin levels tended to have
shorter survival times. What other conclusions can you draw from the tree
diagram?
plot(pbc2.cit) # Figure 3.10
happens that even recent papers do not refer to work carried out from 2000
onward, therefore ignoring more than a decade of active development that may
be highly relevant.” Another important factor is software availability. Many
tree algorithms do not have easy to use opensource implementations. For
example, of the 99 tree algorithms considered by Rusch and Zeileis (see their
discussion at the end of Loh [2014]), roughly one-third had free opensource
implementations available (including CART and CTree). CART-like decision
trees are also broadly implemented across a variety of opensource platforms
(see Section 2.9). CTree, on the other hand, is only available in R—as far as
I’m aware.
Should you be concerned about biased variable selection when using CART-
like decision trees? Certainly. However, as pointed out in Loh’s rejoiner in
Loh [2014], “...selection bias may not cause serious harm if a tree model is
used for prediction but not interpretation, in some situations.” While biased
variable selection can lead to more spurious splits on irrelevant features, if the
sample size is large and there are not too many such variables, pruning with
cross-validation is often effective at removing them.
Final thoughts 145
FIGURE 3.10: Conditional inference tree fit to a subset of the Mayo Clinic
PBC data. The terminal nodes are summarized using Kaplan-Meier estimates
of the survival function. The tree diagram highlights potential risk factors
associated with different survival distributions.
4
The hitchhiker’s GUIDE to modern decision
trees
Of all the trees we could have hit, we had to get one that hits
back.
J.K. Rowling
Harry Potter and the Chamber of Secrets
147
148 The hitchhiker’s GUIDE to modern decision trees
4.1 Introduction
GUIDE evolved from earlier tree growing procedures, starting with the fast
and accurate classification tree (FACT) algorithm [Loh and Vanichsetakul,
1988]. FACT was novel at the time for its use of linear discriminate analysis
(LDA) to find splits based on linear combinations of two predictors. FACT
only applies to classification problems, and any node based on linear splits
is partitioned into as many child nodes as there are class labels (i.e., FACT
can use multiway splits). Split variable selection for continuous variables is
based on comparing ANOVA-based F -statistics; LDA is applied to the variable
with the largest F -statistic to find the optimal split point (i.e., x ≤ c vs.
x > c, where c is in the domain of x) to partition the data. Categorical
predictors are converted to ordinal variables by using LDA to project their
dummy encoded vectors onto the largest discriminant coordinate (also called
the canonical variate in canonical analysis); the final splits are expressed back
in the form x ∈ S, where S is a subset of the categories of x. Since FACT
depends on ANOVA F -tests for split variable selection, it is only unbiased if
all the predictors are ordered (i.e., it is biased towards nominal categorical
variables) [Loh, 2014].
The quick, unbiased, and efficient statistical tree (QUEST) procedure [Loh
and Shih, 1997] improves upon the bias in FACT by using chi-square tests for
categorical variables (i.e., by forming contingency tables between each cate-
gorical variable and the response). Like CART, QUEST only permits binary
splits. Let Jt be the number of response categories in any particular node t.
Whenever Jt > 2, QUEST produces binary splits by merging the Jt classes
into two super classes before applying the F - and chi-square tests to select a
splitting variable. The optimal split point for ordered predictors is found us-
ing either an exhaustive search (like in Chapter 2) or quadratic discriminate
analysis (QDA). For categorical splitting variables, the optimal split point is
found in the same way after converting the dummy encoded vectors to the
largest discriminate coordinate as in FACT.
Kim and Loh [2001] introduced the classification rule with unbiased interaction
selection and estimation (CRUISE) algorithm, as a successor to QUEST. In
contrast to QUEST, however, CRUISE allows multiway splits depending on
the number of response categories in a particular node. Also, while QUEST
uses F -tests for ordered variables and chi-square tests for nominal, CRUISE
uses chi-square tests for both after discretizing the ordered variables. For split
variable selection, CRUISE uses a two-step procedure involving testing both
main effects and two-way interactions at each node.
Introduction 149
CRUISE was later succeeded by GUIDE [Loh, 2009]a , which improves upon
both QUEST and CRUISE by retaining their strengths and fixing their main
weaknesses. One of the drawbacks of CRUISE is that the number of interaction
tests greatly outnumbers the number of main effect tests; for example, with k
features, there are k main effect tests and k (k − 1) pairwise interaction tests.
Since most of the p-values come from the interaction tests, CRUISE is biased
towards selecting split variables with potentially weak main effects, relative
to the other predictors. GUIDE, on the other hand, restricts the number of
interaction tests to only those predictors whose main effects are significant
based on Bonferroni-adjusted p-values.
The next two sections cover many of the details associated with GUIDE for
standard regression and classification, respectively; but note that GUIDE has
been extended to handle several other situations as well (e.g., censored out-
comes and longitudinal data). In general, GUIDE uses a two-step procedure
when selecting the splitting variables. Consequently, GUIDE involves many
more steps compared to CART (Chapter 2) and CTree (Chapter 3). The
individual steps themselves are not complicated (most of them involve trans-
formations of continuous features and what to do when interaction tests are
significant), but for brevity, I’m only going to cover the nitty-gritty details,
while pointing to useful references along the way.
Note that the official GUIDE software—which is freely available, but not
open source—has evolved quite a lot over the years; the GUIDE program is
discussed briefly in Section 4.9. Consequently, some of the fine details on the
various GUIDE algorithms may have changed since their original publications.
Any important updates are likely to be found in the revision history for the
official GUIDE software:
http://pages.stat.wisc.edu/~loh/treeprogs/guide/history.txt.
If you’re interested in going deeper on GUIDE for regression and classifica-
tion, I encourage you to read Loh [2002] (the official reference to GUIDE for
regression), Loh [2009] (the official reference to GUIDE for classification), Loh
[2011] (an updated overview with some comparisons to other tree algorithms),
and Loh [2012] (variable selection and importance). The current GUIDE soft-
ware manual is also useful and can be obtained from the GUIDE website:
http://pages.stat.wisc.edu/~loh/guide.html.
a GUIDE had already been introduced for regression problems in Loh [2002].
150 The hitchhiker’s GUIDE to modern decision trees
discretizing the predictors is only used here for split variable selection, and full predictor
information is used in selecting the split point and making predictions.
A GUIDE for regression 151
involved, but the basic steps (skipping the two-way interaction tests) are out-
lined in Algorithm 4.1 (151); GUIDE’s interaction tests are discussed briefly
in Section 4.2.2.
Algorithm 4.1 Simplified version of the original GUIDE algorithm for regres-
sion. Note that some of the details may have changed as the official software
has continued to evolve over the years.
2) Obtain the signed residuals from a constant fit to the data (e.g., the mean
response).
5) Use an exhaustive search to find the best split on x? yielding the greatest
reduction in node SSE (see Section 2.3). By default, GUIDE uses univari-
ate splits similar to CART and CTree. In particular, if x? is unordered,
splits are of the form x? ∈ S, where S is a subset of the categories of x? .
If x? is ordered, splits have the form x? ≤ c, where c is a midpoint in
the observed range of x? ; for speed, GUIDE will optionally use the within
node sample median of x? for the cutoff c.
6) Recursively apply steps 2)–5) on all the resulting child nodes until all
nodes are pure or suitable stopping criteria are met (e.g., the maximum
number of allowable splits is reached).
A few comments regarding step 4) of Algorithm 4.1 are in order. First, any
rows or columns with zero margin totals are removed. Second, to avoid diffi-
culties in computing very small p-values and to account for the fact that the
degrees of freedom are not fixed across the chi-square tests, GUIDE some-
times uses a modification of the Wilson-Hilferty transformation [Wilson and
Hilferty, 1931] to ensure all the test statistics approximately correspond to a
152 The hitchhiker’s GUIDE to modern decision trees
√ √ 2
w1 = 2x − 2ν − 1 + 1 /2,
√ hp
i3
w2 = max 0, 7/9 + ν 3
x/ν = 1 + 2/9ν ,
√
w2 if x < ν + 10 2ν
√
w = (w1 + w2 ) /2 if x ≥ ν + 10 2ν and w2 < x ,
otherwise
w1
then it follows that P r χ2ν > x ≈ P r χ21 > w ; this transformation is imple-
To illustrate, let’s return to the New York air quality example introduced
in Section 1.4.2. Below is a simple function, called guide.chisq.test(), for
carrying out steps 2)–4) of Algorithm 4.1. For brevity, and since the degrees
of freedom are the same for each test, it omits the modified Wilson-Hilferty
transformation discussed previously:
guide.chisq.test <- function(x, y) {
y <- as.factor(sign(y - mean(y))) # discretize response
if (is.numeric(x)) { # discretize numeric features
bins <- quantile(x, probs = c(0.25, 0.5, 0.75), na.rm = TRUE)
bins <- c(-Inf, bins, Inf)
x <- as.factor(findInterval(x, vec = bins)) # quartiles
}
tab <- table(y, x) # form contingency table
if (any(row.sums <- rowSums(tab) == 0)) { # check rows
tab <- tab[-which(row.sums == 0), ] # omit zero margin totals
}
c CTree (Chapter 3) avoids the small p-value problem internally by working with p-values
Next, I omit any rows with missing response values and compute the
Bonferroni-adjusted p-values from step 2) of Algorithm 4.1 for each fea-
ture:
aq <- airquality[!is.na(airquality$Ozone), ]
pvals <- sapply(setdiff(names(aq), "Ozone"), FUN = function(x) {
guide.chisq.test(aq[[x]], y = aq[["Ozone"]])
})
p.adjust(pvals, method = "bonferroni") # Bonferroni adjusted p-values
#> Solar.R Wind Temp Month Day
#> 2.23e-03 1.40e-06 2.50e-14 2.83e-06 5.88e-01
table, where the columns correspond to all possible pairs of values between
the binned xi values and xj . In any of the above cases, rows or columns with
zeros in the margin are omitted before applying the chi-square test.
When including interaction tests, we need to modify how the split variable
x? is selected in step 4) of Algorithm 4.1. If the smallest p-value is from a
curvature test, then select the associated predictor to split the node. If the
smallest p-value comes from an interaction test, then the choice of splitting
variable depends upon whether both features are ordered or not. If xi and xj
are both ordered,
PNt the node is split using the sample mean for each variable
(e.g., xi ≤ i=1 xi /Nt ). For each of the two splits, a constant (e.g., the mean
response) is fit to the resulting nodes. The split yielding the greatest reduction
in SSE (p. 58) is selected to split the node. On the other hand, if either xi or
xj is nominal, select the variable with the smallest p-value from the associated
curvature tests. For details, see Algorithm 2 in Loh [2002].
Using a split variable selected from an interaction test does not guarantee that
the interacting variable will be used to split one of the child nodes. While it
may be intuitive to force this behavior to highlight the specific interaction in
the tree, Loh [2002] argues that letting variables compete at each individual
split can lead to shorter trees.
To illustrate, let’s return to the Ames housing data (Section 1.4.7). Recall
that I initially split the data into train/test sets using a 70/30 split; since
I’m not plotting anything, I did not bother to rescale the response in this
example. Using the GUIDE software (Section 4.9), I built a default regression
tree with stepwise linear regression models in each node.d All variables were
allowed to compete for splits, and all numeric features were allowed to compete
as predictors in the stepwise procedure applied to each node. The tree was
pruned using 10-fold cross-validation along with the 1-SE rule. The resulting
tree diagram is displayed in Figure 4.1—the inner caption is part of the output
from GUIDE and explains the tree diagram.
The 1-SE pruned GUIDE-based tree for the Ames housing data, using non-
constant fits, is substantially smaller than the 1-SE pruned CART tree from
Figure 2.23 (p. 97); it is also far more accurate, with a test set RMSE of
$28,870.78.e For further comparison, CTree, using α = 0.05, resulted in a tree
with 75 terminal nodes and a test RMSE of $35,331.88.
GUIDE will also output a text file containing the variable importance scores
(Section 4.7), estimated regression equation for each terminal node, and more.
d Since linear models are being used to summarize the terminal nodes, it would be wise
to consider log -transforming the response first, or use a similar transformation, since it is
quite right skewed, but for comparison to tree fits from previous chapters, I elected not to
in this example.
e While smaller in size, one could argue that the pruned GUIDE tree is no less inter-
pretable, since the terminal nodes are summarized using regression fits in different subsets
of the predictors.
156 The hitchhiker’s GUIDE to modern decision trees
For example, the sale price of any home with a garage capacity for three or
more cars, excellent basement quality, good basement exposure, and an above
ground living area of less than 2,088 sq. ft. would be estimated according to
the following equation:
\
Sale_Price = −149,100.00 + 283.30First_Flr_SF, (4.1)
where First_Flr_SF is the square footage of the first floor. This corresponds
to terminal node 24 in Figure 4.1.
The output file from GUIDE also reported that the tree in Figure 4.1 explains
roughly 87.89% of the variance in Sale_Price on the training data (i.e., R2 =
0.8789).
Garage_Cars
≤2.50 1
Neighborhood Bsmt_Qual
in S1 2 =Excellent 3
Bsmt_Exposure
4 5 =Gd 6 7
1000 784 148
1.320E+05 1.999E+05 2.459E+05
Gr_Liv_Area
≤2088 12 13
73
3.513E+05
24 25
21 25
3.703E+05 4.380E+05
GUIDE v.38.0 0.25-SE piecewise linear least-squares regression tree with stepwise variable se-
lection for predicting Sale_Price. Tree constructed with 2051 observations. Maximum num-
ber of split levels is 12 and minimum node sample size is 20. At each split, an observation goes
to the left branch if and only if the condition is satisfied. Set S1 = {Briardale, Brookside,
Edwards, Iowa_DOT_and_Rail_Road, Landmark, Meadow_Village, Mitchell, North_Ames,
Northpark_Villa, Old_Town, Sawyer, South_and_West_of_Iowa_State_University}. Cir-
cles with dashed lines are nodes with no significant split variables. Sample size (in italics)
and mean of Sale_Price printed below nodes. Terminal nodes with means above and below
value of 1.802E+05 at root node are colored yellow and purple respectively. Second best
split variable at root node is Neighborhood.
FIGURE 4.1: Example tree diagram produced by GUIDE for the Ames hous-
ing example. Stepwise linear regression models were fit in each node. The
autogenerated caption produced by the GUIDE software is also included.
A GUIDE for classification 157
GUIDE for classification is not that different from its regression counterpart.
Instead of residuals, the categorical outcome is used directly in the chi-square
tests for split variable selection. Also, once a splitting variable, say x? , is se-
lected, the optimal split point is found using an exhaustive search, similar
to CART’s approach based on a weighted sum of Gini impurities (see Sec-
tion 2.2.1).
Although orthogonal (or binary) splits are more interpretable, Loh [2009]
makes a compelling case for splits based on linear combinations of predic-
tors (which are referred to as either linear splits or oblique splits, since they
are no longer orthogonal to the feature axes). An oblique split on two contin-
uous features, xi and xj , takes the form axi + bxj ≤ c, where a, b, and c are
constants determined from the data; see Loh [2009, Sec. 3] for details.
158 The hitchhiker’s GUIDE to modern decision trees
Using orthogonal splits can result in smaller trees and greater predictive ac-
curacy. GUIDE only allows linear splits for classification problemsf and is
restricted to two variables, xi and xj (say), only when an interaction test be-
tween xi and xj is not significant using another Bonferroni correction. The
form of the linear split is chosen using LDA; see Loh [2009, Procedure 3.1] for
details. In the official GUIDE software, oblique spits can be given higher or
lower priority than orthogonal splits (see Section 4.9). Loh [2009] also men-
tions that while oblique splits are more powerful than orthogonal splits, it is
not necessary to apply them to split each node, which he illustrates with an
example on classifying fish species.
Even when linear splits are allowed, Loh [2009] showed that the GUIDE pro-
cedure for classification is still practically unbiased in terms of split variable
selection.
Figure 4.3 shows a scatterplot of the bill length vs. bill depth for the three
species of penguins. While there does seem to be a good deal of separation
between the three species using bill_depth_mm and bill_length_mm, it will
f Breiman et al. [1984, p. 248] argue that splits on linear combinations of predictors are
be challenging for a classification tree that uses splits that are orthogonal to
the x- and y-axes (e.g., CART and CTree). If the data come from a multi-
variate normal distribution with a common covariance matrix across the three
species, then LDA would give the optimal linear decision boundary (if the co-
variance matrices differ between the classes, then QDA would be optimal). If
we cannot make those assumptions, then a tree-based approach using oblique
splits is a good alternative.
60 A el e
h s rap
e
55
50
ll le g h
35
1 16 1 20
ll ep h
FIGURE 4.3: Scatterplot of bill depth (mm) vs. bill length (mm) for the three
species of Palmer penguins.
To illustrate, consider the plots in Figure 4.4, which show the decision bound-
aries from a GUIDE decision tree with linear splits (top left), LDA (top right),
CART (middle left), CTree (middle right), a random forest (bottom left), gra-
dient boosted tree ensemble (bottom right); the latter two are special types of
tree-based ensembles and are discussed in Chapters 7–8. Both GUIDE and
CART were pruned using 10-fold cross-validation with the 1-SE rule (for
CTree, I used the default α = 0.05). Notice the similarity (and simplicity)
of the linear decision boundaries produced by GUIDE and LDA; these models
are likely to generalize better to new data from the same population. Further-
more, GUIDE only misclassified 8 observations, while LDA, CART, and CTree
misclassified 13, 21, and 15 observations, respectively. Compared to CART and
CTree, the tree-based ensembles (bottom row) are a bit more flexible and able
to adapt to linear decision boundaries, but in this case, they’re not as smooth
or simple to explain as the LDA or GUIDE decision boundaries.
The associated tree diagram for the fitted GUIDE tree with linear splits is
shown in Figure 4.5. The GUIDE tree using linear splits is simpler compared
160 The hitchhiker’s GUIDE to modern decision trees
e p
60 60
A el e A el e
h s rap h s rap
e e
55 55
50 50
ll le g h
ll le g h
5 5
0 0
35 35
1 16 1 20 1 16 1 20
ll ep h ll ep h
T T ee
60 60
A el e A el e
h s rap h s rap
e e
55 55
50 50
ll le g h
ll le g h
5 5
0 0
35 35
1 16 1 20 1 16 1 20
ll ep h ll ep h
m e e e ee
60 60
A el e A el e
h s rap h s rap
e e
55 55
50 50
ll le g h
ll le g h
5 5
0 0
35 35
1 16 1 20 1 16 1 20
ll ep h ll ep h
to the associated CART and CTree trees (not shown); the former uses two
splits, while the latter two require three and seven splits, respectively.
0.44
bill_length_mm 0.20
+bill_depth_mm 1 0.36
0.98
0.02 bill_depth_mm
0.00 2 +bill_length_mm 3
Adelie
153
0.00 0.02
0.02 0.95
0.98 6 7 0.03
Gentoo Chinstrap
123 66
GUIDE v.38.0 1.00-SE classification tree for predicting species using linear split
priority, estimated priors and unit misclassification costs. Tree constructed with 342
observations. Maximum number of split levels is 10 and minimum node sample size
is 3. At each split, an observation goes to the left branch if and only if the condition
is satisfied. Intermediate nodes in lightgray indicate linear splits. Predicted classes
and sample sizes printed below terminal nodes; class sample proportions for species
= Adelie, Chinstrap, and Gentoo, respectively, beside nodes.
Like CART, GUIDE can incorporate different priors and unequal misclassifi-
cation costs. See Section 2.2.4 and Loh [2009] for details.
they require more computation than classification trees with constant fits,
Loh [2009] empirically shows that their prediction accuracy is often relatively
high.
Kernel-based fits essentially select the class with the highest kernel density
estimate in each node. For example, if the selected split variable x is ordered,
then for each class in node t, a kernel density estimate fˆ (x) is computed
according to
N
1 X t
x − xi
fˆ (x) = φ ,
Nt h i=1 h
is the bandwidth; here, s and r denote the sample standard deviation and
interquartile range (IQR) of the observed x values. Some motivation for using
this bandwidth, which is more than twice as large as the usual bandwidth
recommended for density estimation, is given in Ghosh et al. [2006].
A similar idea is used when applying k-nearest neighbor fits to each node. In
particular, k is given by
where dxe just means round x up to the nearest integer. The minimum number
of neighbors (k) is three to avoid too much trouble in dealing with ties.
4.4 Pruning
The GUIDE algorithm continues to recursively perform splits until some stop-
ping criteria are met (e.g., the minimum number of observations required for
splitting has been reached in each node). Like CART, this will often lead to
overfitting and an overly complex tree with too many splits. To circumvent
Missing values 163
the issue, GUIDE adopts the same cost-complexity pruning strategy used in
CART (revisit Section 2.5 for the details). Loh [2002] showed via simulation
that pruning can reduce the variable selection bias in exhaustive search pro-
cedures like CART, provided there are no interaction effects.
For constant fits, the fitted values and predictions are obtained in the same way
as CART and CTree. For example, if y is continuous, then the average within-
node response is used. For classification, the within node class proportions can
serve as predicted class probabilities. For non-constant fits, the fitted values
and predictions for new observations depend on the model fit to each node
(e.g., polynomial regression or k-nearest neighbor). The latter tend to produce
shorter and more accurate trees, but are more computationally intensive and
less interpretable.
Variable importance scores in GUIDE are based on the sum of the weighted
one-df chi-square test statistics across the nodes of the tree, where the weights
are given by the square root of the sample size of the corresponding nodes. In
164 The hitchhiker’s GUIDE to modern decision trees
particular, let qt (x) be the one-df chi-square statistic associated with predictor
x at node t. The variable importance score for x is
Xp
VI (x) = Nt qt (x) ,
t
where Nt is the sample size in node t; see Loh [2012] for details. Later, Loh
et al. [2015] suggest using Nt rather than its square root to weight the chi-
square statistics, which increases the probability that the feature used to split
the root node has the largest importance score. This approach is approxi-
mately unbiased unless there is a mix of both ordered and nominal categori-
cal variables. Loh and Zhou [2021] provide an improved version for regression
that ensures unbiasedness.
While there are a number of approaches to computing variable importance
(especially from decision trees), few include thresholds for identifying the ir-
relevant (or pure noise) features. In GUIDE, any feature x with a variable
importance score less than the 0.95-quantile of the approximate null distri-
bution of VI (x) is considered unimportant. For example, running GUIDE in
variable importance mode (with default settings, which caps the total number
of splits at four) on the Ames housing data flagged 64 of the 80 features as be-
ing highly important, 4 as being less important, and 12 as being unimportant.
The results are displayed in Figure 4.6 (p. 174).
As previously mentioned, the GUIDE software continues to evolve and the de-
tails mentioned above may not correspond exactly with what’s currently im-
plemented. Fortunately, Loh and Zhou [2021] give a relatively recent account
of GUIDE’s approach to variable importance, and provide a thorough compar-
ison against other tree-based approaches to computing variable importance,
including CART, CTree, and many of the tree-based ensembles discussed later
in this book.
4.8 Ensembles
Although I’ll defer the discussion of ensembles until Chapter 5, it’s worth not-
ing that GUIDE supports two types of GUIDE-based tree ensembles: bagging
(Section 5.1) and random forest (Chapter 7).
Software and examples 165
GUIDE, along with its predecessors, is not open source and exists as a com-
mand line program. Compiled binaries for QUEST, CRUISE, and GUIDE are
freely available from http://pages.stat.wisc.edu/~loh/guide.html and
are compatible with most major operating systems. If you’re comfortable with
the terminal, GUIDE is straightforward to install and use; see the available
manual for installation specifics and example usage. Although it’s a terminal
application, GUIDE will optionally generate R code that can be used for scor-
ing after a tree has been fit. This makes it easy to run simulations and the
like after the tree has been built. My only criticism of GUIDE is that it’s not
currently available via easy-to-use open source code (like R or Python); if it
were, it’d probably be much more widely adopted by practitioners.
As discussed in the previous section, GUIDE is a command line program that
requires input from the user. For this reason, I’ll limit this section to a single
example; the GUIDE software manual [Loh, 2020] offers plenty of additional
examples. Note that GUIDE can optionally generate R code to reproduce
predictions from the fitted tree model, which can be useful for simulations
and deployment. Further, I used GUIDE v38.0 for all the examples in this
chapter.
The credit card default data [Yeh and hui Lien, 2009], available from the UCI
Machine Learning Repository at
https://archive.ics.uci.edu/ml/datasets/default+of+credit+card+
clients,
contains demographic and payment information about credit card customers
in Taiwan in the year 2005. The data set contains 30,000 observations on the
following 23 variables:
• default: A binary indicator of whether or not the customer defaulted on
their payment (yes or no).
• limit_bal: Amount of credit (NT dollar) given to the customer.
• sex: The customer’s gender (male or female).
• education: The customer’s level of education (graduate school,
university, high school, or other).
• marriage: The customer’s marital status (married, single, or other).
166 The hitchhiker’s GUIDE to modern decision trees
Note that the categorical variables have been numerically re-encoded. The
next code chunk removes the column ID and cleans up some of the categorical
features by re-encoding them from numeric back to the actual categories based
on the provided column descriptions:
# Remove ID column
credit$id <- NULL
Finally, I’ll split the data into train/test sets using a 70/30 split, leaving
21,000 observations for training and 9,000 for estimating the generalization
performance:
set.seed(1342) # for reproducibility
trn.ids <- sample(nrow(credit), size = 0.7 * nrow(credit),
replace = FALSE)
credit.trn <- credit[trn.ids, ]
credit.tst <- credit[-trn.ids, ]
The GUIDE program requires two special text files before it can be
called:
• the data input file;
• a description file.
See the GUIDE reference manual [Loh, 2020] for full details. The data input file
is essentially just a text file containing the training data in a format that can be
consumed by GUIDE. The description file provides some basic metadata, like
the missing value flag and variable roles. These files can be a pain to generate,
especially for data sets with lots of columns, so I included a little helper
function in treemisc to help generate them; see ?treemisc::guide_setup
for argument details.
Below is the code I used to generate the data input and description files for the
credit card default example. By default, numeric columns are used both for
splitting the nodes and for fitting the node regression models for non-constant
fits, and categorical variables are used for splitting only. In my setup, I have
a /guide-v38.0/credit directory containing the GUIDE executable and where
the generated files will be written to:
treemisc::guide_setup(credit.trn, path = "guide-v38.0/credit",
dv = "default", file.name = "credit",
verbose = TRUE)
• credit.txt (the training data input file in the format required by GUIDE);
• credit_desc.txt (the description file).
Below are the contents of the generated credit_desc.txt file. The first line
gives the name of the training data file; if the file is not in the cur-
rent working directory, its full path must be given with quotes (e.g.,
"some/path/to/credit.txt"). The second line specifies the missing value code
(if it contains non-alphanumeric characters, then it too must be quoted). The
remaining lines specify the column number, name, and role for each variable
in the data input file. As you can imagine, creating this file for a data set
with lots of variables can be tedious, hence the reason for writing a helper
function.
credit.txt
NA
2
1 limit_bal n
2 sex c
3 education c
4 marriage c
5 age n
6 pay_0 n
7 pay_2 n
8 pay_3 n
9 pay_4 n
10 pay_5 n
11 pay_6 n
12 bill_amt1 n
13 bill_amt2 n
14 bill_amt3 n
15 bill_amt4 n
16 bill_amt5 n
17 bill_amt6 n
18 pay_amt1 n
19 pay_amt2 n
20 pay_amt3 n
21 pay_amt4 n
22 pay_amt5 n
23 pay_amt6 n
24 default d
With the credit.txt and credit_desc.txt files in hand (and in the appropriate
directories required by GUIDE), we can spin up a terminal and call the GUIDE
program. I’ll omit the details since it’s OS-specific, but the GUIDE reference
manual will take you through each step. Once the program is called, GUIDE
will ask the user for several inputs (e.g., whether to build a classification or
regression tree, whether to use constant or non-constant fits, number of folds
170 The hitchhiker’s GUIDE to modern decision trees
to use for cross-validation, etc.). In the end, GUIDE generates a special input
text file to be consumed by the software.
Below are the contents of the input file for the credit card default example,
called credit_in.txt, highlighting all the options I selected (I basically requested
a default classification tree that’s been pruned using the 1-SE rule with 10-
fold cross-validation, but you can see several of the available options in the
output).
GUIDE (do not edit this file unless you know what you are doing)
38.0 (version of GUIDE that generated this file)
1 (1=model fitting, 2=importance or DIF scoring, 3=data con...
"credit_out.txt" (name of output file)
1 (1=one tree, 2=ensemble)
1 (1=classification, 2=regression, 3=propensity score group...
1 (1=simple model, 2=nearest-neighbor, 3=kernel)
2 (0=linear 1st, 1=univariate 1st, 2=skip linear, 3=skip li...
1 (0=tree with fixed no. of nodes, 1=prune by CV, 2=by test...
"credit_desc.txt" (name of data description file)
10 (number of cross-validations)
1 (1=mean-based CV tree, 2=median-based CV tree)
1.000 (SE number for pruning)
1 (1=estimated priors, 2=equal priors, 3=other priors)
1 (1=unit misclassification costs, 2=other)
2 (1=split point from quantiles, 2=use exhaustive search)
1 (1=default max. number of split levels, 2=specify no. in ...
1 (1=default min. node size, 2=specify min. value in next l...
2 (0=no LaTeX code, 1=tree without node numbers, 2=tree wit...
"credit.tex" (latex file name)
1 (1=color terminal nodes, 2=no colors)
2 (0=#errors, 1=sample sizes, 2=sample proportions, 3=poste...
3 (1=no storage, 2=store fit and split variables, 3=store s...
"credit_splits.txt" (split variable file name)
2 (1=do not save fitted values and node IDs, 2=save in a file)
"credit_fitted.txt" (file name for fitted values and node IDs)
2 (1=do not write R function, 2=write R function)
"credt_pred.R" (R code file)
1 (rank of top variable to split root node)
Now, all you have to do is feed this input file back into the GUIDE program
(again, see the official manual for details). Once the modeling process is com-
plete, you’ll end up with several files depending on the options you specified
during the initial setup. A portion of the output file produced by GUIDE for
my input file is shown below; the corresponding tree diagram is displayed in
Figure 4.7.
Node 1: Intermediate node
A case goes into Node 2 if pay_0 <= 1.5000000
pay_0 mean = -0.15095238E-01
Software and examples 171
The train and test set accuracies for this tree are 81.85% and 82.21%, respec-
tively (I had to use the R function produced by GUIDE to compute the test
accuracy). Despite the reasonably high accuracy, we have a big problem! If
you didn’t first notice when initially exploring the data on your own, then
hopefully you see it now...the model is biased towards predicting yes since
the data are imbalanced (and naturally so, since we’d hope that most people
are not defaulting on their credit card payments).
The original goal was to build a model to predict the probability of defaulting
(default = "yes"), but the train and test accuracy within that specific class
are 32.35% and 33.87%, respectively. By default, GUIDE (and many other
algorithms) treat the misclassification for both types of error (i.e., predicting
a yes as a no and vice versa) as equal. Fortunately, like CART, GUIDE can
incorporate a matrix of misclassification costs into the tree construction (see
Section 2.2.4).
172 The hitchhiker’s GUIDE to modern decision trees
No Y es
0 5
No
L= ,
Y es 1 0
where Li,j denotes the cost of classifying an observation as class i when it really
belongs to class j. Note that GUIDE sorts the class values in alphabetical order
(i.e., "no" then "yes"). Re-running the previous program, but specifying the
cost matrix file when prompted, leads to the much more useful tree structure
shown in Figure 4.8. Although the overall test accuracy dropped from 82.21%
to 60.00%, the accuracy within the class of interest (the true positive rate or
sensitivity) increased from 7.49% to 17.79%—a significant improvement. I’ll
leave it to the reader to explore further with linear splits and non-constant
fits to see if the results can be improved further.
This chapter introduced the GUIDE algorithm for building classification and
regression trees. GUIDE was developed to solve three problems often encoun-
tered with exhaustive search procedures (like CART):
1. split variable selection bias;
2. insensitivity to local interactions;
3. overly complex tree structures.
Like CTree, GUIDE solves the first problem by decoupling the search for split
variables from the split point selection using statistical tests; in contrast to
CTree, GUIDE exclusively uses one-df chi-square tests throughout. In select-
ing the splitting variable, GUIDE also looks at two-way interaction effects that
can potentially mask the importance of a split when only main effects are con-
sidered. Moreover, GUIDE can often produce smaller and more accurate tree
Final thoughts 173
g If you have trouble accessing any of Loh’s papers, many of them are freely available on
Overall_Qual
Neighborhood
Gr_Liv_Area
Full_Bath
Year_Built
Exter_Qual
Garage_Cars
Kitchen_Qual
Foundation
Bsmt_Qual
Garage_Finish
MS_SubClass
Garage_Area
Year_Remod_Add
Garage_Type
Total_Bsmt_SF
Fireplace_Qu
Fireplaces
BsmtFin_Type_1
First_Flr_SF
Open_Porch_SF
Heating_QC
Longitude
Latitude
TotRms_AbvGrd
Exterior_1st
Exterior_2nd
Lot_Area
Mas_Vnr_Type
Mas_Vnr_Area
MS_Zoning
Overall_Cond
Half_Bath
House_Style
Lot_Shape
BsmtFin_SF_1
Wood_Deck_SF
Sale_Condition
Lot_Frontage
Bsmt_Exposure
Sale_Type
Garage_Qual
Garage_Cond
Central_Air
Paved_Drive
Electrical
Bedroom_AbvGr
Fence
Bsmt_Full_Bath
Condition_1
Bsmt_Cond
Enclosed_Porch
Bldg_Type
Second_Flr_SF
Bsmt_Unf_SF
Alley
BsmtFin_Type_2
Kitchen_AbvGr
Land_Contour
Lot_Config
Exter_Cond
Street
Roof_Style
Utilities
Screen_Porch
BsmtFin_SF_2
Functional
Heating
Pool_Area Type
Land_Slope
Low_Qual_Fin_SF H
Pool_QC
Condition_2 L
Misc_Feature
Misc_Val
Bsmt_Half_Bath U
Roof_Matl
Mo_Sold
Three_season_porch
Year_Sold
0 25 50 75
Score
FIGURE 4.6: GUIDE-based variable importance scores for the Ames housing
example. GUIDE distinguished between highly important (H), less important
(L), and unimportant (U).
Final thoughts 175
pay_0
≤1.50 1 0.22
0.17 2 3 0.69
no yes
18829 2171
GUIDE v.38.0 1.00-SE classification tree for predicting default using estimated
priors and unit misclassification costs. Tree constructed with 21000 observations.
Maximum number of split levels is 30 and minimum node sample size is 210. At
each split, an observation goes to the left branch if and only if the condition is
satisfied. Predicted classes and sample sizes printed below terminal nodes; class
sample proportion for default = yes beside nodes. Second best split variable at
root node is pay_2.
FIGURE 4.7: GUIDE-based classification tree for the credit card default ex-
ample. The autogenerated caption produced by the GUIDE software is also
included.
176 The hitchhiker’s GUIDE to modern decision trees
pay_0
≤ 0.5000 1 0.22
pay_amt2
≤2803.5 2 3 0.50
yes
4770
pay_3 pay_4
≤1 4 ≤1 5
pay_2
≤-0.5000 8 9 0.30 0.09 10 11 0.23
yes no yes
895 7055 403
pay_amt4 limit_bal
≤2236 16 ≤75000 17
GUIDE v.38.0 1.00-SE classification tree for predicting default using estimated priors and specified mis-
classification costs. Tree constructed with 21000 observations. Maximum number of split levels is 30 and
minimum node sample size is 210. At each split, an observation goes to the left branch if and only if the con-
dition is satisfied. Predicted classes and sample sizes printed below terminal nodes; class sample proportion
for default = yes beside nodes. Second best split variable at root node is pay_2.
Tree-based ensembles
5
Ensemble algorithms
You know me, I think there ought to be a big old tree right there.
And let’s give him a friend. Everybody needs a friend.
Bob Ross
179
180 Ensemble algorithms
large, diverse group of individuals is as accurate, if not more accurate, than the
answer from any one individual from the group. For an interesting example,
try looking up the phrase "Francis Galton Ox weight guessing" in a search
engine. Another neat example is to ask a large number of individuals to guess
how many jelly beans are in a jar, after you’ve eaten a handful, of course. If
you look at the individual guesses, you’ll likely notice that they vary all over
the place. The average guess, however, tends to be closer than most of the
individual guesses.
In a way, ensembles use the same idea to help improve the predictions (i.e.,
guesses) of an individual model and are among the most powerful supervised
learning algorithms in existence. While there are many different types of en-
sembles, they tend to share the same basic structure:
B
X
fB (x) = β0 + βb fb (x) , (5.1)
b=1
where B is the size of the ensemble, and each member of the ensemble fb (x)
(also called a base learner) is a different function of the input variables derived
from the training data.
In this chapter, our interests lie primarily in using decision trees for the base
learners—typically, CART-like decision trees (Chapter 2), but any tree al-
gorithm will work. As discussed in Hastie et al. [2009, Section 10.2], many
supervised learning algorithms (not just ensembles) can be seen as some form
of additive expansion like (5.1). A single decision tree is one such example of an
additive expansion. For a single tree, fb (x) = fb (x; θb ), where θb collectively
represents the splits and split points leading to the b-th terminal node region,
whose prediction is given by βb (i.e., the terminal node mean response for
ordinary regression trees). Other examples include single-hidden-layer neural
networks and MARS [Friedman, 1991], among others.
There exist many different flavors of ensembles, and they all differ in the
following ways:
• the choice of the base learners fb (x) (although, in this book, the base
learners will always be some form of decision tree);
• how the base learners are derived from the training data;
B
• the method for obtaining the estimated coefficients (or weights) {βb }b=1 .
The ensemble algorithms discussed in this book fall into two broad categories,
to be discussed over the next two sections: bagging (Section 5.1), short for
bootstrap aggregating), and boosting (Section 5.2). First, I’ll discuss bagging,
one of the simplest approaches to constructing an ensemble.
Bootstrap aggregating (bagging) 181
Y = sin (X) + ,
where X ∼ U (0, 2π) and ∼ N (0, σ = 0.3). Figure 5.1 (left) shows the pre-
diction surface from a single (overfit) decision tree grown to near full depth.a
In contrast, Figure 5.1 (right) shows a bagged ensemble of B = 1000 such trees
whose predictions have been averaged together; here, each tree was induced
from a different bootstrap sample of the original data points. Clearly the indi-
vidual tree is too complex (i.e., low bias and high variance) and will not gen-
eralize well to new samples, but averaging many such trees together resulted
in a smoother, more stable prediction. The MSE from an independent test set
of 10,000 observations was 0.173 for the single tree and 0.1 for the bagged tree
ensemble; the optimal MSE for this example is σ 2 = 0.32 = 0.09.
The general steps for bagging classification and regression trees are outlined
in Algorithm 5.1. To help further illustrate, a simple schematic of the process
for building a bagged tree ensemble with four trees is given in Figure 5.2. Note
that bagged tree ensembles can be extended beyond simple classification and
regression trees. For example, it is also possible to bag survival trees [Hothorn
a In this example, each tree was fit using rpart() with minsplit = 2 and cp = 0.
182 Ensemble algorithms
1 1
0 0
y
y
1 1
0 2 6 0 2 6
FIGURE 5.1: Simulated sine wave example (N = 500). Left: a single (over-
grown) regression tree. Right: a bagged ensemble of B = 1000 overgrown re-
gression trees whose predictions have been averaged together; here each tree
was induced from a different bootstrap sample of the original data points. The
individual tree is too complex (i.e., low bias and high variance) but averaging
many such trees together results in a more stable prediction and smoother fit.
N
1) Start with a training sample, dtrn = {(xi , yi )}i=1 , and specify integers
nmin (the minimum node size of a particular tree), and B (the number of
trees in the ensemble).
2) For b in 1, 2, . . . , B:
b) Optional: Keep track of which observations from dtrn were not se-
lected to be in d?trn ; these are called the out-of-bag (OOB) observa-
tions.
4) To obtain the bagged prediction for a new case x, denoted fbB (x), pass the
observation down each tree—which will result in B separate predictions
(one from each tree)—and aggregate as follows:
B
• Classification: fbB (x) = vote {Tb (x)}b=1 , where Tb (x) is the pre-
dicted class label for x from the b-th tree in the ensemble (in other
words, let each tree vote on the classification for x and take a major-
ity/plurality vote at the end).
PB
• Regression: fbB (x) = B1 b=1 Tb (x) (in other words, we just average
the predictions for case x across all the trees in the ensemble).
B
Bagging has the same structural form as (5.1) with β0 = 0 and {βb = 1/B}b=1 ,
and where each tree is induced from an independent bootstrap sample of
the original training data and grown to near maximal depth (as specified by
nmin ).
An important aspect of how the trees are constructed in bagging is that they
are induced from independent bootstrap samples, which makes the bagging
procedure trivial to parallelize. See Boehmke and Greenwell [2020, Sec. 10.4]
for details and an example using the Ames housing data (Section 1.4.7) in
R using the wonderful foreach package [Revolution Analytics and Weston,
2020].
184 Ensemble algorithms
FIGURE 5.2: A simple schematic of the process for building a bagged tree
ensemble with four trees.
To illustrate, let’s return to the email spam example first introduced in Sec-
tion 1.4.5. In the code snippet below, we load the data from the kernlab
Bootstrap aggregating (bagging) 185
package and split the observations into train/test sets using the same 70/30
split as before:
data(spam, package = "kernlab")
set.seed(852) # for reproducibility
id <- sample.int(nrow(spam), size = floor(0.7 * nrow(spam)))
spam.trn <- spam[id, ] # training data
spam.tst <- spam[-id, ] # test data
Rather than writing our own bagger function, I’ll construct a bagged tree
ensemble using a basic for loop that stores the individual trees in a list
called spam.bag. Note that I turn off cross-validation (xval = 0) when calling
rpart() to save on computing time. The code is shown below.
library(rpart)
Now that we have the individual trees, each of which was fit to a different
bootstrap sample from the training data, we can obtain predictions and assess
the performance of the ensemble using the test sample. To that end, I’ll loop
through each tree to obtain predictions on the test set (spam.tst), and store
the results in an N × B matrix, one column for each tree in the ensemble. I
then compute the test error as a function of B by cumulatively aggregating
the predictions from trees 1 through B by means of voting (e.g., if we are
computing the bagged prediction using only the first three trees, the final
prediction for each observation will simply be the the class with the most
votes across the three trees).
To help with the computations, I’ll write two small helper functions, vote()
and err(), for carrying out the voting and computing the misclassification
error, respectively:
vote <- function(x) names(which.max(table(x)))
err <- function(pred, obs) 1 - sum(diag(table(pred, obs))) /
length(obs)
The results are displayed in Figure 5.3. The error stabilizes after around 200
trees and achieves a minimum misclassification error rate of 4.85% (horizontal
dashed line). For reference, a single tree (pruned using the 1-SE rule) achieved
a test error of 9.99%. Averaging the predictions from several hundred over-
grown trees cut the misclassification error by more than half!
0.10
0.09
0.08
Test error
0.07
0.06
0.05
FIGURE 5.3: Test misclassification error for the email spam bagging example.
The error stabilizes after around 200 trees and achieves a minimum misclas-
sification error rate of 4.85% (horizontal dashed line).
While bagging was quite successful in the email spam example, sometimes
bagging can make things worse. For a good discussion on how bagging can
worsen bias and/or variance, see Berk [2008, Sec. 4.5.2–4.5.3].
Bootstrap aggregating (bagging) 187
Inducing trees from independent learning sets that are bootstrap samples
from the original training data imitates the process of building trees on in-
dependent samples of size N from the true underlying population of interest.
While bagging traditionally utilizes bootstrap sampling (i.e., sampling with
replacement) for training the individual base learners, it can sometimes be ad-
vantageous to use subsampling without replacement; Breiman [1999] referred
to this as pasting. In particular, if N is “large enough,” then bagging using
random subsamples of size N/2 (i.e., sampling without replacement) can be
an effective alternative to bagging based on the bootstrap [Friedman and Hall,
2007]. Strobl et al. [2007b] suggest using a subsample size of 0.632 times the
original sample size N —because in bootstrap sampling about 63.2% of the
original observations end up in any particular bootstrap sample.
This is quite fortunate since sampling half the data without replacement is
much more efficient and can dramatically speed up the bagging process. Ap-
plying this to the email spam data from the previous section, which only
required modifying one line of code in the previous example, resulted in a
minimum test error of 5.21%, quite comparable to the previous results using
the bootstrap but much faster to train.
Another reason why subsampling can sometimes improve the performance
of bagging is through “de-correlation”. Recall that bagging can improve the
performance of unstable learners through variance reduction. As discussed in
more detail in Section 7.2, correlation limits the variance-reducing effect of
averaging. The problem here is that the trees in a bagged ensemble will often
be correlated since they are all induced off of bootstrap samples from the same
training set (i.e., they will share similar splits and structure, to some degree).
Using subsamples of size N/2 will help to de-correlate the trees which can
further reduce variance, resulting in improved generalization performance. A
more effective strategy to de-correlate trees in a bagged ensemble is discussed
in Section 7.2.
Bagged tree ensembles are convenient because they don’t require much tun-
ing. That’s not to say that you can’t improve performance by tuning some of
the tree parameters (e.g., tree depth). However, in contrast to gradient tree
boosting (Chapter 8), increasing the number of trees (B) does not necessar-
ily lead to overfitting (see Figure 5.4 on page 193), and isn’t really a tuning
parameter—although, computation time increases with B, so it can be ad-
vantageous to monitor performance on a validation set to determine when
performance has plateaued or reached a point of diminishing return.
188 Ensemble algorithms
5.1.5 Software
5.2 Boosting
B
!
X
C (x) = sign αb Cb (x) ,
b=1
B
where {αb }b=1 are coefficients that weight the contribution of each respective
base learner Cb (x) and
+1 if x > 0
sign (x) = −1 if x < 0 .
0 otherwise
In essence, classifiers in the sequence with higher accuracy receive more weight
and therefore have more influence on the final classification C (x).
The details of AdaBoost.M1 are given in Algorithm 5.2. The crux of the idea
is this: start with an initial classifier built from the training data using equal
N
case weights {wi = 1/N }i=1 , then increase wi for those cases that have been
most frequently misclassified. The process is continued a fixed number of times
(B).
Like bagging, boosting is a meta-algorithm that can be applied to any type of
model, but it’s often most successfully applied to shallow decision trees (i.e.,
decision trees with relatively few splits/terminal nodes). While bagging relies
upon aggregating the results from several unstable learners, boosting tends to
benefit from sequentially improving the performance of a weak learner (like
a simple decision stump). In the next section, I’ll code up Algorithm 5.2 and
apply it to the email spam data for comparison with the previously obtained
bagged tree ensemble.
While AdaBoost.M1 was one of the most accurate classifiers at the timec ,
the fact that it only produced a classification was a severe limitation. To that
end, Friedman et al. [2000] generalized the AdaBoost.M1 algorithm so that the
weak learners return a class probability estimate, as opposed to a discrete class
c In fact, shortly after its introduction, Leo Breiman referred to AdaBoost as the “...best
N
1) Initialize case weights {wi = 1/N }i=1 .
2) For b = 1, 2, . . . , B:
label; the contribution to the final classifier is half the logit-transform of this
probability estimate. They refer to this procedure as Real AdaBoost. Other
generalizations (e.g., to multi-class outcomes) also exist. In Chapter 8, I’ll
discuss a much more flexible flavor of boosting, called stochastic gradient tree
boosting, which can naturally handle general outcome types (e.g., continuous,
binary, Poisson counts, censored, etc.).
Following the previous example on bagging, I’ll use a simple for loop and
list() to sequentially construct and and store the fitted trees, respectively.
Boosting 191
B
For AdaBoost.M1, we also have to collect and store the {αb }b=1 coefficients
in order to make predictions later. Note that predict.rpart() returns a
factor—in this case, with factor levels "-1" and "1"—which needs to be
coerced to numeric before further processing; this is the purpose of the
fac2num() helper function in the code belowd :
library(rpart)
Next, I’ll generate predictions for the test data (spam.tst) using the first
b trees (where b will be varied over the range 1, 2, . . . , B) and compute the
misclassification error for each; note that I’m using the same err() function
defined in the previous example for bagging:
spam.ada.preds <- sapply(seq_len(B), FUN = function(i) {
class.labels <- predict(spam.ada[[i]], newdata = spam.tst,
type = "class")
alpha[i] * fac2num(class.labels)
}) # (N x B) matrix of un-aggregated predictions
})
min(spam.ada.err) # minimum misclassification error
The results are plotted in Figure 5.4, along with those from the previously ob-
tained bagged tree ensembles (i.e., using sampling with/without replacement).
The minimum test error from the AdaBoost.M1 ensemble is 0.041. Compare
this to the bagged tree ensemble based on sampling with replacement, which
achieved a minimum test error of 0.049. In this case, AdaBoost.M1 slightly
outperforms bagging.
For comparison, let’s see how a single depth-10 decision tree—the base learner
for our AdaBoost.M1 ensemble—performs on the same data.
spam.tree.10 <- rpart(type ~ ., data = spam.trn,
maxdepth = 10, method = "class")
pred <- predict(spam.tree.10, newdata = spam.tst, type = "class")
pred <- as.numeric(as.character(pred)) # coerce to numeric
mean(pred != spam.tst$type)
#> [1] 0.12
5.2.3 Tuning
Aside from bagging, additive expansions like (5.1) are often fit by minimizing
some loss function e , like least squares loss,
eA loss function measures the error in predicting f (x) instead of y.
Boosting 193
0.12
AdaBoost.M1
Bagging
Bagging (N/2)
0.10
Misclassification error
0.08
0.06
0.04
FIGURE 5.4: Misclassification error on the email spam test set from several
different tree ensembles: 1) an AdaBoost.M1 classifier with depth-10 classifi-
cation trees (black curve), 2) a bagged tree ensemble using max depth trees
and sampling with replacement (yellow curve), and 3) a bagged tree ensemble
using max depth trees and subsampling with replacement (blue curve). The
horizontal dashed lines represent the minimum test error obtained by each
ensemble.
N B
!
X X
min L yi , βb fb (xi ; θb ) .
{βb ,θb }B
b=1 i=1 b=1
For many combinations of loss functions and base learners, the solution can
involve complicated and expensive numerical techniques. Fortunately, a simple
approximation can often be used when it is more feasible to solve the opti-
mization problem for a single base learner. This approximate solution is called
stagewise additive modeling, the details of which are listed in Algorithm 5.3
below.
Friedman et al. [2000] show that AdaBoost.M1 (Algorithm 5.2) is equivalent to
forward stagewise additive modeling using the exponential loss function
2) For b = 1, 2, . . . , B
5.2.5 Software
Gradient boosted trees ≥ Random forest > Bagged trees > Single tree.
So, while boosted tree ensembles tend to outperform their bagged counter-
parts, I don’t often find the performance increase to be worth the added com-
plexity and time associated with the additional tuning. It’s a trade off that
we all must take into consideration for the problem at hand. It should also be
noted that sometimes a single decision tree is the right tool for the job, and
an ensemble thereof would be overkill; see, for example, Section 7.9.1.
Recall from Section 2.8 that the relative importance of predictor x is essentially
the sum of the squared improvements over all internal nodes of the tree for
which x was chosen as the partitioning variable. This idea also extends to
ensembles of decision trees, such as bagged and boosted tree ensembles. In
ensembles, the improvement score for each predictor is averaged across all
the trees in the ensemble. Because of the stabilizing effect of averaging, the
aggregated tree-based variable importance score is often more reliable in large
ensembles; see Hastie et al. [2009, p. 368], although, as we’ll see in Chapter 7,
split variable selection bias will also affect the variable importance scores often
produced by tree-based ensembles using CART-like decision trees.
f By “well” here, I mean close to how well they would perform with optimal tuning; the
N
" B
# B
X X X
min L yi , fbb (xi ) βb + λ |βb | ,
{βb }B
b=1 i=1 b=1 b=1
where fbb (xi ) (b = 1, 2, . . . , B) is the prediction(s) from the b-th tree for ob-
servation i, βb are fixed, but unknown coefficients to be estimated via the
LASSO, and λ is the L1 -penalty to be applied.
The wonderful and efficient glmnet package [Friedman et al., 2021] for R can
be used to fit the entire LASSO regularization pathg ; that is it efficiently com-
putes the estimated model coefficients for an entire grid of relevant λ values.
g The glmnet package actually implements the entire elastic net regularization path for
many types of generalized linear models. The LASSO is just a special case of the elastic
net, which combines both the LASSO and ridge (i.e., L2 ) penalties.
Importance sampled learning ensembles 197
To illustrate, let’s return to the Ames housing example. Below, I’ll load the
data into R and apply the same 70/30 split from the previous example. Note
that I continue to rescale Sale_Price by dividing by 1000; this is strictly for
plotting purposes.
ames <- as.data.frame(AmesHousing::make_ames())
ames$Sale_Price <- ames$Sale_Price / 1000 # rescale response
set.seed(2101) # for reproducibility
id <- sample.int(nrow(ames), size = floor(0.7 * nrow(ames)))
ames.trn <- ames[id, ] # training data/learning sample
ames.tst <- ames[-id, ] # test data
ames.xtst <- subset(ames.tst, select = -Sale_Price) # features only
Next, I’ll fit a bagged tree ensemble using the randomForest package
[Breiman et al., 2018] (computational reasons for doing so are discussed in Sec-
tion 5.1.5). Random forest, and its open source implementations, are not dis-
cussed until Chapter 7. For now, just note that the randomForest package,
among others, can be used to implement bagged tree ensembles by tweaking
a special parameter, often referred to as mtry (to be discussed in Section 7.2),
and setting this parameter equal to the number of total predictors will result
in an ordinary bagged tree ensemble). This will be much more efficient than
relying on the ipred package and will also allow us to obtain predictions from
the individual trees, rather than just the aggregated predictions. Examples of
post-processing an RF and boosted tree ensemble are given in Sections 7.9.2
and 8.9.3, respectively.
Here, I’ll fit two models, each containing B = 500 trees:
• a standard bagged tree ensemble where each tree is fully grown to boot-
strap samples of size N (ames.bag);
198 Ensemble algorithms
h Note that there are better ways to benchmark and time expressions in R; see, for
Next, I’ll use glmnet to post-process each ensemble using the LASSO. The
following steps are conveniently handled by treemisc’s isle_post() function,
which I’ll use to post-process the ames.bag.6.5 ensemble. But first, I think
it’s prudent to show the individual steps using the ames.bag ensemble.
To start, I’ll compute the individual tree predictions for the train and test
sets and store them in a matrix
preds.trn <- predict(ames.bag, newdata = ames.trn,
predict.all = TRUE)$individual
preds.tst <- predict(ames.bag, newdata = ames.tst,
predict.all = TRUE)$individual
Next, I’ll use the glmnet() function to fit the entire regularization path using
the training predictions from the B = 500 individual trees:
library(glmnet)
A few things to note about the above code chunk are in order. Since this
is a regression problem, I set family = "gaussian" (for least squares) in
200 Ensemble algorithms
the call to glmnet(). Second, since the individual tree predictions are all on
the same scale, there’s no need to standardize the inputs (standardize =
FALSE). Lastly, we could argue that the estimated coefficients (one for each
tree) should be non-negative (lower.limits = 0).
Figure 5.5 shows the regularization path for the estimated coefficients. In par-
ticular, the λ values (on the log scale) are plotted on the x-axis, and the y-axis
corresponds to the estimated coefficient value (one curve per coefficient/tree).
The top axis highlights the number of non-zero coefficients at each particular
value of the penalty parameter λ:
97 95 93 77 16
0.0
0.06
e s
0.0
e
0.02
0.00
0 2 6
g a a
# List of results
ames.bag.post <- as.data.frame(cbind(
"ntree" = lasso.ames.bag$df, perf,
"lambda" = lasso.ames.bag$lambda)
)
According to the test MSE, the optimal value of the penalty parameter λ is
1.077, which corresponds to 97 trees or non-zero coefficients in the LASSO
model (an appreciable reduction from the original 500).
In the next code chunk, I’ll follow the exact same process with the
ames.bag.6.5 ensemble, but using the isle_post() function instead:
library(treemisc)
The overall results from each ensemble are shown in Figure 5.6. Here, I show
the MSE as a function of the number of trees from each model (or non-
zero coefficients in the LASSO). In this example, the simpler ames.bag.6.5
ensemble benefits substantially from post-processing and appears to perform
on par with the ordinary bagged tree ensemble (ames.bag) in terms of MSE,
while requiring only a small fraction of trees and being orders of magnitude
faster to train! The original ensemble (ames.bag) did not see nearly as much
improvement from post-processing.
202 Ensemble algorithms
ames.bag
2500 ames.bag (post)
ames.bag.6.5
ames.bag.6.5 (post)
2000
Test MSE
1500
1000
FIGURE 5.6: MSE for the test data from several bagged tree ensembles. The
dashed lines correspond to the LASSO-based post-processed versions. Clearly,
the ames.bag.6.5 ensemble benefits the most from post-processing, perform-
ing nearly on par with the standard bagged tree ensemble (ames.bag).
Loh [2014] compared the accuracy of single decision trees to tree ensembles
using both real and simulated data sets. He found that, on average, the best
single-tree algorithm was about 10% less accurate than that of a tree ensemble.
Nonetheless, tree ensembles will not always outperform a simpler individual
tree [Loh, 2009]. These points aside, tree ensembles are a powerful class of mod-
els that are highly competitive in terms of state-of-the-art prediction accuracy.
Chapters 7–8 are devoted to two powerful tree ensemble techniques.
It is also worth pointing out that while tree-based ensembles often out perform
carefully tuned individual trees (like CART, CTree, and GUIDE), they are less
interpretable compared to a single decision tree; hence, they are often referred
to as black box models. Fortunately, post-hoc procedures exist that can help
us peek into the black box to understand the relationships uncovered by the
model and explain their output to others. This is the topic of Chapter 6.
6
Peeking inside the “black box”: post-hoc
interpretability
Wordions
This chapter is dedicated to select topics from the increasingly popular field
of interpretable machine learning (IML), which easily deserves its own book-
length treatment, and it has; see, for example, Molnar [2019] and Biecek and
Burzykowski [2021] (both of which are freely available online). The methods
covered in this chapter can be categorized into whether they help interpret
a black box model at a global or local (e.g., individual row or prediction)
level.
To be honest, I don’t really like the term “black box,” especially when we
now have access to a rich ecosystem of interpretability tools. For example,
linear regression models are often hailed as interpretable models. Sure, but
this is really only true when the model has a simple form. Once you start
including transformations and interaction effects—which are often required
to boost accuracy and meet assumptions—the coefficients become much less
interpretable.
Tree-based ensembles, especially the ones discussed in the next two chapters,
can provide state-of-the-art performance, and are quite competitive with other
popular supervised learning algorithms, especially on tabular data sets. Even
when tree-based ensembles perform as advertised, there’s a price to be paid
in terms of parsimony, as we lose the ability to summarize the model using a
simple tree diagram. Luckily, there exist a number of post-hoc techniques that
allow us to tease the same information out of an ensemble of trees that we
would ordinarily be able to glean from looking at a simple tree diagram (e.g.,
which variables seem to be the most important, the effect of each predictor,
203
204 Peeking inside the “black box”: post-hoc interpretability
and potential interactions). Note that the techniques discussed in this chapter
are model-agnostic, meaning they can be applied to any type of supervised
learning algorithm, not just tree-based ensembles. For example, they can also
be used to help interpret neural networks or a more complicated tree structure
that uses linear splits or non-constant models in the terminal nodes.
The next three sections cover post-hoc methods to help comprehend various
aspects of any fitted model:
• feature importance (Section 6.1);
• feature effects (Section 6.2);
• feature contributions (Section 6.3).
For the purposes of this chapter, we can think of variable importance (VI)
as the extent to which a feature has a “meaningful” impact on the predicted
outcome. A more formal definition and treatment can be found in van der
Laan [2006]. Given that point of view, a natural way to assess the impact of
an arbitrary feature xj is to remove it from the training data and examine
the drop in performance that occurs after refitting the model without it. This
procedure is referred to as leave-one-covariate-out (LOCO) importance; see
Hooker et al. [2019] and the references therein.
Obviously, the LOCO importance method is computationally prohibitive for
larger data sets and complex fitting procedures because it requires retrain-
ing the model once more for each dropped feature. In the next section, I’ll
discuss an approximate approach based on reassessing performance after ran-
domly permuting each feature (one at a time). This procedure is referred to
as permutation importance.
While some algorithms, like tree-based models, have a natural way of quanti-
fying the importance of each predictor, it is useful to have a model-agnostic
procedure that can be used for any type of supervised learning algorithm.
This also makes it possible to directly compare the importance of features
across different types of models. In this section, I’ll discuss a popular method
for measuring the importance of predictors in any supervised learning model
called permutation importance.
Feature importance 205
1) For i = 1, 2, . . . , p:
(c) Record the difference from baseline using VI (xi ) = Mperm − Morig .
permuting more than one feature at a time); this would be useful if features
can be categorized into mutually exclusive groups, for instance, categorical
features that have been one-hot encoded.
6.1.2 Software
To illustrate the basic steps, let’s compute permutation importance scores for
the Ames housing bagged tree ensemble (ames.bag) from Section 5.5.1. I’ll
start by writing a simple function to compute the RMSE, the performance
metric of interest, and use it to obtain a baseline value for computing the
permutation-based importance scores.
rmse <- function(predicted, actual, na.rm = TRUE) {
sqrt(mean((predicted - actual) ^ 2, na.rm = na.rm))
}
(baseline.rmse <- rmse(predict(ames.bag, newdata = ames.trn),
actual = ames.trn$Sale_Price))
#> [1] 10.6
To get more stable VI scores, I’ll use 30 independent permutations for each
predictor; since the permutations are done independently, Algorithm 6.1 can
be trivially parallelized across repetitions or features. This is done using a
nested for loop in the next code chunk:
nperm <- 30 # number of permutation to use per feature
xnames <- names(subset(ames.trn, select = -Sale_Price))
vi <- matrix(nrow = nperm, ncol = length(xnames))
colnames(vi) <- xnames
for (j in colnames(vi)) {
for (i in seq_len(nrow(vi))) {
temp <- ames.trn # temporary copy of training data
temp[[j]] <- sample(temp[[j]]) # permute feature values
pred <- predict(ames.bag, newdata = temp) # score permuted data
permuted.rmse <- rmse(pred, actual = temp$Sale_Price) ^ 2
vi[i, j] <- permuted.rmse - baseline.rmse # smaller is better
Feature importance 207
}
}
Note that the individual permutation importance scores are computed inde-
pendently of each other, making it relatively straightforward to parallelize the
whole procedure; in fact, many R implementations of Algorithm 6.1, like vip
and iml, have options to do this in parallel using a number of different parallel
backends.
A boxplot of the unaggregated permutation scores for the top ten features,
as measured by the average across all 30 permutations, is displayed in Fig-
ure 6.1. Here, you can see that the overall quality rating of the home and its
above grade square footage are two of the most important predictors of sale
price, followed by neighborhood. A simple dotchart of the average permuta-
tion scores would suffice, but fails to show the variability in the individual VI
scores.
Year_Remod_Add
Total_Bsmt_SF
Second_Flr_SF
Overall_Qual
Neighborhood
MS_SubClass
Lot_Area
Gr_Liv_Area
Garage_Area
First_Flr_SF
Permutation importance
FIGURE 6.1: Permutation-based VI scores for the top ten features in the
Ames housing bagged tree ensemble, as measured by the average across all 30
permutations.
208 Peeking inside the “black box”: post-hoc interpretability
Partial dependence (PD) plots (or PDPs) help visualize the relationship be-
tween a subset of the features (typically 1–3) and the response while account-
ing for the average effect of the other predictors in the model. They are par-
ticularly effective with black box models like random forests, support vector
machines, and neural networks.
Let x = {x1 , x2 , . . . , xp } represent the predictors in a model whose prediction
function is fb(x). If we partition x into an interest set, z s , and its compliment,
z c = x \ z s , then the “partial dependence” of the response on z s is defined
as
h i Z
fs (z s ) = Ezc fb(z s , z c ) = fb(z s , z c ) pc (z c ) dz c , (6.1)
n
1Xb
f¯s (z s ) = f (z s , z i,c ) , (6.2)
n i=1
N
where {z i,c }i=1 are the values of z c that occur in the training sample; that
is, we average out the effects of all the other predictors in the model.
Feature effects 209
(a) Copy the training data and replace the original values of x1 with the
constant x1i .
(b) Compute the vector of predicted values from the modified copy of
the training data.
J
1X
fj (x) = log [pj (x)] − log [pj (x)] , j = 1, 2, . . . , J, (6.3)
J j=1
where pj (x) is the predicted probability for the j-th class. Plotting fj (x)
helps us understand how the log-odds for the j-th class depends on different
subsets of the predictor variables. Nonetheless, there’s no reason partial de-
pendence can’t be displayed on the raw probability scale. The same goes for
ICE plots (Section 6.2.3). A multiclass classification example of PD plots on
the probability scale is given in Section 6.2.6.
210 Peeking inside the “black box”: post-hoc interpretability
N N
f¯jk (xij , xik ) − f¯j (xij ) − f¯k (xik ) / f¯jk (xij , xik ) .
X 2 X
2
= (6.4)
Hjk
i=1 i=1
In essence, (6.4) measures the fraction of variance of f¯jk (xj , xk )—the joint
partial dependence of y on xj and xk —not captured by f¯j (xj ) and f¯j (xj )
(the individual partial dependence of y on xj and xk , respectively) over the
training data (or representative sample thereof). Note that Hjk 2
≥ 0, with zero
indicating no interaction between xj and xk . To determine whether a single
predictor, xj , say, interacts with any other variables, a similar H-statistic can
be computed. Unfortunately, these statistics are not widely implemented; the
R gbm package [Greenwell et al., 2021b], probably has the most efficient
implementation (see ?gbm::interact.gbm for details), but it’s only available
for GBMs (Chapter 8).
According to Friedman and Popescu [2008], only predictors with strong main
effects (e.g., high relative importance) should be examined for potential inter-
actions; the strongest interactions can then be further explored via two-way
PD plots. Be warned, however, that collinearity among predictors can lead to
spurious interactions that are not present in the target function.
A major drawback of the H-statistic (6.4) is that it requires computing both
the individual and joint partial dependence functions, which can be expensive;
the fast recursion method of Section 8.6.1 makes it feasible to compute the H-
statistic for binary decision trees (and ensembles of shallow trees). A simpler
approach, based on just the joint partial dependence function, is discussed in
Greenwell et al. [2018].
6.2.4 Software
PD plots and ICE plots (and many variants thereof) are implemented in
several R packages. Historically, PD plots were only implemented in specific
tree-based ensemble packages, like randomForest [Breiman et al., 2018] and
gbm. However, they were made generally available in package pdp, which
was soon followed by iml and ingredients, among others; these packages
also support ICE plots; the R package ICEbox [Goldstein et al., 2017] pro-
vides the original implementation of ICE plots and several variants thereof,
like c-ICE and d-ICE plots. PD plots and ICE plots were also made available
in scikit-learn’s inspection module, starting with versions 0.22.0 and 0.24.0,
respectively.
Using the Ames housing bagged tree ensemble, I’ll show how to construct
PD plots and ICE curves by hand and using the pdp package. To start, let’s
construct a PD plot for above grade square footage (Gr_Liv_Area), one of
the top predictors according to permutation-based VI scores from Figure 6.1
(p. 207).
The first step is to create a grid of points over which to construct the plot. For
continuous variables, it is sufficient to use a fine enough grid of percentiles,
as is done in the example below. Then, I simply loop through each grid point
and 1) copy the training data, 2) replace all the values of Gr_Liv_Area in the
copy with the current grid value, and 3) score the modified copy of the training
data and average the predictions together. Lastly, I simply plot the grid points
212 Peeking inside the “black box”: post-hoc interpretability
against the averaged predictions obtained from the for loop. The results are
displayed in Figure 6.2 and show a relatively monotonic increasing relationship
between above grade square footage and predicted sale price.
x.grid <- quantile(ames.trn$Gr_Liv_Area, prob = 1:30 / 31)
pd <- numeric(length(x.grid))
for (i in seq_along(x.grid)) {
temp <- ames.trn # temporary copy of data
temp[["Gr_Liv_Area"]] <- x.grid[i]
pd[i] <- mean(predict(ames.bag, newdata = temp))
}
210
200
Partial dependence
190
180
170
160
FIGURE 6.2: Partial dependence of sale price on above grade square footage
for the bagged tree ensemble.
computational complexity usually prohibits going beyond just two- or three-way interac-
tions.
Feature effects 213
In the next step, I perform a cross-join between the grid of plotting val-
ues (df1) and the original training data with the plotting features removed
(df2)b :
df2 <- subset(ames.trn, select = -c(Gr_Liv_Area, First_Flr_SF))
Then, I simply score the data and aggregate by computing the average pre-
diction within each grid point, as shown in the example below:
pd$yhat <- predict(ames.bag, newdata = pd) # might take a few minutes!
pd <- aggregate(yhat ~ Gr_Liv_Area + First_Flr_SF, data = pd,
FUN = mean)
The code snippet below constructs a false color level plot of the data with
contour lines using the built-in lattice package; the results are displayed in
Figure 6.3. Here, you can see the joint effect of both features on the predicted
sale price.
b BE CAREFUL as the resulting data set, which is a Cartesian product, can be quite
large!
214 Peeking inside the “black box”: post-hoc interpretability
library(lattice)
2000 220
210
200
First_Flr_SF
1500
190
180
1000 170
160
150
1000 1500 2000 2500
Gr_Liv_Area
FIGURE 6.3: Partial dependence of sale price on above grade and first floor
square footage for the bagged tree ensemble.
It is not wise to draw conclusions from PD plots (and ICE plots) in regions
outside the area of the training data. Greenwell [2017] describes two ways to
mitigate the risk of extrapolation in PD plots: rug displays, like the one I used
in Figure 6.2, and convex hulls (which can be used with bivariate displays,
like in Figure 6.3).
Constructing ICE curves is just as easy; just skip the aggregation step and
plot each of the individual curves. In the example below, I’ll use the pdp
package to construct c-ICE curves showing the partial dependence of above
grade square footage on sale price for each observation in the learning sample.
There’s no need to construct a curve for each sample, especially when you have
thousands (or more) data points; here, I’ll just plot a random sample of 500
curves. I’ll use the same percentiles to construct the plot as I did for the PD
plot in Figure 6.2 (p. 212) by invoking the quantiles and probs arguments
in the call to partial(); note that partial()’s default is to use an evenly
spaced grid of points across the range of predictor values.
Feature effects 215
The results are displayed in Figure 6.4; the red line shows the average c-ICE
value at each above grade square footage (i.e., the centered partial depen-
dence). The heterogeneity in the c-ICE curves indicates a potential interaction
effect between Gr_Liv_Area and at least one other feature. The c-ICE curves
also indicate a relatively monotonic increasing relationship for the majority of
houses in the training set, but you can see a few of the curves at the bottom
deviate from this overall pattern.
ice <- partial(ames.bag, pred.var = "Gr_Liv_Area", ice = TRUE,
center = TRUE, quantiles = TRUE, probs = 1:30 / 31)
set.seed(1123) # for reproducibility
samp <- sample.int(nrow(ames.trn), size = 500) # sample 500 homes
autoplot(ice[ice$yhat.id %in% samp, ], alpha = 0.1) +
ylab("Conditional expectation")
100
al e pe a
50
FIGURE 6.4: A random sample of 500 c-ICE curves for above grade square
footage using the Ames housing bagged tree ensemble. The curves indicate a
relatively monotonic increasing relationship for the majority of houses in the
sample. The average of the 500 c-ICE curves is shown in red.
For a classification example, I’ll consider Edgar Anderson’s iris data from the
datasets package in R. The iris data frame contains the sepal length, sepal
width, petal length, and petal width (in centimeters) for 50 flowers from each
of three species of iris: setosa, versicolor, and virginica. Below, I fit a bagged
tree ensemble to the data using the randomForest package:
216 Peeking inside the “black box”: post-hoc interpretability
library(randomForest)
Note that without the aid of a user-supplied prediction function (via the
pred.fun argument), pdp’s partial() function can only compute partial
dependence in regards to a single class; see Greenwell [2017] for more details
on the use of this package.
Feature contributions 217
0.6
0.5
yhat
0.4
0.3
0.2
in a “fair” way; that is, so that each player receives their “fair” share. The
Shapley value is one such solution and the only one that uniquely satisfies a
particular set of “fairness properties.”
Let v be a characteristic function that assigns a value to each subset of players;
in particular, v : 2p → R, where v (S) = ∆S and v (∅) = 0, with ∅ denoting
the empty set (i.e., zero players). Let φi (v) be the contribution (or portion of
the total payout) attributed to player i in a particular game with total payout
v (S) = ∆S . The Shapley value satisfies the following properties:
Pp
• efficiency: i=1 φi (v) = ∆S ;
• null player: ∀W ⊆ S \ {i} : ∆W = ∆W ∪{i} =⇒ φi (v) = 0;
• symmetry: ∀W ⊆ S \ {i, j} : ∆W ∪{i} = ∆W ∪{j} =⇒ φi (v) = φj (v);
• linearity: If v and w are functions describing two coalitional games, then
φi (v + w) = φi (v) + φi (w).
The above properties can be interpreted as follows:
• the individual player contributions sum to the total payout, hence, are
implicitly normalized;
• if a player does not contribute to the coalition, they receive a payout of
zero;
• if two players have the same impact across all coalitions, they receive equal
payout;
• the local contributions are additive across different games.
Shapley [2016] showed that the unique solution satisfying the above properties
is given by
1 X
φi (v) = v SO ∪ i − v SO , i = 1, 2, . . . , p, (6.5)
p!
O∈π(p)
value for player i is given by the average of this contribution over all possible
permutations in which the coalition can be formed.
Feature contributions 219
A simple example may help clarify the main ideas. Suppose three friends
(players)—Alex, Brad, and Brandon—decide to go out for drinks after work
(the game). They shared a few pitchers of beer, but nobody paid attention to
how much each person drank (collaborated). What’s a fair way to split the
tab (total payout)?
Suppose we knew the following information, perhaps based on historical happy
hours:
• if Alex drank alone, he’d only pay $10;
• if Brad drank alone, he’d only pay $20;
• if Brandon drank alone, he’d only pay $10;
• if Alex and Brad drank together, they’d only pay $25;
• if Alex and Brandon drank together, they’d only pay $15;
• if Brad and Brandon drank together, they’d only pay $13;
• if Alex, Brad, and Brandon drank together, they’d only pay $30.
With only three players, we can enumerate all possible coalitions. In Table
6.1, I list all possible permutations of the three players and list the marginal
contribution of each. Take the first row, for example. In this particular per-
mutation, we start with Alex. We know that if Alex drinks alone, he’d spend
$10, so his marginal contribution by entering first is $10. Next, we assume
Brad enters the coalition. We know that if Alex and Brad drank together,
they’d pay a total of $25, leaving $15 left over for Brad’s marginal contri-
bution. Similarly, if Brandon joins the party last, his marginal contribution
would be only $5 (the difference between $30 and $25). The Shapley value for
each player is the average across all six possible permutations (these are the
column averages reported in the last row). In this case, Brandon would get
away with the smallest payout (i.e., have to pay the smallest portion of the
total tab). The next time the bartender asks how you want to split the tab,
whip out a pencil, and do the math!
Marginal contribution
Permutation/order of players Alex Brad Brandon
Alex, Brad, Brandon $10 $15 $5
Alex, Brandon, Brad $10 $15 $5
Brad, Alex, Brandon $5 $20 $5
Brad, Brandon, Alex $10 $20 $0
Brandon, Alex, Brad $5 $15 $10
Brandon, Brad, Alex $17 $3 $10
Shapley contribution: $9.50 $14.67 $5.83
• the total payout/worth (∆S ) for x? is the prediction for x? minus the
average prediction for all training observations (the latter is referred to as
the baseline and denoted f¯): fˆ (x? ) − f¯;
• the players are the individual feature values of x? that collaborate to
receive the payout ∆S (i.e., predict a certain value).
The second point, combined with the efficiency property stated in the previous
section, implies that the p Shapley explanations (or feature contributions) for
p
an observation of interest x? , denoted {φi (x? )}i=1 , are inherently standard-
ized since j=1 φi (x? ) = fˆ (x? ) − f¯.
Pp
Several methods exist for estimating Shapley values in practice. The most
common is arguably Tree SHAP [Lundberg et al., 2020], an efficient implemen-
tation of exact Shapley values for decision trees and ensembles thereof.
Tree SHAP is a fast and exact method to estimate Shapley values for tree-
based models (including tree ensembles), under several different possible as-
sumptions about feature dependence. The specifics of Tree SHAP are beyond
the scope of this book, so I’ll defer to [Lundberg et al., 2020] for the details.
It’s implemented in the Python shap module, and embedded in several tree-
based modeling packages across several open source languages (like xgboost
[Chen et al., 2021] and lightgbm [Shi et al., 2022]). While the details of Tree
SHAP are beyond the scope of this book, we’ll see an example of it in action
in Section 8.9.4.
Feature contributions 221
In the following section, I’ll discuss a general way to estimate Shapley values
for any supervised learning model using a simple Monte Carlo approach.
Except in special circumstances, like Tree SHAP, computing the exact Shapley
value is computationally infeasible in most applications. To that end, Štrum-
belj and Kononenko [2014] suggest a Monte Carlo approximation, which I’ll
call Sample SHAP for short, that assumes independent featuresc . Their ap-
proach is described in Algorithm 6.3 below.
Here, a single estimate of the contribution of feature xi to f (x? )− f¯ is nothing
more than the difference between two predictions, where each prediction is
based on a set of “Frankenstein instances”d that are constructed by swapping
out values between the instance being explained (x? ) and an instance selected
at random from the training data (w? ). To help stabilize the results, the
procedure is repeated a large number, say, R, times, and the results averaged
together:
1) For j = 1, 2, . . . , R:
c While Sample SHAP, along with many other common Shapley value procedures, as-
sumes independent features, several arguments can be made in favor of this assumption;
see, for example, Chen et al. [2020] and the references therein.
d The terminology used here takes inspiration from Molnar [2019, p. 231].
222 Peeking inside the “black box”: post-hoc interpretability
To illustrate, let’s continue with the Ames housing example (ames.bag). Be-
low, I use the sample.shap() function to estimate the contribution of the
value of Gr_Liv_Area to the prediction of the first observation in the learning
sample (ames.trn):
X <- subset(ames.trn, select = -Sale_Price) # features only
set.seed(2207) # for reproducibility
sample.shap(predict, obj = ames.bag, R = 100, x = X[1, ],
feature = "Gr_Liv_Area", X = X)
#> [1] -6.7
So, having Gr_Liv_Area = 1474 helped push the predicted sale price down
toward the baseline average; in this case, the baseline average is just the
average predicted sale price across the entire training set: f¯ = $181.53 (don’t
forget that I rescaled the response in this example).
If there are p features and m instances to be explained, this requires 2 ×
R × p × m predictions (or calls to the scoring function f ). In practice, this
can be quite computationally demanding, especially since R needs to be large
enough to produce good approximations to each φi (x? ). How large does R
need to be to produce accurate explanations? It depends on the variance of
each feature in the observed training data, but typically R ∈ [30, 100] will
suffice. The R package fastshap [Greenwell, 2021a] provides an optimized
implementation of Algorithm 6.3 that only requires 2mp calls to f ; see the
package documentation for details.
Sample SHAP can be computationally prohibitive if you need to explain large
data sets (optimized or not). Fortunately, you often only need to explain a
Feature contributions 223
handful of predictions, the most extreme ones, for example. However, generat-
ing explanations for the entire training set, or a large enough sample thereof,
can be useful for generating aggregated global model summaries. For example,
Shapley-based dependence plots [Lundberg et al., 2020] show how a feature’s
value impacts the prediction of every observation in a data set of interest.
6.3.3 Software
pred[highest]
#> 433
#> 503
# fastshap needs to know how to compute predictions from your model
pfun <- function(object, newdata) predict(object, newdata = newdata)
Overall_Qual = Very_Excellent
Gr_Liv_Area = 2674
Neighborhood = Northridge_Heights
Total_Bsmt_SF = 2630
First_Flr_SF = 2674
Fireplaces = 2
Garage_Cars = 3
Latitude = 42.1
Longitude = −93.7
Lot_Area = 13693
0 50 100 150
Shapley value
Next, I’ll construct a Shapley dependence plot for Gr_Liv_Area using fast-
shap with R = 50 Monte Carlo repetitions. The results are displayed in Fig-
ure 6.2. As with Figures 6.2 and 6.4, the predicted sale price tends to increase
with above grade square footage. As with the c-ICE curves in Figure 6.4, the
increasing dispersion in the plot indicates a potential interaction with at least
one other feature. Coloring the Shapley dependence plot by the values of an-
other feature can help visualize such an interaction, if you know what you’re
looking for.
ex <- explain(ames.bag, feature_names = "Gr_Liv_Area", X = X,
nsim = 50, pred_wrapper = pfun)
0
Shapley value
0
1000 2000 3000 000 5000
r v Area
IML is on the rise, and so is IML-related open source software. There are
simply too many methods and useful packages to discuss in one chapter, so I
only just covered a handful. If you’re looking for more, I’d recommend starting
with the IML awesome list hosted by Patrick Hall at
https://github.com/jphall663/awesome-machine-learning-
interpretability.
f It’s been argued that approximate Shapley values share the same drawback; however,
A good resource for R users is Maksymiuk et al. [2021]. And of course, Molnar
[2019] is a freely available resource, filled with intuitive explanations and links
to relevant software in both R and Python. Molnar et al. [2021] is also worth
reading, as they discuss a number of pitfalls to watch out for when using
model-agnostic interpretation methods.
7
Random forests
Paulo Coelho
7.1 Introduction
Random forests (RFs) are essentially bagged tree ensembles with an added
twist, and they tend to provide similar accuracy to many state-of-the-art su-
pervised learning algorithms on tabular data, while being relatively less diffi-
cult to tune. In other words, RFs tend to be competitive right out of the box.
But be warned, RFs—like any statistical and machine learning algorithm—
enjoy their fair share of disadvantages. As we’ll see in this chapter, RFs also
include many bells and whistles that data scientists can leverage for non-
prediction tasks, like detecting anomalies/outliers, imputing missing values,
and so forth.
Recall that a bagged tree ensemble (Section 5.1) consists of hundreds (some-
times thousands) of independently grown decision trees, where each tree is
trained on a different bootstrap sample from the original training data. Each
tree is intentionally grown deep (low bias), and variance is reduced by aver-
229
230 Random forests
aging the predictions across all the trees in the ensemble. For classification, a
plurality vote among the individual trees is used.
Unfortunately, correlation limits the variance-reducing effect of averaging.
N iid
Take the following example for illustration. Suppose {Xi }i=1 ∼ µ, σ 2 is
a random
PN sample from some distribution with mean µ and variance σ . Let
2
the variance of the sample elements. This of course assumes that the Xi
are uncorrelated. If the pairwise correlation between any two observations
is ρ = ρ (Xi , Xj ) (i 6= j), then
1−ρ 2
V X̄ = ρσ 2 +
σ ,
N
2
Sample mean
−2
FIGURE 7.1: 100 simulated averages from samples of size N = 30 with pair-
wise correlation increasing from zero to one.
The random forest algorithm 231
1 1 1
yes charExcl < 0.077 no yes charDoll < 0.049 no yes charDoll < 0.056 no
3 3 3
charDoll < 0.0065 hp >= 0.41 hp >= 0.43
2 2 2
remove < 0.045 remove < 0.06 remove < 0.065
5 7 5 7 5 7
hp >= 0.15 hp >= 1.5 george >= 0.15 edu >= 0.7 hp >= 0.22 edu >= 0.7
4 6 4 6 4 6
charDoll < 0.17 capitalA < 2.8 charExcl < 0.23 nonspam charExcl < 0.23 free < 0.88
9 11 13 15 9 11 15 9 11 13 15
spam spam spam spam spam spam spam spam spam spam spam
8 10 12 14 8 10 14 8 10 12 14
nonspam nonspam nonspam nonspam nonspam nonspam nonspam nonspam nonspam nonspam nonspam
1 1 1
charExcl < 0.083
yes no
yes charDoll < 0.054 no yes charDoll < 0.03 no
3 3 3
capitalL < 16 hp >= 0.39 hp >= 0.2
2 2 2
remove < 0.045
remove < 0.06 remove < 0.065
5 7 5 7
spam spam
5 7
george >= 0.08 hp >= 0.23
4 6
edu >= 0.38 4 6
edu >= 0.15
4 6
charDoll < 0.12 free < 0.065 charExcl < 0.49 nonspam free < 0.3 our < 0.9
9 15 9 13 15
9
spam
11
spam
13
spam
15
spam
spam spam spam spam spam
8 14 8 12 14
8 10 12 14
nonspam nonspam nonspam nonspam nonspam nonspam nonspam nonspam nonspam
FIGURE 7.2: Six bagged decision trees applied to the email spam training
data. The path to terminal node 15 is highlighted in each tree.
Luckily, Leo Breiman and Adele Cutler thought of a clever way to reduce
correlation in a bagged tree ensemble; that is, make the trees more diverse. The
idea is to limit the potential splitters at each node in a tree to a random subset
of the available predictors, which will often result in a much more diverse
ensemble of trees. In essence, bagging constructs a diverse tree ensemble by
introducing randomness into the rows via sampling with replacement, while
an RF further increases tree diversity by also introducing randomness into the
columns via subsampling the features.
232 Random forests
1) Start with a training sample, dtrn , and specify integers, nmin (the mini-
mum node size), B (the number of trees in the forest), and mtry ≤ p (the
number of predictors to select at random as candidate splitters prior to
splitting the data at each node in each tree).
2) For b in 1, 2, . . . , B:
(a) Select a bootstrap sample d?trn of size N from the training data dtrn .
(b) Optional: Keep track of which observations from the original train-
ing data were not selected to be in the bootstrap sample; these are
called the out-of-bag (OOB) observations.
(c) Fit a decision tree Tb to the bootstrap sample d?trn according to the
following rules:
(ii) Continue recursively splitting each terminal node until the min-
imum node size nmin is reached.
B
3) Return the “forest” of trees {Tb }b=1 .
4) To obtain the RF prediction for a new case x, pass the observation down
each tree and aggregate as follows:
n oB
rf
• Classification: ĈB (x) = vote Ĉb (x) , where Ĉb (x) is the pre-
b=1
dicted class label for x from the b-th tree in the forest (in other
words, let each tree vote on the classification for x and take the ma-
jority/plurality vote).
• Regression: fˆBrf (x) = B1 b=1 fˆb (x) (in other words, we just average
PB
the predictions for case x across all the trees in the forest).
1) Take the proportion of votes for each class over the entire forest.
2) Average the class probabilities from each tree in the forest. (In this case,
nmin should be considered a tuning parameter; see, for example, Malley
et al. [2012].)
The first approach can be problematic. For example, suppose the probabil-
ity that x belongs to class j is Pr (Y = j|x) = 0.91. If each tree correctly
predicts class j for x, then Prc (x) = 1, which is incorrect. If nmin = 1, the
two approaches are equivalent and neither will produce consistent estimates
of the true class probabilities (see, for example, Malley et al. [2012]). So which
approach is better for probability estimation? Hastie et al. [2009, p. 283] ar-
gue that the second method tends to provide improved estimates of the class
probabilities with lower variance, especially for small B.
Malley et al. [2012] make a similar argument for the binary case, but from
a different perspective. In particular, they suggest treating the 0/1 outcome
as numeric and fitting a regression forest using the standard MSE splitting
criterion (an example of a so-called probability machine). It seems strange to
use MSE on a 0/1 outcome, right? Not really. Recall from Section 2.2.1 that
the Gini index for binary outcomes is equivalent to using the MSE. Malley
et al. recommend using a minimum node size equal to 10% of the number
of training cases: nmin = b0.1 × N c. However, for probability estimation, it
seems natural to treat nmin as a tuning parameter. Devroye et al. [1997,
Chap. 21–22] provide some guidance on the choice of nmin for consistent
probability estimation in decision trees.
The predicted probabilities can be converted to class predictions (i.e., by com-
paring each probability to some threshold), which gives us an alternative to
hard voting called soft voting. In soft voting, we classify x to the class with
the largest averaged class probability. This approach to classification in RFs
tends to be more accurate since predicted probabilities closer to zero or one
are given more weight during the averaging step; hence, soft voting attaches
more weight to votes with higher confidence (or smaller standard errors; Sec-
tion 7.7).
1,
r (x) < 8
p (x) = Pr (Y = 1|x) = 28−r(x)
, 8 ≤ r (x) ≤ 20 , (7.1)
20
0, r (x) ≥ 28
where r (x) is the Euclidean distance from x = (x1 , x2 ) to the point (25, 25).
A sample of N = 1000 observations from the Mease model is displayed in
Figure 7.3; note that the observed 0/1 outcomes were generated according
to the above probability rule p (x). (As always, the code to reproduce the
simulation is available on the companion website.)
50
30
2
20
10
0
1
0
0 10 20 30 0 50
1
Figures 7.4–7.5 display the results of the simulation. In Figure 7.4, the median
predicted probability across all 250 simulations was computed and plotted vs.
the true probability of class 1 membership; the dashed 45-degree line corre-
sponds to perfect agreement. Here it is clear that the regression forest (i.e.,
treating the 0/1 outcome as continuous and building trees using the MSE split-
ting criterion) outperforms the classification forest (except when nmin = 1, in
which case they are equivalent.) This is also evident from Figure 7.5, which
shows the distribution of the MSE between the predicted class probabilities
and the true probabilities for each case. In essence, for binary outcomes, re-
gression forests produce consistent estimates of the true class probabilities.a
This goes to show that mtry isn’t the only important tuning parameter when
a By b (Y = 1|x) → Pr (Y = 1|x) as N → ∞.
consistent, I mean that Pr
236 Random forests
0 1
es e 1 es e 10
1.00
0.75
a l y
0.50
0.25
e pr
0.00
es e 0 es e 00
1.00
e a pre
0.75
0.50
0.25
0.00
0.00 0.25 0.50 0.75 1.00 0.00 0.25 0.50 0.75 1.00
rue pr a l y
es e 1 es e 10 es e 0 es e 00
0.05
0.0
ea s ure err r
0.03
0.02
0.01
FIGURE 7.5: Mean squared errors (MSEs) from the Mease simulation. Here
the MSE between the predicted probabilities and the true probabilities for
each simulation are displayed using boxplots. Clearly, in this example, the
regression forest (RF) with nmin > 1 produces more accurate class probability
estimates.
To help solidify the basic concepts of an RF, let’s construct one from scratch.b
To do that, we need a decision tree implementation that will allow us to
randomly select a subset of features for consideration at each node in the tree.
Such arguments are available in the sklearn.tree module in Python, as well as
R’s party and partykit packages—unfortunately, this option is not currently
available in rpart. In this example, I’ll go with party, since it’s ctree()
function is faster, albeit less flexible, than partykit’s implementation.
Below is the definition for a function called crforest(), which constructs a
conditional random forest (CRF) [Hothorn et al., 2006a, Strobl et al., 2008a,
2007b], that is, an RF using conditional inference trees (Chapter 3) for the
base learners.c The oob argument will come into play in Section 7.3, so just
b The code I’m about to show is for illustration purposes only. It will not be nearly as
ignore that part of the code for now. Note that the function returns a list of
fitted CTrees that we can aggregate later for the purposes of prediction.
crforest <- function(X, y, mtry = NULL, B = 5, oob = TRUE) {
min.node.size <- if (is.factor(y)) 1 else 5
N <- nrow(X) # number of observations
p <- ncol(X) # number of features
train <- cbind(X, "y" = y) # training data frame
fo <- as.formula(paste("y ~ ", paste(names(X), collapse = "+")))
if (is.null(mtry)) { # use default definition
mtry <- if (is.factor(y)) sqrt(p) else p / 3
mtry <- floor(mtry) # round down to nearest integer
}
# CTree parameters; basically force the tree to have maximum depth
ctrl <- party::ctree_control(mtry = mtry, minbucket = min.node.size,
minsplit = 10, mincriterion = 0)
forest <- vector("list", length = B) # to store each tree
for (b in 1:B) { # fit trees to bootstrap samples
boot.samp <- sample(1:N, size = N, replace = TRUE)
forest[[b]] <- party::ctree(fo, data = train[boot.samp, ],
control = ctrl)
if (isTRUE(oob)) { # store row indices for OOB data
attr(forest[[b]], which = "oob") <-
setdiff(1:N, unique(boot.samp))
}
}
forest # return the "forest" (i.e., list) of trees
}
Let’s test out the function on the Ames housing data, using the same 70/30
split from previous examples (Section 1.4.7). Here, I’ll fit a default CRF (i.e.,
mtry = bp/3c and nmin = 5) using our new crforest() function. (Be warned,
this code may take a few minutes to run; the code on the book website includes
an optional progress bar and the ability to run in parallel using the foreach
package [Revolution Analytics and Weston, 2020].)
X <- subset(ames.trn, select = -Sale_Price) # feature columns
set.seed(1408) # for reproducibility
ames.crf <- crforest(X, y = ames.trn$Sale_Price, B = 300)
To obtain predictions from the fitted model, we can just loop through each
tree, extract the predictions, and then average them together at the end. This
can be done with a simple for loop, which is demonstrated in the code chunk
below. Here, I obtain the averaged predictions from ames.crf on the test data
and compute the test RMSE.
B <- length(ames.crf) # number of trees in forest
preds.tst <- matrix(nrow = nrow(ames.tst), ncol = B)
for (b in 1:B) { # store predictions from each tree in a matrix
preds.tst[, b] <- predict(ames.crf[[b]], newdata = ames.tst)
Out-of-bag (OOB) data 239
}
pred.tst <- rowMeans(preds.tst) # average predictions across trees
Rather than reporting the test RMSE for the entire forest, we can compute it
for each sub-forest of size b ≤ B to see how it changes as the forest grows. We
can do this using a simple for loop, as demonstrated in the code chunk below.
(Note that I use drop = FALSE here so that the subset matrix of predictions
doesn’t lose its dimension when b = 1.)
rmse.tst <- numeric(B) # to store RMSEs
for (b in 1:B) {
pred <- rowMeans(preds.tst[, 1:b, drop = FALSE], na.rm = TRUE)
rmse.tst[b] <- rmse(pred, obs = ames.tst$Sale_Price, na.rm = TRUE)
}
The above test RMSEs are displayed in Figure 7.6 (black curve). For compar-
ison, I also included the test error for a single CTree fit (horizontal dashed
line). Here, the CRF clearly outperforms the single tree, and the test error
stabilizes after about 50 trees. Next, I’ll discuss an internal cross-validation
strategy based on the OOB data.
applies to bagging and boosting when sampling is involved, regardless if the sampling is
done with or without replacement.
e This discussion also applies to subsampling without replacement.
240 Random forests
N
1
Pr (case i ∈
/ bootstrap sample b) = 1− .
N
N
As N → ∞ it can be shown that 1 − N1 → e−1 ≈ 0.368. In other words, on
average, each bootstrap sample contains approximately 1 − e−1 ≈ 0.632 of the
original training records; the remaining e−1 ≈ 0.368 observations are OOB and
can be used as an independent validation set for the corresponding tree. This
is rather straightforward to observe without a mathematical derivation. The
code below computes the proportion of non-OOB observations in B = 10000
bootstrap samples of size N = 100, and averages the results together:
set.seed(1226) # for reproducibility
N <- 100 # sample size
obs <- 1:N # original observations
res <- replicate(10000, sample(obs, size = N, replace = TRUE))
inbag <- apply(res, MARGIN = 2, FUN = function(boot.sample) {
mean(obs %in% boot.sample) # proportion in bootstrap sample
})
mean(inbag)
#> [1] 0.634
Let wb,i = 1 if observation i is OOB in the b-th tree and zero otherwise.
PB
Further, if we let Bi = i=1 wb,i be the number of trees in the forest for
which observation i is OOB, then the OOB prediction for the i-th training
observation is given by
1 X
ŷiOOB = ŷib , i = 1, 2, . . . , N. (7.2)
Bi
b:wb,i =1
The OOB error estimate is just the error computed from these OOB predic-
tions. (See [Hastie et al., 2009, Sec. 7.11] for a more general discussion on
using the bootstrap to estimate prediction error and its apparent bias.)
To illustrate, I’m going to compute the OOB RMSE for the CRF I previously
fit to the Ames housing data. There are numerous ways in which this can be
Out-of-bag (OOB) data 241
done programmatically given our setup; I chose the easy route. Recall that
each tree in our rfo object contains an attribute called "oob" which stores
the row numbers for the training records that were OOB for that particular
tree. From these we can easily construct an N × B matrix, where the (i, j)-th
element is given by
(
ŷib if wb,i = 1
.
NA if wb,i = 0
The reason for using NAs in place of the predictions for the non-OOB obser-
vations will hopefully become apparent soon.
preds.oob <- matrix(nrow = nrow(ames.trn), ncol = B) # OOB predictions
for (b in 1:B) { # WARNING: Might take a minute or two!
oob.rows <- attr(ames.crf[[b]], which = "oob") # OOB row IDs
preds.oob[oob.rows, b] <-
predict(ames.crf[[b]], newdata = ames.trn[oob.rows, ])
}
pred.oob <- rowMeans(preds.oob) # average OOB predictions across trees
# Peek at results
preds.oob[1:3, 1:6]
Peeking at the first few rows and columns you can see that the first training ob-
servation (which corresponds to the first row in the above matrix) was OOB in
the first and sixth trees (since the rest of the columns are NA), whereas the sec-
ond observation was OOB for trees two and six, so I obtained the correspond-
ing OOB predictions for these. Next, I compute ŷiOOB as in Equation (7.2)
by computing the row means of our matrix pred.oob—setting na.rm = TRUE
in the call to rowMeans() ensures that the NAs in the matrix aren’t counted,
so that the average is taken only over the OOB predictions (i.e., the correct
denominator Bi will be used). Note that the OOB error is slightly larger than
the test error I computed earlier; this is typical in many common settings, as
noted in Janitza and Hornung [2018].
pred.oob <- rowMeans(preds.oob, na.rm = TRUE)
rmse(pred.oob, obs = ames.trn$Sale_Price, na.rm = TRUE)
#> [1] 26.6
Similar to what I did in the previous section, I can compute the OOB RMSE
as a function of the number of trees in the forest. The results are displayed
in Figure 7.6, along with the test RMSEs from the same forest (black curve)
242 Random forests
and test error from a single CTree fit (horizontal blue line). Here, we can see
that the OOB error is consistently higher than the test error, but both begin
to stabilize at around 50 trees.
rmse.oob <- numeric(B) # to store RMSEs
for (b in 1:B) {
pred <- rowMeans(preds.oob[, 1:b, drop = FALSE], na.rm = TRUE)
rmse.oob[b] <- rmse(pred, obs = ames.trn$Sale_Price, na.rm = TRUE)
}
CRF (test)
CRF (OOB)
Ctree (test)
35
RMSE
30
25
FIGURE 7.6: RMSEs for the Ames housing example: the CRF test RMSEs
(black curve), the CRF OOB RMSEs (yellow curve), and the test RMSE from
a single CTree fit (horizontal blue line).
As noted in Hastie et al. [2009], the OOB error estimate is almost identical to
that obtained by N -fold cross-validation, where N is the number of rows in
the learning sample; this is also referred to as leave-one-out cross-validation
(LOOCV). Hence, algorithms that produce OOB data can be fit in one se-
quence, with cross-validation being performed along the way. The OOB error
can be monitored during fitting and the training can stop once the OOB error
has “stabilized”. In the Ames housing example (Figure 7.6) it can be seen that
the test and OOB errors both stabilize after around 50 trees.
While the OOB error is computationally cheap, Janitza and Hornung [2018]
observed that it tends to overestimate the true error in many practical situa-
tions, including
• when the class frequencies are reasonably balanced in classification set-
tings;
Hyperparameters and tuning 243
7
S
ea s uare err r
1 2 3 5 6 7 9 10
ry
FIGURE 7.7: OOB and test error vs. mtry for the Friedman benchmark data
using N = 2000 with a 50/50 split. The dashed line indicates the standard
default for regression; in this case, mtry = 3.
Variable importance 245
B
1 X
VI (x) = VITb (x) , (7.3)
B
b=1
where VITb (x) is the relative importance of x in tree Tb (Section 2.8). Since
averaging helps to stabilize variance, VI (x) tends to be more reliable than
VITb (x) [Hastie et al., 2009, p. 368].
The split variable selection bias inherent in CART-like decision trees also
affects the impurity-based importance measure in their ensembles (7.3). The
bias tends to result in higher variable importance scores for predictors with
more potential split points (e.g., categorical variables with many categories).
Several authors have proposed methods for eliminating the bias when the Gini
index is used as the splitting criterion; see, for example, Sandri and Zuccolotto
[2008] and the references therein. An interesting (and rather simple) approach
246 Random forests
where VItrue
Tb (xi ) is the part attributable to informative splits and is related
to the “true” importance of xi , and VIbias
Tb (xi ) is the part attributable to unin-
formative splits and is a source of bias. The algorithm they propose attempts
to eliminate the bias in VITb (xi ) by subtracting off an estimate of VIbias
Tb (xi ).
This is done many times and the results averaged together. The basic steps
are outlined in Algorithm 7.2 below.
1) For r = 1, 2, . . . , R:
3) Use Equation 7.3 to compute both VI (xi ) and VI (zi ); that is, com-
pute the usual impurity-based variable importance measure for each
predictor xi and pseudo predictor zi , for i = 1, 2, . . . , p.
Algorithm 7.2 can be used to correct biased variable importance scores from
a single CART-like tree or an ensemble thereof. Also, while the original al-
gorithm was developed for the Gini-based importance measure, Sandri and
Zuccolotto [2008] suggest it is also effective at eliminating bias for other im-
purity measures, like cross-entropy and SSE. One of the drawbacks of Algo-
rithm 7.2., however, is that it effectively doubles the number of predictors
to 2p and requires multiple (R) iterations. This can be computationally pro-
hibitive for large data sets, especially for tree-based ensembles. Fortunately,
Nembrini et al. [2018] proposed a similar technique specific to RFs that only
requires a single replication. I’ll omit the details, but the procedure is avail-
able in the ranger package for R (which has also been ported to Python and
is available in the skranger package [Flynn, 2021]); an example is given in
Figure 7.8.
Even though our quick-and-dirty crforest() function in Section 7.2.3 used
bootstrap sampling, the actual CRF procedure described in Strobl et al.
[2007b], and implemented in R packages party and partykit, defaults to
growing trees on random subsamples of the training data without replace-
ment (by default, the size of each sample is given by b0.632N c), as opposed
to bootstrapping. Strobl et al. [2007b] showed that this effectively removes
the bias in CRFs due to the presence of predictor variables that vary in their
scale of measurement or their number of categories.
RFs offer an additional (and unbiased) variable importance method; the ap-
proach is quite similar to the more general permutation approach discussed
in Section 6.1.1, but it’s based on permuting observations in the OOB data
instead. The idea is that if predictor x is important, then the OOB error will
go up when x is perturbed in the OOB data. In particular, we start by com-
puting the OOB error for each tree. Then, each predictor is randomly shuffled
in the OOB data, and the OOB errors are computed again. The difference in
the two errors is recorded for the OOB data, then averaged across all trees in
the forest.
As with the more general permutation-based importance measure, these scores
can be unreliable in certain situations; for example, when the predictor vari-
ables vary in their scale of measurement or their number of categories [Strobl
et al., 2007a], or when the predictors are highly correlated [Strobl et al., 2008b].
Additionally, the corrected Gini-based importance discussed in Nembrini et al.
[2018] has the advantage of being faster to compute and more memory effi-
cient.
Figure 7.8 shows the results from three difference RF variable importance
measures on the simulation example from Section 3.1; the simulation com-
paring the split variable selection bias between CART and CTree. Here, we
248 Random forests
can see that the traditional Gini-based variable importance measure is biased
towards the categorical variables, while the corrected Gini and permutation-
based variable importance scores are relatively unbiased.
6
Average rank
0
ch2 m10 m2 m20 m4 nor uni ch2 m10 m2 m20 m4 nor uni ch2 m10 m2 m20 m4 nor uni
FIGURE 7.8: Average feature importance ranking for three RF-based variable
importance measures. Left: the traditional Gini-based measure. Middle: the
OOB-based permutation measure. Right: The corrected Gini-based measure.
The next two sections discuss more contemporary permutation schemes for
RFs that deserve some consideration.
So far in this book, we’ve mainly discussed tree-based methods for supervised
learning problems. However, not every problem is supervised. For example,
it is often of interest to understand how the data clusters—that is, whether
the rows of the data form any “interesting” groups. (This is an application of
unsupervised learning.) Many clustering methods rely on computing the pair-
wise distances between any two rows in the data, but the challenge becomes
choosing the right distance metric. Euclidean distance (i.e., the “ordinary”
straight-line, or “as the crow flies” distance between two points), for exam-
ple, is quite sensitive to the scale of the inputs. It’s also rather awkward to
compute the Euclidean distance between two rows of data when the features
are a mix of both numeric and categorical types. Fortunately, other distance
(or distance-like) measures are available which more naturally apply to mixed
data types.
Another useful output that can be obtained from an RF, provided it’s imple-
mented, are pairwise case proximities. RF proximities are distance-like mea-
sures of how similar any two observations are, and can be used for
• clustering in supervised and unsupervised (Section 7.6.3) settings;
250 Random forests
Outliers (or anomalies) are generally defined as cases that are removed from
the main body of the data. In the context of an RF, Leo Breiman defined
outliers as cases whose proximities to all other cases in the data are generally
small. For classification, he proposed a simple measure of “outlyingness” based
on the RF proximity values. Define the average proximity from case m in class
j to the rest of the training data in class j as
X
prox? (m) = prox2 (m, k) ,
k ∈ class j
where the sum is over all training instances belonging to class j. The outly-
ingness of case m in class j to all other cases in class j is defined as
N
out (m, j) = ,
prox? (m)
150
100
50
FIGURE 7.9: Proximity-based outlier scores for the Swiss banknote data. The
largest outlier score corresponds to observation 101, which was a counterfeit
banknote that was mislabeled as genuine.
Many decision tree algorithms can can naturally handle missing values; CART
and CTree, for example, employ surrogate splits to handle missing values (Sec-
tion 2.7). Unfortunately, the idea does not carry over to RFs. I suppose that
makes sense: searching for surrogates would greatly increase the computation
time of the RF algorithm. Although some RF software can handle missing val-
ues without casewise deletion (e.g., the h2o package in both R and Python),
most often they have to be imputed or otherwise dealt with.
Breiman also developed a clever way to use RF proximities for imputing miss-
ing values. The idea is to first impute missing values with a simple method
(such as using the mean or median for numeric predictors and the most com-
mon value for categorical ones). Next, fit an initial RF to the N complete
observations and generate the N × N proximity matrix. For a numeric fea-
ture, the initial imputed values can be updated using a weighted mean over
the non-missing values where the weights are given by the proximities. For
categorical variables, the imputed values are updated using the most frequent
non-missing values where frequency is weighted by the proximities. Then just
iterate until some convergence criterion is met (typically 4–6 runs). In other
words, this method just imputes missing values using a weighted mean/mode
with more weight on non-missing cases.
Casewise proximities 253
Breiman [2002] noted that the OOB estimate of error in RFs tends to be
overly optimistic when fit to training data that has been imputed. As with
proximity-based outlier detection, this approach to imputation does not scale
well (especially since it requires fitting multiple RFs and computing prox-
imities). Further, this imputation method is often not as accurate as more
contemporary techniques, like those implemented in the R package mice [van
Buuren and Groothuis-Oudshoorn, 2021].
Perhaps the biggest drawback to proximity-based imputation, like many other
imputation methods, is that it only generates a single completed data set. As
discussed in van Buuren [2018, Chap. 1], our level of confidence in a particular
imputed value can be expressed as the variation across a number of completed
data sets. In Section 7.9.3, I’ll use the CART-based multiple imputation pro-
cedure discussed in Section 2.7.1 and show how we can have confidence in the
interpretation of the RF output by incorporating the variability associated
with multiple imputation runs.
As it turns out, RFs can be used in unsupervised settings as well (i.e., when
there is no defined response variable). In this case, the goal is to cluster the
data, that is, see if the rows from the learning sample form any ‘interesting”
groups.
In an unsupervised RF, the idea is formulate a two-class problem. The first
class corresponds to the original data, while the second class corresponds to
a synthetic data set generated from the original sample. There are two ways
to generate the synthetic data corresponding to the second class [Liaw and
Wiener, 2002]:
1) a bootstrap sample is generated from each predictor column of the original
data;
2) a random sample is generated uniformly from the range of each predictor
column of the original data.
These two data sets are then stacked on top of each other, and an ordinary
RF is used to build a binary classifier to try and distinguish between the real
and synthetic data. (A necessary drawback here is that the resulting data set
is twice as large as the original learning sample.) If the OOB misclassification
error rate in the new two-class problem is, say, ≥ 40%, then the columns look
too much like independent variables in the eyes of the RF; in other words, the
dependencies among the columns do not play a large role in discriminating
between the two classes. On the other hand, if the OOB misclassification rate
is lower, then the dependencies are playing an important role. If there is some
discrimination between the two classes, then the resulting proximity matrix
254 Random forests
I then fit an RF of 1000 trees using the newly created binary indicator y and
generated proximities for the original (i.e., first 200) observations. So how well
did the unsupervised RF cluster the data? Well, we could convert the prox-
imity matrix into a dissimilarity matrix and feed it into any distance-based
clustering algorithm. Another approach, which I’ll take here, is to visual-
ize the dissimilarities using multidimensional scaling (MDS). MDS is one of
many methods for displaying (transformed) multidimensional data in a lower-
dimensional space; for details, see Johnson and Wichern [2007, Sec. 12.6].
Essentially, MDS takes a set of dissimilarities—one minus the proximities, in
this case—and returns a set of points such that the distances between the
points are approximately equal to the dissimilarities. Figure 7.10 shows the
best-fitting two-dimensional representation. Here you can see a clear separa-
tion between the genuine bills (black) and counterfeit bills (yellow).
0.2
Scaling coordinate 2
0.1
0.0
−0.1
−0.2
proximity (i.e., are more similar) to x0 . Note that most open source RF soft-
ware provide the option to specify case weights for the training observations,
which are used to weight each row when taking bootstrap samples (but many
implementations do not provide proximities). While the idea of case-specific
RFs makes sense, it has a couple of limitations. First off, it requires fitting
Ntst + 1 RFs, which can be expensive whenever N or Ntst are large. Second,
it requires computing N × Ntst proximities from an RF, which aren’t always
available from software.
Case-specific RFs are relatively straightforward to implement with tradi-
tional RF software, provided you can compute proximity scoresf . The R pack-
age ranger provides an implementation of case-specific RFs. I applied this
methodology to the Ames housing example (Section 1.4.7), which actually re-
sulted in a slight increase to the test RMSE when compared to a traditional
RF; the code to reproduce the example is available on the companion website
for this book.
Using a similar technique to OOB error estimation, Wager et al. [2014] pro-
posed a method for estimating the variance of an RF prediction using a tech-
nique called the jackknife. The jackknife procedure is very similar to LOOCV,
but specifically used for estimating the variance of a statistic of interest. If
we have a statistic, θ̂, estimated from N training records, then the jackknife
estimate of the variance of θ̂ is given by:
N −1XN 2
V̂jack θ̂ = θ̂(i) − θ̂(·) , (7.4)
N i=1
where θ̂(i) is the statistic of interest using all the N training observations
PN
except observation i, and θ̂(·) = i=1 θ̂(i) /N .
For brevity, let fˆ (x) = fˆBrf (x), for some arbitrary observation x (see Algo-
rithm 7.1). A natural jackknife variance estimate for the RF prediction fˆ (x)
is given by
f Even if you don’t have access to an implementation of RFs that can compute proximi-
ties, they’re still obtainable as long as you can compute terminal node assignments for new
observations (i.e., compute which terminal node a particular observation falls in for each of
the B trees), which is readily available in most RF software. See Section 7.6 for details.
Prediction standard errors 257
N
N −1X 2
V̂jack fˆ (x) = fˆ(i) (x) − fˆ (x) .
(7.5)
N i=1
This is derived under the assumption that B = ∞ trees were averaged together
in the forest, which, of course, is never the case. Consequently, (7.5) has a
positive bias. Fortunately, the same B bootstrap samples used to derive the
forest can also be used to provide the bias corrected variance estimate
N
V̂jack fˆ (x) = V̂jack fˆ (x) − (e − 1) v̂ (x) ,
BC
(7.6)
B
B
1 Xˆ 2
v̂ (x) = fb (X) − fˆ (x) (7.7)
B
b=1
Switching back to the email spam data, let’s compute jackknife-based standard
errors for the test set predicted class probabilities. Following Wager et al.
[2014], I fit an RF using B = 20, 000 trees and three different values for mtry :
5, 19 (based on Breiman’s default for classification), and 57 (an ordinary
bagged tree ensemble).
258 Random forests
The predicted class probabilities for type = "spam", based on the test data,
from each RF are displayed in Figure 7.11 (x-axis), along with their bias-
corrected jackknife estimated standard errors (y-axis). Notice how the mis-
classified cases (solid black points) tend to correspond to observations where
the predicted class probability is closer to 0.5. It also appears that the more
constrained RF with mtry = 5 produced smaller standard errors, while the
default RF (mtry = 19) and bagged tree ensemble (mtry = 57) produced no-
ticeably larger standard errors, with the bagged tree ensemble performing the
worst.
Standard error
Standard error
0.20 0.20 0.20
0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0 0.0 0.2 0.4 0.6 0.8 1.0
predictive accuracy, even if they do make the tree harder to interpret. Many
decision tree algorithms support linear splits (e.g., CART and GUIDE) and
Breiman [2001] even proposed a variant of RFs that employed linear splits
based on random linear combinations of the predictors. This approach did not
gain the same traction as the traditional RF algorithm based on univariate
splits. In fact, I’m not aware of any open source RF implementations that
support his original approach based on random coefficients.
Menze et al. [2011] proposed a variant, called oblique random forests (ORFs),
that explicitly learned optimal split directions at internal nodes using linear
discriminative modelsg , as opposed to random linear combinations. Similar
to random rotation ensembles (Section 7.8.3), ORFs tend to have a smoother
topology; see, for example, Figure 7.16. Menze et al. [2011] even go as far as to
recommend the use of ORFs over the traditional RF when applied to mostly
numeric features. Nonetheless, the idea of non-axis oriented splits in RFs has
still not caught on. The only open source implementation of ORFs that I’m
aware of is in the R package obliqueRF [Menze and Splitthoff, 2012], which
has not been updated since 2012.
A more recent approach, called projection pursuit random forest (PPforest)
[da Silva et al., 2021a] uses splits based on linear combinations of randomly
chosen inputs. Each linear combination is found by optimizing a projection
pursuit index [Friedman and Tukey, 1974] to get a projection of the features
that best separates the classes; hence, this method is also only suitable for
classification. PPforests are implemented in the R package PPforest [da Silva
et al., 2021b]. Individual projection pursuit trees (PPtrees) [Lee et al., 2013],
which are used as the base learners in a PPforest, can be fit using the R
package PPtreeViz [Lee, 2019] (which seems to have superseded the older
PPtree package).
restricts itself to linear splits in only two features at a time to help with interpretation and
reduce the impact of missing values.
260 Random forests
N
X
F̂ (y|x) = wi (x) I (Yi ≤ y) ,
i=1
show the corresponding 95% prediction bounds. (Note that the prediction
intervals here are pointwise prediction intervals.)
The most expensive house in the test set sold for a (rescaled) sale price of
$610. A traditional RF estimated a conditional mean sale price of $472.59,
whereas the QRF produced a conditional median sale price of $479.07 with a
0.025 quantile of $255.24 and a 0.975 quantile of $745. Here, the QRF gives
a much better sense of the variability in the predicted outcome, as well as a
sense of the skewness of its distribution.
300
Observed
Predicted median
200
Sale price / 1000 (centered)
Prediction internal
100
−100
−200
−300
FIGURE 7.12: Rescaled sale prices for each home in the test set, along with
the predicted 0.025, 0.5, and 0.975 quantiles from a QRF. To enhance visu-
alization, the observations were ordered according to the length of the corre-
sponding prediction intervals, and the mean of the upper and lower end of the
prediction interval is subtracted from all observations and prediction intervals.
Before talking about rotation forests and random rotation ensembles, it might
help to briefly discuss rotation matrices. A rotation matrix R of dimension p
is a p × p square transformation matrix that’s used to perform a rotation in
N -dimensional Euclidean space.
A common application of rotation matrices in statistics is principal component
analysis (PCA). The details of PCA are beyond the scope of this book, but
the interested reader is pointed to Johnson and Wichern [2007, Chap. 8],
among others. While PCA has many use cases, it is really just an unsupervised
262 Random forests
iid iid
where X1i ∼ U (0, 1) and i ∼ N (0, 1). Further, let X be the 100 × 2 matrix
whose first and second columns are given by X1i and X2i , respectively. As
a rotation in two dimensions, PCA finds the rotation of the axes that yields
maximum variance. The rotated axes for this example are shown in Figure 7.13
(middle). Notice that the first (i.e., yellow) axis is aligned with the direction
of maximum variance in the sample. An alternative would be to rotate the
data points themselves (right side of Figure 7.13). In this case, the variable
loadings from PCA form a 2×2 rotation matrix, R, that can be used to rotate
X so that the direction of maximal variance aligns with the first (i.e., yellow)
axis; this is shown in the right side of Figure 7.13. The rotated matrix is given
by X 0 = XR. Notice how the relative position of the points between x1 and
x2 is preserved, albeit rotated about the axes.
r g al a es a e a es ae p s
2
5 0 5 5 0 5 5 0 5
1
FIGURE 7.13: Data generated from a simple linear regression model. Left:
Original data points and axes. Middle: Original data with rotated axes (notice
how the first/yellow axis aligns with the direction of maximum variance in the
sample). Right: Rotated data points on the original axes (here the data are
rotated so that the direction of maximal variance aligns with the first/yellow
axis).
So what does any of this have to do with RFs? Recall that the key to accuracy
with model averaging is diversity. In an RF, diversity is achieved by choosing
a random subset of predictors prior to each split in every tree. A rotation
forest [Rodríguez et al., 2006], on the other hand, introduces diversity to a
bagged tree ensemble by using PCA to construct a rotated feature space prior
Random forest extensions 263
to the construction of each tree. Rotating the feature space allows adaptive
nonparametric learning algorithms, like decision trees, to learn potentially
interesting patterns in the data that might have gone unnoticed in the original
feature space. Applying PCA to all the predictors prior to the construction
each tree, even when using sampling with replacement, won’t be enough to
diversify the ensemble. Instead, prior to the construction of each tree, the
predictor set is randomly split into K subsets, PCA is run separately on each,
and a new set of linearly extracted features is constructed by pooling all the
principal components (i.e., the rotated data points). K is treated as a tuning
parameter, but the value of K that results in roughly three features per subset
seems to be the suggested default [Kuncheva and Rodríguez, 2007]. Rotation
forests can be thought of as a bagged tree ensemble with a random feature
transformation applied to the predictors prior to constructing each tree. In this
sense, PCA can be thought of as a feature extraction method. By performing
PCA on random subsets of features prior to fitting each tree, rotation forests
can improve the performance of a bagged tree ensemble. In this case, the
derived features come from PCA applied to random subsets of the data, and
while other feature extraction methods have also been considered, PCA was
found to be the most suitable [Kuncheva and Rodríguez, 2007].
Rotation forests have been shown to be competitive with RFs and can achieve
better performance on data sets with mostly quantitative variables; although,
this seems to be true mostly for smaller ensemble sizes [Rodríguez et al., 2006,
Kuncheva and Rodríguez, 2007]. However, most comparative studies I’ve seen
seem to focus on classification accuracy for comparison, which we know is not
the most appropriate metric for comparing models in classification settings.
Rotation forests are available in the R package rotationForest [Ballings and
Van den Poel, 2017].
In this section, I’ll use the Gaussian mixture data from Hastie et al. [2009]
to compare the results of an RF, rotation forest, and random rotation forest.
The data for each class come from a mixture of ten normal distributions,
Random forest extensions 265
r g al sa ple
0
2
5 0 5
FIGURE 7.14: Scatterplot of x1 vs. x2 . The original data are shown in black
and a dashed black line gives the direction of maximum variance. The rotated
points under PCA are shown in dark yellow along with the new axis of maximal
variance; notice how in two dimensions this shifts the points to have maximal
variance along the x-axis. The rest of the colors display the data points under
random rotations.
Note that this is not a data frame, but rather a list with several components;
for a description of each, see ?treemisc::load_eslmix.
The code chunk below constructs a scatterplot of the training data (i.e., com-
ponent x) along with the Bayes decision boundaryh ; see Figure 7.15. The
Bayes error rate for these data—that is, the theoretically optimal error rate—
is 0.210.
x <- as.data.frame(eslmix$x) # training data
xnew <- as.data.frame(eslmix$xnew) # evenly spaced grid of points
x$y <- as.factor(eslmix$y) # coerce to factor for plotting
xnew$prob <- eslmix$prob # Pr(Y = 1 | xnew)
# Colorblind-friendly palette
oi.cols <- unname(palette.colors(8, palette = "Okabe-Ito"))
for each of thew two classes, the theoretically optimal decision boundary can be computed
exactly. This makes it useful to compare classifiers visually in terms of their estimated
decision boundaries.
Random forest extensions 267
1
x2
−1
−2
−2 0 2 4
x1
FIGURE 7.15: Simulated mixture data with optimal (i.e., Bayes) decision
boundary.
implementation only uses regression trees for the base learners, hence, only re-
gression and binary classification are supported. For the latter, the probability
machine approach discussed in Section 7.2.1 is implemented.
The resulting decision boundaries from each forest are in Figure 7.16. Here,
you can see that the axis-oriented nature of the individual trees in a traditional
RF leads to a decision boundary with an axis-oriented flavor (i.e., the decision
boundary is rather “boxy”). The RF also exhibits more signs of overfitting,
as suggested by the little islands of decision boundaries. On the other hand,
using feature rotation (with PCA or random rotations) prior to building each
tree results in a noticeably smoother and non-axis-oriented decision boundary.
The test error rates for the RF, rotation forest, and random rotation forest,
under this random seed, are 0.235, 0.239, and 0.226, respectively. (As always,
the code to reproduce this example is available on the companion website for
this book.)
1
2
2
2 0 2 2 0 2 2 0 2
1
FIGURE 7.16: Traditional RF vs. random rotation forest on the mixture data
from Hastie et al. [2009]. The random rotation forest produces a noticeably
smoother decision boundary than the axis-oriented decision boundary from a
traditional RF.
decrease variance, but sometimes at the cost of additional bias [Geurts et al.,
2006]; this is especially true if the data contain irrelevant features. To combat
the extra bias, extra-trees utilize the full learning sample to grow each tree,
rather than bootstrap replicas (another subtle difference from the bagging
and RF algorithms). Note that bootstrap sampling can be used in extra-trees
ensembles, but Geurts et al. argue that it can often lead significant drops in
accuracyi .
The primary tuning parameters for an extra-trees ensemble are K and nmin ,
where K is the number of random splits to consider for each candidate splitter,
and nmin is the minimum node size, which is a common parameter in many
tree-based models and can act as a smoothing parameter. A common default
for K is
(√
p for classification
K= ,
p for regression
pling.
Random forest extensions 269
largest reduction in node impurity). Note that nmin has the same defaults as
it does in an RF (Section 7.2).
The extra-trees ensemble still makes use of the RF mtry parameter, but note
that, in the extreme case where K = 1, an extra-trees tree is unsupervised
in that the response variable is not needed in determining any of the splits.
Such a totally randomized tree [Geurts et al., 2006] can be useful in detecting
potential outliers and anomalies, as will be discussed in Section 7.8.5. Extra-
trees can be fit in R via the ranger package. In Python, an implementation of
the extra-trees algorithm is provided by the sklearn.ensemble module.
1
PB
s (x, N ) = 2− B b=1
hb (x)/c(N )
, (7.9)
where N is the sample size, hb (x) is the path length to x in the b-th tree,
and c (N ) is the average path length of unsuccesful searches; in a binary tree
constructed from N observations, c (N ) is given by
c (N ) = 2H (N − 1) − 2 (N − 1) /N,
2
2
2 0 2
1
An extended isolation forest [Hariri et al., 2021] improves the consistency and
reliability of the anomaly score produced by a standard isolation forest by
using random oblique splits (in this case, hyperplanes with random slopes)—
as opposed to axis-oriented splits—which often results in improved anomaly
scores. The tree in Figure 7.17 is from an extended isolation forest fit using
the eif package [Hariri et al., 2021], which is also available in R.
To illustrate the basic use of an isolation forest, I’ll use a data set from Kag-
glej containing anonymized credit card transactions, labeled as fraudulent or
genuine, obtained over a 48 hour period in September of 2013; the data can
be downloaded from Kaggle at
https://www.kaggle.com/mlg-ulb/creditcardfraud.
Recognizing fraudulent credit card transactions is an important task for credit
card companies to ensure that customers are not charged for items that they
did not purchase. Since fraudulent transactions are relatively rare (as one
j Kaggle is an online community of data scientists and machine learning practitioners
who can find and publish data sets and enter competitions to solve data science challenges;
for more, visit https://www.kaggle.com/.
272 Random forests
would hope), the data are highly imbalanced, with 492 frauds (0.17%) out of
the N = 284,807 transactions.
For reasons of confidentiality, the original features have been transformed
using PCA, resulting in 28 numeric features labeled V1, V2, ..., V28. Two addi-
tional variables, Time (the number of seconds that have elapsed between each
transaction and the first transaction in the data set) and Amount (the trans-
action amount), are also available. These are labeled data, with the binary
outcome Class taking on values of 0 or 1, where a 1 represents a fraudulent
transaction. While this can certainly be framed as a supervised learning prob-
lem, I’ll only use the class label to measure the performance of our isolation
forest-based anomaly detection, which will be unsupervised. I would argue
that it is probably more often that you will be dealing with unlabeled data of
this nature, as it is rather challenging to accurately label each transaction in
a large database.
To start, I’ll split the data into train/test samples using only N = 10,000 ob-
servations (3.51%) for training; the remaining 274,807 observations (96.49%)
will be used as a test set. However, before doing so, I’m going to shuffle the
rows just to make sure they are in random order first. (Assume I’ve already
read the data into a data frame called ccfraud.)
# ccfraud <- data.table::fread("some/path/to/ccfraud.csv")
#>
#> 0 1
#> 0.9982 0.0018
#>
#> 0 1
#> 0.99828 0.00172
Next, I’ll use the isotree package [Cortes, 2022] to fit a default isolation forest
to the training set and provide anomaly scores for the test set. (Notice how
I exclude the true class labels (column 31) when constructing the isolation
forest!)
Random forest extensions 273
library(isotree)
#> 1 2 3 4 5 6
#> 0.320 0.341 0.324 0.325 0.340 0.325
k Precision (or positive predictive value) is directly proportional to the prevalence of the
positive outcome. PR curves are not appropriate for case-control studies (e.g., which also
includes case-control sampling—like down sampling—with imbalanced data sets) and should
only be used when the true class priors are reflected in the data.
l The cumulative gains (or lift) chart shows the fraction of the overall number of cases
head(cbind(recall, precision))
#> recall precision
#> [1,] 1 0.00172
#> [2,] 1 0.00172
#> [3,] 1 0.00172
#> [4,] 1 0.00172
#> [5,] 1 0.00173
#> [6,] 1 0.00173
# Compute data for lift chart
ord <- order(scores, decreasing = TRUE)
y <- ccfraud.tst$Class[ord] # order according to sorted scores
prop <- seq_along(y) / length(y)
lift <- cumsum(y) / sum(ccfraud.tst$Class) # convert to proportion
head(cbind(prop, lift))
#> prop lift
#> [1,] 3.64e-06 0.00000
#> [2,] 7.28e-06 0.00000
#> [3,] 1.09e-05 0.00000
#> [4,] 1.46e-05 0.00000
#> [5,] 1.82e-05 0.00000
#> [6,] 2.18e-05 0.00211
1.0
0.5 0.9
Proportion of anomalies identified
0.8
0.4 0.7
0.6
Precision
0.3
0.5
0.4
0.2
0.3
0.1 0.2
0.1
0.0 0.0
0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0
Recall Proportion of sample inspected
We can take the analysis a step further by using Shapley values (Section 6.3.1)
to help explain the observations with the highest/lowest anomaly scores,
whichever is of more interest. To illustrate, let’s estimate the feature con-
tributions for the test observation with the highest anomaly score. Keep in
mind that the features in this data set have been anonymized using PCA, so
we won’t be able to understand much of the output from a contextual perspec-
tive, but the idea applies to any application of anomaly detection based on a
model that produces anomaly scores, like isolation forests. I’m just treating
the scores as ordinary predictions and applying Shapley values in the usual
way.
In the code chunk below, I find the observation in the test data that corre-
sponds to the highest anomaly score. Here, we see that max.x corresponds to
an actual instance of fraud (Class = 1) and was assigned an anomaly score
of 0.843. The average anomaly score on the training data is 0.336, for a differ-
ence of 0.507. The question we want to try and answer is: how did each feature
contribute to the difference 0.507? This is precisely the type of question that
Shapley values can help with.
max.id <- which.max(scores) # row ID for max anomaly score
(max.x <- ccfraud.tst[max.id, ])
#> Time V1 V2 V3 V4 V5 V6 V7 V8
#> 1: 166198 -35.5 -31.9 -48.3 15.3 -114 73.3 121 -27.3
#> V9 V10 V11 V12 V13 V14 V15 V16 V17
#> 1: -3.87 -12 6.85 -9.19 7.13 -6.8 8.88 17.3 -7.17
#> V18 V19 V20 V21 V22 V23 V24 V25 V26
#> 1: -1.97 5.5 -54.5 -21.6 5.71 -1.58 4.58 4.55 3.42
#> V27 V28 Amount Class
#> 1: 31.6 -15.4 25691 0
max(scores)
#> [1] 0.843
Amount=25691.16
V7=120.59
V5=−113.74
V6=73.3
V8=−27.35
V3=−48.33
V28=−15.43
V2=−31.85
V27=31.61
V20=−54.5
V1=−35.55
V21=−21.62
V4=15.3
V16=17.32
V26=3.42
V24=4.58
V10=−12.01
V12=−9.19
V11=6.85
V25=4.55
V13=7.13
V15=8.88
V22=5.71
V19=5.5
V14=−6.8
V17=−7.17
V9=−3.87
V18=−1.97
V23=−1.58
Time=166198
0.00 0.01 0.02 0.03 0.04
Shapley−based feature contribution
FIGURE 7.19: Estimated feature contributions for the test observation with
highest anomaly score. There’s a dashed vertical line at zero to differentiate
between features with a positive/negative contribution. Inhthis case,
i all feature
ˆ
values contributed positively to the difference f (x ) − E f (x) = 0.501.
?
RFs are available in numerous software, both open source and proprietary. The
R packages randomForest, ranger, and randomForestSRC [Ishwaran and
Kogalur, 2022] implement the traditional RF algorithm for classification and
regression; the latter two also support survival analysis, as well as several
other extensions. It’s important to point out that ranger’s implementation
of the RF algorithm treats categorical variables as ordered by default; for
Software and examples 277
No! These data are easy and an ensemble would be overkill here. Remember,
the original goal of the problem was to come up with an accurate but simple
rule for determining the edibility of a mushroom. This was easily accomplished
using a single decision tree (e.g., CART with some manual pruning) or a rule-
based model like CORELS; see, for example, Figure 2.22.
In Section 5.5, I showed how the LASSO can be used to effectively post-
process a tree-based ensemble by essentially zeroing out the predictions from
some of the trees and reweighting the rest. The idea is that we can often
reduce the number of trees quite substantially without sacrificing much in
the way of performance. A smaller number of trees means we could, at least
in theory, compute predictions faster, which has important implications for
model deployment (e.g., when trying to score large data sets on a regular
basis). However, unless we have a way to remove the zeroed out trees from
278 Random forests
the fitted RF object, we can’t really reap all the benefits. This is the purpose of
the new deforest() function in the ranger packagem , which I’ll demonstrate
in this section using the Ames housing example.
Keep in mind that this method of post-processing is not specific to bagged
tree ensembles and RFs, and can be fruitfully applied to other types of en-
sembles as well; see Section 8.9.3 for an example using a gradient boosted tree
ensemble.
To start, I’ll load a few packages, prep the data, and create a helper function
for computing the RMSE as a function of the number of trees in a ranger-
based RF:
library(ranger)
library(treemisc) # for isle_post() function
# Load the Ames housing data and split into train/test sets
ames <- as.data.frame(AmesHousing::make_ames())
ames$Sale_Price <- ames$Sale_Price / 1000 # rescale response
set.seed(2101) # for reproducibility
trn.id <- sample.int(nrow(ames), size = floor(0.7 * nrow(ames)))
ames.trn <- ames[trn.id, ] # training data/learning sample
ames.tst <- ames[-trn.id, ] # test data
xtst <- subset(ames.tst, select = -Sale_Price) # test features only
The test RMSE for the RFO model is comparable to the test RMSE from the
conditional RF fit in Section 7.2.3. In comparison, the RFO.4.5 model has a
much larger test RMSE, which we might have expected given the shallowness
of each tree and the tiny fraction of the learning sample each was built from.
Consequently, the RFO.4.5 model finished training in only a fraction of the
time it took the RFO model. As we’ll see shortly, post-processing will help
improve the performance of RFO.4.5 so that it is comparable to RFO in terms
of performance, while substantially reducing the number of trees (i.e., compa-
rable performance, faster training time, and fewer trees in the end).
Next, I’ll obtain the individual tree predictions from each forest and post-
process them using the LASSO via treemisc’s isle_post() function. Note
that k-fold cross-validation can be used here instead of (or in conjunction
with) a test set; see ?treemisc::isle_post for details. For brevity, I’ll use a
simple prediction wrapper, called treepreds(), to compute and extract the
individual tree predictions from each RF model:
treepreds <- function(object, newdata) {
p <- predict(object, data = newdata, predict.all = TRUE)
p$predictions # return predictions component
}
The results are plotted in Figure 7.20. Here, we can see that both mod-
els benefited from post-processing, but the RFO model only experienced a
marginal increase in performance compared to RFO.4.5. Is the slightly better
performance in the default RFO model enough to justify its larger training
time? Maybe in this particular example, but for larger data sets, the differ-
ence in training time can be huge, making it extremely worthwhile. For the
post-processed RFO.4.5 model, the test RMSE is minimized using only 93
(reweighted) trees.
palette("Okabe-Ito")
plot(rmse.rfo, type = "l", ylim = c(20, 50),
las = 1, xlab = "Number of trees", ylab = "Test RMSE")
lines(rmse.rfo.4.5, col = 2)
lines(sqrt(rfo.post$results$mse), col = 1, lty = 2)
lines(sqrt(rfo.4.5.post$results$mse), col = 2, lty = 2)
legend("topright", col = c(1, 2, 1, 2), lty = c(1, 1, 2, 2),
legend = c("RFO", "RFO.4.5","RFO (post)", "RFO.4.5 (post)"),
inset = 0.01, bty = "n")
palette("default")
To make this useful in practice, we need a way to remove trees from a fit-
ted RF (i.e., to “deforest” the forest of trees). This could vastly speed up
prediction time and reduce the memory footprint of the final model. Fortu-
nately, the ranger package includes such a function; see ?ranger::deforest
for details.
In the code snippet below, I “deforest” the RFO.4.5 ensemble by removing
trees corresponding to the zeroed-out LASSO coefficients, which requires es-
Software and examples 281
50
RFO
RFO.4.5
45
RFO (post)
RFO.4.5 (post)
40
Test RMSE
35
30
25
20
FIGURE 7.20: Test RMSE for the RFO and RFO.4.5 fits. The dashed lines
correspond to the post-processed versions of each model. Note how the RFO
model only experienced a marginal increase in performance compared to the
RFO.4.5 model.
timating the optimal value for the penalty parameter λ (it might be helpful
to read the help page for ?glmnet::coef.glmnet):
res <- rfo.4.5.post$results # post-processing results on test set
lambda <- res[which.min(res$mse), "lambda"] # optimal penalty parameter
coefs <- coef(rfo.4.5.post$lasso.fit, s = lambda)[, 1L]
int <- coefs[1L] # intercept
tree.coefs <- coefs[-1L] # no intercept
trees <- which(tree.coefs == 0) # trees to remove
Notice the impact this had on reducing the overall size of the fitted model.
This can often lead to a much more compact model that’s easier to save and
load when memory requirements are a concern.
We can’t just use the “deforested” tree ensemble directly; remember, the
estimated LASSO coefficients imply a reweighting of the remaining trees!
To obtain the reweighted predictions from the “deforested” model, we need
to do a bit more work. Here, I’ll create a new prediction function, called
predict.def(), that will compute the reweighted predictions from the re-
maining trees using the estimated LASSO coefficients—similar to how predic-
tions in a linear model are computed.
To test it out, I’ll stack the learning sample (ames.trn) on top of itself 100
times, resulting in N = 205, 100 observations for scoring. Below, I compare the
prediction times for both the original (i.e., non-processed) and “deforested”
RFO.4.5 fits:
ames.big <- # stack data on top of itself 100 times
do.call("rbind", args = replicate(100, ames.trn, simplify = FALSE))
The final model contains only 93 trees and achieved a test RMSE of 26.59,
while also being orders of magnitude faster to initially train. The computa-
tional advantages are easier to appreciate on even larger data sets.
In summary, I used the LASSO to post-process and “deforest” a large ensemble
of shallow trees (which trained relatively fast), producing a much smaller
ensemble with fewer trees that scores faster compared to the default RFO.
While the default RFO model had a slightly smaller test RMSE of 24.72
compared to the “deforested” RFO.4.5 test RMSE of 111.29, the difference is
arguably negligible (especially when you take the differences in both training
and scoring time into account).
Note that roughly 20.09% of the values for age, the age in years of the pas-
senger, are missing:
sapply(t3, FUN = function(x) mean(is.na(x)))
#> survived pclass age sex sibsp parch
#> 0.000 0.000 0.201 0.000 0.000 0.000
Following Harrell [2015, Sec. 12.4], I use a decision tree to investigate which
kinds of passengers tend to have a missing value for age. In the example be-
low, I use the partykit package to apply the CTree algorithm (Chapter 3)
using a missing value indicator for age as the response. From the tree out-
put we can see that third-class passengers had the highest rate of missing
age values (29.3%), followed by first-class male passengers with no siblings or
oA description of the original source of these data is provided in Harrell [2015, p. 291].
284 Random forests
spouses aboard (22.8%). This makes sense, since males and third-class pas-
sengers supposedly had the least likelihood of survival (“women and children
first”).
library(partykit)
Next, I’ll use the CART-based multiple imputation procedure outlined in Sec-
tion 2.7.1 to perform m = 21 separate imputations for each missing age value.
Why did I choose m = 21? White et al. [2011] propose setting m ≥ 100f , where
f is the fraction of incomplete casesp . Since age is the only missing variable,
with f = 0.201, I chose m = 21. Using multiple different imputations will give
us an idea of the sensitivity of the results of our (yet to be fit) RF.
library(mice)
p When f ≥ 0.03, Harrell [2015, p. 57] suggests setting m = max (5, 100f )
Software and examples 285
0.0
0.03
e s y
0.02
0.01
0.00
0 20 0 60 0
age
FIGURE 7.21: Nonparametric density estimate of age for the complete cases
(blue line) and 15 imputed data sets.
Nonparametric densities for passenger age are given in Figure 7.21. There is
one density for each of the imputed data sets (red curves) and one density for
the original complete case (blue curve). The overall distributions are compa-
rable, but there is certainly some variance across the m = 21 imputation runs.
Our goal is to run an RF analysis for each of the m = 21 completed data sets
and inspect the variability of the results. For instance, I might graphically
show the m = 21 variable importance scores for each feature, along with the
mean or median.
Next, I call complete() (from package mice) to produce a list of the m = 21
completed data sets which I can use to carry on with the analysis. The only
difference is that I’ll perform the same analysis on each for the completed data
sets.q
t3.mice <- complete(
data = imp, # "mids" object (multiply imputed data set)
action = "all", # return list of all imputed data sets
include = FALSE # don't include original data (i.e., data with NAs)
)
length(t3.mice) # returns a list of completed data sets
#> [1] 21
q This approach is probably not ideal in situations where the analysis is expensive (e.g.,
because the data are “big” and the model is expensive to tune). In such cases, you may
have to settle for a smaller, less optimal value for m.
286 Random forests
For comparison, let’s look at the results from using the proximity-based RF
imputation procedure discussed in Section 7.6.2. The code snippet below uses
rfImpute() from package randomForest to handle the proximity-based im-
putation. The results are plotted along with those from MICE in Figure 7.22.
# Plot results
palette("Okabe-Ito")
plot(x[, 1], type = "n", xlim = c(1, length(na.id)), ylim = c(0, 100),
las = 1, ylab = "Imputed value")
for (i in 1:m) {
lines(x[, i], col = adjustcolor(1, alpha.f = 0.1))
}
lines(rowMeans(x[, 1:m]), col = 1, lwd = 2)
lines(x[, m + 1], lwd = 2, col = 2)
legend("topright", legend = c("MICE: CART", "RF: proximity"), lty = 1,
col = 1:2, bty = "n")
palette("default")
Here, you can see that the imputed values from both procedures are similar,
but that multiple imputations provide a range of plausible values. Also, there
are a few instances where there’s a bit of a gap between the imputed values for
the two procedures. For example, consider observations 956 and 959, whose
records are printed below. The first passenger is recorded to be a third-class
female with three siblings (or spouses) and one parent (or child) aboard. This
individual is likely a child. The proximity-based imputation imputed the age
Software and examples 287
100
A
pr y
0
pu e value
60
20
FIGURE 7.22: Imputed values for the 263 missing age values. The yellow
line corresponds to the proximity-based imputation. The light gray lines cor-
respond to the 15 different imputation runs using MICE, and the black line
corresponds to their average.
for this passenger as 17.522 years, whereas MICE gives an average of 4.571
years and a plausible range of 0.75–8.00 years. Similarly, the proximity method
imputed the age for case 959—a third-class female with three children—as
23.52 whereas MICE gave an average of 40.238. Which imputations do you
think are more plausible?
t3[c(956, 959), ]
#> survived pclass age sex sibsp parch
#> 956 0 3 NA female 3 1
#> 959 0 3 NA female 0 4
With the m = 21 completed data sets in hand, I can proceed with the RF
analysis. The idea here is to fit an RF separately to each completed data
set. I’ll then look at the variance of the results (e.g., OOB error, variable
importance scores, etc.) to judge its sensitivity to the different plausible im-
putations. Below, I use the ranger package to fit a (default) RF/probability
machine to each of the m = 21 completed data sets. In anticipation of look-
ing at the sensitivity of the variable importance scores, I set importance =
288 Random forests
# Obtain a list of probability forests, one for each imputed data set
set.seed(2147) # for reproducibility
rfos <- lapply(t3.mice, FUN = function(x) {
ranger(as.factor(survived) ~ ., data = x, probability = TRUE,
importance = "permutation")
})
The OOB errors from each model are comparable; that’s a good start! The
average OOB Brier score is 0.134, with a standard deviation of 0.001.
Next, I’ll look at variable importance. With multiple imputation I think the
most sensible thing to do is to just plot the variable importance scores from
each run together, so that you can see the variability in the results:
# Compute list of VI scores, one for each model. Note: can use
#`FUN = ranger::importance` to be safe
vis <- lapply(rfos, FUN = importance)
0.12
0.10
0.08
0.06
0.04
0.02
Next, I plot the results. There’s some R-ninja trickery happening in the code
chunk below in order to get the plot I want. Using ggplot2, I want to group
a set of line plots by two variables, but color by just one of them. We can
paste the two grouping variables together into a new column. However, base
R’s interaction() function can accomplish this for us; see ?interaction for
details.
The results are displayed in Figure 7.24; compare this to Figure 12.22 in
Harrell [2015, p. 308]. I also included a rug representation (i.e., 1-d plot) in
each panel showing the deciles (i.e., the 10-th percentile, 20-th percentile, etc.)
of passenger age from the original (incomplete) training set. This helps guide
where the plots are potentially extrapolating. Using deciles means that 10%
of the observations lie between any two consecutive rug marks; see Greenwell
[2017] for some remarks on the importance of avoiding extrapolation when
r I’m using cats = "pclass" here to treat pclass as categorical since it’s restricted to
# Plot results
deciles <- quantile(t3$age, prob = 1:9/10, na.rm = TRUE)
ggplot(pdps, aes(age, yhat, color = sex,
group = interaction(m, sex))) +
geom_line(alpha = 0.3) +
geom_rug(aes(age), data = data.frame("age" = deciles),
sides = "b", inherit.aes = FALSE) +
labs(x = "Age (years)", y = "Surival probability") +
facet_wrap(~ pclass) +
scale_colour_manual(values = c("black", "orange")) + # Okabe-Ito
theme_bw() +
theme(legend.title = element_blank(),
legend.position = "top")
e ale ale
1 2 3
1.00
0.75
a l y
Sur val pr
0.50
0.25
10 20 30 0 50 10 20 30 0 50 10 20 30 0 50
Age (years)
While there’s some variability in the results, the overall patterns are clear.
First-class females had the best chances of surviving, regardless of age or
class. Passenger age seems to have a stronger negative effect on passenger
survivability for males compared to females, regardless of class. The difference
in survivability between males and females is less pronounced for third-class
292 Random forests
passengers. Do you agree? What other conclusions can you draw from this
plot?
Finally, let’s use Shapley values (Section 6.3.1) to help us understand in-
dividual passenger predictions. To illustrate, let’s focus on a single (hypo-
thetical/fictional) passenger. Everyone, especially those who haven’t seen the
movie, meet Jacks :
jack.dawson <- data.frame(
#survived = 0L, # in case you haven't seen the movie
pclass = 3L, # using `3L` instead of `3` to treat as integer
age = 20.0,
sex = factor("male", levels = c("female", "male")),
sibsp = 0L,
parch = 0L
)
Here, I’ll use the fastshap package for computing Shapley values, but you can
use any Shapley value package you like (e.g., R package iml). First, we need
to set up a prediction wrapper—this is a function that tells fastshap how
to extract predictions from the fitted ranger model, which is the purpose of
function pfun() below. Next, I compute approximate feature contributions for
Jack’s predicted probability of survival using 1000 Monte-Carlo repetitions,
which is done for each of the m =21 completed data sets:
library(fastshap)
sI guesstimated some of Jack’s inputs, based on the movie I saw in seventh grade.
Software and examples 293
#> # A tibble: 21 x 5
#> `pclass=3` `age=20` `sex=male` `sibsp=0` `parch=0`
#> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 -0.0836 -0.0136 -0.141 0.00721 -0.0174
#> 2 -0.0796 -0.0222 -0.144 0.0109 -0.00967
#> 3 -0.0743 -0.000271 -0.144 0.00995 -0.0170
#> 4 -0.0709 -0.0132 -0.139 0.00740 -0.0126
#> 5 -0.0807 -0.0192 -0.134 0.00768 -0.0159
#> 6 -0.0807 -0.0134 -0.136 0.0103 -0.0159
#> 7 -0.0840 -0.00355 -0.145 0.00999 -0.0147
#> 8 -0.0874 0.0110 -0.136 0.0103 -0.0254
#> 9 -0.0754 -0.00982 -0.143 0.00449 -0.0233
#> 10 -0.0663 -0.000338 -0.144 0.00519 -0.0165
#> # ... with 11 more rows
Fortunately, again, the results are relatively stable across imputations. A sum-
mary of the overall Shapley explanations, along with Jack’s predictions, is
shown in Figure 7.25. Here, we can see that Jack being a third-class male
contributed the most to his poor predicted probability of survival, aside from
him not being able to fit on the floating door that Rose was hogging...
# Jack's predicted probability of survival across all imputed
# data sets
pred.jack <- data.frame("pred" = sapply(rfos, FUN = function(rfo) {
pfun(rfo, jack.dawson)
}))
re e pr a l y surv v g ea ure r u
0.160
par h 0
0.155
s sp 0
0.150
0.1 5 se ale
0.1 0
age 20
0.135
p lass 3
0.130
FIGURE 7.25: Predicted probability of survival for Jack across all imputed
data sets (left) and their corresponding Shapley-based feature contributions
(right).
7.9.4 Example: class imbalance (the good, the bad, and the
ugly)
Unless the trees are grown to purity, RFs generally produce consistent and
well-calibrated probabilities [Malley et al., 2012, Niculescu-Mizil and Caruana,
2005], while boosted trees (Chapter 8) do nott . However, we’ll see shortly that
that’s not always the case. Furthermore, calibration curves can be useful in
comparing the performance of fitted probability models.
Real binary data are often unbalanced. For example, in modeling loan defaults,
the target class (default on a loan) is often underrepresented. This is expected
since we would hope that most people don’t default on their loan over the
course of paying it off. However, many practitioners perceive class imbalance
as an issue that affects “accuracy.” In actuality, the problem is usually that
the data are balanced [Matloff, 2017, p. 193].
In this example, I’m going to simulate some unbalanced data. In particular,
I’m going to convert the Friedman 1 benchmark regression data (Section 1.4.3)
into a binary classification problem using a latent variable model. Essentially,
I’ll treat the observed response values as the linear predictor of a logistic re-
gression model and convert them to probabilities. We can then use a binomial
random number generator to simulate the observed class labels. The impor-
tant thing to remember about this simulation study is that we have access to
the true underlying class probabilities!
A simple function to convert the Friedman 1 regression problem into a binary
classification problem, as described above, is given below. (Note that the line
d$y <- d$y - 23 shifts the intercept term and effectively controls the balance
of the generated 0/1 outcomes—here, it was chosen to obtain a 0/1 class
balance of roughly 0.95/0.05.)
gen_binary <- function(...) {
d <- treemisc::gen_friedman1(...) # regression data
d$y <- d$y - 23 # shift intercept
d$prob <- plogis(d$y) # inverse logit to obtain class probabilities
#d$prob <- exp(d$y) / (1 + exp(d$y)) # same as above
d$y <- rbinom(nrow(d), size = 1, prob = d$prob) # 0/1 outcomes
d
}
# Generate samples
set.seed(1921) # for reproducibility
trn <- gen_binary(100000) # training data
tst <- gen_binary(100000) # test data
class priors from the data using πi = Ni /N , for i = 1, 2; this was discussed
for CART-like decision trees in Section 2.2.4. There are three scenarios to
consider:
a) the data form a representative sample, and the observed class frequencies
reflect the true class priors in the population (the good);
b) the class frequencies have been artificially balanced, but the true class
frequencies/priors are known (the bad);
c) the class frequencies have been artificially balanced, and the true class
frequencies/priors are unknown (the ugly).
In the code chunk below I use an independent sample of size N = 106 to esti-
mate π1 (i.e., the prevalance of observations in class 1 in the population):
(pi1 <- proportions(table(gen_binary(1000000)$y))["1"])
#> 1
#> 0.0498
Next, I’ll define a simple calibration function that can be used for isotonic
calibration; for a brief overview of different calibration methods, see Niculescu-
Mizil and Caruana [2005], Kull et al. [2017]. Note that there are many R and
Python libraries for calibration; for example, val.prob() from R package
rms [Harrell, Jr., 2021] and the sklearn.calibration module.
isocal <- function(prob, y) { # isotonic calibration function
ord <- order(prob)
prob <- prob[ord] # put probabilities in increasing order
y <- y[ord]
prob.cal <- isoreg(prob, y)$yf # fitted values
data.frame("original" = prob, "calibrated" = prob.cal)
}
To start, let’s fit a default RF to the original (i.e., unbalanced) learning sample.
Note that I exclude the prod column when specifying the model formula.
library(ranger)
The OOB prediction error (in this case, the Brier score) is 0.026. The Brier
score on the test data can also be computed, but since I have access to the
true probabilities, I might as well compare them with the predictions too (for
this, I’ll compute the MSE between the predicted and true probabilities). In
this case, we see that the Brier score on the test set is comparable to the OOB
Brier score.
prob1 <- predict(rfo1, data = tst)$predictions[, 2]
Looking at a single metric (or metrics) does not paint a full picture, so it can
be helpful to look at specific visualizations, like calibration curves, to further
assess the accuracy of the model’s predicted probabilities (lift charts can also
be useful). The leftmost plot in Figure 7.26 shows the actual vs. predicted
probabilities for the test set, as well as the isotonic-based calibration curve
from the above RF. In this case, the RF seems to be doing a reasonable job
in terms of accuracy. The model seems well-calibrated for probabilities below
0.5, but seems to have a slight negative bias for probabilities above 0.5, which
makes sense since most of the probability mass is concentrated near zero (as
we might have expected given the true class frequencies).
To naively combat the perceived issue of unbalanced class labels, the learning
sample is often artificially rebalanced (e.g., using down sampling) so that the
class outcomes have roughly the same distribution. In general, THIS IS A BAD
IDEA for probability models, and can lead to serious bias in the predicted
probabilities—in fact, any algorithm that requires you to remove good data
to optimize performance is suspect. Nonetheless, sometimes the data have
been artificially rebalanced in a preprocessing step outside of our control, or
maybe you decided to down sample the data to reduce computation time (in
which case, you should try to preserve the original class frequencies, or at least
store them for adjustments later). In any case, let’s see what happens to our
predictions when we down sample the majority class.
298 Random forests
Let’s try this out on an RF fit to a down sampled version of the training data.
Below I artificially balance the classes by removing rows corresponding to the
dominant class (i.e., y = 0):
trn.1 <- trn[trn$y == 1, ]
trn.0 <- trn[trn$y == 0, ]
trn.down <- rbind(trn.0[seq_len(nrow(trn.1)), ], trn.1)
table(trn.down$y)
#>
#> 0 1
#> 5018 5018
Next, I’ll fit another (default) RF, but this time to the down samples training
set. I then apply the adjustment formula to the predicted probabilities for the
positive class in the test set:
set.seed(1146) # for reproducibility
rfo2 <- ranger(y ~ . - prob, data = trn.down, probability = TRUE)
Figure 7.26 shows the predicted vs. true probabilities across three different
cases: 1) predicted probabilities (prob1) from an RF applied to the original
training data (left display), 2) predicted probabilities (prob2) from an RF
applied to a down-sampled version of the original training data, but adjusted
using the original class frequencies (middle display), and 3) predicted proba-
bilities (prob2) from an RF applied to a down-sampled version of the original
training data with no adjustment (right display).
he g he a he ugly
1.00
0.75
a l y
0.50
A ual pr
0.25
0.00
0.00 0.25 0.50 0.75 1.000.00 0.25 0.50 0.75 1.000.00 0.25 0.50 0.75 1.00
re e pr a l y
FIGURE 7.26: True vs. predicted probabilities for the test set from an RF
fit to the original and down sampled (i.e., artificially balanced) training sets.
Left: probabilities from the original RF. Middle: Probabilities from the down
sampled RF with post-hoc adjustment. Right: Probabilities from the down
sampled RF. The calibration curve is shown in orange, with the blue curve
representing perfect calibration.
to provide some relief but requires access to the true class priors (or good
estimates thereof).
Basically, it is ill-advised to choose a model based on a metric that forces a
classification based on an arbitrary threshold. Instead, choose a model using
a proper scoring rule (e.g., or the Brier score) that makes use of the full range
of predicted probabilities and is optimized when the true probabilities are
recovered. Down sampling, adjusted or not, seems to highly under or over
estimate the true class probabilities.
In this section, I’ll look at a well-known bank marketing data set available
from the UC Irvine Machine Learning Repository [Moro et al., 2014]. The data
concern the direct marketing campaigns of a Portuguese banking institution,
which were based on phone calls. Often, more than one contact to the same
client was required, in order to assess if the product, a bank term deposit,
would be subscribed or not; hence, this is a binary classification problem with
response variable y taking on values yes/no depending on whether or not
the client did/did not subscribe to the bank term deposit. For details and a
description of the different columns, visit https://archive.ics.uci.edu/
ml/datasets/Bank+Marketing.
Furthermore, I’ll do the analysis via one of the R front ends to Apache Spark
and MLlib: SparkRu [Venkataraman et al., 2016]. Starting with Spark 3.1.1,
SparkR also provides a distributed data frame implementation that supports
common data wrangling operations like selection, filtering, aggregation, etc.
(similar to using data.table or dplyr with R data frames, but on large data
sets that don’t fit into memory). For instructions on installing Spark, visit
https://spark.apache.org/docs/latest/sparkr.html.
While the bank marketing data contain 21 columns on 41,188 records, this is
by no means “Spark territory.” However, you may find yourself in a situation
where you need to use Spark MLlib for scalable analytics, so I think it’s useful
to show how to perform common statistical and machine learning tasks in
Spark and MLlib, like fitting RFs, assessing performance, and computing PD
plots.
To start, I’ll download a zipped file containing a directory with the full bank
marketing data set. The code downloads the zipped file into a temporary
directory, unzips it, and reads in the CSV file of interest. (If the following
code does not work for you, then you may find it easier to just manually
download the bank-additional-full.csv and read it into R on your own.)
u This analysis can easily be translated to any other front end to Spark, including spark-
lyr or pyspark.
Software and examples 301
Next, I’ll clean up the data a bit. First off, I’ll replace the dots in the column
names with underscores; Spark does not like dots in column names! Second,
I’ll coerce the response (y) from a factor (no/yes) to a binary indicator (0/1)
and treat it as numeric to fit a probability forest/machine. Finally, I’ll re-
move the column called duration. Too often have I seen online analyses of
the same data, only for the analyst to be fooled into thinking that duration
is a useful indicator of whether or not a client will subscribe to a bank term
deposit. If you take care and read the data documentation, you’d notice that
the value of duration is not known before a call is made to a client. In other
words, the value of duration is not known at prediction time and therefore
cannot be used to train a model. This is a textbook example of target leak-
age. KNOW YOUR DATA! Finally, the data are split into train/test sets
using a 50/50 split; I could do this manually, but here I’ll use the caret
package’s createDataPartition() function, which uses stratified sampling
to ensure that the distribution of classes is similar between the resulting par-
titionsv :
names(bank) <- gsub("\\.", replacement = "_", x = names(bank))
bank$y <- ifelse(bank$y == "yes", 1, 0)
bank$duration <- NULL # remove target leakage
example.
w You can also start an R session with sparkR already available from the terminal by
running ./bin/sparkR from your Spark home folder; for details, see https://spark.apache.
org/docs/latest/sparkr.html
302 Random forests
to tell ) the location of the package. (Note that the code snippet below may
need to change for you depending on where you have Spark installed; for me,
it’s in C:\spark\spark-3.0.1-bin-hadoop2.7\R\lib.)
library(SparkR, lib.loc = "C:\\spark")
library(ggplot2)
To assess the performance of the probability forest, I can compute the Brier
score on the test set. A couple of things are worth noting about the code chunk
below. First, the predict() method, when applied to a SparkR MLlib model,
returns the predictions along with the original columns from the supplied
Spark DataFrame. Second, note that I have to compute the Brier score using
Spark DataFrame operations, like SparkR’s summarize() function, in this
case).
p <- predict(bank.rfo, newData = bank.tst.sdf) # Pr(Y=yes|x)
head(summarize(p, brier_score = mean((p$prediction - p$y)^2)))
#> brier_score
#> 1 0.07815544
The AUC on the test set for this model, if you care purely about discrimina-
tion, is 0.798x , which is in line with some of the even more advanced analy-
ses I’ve seen on these data. Nice! In addition, Figure 7.27 shows an isotonic
regression-based calibration curve (left) and cumulative gains chart, both com-
x Even if discrimination is the goal, AUC does not take into account the prior class
probabilities and is not necessarily appropriate in situations with severe class imbalance;
in this case the area under the PR curve would be more informative [Davis and Goadrich,
2006].
Software and examples 303
puted from the test data. The model seems reasonably calibrated (as we would
hope from a probability forest). The cumulative gains chart tells us, for exam-
ple, that we could expect roughly 1,500 subscriptions by contacting the top
20% of clients with the highest predicted probability of subscribing.
20
15
0.6
0.4 10
0.2 5
0.0 0
0.0 0.2 0.4 0.6 0.8 0.0 0.2 0.4 0.6 0.8 1.0
Original probabilities % Contacted
FIGURE 7.27: Graphical assessment of the performance on the test set. Left:
isotonic regression-based calibration curve. Right: Cumulative gains chart.
Nonetheless, I can use some regular expression (regex) magic to parse the
output into a more friendly data frame. (By no means do I claim this to be
the best solution; I’m certainly no regexpert—ba dum tsh.)
vi <- substr(vi, start = regexpr(",", text = vi)[1] + 1,
stop = nchar(vi) - 1)
vi <- gsub("\\[", replacement = "c(", x = vi)
y It is not clearly documented which variable importance metric MLlib uses in its imple-
mentation of RFs, but I suspect it’s the impurity-based metric (Section 7.5) since, as far as
I’m aware, MLlib’s RF implementation does not support the notion of OOB observations
(Section 7.3).
304 Random forests
lend one another funds (denominated in euros) whereby the loans have a 3 month maturity.
aa You can find a similar example using dplyr with the sparklyr front end to Spark here:
https://github.com/bgreenwell/pdp/issues/97.
Software and examples 305
construct a new Spark DataFrame containing only the plotting values for
euribor3m. Then, we just need to create a Cartesian product with the original
training data (excluding the variable euribor3m), or representative sample
thereof. This is accomplished in the next code chunk.
A word of caution is in order. Even though Spark is designed to work with
large data sets in a distributed fashion, Cartesian products can still be costly!
Hence, if your learning sample is quite large (e.g., in the millions), which is
probably the case if you’re using MLlib, then keep in mind that you don’t
necessarily need to utilize the entire training sample for computing partial
dependence and the like. If you have 50 million training records, for example,
then consider only using a small fraction, say, 10,000, for constructing feature
effect plots.
euribor3m.grid <- as.DataFrame(unique( # DataFrame of unique quantiles
approxQuantile(bank.trn.sdf, cols = "euribor3m",
probabilities = 1:29 / 30, relativeError = 0)
))
names(euribor3m.grid) <- "euribor3m"
Finally, we can compute the partial dependence values by aggregating the pre-
dictions using a simple grouping operator combined with a summary function
(for PD plots, we just average the predictions). The results are displayed in
Figure 7.28. Here you can see that the relative frequency of exclamation marks
is positively associated with spam (note that the y-axis is on the probability
scale).
ggplot(pd, aes(x = euribor3m, y = yhat)) +
geom_line() +
geom_rug(data = as.data.frame(euribor3m.grid),
aes(x = euribor3m), inherit.aes = FALSE) +
xlab("Euribor 3 month rate") +
ylab("Partial dependence") +
theme_bw()
0.20
Partial dependence
0.15
1 2 3 4 5
Euribor 3 month rate
Yikes, that was a long chapter, but necessarily so. While RFs were originally
introduced in Breiman [2001], many of the ideas have been seen before. For
example, the term “random forest” was actually coined by Ho [1995], who used
the random subspace method to combine trees grown in random subspaces
of the original features. Breiman [2001] references several other attempts to
further improve bagging by introducing more diversity among the trees in a
“forest.”
Leo Breiman was a phenomenal statistician (and theoretical probabilist) who
had a profound impact on the field of statistical and machine learning. If
you’re interested in more of his work, especially on the development of RF,
and the many collaborations it involved, see Cutler [2010]. Adele Cutler, a
close collaborator with Breiman on RFs, still maintains their original RF web-
site at https://www.stat.berkeley.edu/~breiman/RandomForests/. This
website is still one of the best references to understanding Breiman’s origi-
nal RF and includes links to several relevant papers and the original Fortran
source code.
Final thoughts 307
To end this chapter, I’ll leave you with a quote listed under the philosophy
section of Breiman and Cutler’s RF website, which applies more generally
than just to RFs:
Emily Dickinson
I like RFs because they’re powerful and flexible, yet conceptually simple: av-
erage a bunch of “de-correlated” trees together in the hopes of producing an
accurate prediction. However, RF is not always the most efficient or most
accurate tree ensemble to use. Gradient tree boosting provides another rich
and flexible class of tree-based ensembles that, at a high level, I think is also
conceptually simple. However, with gradient tree boosting, the devil is in the
details.
In Section 5.2.1, we were introduced to AdaBoost.M1, a particularly sim-
ple boosting algorithm for binary classification problems. Boosting initially
started off as a way to improve the performance of weak binary classifiers.
Over time, boosting has evolved into an incredibly flexible procedure that,
like RFs, can handle a wide array of supervised learning problems.
In this chapter, I’ll walk through the basics of the most currently popular flavor
of boosting: stochastic gradient boosting, also known as a gradient boosting
machine (or GBM for short)a . Although GBMs are meant to be more general,
in this book, GBM generally refers to stochastic gradient boosted decision
trees.
a This flavor of boosting goes by several names in the literature. For example, the R
package gbm [Greenwell et al., 2021b] fits this class of models, but stands for generalized
boosted models.
309
310 Gradient boosting machines
N
X
L (θ) = L [yi , f (xi ; θ)] .
i=1
A popular method for solving (8.1) is the method of steepest descent, a special
case of the more general method of gradient descent. Gradient descent is a
general class of iterative optimization algorithms that express the solution to
(8.1) as a sum of components:
B
X
θ =
?
θb ,
b=0
B
where θ 0 is some initial guess and {θ b }b=1 are successive “steps” or “boosts”
towards the optimal solution θ ? , and are found through the update equa-
tion
where ∆θ b−1 = −∂L (θ) /∂θ is the negative gradient of L (θ) with respect to
θ and represents the direction of “steepest descent” of L (θ), and γ > 0 is the
step size taken in that direction. Steepest descent methods differ in how γ is
determined.
Gradient tree boosting 311
Here’s a pretty common analogy for explaining gradient descent without any
math or fancy notation. Imagine trying to reach the bottom of a large hill (i.e.,
trying to find the global minimum) blindfolded. Without being able to see,
and assuming you’re not playing Marco Polo, you’ll have to rely on what you
can feel on the ground around you (i.e., local information) to find your way
to the bottom. Ideally you’d feel around the ground at your current location
(θ b−1 ) to get a sense of the direction of steepest descent (∆θ b−1 )—the fastest
way down—and proceed in that direction while periodically reassessing which
direction to go using the current local information available. How far you go
in each direction is determined by your step size γb .
This simple analogy glosses over a number of details, like getting stuck in a hole
(i.e., finding a local minimum), but hopefully the basic idea of gradient descent
is relatively clear: to find the global minimum of L (θ), we take incremental
steps in the direction of steepest descent, provided by the negative gradient
of L (θ) evaluated at the current point. The step size to take at each iteration
can be fixed or estimated by solving another minimization problem:
Now, what does steepest descent have to do with boosting decision trees?
Imagine trying to find some generic prediction function f such that
f ? = arg min L (f ) ,
f
PN
where L (f ) = i=1 L [yi , f (xi )] is a loss function evaluated over the learning
sample and encourages f to fit the data well [James et al., 2021, p. 302].
In contrast to (8.1), the parameters to be optimized here are the N fitted
values f ∈ RN from f (xi ) found at each iteration evaluated at the training
data xi :
312 Gradient boosting machines
N
f = {f (xi )}i=1 .
B
X
fB = f b, f b ∈ RN ,
b=0
f b = f b−1 − γg b ,
where
( )N
∂L (f )
gb = (8.3)
∂f f =fb−1 (xi ) i=1
J
X
f (xi ; θ, R) = θj I (xi ∈ Rj ) ,
j=1
J
where θ = {θj }j=1 represents the terminal node estimates (i.e., the mean
J
response in each terminal node), R = {Rj }j=1 represents the disjoint regions
b Typically, a CART-like tree, but any regression tree would, in theory, work here.
Gradient tree boosting 313
that form the J terminal nodes, and I (·) is the usual indicator function that
evaluates to one whenever its argument is true (and zero otherwise). By fitting
a model—a regression tree, in this case—to the observed negative gradient
means we can define it at new data points. Using regression trees, the update
becomes
J
X
fb (x) = fb−1 (x) + γb θjb I (x ∈ Rjb )
j=1
J
X
= fb−1 (x) + γjb I (x ∈ Rjb ) .
j=1
Consequently, the line search for choosing the step size γb is equivalent to
updating the terminal node estimates using the specified loss function:
N
X J
X
J
{γjb }j=1 = arg min L yi , fb−1 (xi ) + γj I (x ∈ Rjb ) . (8.4)
{γj }J
j=1 i=1 j=1
Following Friedman [2001], since the J terminal node regions are disjoint, we
can rewrite (8.4) as
X
γjb = arg min L [yi , fb−1 (xi ) + γ] , (8.5)
γ
xi ∈Rjb
J
which is the optimal constant update for each terminal node region, {Rjb }j=1 ,
based on the specified loss function L and the current iteration fb−1 (xi ).
Solving (8.5) is equivalent to fitting a generalized linear model with an off-
set c [Efron and Hastie, 2016, p. 349]. This step is quite important since, for
some loss functions, the original terminal node estimates will not be accu-
rate enough. For example, with least absolute deviation (LAD) loss (see Sec-
tion 8.2.0.1), the observed negative gradient, g b , only takes on integer values
in {−1, 1}; hence, the fitted values are not likely to be very accurate. In sum-
mary, the “line search” step (8.5) modifies the terminal node estimates of the
current fit to minimize loss.
c Roughly speaking, an offset is an adjustment term (in this case, a fixed constant) to be
added to the predictions in a model; this is more common in generalized linear models where
it’s added to the linear predictor with a fixed coefficient of one (rather than an estimated
coefficient).
314 Gradient boosting machines
X 2
γjb = arg min ([yi − fb−1 (xi )] − γ)
γ
xi ∈Rjb
X 2
= arg min (ri,b−1 − γ)
γ ,
xi ∈Rjb
1 X
= ri,b−1
Njb
xi ∈Rjb
where ri,b−1 and Njb are the i-th residual and number of observations in the
j-th terminal node for the b-th iteration, respectively. This results in the mean
of the residuals in each terminal node at the b-th iteration, which is precisely
what the original regression tree induced at iteration b uses for prediction (i.e.,
the terminal node summaries). In other words, for the special case of LS loss,
the original terminal node estimates at the b-th iteration are already optimal,
and so no update (i.e., line search) is needed.
For LAD loss, L (f ) = |y − f (x)|, and (8.5) results in the median of the cur-
rent residuals in the j-th terminal node at the b-th iteration. Solving (8.5) can
be difficult for general loss functions, like those often used in binary or multi-
nomial classification settings, and fast approximations are often employed (see
Section 8.2.0.1).
The full gradient tree boosting algorithm is presented in Algorithm 8.1. Note
that this is the original gradient tree boosting algorithm proposed in Friedman
[2001]. Several variations have been proposed in the literature, each with their
own enhancements, and I’ll discuss some of these modifications in the sections
that follow.
∂L (yi , f (xi ))
yib = −
?
, i = 1, 2, . . . , N.
∂f (xi ) f (xi )=fb−1 (xi )
d) Update fb (x) as
Jb
X
fb (x) ← fb−1 (x) + γjb I (x ∈ Rjb ) .
j=1
where ỹ = 2y − 1 ∈ {−1, +1}d and f refers to half the log odds for y = +1.
With binomial deviance, there is no closed-form solution to the line search
(8.5) in Algorithm 8.1, and approximations are often used instead. For exam-
ple, a single Newton-Rhapson step yields
X X
γjb = ỹi / |ỹi | (2 − |ỹi |) .
xi ∈Rjb xi ∈Rjb
The final approximation fb(x), which is half the logit for y = +1, can be
inverted to produce a predicted probability. The binomial deviance can also
be generalized to the case of multiclass classification [Friedman, 2001].
More specialized loss functions also exist when dealing with other types of
outcome variables. For example, Poisson loss (Table 8.1) can be used when
modeling counts (which are always positive integers). Ridgeway [1999] showed
how gradient boosting is extendable to the exponential familye , via likelihood-
based loss functions, as well as Cox proportional hazards regression models
for censored outcomes. Greg Ridgeway is also the original creator of the R
package gbm, which is arguably the first open source implementations of
gradient boosted decision trees.
TABLE 8.1: Common loss functions for gradient tree boosting. The top and
bottom sections list common loss functions used for ordered and binary out-
comes, respectively.
d This re-encoding is done for computational efficiency and also results in the same pop-
ulation minimizer as exponential loss (i.e., Adaboost.M1 from Section 5.2.4); see Bühlmann
and Hothorn [2007] for details.
e The exponential family includes many common loss functions as a special case; for
example, the Gaussian family is equivalent to using LS loss, the Laplace distribution is
equivalent to using LAD loss, and the Bernoulli/binomial family is equivalent to using
binomial deviance (or log loss).
f Here, y ∈ {0, 1, 2, ...} is a non-negative integer (e.g., the number of people killed by
i
mule or horse kicks in the Prussian army per year, or the number of calls to a customer
support center on a particular day.
g Here, ỹ ∈ {−1, 1} and f (x ) refers to half the log odds for y = +1.
i i
Hyperparameters and tuning 317
Another subtle difference from bagging or RFs is that GBMs always use re-
gression trees for the base learner, even for classification problems! It makes
sense if you think about it: gradient tree boosting involves fitting a tree to the
negative gradient of the loss function (i.e., pseudo residuals) at each iteration,
which is always ordered and treated as continuous.
N
1 X 2
PN wi [yi − f (xi )] ,
i=1 wi i=1
N
where {wi }i=1 are positive case weights that affect the terminal node esti-
mates. For example, a terminal node in a regression tree with values 2, 3, and
8 would be given estimate of (2 + 3 + 8) /3 = 4.333 under equal case weights
(the default). However, if the corresponding case weights were 1, 5, and 1, then
the terminal node estimate would become [(1 × 2) + (5 × 3) + (1 × 8)] /(1 +
5 + 1) = 3.571. See Kriegler and Berk [2010] for an example on boosted quan-
tile regression with case weights for small area estimation. The discussion in
Berk [2008, Sec. 6.5] is also worth reading.
Pafka [2020]
Unlike bagging and RFs, GBMs can overfit as the number of trees in the
ensemble (B) increases, all else held constant. This is evident from Figure 8.1,
which shows the training error (black curve) and 5-fold CV error (yellow
curve) for a sequence of boosted regression stumps fit to the Ames housing
data (Section 1.4.7) using LS loss. While the training error will continue to
decrease to zero as the number of trees (B) increases, at some point the
validation error will start to increase. Consequently, it is important to tune
the number of boosting iterations B.
So how many boosting iterations should you try? In part, it depends on a num-
ber of things, including the values set for other hyperparameters. In essence,
you want B large enough to adequately minimize loss and small enough to
avoid overfitting. For smaller data sets, like the PBC data (Section 1.4.9), it’s
easy enough to fix B to an arbitrarily large value (B = 2000, say) and use a
method like cross-validation to select an optimal value, assuming B was large
enough to begin with. In the baseball hitters example, the optimal number of
trees found by 5-fold CV is 52 (dashed blue line). For larger data sets, this
approach can be wasteful, which is where early stopping comes in.
The idea of early stopping is rather simple. At each iteration, we keep track
of the overall performance using some form of cross-validations or a separate
validation set. When the model stops improving by a prespecified amount, the
Hyperparameters and tuning 319
2000
LS loss
Train
5−fold CV
1000
FIGURE 8.1: Gradient boosted decision stumps for the Ames housing exam-
ple. The training error (black curve) continues to decrease as the number of
trees increases, while the error based on 5-fold CV eventually starts to in-
crease after B = 83 trees (dashed blue vertical line), indicating a problem
with overfitting for larger values of B.
process “stops early” and the model is considered to have converged. Early
stopping is really a mechanic of the implementation, not the GBM algorithm
itself, and so it is not necessarily supported in all implementations of gradi-
ent tree boosting. Note that the concept of early stopping can be applied to
other iterative or ensemble methods as well, like RFs (Chapter 7). Implemen-
tations of GBMs that support early stopping are discussed in Section 8.9. The
next section deals with a tuning parameter that’s intimately connected to the
number of boosted trees, the learning rate.
Jb
X
fb (x) ← fb−1 (x) + ν γjb I (x ∈ Rjb ) ,
j=1
Here, I’ll look at a brief example using the ALS data from Efron and Hastie
[2016, p. 349]. A description of the data, along with the original source and
download instructions, can be found at
https://web.stanford.edu/~hastie/CASI/.
The data concern N = 1, 822 observations on amyotrophic lateral sclerosis
(ALS or Lou Gehrig’s disease) patients. The goal is to predict ALS progression
over time, as measured by the slope (or derivative) of a functional rating score
(dFRS), using 369 available predictors obtained from patient visits. The data
were originally part of the DREAM-Phil Bowen ALS Predictions Prize4Life
challenge. The winning solution [Küffner et al., 2015] used a tree-based ensem-
ble quite similar to an RF, while Efron and Hastie [2016, Chap. 17] analyzed
the data using GBMs (as I’ll do in this chapter). I’ll show a fuller analysis of
these data in Sections 8.9.2–8.9.3.
Figure 8.2 shows the performance of a (very) basic implementation of gra-
dient tree boosting with LS loss using treemisc’s lsboost() function (see
Section 8.5) applied to the ALS data. Here, we can see the test MSE as a
function of the number of trees using two different learning rates: 0.02 (black
curve) and 0.50 (yellow curve) (following Efron and Hastie [2016, p. 339],
these are boosted regression trees of depth three). Using ν = 0.50 results
in overfitting much quicker. The performance curve for ν = 0.50 is also less
smooth than for ν = 0.02. While not spectacularly different, using ν = 0.02
results in a slightly more accurate model (in terms of MSE on the test set),
but requires far more trees. For comparison (and as a sanity check against
treemisc’s overly simplistic lsboost() function), I also included the re-
sults from a popular open source implementation of gradient boosting called
Hyperparameters and tuning 321
lsboost xgboost
shrinkage = 0.02
1.5
shrinkage = 0.50
Mean−squared error
1.0
0.5
FIGURE 8.2: Gradient boosted depth-three regression trees for the ALS data
using two different learning rates: 0.02 (black curve) and 0.50 (yellow curve).
Left: results from our own lsboost() function. Right: results from XGBoost.
The number of terminal nodes J controls the size (i.e., complexity) of each
tree in gradient tree boosting and plays the role of an important tuning pa-
rameter for capturing interaction effects. Alternatively, you can control tree
size by specifying the maximum depth. In general, a tree with maximum depth
d can capture interactions up to order d. Note that a tree of depth d will have
at most 2d terminal nodes and 2d − 1 splits. A binary tree with J terminal
nodes contains J − 1 splits and can capture interactions up to order J − 1.
(A J − 1-th order interaction is known as a J-way interaction effect; hence,
J = 1 corresponds to an additive model with no interaction effects). The
documentation for scikit-learn’s implementation of GBMsh notes that con-
trolling tree size with J seems to give comparable results to using d = J − 1
“...but is significantly faster to train at the expense of a slightly higher training
error.”
h See the “Controlling the tree size” section of the scikit-learn documentation at https:
//scikit-learn.org/stable/modules/ensemble.
322 Gradient boosting machines
I don’t think that grid searches are all that useful for GBMs, and tend to
be too costly for large data sets, especially if early stopping is not available.
A simple and effective tuning strategy for GBMs is to leave the tree-specific
hyperparameters at their defaults (discussed in the previous sections) and
tune the boosting parameters. A rule of thumb proposed by Greg Ridgewayi
is to set shrinkage as small as possible while still being able to fit the model
in a reasonable amount of time and storage. For example, aim for 3,000–
10,000 iterations with shrinkage rates between 0.01–0.001; use early stopping,
if it’s available. More elaborate tuning strategies for GBMs are discussed in
Boehmke and Greenwell [2020, Chap. 12].
every tree.
324 Gradient boosting machines
1.0
Subsample rows
Subsample columns
Mean−squared error
0.8
0.6
0.4
FIGURE 8.3: Effect of subsampling in GBMs on the ALS data. In this case,
randomly subsampling the columns (yellow curve) slightly outperforms ran-
domly subsampling the rows (black curve).
Let’s apply the lsboost() function to the Ames housing data. Below, I use the
same train/test split for the Ames housing data we’ve been using throughout
this book, then call lsboost() to fit a GBM to the training set; here, I’ll use
a shrinkage factor of ν = 0.1:
library(treemisc)
The test RMSE as a function of the number of trees in the ensemble is com-
puted below using the previously defined predict() method; the results are
shown in Figure 8.4 (black curve). For brevity, the code uses sapply() to
essentially iterate cumulatively through the B = 500 trees and computes the
test RMSE for the first tree, first two trees, etc. For comparison, the test RM-
SEs from a default RF are also computed and displayed in Figure 8.4 (yellow
curve). In this example, the GBM slightly outperforms the RF.
set.seed(1128) # for reproducibility
ames.rfo <- # fit a default RF for comparison
randomForest(subset(ames.trn, select = -Sale_Price),
y = ames.trn$Sale_Price, ntree = 500,
# Monitor test set performance (MSE, in this case)
xtest = subset(ames.tst, select = -Sale_Price),
ytest = ames.tst$Sale_Price)
# Compute RMSEs from both models on the test set as a function of the
# number of trees in each ensemble (i.e., B = 1, 2, ..., 500)
rmses <- matrix(nrow = 500, ncol = 2) # to store results
colnames(rmses) <- c("GBM", "RF")
rmses[, "GBM"] <- sapply(seq_along(ames.bst$trees), FUN = function(B) {
pred <- predict(ames.bst, newdata = ames.tst, ntree = B)
rmse(pred, obs = ames.tst$Sale_Price)
}) # add GBM results
rmses[, "RF"] <- sqrt(ames.rfo$test$mse) # add RF results
8.6 Interpretability
7
Test RMSE
FIGURE 8.4: Root mean-squared error for the Ames housing test set as a
function of B, the number of trees in the ensemble. Here, I show both a GBM
(black curve) and a default RF (yellow curve). In this case, gradient tree
boosting with LS loss, a shrinkage of λ = 0.1, and a maximum tree depth of
d = 4 (black curve) slightly outperforms a default RF (yellow curve).
Let’s illustrate with the email spam example (Section 1.4.5). Here, I used the
R package gbm to fit a GBM using log loss, B = 4, 043 depth-2 regression
trees (found using 5-fold cross-validation), a shrinkage factor of ν = 0.01.
To gain an appreciation for the computational speed-up of the recursion
method (which is implemented in gbm), I computed Friedman’s H-statistic
for all 1,596 pairwise interactions, which took roughly five minutes! The largest
pairwise interaction occurred between address and receive. The partial de-
pendence of the log-odds of spam on the joint frequencies of address and
receive is displayed in Figure 8.5. Using the fast recursion method, this took
roughly a quarter of a second to compute, compared to the brute force method,
which took almost 500 seconds.
https://scikit-learn.org/stable/modules/partial_dependence.html.
330 Gradient boosting machines
0.0
−0.2
−0.4
−0.6
1.0
0.8 2.0
0.6 1.5
receive 0.4 1.0
0.2 0.5 address
0.0
FIGURE 8.5: Partial dependence of log-odds of spam on joint frequency of
address and receive.
Returning to the bank marketing example from Section 7.9.5, I fit a GBM with
and without a decreasing monotonic constraint on euribor3m, the Euribor 3
Interpretability 331
month ratel . In both cases, I used 5-fold cross-validation to fit a GBM with a
maximum of 3,000 trees using a shrinkage rate of ν = 0.01 and a maximum
depth of dmax = 3. The partial dependence of the probability of subscribing
on euribor3m from each model is displayed in Figure 8.6. Both figures tell the
same story: the predicted probability of subscribing tends to decrease as the
euribor 3 month rate increases. However, it may make sense here to assume the
relationship to be monotonic decreasing, as in the left side of Figure 8.6. This
can help increase interpretation and understanding by incorporating domain
expertise, for example, by removing some of the noise like the little spike in
the right side of Figure 8.6 near euribor3m = 1 = 1. Compare these to the
RF-based PDP from Figure 7.28.
0.15
Partial dependence
0.12
0.09
0.06
1 2 3 4 5 1 2 3 4 5
Euribor 3 month rate
l Recall
that the 3 month Euribor rate is the interest rate at which a selection of European
banks lend each other funds (denominated in euros) whereby the loans have a 3 month
maturity.
332 Gradient boosting machines
There two strategies to consider when growing an individual decision tree that
we have yet to discuss:
• Level-wise (also referred to as depth-wise or depth first) tree induction
is used by many common decision tree algorithms (e.g., CART and
C4.5/C5.0, but this probably depends on the implementation) and grows
a tree level by level in a fixed order; that is, each node splits the data by
prioritizing the nodes closer to the root node.
• Leaf-wise tree induction (also referred to as best-first splitting), on the
other hand, grows a tree by splitting the node whose split leads to the
largest reduction of impurity.
When grown to maximum depth, both strategies result in the same tree struc-
ture; the difference occurs when trees are restricted to a maximum depth
or number of terminal nodes. Leaf-wise tree induction, while not specific to
boosting, has primarily only been evaluated in that context; see, for example,
Friedman [2001] and Shi [2007].
Figure 8.7 gives an example of a tree grown level-wise (left) and leaf-wise
(right). Notice how the overall tree structures are the same, but the order
in which the splits are made (i.e., S1 –S4 ) is different. In general, level-wise
growth tends to work better for smaller data sets, whereas leaf-wise tends
to overfit. Leaf-wise growth tends to excel in larger data sets where it is
considerably faster than level-wise growth. This is why some modern GBM
implementations—like LightGBM (Section 8.8)—default to growing trees leaf-
wise.
S1 S1
S3 S2 S2 S4
τ3 S4 τ1 τ2 τ1 S3 τ4 τ5
τ4 τ5 τ2 τ3
Finding the optimal split for a numeric feature in a decision tree can be slow
when dealing with many unique values; the more unique values a numeric
predictor has, the more split points the tree algorithm has to search through.
A much faster alternative is to bucket the numeric features into bins using
histograms.
The idea is to first bin the input features into integer-valued bins (255–256 bins
seems to be the default across many implementations) which can tremendously
reduce the number of split points to search through. Histogram binning is im-
plemented in a number of popular GBM implementations, including XGBoost
(Section 8.8.1), LightGBM (Section 8.8.2), and the sklearn.ensemble mod-
ule. LightGBM’s online documentation lists several earlier references to this
approach; visit https://lightgbm.readthedocs.io/en/latest/Features.
html for details.
X
g (E [y|x]) = β0 + fj (xj ) , (8.6)
j
where g is a link function that connect the random and systematic component
(e.g., adapts the GAM to different settings such as classification, regression,
or Poisson regression), and fj is a function of predictor xj . Compared to a
traditional GAM, an EBM:
• estimates each feature function fj (xj ) using tree-based ensembles, like
gradient tree boosting or bagging;
• can automatically detect and include pairwise interaction terms of the
form fij (xi , xj ).
The overall boosting procedure is restricted to train on one feature at a time
in a “round-robin” fashion using a small learning rate to ensure that feature
334 Gradient boosting machines
order does not matter, which helps limit the effect of collinearity or strong
dependencies among the features.
EBMs are considered “glass box” or highly interpretable models because
the contribution of each feature (or pairwise interaction) to a final predic-
tion can be visualized and understood by plotting fj (xj ), similar to a PD
plot (Section 6.2.1). And since EBMs are additive models, each feature con-
tributes to predictions in a modular way that makes it easy to reason about
the contribution of each feature to the prediction [Nori et al., 2019]. The
simple additive structure of an EBM comes at the cost of longer training
times. However, at the end of model fitting, the individual trees can be
dropped and only the fj (xj ) and fij (xi , xj ) need to be retained, which makes
EBMs faster at execution time. EBMs are available in the interpret pack-
age for Python. For more info, check out the associated GitHub repository at
https://github.com/interpretml/interpret.
XGBoost [Chen and Guestrin, 2016] is one of the most popular and scalable
implementations of GBMs. While XGBoost follows the same principles as the
standard GBM algorithm, there are some important differences, a few of which
are listed below:
• more stringent regularization to help prevent overfitting;
• a novel sparsity-aware split finding algorithm;
• weighted quantile sketch for fast and approximate tree learning;
• parallel tree building (across nodes within a tree);
• exploits out-of-core processing for maximum scalability on a single ma-
chine;
336 Gradient boosting machines
" P 2 2 2 #
1
P P
i∈IL gi i∈IR gi i∈II gi
Lsplit = +P −P − γ. (8.7)
2 i∈IL hi + λ i∈IR hi + λ i∈II hi + λ
P
feature is found by searching through this reduced set of candidate split val-
ues. For details, see Chen and Guestrin [2016]. A modification for weighted
data, called a weighted quantile sketch, is also discussed in Chen and Guestrin
[2016].
Around early 2017, XGBoost introduced fast histogram binning (Section 8.7.2)
to even further push the boundaries of scale and computation speed. In con-
trast to the original approximate tree learning strategy, which generates a
new set of bins for each iteration, the histogram method re-uses the bins over
multiple iterations, and therefore is far better suited for large data sets. XG-
Boost also introduced the option to grow trees leaf-wise, as opposed to just
level-wise (the default), which can also speed up fitting, albeit, at the risk of
potentially overfitting the training data (Section 8.7.1).
Sparse data are common in many situations, including the presence of miss-
ing values and one-hot encoding. In such cases, efficiency can be obtained
by making the algorithm aware of any sparsity patterns. XGBoost handles
sparsity by learning an optimal “default” direction at each split in a tree.
When an observation is missing for one of the split variables, for example, it
is simply passed down the default branch. For details, see Chen and Guestrin
[2016].
One drawback of XGBoost is that it does not currently handle categorical
variables—they have to be re-encoded numerically (e.g., using one-hot encod-
ing). However, at the time of writing this book, XGBoost has experimental
support for categorical variables, although it’s currently quite limited. An ex-
ample of using XGBoost is given in Section 8.9.4. Note that XGBoost can also
be used to fit RFs in a distributed fashion; see the XGBoost documentation
for details.
LightGBM [Ke et al., 2017] offers many of the same advantages as XG-
Boost, including sparse optimization, parallel tree building, a plethora of
loss functions, enhanced regularization, bagging, histogram binning, and early
stopping. A major difference between the two is that LightGBM defaults
to building trees leaf-wise (or best-first). Unlike XGBoost, LightGBM can
more naturally handle categorical features in a way similar to what’s de-
scribed in Section 2.4. In addition, the LightGBM algorithm utilizes two novel
techniques, gradient-based one-side sampling (GOSS) and exclusive feature
bundling (EFB).
GOSS reduces the number of observations by excluding rows with small gra-
dients, while the remaining instances are used to estimate the information
gain for each split; the idea is that observations with larger gradients play a
more important role in split selection. EFB, on the other hand, reduces the
338 Gradient boosting machines
8.8.3 CatBoost
https://catboost.ai/
They also claim that CatBoost works reasonably well out of the box with less
time needed to be spent on hyperparameter tuning.
In CatBoost, a process called quantization is applied to numeric features,
whereby values are divided into disjoint ranges or buckets—this is similar to
the approximate tree growing algorithm in XGBoost whereby numeric fea-
tures are binned. Before each split, categorical variables are converted to nu-
meric using a strategy similar to mean target encoding, called ordered target
statistics, which avoids the problem of target leakage and reduces overfitting.
CatBoost is currently available as a command line application in C++, but
R and Python interfaces are also available. For further details and resources,
visit the CatBoost website at https://catboost.ai/.
m Both XGBoost and LightGBM now support categorical features.
Software and examples 339
In this example, I’ll return to the PBC data introduced in Section 1.4.9, where
the goal is to model survival in patients with the autoimmune disease PBC. In
Section 3.5.3, I fit a CTree model to the randomized subjects using log-rank
scores. Here, I’ll use the GBM framework to boost a Cox proportional hazards
(Cox PH) model; see Ridgeway [1999] for details.
The Cox PH model is one of the most widely used models for the analysis
of survival data. It is a semi-parametric model in the sense that it makes a
parametric assumption regarding the effect of the predictors on the hazard
n MART, which evolved into a product called TreeNet(tm), is proprietary software avail-
able from Salford Systems, which is currently owned by MiniTab; visit https://www.
minitab.com/en-us/predictive-analytics/treenet/ for details.
340 Gradient boosting machines
function (or hazard rate) at time t, often denoted λ (t), but makes no assump-
tion regarding the shape of λ (t); since little is often known about λ (t) in
practice, the Cox PH model is quite useful. The hazard rate—also referred to
as the force of mortality or instantaneous failure rate—is related to the prob-
ability that the event (e.g., death or failure) will occur in the next instant of
time, given the event has not yet occurred [Harrell, 2015, Sec. 17.3]; it’s not
a true probability since λ (t) can exceed one. Studying the hazard rate helps
understand the nature of risk of time.
Extending Algorithm 8.1 to maximize Cox’s log-partial likelihood (which is
akin to minimizing an appropriate loss function) allows us to relax the linearity
assumption, which assumes that the (possibly transformed) predictors are
linearly related to the log hazard, and fit a richer class of models based on
regression trees. Below, I load the survival package and recreate the same
pbc2 data frame I used in Section 3.5.3:
library(survival)
Next, I’ll fit a boosted PH regression model using the gbm package; de-
tails on the deviance/loss function used in Algorithm 8.1 can be found in
the gbm package vignette: vignette("gbm", package = "gbm"). Here, I use
B = 3000, a maximum tree depth of three, and a learning rate of ν = 0.001.
The optimal number of trees (best.iter) is determined using 5-fold cross-
validation.
library(gbm)
trt
sex
spiders
hepato
alk.phos
trig
stage
ascites
ast
platelet
edema
chol
protime
albumin
age
copper
bili
0 5 10 15 20 25 30
Variable importance
FIGURE 8.8: Variable importance plot from the boosted Cox PH model ap-
plied to the PBC data.
It looks as though serum bilirubin (mg/dl) (bili) is the most influential fea-
ture on the fitted model. We can easily investigate a handful of important
features by constructing PDPs or ICE plots. In the next code chunk, I con-
struct c-ICE plots (Section 6.2.3) for the top four features. The results are
displayed in Figure 8.9; the average curve (i.e., partial dependence, albeit cen-
tered) is shown in red. Note that while gbm has built-in support for partial
dependence using the recursion method (Section 8.6.1), it does not support
ICE plots; hence, I’m using the brute force approach (recursive = FALSE)
via the pdp package. The code essentially creates a list of plots, which is
displayed in a 2-by-2 grid using the gridExtra package [Auguie, 2017]:
library(ggplot2)
library(pdp)
})
2.0 0.
1.5 0.6
g ha ar
g ha ar
1.0 0.
0.5 0.2
0.0 0.0
0 10 20 0 200 00 600
l pper
1.00 0.0
0.75 0.2
g ha ar
g ha ar
0.50 0.
0.25 0.6
0.00 0.
0 60 0 2 3
age al u
FIGURE 8.9: Main effect plots from a gradient boosted Cox PH model on the
PBC data.
70
PD
60
Age (years)
0.0
−0.5
50
−1.0
40 −1.5
30
0 10 20
Serum bilirunbin (mg/dl)
FIGURE 8.10: Partial dependence of log hazard on the joint value of bili
and age.
# Estimate feature contributions for newx using 1,000 Monte Carlo reps
X <- pbc2[, pbc2.gbm$var.names] # feature columns only
newx <- pbc2[max.id, pbc2.gbm$var.names]
set.seed(1408) # for reproducibility
(ex <- explain(pbc2.gbm, X = X, nsim = 1000, pred_wrapper = pfun,
newdata = newx))
#> # A tibble: 1 x 17
#> trt age sex ascites hepato spiders
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 -0.00168 0.313 -0.000342 0.282 -0.0175 -0.00492
#> # ... with 11 more variables: edema <dbl>, bili <dbl>,
#> # chol <dbl>, albumin <dbl>, copper <dbl>,
#> # alk.phos <dbl>, ast <dbl>, trig <dbl>,
#> # platelet <dbl>, protime <dbl>, stage <dbl>
current predicted log hazard and the overall average baseline—in essence, it
shows how this subject went from the average baseline log hazard of -1 to their
much higher prediction of 2.701. Note that the waterfallchart() function
produces a lattice graphic, which behaves differently than base R graphics;
hence, I use the ladd() function from package mosaic [Pruim et al., 2021]
to add specific details to the plot (e.g., text labels and additional reference
lines).
library(waterfall)
f(x) − baseline
trt=1
trig=229
stage=4
spiders=0
sex=f
protime=12.9
platelet=62
hepato=0
edema=1
copper=220 baseline f(x)
chol=175
bili=17.9
ast=338
ascites=1
alk.phos=705
albumin=2.1
age=65.9
−1 0 1 2
Log hazard
Now I can fit the actual model by calling the fit() on our ngb object. Here,
I provide the test set to the validation parameters and use early stopping to
determine at which iteration (or tree) the procedure should stop. Here, the pro-
cedure will stop if there has been no improvement in the negative log-likelihood
(the default scoring rule in ngboost) in five consecutive rounds.
_ = ngb.fit(X_trn, Y=als_trn["dFRS"], X_val=X_tst,
Y_val=als_tst["dFRS"], early_stopping_rounds=5)
#> == Early stopping achieved.
Software and examples 347
Recall from Section 5.5 that it is possible to reduce the number of base learners
in a fitted ensemble via post-processing with the LASSO (hopefully without
sacrificing accuracy). We saw this using the Ames housing data with a bagged
tree ensemble and RF in Sections 5.5.1 and 7.9.2, respectively.
It may seem redundant to include another ISLE post-processing example, but
there’s a subtle difference that can be overlooked with GBMs: the initial fit
f0 (x) in Step 1) of Algorithm 8.1 essentially represents an offset.
348 Gradient boosting machines
To illustrate, I’ll continue with the ALS example from Section 8.9.2o . Be-
low, I read in the ALS data and split it into train/test sets using the pro-
vided testset indicator column; the -1 keeps the indicator column out of the
train/test sets.
# Read in the ALS data
url <- "https://web.stanford.edu/~hastie/CASI_files/DATA/ALS.txt"
als <- read.table(url, header = TRUE)
Next, I call on our lsboost() function to fit a GBM using B = 1000 depth-2
trees with a shrinkage factor of ν = 0.01 and a subsampling rate of 50%.
I also compute test predictions from each individual tree and compute the
cumulative MSE as a function of the number of trees. (Warning, the model
fitting here will take a few minutes.)
library(treemisc)
Next, I’ll call upon the isle_post() function from package treemisc to post-
process our fitted GBM using the LASSO. There’s one important difference
between this example and the previous ones applied to bagging and RF: with
o A special thanks to Trevor Hastie for clarification and sharing code from Efron and
Hastie [2016, pp. 346–347], which greatly helped in producing this analysis (which is a
detailed recreation of their example) and building out treemisc’s isle_post() function.
Software and examples 349
50
10
0 976
e s
900
96
671
970
9
772
71
30 979
726
62
9
73
0
713
6519
393
767
790
693
901
319
921
e
999 2
616
7
617
9
77
712
911
561
739
11
55
22
906
651
930
673
797
707
919 5
20 916
773
9
7
6
9
99
9
26
923
151
9
95
7
6 79
519
903
72
771
5
3
5
0
23
3
90
5 50
523
967
29
616
1
3
652
7
996
5 20
606
253
91
599 6
35
90
715
636
151 3
6
0
507
73
915
33
925
932
633
675
907
61
705
99
9117
913
973
769
553
527
626
62
66
1
1000
605
1
359
515
01
691
659
513
5
7
730 9
10 6
29
9
5 5
02
23
59
0
721
93
927
939
615
931
936
953
59
621
72
9
350
225
7
566
516
676
65
702
706
302
509
235
20
6 99
666
565
753
731
729
6
7
9 2
92
602
7
5
2
0
5
9
7
1
5
3
301
559
520
59
966
92
560
6
735
7
331
67
6
97
177
905
9 75
265
2 2
9
7
0
309
759
66
6
51
59
2 3
917
53
33
273
720
97
596 1
0
362
557
690
701
3 37
325
215
920
6909
711
65
173
192
106
603
997
629
567
2
995 3
9
5
703
20
5 56
655
6
573
993
660
72
76
5
209
276
197
792
290
590
9 6
0
9
1 12
322
391
5752
577
91
295
505
90
521
609
5
19322
5
632
7
511
529
503
737
776
22
355
25
327
6 162
0 612
992
639
60
21
9 2
55
70
67
79
90
93
103
102
101
139
133
132
129
125
12
123
121
11
112
110
172
16
1
200
377
376
373
369
36
363
390
3
922
212
211
221
220
217
239
2
257
279
291
311
321
329
332
335
357
356
351
399
397
517
526
75
7
912
55
563
56
575
63
657
67
6 32
25
23
19
17
15
12
11
37
1
7
6
3
77
9
10
09
20
31
30
35
53
51
73
33
692
719
595
597
710 2
0
9
5
16 1 12 10
g a a
boosting, we need to make sure we include the initial fit f0 (x), which is stored
in the "init" component of the output from lsboost(). Recall that for LS
loss f0 (x) = ȳ, the mean response in the training data. This can be done in
one of two ways:
1) arbitrarily add it to the predictions from the first tree;
2) include it as an offset when fitting the LASSO and generating predictions.
In this example, I’ll include the initial fit as an offset in the call to
isle_post().
The results are displayed in Figure 8.12 which contains the coefficient paths as
a function of the L1 -penalty as λ varies. The top axis indicates the number of
nonzero coefficients (i.e., number of trees) at the current value of λ. Here, the
smallest test error for the LASSO-based post-processed GBM is 25.9% and
corresponds to 84 trees; see Figure 8.13. The post-processing has significantly
reduced the number of trees in this example resulting in a substantially more
parsimonious model while maintaining accuracy. Sweet!
library(treemisc)
0.34
0.32
Test MSE
0.30
0.28
0.26
FIGURE 8.13: Test MSE as a function of the number of trees from the full
GBM (black curve) and LASSO-based post-processed results (yellow curve).
Here, we can see that by re-weighting the trees using an L1 penalty, which
enables some the trees to be dropped entirely, we end up with a smaller more
parsimonious model without degrading performance on the test set.
Software and examples 351
In this example, I’ll revisit the bank marketing data I analyzed back in Sec-
tion 7.9.5 using an RF in Spark. Here, I’ll fit a GBM using the scalable XG-
Boost library and show the benefits to early stopping. For brevity, I’ll omit
the necessary code already shown in Section 7.9.5. To that end, assume we
already have the full data set loaded into a data frame called bank.
Next, similar to the RF-based analysis, I’ll clean up some of the columns and
column names. Since XGBoost requires all the data to be numericp , we have
to re-encode the categorical features. The binary variables I’ll convert to 0/1,
while categorical variables with higher cardinality will be transformed using
one-hot encoding (OHE)q . First, I’ll deal with the binary variables and column
names:
names(bank) <- gsub("\\.", replacement = "_", x = names(bank))
bank$y <- ifelse(bank$y == "yes", 1, 0)
bank$contact <- ifelse(bank$contact == "telephone", 1, 0)
bank$duration <- NULL # remove target leakage
Next, I’ll deal with the one-hot encoding. There are several packages that can
help with this (e.g., caret’s dummyVars() function); I’ll do the transforma-
tion using pure data.table. The code below identifies the remaining categor-
ical variables (cats) and uses data.table’s melt() and dcast() functions
to handle the heavy lifting; the left hand side of the generated formula (fo)
tells dcast() which variables to not OHE (i.e., the binary and non-categorical
features):
bank$id <- seq_len(nrow(bank)) # need a unique row identifier
cats <- names(which(sapply(bank, FUN = class) == "character"))
lhs <- paste(setdiff(names(bank), cats), collapse = "+")
fo <- as.formula(paste(lhs, "~ variable + value"))
bank <- as.data.table(bank) # coerce to data.table
bank.ohe <- dcast(melt(bank, id.vars = setdiff(names(bank), cats)),
formula = fo, fun = length)
bank$id <- bank.ohe$id <- NULL
Now that we have the data encoded properly for XGBoost, we can split the
data into train/test sets using the same 50/50 split as before:
set.seed(1056) # for reproducibility
trn.id <- caret::createDataPartition(bank.ohe$y, p = 0.5, list = FALSE)
bank.trn <- data.matrix(bank.ohe[trn.id, ]) # training data
bank.tst <- data.matrix(bank.ohe[-trn.id, ]) # test data
p While XGBoost has limited (and experimental) support for categorical features, this
does not seem to be accessible via the R interface, at least at the time of writing this book.
q Several of the categorical features are technically ordinal (e.g., day_of_week) and should
XGBoost does not work with R data frames. The xgb.train() function,
in particular, only accepts data as an "xgb.DMatrix" object. An XGBoost
DMatrix is an internal data structure used by XGBoost, which is optimized
for both memory efficiency and training speed; see ?xgboost::xgb.DMatrix
for details. We can create such an object using the xgboost function
xgb.DMatrix() (note that I separate the predictors and response in the calls
to xgb.DMatrix()):
library(xgboost)
Finally, I can fit an XGBoost model. I’ll fit two in total, one without early
stopping and one with, starting with the no early stopping version below.
But first, I’ll define a “watch list,” which is just a named list of data sets to
use for evaluating model performance after each iteration that we can use to
determine the optimal number of trees (k-fold cross-validation could also be
used via xgboost’s xgb.cv() function):
watch.list <- list(train = dm.trn, eval = dm.tst)
Out of 3,000 total iterations, we really only needed to build 1,296 trees,
which can be expensive for large data sets (regardless of which fancy scal-
able implementation you use). While XGBoost is incredibly efficient, it is
still wasteful to fit more trees than potentially necessary. To that end, I can
turn on early stopping (Section 8.3.1.1) to halt performance once it detects
Software and examples 353
the potential for overfitting. In XGBoost, early stopping will halt training
if model performance has not improved for a specified number of iterations
(early_stopping_rounds).
In the code chunk below, I fit the same model (random seed and all), but tell
XGBoost to stop the training process if the performance on the test set (as
specified in the watch list) has not improved for 150 consecutive iterations
(5% of the total number of requested iterations)r :
set.seed(1100) # for reproducibility
(bank.xgb.2 <-
xgb.train(params, data = dm.trn, nrounds = 3000, verbose = 0,
watchlist = watch.list, early_stopping_rounds = 150))
In this case, using early stopping resulted in the same optimal number of trees
(e.g., 1,296), but only required 1,446 boosting iterations (or trees) in total,
a decent savings in terms of both computation time and storage space (1.7
Mb for early stopping compared to 3.6 Mb for the full model)! The overall
training results are displayed in Figure 8.14 below.
palette("Okabe-Ito")
plot(bank.xgb.1$evaluation_log[, c(1, 2)], type = "l",
xlab = "Number of trees",
ylab = "RMSE (square root of Brier score)")
lines(bank.xgb.1$evaluation_log[, c(1, 3)], type = "l", col = 2)
abline(v = best.iter, col = 2, lty = 2)
abline(v = bank.xgb.2$niter, col = 3, lty = 2)
legend("topright", legend = c("Train", "Test"), inset = 0.01, bty = "n",
lty = 1, col = 1:2)
palette("default")
10% of the total number of requested iterations, but no evidence or citations as to why.
354 Gradient boosting machines
0.50
Train
RMSE (square root of Brier score) Test
0.45
0.40
0.35
0.30
FIGURE 8.14: RMSE (essentially, the square root of the Brier score) from
an XGBoost model fit to the bank marketing data. According to the inde-
pendent test set (yellow curve), the optimal number of trees is 1,296 (verti-
cal dashed yellow line). Early stopping, which reached the same conclusion,
would’ve stopped training at 1,446 trees (vertical dashed blue line), which
roughly halves the training time in this case.
Below, I compute Tree SHAP values for the entire training set and use that
to form Shapley-based variable importance scores; here, I’ll follow Lundberg
et al. [2020] and the shap module’s approach by computing the mean abso-
lute Shapley value for each column. (Note that it is not necessary to use the
entire learning sample for doing this, and a large enough subsample should
often suffice, especially when dealing with hundreds of thousands or millions
of records.) A dot plot of the top 10 Shapley-based importance scores is dis-
played in Figure 8.15. Note that I need to specify the optimal number of trees
(ntreelimit = best.iter) when calling predict():
shap.trn <- predict(bank.xgb, newdata = dm.trn, ntreelimit = best.iter,
predcontrib = TRUE, approxcontrib = FALSE)
shap.trn <- shap.trn[, -which(colnames(shap.trn) == "BIAS")]
default_no
pdays
poutcome_failure
campaign
cons_conf_idx
age
emp_var_rate
contact
euribor3m
nr_employed
be explored further by using another feature (or features) to help color the
plot.
shap.age <- data.frame("age" = bank.trn[, "age"],
"shap" = shap.trn[, "age"])
0.9
0.6
Shapley value
0.3
0.0
25 50 75 100
Age (years)
FIGURE 8.16: Shapley dependence plot for age from the XGBoost model fit
to the bank marketing data; a nonparametric smooth is also shown (yellow
curve). Any point below the horizontal dashed blue line corresponds to a
negative contribution to the predicted outcome).
GBMs are a powerful class of machine learning algorithms that can achieve
state-of-the-art performance, provided you train them properly. Due to the
existence of efficient libraries (like XGBoost and Microsoft’s LightGBM) and
Final thoughts 357
the shallower nature of the individual trees, GBMs can also scale incredibly
well; see, for example, Pafka [2021]. For these reasons, GBMs are quite popular
in applied practice and are often used in the winning entries for many super-
vised learning competitions with tabular data sets. Just keep in mind that,
unlike RFs, GBMs are quite sensitive to several tuning parameters (e.g., the
learning rate and number of boosting iterations), and these models should be
carefully tuned (ideally, with some form of early stopping, especially if you’re
working with a large learning sample, using a fairly small learning rate with a
large number of boosting iterations, and/or tuning lots of parameters).
Bibliography
359
360 Bibliography
Michel Ballings and Dirk Van den Poel. rotationForest: Fit and Deploy Rota-
tion Forest Models, 2017. URL https://CRAN.R-project.org/package=
rotationForest. R package version 0.1.3.
Richard A. Berk. Statistical Learning from a Regression Perspective. Springer
Series in Statistics. Springer New York, 2008. ISBN 9780387775012.
Gérard Biau, Luc Devroye, and Gábor Lugosi. Consistency of random forests
and other averaging classifiers. Journal of Machine Learning Research, 9
(66):2015–2033, 2008. URL http://jmlr.org/papers/v9/biau08a.html.
Przemyslaw Biecek and Hubert Baniecki. ingredients: Effects and Importances
of Model Ingredients, 2021. URL https://CRAN.R-project.org/package=
ingredients. R package version 2.2.0.
Przemyslaw Biecek and Tomasz Burzykowski. Explanatory Model Analysis.
Chapman and Hall/CRC, New York, 2021. ISBN 9780367135591. URL
https://pbiecek.github.io/ema/.
Przemyslaw Biecek, Alicja Gosiewska, Hubert Baniecki, and Adam Izdeb-
ski. iBreakDown: Model Agnostic Instance Level Variable Attributions,
2021. URL https://CRAN.R-project.org/package=iBreakDown. R pack-
age version 2.0.1.
Rico Blaser and Piotr Fryzlewicz. Random rotation ensembles. Journal of
Machine Learning Research, 17(4):1–26, 2016. URL http://jmlr.org/
papers/v17/blaser16a.html.
Bradley Boehmke and Brandon Greenwell. Hands-On Machine Learning
with R. Chapman & Hall/CRC the R series. CRC Press, 2020. ISBN
9781138495685.
Leo Breiman. Bagging predictors. Machine Learning, 24(2):123–140, 1996a.
URL https://doi.org/10.1007/BF00058655.
Leo Breiman. Heuristics of instability and stabilization in model selection.
The Annals of Statistics, 24(6):2350–2383, 1996b.
Leo Breiman. Technical note: Some properties of splitting criteria. Machine
Learning, 24:41–47, 1996c.
Leo Breiman. Pasting small votes for classification in large databases and on-
line. Machine Learning, 36(1):85–103, 1999. doi: 10.1023/A:1007563306331.
URL https://doi.org/10.1023/A:1007563306331.
Leo Breiman. Random forests. Machine Learning, 45(1):5–32, 2001. URL
https://doi.org/10.1023/A:1010933404324.
Leo Breiman. Manual on setting up, using, and understanding random forests
v3.1. Technical report, 2002. URL https://www.stat.berkeley.edu/
~breiman/Using\_random\_forests\_V3.1.pdf.
Bibliography 361
Paulo Cortez, António Cerdeira, Fernando Almeida, Telmo Matos, and José
Reis. Modeling wine preferences by data mining from physicochemical
properties. Decision Support Systems, 47(4):547–553, 2009. ISSN 0167-
9236. doi: https://doi.org/10.1016/j.dss.2009.05.016. URL http://www.
sciencedirect.com/science/article/pii/S0167923609001377.
Mark Culp, Kjell Johnson, and George Michailidis. ada: The R Package
Ada for Stochastic Boosting, 2016. URL https://CRAN.R-project.org/
package=ada. R package version 2.0-5.
Adele Cutler. Remembering Leo Breiman. The Annals of Applied Statistics,
4(4):1621–1633, 2010. doi: 10.1214/10-AOAS427. URL https://doi.org/
10.1214/10-AOAS427.
Natalia da Silva, Dianne Cook, and Eun-Kyung Lee. A projection pur-
suit forest algorithm for supervised classification. 0(0):1–13, 2021a.
doi: 10.1080/10618600.2020.1870480. URL https://doi.org/10.1080/
10618600.2020.1870480.
Natalia da Silva, Eun-Kyung Lee, and Di Cook. PPforest: Projection Pursuit
Classification Forest, 2021b. URL https://github.com/natydasilva/
PPforest. R package version 0.1.2.
Data Mining Group. Predictive model markup language, 2014. URL http:
//www.dmg.org/. Version 4.2.
Jesse Davis and Mark Goadrich. The relationship between precision-recall
and roc curves. In Proceedings of the 23rd International Conference on
Machine Learning, ICML ’06, pages 233–240, New York, NY, USA, 2006.
Association for Computing Machinery. ISBN 1595933832. doi: 10.1145/
1143844.1143874. URL https://doi.org/10.1145/1143844.1143874.
Anthony C. Davison and David V. Hinkley. Bootstrap Methods and Their
Application. Cambridge Series in Statistical and Probabilistic Mathematics.
Cambridge University Press, 1997. ISBN 9780521574716.
Dean De Cock. Ames, Iowa: Alternative to the Boston housing data as an
end of semester regression project. Journal of Statistics Education, 19(3):
null, 2011. doi: 10.1080/10691898.2011.11889627. URL https://doi.org/
10.1080/10691898.2011.11889627.
Bibliography 363
Luc Devroye, László Györfi, and Gábor Lugosi. A Probabilistic Theory of Pat-
tern Recognition. Stochastic Modelling and Applied Probability. Springer
New York, 1997. ISBN 9780387946184.
Stephan Dlugosz. rpart.LAD: Least Absolute Deviation Regression Trees,
2020. URL https://CRAN.R-project.org/package=rpart.LAD. R pack-
age version 0.1.2.
Rémi Domingues, Maurizio Filippone, Pietro Michiardi, and Jihane Zouaoui.
A comparative evaluation of outlier detection algorithms: Experiments and
analyses. Pattern Recognition, 74:406–421, 2018. ISSN 0031-3203. doi:
https://doi.org/10.1016/j.patcog.2017.09.037. URL https://doi.org/10.
1016/j.patcog.2017.09.037.
Lisa Doove, Stef Van Buuren, and Elise Dusseldorp. Recursive partitioning
for missing data imputation in the presence of interaction effects. Com-
putational Statistics & Data Analysis, 72:92–104, 2014. ISSN 0167-9473.
doi: 10.1016/j.csda.2013.10.025. URL https://doi.org/10.1016/j.csda.
2013.10.025.
Anna Veronika Dorogush, Vasily Ershov, and Andrey Gulin. Catboost:
gradient boosting with categorical features support, 2018. URL https:
//arxiv.org/abs/1810.11363.
Matt Dowle and Arun Srinivasan. data.table: Extension of ‘data.frame‘,
2021. URL https://CRAN.R-project.org/package=data.table. R pack-
age version 1.14.2.
Tony Duan, Anand Avati, Daisy Yi Ding, Khanh K. Thai, Sanjay Basu, An-
drew Y. Ng, and Alejandro Schuler. Ngboost: Natural gradient boosting for
probabilistic prediction, 2020. URL https://arxiv.org/abs/1910.03225.
Bradley Efron and Trevor Hastie. Computer Age Statistical Inference:
Algorithms, Evidence, and Data Science. Institute of Mathematical
Statistics Monographs. Cambridge University Press, 2016. doi: 10.1017/
CBO9781316576533.
Ad Feelders. Handling missing data in trees: Surrogate splits or statistical
imputation? In Principles of Data Mining and Knowledge Discovery, pages
329–334, 03 2000. ISBN 978-3-540-66490-1. doi: 10.1007/978-3-540-48247-
5_38. URL https://doi.org/10.1007/978-3-540-48247-5_38.
Aaron Fisher, Cynthia Rudin, and Francesca Dominici. All models are wrong,
but many are useful: Learning a variable’s importance by studying an entire
class of prediction models simultaneously. 2018. doi: 10.48550/ARXIV.1801.
01489. URL https://arxiv.org/abs/1801.01489.
Thomas R. Fleming and David P. Harrington. Counting Processes and Sur-
vival Analysis. Wiley Series in Probability and Statistics. Wiley, 1991. ISBN
9780471522188.
364 Bibliography
Christopher Flynn. Python bindings for C++ ranger random forests, 2021.
URL https://github.com/crflynn/skranger. Python package version
0.3.2.
Yoav Freund and Robert E. Schapire. Experiments with a new boosting
algorithm. In Proceedings of the Thirteenth International Conference on
Machine Learning, ICML’96, pages 148–156, San Francisco, CA, USA, 1996.
Morgan Kaufmann Publishers Inc. ISBN 1558604197.
Peter W. Frey and David J. Slate. Letter recognition using Holland-style
adaptive classifiers. Machine Learning, 6(2):161–182, 1991. URL https:
//doi.org/10.1007/BF00114162.
Jerome Friedman, Trevor Hastie, Rob Tibshirani, Balasubramanian
Narasimhan, Kenneth Tay, Noah Simon, and James Yang. glmnet: Lasso
and Elastic-Net Regularized Generalized Linear Models, 2021. URL https:
//CRAN.R-project.org/package=glmnet. R package version 4.1-3.
Jerome H. Friedman. Multivariate adaptive regression splines. The Annals of
Statistics, 19(1):1–67, 03 1991. doi: 10.1214/aos/1176347963. URL https:
//doi.org/10.1214/aos/1176347963.
Jerome H. Friedman. Greedy function approximation: A gradient boosting
machine. The Annals of Statistics, 29(5):1189–1232, 2001. URL https:
//doi.org/10.1214/aos/1013203451.
Jerome H. Friedman. Stochastic gradient boosting. Computational Statis-
tics & Data Analysis, 38(4):367–378, 2002. ISSN 0167-9473. doi: https:
//doi.org/10.1016/S0167-9473(01)00065-2. URL https://doi.org/10.
1016/S0167-9473(01)00065-2.
Jerome H. Friedman and Peter Hall. On bagging and nonlinear es-
timation. Journal of Statistical Planning and Inference, 137(3):669–
683, 2007. ISSN 0378-3758. doi: https://doi.org/10.1016/j.jspi.2006.
06.002. URL http://www.sciencedirect.com/science/article/pii/
S0378375806001339. Special Issue on Nonparametric Statistics and Re-
lated Topics: In honor of M.L. Puri.
Jerome H. Friedman and Bogdan E. Popescu. Importance sampled learning
ensembles. Technical report, Stanford University, Department of Statistics,
2003. URL https://statweb.stanford.edu/~jhf/ftp/isle.pdf.
Jerome H. Friedman and Bogdan E. Popescu. Predictive learning via rule
ensembles. The Annals of Applied Statistics, 2(3):916–954, 2008. ISSN
19326157. URL https://doi.org/10.2307/30245114.
Bibliography 365
Torsten Hothorn, Kurt Hornik, Mark A. van de Wiel, and Achim Zeileis. A
lego system for conditional inference. The American Statistician, 60(3):
257–263, 2006b. URL https://doi.org/10.1198/000313006X118430.
Torsten Hothorn, Kurt Hornik, and Achim Zeileis. Unbiased recursive parti-
tioning: A conditional inference framework. Journal of Computational and
Graphical Statistics, 15(3):651–674, 2006c.
Yoon Dong Lee, Dianne Cook, Ji won Park, and Eun-Kyung Lee. Pptree:
Projection pursuit classification tree. Electronic Journal of Statistics, 7:
Bibliography 371
Wei-Yin Loh and Wei Zheng. Regression trees for longitudinal and multire-
sponse data. The Annals of Applied Statistics, 7(1):495–522, 2013. doi:
10.1214/12-AOAS596. URL https://doi.org/10.1214/12-AOAS596.
Wei-Yin Loh and Peigen Zhou. The GUIDE Approach to Subgroup Identi-
fication, pages 147–165. Springer International Publishing, Cham, 2020.
ISBN 978-3-030-40105-4. doi: 10.1007/978-3-030-40105-4_6. URL https:
//doi.org/10.1007/978-3-030-40105-4_6.
Wei-Yin Loh and Peigen Zhou. Variable importance scores. Journal of Data
Science, 19(4):569–592, 2021. ISSN 1680-743X. doi: 10.6339/21-JDS1023.
URL https://doi.org/10.6339/21-JDS1023.
Wei-Yin Loh, Xu He, and Michael Man. A regression tree approach to identi-
fying subgroups with differential treatment effects. Statistics in Medicine,
34(11):1818–1833, 2015. URL https://doi.org/10.1002/sim.6454.
Wei-Yin Loh, John Eltinge, Moon Jung Cho, and Yuanzhi Li. Classification
and regression trees and forests for incomplete data from sample surveys.
Statistica Sinica, 29(1):431–453, 2019. ISSN 10170405, 19968507. doi: 10.
5705/ss.202017.0225. URL https://doi.org/10.5705/ss.202017.0225.
Wei-Yin Loh, Qiong Zhang, Wenwen Zhang, and Peigen Zhou. Imissing data,
imputation and regression trees. Statistica Sinica, 30:1697–1722, 2020.
Scott M. Lundberg and Su-In Lee. A unified approach to interpreting model
predictions. In I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fer-
gus, S. Vishwanathan, and R. Garnett, editors, Advances in Neural Infor-
mation Processing Systems 30, pages 4765–4774. Curran Associates, Inc.,
2017. URL http://papers.nips.cc/paper/7062-a-unified-approach-
to-interpreting-model-predictions.pdf.
Scott M. Lundberg, Gabriel Erion, Hugh Chen, Alex DeGrave, Jordan M.
Prutkin, Bala Nair, Ronit Katz, Jonathan Himmelfarb, Nisha Bansal, and
Su-In Lee. From local explanations to global understanding with explainable
ai for trees. Nature Machine Intelligence, 2(1):2522–5839, 2020.
Javier Luraschi, Kevin Kuo, Kevin Ushey, J. J. Allaire, Hossein Falaki,
Lu Wang, Andy Zhang, Yitao Li, and The Apache Software Founda-
tion. sparklyr: R Interface to Apache Spark, 2021. URL https://spark.
rstudio.com/. R package version 1.7.3.
Szymon Maksymiuk, Alicja Gosiewska, and Przemyslaw Biecek. Landscape
of r packages for explainable artificial intelligence, 2021. URL https://
arxiv.org/abs/2009.13248.
James D. Malley, Jochen Kruppa, Abhijit Dasgupta, Karen Godlove Malley,
and Andreas Ziegler. Probability machines: consistent probability estima-
tion using nonparametric learning machines. Methods of Information in
Bibliography 373
Lloyd S. Shapley. 17. A Value for n-Person Games, pages 307–318. Princeton
University Press, 2016. URL https://doi.org/10.1515/9781400881970-
018.
Haijian Shi. Best-first Decision Tree Learning. PhD thesis, Hamilton, New
Zealand, 2007. URL https://hdl.handle.net/10289/2317. Masters.
Tao Shi, David Seligson, Arie Belldegrun, Aarno Palotie, and Steve Horvath.
Tumor classification by tissue microarray profiling: random forest cluster-
ing applied to renal cell carcinoma. Modern Pathologyc, 18:547–57, 05
2005. doi: 10.1038/modpathol.3800322. URL https://doi.org/10.1038/
modpathol.3800322.
Yu Shi, Guolin Ke, Damien Soukhavong, James Lamb, Qi Meng, Thomas
Finley, Taifeng Wang, Wei Chen, Weidong Ma, Qiwei Ye, Tie-Yan Liu,
and Nikita Titov. lightgbm: Light Gradient Boosting Machine, 2022. URL
https://github.com/Microsoft/LightGBM. R package version 3.3.2.
Julia Silge, Fanny Chow, Max Kuhn, and Hadley Wickham. rsample: Gen-
eral Resampling Infrastructure, 2021. URL https://CRAN.R-project.
org/package=rsample. R package version 0.1.1.
Nora Sleumer. Hyperplane arrangements: construction visualization and ap-
plications. PhD thesis, Swiss Federal Institute of Technology, 1969. PhD
dissertation.
Helmut Strasser and Christian Weber. On the asymptotic theory of permu-
tation statistics. Mathematical Methods of Statistics, 2(27), 1999.
Carolin Strobl, Anne-Laure Boulesteix, Achim Zeileis, and Torsten Hothorn.
Bias in random forest variable importance measures: Illustrations, sources
and a solution. BMC Bioinformatics, 8(25), 2007a. doi: 10.1186/1471-2105-
8-25.
Carolin Strobl, Anne-Laure Boulesteix, Achim Zeileis, and Torsten Hothorn.
Bias in random forest variable importance measures: Illustrations, sources
and a solution. BMC Bioinformatics, 8(25), 2007b. URL https://doi.
org/10.1186/1471-2105-8-25.
Carolin Strobl, Anne-Laure Boulesteix, Thomas Kneib, Thomas Augustin,
and Achim Zeileis. Conditional variable importance for random forests.
BMC Bioinformatics, 9(307), 2008a. doi: 10.1186/1471-2105-9-307.
Carolin Strobl, Anne-Laure Boulesteix, Thomas Kneib, Thomas Augustin, and
Achim Zeileis. Conditional variable importance for random forests. BMC
Bioinformatics, 9(307), 2008b. URL https://doi.org/10.1186/1471-
2105-9-307.
The Pandas Development Team. pandas-dev/pandas: Pandas, February 2020.
URL https://doi.org/10.5281/zenodo.3509134.
378 Bibliography
Terry Therneau and Beth Atkinson. rpart: Recursive Partitioning and Regres-
sion Trees, 2019. URL https://CRAN.R-project.org/package=rpart. R
package version 4.1-15.
Terry M. Therneau. survival: Survival Analysis, 2021. URL https://github.
com/therneau/survival. R package version 3.2-13.
Julie Tibshirani, Susan Athey, Erik Sverdrup, and Stefan Wager. grf: Gener-
alized Random Forests, 2021. URL https://github.com/grf-labs/grf.
R package version 2.0.2.
Robert Tibshirani. Regression shrinkage and selection via the lasso. Journal
of the Royal Statistical Society. Series B (Methodological), 58(1):267–288,
1996. ISSN 00359246. URL http://www.jstor.org/stable/2346178.
Stef van Buuren. Flexible Imputation of Missing Data. Chapman &
Hall/CRC Interdisciplinary Statistics. CRC Press, Taylor & Francis Group,
2018. ISBN 9781138588318. URL https://books.google.com/books?id=
bLmItgEACAAJ.
Stef van Buuren and Karin Groothuis-Oudshoorn. mice: Multivariate Impu-
tation by Chained Equations, 2021. URL https://CRAN.R-project.org/
package=mice. R package version 3.14.0.
Mark J. van der Laan. Statistical inference for variable importance. The
International Journal of Biostatistics, 2(1), 2006. URL https://doi.org/
10.2202/1557-4679.1008.
Shivaram Venkataraman, Zongheng Yang, Davies Liu, Eric Liang, Hossein
Falaki, Xiangrui Meng, Reynold Xin, Ali Ghodsi, Michael Franklin, Ion
Stoica, and Matei Zaharia. Sparkr: Scaling r programs with spark. In
Proceedings of the 2016 International Conference on Management of Data,
SIGMOD ’16, pages 1099–1104, New York, NY, USA, 2016. Association
for Computing Machinery. ISBN 9781450335317. doi: 10.1145/2882903.
2903740. URL https://doi.org/10.1145/2882903.2903740.
James Verbus. Detecting and preventing abuse on linkedin using isola-
tion forests, Aug. 2019. URL https://engineering.linkedin.com/blog/
2019/isolation-forest.
Erik Štrumbelj and Igor Kononenko. Explaining prediction models and indi-
vidual predictions with feature contributions. Knowledge and Information
Systems, 31(3):647–665, 2014. URL https://doi.org/10.1007/s10115-
013-0679-x.
Stefan Wager, Trevor Hastie, and Bradley Efron. Confidence intervals for
random forests: The jackknife and the infinitesimal jackknife. Journal of
Machine Learning Research, 15(48):1625–1651, 2014. URL http://jmlr.
org/papers/v15/wager14a.html.
Ian R. White, Patrick Royston, and Angela M. Wood. Multiple imputation
using chained equations: Issues and guidance for practice. Statistics in
Medicine, 30(4):377–399, 2011. doi: 10.1002/sim.4067. URL https://doi.
org/10.1002/sim.4067.
Hadley Wickham. Advanced R, Second Edition. Chapman & Hall/CRC The
R Series. CRC Press, 2019. ISBN 9781351201308. URL https://adv-
r.hadley.nz/.
Hadley Wickham and Jennifer Bryan. readxl: Read Excel Files, 2019. URL
https://CRAN.R-project.org/package=readxl. R package version 1.3.1.
Hadley Wickham, Winston Chang, Lionel Henry, Thomas Lin Pedersen,
Kohske Takahashi, Claus Wilke, Kara Woo, Hiroaki Yutani, and Dewey
Dunnington. ggplot2: Create Elegant Data Visualisations Using the Gram-
mar of Graphics, 2021a. URL https://CRAN.R-project.org/package=
ggplot2. R package version 3.3.5.
Hadley Wickham, Romain François, Lionel Henry, and Kirill Müller. dplyr: A
Grammar of Data Manipulation, 2021b. URL https://CRAN.R-project.
org/package=dplyr. R package version 1.0.7.
Edwin B. Wilson and Margaret M. Hilferty. The distribution of chi-square.
Proceedings of the National Academy of Sciences of the United States of
America, 17(12):684–688, 1931.
Marvin N. Wright, Stefan Wager, and Philipp Probst. ranger: A Fast Im-
plementation of Random Forests, 2021. URL https://github.com/imbs-
hl/ranger. R package version 0.13.1.
Paul S. Wright. Adjusted p-values for simultaneous inference. Biometrics,
48(4):1005–1013, 1992. doi: 10.2307/2532694. URL https://doi.org/10.
2307/2532694.
Yihui Xie. knitr: A General-Purpose Package for Dynamic Report Generation
in R, 2021. URL https://yihui.org/knitr/. R package version 1.36.
Ruo Xu, Dan Nettleton, and Daniel J. Nordman. Case-specific random
forests. Journal of Computational and Graphical Statistics, 25(1):49–65,
380 Bibliography
I-Cheng Yeh and Che hui Lien. The comparisons of data mining techniques
for the predictive accuracy of probability of default of credit card clients.
Expert Systems with Applications, 36(2, Part 1):2473–2480, 2009. doi: https:
//doi.org/10.1016/j.eswa.2007.12.020. URL https://doi.org/10.1016/
j.eswa.2007.12.020.
Achim Zeileis, Friedrich Leisch, Kurt Hornik, and Christian Kleiber. struc-
change: Testing, Monitoring, and Dating Structural Changes, 2019. URL
https://CRAN.R-project.org/package=strucchange. R package version
1.5-2.
Haozhe Zhang, Joshua Zimmerman, Dan Nettleton, and Daniel J. Nordman.
Random forest prediction intervals. The American Statistician, 74(4):392–
406, 2020. doi: 10.1080/00031305.2019.1585288. URL https://doi.org/
10.1080/00031305.2019.1585288.
Heping Zhang and Burton H. Singer. Recursive Partitioning and Applications.
Springer New York, New York, NY, 2010. ISBN 978-1-4419-6824-1. URL
https://doi.org/10.1007/978-1-4419-6824-1_3.
Huan Zhang, Si Si, and Cho-Jui Hsieh. Gpu-acceleration for large-scale tree
boosting, 2017. URL https://arxiv.org/abs/1706.08359.
Index
381
382 Index
glmnet, 196, 199, 200 rpart, 19, 27, 39, 40, 42, 55, 59,
grDevices, xv 63, 67, 73, 75, 77, 80,
graphics, 20 83–85, 88, 89, 91–96, 100,
grf, 257 102, 107, 108, 136, 188, 190,
gridExtra, 341 194, 237, 298, 324, 325
h2o, 244, 252, 277 rsample, 301
iBreakDown, 223 sparkR, 301
iml, 206, 207, 211, 223, 226, sparklyr, 213, 277, 300,
292 304
ingredients, 206, 211 strucchange, 134
ipred, 188, 197 survival, 31, 34, 140, 340
isotree, 272, 275 titanic, 283
kernlab, 25, 184 treemisc, xvi, 20, 24, 25, 30,
knitr, 17 84, 88, 90, 96, 138, 152,
lattice, 20, 213, 345 199, 251, 263, 265, 266, 279,
lightgbm, 220, 339 320, 323, 325, 348
mboost, 339 treevalues, 84
mice, 81, 253, 285 tree, 83
microbenchmark, 198, utils, 19
278 vip, 206, 207
mlbench, 24, 103 waterfall, 344
mmpf, 206 xgboost, 220, 339, 352
modeldata, 28 random forest, 36, 159, 164, 182
mosaic, 345 random forest (RF), 229, 231–234,
obliqueRF, 259 236, 237, 239, 243–269,
partykit, 88, 128, 132–134, 136, 276–280, 283–287, 289,
138, 139, 237, 247, 249, 277, 293–300, 302, 303, 306, 307,
283, 324, 325 309, 317–323, 327–329, 331,
party, 128, 132–135, 237, 247, 333, 337, 347, 348, 351,
249, 277 357
pdp, 98, 99, 211, 214, 216, 223, random subspace, 306
289, 290, 329, 341, 343 random variable, 4
purrr, 18 regression, 7
randomForestSRC, 276 regular expression, 303
randomForest, 197, 211, 215, reliability analysis, 32
266, 276, 277, 286 ridge regression, 336
ranger, 247, 248, 250, 256, 257, rotation forests, 261
269, 276–278, 280, 287, 289, rotation matrix, 261, 262, 264
292 rule-based models, 36
regtools, 104 rules, 28
rms, 296
rotationForest, 263, 266 sensitivity, 172
rpart.LAD, 84 Shapley value, 217–221, 223, 226,
rpart.plot, 88 275, 292
rpartScore, 84
388 Index
target leakage, 12
THAID, see theta automatic
interaction detection
theta automatic interaction
detection, 13
time-to-event, 32
true positive rate, 172
twenty questions, 1
two-sample test, 115, 126, 133
type I error, 128, 130