Visualize Decision tree

  • Decision tree algorithms create splits on the basis of feature values and propagate the tree. Let’s Visualize decision tree to get a better understanding of how decision trees work. 

We have built a decision tree model on iris dataset which has four features namely sepal length, sepal width, petal length and petal width. It has three target values namely setosa, virginica and versicolor. 
# Imports
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, export_graphviz
import pandas as pd
import pydotplus
import numpy as np
from IPython.display import Image  

# Load Data
iris = load_iris()

# Create a dataframe
df = pd.DataFrame(, columns = iris.feature_names)
df['target'] =
# Let's see target names
targets = iris.target_names
['setosa' 'versicolor' 'virginica']
# Prepare training data for building the model
X_train = df.drop(['target'], axis=1)
y_train = df['target']

# Instantiate the model
cls = DecisionTreeClassifier()

# Train/Fit the model, y_train)

# Make prediction using the model
X_pred = [5.1, 3.2, 1.5, 0.5]
y_pred = cls.predict([X_pred])

print("Prediction is: {}".format(targets[y_pred]))
Prediction is: ['setosa']
# Create DOT data
dot_data = tree.export_graphviz(cls, out_file=None, 
                                class_names=targets,filled=True, rounded=True,

# Draw graph
graph = pydotplus.graph_from_dot_data(dot_data)  

# Show graph

As we can see, decision tree algorithm creates splits on the basis of feature values and keeps propagating the tree until it reaches a  clear decision boundary. To learn more about how splitting happens and how it selects best split, check out the article on entropy and gini index.

That's how we Visualize Decision tree

That’s all for this mini tutorial. To sum it up, we learned how to Visualize Decision tree.

Hope it was easy, cool and simple to follow. Now it’s on you.