##################################################
##  imports
import numpy as np
import pandas as pd

from keras import models
from keras import layers
from keras import regularizers

from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix

import matplotlib.pyplot as plt
import seaborn; seaborn.set()

##################################################
### data
trainD = pd.read_csv("http://www.rob-mcculloch.org/data/mnist-train.csv")
trD = trainD.to_numpy()
testD = pd.read_csv("http://www.rob-mcculloch.org/data/mnist-test.csv")
teD = testD.to_numpy()

ytr = trD[:,-1]; xtr = trD[:,0:-1]
yte = teD[:,-1]; xte = teD[:,0:-1]
nx = xtr.shape[1]

#dist of y
print(pd.Series(ytr).value_counts())
# range of x
print('min x:',xtr.min())
print('max x:',xtr.max())

# we need to dummy up (one-hot) y
one_hot = LabelBinarizer()
ytrD = one_hot.fit_transform(ytr)
yteD = one_hot.fit_transform(yte)

##################################################
### keras

nmod = models.Sequential()

#l1p = .000001
l1p = .0001 #L1 penalty
## at two hidden layers with 200 units and tanh activation and L1 regularization
nmod.add(layers.Dense(units=200,activation='tanh',kernel_regularizer=regularizers.l1(l1p),input_shape=(nx,)))
nmod.add(layers.Dense(units=200,activation='tanh',kernel_regularizer=regularizers.l1(l1p)))
# final layer is multinomial outputt so we use softmax with num units = num categories for y
nmod.add(layers.Dense(units=10,activation='softmax'))

#nmod.compile(loss='categorical_crossentropy',optimizer='rmsprop',metrics=['accuracy'])
nmod.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy'])

nhist = nmod.fit(xtr,ytrD,epochs=100,verbose=1,batch_size=100, validation_data=(xte,yteD))


##################################################
### plot
trL = nhist.history['loss']
teL = nhist.history['val_loss']

epind = range(1,len(teL)+1)
plt.plot(epind,trL,"r--")
plt.plot(epind,teL,"b--")
plt.xlabel("epoch number"); plt.ylabel("cross entropy loss")
plt.legend(['train','test'])
#plt.show()
plt.savefig("keras_simple-mnist.pdf")

phat = nmod.predict(xte)
yhat = [np.argmax(x) for x in phat]

##crosstab
ctab = pd.crosstab(pd.Series(yhat),pd.Series(yte))
print(ctab)

print("the accuracy is: ",accuracy_score(yhat,yte))
#print(confusion_matrix(yhat,yte))

