##################################################
##################################################
## chapter 2 of Deep Learning with Python, 2nd Ed, Chollet

##################################################
## imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt


from tensorflow.keras.datasets import mnist
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import regularizers

from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix

##################################################
## look at data
(train_images,train_labels), (test_images,test_labels) = mnist.load_data()

## check basic dimensions
print(train_images.shape)
print(train_labels.shape)
print(test_images.shape)
print(test_labels.shape)
print(train_labels)
print(test_labels)

ntr = len(train_labels)
nte = len(test_labels)
print(f'the number of train and test are: {ntr}, {nte}')

## check the digit counts
print(pd.Series(train_labels).value_counts()/ntr)
print(pd.Series(test_labels).value_counts()/nte)

## plot the 4th digit in the train data
digit = train_images[4]
plt.imshow(digit,cmap = plt.cm.binary)
plt.show()

## check the pixels are in [0,255]
print(train_images.min())
print(train_images.max())

##################################################
## preparing the image data

## change the 28 by 28 images to a 28^2 vector
## change the type to double

train_images = train_images.reshape((ntr,28*28))
train_images = train_images.astype("float32")/255
test_images = test_images.reshape((nte,28*28))
test_images = test_images.astype("float32")/255

##################################################
## model

model = keras.Sequential([
   layers.Dense(512,activation="relu"),
   #layers.Dense(512,activation="relu",kernel_regularizer=regularizers.l1(0.1)),
   layers.Dense(10,activation="softmax")
])

##################################################
## compilaton step

model.compile(optimizer = "rmsprop",
   loss = "sparse_categorical_crossentropy",
   metrics = ["accuracy"])


##################################################
## fit

nhist = model.fit(train_images,train_labels, epochs = 10, batch_size = 128, validation_data = (test_images,test_labels))

##################################################
## predictions
## just do first 10

test_digits = test_images[0:10]
predictions = model.predict(test_digits)

# for each test image you get a probability vector
print(type(predictions))
print(predictions.shape)
print(predictions.dtype)

## should sum to 1
print(predictions.sum(axis=1))

#  get predictions from biggest probability
ypred = predictions.argmax(axis=1)
print(ypred[:2])
print(predictions[:2])

##################################################
## Evaluating the model on new data

test_loss , test_acc = model.evaluate(test_images, test_labels)
print(f'test acc: {test_acc}')

##################################################
### plot fitting by SGD 
trL = nhist.history['loss'] # loss on train
teL = nhist.history['val_loss'] # loss on val

epind = range(1,len(teL)+1)
plt.plot(epind,trL,"r--")
#plt.plot(epind,teL,"b--")
plt.plot(epind,teL,color='black', linestyle='dashed')
plt.xlabel("epoch number"); plt.ylabel("cross entropy loss")
plt.legend(['train','test'])
plt.show()
#plt.savefig("keras_simple-mnist.pdf")

##################################################
### predictions on all test data
ppred = model.predict(test_images)
ypred = ppred.argmax(axis=1)

##crosstab
ctab = pd.crosstab(pd.Series(ypred),pd.Series(test_labels))
print(ctab)

print("the accuracy is: ",accuracy_score(ypred,test_labels))
print(confusion_matrix(ypred,test_labels))

