dpl=TRUE
#dpl=FALSE
source("notes-funs.R")

##################################################
if(1) {cat("### get data\n")
library(ElemStatLearn)

print(names(mixture.example)) # mixture example from ElemStatLearn

x = mixture.example$x #n x 2 matrix of [x1,x2]
y = mixture.example$y
xnew = mixture.example$xnew
px1 = mixture.example$px1
px2 = mixture.example$px2

#note: xnew is expand.grid(px1,px2)
#check
cat("these should be the same\n")
print(summary(sort(unique(xnew[,1]))-px1))
}
##################################################
if(1) {cat("### plot data\n")
if(dpl) pdf(file="plot-xy-sim-esl.pdf",height=10,width=12)
plot(x, col=ifelse(y==1, "coral", "cornflowerblue"),xlab="x1",ylab="x2",
           cex.axis=1.5,cex.lab=1.5,pch=16,cex=1.5)
legend("topright",legend=c("y=1","y=0"),col=c("coral", "cornflowerblue"),
           pch=c(16,16),cex=1.5)
if(dpl) dev.off()
}

############################################################
#function to fit logit and plot fit with x=(x1,x2) and binary y
fitplot = function(dtrain,xtest,x,xt) {
#dtrain: transformations of x and y, train
#xtest: transformations x test
#x : original x (without transformations, needed for plotting) train
#xt: original x (without transformations, needed for plotting) test

   px1 = sort(unique(xt[,1]))
   px2 = sort(unique(xt[,2]))

   lrfit = glm(y~.,dtrain,family="binomial")
   phat = predict(lrfit,xtest,type="response")
   probMat = matrix(phat,length(px1),length(px2))
   contour(px1,px2,probMat,levels=.5,xlab="x1",ylab="x2",
              cex.axis=1.5,cex.lab=1.5)
   points(x,col=ifelse(y==1, "coral", "cornflowerblue"),pch=16,cex=1.5)
   points(xt,col=ifelse(phat>.5, "coral", "cornflowerblue"),pch=".",cex=3)
   return(phat)
}

##function to plot decision boundary
cplot = function(phat,px1,px2)
## phat is fit on expand.grid(px1,px2)
{
   probMat = matrix(phat,length(px1),length(px2))
   contour(px1,px2,probMat,levels=.5,xlab="x1",ylab="x2",
              cex.axis=1.5,cex.lab=1.5)

   xt = expand.grid(px1,px2)
   points(xt,col=ifelse(phat>.5, "coral", "cornflowerblue"),pch=".",cex=3)
}
##################################################
if(1) {cat("### fit logits\n")

if(dpl) pdf(file="logit-fits.pdf",height=8,width=12)

par(mfrow=c(2,3))

dftrain = data.frame(x1=x[,1], x2=x[,2], y)
dftest = data.frame(x1=xnew[,1], x2=xnew[,2])
fitplot(dftrain,dftest,x,xnew)
title(main="linear",cex.main=1.5)

dftrain = data.frame(x=poly(x[,1], x[,2], degree=2, raw=TRUE), y)
dftest  = data.frame(x=poly(xnew[,1], xnew[,2], degree=2, raw=TRUE))
fitplot(dftrain,dftest,x,xnew)
title(main="second order",cex.main=1.5)

dftrain = data.frame(x=poly(x[,1], x[,2], degree=3, raw=TRUE), y)
dftest  = data.frame(x=poly(xnew[,1], xnew[,2], degree=3, raw=TRUE))
fitplot(dftrain,dftest,x,xnew)
title(main="third order",cex.main=1.5)

dftrain = data.frame(x=poly(x[,1], x[,2], degree=7, raw=TRUE), y)
dftest  = data.frame(x=poly(xnew[,1], xnew[,2], degree=7, raw=TRUE))
fitplot(dftrain,dftest,x,xnew)
title(main="7th order",cex.main=1.5)

dftrain = data.frame(x=poly(x[,1], x[,2], degree=10, raw=TRUE), y)
dftest  = data.frame(x=poly(xnew[,1], xnew[,2], degree=10, raw=TRUE))
fitplot(dftrain,dftest,x,xnew)
title(main="10th order",cex.main=1.5)

dftrain = data.frame(x=poly(x[,1], x[,2], degree=15, raw=TRUE), y)
dftest  = data.frame(x=poly(xnew[,1], xnew[,2], degree=15, raw=TRUE))
fitplot(dftrain,dftest,x,xnew)
title(main="15th order",cex.main=1.5)

## note: what does poly do?
#> 15*14/2 + 2*15
#[1] 135
#> dim(dftest)
#[1] 6831  135

if(dpl) dev.off()
}

