py4sci

Table Of Contents

Previous topic

Multidimensional arrays

Next topic

Discriminant analysis

This Page

Decision trees

This example applies R‘s decision tree tools to the iris data and does some simple visualization.

Iris data

    iris = read.table("http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data",
        sep = ",", header = FALSE)
    names(iris) = c("sepal.length", "sepal.width", "petal.length", "petal.width",
        "iris.type")
    attach(iris)

The library we will use is the tree library

    library(tree)

    ## Warning: package 'tree' was built under R version 2.15.1

    stree = tree(iris.type ~ ., data = iris)
    stree

    ## node), split, n, deviance, yval, (yprob)
    ##       * denotes terminal node
    ##
    ##  1) root 150 300 Iris-setosa ( 0.33 0.33 0.33 )
    ##    2) petal.length < 2.45 50   0 Iris-setosa ( 1.00 0.00 0.00 ) *
    ##    3) petal.length > 2.45 100 100 Iris-versicolor ( 0.00 0.50 0.50 )
    ##      6) petal.width < 1.75 54  30 Iris-versicolor ( 0.00 0.91 0.09 )
    ##       12) petal.length < 4.95 48  10 Iris-versicolor ( 0.00 0.98 0.02 )
    ##         24) sepal.length < 5.15 5   5 Iris-versicolor ( 0.00 0.80 0.20 ) *
    ##         25) sepal.length > 5.15 43   0 Iris-versicolor ( 0.00 1.00 0.00 ) *
    ##       13) petal.length > 4.95 6   8 Iris-virginica ( 0.00 0.33 0.67 ) *
    ##      7) petal.width > 1.75 46  10 Iris-virginica ( 0.00 0.02 0.98 )
    ##       14) petal.length < 4.95 6   5 Iris-virginica ( 0.00 0.17 0.83 ) *
    ##       15) petal.length > 4.95 40   0 Iris-virginica ( 0.00 0.00 1.00 ) *

We can also plot the tree. The second command adds the appropriate text to the visualization.

    plot(stree)
    text(stree)

_images/trees_fig_00.png

The tree function uses deviance or entropy by default:

    stree = tree(iris.type ~ ., data = iris, split = "gini")
    stree

    ## node), split, n, deviance, yval, (yprob)
    ##       * denotes terminal node
    ##
    ##    1) root 150 300 Iris-setosa ( 0.33 0.33 0.33 )
    ##      2) petal.length < 1.35 11   0 Iris-setosa ( 1.00 0.00 0.00 ) *
    ##      3) petal.length > 1.35 139 300 Iris-versicolor ( 0.28 0.36 0.36 )
    ##        6) sepal.width < 2.35 7   6 Iris-versicolor ( 0.00 0.86 0.14 ) *
    ##        7) sepal.width > 2.35 132 300 Iris-virginica ( 0.30 0.33 0.37 )
    ##         14) sepal.width < 2.55 11  10 Iris-versicolor ( 0.00 0.64 0.36 )
    ##           28) petal.length < 4.25 6   0 Iris-versicolor ( 0.00 1.00 0.00 ) *
    ##           29) petal.length > 4.25 5   5 Iris-virginica ( 0.00 0.20 0.80 ) *
    ##         15) sepal.width > 2.55 121 300 Iris-virginica ( 0.32 0.31 0.37 )
    ##           30) petal.length < 1.45 12   0 Iris-setosa ( 1.00 0.00 0.00 ) *
    ##           31) petal.length > 1.45 109 200 Iris-virginica ( 0.25 0.34 0.41 )
    ##             62) sepal.width < 2.65 5   7 Iris-versicolor ( 0.00 0.60 0.40 ) *
    ##             63) sepal.width > 2.65 104 200 Iris-virginica ( 0.26 0.33 0.41 )
    ##              126) sepal.width < 2.75 9  10 Iris-versicolor ( 0.00 0.56 0.44 ) *
    ##              127) sepal.width > 2.75 95 200 Iris-virginica ( 0.28 0.31 0.41 )
    ##                254) sepal.width < 2.85 14  20 Iris-virginica ( 0.00 0.43 0.57 )
    ##                  508) petal.length < 4.85 7   6 Iris-versicolor ( 0.00 0.86 0.14 ) *
    ##                  509) petal.length > 4.85 7   0 Iris-virginica ( 0.00 0.00 1.00 ) *
    ##                255) sepal.width > 2.85 81 200 Iris-virginica ( 0.33 0.28 0.38 )
    ##                  510) petal.length < 1.55 14   0 Iris-setosa ( 1.00 0.00 0.00 ) *
    ##                  511) petal.length > 1.55 67 100 Iris-virginica ( 0.19 0.34 0.46 )
    ##                   1022) petal.length < 5.05 38  60 Iris-versicolor ( 0.34 0.61 0.05 )
    ##                     2044) petal.length < 2.75 13   0 Iris-setosa ( 1.00 0.00 0.00 ) *
    ##                     2045) petal.length > 2.75 25  10 Iris-versicolor ( 0.00 0.92 0.08 )
    ##                       4090) petal.length < 4.75 20   0 Iris-versicolor ( 0.00 1.00 0.00 ) *
    ##                       4091) petal.length > 4.75 5   7 Iris-versicolor ( 0.00 0.60 0.40 ) *
    ##                   1023) petal.length > 5.05 29   0 Iris-virginica ( 0.00 0.00 1.00 ) *

    plot(stree)
    text(stree)

