
##################################################
### second diabetes

## what do the variables mean?
## http://statweb.lsu.edu/faculty/li/teach/exst7142/diabetes.html

##  Regression Analysis on Diabetes Dataset
##  Bin Li
##  EXST4142/7142
##  Introduction of the diabetes dataset
##  In this data set, there are 10 baseline variables, age, sex, body mass index, average blood pressure, 
##  and six blood serum measurements were obtained for each of 442 diabetes patients, 
##  as well as the response of interest, a quantitative measure of disease progression one year after baseline. 
##  This dataset was originally used in ``Least Angle Regression'' by Efron et al., 2004, in Annals of Statistics.

##################################################
### load code
source("load-monbart.R")
library(BART)

dobatch=TRUE

library(ggplot2)
library(gridExtra)

dfForViolins = function(x) {

   ## pull basic info from x
   n = nrow(x)
   p = ncol(x)
   vnms = colnames(x)

   ## factor for columns
   lbs = rep(vnms,rep(n,p))

   return(data.frame(x = as.double(x), vnm = factor(lbs)))
}

getPerVarUse = function(nvc) {
   # nvc: variable count, rows are draws, columns are variables
   p = ncol(nvc)
   nd = nrow(nvc)
   perMat = matrix(0,nd,p)
   colnames(perMat) = colnames(nvc)
   for(i in 1:nrow(nvc)) perMat[i,] = nvc[i,]/sum(nvc[i,])
   return(perMat)
}

##################################################
### get data

dbd = read.csv('diabetes.csv')
y = dbd$y
x = as.matrix(dbd[,2:11]) # don't use quadratic transformations

pnm = 'DoDiabetes'

## flip using lm
## flip to have all coefficients positive
xf = x
bhat = checklm$coef[-1]
for(i in 1:ncol(xf)) {
   if(bhat[i] < 0) xf[,i]  = -xf[,i]
}
checklmf = lm(y~.,as.data.frame(xf,y))
print(summary(checklmf))
summary(checklmf$coef[-1] - abs(bhat))

##################################################
## run lm
checklm = lm(y~.,as.data.frame(x,y))
lmshat = summary(checklm)$sigma
cat('lmshat: ',lmshat,'\n')

fnm = paste0(pnm,'_lm-summary_10var.txt')
if(dobatch) sink(fnm)
summary(checklm)
if(dobatch) sink()


##################################################
### run BART

nburn = 200; nd = 2000
theseed = 99 # Gretzky
set.seed(theseed)
bf = wbart(x,y,ndpost=nd,nskip=nburn)
cat('\n\n@@@@@ bf mean sigma: ',mean(bf$sigma[1:nd + nburn]),'\n\n')

if(0) { # plot bart sigma draws
plot(bf$sigma,col='green',pch=16,cex=.7)
abline(v=nburn,col='grey',lwd=2,lty=2)
abline(h=lmshat,col='red',lty=2,lwd=3)
abline(h = mean(bf$sigma[1:nd + nburn]), col='green',lty=1,lwd=3)
}

##################################################
### run BART with fewer trees for variable selection
set.seed(theseed)
bfvsntree = 40 
bfvs = wbart(x,y,ndpost=nd,nskip=nburn,ntree=bfvsntree)
cat('\n\n@@@@@ bfvs mean sigma: ',mean(bfvs$sigma[1:nd + nburn]),'\n\n')

## compare default BART with vs BART
qqplot(bfvs$sigma[1:nd + nburn],bf$sigma[1:nd + nburn])
abline(0,1)

##################################################
### plot BART variable selection

bfvsP = getPerVarUse(bfvs$varcount)
bfvsV = dfForViolins(bfvsP)
names(bfvsV) = c('pUse','var')

## boxplots
fnm = paste0(pnm,'_BART-vs_boxplots.pdf')
pdf(file=fnm,width=12,height=5)
boxplot(pUse~var,data = bfvsV,col='lightblue',cex.axis=1.5,ylab='percent of rules using variable',cex.lab=1.2)
tnm = paste0('BART variable selection, number of trees = ',bfvsntree)
title(main=tnm,cex.main=1.5)
dev.off()

