py4sci

Table Of Contents

Previous topic

Support vector machines

Next topic

K-means clustering

This Page

Random Forests

Random forests are implemented in R in the randomForest package

    library(randomForest)

    ## Warning: package 'randomForest' was built under R version 2.15.1

    ## randomForest 4.6-7

    ## Type rfNews() to see new features/changes/bug fixes.

    iris = read.table("http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data",
        sep = ",", header = FALSE)
    names(iris) = c("sepal.length", "sepal.width", "petal.length", "petal.width",
        "iris.type")
    model = randomForest(iris.type ~ sepal.length + sepal.width, data = iris)
    names(model)

    ##  [1] "call"            "type"            "predicted"
    ##  [4] "err.rate"        "confusion"       "votes"
    ##  [7] "oob.times"       "classes"         "importance"
    ## [10] "importanceSD"    "localImportance" "proximity"
    ## [13] "ntree"           "mtry"            "forest"
    ## [16] "y"               "test"            "inbag"
    ## [19] "terms"

    plot(model)

_images/ensemble_fig_00.png

The red curve is the error rate for the Setosa class, the green and blue curves above are for Versicolor and Virginica while the black curve is the Out-of-Bag error rate. Let’s see if it improves with more trees:

    model = randomForest(iris.type ~ sepal.length + sepal.width, data = iris, ntree = 5000)
    plot(model)

_images/ensemble_fig_01.png

In[20]:

def iris_random_forest():
    col = {1:'r', 2:'y', 3:'g'}
    grid = np.mgrid[3.5:8:500j,2:5:400j]
    gridT = grid.reshape((2,-1)).T
    %R -i gridT colnames(gridT) = c('sepal.length', 'sepal.width')
    %R model = randomForest(iris.type ~ sepal.length + sepal.width, data=iris, ntree=5000)
    %R gridT = data.frame(gridT); labels = predict(model, gridT, type='response')
    %R -o labels -d iris

    plt.imshow(labels.reshape((500,400)).T, interpolation='nearest', origin='lower', alpha=0.4, extent=[3.5,8,2,5], cmap=pylab.cm.RdYlGn)
    plt.scatter(iris['sepal.length'], iris['sepal.width'], c=[col[t] for t in iris['iris.type']])
    a = plt.gca()
    a.set_xlim((3.5,8))
    a.set_ylim((2,5))

iris_random_forest()
_images/ensemble_fig_02.png
<matplotlib.figure.Figure at 0x9036970>

Boosting