_images/trees_fig_01.png

We can also easily fit a decision tree using only two of the variables. This will allow us to visualize the regions in 2 dimensions fairly easily.

    stree = tree(iris.type ~ petal.width + petal.length, data = iris)
    stree

    ## node), split, n, deviance, yval, (yprob)
    ##       * denotes terminal node
    ##
    ##  1) root 150 300 Iris-setosa ( 0.33 0.33 0.33 )
    ##    2) petal.width < 0.8 50   0 Iris-setosa ( 1.00 0.00 0.00 ) *
    ##    3) petal.width > 0.8 100 100 Iris-versicolor ( 0.00 0.50 0.50 )
    ##      6) petal.width < 1.75 54  30 Iris-versicolor ( 0.00 0.91 0.09 )
    ##       12) petal.length < 4.95 48  10 Iris-versicolor ( 0.00 0.98 0.02 ) *
    ##       13) petal.length > 4.95 6   8 Iris-virginica ( 0.00 0.33 0.67 ) *
    ##      7) petal.width > 1.75 46  10 Iris-virginica ( 0.00 0.02 0.98 )
    ##       14) petal.length < 4.95 6   5 Iris-virginica ( 0.00 0.17 0.83 ) *
    ##       15) petal.length > 4.95 40   0 Iris-virginica ( 0.00 0.00 1.00 ) *

    plot(stree)
    text(stree)

_images/trees_fig_02.png

Here is a visualization of this two-dimensional decision boundary

In[7]:

%load_ext rmagic
%R -d iris
from matplotlib import pyplot as plt, mlab
col = {1:'r', 2:'y', 3:'g'}
coln = {'Iris-setosa':'r', 'Iris-versicolor':'y', 'Iris-virginica':'g'}

In[8]:

iris.shape
plt.scatter(iris['petal.length'], iris['petal.width'], c=[col[t] for t in iris['iris.type']])
a = plt.gca()
a.set_xlabel('Petal length')
a.set_ylabel('Petal width')

# Here are the regions as described in R's plot above
# There are five terminal leaves, so there are five regions

xf, yf = mlab.poly_between([0,8],[-0.5,-0.5],[0.8,0.8])
plt.fill(xf, yf, coln['Iris-setosa'], alpha=0.3)

xf, yf = mlab.poly_between([0,4.95],[0.8,0.8],[1.75,1.75])
plt.fill(xf, yf, coln['Iris-versicolor'], alpha=0.3)

xf, yf = mlab.poly_between([4.95,8],[0.8,0.8],[1.75,1.75])
plt.fill(xf, yf, coln['Iris-virginica'], alpha=0.3)

