v Decision Tree Classifier - Machine Learning

Decision Tree Classifier

Preliminaries

# Load libraries
from sklearn.tree import DecisionTreeClassifier
from sklearn import datasets

Load Iris Dataset

# Load data
iris = datasets.load_iris()
X = iris.data
y = iris.target

Create Decision Tree Using Gini Impurity

# Create decision tree classifer object using gini
clf = DecisionTreeClassifier(criterion='gini', random_state=0)

Train Model

# Train model
model = clf.fit(X, y)

Create Observation To Predict

# Make new observation
observation = [[ 5,  4,  3,  2]]

Predict Observation

# Predict observation's class    
model.predict(observation)
array([1])

View Predicted Probabilities

# View predicted class probabilities for the three classes
model.predict_proba(observation)
array([[ 0.,  1.,  0.]])