Decision Trees are a great way to learn patterns in data. They are extremely interpretable, making them a useful tool for communicating findings in data science applications. Moreover, extensions to simpler tree-based methods exist which are competitive in prediction accuracy. We first begin with the basics through an example.
Regression and Classification Trees
Assume we are trying to predict the salary of a baseball player based on two
characteristics: the number of years they played in the major leagues, and the
number of hits they had last season. A possible tree that we might end up
with can be seen in the figure below. There are two places where the tree is split:
Years<4.5 and Hits<117.5. These are the tree’s internal nodes. Incidentally, we call the
split conditions Years<4.5 and Hits<117.5 ’splitting rules’.
How exactly do we make a prediction given a new observation? Consider a
player named Bob who has played 6 years in the major leagues and got 96 hits
last season. According to the tree, the first split checks whether the number
of years this player has played in the major leagues is less than 4.5. If it is
less than 4.5, we take the left branch, otherwise, the right (you can swap the
right/left directions as long as it is consistent). Since Bob has played greater
than 4.5 years, we take the right branch. Now we arrive at another ‘splitting
rule’ which asks if the number of hits this player got from last season is less
than 117.5. Bob had 96 hits last season; thus, we travel to the left and get to
the terminal or leaf node from which we make our prediction. In this case, we
predict Bob to make \(e^6 * 1000 ≈ 403,429\) dollars.
The next important concept to address is how a tree partitions the predictor
space. Examining the above figure, we see three regions annotated by R1, R2, R3.
We denote observations within each region as \(y\{R_i\} = \{y_i
: y_i \in R_i\}\). Each of these regions correspond to a terminal node in the tree. Herein lies the key
difference between regression and classification trees: for regression trees, we
take a quantitative statistic such as the mean of all data points within the
region, whereas for classification trees, we take a statistic over qualitative data
such as the mode. For instance, a possible statistic for R1 could be the mean
of all observations that lie in R1, \(\frac{1}{n}\sum_{i=1}^n y_i\)
Pruning
An important question to ask when building a tree is when to stop. Theoretically, a tree can have a terminal node for every single data point in the entire training set and obtain 100% training accuracy; in this case, we will almost certainly overfit to the training set. Pruning is a technique that mitigates this
problem by decreasing the number of terminal nodes by removing branches.
Naturally, we should remove branches such that the error rate increases by a
minimum. One method that accomplishes this is called cost complexity pruning,
which assumes an augmented loss function,
$$\sum_{m=1}^{|T|} \sum_{x_i \in R_m} (y_i - \hat{y}_{R_m})^2 + \alpha |T|$$
where |T| is the number of terminal nodes in the tree, \(R_m\) is the region indexed
by \(m\), \(y_i\) are the training observations, \(\hat{y}_{R_m}\) is the estimate for region \(R_m\) and
\(\alpha\) is a non-negative tuning parameter. The additional \(\alpha\)|T| term penalizes trees
with many terminal nodes, meaning very deep trees will have a higher loss,
despite it having better prediction accuracy on the training set. Since \(\alpha\) is not
fixed, it is typical to learn it using cross validation. This is accomplished through
a grid search to identify values of \(\alpha\) which result in a sequence of successively
shallower trees. Then, cross validation can be performed on each individual tree
to determine which is optimal. See Figure 3 for an example of pruning.
Pros vs. Cons of Trees
As mentioned in the beginning, trees are easy to interpret. Furthermore, the
way trees make predictions is thought to mirror human decision-making. Trees
can also be displayed graphically - a desirable trait for presentations in real world applications. Despite these advantages, trees suffer from lower prediction accuracy, high variance, and lack of robustness (a small change in the data leads to a big change in the resulting tree). See this article if you are interested in learning about extensions to tree-based methods which address these problems.
Twitter Facebook LinkedIn