xf, yf = mlab.poly_between([4.95,8],[1.75,1.75],[3,3])
plt.fill(xf, yf, coln['Iris-virginica'], alpha=0.3)

xf, yf = mlab.poly_between([0,4.95],[1.75,1.75],[3,3])
plt.fill(xf, yf, coln['Iris-virginica'], alpha=0.3)

a.set_xlim((0,8))
a.set_ylim((-0.5,3))

Out[8]:

(-0.5, 3)
_images/trees_fig_03.png
<matplotlib.figure.Figure at 0x67cfdd0>

Let’s draw the decision boundaries another way, one that will be easier for later classifiers. Recall that a classifier assigns a label to each point in the feature space. We are going to evaluate the label on a dense grid, then show an image of that label.

In[9]:

grid = np.mgrid[-0.5:3.5:500j,0:8:400j]
gridT = grid.reshape((2,-1)).T
gridT.shape

Out[9]:

(200000, 2)

Having created a grid, we will use R‘s predict function to evaluate the label on this grid.

In[10]:

%%R -i gridT -o labels
colnames(gridT) = c('petal.width', 'petal.length')
gridT = data.frame(gridT)
labels = predict(stree, gridT, type='class')

We underlay the image beneath the scatter plot.

In[11]:

plt.imshow(labels.reshape((500,400)), interpolation='nearest', origin='lower', alpha=0.4, extent=[0,8,-0.5,3], cmap=pylab.cm.RdYlGn)
plt.scatter(iris['petal.length'], iris['petal.width'], c=[col[t] for t in iris['iris.type']])
a = plt.gca()
a.set_xlim((0,8))
a.set_ylim((-0.5,3))

Out[11]:

(-0.5, 3)
_images/trees_fig_04.png
<matplotlib.figure.Figure at 0x67bfb30>

Training and test error

Finally, let’s form a test set of size 50 and training set of size 100. We will fit the tree on the training set and then evaluate the result on the test set.

    test_sample = sample(1:150, 50)
    test_data = iris[test_sample, ]
    training_data = iris[-test_sample, ]
    fit_tree = tree(iris.type ~ ., data = training_data)
    test_predictions = predict(fit_tree, test_data, type = "class")
    test_error = sum(test_predictions != test_data$iris.type)/nrow(test_data)
    test_error

    ## [1] 0.1

    training_error = sum(predict(fit_tree, type = "class") != training_data$iris.type)/nrow(training_data)
    training_error

    ## [1] 0.01

Usually, the test error is slightly higher than the training error, but this does not have to be the case.

Voting data

Let’s try the tree classifier on our voting data:

    votes = read.table("http://stats202.stanford.edu/data/2011_cleaned_votes.csv",
        header = TRUE, sep = ";")
    dim(votes)

    ## [1] 426 948

    vtree = tree(party ~ ., data = votes)
    vtree

    ## node), split, n, deviance, yval, (yprob)
    ##       * denotes terminal node
    ##
    ## 1) root 426 600 R ( 0.4 0.6 )
    ##   2) numeric_vote675 < 0.5 189   0 D ( 1.0 0.0 ) *
    ##   3) numeric_vote675 > 0.5 237   0 R ( 0.0 1.0 ) *

    plot(vtree)
    text(vtree)

_images/trees_fig_05.png

Apparently, bill 675 are very partisan, achieving perfect separation. Here is the bill. We see that all 239 Republicans voted Y, 187 Democrats voted N and 6 Democrats did not vote. What if it was 6 Republicans who did not vote? Where would the split be?

Health data

The rpart library also has a decision tree fitter. This one is more flexible and follows closer to the standard CART approach though its pruning is different than described in the notes. First, we show how to prune a tree as described in the notes.

    health = read.table("http://stats202.stanford.edu/data/health.csv", sep = ",",
        header = TRUE)
    health_no_sample_weight = health[, -c(13)]
    health_tree = tree(country ~ ., data = health_no_sample_weight, mindev = 0.001)
    summary(health_tree)

    ##
    ## Classification tree:
    ## tree(formula = country ~ ., data = health_no_sample_weight, mindev = 0.001)
    ## Variables actually used in tree construction:
    ##  [1] "age"          "weight"       "teeth"        "vegetables"
    ##  [5] "hungry"       "hands_soap"   "fruit"        "height"
    ##  [9] "hands_toilet" "bmi"
    ## Number of terminal nodes:  54
    ## Residual mean deviance:  1.19 = 25100 / 21000
    ## Misclassification error rate: 0.245 = 5163 / 21089

    plot(health_tree)
    text(health_tree)

