
##################################################
### compute f1,g, M and plot
a = 3
yv = seq(from=a, to=a+2,length.out=1000)

fy1 = exp(-.5*yv^2)
gy = a*exp(-a*(yv-a))

M = (exp((-a^2)/2)/a)

plot(range(yv),range(c(fy1,M*gy)),type="n")
lines(yv,fy1,lwd=2)
lines(yv,M*gy,col="blue",lwd=3,lty=2)

##################################################
## compute h and plot
hycheck = fy1/(M*gy)
hy = exp(-.5*(yv-a)^2)
summary(hy-hycheck)

plot(yv,hy)

##################################################
### exponential rejection sampling function

trn = function(a) {
   if(a<2) stop("you have not thought this through")

   done=FALSE

   while(!done) {
      x = rexp(1,rate=a)
      h = exp(-.5*x^2)
      u = runif(1)
      if(u < h) {
         done=TRUE
      }
   }
   return(x+a)
}


##################################################
###  Dumb rejection sampling
trnD = function(a) {
   done=FALSE
   while(!done) {
      z = rnorm(1)
      if(z >=a) {
         done=TRUE
      }
   }
   return(z)
}

a=3
nd = 5000

### draw using exponential rejection
drv = rep(0,nd)
tm1 = system.time({
for(i in 1:nd) {
   drv[i] = trn(a)
}
})

### draw using dumb rejection
drvD = rep(0,nd)
tm2 = system.time({
for(i in 1:nd) {
   drvD[i] = trnD(a)
}
})

### draw using dumb rejection vectorized
tm3 = system.time({
drv1 = rnorm(nd/(1-pnorm(a)))
drv1 = drv1[drv1>=a]
})


### check draws
par(mfrow=c(1,2))
qqplot(drv1,drv)
abline(0,1,col="red",lwd=3)
qqplot(drv1,drvD)
abline(0,1,col="red",lwd=3)
print(tm1)
print(tm2)
print(tm3)


### draw using the R-package
library(truncnorm)

# in a loop
drvP = rep(0,nd)
tm4 = system.time({
for(i in 1:nd) {
   drvP[i] = rtruncnorm(1,a=a)
}
})

# not in a loop
tm5 = system.time({
drvP1 = rtruncnorm(nd,a=a)
})

# look at draws from R package
par(mfrow=c(1,2))
qqplot(drvP,drv1)
abline(0,1,col="red")
qqplot(drvP1,drv1)
abline(0,1,col="red")
print(tm4)
print(tm5)


### try inverse CDF
drcdf = function(n,a) {
   u = runif(n)
   Fa = pnorm(a)
   temp = u*(1-Fa) + Fa
   return(qnorm(temp))
}

#no loop
tm6 = system.time({
   drC = drcdf(nd,a)
})

par(mfrow=c(1,1))
qqplot(drC,drv1)
abline(0,1,col="red")
print(tm6)

# loop
drC1 = rep(0,nd)
tm7 = system.time({
for(i in 1:nd) {
   drC1[i] = drcdf(1,a)
}
})

par(mfrow=c(1,1))
qqplot(drC1,drv1)
abline(0,1,col="red")
print(tm7)