## violins
fnm = paste0(pnm,'_BART-vs_violins.pdf')
pdf(file=fnm,width=12,height=5)
bfvsVp = ggplot(data=bfvsV) + geom_violin(mapping = aes(y=pUse,x=var),fill='lightblue') +
     ggtitle(tnm) +
      theme(axis.text=element_text(size=18), axis.title=element_text(size=15),plot.title = element_text(size=15,face="bold")) +
      xlab('') + ylab('percent of rules using variable')
bfvsVp
dev.off()

##################################################
### fit discovery

## make xm with nice column names
xm = cbind(x,-x)
colnames(xm) = c(paste0('Up-',colnames(x)),paste0('Dn-',colnames(x)))

## default mBART,  discovery
dbf = monbart(xm,y,ndpost=nd,nskip=nburn)

plot(dbf$sigma,col='cyan')
abline(v=nburn,col='grey',lwd=2,lty=2)
abline(h=lmshat,col='red',lty=2,lwd=3)
abline(h = mean(bf$sigma[1:nd + nburn]), col='green',lty=1,lwd=3)

## mBART with prior at BART sigma mean, discovery
dbfPri = monbart(xm,y,ndpost=nd,nskip=nburn,sigest = mean(bf$sigma),sigdf=5000,sigquant=.5)

ylm = range(c(dbfPri$sigma,dbf$sigma,lmshat))
plot(dbfPri$sigma,col='blue',ylim = ylm)
abline(v=nburn,col='grey',lwd=2,lty=2)
abline(h=lmshat,col='red',lty=2,lwd=3)
abline(h = mean(bf$sigma[1:nd + nburn]), col='green',lty=1,lwd=3)
abline(h = mean(dbf$sigma[1:nd + nburn]), col='cyan',lty=1,lwd=3)

##################################################
### fit discovery for variable selection

## mBART with prior at BART sigma mean
dbfvsntree = 20
dbfPrivs = monbart(xm,y,ndpost=nd,nskip=nburn,sigest = mean(bf$sigma),sigdf=5000,sigquant=.5,ntree = dbfvsntree)

## compare sigma draws
qqplot(dbfPrivs$sigma,dbfPri$sigma)
abline(0,1)

##################################################
### plot sigma draws

ylm = range(c(dbfPri$sigma,dbf$sigma,lmshat, dbfPrivs$sigma))
plot(dbfPrivs$sigma,col='blue',ylim = ylm)
points(bf$sigma,col='green',pch='.',cex=2)
abline(v=nburn,col='grey',lwd=2,lty=2)
abline(h=lmshat,col='red',lty=2,lwd=3)
abline(h = mean(bf$sigma[1:nd + nburn]), col='green',lty=4,lwd=3)
abline(h = mean(dbf$sigma[1:nd + nburn]), col='cyan',lty=5,lwd=3)
abline(h = mean(dbfPrivs$sigma[1:nd + nburn]), col='blue',lty=1,lwd=3)

##################################################
### plot discovery variable selection

userun = 'default'
#userun = 'prior'
#userun = 'prior-vs'

if(userun == 'prior-vs') {
   doD = dbfPrivs$nvcount
   tnm = paste0('Discovery variable selection, BART prior,  number of trees = ',dbfvsntree)
}
if(userun == 'prior') {
   doD = dbfPri$nvcount
   tnm = paste0('Discovery variable selection, BART prior,  number of trees = ',200)
}
if(userun == 'default') {
   doD = dbf$nvcount
   tnm = paste0('Discovery variable selection, default prior,  number of trees = ',200)
}
 
colnames(doD) = colnames(xm)
dbfvsP = getPerVarUse(doD)
dbfvsV = dfForViolins(dbfvsP)
names(dbfvsV) = c('pUse','var')

## 2 pane boxplots, make 2 separate data frames for up and down
dbfvsPU = dbfvsP[,1:10]
dbfvsVU = dfForViolins(dbfvsPU)
names(dbfvsVU) = c('pUse','var')
dbfvsPD = dbfvsP[,10 + 1:10]
dbfvsVD = dfForViolins(dbfvsPD)
names(dbfvsVD) = c('pUse','var')