_images/trees_fig_06.png

Now, let’s prune it

    plot(prune.tree(health_tree))
    abline(v = 6, col = "red")

_images/trees_fig_07.png

We may want to specify a particular value for the cost-complexity parameter. Above, the plot shows that for a cost-complexity parameter of about 390, we should have a tree of size 6.

    pruned_tree = prune.tree(health_tree, k = 390)
    summary(pruned_tree)

    ##
    ## Classification tree:
    ## snip.tree(tree = health_tree, nodes = c(15, 27, 26, 2, 12, 14
    ## ))
    ## Variables actually used in tree construction:
    ## [1] "age"    "weight" "teeth"  "hungry"
    ## Number of terminal nodes:  6
    ## Residual mean deviance:  1.44 = 30400 / 21100
    ## Misclassification error rate: 0.288 = 6077 / 21089

    plot(pruned_tree)
    text(pruned_tree)

_images/trees_fig_08.png

Let’s fit the tree using rpart. We see that it has been pruned already. In fact, no observations are labelled ugh.

    library(rpart)
    health.tree = rpart(country ~ ., method = "class", data = health_no_sample_weight)
    summary(health.tree)

    ## Call:
    ## rpart(formula = country ~ ., data = health_no_sample_weight,
    ##     method = "class")
    ##   n= 24662
    ##
    ##        CP nsplit rel error xerror     xstd
    ## 1 0.06746      0    1.0000 1.0000 0.008495
    ## 2 0.02728      2    0.8651 0.8752 0.008221
    ## 3 0.01727      3    0.8378 0.8561 0.008172
    ## 4 0.01000      7    0.7687 0.7830 0.007962
    ##
    ## Node number 1: 24662 observations,    complexity param=0.06746
    ##   predicted class=aeh  expected loss=0.3597
    ##     class counts: 15790  5657  3215
    ##    probabilities: 0.640 0.229 0.130
    ##   left son=2 (19131 obs) right son=3 (5531 obs)
    ##   Primary splits:
    ##       teeth  splits as  LLLLRL,    improve=909.3, (211 missing)
    ##       age    splits as  LLLRRR,    improve=866.2, (314 missing)
    ##       weight < 55.5  to the right, improve=611.4, (2187 missing)
    ##       hungry splits as  LLLRR,     improve=595.0, (269 missing)
    ##       bmi    < 22.84 to the right, improve=515.6, (2187 missing)
    ##
    ## Node number 2: 19131 observations,    complexity param=0.01727
    ##   predicted class=aeh  expected loss=0.2829
    ##     class counts: 13719  3065  2347
    ##    probabilities: 0.717 0.160 0.123
    ##   left son=4 (5837 obs) right son=5 (13294 obs)
    ##   Primary splits:
    ##       age        splits as  LLLRRR,    improve=449.5, (246 missing)
    ##       hungry     splits as  LLLRR,     improve=342.6, (232 missing)
    ##       hands_soap splits as  LLRRR,     improve=324.7, (237 missing)
    ##       weight     < 56.5  to the right, improve=306.2, (1675 missing)
    ##       bmi        < 22.32 to the right, improve=300.3, (1675 missing)
    ##
    ## Node number 3: 5531 observations,    complexity param=0.06746
    ##   predicted class=pih  expected loss=0.5314
    ##     class counts:  2071  2592   868
    ##    probabilities: 0.374 0.469 0.157
    ##   left son=6 (1200 obs) right son=7 (4331 obs)
    ##   Primary splits:
    ##       age    splits as  LLLRRR,    improve=353.0, (68 missing)
    ##       weight < 50.5  to the right, improve=260.5, (512 missing)
    ##       hungry splits as  LRLRR,     improve=177.3, (37 missing)
    ##       bmi    < 23.14 to the right, improve=146.8, (512 missing)
    ##       height < 1.575 to the left,  improve=133.0, (512 missing)
    ##
    ## Node number 4: 5837 observations
    ##   predicted class=aeh  expected loss=0.09474
    ##     class counts:  5284   313   240
    ##    probabilities: 0.905 0.054 0.041
    ##
    ## Node number 5: 13294 observations,    complexity param=0.01727
    ##   predicted class=aeh  expected loss=0.3655
    ##     class counts:  8435  2752  2107
    ##    probabilities: 0.634 0.207 0.158
    ##   left son=10 (4675 obs) right son=11 (8619 obs)
    ##   Primary splits:
    ##       weight     < 55.5  to the right, improve=447.2, (1185 missing)
    ##       bmi        < 23.21 to the right, improve=343.3, (1185 missing)
    ##       hungry     splits as  LLLRR,     improve=322.4, (155 missing)
    ##       hands_soap splits as  LLRRR,     improve=311.1, (160 missing)
    ##       vegetables splits as  LRRRRRL,   improve=251.1, (124 missing)
    ##   Surrogate splits:
    ##       bmi    < 21.51 to the right, agree=0.858, adj=0.632, (0 split)
    ##       height < 1.665 to the right, agree=0.697, adj=0.216, (0 split)
    ##
    ## Node number 6: 1200 observations
    ##   predicted class=aeh  expected loss=0.25
    ##     class counts:   900   224    76
    ##    probabilities: 0.750 0.187 0.063
    ##
    ## Node number 7: 4331 observations,    complexity param=0.02728
    ##   predicted class=pih  expected loss=0.4532
    ##     class counts:  1171  2368   792
    ##    probabilities: 0.270 0.547 0.183
    ##   left son=14 (1574 obs) right son=15 (2757 obs)
    ##   Primary splits:
    ##       weight     < 50.5  to the right, improve=263.9, (399 missing)
    ##       height     < 1.575 to the left,  improve=190.2, (399 missing)
    ##       hungry     splits as  LRLRR,     improve=124.5, (23 missing)
    ##       bmi        < 23.14 to the right, improve=115.5, (399 missing)
    ##       vegetables splits as  LLRRRLL,   improve=101.2, (35 missing)
    ##   Surrogate splits:
    ##       bmi    < 20.55 to the right, agree=0.801, adj=0.503, (0 split)
    ##       height < 1.615 to the right, agree=0.731, adj=0.328, (0 split)
    ##
    ## Node number 10: 4675 observations
    ##   predicted class=aeh  expected loss=0.1765
    ##     class counts:  3850   276   549
    ##    probabilities: 0.824 0.059 0.117
    ##
    ## Node number 11: 8619 observations,    complexity param=0.01727
    ##   predicted class=aeh  expected loss=0.468
    ##     class counts:  4585  2476  1558
    ##    probabilities: 0.532 0.287 0.181
    ##   left son=22 (4832 obs) right son=23 (3787 obs)
    ##   Primary splits:
    ##       hungry     splits as  LLLRR,   improve=265.7, (111 missing)
    ##       vegetables splits as  LRRRRRR, improve=226.7, (79 missing)
    ##       teeth      splits as  LLLR-R,  improve=204.6, (101 missing)
    ##       hands_soap splits as  LLRRR,   improve=189.8, (104 missing)
    ##       fruit      splits as  LRRRRRR, improve=181.0, (110 missing)
    ##   Surrogate splits:
    ##       hands_soap   splits as  LLLLR, agree=0.565, adj=0.020, (101 split)
    ##       hands_toilet splits as  LRLRR, agree=0.563, adj=0.014, (5 split)
    ##
    ## Node number 14: 1574 observations
    ##   predicted class=aeh  expected loss=0.5623
    ##     class counts:   689   447   438
    ##    probabilities: 0.438 0.284 0.278
    ##
    ## Node number 15: 2757 observations
    ##   predicted class=pih  expected loss=0.3032
    ##     class counts:   482  1921   354
    ##    probabilities: 0.175 0.697 0.128
    ##
    ## Node number 22: 4832 observations
    ##   predicted class=aeh  expected loss=0.3551
    ##     class counts:  3116   857   859
    ##    probabilities: 0.645 0.177 0.178
    ##
    ## Node number 23: 3787 observations,    complexity param=0.01727
    ##   predicted class=pih  expected loss=0.5725
    ##     class counts:  1469  1619   699
    ##    probabilities: 0.388 0.428 0.185
    ##   left son=46 (1744 obs) right son=47 (2043 obs)
    ##   Primary splits:
    ##       vegetables splits as  LLRRRLL,   improve=150.70, (38 missing)
    ##       teeth      splits as  LLLR-R,    improve=135.50, (43 missing)
    ##       fruit      splits as  LLRRRRL,   improve= 87.00, (49 missing)
    ##       height     < 1.525 to the left,  improve= 85.22, (491 missing)
    ##       hands_soap splits as  LLRRR,     improve= 76.15, (38 missing)
    ##   Surrogate splits:
    ##       fruit        splits as  LLRRRRL, agree=0.652, adj=0.243, (31 split)
    ##       teeth        splits as  LLLR-R,  agree=0.583, adj=0.092, (5 split)
    ##       hands_soap   splits as  RRLRR,   agree=0.544, adj=0.009, (1 split)
    ##       hands_toilet splits as  RRLLR,   agree=0.541, adj=0.001, (1 split)
    ##
    ## Node number 46: 1744 observations
    ##   predicted class=aeh  expected loss=0.4633
    ##     class counts:   936   473   335
    ##    probabilities: 0.537 0.271 0.192
    ##
    ## Node number 47: 2043 observations
    ##   predicted class=pih  expected loss=0.4391
    ##     class counts:   533  1146   364
    ##    probabilities: 0.261 0.561 0.178

    plot(health.tree)
    text(health.tree)