The package gbm implements a version of boosting called gradient boosting. This algorithm is discussed in detail in Chapter 10 of Elements of Statistical Learning. This is related to AdaBoost.M1 but not identical. One of the main differences is the step size it takes, often much smaller than AdaBoost.M1. Another difference is that it can use a binomial or logistic loss rather than the exponential loss of AdaBoost.M1. Finally, gbm does insert randomness in that it only uses a random sample of the data for each gradient step. This allows it to form the OOB error estimate. It can take many trees to get a satisfactory fit with a small shrinkage.

    library(gbm)

    ## Loading required package: survival

    ## Loading required package: splines

    ## Loading required package: lattice

    ## Loaded gbm 1.6.3.2

    data(iris)
    iris = read.table("http://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data",
        sep = ",", header = FALSE)
    names(iris) = c("sepal.length", "sepal.width", "petal.length", "petal.width",
        "iris.type")
    iris$Y = (iris$iris.type != "Iris-versicolor")
    model = gbm(Y ~ sepal.length + sepal.width, data = iris, n.trees = 10000)

    ## Iter   TrainDeviance   ValidDeviance   StepSize   Improve
    ##      1        1.2726             nan     0.0010    0.0002
    ##      2        1.2721             nan     0.0010    0.0002
    ##      3        1.2719             nan     0.0010    0.0001
    ##      4        1.2714             nan     0.0010    0.0002
    ##      5        1.2710             nan     0.0010    0.0002
    ##      6        1.2707             nan     0.0010    0.0001
    ##      7        1.2703             nan     0.0010    0.0001
    ##      8        1.2700             nan     0.0010    0.0001
    ##      9        1.2696             nan     0.0010    0.0001
    ##     10        1.2692             nan     0.0010    0.0002
    ##    100        1.2389             nan     0.0010    0.0002
    ##    200        1.2094             nan     0.0010    0.0001
    ##    300        1.1833             nan     0.0010    0.0001
    ##    400        1.1608             nan     0.0010    0.0001
    ##    500        1.1411             nan     0.0010    0.0001
    ##    600        1.1239             nan     0.0010    0.0001
    ##    700        1.1089             nan     0.0010    0.0001
    ##    800        1.0945             nan     0.0010    0.0000
    ##    900        1.0820             nan     0.0010    0.0000
    ##   1000        1.0706             nan     0.0010    0.0000
    ##   1100        1.0599             nan     0.0010    0.0000
    ##   1200        1.0498             nan     0.0010    0.0000
    ##   1300        1.0409             nan     0.0010    0.0000
    ##   1400        1.0327             nan     0.0010   -0.0000
    ##   1500        1.0245             nan     0.0010    0.0000
    ##   1600        1.0172             nan     0.0010    0.0000
    ##   1700        1.0103             nan     0.0010    0.0000
    ##   1800        1.0037             nan     0.0010   -0.0000
    ##   1900        0.9974             nan     0.0010   -0.0000
    ##   2000        0.9912             nan     0.0010    0.0000
    ##   2100        0.9854             nan     0.0010   -0.0000
    ##   2200        0.9802             nan     0.0010   -0.0000
    ##   2300        0.9750             nan     0.0010   -0.0000
    ##   2400        0.9700             nan     0.0010   -0.0000
    ##   2500        0.9648             nan     0.0010    0.0000
    ##   2600        0.9603             nan     0.0010   -0.0000
    ##   2700        0.9557             nan     0.0010   -0.0000
    ##   2800        0.9512             nan     0.0010   -0.0000
    ##   2900        0.9471             nan     0.0010   -0.0000
    ##   3000        0.9432             nan     0.0010   -0.0000
    ##   3100        0.9392             nan     0.0010   -0.0000
    ##   3200        0.9356             nan     0.0010   -0.0000
    ##   3300        0.9325             nan     0.0010    0.0000
    ##   3400        0.9294             nan     0.0010   -0.0000
    ##   3500        0.9263             nan     0.0010   -0.0000
    ##   3600        0.9234             nan     0.0010   -0.0000
    ##   3700        0.9206             nan     0.0010    0.0000
    ##   3800        0.9181             nan     0.0010    0.0000
    ##   3900        0.9152             nan     0.0010   -0.0000
    ##   4000        0.9127             nan     0.0010   -0.0000
    ##   4100        0.9101             nan     0.0010    0.0000
    ##   4200        0.9077             nan     0.0010    0.0000
    ##   4300        0.9053             nan     0.0010   -0.0000
    ##   4400        0.9029             nan     0.0010   -0.0000
    ##   4500        0.9007             nan     0.0010   -0.0000
    ##   4600        0.8988             nan     0.0010   -0.0000
    ##   4700        0.8965             nan     0.0010   -0.0000
    ##   4800        0.8944             nan     0.0010   -0.0000
    ##   4900        0.8925             nan     0.0010   -0.0000
    ##   5000        0.8906             nan     0.0010   -0.0000
    ##   5100        0.8884             nan     0.0010   -0.0000
    ##   5200        0.8866             nan     0.0010   -0.0000
    ##   5300        0.8846             nan     0.0010   -0.0000
    ##   5400        0.8830             nan     0.0010   -0.0000
    ##   5500        0.8813             nan     0.0010   -0.0000
    ##   5600        0.8797             nan     0.0010   -0.0000
    ##   5700        0.8782             nan     0.0010   -0.0000
    ##   5800        0.8768             nan     0.0010   -0.0000
    ##   5900        0.8751             nan     0.0010   -0.0000
    ##   6000        0.8734             nan     0.0010   -0.0000
    ##   6100        0.8720             nan     0.0010   -0.0000
    ##   6200        0.8701             nan     0.0010   -0.0000
    ##   6300        0.8688             nan     0.0010   -0.0000
    ##   6400        0.8674             nan     0.0010   -0.0000
    ##   6500        0.8659             nan     0.0010   -0.0000
    ##   6600        0.8644             nan     0.0010   -0.0000
    ##   6700        0.8632             nan     0.0010   -0.0000
    ##   6800        0.8616             nan     0.0010   -0.0000
    ##   6900        0.8602             nan     0.0010   -0.0000
    ##   7000        0.8588             nan     0.0010   -0.0000
    ##   7100        0.8575             nan     0.0010   -0.0000
    ##   7200        0.8560             nan     0.0010   -0.0000
    ##   7300        0.8548             nan     0.0010   -0.0000
    ##   7400        0.8537             nan     0.0010   -0.0000
    ##   7500        0.8523             nan     0.0010   -0.0000
    ##   7600        0.8510             nan     0.0010   -0.0000
    ##   7700        0.8500             nan     0.0010   -0.0000
    ##   7800        0.8489             nan     0.0010   -0.0001
    ##   7900        0.8478             nan     0.0010   -0.0000
    ##   8000        0.8468             nan     0.0010   -0.0000
    ##   8100        0.8457             nan     0.0010   -0.0000
    ##   8200        0.8447             nan     0.0010   -0.0001
    ##   8300        0.8436             nan     0.0010   -0.0000
    ##   8400        0.8427             nan     0.0010   -0.0000
    ##   8500        0.8417             nan     0.0010   -0.0000
    ##   8600        0.8406             nan     0.0010    0.0000
    ##   8700        0.8395             nan     0.0010   -0.0000
    ##   8800        0.8385             nan     0.0010   -0.0000
    ##   8900        0.8376             nan     0.0010    0.0000
    ##   9000        0.8366             nan     0.0010   -0.0000
    ##   9100        0.8356             nan     0.0010   -0.0000
    ##   9200        0.8348             nan     0.0010   -0.0000
    ##   9300        0.8339             nan     0.0010   -0.0000
    ##   9400        0.8331             nan     0.0010   -0.0000
    ##   9500        0.8321             nan     0.0010   -0.0000
    ##   9600        0.8312             nan     0.0010   -0.0000
    ##   9700        0.8304             nan     0.0010   -0.0000
    ##   9800        0.8296             nan     0.0010   -0.0001
    ##   9900        0.8287             nan     0.0010   -0.0000
    ##  10000        0.8278             nan     0.0010   -0.0000

