############################################################
############################################################
## file: Hitters_var-sel_cv-bic.R
## Do variable selection using Hitters data 
## Using validation, cross-validation, and BIC
############################################################
############################################################
library(leaps) #library to do all subsets for linear regression
dpl=TRUE #if true, write plots to a pdf file

##------------------------------------------------------------
## Get the data
library(ISLR) #R package for ISLR book
print(names(Hitters))
print(dim(Hitters))
## Salary has some missing values so let's get rid of the observations
cat('number missing for Salary: ',sum(is.na(Hitters$Salary)),'\n')
## Drop observations with missing
Hitters=na.omit(Hitters) 
cat('number missing for Salary after drop: ',sum(is.na(Hitters$Salary)),'\n')
cat("dim after dropping missing:\n")
print(dim(Hitters))

##------------------------------------------------------------
##Do regsubsets (all subsets) with all data 
regfit.best=regsubsets(Salary~.,data=Hitters,nvmax=19,nbest=1,method="exhaustive") #nvmax: max number of vars
regsum = summary(regfit.best)
print(regsum$which[1:5,]) #which tells you which variables are in the selected subset

if(dpl) pdf(file='Hitters_rsq_alldat.pdf',height=10,width=12)
plot(regsum$rsq,xlab='num var (k)',ylab='R-squared',cex.lab=1.5,type='b',col='blue')
if(dpl) dev.off()
 
## to get the coefficients use coef
print(coef(regfit.best,2)) #coefficients for best model with 2 x variables


##--------------------------------------------------
## function to do rmse for k in 1:p
dovalbest = function(object,newdata,ynm)
#object: regsubsets on train
#newdata: test data
#ynm: name of y in data frame.
{
form = as.formula(object$call[[2]])
p=ncol(newdata)-1
rmsev = rep(0,p)
test.mat = model.matrix(form,newdata)
for(k in 1:p) {
   coefk = coef(object,id=k)
   xvars = names(coefk)
   pred = test.mat[,xvars] %*% coefk
   rmsev[k] = sqrt(mean((newdata[[ynm]]-pred)^2))
}
return(rmsev)
}
##------------------------------------------------------------
##do validation approach several times
ntry=100
p=ncol(Hitters)-1
resmat = matrix(0,p,ntry) #each row for num vars, each col for new train/test draw
set.seed(14)
for(i in 1:ntry) {
   train = sample(1:nrow(Hitters),floor(nrow(Hitters)/2)) #indices of train
   regfit.best=regsubsets(Salary~.,data=Hitters[train,],nvmax=19,nbest=1,method="exhaustive")
   resmat[,i]=dovalbest(regfit.best,Hitters[-train,],'Salary')
}
mresmat = apply(resmat,1,mean) #average across columns
##--------------------------------------------------
##plot results of repeated train/val
if(dpl) pdf(file='Hitters_train-val.pdf',height=10,width=12)
plot(mresmat,xlab='num vars',ylab='rmse',type='b',col='blue',pch=19,cex.lab=1.5)
if(dpl) dev.off()

##--------------------------------------------------
##Fit using number of vars chosen by train/validation and all the data.
kopt = 6 #optimal k=number of vars: chosen by eye-balling plot
regfit.best=regsubsets(Salary~.,data=Hitters,nvmax=kopt,nbest=1,method="exhaustive")
xmat = model.matrix(Salary~.,Hitters)
ddf = data.frame(xmat[,-1],Salary=Hitters$Salary) #don't use intercept (-1), and in y=Salary
nms = c(names(coef(regfit.best,kopt))[-1],"Salary")
ddfsub = ddf[,nms] #drop all vars except those names by the coef at kopt
thereg = lm(Salary~.,ddfsub)
print(summary(thereg))

##------------------------------------------------------------
## bic
regfit.best=regsubsets(Salary~.,data=Hitters,nvmax=19,nbest=1,method="exhaustive")
if(dpl) pdf(file="bic-best-on-train_Hitters.pdf",height=10,width=12)
plot(regfit.best) #will plot models ordered by BIC (the default)
for(i in 1:ncol(Hitters)) abline(v=i,col="red",lty=2) #so you can see which var it is.
if(dpl) dev.off()

sumreg = summary(regfit.best)
if(dpl) pdf(file="Hitters-BIC.pdf",height=10,width=12)
plot(sumreg$bic,xlab="k",ylab="BIC",type='b',col="blue",lwd=2,cex.lab=1.5)
if(dpl) dev.off()


#--------------------------------------------------
if(dpl) rm(list=ls())