## boxplots up and down together
fnm = paste0(pnm,'_',userun,'_DBART-vs_boxplots.pdf')
pdf(file=fnm,width=12,height=5)
boxplot(pUse~var,data = dbfvsV,col='lightblue',cex.axis=0.7,ylab='percent of rules using variable',cex.lab=1.2)
abline(v = 10.5,col='grey',lty=2)
title(main=tnm,cex.main=1.5)
dev.off()

## plot two data frames together with boxplots
fnm = paste0(pnm,'_',userun,'_DBART-vs_boxplots_2-pane.pdf')
pdf(file=fnm,width=12,height=5)
par(mfrow=c(2,1))
par(mai=c(.4,.5,.5,.5))
par(oma=c(.4,.5,.5,.5))
boxplot(pUse~var,data = dbfvsVU,col='lightblue',cex.axis=0.9,ylab='percent of rules using variable',cex.lab=1.2,main=tnm)
boxplot(pUse~var,data = dbfvsVD,col='lightblue',cex.axis=0.9,ylab='percent of rules using variable',cex.lab=1.2)
dev.off()

## plot two data frames together with violins

# make up plot
dbfvsVUJ = dbfvsVU
set.seed(34)
dbfvsVUJ$pUse = dbfvsVUJ$pUse + rnorm(nrow(dbfvsVUJ),mean=0,sd = .005)
dvU = ggplot(data=dbfvsVUJ) + geom_violin(mapping = aes(y=pUse,x=var),fill='lightblue') +
     ggtitle(tnm) +
      theme(axis.text=element_text(size=12), axis.title=element_text(size=8),plot.title = element_text(size=15,face="bold")) +
      xlab('') + ylab('percent of rules using variable')

# make down plot
dbfvsVDJ = dbfvsVD
set.seed(34)
dbfvsVDJ$pUse = dbfvsVDJ$pUse + rnorm(nrow(dbfvsVDJ),mean=0,sd = .005)
dvD = ggplot(data=dbfvsVDJ) + geom_violin(mapping = aes(y=pUse,x=var),fill='lightblue') +
     ggtitle('') +
      theme(axis.text=element_text(size=12), axis.title=element_text(size=8),plot.title = element_text(size=12,face="bold")) +
      xlab('') + ylab('percent of rules using variable')

pList = vector('list',2)
pList[[1]] = dvU; pList[[2]] = dvD
fnm = paste0(pnm,'_',userun,'_DBART-vs_violins_2-pane.pdf')
pdf(file=fnm,width=12,height=5)
grid.arrange(grobs = pList,nrow=2,ncol=1)
dev.off()

##################################################
### flip according to DBART
## NOTE that lm and DBART disagree on sign of hdl

xf = x
#bhat = checklm$coef[-1]
flipv = c(-1,-1,1,1,-1,1,-1,1,1,1)
names(flipv) = colnames(x)

for(i in 1:ncol(xf)) {
   xf[,i] = flipv[i] * xf[,i]
}
checkfdf = data.frame(y,xf)
checkflm = lm(y~.,checkfdf)
summary(checkflm)

## run monbart with flipped 

bfmfDef = monbart(xf,y,ndpost=nd,nskip=nburn)
bfmfPri = monbart(xf,y,ndpost=nd,nskip=nburn,sigest = mean(bf$sigma),sigdf=5000,sigquant=.5)
bfmfPriVs = monbart(xf,y,ndpost=nd,nskip=nburn,sigest = mean(bf$sigma),sigdf=5000,sigquant=.5,ntree = dbfvsntree)

##################################################
## plot sigma draws
fnm = paste0(pnm,'_x-flipped_sigma-draws.pdf')
pdf(file=fnm,width=12,height=8)
plot(bfmfDef$sigma,col='blue',ylab='sigma draws',xlab = 'MCMC index',pch=16,cex=.5)
points(bfmfPri$sigma,col='magenta')
points(bfmfPriVs$sigma,col='cyan')
abline(h = lmshat,col='red',lty=2,lwd=3)
points(bfmfDef$sigma,col='blue',pch=16,cex=.5)
abline(h = mean(bf$sigma[nburn + 1:nd]),lty=4,lwd=3,col='green')
abline(h = mean(bfmfDef$sigma[nburn + 1:nd]),lty=1,lwd=3,col='blue')
title(main='Monotone BART with flipped x',cex.main=1.2)
dev.off()