We can use only a certain fraction of the trees to predict – so-called early stopping. Here is the result after 100 trees:

In[3]:

def iris_boosting(npredict=100, shrinkage=0.001, depth=1):
    grid = np.mgrid[3.5:8:500j,2:5:400j]
    gridT = grid.reshape((2,-1)).T
    %R -i gridT,npredict,shrinkage,depth colnames(gridT) = c('sepal.length', 'sepal.width')
    %R model = gbm(Y ~ sepal.length + sepal.width, data=iris, n.trees=npredict, verbose=FALSE, shrinkage=shrinkage, interaction.depth=depth)
    %R gridT = data.frame(gridT); labels = sign(predict(model, gridT, npredict))
    %R -o labels -d iris

    plt.imshow(labels.reshape((500,400)).T, interpolation='nearest', origin='lower', alpha=0.4, extent=[3.5,8,2,5], cmap=pylab.cm.RdYlGn)
    plt.scatter(iris['sepal.length'], iris['sepal.width'], c=[col[t] for t in iris['iris.type']])
    a = plt.gca()
    a.set_xlim((3.5,8))
    a.set_ylim((2,5))

iris_boosting()
_images/ensemble_fig_03.png
<matplotlib.figure.Figure at 0x6a0b7d0>

After 1000 trees

In[4]:

iris_boosting(1000)
_images/ensemble_fig_04.png
<matplotlib.figure.Figure at 0x69fff90>

After 5000 trees

In[5]:

