Visualize Decision tree
Decision tree algorithms create splits on the basis of feature values and propagate the tree. Let’s Visualize decision tree to get a better understanding of how decision trees work.
We have built a decision tree model on iris dataset which has four features namely sepal length, sepal width, petal length and petal width. It has three target values namely setosa, virginica and versicolor.
# Imports
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, export_graphviz
import pandas as pd
import pydotplus
import numpy as np
from IPython.display import Image
# Load Data
iris = load_iris()
# Create a dataframe
df = pd.DataFrame(iris.data, columns = iris.feature_names)
df['target'] = iris.target
# Let's see target names
targets = iris.target_names
print(targets)
['setosa' 'versicolor' 'virginica']
# Prepare training data for building the model
X_train = df.drop(['target'], axis=1)
y_train = df['target']
# Instantiate the model
cls = DecisionTreeClassifier()
# Train/Fit the model
cls.fit(X_train, y_train)
# Make prediction using the model
X_pred = [5.1, 3.2, 1.5, 0.5]
y_pred = cls.predict([X_pred])
print("Prediction is: {}".format(targets[y_pred]))
Prediction is: ['setosa']
# Create DOT data
dot_data = tree.export_graphviz(cls, out_file=None,
feature_names=iris.feature_names,
class_names=targets,filled=True, rounded=True,
special_characters=True)
# Draw graph
graph = pydotplus.graph_from_dot_data(dot_data)
# Show graph
Image(graph.create_png())
As we can see, decision tree algorithm creates splits on the basis of feature values and keeps propagating the tree until it reaches a clear decision boundary. To learn more about how splitting happens and how it selects best split, check out the article on entropy and gini index.
That's how we Visualize Decision tree
That’s all for this mini tutorial. To sum it up, we learned how to Visualize Decision tree.
Hope it was easy, cool and simple to follow. Now it’s on you.
Related Resources:
- Build Decision Tree classification model in Python Build Decision Tree classifier Build Decision tree model. It is a machine learning algorithm which creates a tree on the...
- Build SVM Support Vector Machine model in Python Build SVM | support vector machine classifier SVM (Support Vector Machine) algorithm finds the hyperplane which is at max distance...
- Building Adaboost classifier model in Python Building Adaboost classifier model Adaboost is a boosting algorithm which combines weak learners into a strong classifier. Let’s learn building...
- Build Logistic Regression classifier model in Python Build Logistic Regression classifier Logistic regression is a linear classifier. Despite the name it is actually a classification algorithm. #...
- Build K Nearest Neighbors classifier model in Python Build K Nearest Neighbors classifier K Nearest Neighbors also known as KNN takes max vote of nearest neighbors and predicts...
- Build Random Forest classification model in Python Build Random Forest classifier Random forest is an ensemble technique which combines weak learners to build a strong classifier. #...
- Gini Index vs Entropy Information gain | Decision Tree | No 1 Guide Gini index vs Entropy Gini index and entropy is the criterion for calculating information gain. Decision tree algorithms use information...
- Learn Machine learning with mini tutorials Machine Learning Learn machine learning easily and efficiently with mini tutorials one snippet at a time. Machine learning is the...
- Pipeline in scikit learn | Machine Learning Tutorial Pipeline in scikit learn Pipeline in scikit learn simplifies whole machine learning model building and testing flow. Machine learning model...
- Save Machine Learning model to a file | Pickle Save model to file Save machine learning model so that it can be used again and again without having to...