scikit-learn: Decision Tree split looks wrong
Description
When looking for an optimal split on a decision tree, DecisionTreeClassifier gives a seemingly wrong answer
Steps/Code to Reproduce
raw dataset used : https://github.com/juan-barragan/titanic/blob/master/titanic_ds.csv
import pandas as pd
import seaborn
import matplotlib.pyplot as plt
from sklearn import tree
import graphviz
import numpy as np
def gini_by_age(df, t):
df['age_group'] = df['age'].apply(lambda row : 0 if row <= t else 1)
kids = df[df['age_group'] == 0]
kids0 = kids[kids['survived'] == 0]
kids1 = kids[kids['survived'] == 1]
adults = df[df['age_group'] == 1]
adults0 = adults[adults['survived'] == 0]
adults1 = adults[adults['survived'] == 1]
gk = 1 - (len(kids0)**2 + len(kids1)**2)/float(len(kids))**2
ga = 1 - (len(adults0)**2 + len(adults1)**2)/float(len(adults))**2
return gk + ga
def plot_gini_by_age(df):
ages = range(6,25)
y = [gini_by_age(df, a) for a in ages]
plt.plot(ages, y)
plt.show()
def use_tree(df):
X = np.array(df['age']).reshape((len(df['age']),1))
y = df['survived']
clf = tree.DecisionTreeClassifier(max_depth=1).fit(X,y)
dot_data = tree.export_graphviz(clf, out_file=None)
graph = graphviz.Source(dot_data)
graph.render("age")
titanic_df = pd.read_csv("titanic_ds.csv")
ages_cov = titanic_df[['age', 'survived']].dropna()
plot_gini_by_age(ages_cov)
use_tree(ages_cov)
print gini_by_age(ages_cov, 5)
print gini_by_age(ages_cov, 6)
print gini_by_age(ages_cov, 8.5)
print gini_by_age(ages_cov, 15)
Expected Results
gini split at around 5
Actual Results
Got 8.5 instead 0.925844132419 0.93508349655 0.937732003001 0.963875889772
Versions
Linux-4.4.0-21-generic-x86_64-with-LinuxMint-18-sarah (‘Python’, ‘2.7.12 (default, Nov 20 2017, 18:23:56) \n[GCC 5.4.0 20160609]’) (‘NumPy’, ‘1.11.2’) (‘SciPy’, ‘0.18.1’) (‘Scikit-Learn’, ‘0.18’)
About this issue
- Original URL
- State: closed
- Created 6 years ago
- Comments: 16 (11 by maintainers)
Could you edit your post to fix the indentation and use triple back-quotes aka fenced code blocks to format your snippet. Bonus points if you use syntax highlighting with
pyfor python snippets.Also can you try with the latest released version of scikit-learn (0.19.1) to see whether the problem still exists?