iris_boosting(5000)
_images/ensemble_fig_05.png
<matplotlib.figure.Figure at 0x57212b0>

After 10000 trees

In[6]:

iris_boosting(10000)
_images/ensemble_fig_06.png
<matplotlib.figure.Figure at 0x901ff90>

And finally, after 20000 trees

In[7]:

iris_boosting(20000)
_images/ensemble_fig_07.png
<matplotlib.figure.Figure at 0x69ffa30>

The step size can also influence the fit. This is controoled by the shrinkage argument to gbm. We can force it to fit more quickly by increasing shrinkage.

In[8]:

iris_boosting(1000, 0.1)
_images/ensemble_fig_08.png
<matplotlib.figure.Figure at 0x561ed30>

There are lots of bells and whistles in gbm. The fit of the gbm in this case is an additive model, i.e. the classifier has the form

G(x) = \text{sign}(f_1({\tt sepal.length}) + f_2({\tt sepal.width})).

We can see these functions with

    plot(model, 1)

_images/ensemble_fig_09.png
    plot(model, 2)

_images/ensemble_fig_10.png

We can get a sense at how quickly the error is going down by looking at the OOB error. It doesn’t seem to be improving much after 100 trees.

    gbm.perf(model, method = "OOB")

    ## Warning: OOB generally underestimates the optimal number of iterations
    ## although predictive performance is reasonably competitive. Using
    ## cv.folds>0 when calling gbm usually results in improved predictive
    ## performance.

_images/ensemble_fig_11.png
    ## [1] 22

By default, gbm fits this additive model. It can include interactions by changing the interaction.depth. This means that it will fit trees with a split on up to two variables each time. These two-way models still have the notion of an additive part. These can still be plotted. We will also evaluate the classifier on test data, using only 90% to train the model.

    model = gbm(Y ~ sepal.length + sepal.width, data = iris, n.trees = 1000, verbose = FALSE,
        shrinkage = 0.1, interaction.depth = 2, train.fraction = 0.9)
    plot(model, 1)

_images/ensemble_fig_12.png
    plot(model, 2)

_images/ensemble_fig_13.png
    gbm.perf(model, method = "test")

_images/ensemble_fig_14.png
    ## [1] 1

It seems that for our iris data, our test set is a little small. Let’s try gbm out on the spam data.

    library(ElemStatLearn)
    data(spam)
    spam$spam = c(1, 0)[unclass(spam$spam)]
    model = gbm(spam ~ ., data = spam, train.fraction = 0.9, interaction.depth = 2,
        n.trees = 1000, shrinkage = 0.1)

    ## Iter   TrainDeviance   ValidDeviance   StepSize   Improve
    ##      1        1.2849          1.1451     0.1000    0.0437
    ##      2        1.2066          1.1515     0.1000    0.0388
    ##      3        1.1364          1.0870     0.1000    0.0341
    ##      4        1.0804          1.0713     0.1000    0.0279
    ##      5        1.0239          1.0358     0.1000    0.0281
    ##      6        0.9798          1.0478     0.1000    0.0214
    ##      7        0.9350          1.0004     0.1000    0.0216
    ##      8        0.8949          0.9987     0.1000    0.0201
    ##      9        0.8602          0.9787     0.1000    0.0173
    ##     10        0.8295          0.9802     0.1000    0.0150
    ##    100        0.2925          0.9425     0.1000    0.0004
    ##    200        0.2348          0.8885     0.1000    0.0001
    ##    300        0.2044          0.8959     0.1000   -0.0002
    ##    400        0.1823          0.9325     0.1000   -0.0002
    ##    500        0.1661          0.9102     0.1000   -0.0002
    ##    600        0.1524          0.9373     0.1000   -0.0001
    ##    700        0.1419          0.9384     0.1000    0.0000
    ##    800        0.1323          0.9551     0.1000   -0.0001
    ##    900        0.1236          0.9720     0.1000   -0.0002
    ##   1000        0.1158          0.9622     0.1000   -0.0001

    gbm.perf(model, method = "test")

_images/ensemble_fig_15.png
    ## [1] 265