_images/trees_fig_09.png

The parameter cp controls this pruning and has default. The algorithm terminates if the improvement in a given split is not at least cp.

    health.tree = rpart(country ~ ., method = "class", data = health_no_sample_weight,
        cp = 0.005)
    plot(health.tree)
    text(health.tree)

_images/trees_fig_10.png

For rpart we can specify a loss matrix to emphasize one mistake as worse than others.

    loss_matrix = matrix(c(0, 1, 1, 1, 0, 1, 1, 50, 0), 3, 3)
    loss_matrix

    ##      [,1] [,2] [,3]
    ## [1,]    0    1    1
    ## [2,]    1    0   50
    ## [3,]    1    1    0

    health.tree.cost = rpart(country ~ ., method = "class", data = health_no_sample_weight,
        parms = list(loss = loss_matrix), cp = 0.001)
    plot(health.tree.cost)
    text(health.tree.cost)

_images/trees_fig_11.png
    health.tree.cost.info = rpart(country ~ ., method = "class", data = health_no_sample_weight,
        parms = list(loss = loss_matrix, split = "info"), cp = 0.001)
    plot(health.tree.cost.info)
    text(health.tree.cost.info)

_images/trees_fig_12.png

Confusion matrix

    table(predict(health.tree.cost, type = "class"), health_no_sample_weight$country)

    ##
    ##         aeh   pih   ugh
    ##   aeh 14694  3758  1542
    ##   pih   146   858     0
    ##   ugh   950  1041  1673

Although rpart accepts split as an argument it does not seem to be using it.

    table(predict(health.tree.cost.info, type = "class"), health_no_sample_weight$country)

    ##
    ##         aeh   pih   ugh
    ##   aeh 14765  3402  1628
    ##   pih   186  1094     0
    ##   ugh   839  1161  1587

Training vs. test error for health data

Let’s repeat what we had done above for the iris data, to see what the training and test error look like.

A more complicated tree may fit better (or worse).

Let’s make a plot of training error and test error