## collect runs

#bf = wbart(x,y,ndpost=nd,nskip=nburn)

#bfvsntree = 40 
#bfvs = wbart(x,y,ndpost=nd,nskip=nburn,ntree=bfvsntree)

#dbf = monbart(xm,y,ndpost=nd,nskip=nburn)
#dbfPri = monbart(xm,y,ndpost=nd,nskip=nburn,sigest = mean(bf$sigma),sigdf=5000,sigquant=.5)
#dbfPrivs = monbart(xm,y,ndpost=nd,nskip=nburn,sigest = mean(bf$sigma),sigdf=5000,sigquant=.5,ntree = dbfvsntree)

#bfmfDef = monbart(xf,y,ndpost=nd,nskip=nburn)
#bfmfPri = monbart(xf,y,ndpost=nd,nskip=nburn,sigest = mean(bf$sigma),sigdf=5000,sigquant=.5)
#bfmfPriVs = monbart(xf,y,ndpost=nd,nskip=nburn,sigest = mean(bf$sigma),sigdf=5000,sigquant=.5,ntree = dbfvsntree)

## boxplots
fnm = paste0(pnm,'_sigma-draws_boxplots.pdf')
pdf(file=fnm,width=14,height=5)
iis = nburn + 1:nd
sDRM = cbind(bf$sigma[iis], bfvs$sigma[iis], bfmfDef$sigma[iis], bfmfPri$sigma[iis], dbfPrivs$sigma[iis], bfmfDef$sigma[iis], 
    bfmfPri$sigma[iis])
colnames(sDRM) = c('BART-default','BART-vs','mBART-Disc-default','mBART-Disc-Pri','mBART-Disc-Pri-vs','mBART-default-xflipped',
    'mBART-Pri-xflipped')
boxplot(sDRM,cex.axis=0.8,ylab='sigma draws')
abline(h = lmshat,col='red',lty=2)
abline(h = mean(bf$sigma[iis]),col = 'green', lty=4)
legend('topright',legend=c('linear least squares','BART default mean'),col = c('red','green'),lty=c(2,4),lwd=c(2,2),bty='n')
title(main='MCMC sigma draws from BART, mBART, and mBART discovery',cex.main=1.0)
dev.off()

## violins

SDV = dfForViolins(sDRM)
names(SDV) = c('sigma','BARTmod')


fnm = paste0(pnm,'_sigma-draws_violins.pdf')
pdf(file=fnm,width=12,height=5)
sigDVio = ggplot(data=SDV) + geom_violin(mapping = aes(y=sigma,x=BARTmod),fill='lightblue') +
     ggtitle('MCMC sigma draws from BART, mBART, and mBART discovery') +
      theme(axis.text=element_text(size=10), axis.title=element_text(size=15),plot.title = element_text(size=15,face="bold")) +
      xlab('') + ylab('sigma draws') +
      geom_hline(yintercept = lmshat,color='red', linetype = 'dashed') + 
      geom_hline(yintercept = mean(bf$sigma[iis]), color = 'green', linetype = 'dotdash', linewidth = 1.25)
sigDVio
dev.off()


## ----------------------
bfvsV = dfForViolins(bfvsP)
names(bfvsV) = c('pUse','var')

## boxplots
fnm = paste0(pnm,'_BART-vs_boxplots.pdf')
pdf(file=fnm,width=12,height=5)
boxplot(pUse~var,data = bfvsV,col='lightblue',cex.axis=1.5,ylab='percent of rules using variable',cex.lab=1.2)
tnm = paste0('BART variable selection, number of trees = ',bfvsntree)
title(main=tnm,cex.main=1.5)
dev.off()

## violins
fnm = paste0(pnm,'_BART-vs_violins.pdf')
pdf(file=fnm,width=12,height=5)
bfvsVp = ggplot(data=bfvsV) + geom_violin(mapping = aes(y=pUse,x=var),fill='lightblue') +
     ggtitle(tnm) +
      theme(axis.text=element_text(size=18), axis.title=element_text(size=15),plot.title = element_text(size=15,face="bold")) +
      xlab('') + ylab('percent of rules using variable')
bfvsVp
dev.off()



 


