Decision trees and regression#

Like SVM, Decision Trees are versatile machine learning algorithms. Decision trees can be used for both classification problems and regression problems. They can fit complex datasets. They are also used in the random forest algorithm.

In this notebook you will learn how to train, visualize and use dicision trees. First of all, we are looking at a classification application.


Setup#

# needed packages
import matplotlib.pyplot as plt
import numpy as np
import pydot

from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import export_graphviz
from six import StringIO
from IPython.display import Image  
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split
from matplotlib.colors import ListedColormap
from sklearn.metrics import accuracy_score
# keep random seed stable
np.random.seed(42)
# needed functions
def image_path(fig_id):
    return os.path.join(PROJECT_ROOT_DIR, "images", CHAPTER_ID, fig_id)

def plot_reg_tree(X, x_pred, Y, y_pred, title):
    plt.plot(X, y, "b.")
    plt.plot(x_pred, y_pred, "r.-", linewidth=2, label=r"$\hat{y}$")
    plt.axis([0, 1, -0.2, 1.1])
    plt.xlabel("$x_1$", fontsize=18)
    plt.ylabel("$y$", fontsize=18, rotation=0)
    plt.legend(loc="upper center", fontsize=18)
    plt.title(title, fontsize=14)

def plot_decision_boundary(clf, X, y, axes=[0, 7.5, 0, 3], iris=True, legend=False, plot_training=True):
    x1s = np.linspace(axes[0], axes[1], 100)
    x2s = np.linspace(axes[2], axes[3], 100)
    x1, x2 = np.meshgrid(x1s, x2s)
    X_new = np.c_[x1.ravel(), x2.ravel()]
    y_pred = clf.predict(X_new).reshape(x1.shape)
    custom_cmap = ListedColormap(['#fafab0','#9898ff','#a0faa0'])
    plt.contourf(x1, x2, y_pred, alpha=0.3, cmap=custom_cmap)
    if not iris:
        custom_cmap2 = ListedColormap(['#7d7d58','#4c4c7f','#507d50'])
        plt.contour(x1, x2, y_pred, cmap=custom_cmap2, alpha=0.8)
    if plot_training:
        plt.plot(X[:, 0][y==0], X[:, 1][y==0], "yo", label="Iris-Setosa")
        plt.plot(X[:, 0][y==1], X[:, 1][y==1], "bs", label="Iris-Versicolor")
        plt.plot(X[:, 0][y==2], X[:, 1][y==2], "g^", label="Iris-Virginica")
        plt.axis(axes)
    if iris:
        plt.xlabel("Petal length", fontsize=14)
        plt.ylabel("Petal width", fontsize=14)
    else:
        plt.xlabel(r"$x_1$", fontsize=18)
        plt.ylabel(r"$x_2$", fontsize=18, rotation=0)
    if legend:
        plt.legend(loc="lower right", fontsize=14)

Classification#

For demonstrating purpose we use the iris dataset

from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data[:, 2:] # petal length and width
y = iris.target # kind of iris
print(X.shape, y.shape)
iris
y

To perform the classification, we first create a DecisionTreeClassifier object, in this example called tree_clf. Then we ask the object to execute the fit.

tree_clf = DecisionTreeClassifier(max_depth=2, random_state=42)
tree_clf.fit(X, y)

Visualisation: Tree plot#

We can visualize the DecisionTreeClassifier with the export_graphiz() method

dot_data = StringIO()  
export_graphviz(
        tree_clf,
        out_file=dot_data,
        feature_names=iris.feature_names[2:],
        class_names=iris.target_names,
        rounded=True,
        filled=True
    )
graph = pydot.graph_from_dot_data(dot_data.getvalue())  
Image(graph[0].create_png())  

We can read the picture of the decision tree as follows. Suppose you find an iris flower and you want to classify it. You then start at the top of the root node, at depth = 0. The first question we need to ask ourselves is: Is the pethal length smaller than 2.45 cm? If so, we descend down the True axis down to the next child node,at depth = 1. In this particular case it is a leaf node because it no longer has nodes that we can downsize. It has no children anymore. This is the maximum depth for this node. So we can see here what the predicted class is. That is Iris-Setosa class = setosa

In the picture we also see that a node attribute counts the number of samples. In our example, 50 samples have a petal length smaller than 2.45 cm and 100 samples have a petal length equal to or greater than 2.45 cm. Of the 100 samples, 54 samples have a petal width smaller than 1.75 cm and 46 samples have a petal width equal to or greater than 1.75 cm.

The attribute gini says something about purity. A node is pure \( (gini = 0) \) if all training samples belong to the same class. For example, the versicolor node has a gini score of \( 1 - (0/54) ^ 2 - (49/54) ^ 2 - (5/54) ^ 2 = \) 0.168. The gini score is calculated by the formula: $\( G_i = 1 - \sum_ {k = 1} ^ n P_i, k ^ 2 \)$

Decision Tree boundary plot#

The export_graphviz method is quite a cumbersome method to visualize a decision tree. We also do not immediately see how well the tree is performing. Another method is a decision tree boundary plot. The code is to be find below

plt.figure(figsize=(8, 4))
plot_decision_boundary(tree_clf, X, y)

plt.plot([2.45, 2.45], [0, 3], "k-", linewidth=2)
plt.plot([2.45, 7.5], [1.75, 1.75], "k--", linewidth=2)
plt.text(1.30, 1.0, "Depth=0", fontsize=15)
plt.text(3.2, 1.85, "Depth=1", fontsize=13)
plt.legend()

plt.show()

The thick vertical line represents the decided node at the top at depth 0 (petal length <2.45?) Because the left side is pure we only find Iris setosa here. It cannot be further divided. The right part is not pure, this can be divided. Depth 1 node splits all samples at petal width 1.75 (the ‘- - -’ line).

Predicting classes and class probabilities#

Suppose we find a flower with a pethal length of 6 and a pethal width of 1.5. It will be classified in the Iris-Versicolor classe.

tree_clf.predict_proba([[6, 1.5]])

Indeed, we see that the middle class, the Iris-versicolor shows the highest probability (0.907 …)

If we ask to predict what the class is, it will predict class 1

tree_clf.predict([[6, 1.5]])