
##################################################
### import

### basic 
import matplotlib.pyplot as plt
import numpy as np
import scipy
import pandas as pd
import math

import seaborn as sns; sns.set()
#%matplotlib inline

##sklearn learners
from sklearn.tree import DecisionTreeRegressor 

##sklearn metrics
from sklearn.metrics import mean_squared_error

##sklearn model selection
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import validation_curve
from sklearn.model_selection import GridSearchCV

## to visualize a tree
import pydotplus
from sklearn import tree
import os

##################################################
### read in boston data (this data is also in sklearn.datasets)
#bd = pd.read_csv("http://www.rob-mcculloch.org/data/Boston.csv")
bd = pd.read_csv("https://bitbucket.org/remcc/rob-data-sets/downloads/Boston.csv")

xpdf = bd.iloc[:,[7,12]] #just use dis and lstat
x = xpdf.to_numpy()
y = bd['medv'].to_numpy()

## no need to standardize x's !!!!!!!!!!!!!!!!!!!!!!!

##################################################
### simple decision tree

# tree with at most 10 bottom nodes
tmod = DecisionTreeRegressor(max_leaf_nodes=10)
tmod.fit(x,y)

## look at in-sample fits
yhat = tmod.predict(x)

plt.scatter(y,yhat,c='blue')
plt.xlabel('y'); plt.ylabel('yhat')
plt.plot(y,y,c='red')
plt.show()
print("number of bottom  nodes: ",pd.Series(yhat).nunique())

##################################################
### variable importance
## we only have two variables, but for future reference note the variable importance 
varimp = tmod.feature_importances_
print('variable importances:',varimp)
print(pd.Series(tmod.feature_importances_,index=xpdf.columns.values))

##################################################
### plot a tree
### this seems a bit cludgy but the below worked on my linux system.

dot_data = tree.export_graphviz(tmod,out_file=None,feature_names=xpdf.columns.values)
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_png("tree.png") #write graph to file
## the picture is now in the file tree.png
os.system('eog tree.png') #display file from python command line, unix command eog


