############################################################
dpl=TRUE
#dpl=FALSE
source('../../robfuns.R')
source('../../rob-utility-funs.R')

###################################################
if(1) { cat("### read in data, make y a factor, log nTab\n")

#in-sample
td1 = read.table('td1.dat',header=TRUE)
td1[[1]] = as.factor(td1[[1]]) #td1[[1]] is purchase, our depedent binary variable.
td1$nTablog = log(td1$nTab+1)

#out-sample
td2 = read.table('td2.dat',header=TRUE)
td2[[1]] = as.factor(td2[[1]]) #td2[[1]] is purchase, our depedent binary variable.
td2$nTablog = log(td2$nTab+1)

#print out summary of data
printfl(summary(td1),doIt=dpl,fnm='tab-sum.rtxt')
printfl(summary(td2),doIt=dpl,fnm='tab2-sum.rtxt')
}
###################################################
if(1) { cat("### fit logits and get predictions\n")
lgt = glm(purchase~nTab+moCbook+iRecMer1+llDol,td1,family=binomial)
lgt1 = glm(purchase~llDol,td1,family=binomial)
lgtl = glm(purchase~nTablog+moCbook+iRecMer1+llDol,td1,family=binomial)

#in-sample
phatl = predict(lgtl,type='response')
phat = predict(lgt,type='response')
phat1 = predict(lgt1,type='response')

#out-sample
ophatl = predict(lgtl,newdata=td2,type='response')
ophat = predict(lgt,newdata=td2,type='response')
ophat1 = predict(lgt1,newdata=td2,type='response')
}

###################################################
if(1) { cat("### lift\n")
iy = as.numeric(td1$purchase)-1 #in sample y
oy = as.numeric(td2$purchase)-1 #out sample y

if(dpl) pdf(file='tab-oos-lift.pdf',height=10,width=12)
par(mai=c(1,1,.5,.5))

olift = liftf(oy,ophat,dopl=FALSE)
olift1 = liftf(oy,ophat1,dopl=FALSE)
oliftl = liftf(oy,ophatl,dopl=FALSE)
rgy = range(c(olift,oliftl,olift1))
ii = (1:length(olift))/length(olift)
plot(ii,olift,type='n',lwd=2,xlab='% tried',ylab='% of successes',cex.lab=2,ylim=rgy)
lines(ii,olift,col='red')
lines(ii,oliftl,col='blue')
lines(ii,olift1,col='green')
abline(0,1,lty=2)
legend('bottomright',legend=c('no log','log','llDol'),col=c('red','blue','green'),lwd=3)


if(dpl) pdf(file='phat-oos.pdf',height=6,width=12)
boxplot(cbind(ophat,ophatl,ophat1))
abline(h=.5,col="red",lty=3)
if(dpl) dev.off()

}
##################################################
if(1) {cat("### make confusion matrix at s=.02\n")
ddf = data.frame(respond = oy,phatB=ophatl)
phat.b = ddf$phatB
y = oy

s=.02
yhat = ifelse(phat.b<s,0,1)
tbl = table(yhat,y)

printfl(tbl,dpl,"conf.rtxt")

}
###########################################################
if(0) {cat("###compute lift and roc\n")
ns=1000
sv = seq(from=.0,to=.99,length.out=ns)
FP=rep(0,ns)
TP=rep(0,ns)
N=rep(0,ns)
n0=sum(y==0)
for(i in 1:ns) {
   N[i] = sum(phat.b>sv[i])/length(y)
   TP[i] = sum((phat.b>sv[i]) & (y==1))/sum(y==1)
   FP[i] = sum((phat.b>sv[i]) & (y==0))/sum(y==0)
}

if(dpl) pdf(file='N-TP-FP.pdf',height=4,width=16)

par(mfrow=c(1,3))
par(mai=c(0.9,0.9,.4,.4))
plot(sv,N,xlab='s',type='l',col='blue',cex.lab=2.0)
title(main='get fewer yhat=1 as s gets bigger',cex.main=2)
plot(sv,TP,xlab='s',type='l',col='blue',cex.lab=2.0)
title(main='get fewer of the  y=1 as s gets bigger',cex.main=2)
plot(sv,FP,xlab='s',type='l',col='blue',cex.lab=2.0)
title(main='fewer wrong yhat=1 as s gets bigger',cex.main=2)

if(dpl) dev.off()

if(dpl) pdf('roc-lift.pdf',height=5,width=14)

par(mai=c(0.9,0.9,.4,.4))
par(mfrow=c(1,2))
plot(FP,TP,type='l',col='blue',cex.lab=2.0)
abline(0,1,lty=2)
title(main='ROC',cex.main=2)
plot(N,TP,type='l',col='blue',cex.lab=2.0)
abline(0,1,lty=2)
title(main='Lift',cex.main=2)
#temp = liftf(y,phat.b,dopl=FALSE)
#lines((1:length(y))/length(y),temp,col='red',lty=3)

if(dpl) dev.off()
}
###########################################################
if(1) {cat("###plot ROC and compute auc using R package pROC\n")


library(pROC)

if(dpl) pdf(file="roc.pdf",height=5,width=12)
par(mfrow=c(1,2))
par(mai=c(1,1,.5,.5))

rocR = roc(response=oy,predictor=ophatl)
AUC = auc(rocR)
plot(rocR)
title(main=paste("log model AUC= ",round(AUC,2)))

rocR = roc(response=oy,predictor=ophat1)
AUC = auc(rocR)
plot(rocR)
title(main=paste("llDol model AUC= ",round(AUC,2)))


if(dpl) dev.off()


}

###################################################
if(dpl) rm(list=ls())