##################################################
if(1) {cat("### fit Lasso\n")

if(dpl) pdf(file="lasso-fit.pdf",height=10,width=12)

degL=15
dftrain = data.frame(x=poly(x[,1], x[,2], degree=degL, raw=TRUE), y)
dftest  = data.frame(x=poly(xnew[,1], xnew[,2], degree=degL, raw=TRUE))
library(glmnet)
bigXtrain = as.matrix(dftrain[,-ncol(dftrain)], nrow=nrow(dftrain)) #X matrix, y dropped
bigXtest = as.matrix(dftest, nrow=nrow(dftest))

set.seed(14)
cvLfit = cv.glmnet(x=bigXtrain, y, family="binomial", nfolds=10)

par(mfrow=c(2,2))
plot(cvLfit)
plot(cvLfit$glmnet.fit)

probL = predict(cvLfit$glmnet.fit, bigXtest, type="response", s=cvLfit$lambda.min)
probL1 = predict(cvLfit$glmnet.fit, bigXtest, type="response", s=cvLfit$lambda.1se)
#plot(probL,probL1)
#abline(0,1,col="red",lwd=2)
#title(main="Lasso: min lambda phat vs. min1se",cex.main=1.5)

cplot(probL,px1,px2)
points(x,col=ifelse(y==1, "coral", "cornflowerblue"),pch=16,cex=1.5)
title(main="Lasso: min lambda decision boundary",cex.main=1.5)

cplot(probL1,px1,px2)
points(x,col=ifelse(y==1, "coral", "cornflowerblue"),pch=16,cex=1.5)
title(main="Lasso: min lambda 1se decision boundary",cex.main=1.5)

if(dpl) dev.off()

#see Lasso coefficents
bhatL = coef(cvLfit$glmnet.fit,s=cvLfit$lambda.min)[,1] #[,1] gets rid of sparse matrix format
cat("Lasso coefficients, min lambda\n")
print(bhatL[abs(bhatL)>0])
bhatL1 = coef(cvLfit$glmnet.fit,s=cvLfit$lambda.1se)[,1] #[,1] gets rid of sparse matrix format
cat("Lasso coefficients, 1se lambda\n")
print(bhatL1[abs(bhatL1)>0])

printfl(bhatL[abs(bhatL)>0],dpl,"lasso-coefs.rtxt")
printfl(bhatL1[abs(bhatL1)>0],dpl,"lasso1-coefs.rtxt")
}

############################################################
if(1) {cat("### do Ridge\n")
degR=15
dftrain = data.frame(x=poly(x[,1], x[,2], degree=degR, raw=TRUE), y)
dftest  = data.frame(x=poly(xnew[,1], xnew[,2], degree=degR, raw=TRUE))
bigXtrain = as.matrix(dftrain[,-ncol(dftrain)], nrow=nrow(dftrain))
bigXtest = as.matrix(dftest, nrow=nrow(dftest))

sink("dimofx.txt")
print(dim(bigXtest))
sink()

set.seed(14)
cvRfit = cv.glmnet(x=bigXtrain, y, family="binomial", nfolds=10,alpha=0)

if(dpl) pdf(file="ridge-fit.pdf",height=10,width=12)
par(mfrow=c(2,2))
plot(cvRfit)
plot(cvRfit$glmnet.fit)

probR = predict(cvRfit$glmnet.fit, bigXtest, type="response", s=cvRfit$lambda.min)
probR1 = predict(cvRfit$glmnet.fit, bigXtest, type="response", s=cvRfit$lambda.1se)
#plot(probR,probR1)
#abline(0,1,col="red",lwd=2)
#title(main="Ridge: min lambda phat vs. min1se",cex.main=1.5)

cplot(probR,px1,px2)
points(x,col=ifelse(y==1, "coral", "cornflowerblue"),pch=16,cex=1.5)
title(main="Ridge: min lambda decision boundary",cex.main=1.5)

cplot(probR1,px1,px2)
points(x,col=ifelse(y==1, "coral", "cornflowerblue"),pch=16,cex=1.5)
title(main="Ridge: min lambda 1se decision boundary",cex.main=1.5)
if(dpl) dev.off()

#get big coefficients
coefR = coef(cvRfit$glmnet.fit, s=cvRfit$lambda.min)[,1]
coefR = sort(abs(coefR),decreasing=TRUE)
print(coefR)
printfl(coefR[1:20],dpl,"ridge-coef-simex.rtxt")
}

##################################################
if(dpl) rm(list=ls())
