##################################################
## libraries
library(ggplot2)
library(viridis)

##################################################
## simulate data
n = 500 #sample size
beta = c(1,.5) # true (intercept, slope)

## storage for simulation results
## py1 = F(beta[1] + beta[2]x), y ~ Bern(py1)
x = rep(0,n)
y = rep(0,n)
py1 = rep(0,n)

## simulate
set.seed(34)
for(i in 1:n) {
  x[i] = rnorm(1)
  py1[i] = 1/(1 + exp(-(beta[1] + beta[2]*x[i])))
  y[i] = rbinom(1,1,py1[i])
}

## simulated data as a data frame
ddf = data.frame(x,y= as.factor(y),py1)
head(ddf)
summary(ddf)

## write to csv file
write.csv(ddf,file="rsimdata.csv",row.names=FALSE)
## read in with temp = read.csv("rsimdata.csv")

lfit = glm(y~x,ddf,family=binomial())
summary(lfit)

## plot fit
phat = predict(lfit,type="response")
oo = order(x) # helps with plotting
plot(x[oo],py1[oo],col="blue",xlab="x",ylab="P(Y=1|x)",ylim=range(c(phat,py1)),type='l',lwd=2)
lines(x[oo],phat[oo],col="red",lwd=2)
legend('topleft',legend=c("true","estimated"),col=c("blue","red"),lwd=c(1,1),bty="n")
title("x vs. P(Y=1|x)")

##################################################
### - log lik
mLL1 = function(x,y,beta) {
  n = length(y)
  mll = 0.0
  for(i in 1:n) {
    py1 = 1/(1 + exp(-(beta[1] + beta[2]*x[i])))
    if(y[i] == 1) {
      mll = mll - log(py1)
    } else {
      mll = mll - log(1-py1)
    }
  }
  return(mll)
}

##check against R
deviance = 2*mLL1(x,y,lfit$coef)
aic = deviance + 2*length(lfit$coef)
cat("deviance and aic should be ",deviance,", ", aic,"\n")


##################################################
### get -LL on beta1 grid
nval = 1000
p=2
## evaluate -LL at each row of bMat
bMat = matrix(0.0,nval,p)
bMat[,1] = lfit$coeff[1]
bMat[,2] = seq(from=0,to=1,length.out=nval)

## -LL evals
llv = rep(0,nval)

for(i in 1:nval) {
  llv[i] = mLL1(x,y,bMat[i,])
}

## plot llv
plot(bMat[,2],llv,xlab="beta1 values",ylab="- log likelihood")
abline(v=beta[2],col="blue")
ii = which.min(llv)
abline(v=bMat[ii,2],col="red")
title(main="blue at true, red at mle",cex.main=2.2)
if(1) dev.copy2pdf(file="mLL-on-a-beta1-grid.pdf",width=10,height=8)

## row of bMat at min:
bMat[ii,]
## check
lfit$coef

##################################################
### get -LL on bivariate grid
nval=100
b1g = seq(from=0,to=2,length.out=nval) 
b2g = seq(from=0,to=1,length.out=nval) 
## bg will be bivariate grid, all vals of b1g x all vals of b2g
bg = expand.grid(b1g,b2g)
nn = nrow

## get bivariate grid as a matrix, and try loop
bgM = as.matrix(bg)
llv2 = rep(0,nn)
tm2 = system.time({
  for(i in 1:nn) {
    if( (i %% 100) == 0) cat("i: ",i,"\n")
    llv2[i] = mLL1(x,y,bgM[i,])
  }
})

## check
ii = which.min(llv2)
cat("min -LL over grid: ",bgM[ii,],"\n")
cat("mle: ",lfit$coef,"\n")
cat("true: ",beta,"\n")
temp  = rbind(bgM[ii,],lfit$coef,beta)
rownames(temp) = c("gmle","rmle","true")
colnames(temp) = c("intercept","slope")
print(temp)

## try a simpler matrix, do str(bgM) str(bgMM)
bgMM = cbind(bgM[,1],bgM[,2])
llv3 = rep(0,nn)
tm3 = system.time({ 
  for(i in 1:nn) {
    if( (i %% 100) == 0) cat("i: ",i,"\n")
    llv3[i] = mLL1(x,y,bgMM[i,])
  }
})

## check
summary(llv2-llv3)


##################################################
### write vectorized version of mLL, no loops!!
mLL = function(x,y,beta) {
  py1 = 1/(1+exp(-(beta[1] + beta[2]*x)))
  return(-sum(ifelse(y==1,log(py1),log(1-py1))))
}

llv4 = rep(0,nn)
tm4 = system.time({
  for(i in 1:nn) {
    llv4[i] = mLL(x,y,bgMM[i,])
  }
})

##check
summary(llv3-llv4)

## contour, image ...
mllM = matrix(llv4,nrow=length(b1g))

contour(b1g,b2g,mllM,nlevels=50,col='blue',drawlabels=FALSE)
points(beta[1],beta[2],col='blue',pch=16)
points(lfit$coef[1],lfit$coef[2],col='red',pch=15)

persp(b1g,b2g,mllM)

image(b1g,b2g,mllM)

nncol = 100
cvec = viridis(nncol, alpha = 1, begin = 0, end = 1, direction = 1, option = "G")
image(b1g,b2g,mllM,col=cvec)

filled.contour(b1g,b2g,mllM,nlevels=30)
filled.contour(b1g,b2g,mllM)

nncol = 20
cvec = viridis(nncol, alpha = 1, begin = 0, end = 1, direction = 1, option = "A")
filled.contour(b1g,b2g,mllM,col=cvec)

filled.contour(b1g,b2g,mllM,color.palette = mako)

