scikit-learn: Decision Tree split looks wrong
Description
When looking for an optimal split on a decision tree, DecisionTreeClassifier gives a seemingly wrong answer
Steps/Code to Reproduce
raw dataset used : https://github.com/juan-barragan/titanic/blob/master/titanic_ds.csv
import pandas as pd
import seaborn
import matplotlib.pyplot as plt
from sklearn import tree
import graphviz
import numpy as np
def gini_by_age(df, t):
df['age_group'] = df['age'].apply(lambda row : 0 if row <= t else 1)
kids = df[df['age_group'] == 0]
kids0 = kids[kids['survived'] == 0]
kids1 = kids[kids['survived'] == 1]
adults = df[df['age_group'] == 1]
adults0 = adults[adults['survived'] == 0]
adults1 = adults[adults['survived'] == 1]
gk = 1 - (len(kids0)**2 + len(kids1)**2)/float(len(kids))**2
ga = 1 - (len(adults0)**2 + len(adults1)**2)/float(len(adults))**2
return gk + ga
def plot_gini_by_age(df):
ages = range(6,25)
y = [gini_by_age(df, a) for a in ages]
plt.plot(ages, y)
plt.show()
def use_tree(df):
X = np.array(df['age']).reshape((len(df['age']),1))
y = df['survived']
clf = tree.DecisionTreeClassifier(max_depth=1).fit(X,y)
dot_data = tree.export_graphviz(clf, out_file=None)
graph = graphviz.Source(dot_data)
graph.render("age")
titanic_df = pd.read_csv("titanic_ds.csv")
ages_cov = titanic_df[['age', 'survived']].dropna()
plot_gini_by_age(ages_cov)
use_tree(ages_cov)
print gini_by_age(ages_cov, 5)
print gini_by_age(ages_cov, 6)
print gini_by_age(ages_cov, 8.5)
print gini_by_age(ages_cov, 15)
Expected Results
gini split at around 5
Actual Results
Got 8.5 instead 0.925844132419 0.93508349655 0.937732003001 0.963875889772
Versions
Linux-4.4.0-21-generic-x86_64-with-LinuxMint-18-sarah (‘Python’, ‘2.7.12 (default, Nov 20 2017, 18:23:56) \n[GCC 5.4.0 20160609]’) (‘NumPy’, ‘1.11.2’) (‘SciPy’, ‘0.18.1’) (‘Scikit-Learn’, ‘0.18’)
About this issue
- Original URL
- State: closed
- Created 6 years ago
- Comments: 16 (11 by maintainers)
Could you edit your post to fix the indentation and use triple back-quotes aka fenced code blocks to format your snippet. Bonus points if you use syntax highlighting with
py
for python snippets.Also can you try with the latest released version of scikit-learn (0.19.1) to see whether the problem still exists?