#####################################################################
### Compare Lasso,Ridge, and Elastic Net on Diabetes Data
#####################################################################
dpl=FALSE
#dpl=TRUE
library(glmnet)
source("do-stepcv.R") #has getfolds function
################################################
if(1) {cat("### read in data\n")
ddf = read.csv("diabetes.csv")
y = ddf$y
x = as.matrix(ddf[,2:ncol(ddf)])
}
################################################
if(1) {cat("### make folds and run and plot\n")
n=length(y)
nfold=10
set.seed(99)
fid1 =  getfolds(nfold,n)

if(dpl) pdf(file="plot-fold-id.pdf",height=10,width=14)
plot(fid1,xlab="obs number",ylab="fold id",cex.axis=1.5,cex.lab=1.5)
if(dpl) dev.off()

#fit three glmnets
#lasso
cvdgnL = cv.glmnet(x,y,nfolds=10,foldid=fid1,family="gaussian",standardize=FALSE,alpha=1,intercept=FALSE)
#ridge
cvdgnR = cv.glmnet(x,y,nfolds=10,foldid=fid1,family="gaussian",standardize=FALSE,alpha=0,intercept=FALSE)
#enet
cvdgnE = cv.glmnet(x,y,nfolds=10,foldid=fid1,family="gaussian",standardize=FALSE,alpha=.5,intercept=FALSE)

if(dpl) pdf(file="cv-all-three-sep.pdf",height=5,width=10)
par(mfrow=c(2,3))
plot(cvdgnL$glmnet.fit); plot(cvdgnR$glmnet.fit) ; plot(cvdgnE$glmnet.fit)
plot(cvdgnL); plot(cvdgnR) ; plot(cvdgnE)
if(dpl) dev.off()


#plot them together so we can easily compare
if(dpl) pdf(file="cv-all-three-together.pdf",height=8,width=12)
par(mfrow=c(1,1))
cvmL = cvdgnL$cvm; lmL = cvdgnL$lambda
cvmR = cvdgnR$cvm; lmR = cvdgnR$lambda
cvmE = cvdgnE$cvm; lmE = cvdgnE$lambda
plot(range(log(c(lmL,lmR,lmE))),range(sqrt(c(cvmL,cvmR,cvmE))),xlab="lambda",ylab="loss",cex.axis=1.5,cex.lab=1.5)
lines(log(lmL),sqrt(cvmL),col="red",type="b",lwd=2)
lines(log(lmL),sqrt(cvdgnL$cvlo),col="red",lty=3,type="l");lines(log(lmL),sqrt(cvdgnL$cvup),col="red",lty=3,type="l")
lines(log(lmR),sqrt(cvmR),col="blue",type="b",lwd=2)
lines(log(lmR),sqrt(cvdgnR$cvlo),col="blue",lty=3,type="l");lines(log(lmR),sqrt(cvdgnR$cvup),col="blue",lty=3,type="l")
lines(log(lmE),sqrt(cvmE),col="magenta",type="b",lwd=2)
lines(log(lmE),sqrt(cvdgnE$cvlo),col="magenta",lty=3,type="l");lines(log(lmE),sqrt(cvdgnE$cvup),col="magenta",lty=3,type="l")
legend("topleft",legend=c("lasso","ridge","enet"),col=c("red","blue","magenta"),lwd=c(2,2,2),cex=2)
if(dpl) dev.off()

}
###############################################
if(dpl) rm(list=ls())
