---
title: 'Bayesian Latent Variable Regression example: MNDR data'
author: "Vitense et al. (2017)"
output: html_document
---

```{r set-options, echo=FALSE, cache=FALSE}
options(width = 150)
options(digits = 3)
```

#### This document provides  R code to analyze a single year (2009, 2010, or 2011) of Minnesota shallow lake data contained in DNR_Data.csv. Both Bayesian latent variable regression (BLR) and a linear model (LM) are used to analyze the data, and the two fits are compared using a Pareto-smoothed importance sampling approximation to leave-one-out cross-validation using the R package 'loo' (Vehtari et al. 2016). R session info and package versions can be found at the end of this document.


```{r, message=FALSE, warning=FALSE}
library(R2jags)
library(R2WinBUGS)
library(mcmcplots)
library(ggplot2)
library(KernSmooth)
library(coda)
library(gridExtra)
library(devtools)
install_github(repo='johnbaums/jagstools')
library(jagstools)
library(loo)
```


#### Read in MDNR data:
```{r}
State.dat <- read.csv("DNR_Data.csv",header=T) # DNR data

# Resave variables
lake <- as.factor(State.dat$lakeID)
Year <- as.factor(State.dat$Year) # 2009, 2010, 2011
TPug <- State.dat$TPug  # ug/L
lnTPug <- log(TPug)
Chla <- State.dat$Chla  # ug/L
lnChla <- log(Chla)
SAVms <- State.dat$SAVms      # avg SAV (kg) in each lake
lnSAVms <- log(SAVms+.02)
```


#### Pull out one year of data (2009, 2010, or 2011)
```{r}
Yr <- 2010 # can be changed to 2009 or 2011 to analyze a different year
lnP <- lnTPug[Year==Yr]
ind.rm <- which(is.na(lnP)) # which observations missing

# Remove missing observations
lnP <- lnP[-ind.rm]
P <- TPug[Year==Yr]; P <- P[-ind.rm]
lnC <- lnChla[Year==Yr]; lnC <- lnC[-ind.rm]
Chl <- Chla[Year==Yr]; Chl <- Chl[-ind.rm]
SAV <- SAVms[Year==Yr]; SAV <- SAV[-ind.rm]
lnSAV <- lnSAVms[Year==Yr]; lnSAV <- lnSAV[-ind.rm]
SAV.mean <- mean(SAV)
SAVc <- SAV-SAV.mean # centered SAV
```


#### Visualize data
```{r, message=FALSE, fig.height=4, fig.width=10}
pl1 <- ggplot(NULL, aes(x=P, y=Chl)) + geom_point(aes(size=SAV)) + scale_size_continuous(range=c(2,5))
pl2 <- ggplot(NULL, aes(x=lnP, y=lnC)) + geom_point( aes(size=SAV))+ scale_size_continuous(range=c(2,5))
grid.arrange(pl1, pl2, ncol=2)
```



#### JAGS code for Bayesian latent variable regression
```{r}
latent_reg <-function(){
  
  # Priors for TP/Chla regression
  a0 ~  dnorm(0,0.01) # clear intercept
  a1adj ~  dnorm(0,0.1)  # adjustment to intercept to get turbid intercept
  a1 <- a1adj+a0 # turbid intercept

  b0 ~ dunif(0,10)   # clear slope
  b1 ~ dunif(0,10)  # turbid slope
  b1adj <- b1-b0 # difference in slopes
  
  # Priors for logistic regression relating SAV to P(turbid)
  alpha ~ dnorm(0,.01) # intercept for logistic regression
  beta <- -1*negbeta  # slope for logistic regression 
  negbeta ~ dlnorm( .5, 1 ) # prior for -1*beta

  sigma1 ~ dunif(0, 20)  # Residual standard deviation of clear lakes
  sigma2 ~ dunif(0, 20)  # Residual standard deviation of turbid lakes
  tau1 <- 1 / ( sigma1 * sigma1)
  tau2 <- 1 / ( sigma2 * sigma2)

  b2 <- (a1+b1*logp1-a0-b0*logp2) /(logp1-logp2) # slope of unstable line
  constraint <- step(b2)
  
  # Priors for TP thresholds
  logp1 ~ dunif(0,6.5)   # left bifurcation point (tip down) on log scale
  logp2 ~ dunif(0,6.5)   # right bifurcation point (tip up) on log scale
  p1 <- exp(logp1)
  p2 <- exp(logp2)
  
  
  
  # Likelihood
  for(i in 1:nsamp){
    lnC[i] ~ dnorm(mu[i], tau1*(1-S[i]) + tau2*S[i])
    mu[i] <- a0+a1adj*S[i] + b0*lnP[i]*(1-S[i])+b1*lnP[i]*S[i] # state-dependent linear relationship between log(TP) and log(Chla). 

    S[i] ~ dbern(p[i]) # latent state (clear/turbid) for lake i
    logit(ptmp[i])<-alpha+beta*SAVc[i] # logistic regression relating SAV to the probability lake i is turbid        
    p[i] <- ptmp[i]*step(lnP[i]-logp1)*step(logp2-lnP[i])+step(lnP[i]-logp2) # sets p to 0 if log(TP) less than logp1, p depends on SAV if between logp1 and logp2, and p is 1 if log(TP) greater than logp2

    ForceNeg[i] ~ dbern(constraint) # Force the slope of the unstable line connecting the turbid line at logp1 to the clear line at logp2 to be negative so the fit resembles a bifurcation diagram
  
    log_lik[i] <- log( dnorm(lnC[i], mu[i], tau1*(1-S[i]) + tau2*S[i]) ) # contribution of lake i to log-likelihood for PSIS-LOO computations
    
    }
}


# Bundle data
ForceNeg <- rep(0,length(lnC))
nsamp <- length(lnC)

win.data <- list("lnC", "lnP", "nsamp", "SAVc", "ForceNeg")
  
  
# Inits function
inits <- function(){ list(a0=-.70, a1adj=.81, logp1=3.91, logp2=6.02, b0=.60, b1=.79, alpha=runif(1,-7,-3), negbeta=runif(1,10,18), sigma1=runif(1, .5,.7), sigma2=runif(1, .4,.8))}
# initial values informed by previous run for 2010


# Parameters to estimate
params <- c("a0","a1","a1adj", "b0", "b1","b1adj","b2","alpha","beta","p","p1","p2","logp1","logp2", "sigma1", "sigma2", "S", "log_lik")

# MCMC settings
nc <- 3
ni <- 10000000  
nb <-  2000000  
nt <-     2400
```



#### Start MCMC sampler for BLR
WARNING: For 10 million iterations (as specified above), JAGS takes several hours to run, even using parallel computing. The model has already been run and saved as "DNR_JAGSout2010_10mil_2400thin.rds", which can be loaded using the function 'readRDS' (see Markdown file for code).
```{r, eval=FALSE}
set.seed(27)

out.BLR <- do.call(jags.parallel, list(win.data, inits, params, latent_reg, nc, ni, nb, nt))

# If desired, save MCMC run to be loaded later
#saveRDS(out.BLR, "DNR_JAGSout2010_10mil_2400thin.rds")
```

 
```{r, echo=FALSE}
# Load previous run
out.BLR <- readRDS("DNR_JAGSout2010_10mil_2400thin.rds") 
```

#### Print summary for BLR
```{r}
jagsresults(x=out.BLR, params=c("logp1","logp2","p1","p2","a0","a1","b0","b1","alpha","beta","sigma1","sigma2"))

```

#### Trace plots for BLR
```{r}
traplot(out.BLR ,c("a0","a1","b0","b1","alpha","beta","logp1","logp2"))
```

#### Posterior density plots
```{r}
denplot(out.BLR ,c("a0","a1","b0","b1","alpha","beta","p1","p2", "sigma1", "sigma2"))
```


#### JAGS code for linear model (LM):
```{r, cache=TRUE}

latent_reg <-function(){
  
  # Priors 
  a0 ~  dnorm(0,0.01) # N(0,10^2)
  b0 ~ dunif(0,10)   # slopes should be positive
  sigma1 ~ dunif(0, 20)  # Residual standard deviation
  tau1 <- 1 / ( sigma1 * sigma1)

  for(i in 1:nsamp){
    lnC[i] ~ dnorm(mu[i], tau1)
    mu[i] <- a0 + b0*lnP[i] # linear relationship between log(TP) and log(Chla). 

    log_lik[i] <- log( dnorm(lnC[i], mu[i], tau1) )
  }
}


# Bundle data
nsamp <- length(lnC)

win.data <- list("lnC", "lnP", "nsamp")
  
  
# Inits function
inits <- function(){ list(a0=rnorm(1, 0, 1),  b0=runif(1, 0, 2),  sigma1=runif(1, .2, 1))} 


# Parameters to estimate
params <- c("a0", "b0",  "sigma1", "log_lik") 

# MCMC settings
nc <- 3
ni <- 1000000  # 10000000
nb <-  200000   # 2000000
nt <- 240         #2400
```


#### Start MCMC sampler for LM
```{r, eval=FALSE}
set.seed(27)

out.lin <- do.call(jags.parallel, list(win.data, inits, params, latent_reg, nc, ni, nb, nt))

# If desired, save MCMC run to be loaded later
#saveRDS(out.lin, "DNR_JAGSout2010_1mil_240thin_LOO_LINEAR.rds")
```


```{r, echo=FALSE}
# Load previous run
out.lin <- readRDS("DNR_JAGSout2010_1mil_240thin_LOO_LINEAR.rds")

```

#### Print summary for LM
```{r}
jagsresults(x=out.lin, params=c("a0","b0","sigma1"))
```

#### Trace plots for LM
```{r}
traplot(out.lin ,c("a0","b0", "sigma1"))
```


#### Compare BLR and LM using PSIS-LOO (see Vehtari et al. (2016))
```{r, message=FALSE, warning=FALSE}
 llik <- out.BLR$BUGSoutput$sims.list$log_lik
 loo <- loo(llik) # pareto-smoothed importance sampling leave-one-out cross validation (PSIS-LOO)
 print(loo) # For 2010, ~7 pareto k estiamtes are not good

 llik.lin <- out.lin$BUGSoutput$sims.list$log_lik
 loo.lin <- loo(llik.lin) 
 print(loo.lin) # All pareto k estimates are good

 
 loo_diff <- compare(loo, loo.lin)
 print(loo_diff) # The difference will be negative if the expected predictive accuracy for the first (BLR) model is higher.
 
 elpd_diff_loo <- loo_diff[[1]]     # estimated difference in expected predictive accuracy
 se_elpd_diff_loo <-  loo_diff[[2]] # estimated standard error of difference
    
 # Bounds for 95% confidence interval
 low.bound.loo <- elpd_diff_loo  - 1.96*se_elpd_diff_loo 
 up.bound.loo <- elpd_diff_loo  + 1.96*se_elpd_diff_loo 
 
 # 95% CI Contains 0?
 low.bound.loo  < 0 & up.bound.loo  > 0
 # FALSE = SIGNIFICANT AT ALPHA=.05

```


### Remaining code pertains only to BLR

#### BLR threshold posterior modes and medians
```{r}
dKS.logp1 <- bkde(out.BLR$BUGSoutput$sims.list$logp1)
(logp1.modeKS <- dKS.logp1$x[which.max(dKS.logp1$y)]) # lower threshold mode on log scale
dKS.logp2 <- bkde(out.BLR$BUGSoutput$sims.list$logp2)
(logp2.modeKS <- dKS.logp2$x[which.max(dKS.logp2$y)]) # upper threshold mode on log scale
ylim.logp <- max(c(max(dKS.logp1$y), max(dKS.logp2$y)))

exp(logp1.modeKS) # exponentiate to get p1 mode
exp(logp2.modeKS) # exponentiate to get p2 mode

# Save posterior medians for thresholds
logp1med <- out.BLR$BUGSoutput$median$logp1
logp2med <- out.BLR$BUGSoutput$median$logp2
p1med <- out.BLR$BUGSoutput$median$p1
p2med <- out.BLR$BUGSoutput$median$p2
```


#### Density plots of threshold posteriors
```{r}
par(mar= c(5, 5, 2, 2), cex.lab=1.8, cex.axis=1.6)
plot(dKS.logp1, xlab="Threshold", las=1, lwd=3, xlim=c(0,6.5), main="", col="#009a9a", ylim=c(0,ylim.logp+1), type="l", ylab="Density")
lines(dKS.logp2, lwd=3, col="#9a0000")
abline(h=1/6.5, lwd=2, col="black")
abline(v=c(logp1.modeKS,logp2.modeKS), col="darkorchid", lty=3, lwd=3) 
abline(v=c(logp1med,logp2med), col="dodgerblue", lty=3, lwd=3) 
legend("topleft",c("Lower Threshold Posterior","Upper Threshold Posterior","Prior","Mode","Median"),lwd=3, col=c("#009a9a","#9a0000","black","darkorchid","dodgerblue"), bty='n', cex=1.3, lty=c(1,1,1,2,3,3))
```


#### Plot of prior and posterior distributions for $\beta$ (this is $\gamma_1$ in the paper; i.e., the slope parameter for the logistic regression relating SAV to the probability a lake is turbid)
```{r, fig.height=5, fig.width=6}
par(mar= c(5, 5, 2, 2), cex.lab=1.7, cex.axis=1.6)
plot(density(out.BLR$BUGSoutput$sims.list$beta),xlim=c(-40,0),ylim=c(0,.4),las=1,xlab=expression(beta),main="",lwd=3, col="firebrick3")
lines(density(-1*rlnorm(10000,meanlog=.5,sdlog=1)), col="dodgerblue3",lwd=3)
legend("topleft",c("Prior","Posterior"),cex=1.3,lty=1,lwd=3,col=c("dodgerblue3","firebrick3"))
```



#### Pull out posterior means/medians and classify discrete states
```{r}
# posterior means for regression parameters
a0m <- out.BLR$BUGSoutput$mean$a0
a1m <- out.BLR$BUGSoutput$mean$a1
a1adjm <- out.BLR$BUGSoutput$mean$a1adj
b0m <- out.BLR$BUGSoutput$mean$b0
b1m <- out.BLR$BUGSoutput$mean$b1
b1adjm <- out.BLR$BUGSoutput$mean$b1adj
alpham <- out.BLR$BUGSoutput$mean$alpha
betam <- out.BLR$BUGSoutput$mean$beta

# posterior medians for regression parameters
a0med <- out.BLR$BUGSoutput$median$a0
a1med <- out.BLR$BUGSoutput$median$a1
a1adjmed <- out.BLR$BUGSoutput$median$a1adj
b0med <- out.BLR$BUGSoutput$median$b0
b1med <- out.BLR$BUGSoutput$median$b1
b1adjmed <- out.BLR$BUGSoutput$median$b1adj
alphamed <- out.BLR$BUGSoutput$median$alpha
betamed <- out.BLR$BUGSoutput$median$beta

# 95% credible interval bounds for TP thresholds on log scale (medians were previously saved)
logp1.row <- which(row.names(out.BLR$BUGSoutput$summary)=="logp1")
logp1.lowbound <- out.BLR$BUGSoutput$summary[logp1.row,3] # 2.5
logp1.upbound <- out.BLR$BUGSoutput$summary[logp1.row,7] # 97.5
logp2.row <- which(row.names(out.BLR$BUGSoutput$summary)=="logp2")
logp2.lowbound <- out.BLR$BUGSoutput$summary[logp2.row,3] # 2.5
logp2.upbound <- out.BLR$BUGSoutput$summary[logp2.row,7] # 97.5

# Put parameter estimates in a dataframe
coefs <- data.frame(Int.Mean=c(a0m,a1m),Int.Med=c(a0med,a1med),Slope.Mean=c(b0m,b1m), Slope.Med=c(b0med,b1med),State=c("Clear","Turbid"), Bifur.LogMed=c(logp1med,logp2med),Bifur.LogMode=c(logp1.modeKS,logp2.modeKS),  Bifur.Log.LowBound=c(logp1.lowbound,logp2.lowbound), Bifur.Log.UpBound=c(logp1.upbound,logp2.upbound))

coefs

# Classify Discrete State: 0 if mean(State)<.5 ; 1 if mean(State) > .5
Latent.State <- as.factor(as.numeric(out.BLR$BUGSoutput$mean$S > .5))
```



#### Plot average estimated logistic curve
```{r}
  mcmc.l<-out.BLR$BUGSoutput$sims.list
  Ep.PE <- Ep.LCL <- Ep.UCL <- rep(NA,length(SAV))
  
  for(i in 1:length(SAV)){
     jags.Ep.temp<-plogis(mcmc.l$alpha-SAV.mean*mcmc.l$beta + mcmc.l$beta*SAV[i])
     Ep.PE[i]<-mean(jags.Ep.temp)
     Ep.LCL[i]<-quantile(jags.Ep.temp, prob=0.025)   
     Ep.UCL[i]<-quantile(jags.Ep.temp, prob=0.975)    
  }

state.num <- as.numeric(out.BLR$BUGSoutput$mean$S > .5)

ggplot(NULL, aes(x=SAV, y=state.num) ) + geom_point(size=3)+ geom_line(aes(x=SAV, y=Ep.PE)) + geom_ribbon(aes(x=SAV,ymin=Ep.LCL, ymax=Ep.UCL),alpha=0.3) +theme(axis.text=element_text(colour="black", size=15),axis.title=element_text(colour="black", size=16),legend.text=element_text(colour="black", size=15),legend.title=element_text(colour="black", size=15),axis.title.y=element_text(vjust=1.5),axis.title.x=element_text(vjust=0))+ coord_cartesian(ylim = c(-.05,1.05)) + ylab("P(turbid)") + xlab("SAV (kg)")
```




#### Plot average estimated regression line for each state

Black solid (dashed) lines represent average (2.5th, 97.5th quantiles) estimated steady state relationships across all MCMC samples. Steady state lines end at the TP threshold point estimates (posterior modes), and gray bands represent 95% credible intervals for TP thresholds. Triangular points represent lakes classified as clear (>50% of MCMC sampled states were clear), and circular points represent lakes classified as turbid (>50% of MCMC sampled states were turbid). The average MCMC sampled state for each lake is shown on a blue to green color gradient (0=clear, 1=turbid). Point size is proportional to submerged aquatic vegetation (SAV, units: average kg/sample). 


```{r}
Discrete.State <- as.numeric(Latent.State)-1
Discrete.State[Discrete.State==0] <- "Clear" ; Discrete.State[Discrete.State==1] <- "Turbid"
Continuous.State <- out.BLR$BUGSoutput$mean$S

log.TP <- lnP
log.Chla <- lnC
log.SAV <- lnSAV
cbPalette <- c("royalblue4","seagreen")

a0.samps <- out.BLR$BUGSoutput$sims.list$a0
a1.samps <- out.BLR$BUGSoutput$sims.list$a1
b0.samps <- out.BLR$BUGSoutput$sims.list$b0
b1.samps <- out.BLR$BUGSoutput$sims.list$b1

num.sims <- length(a0.samps)
logp2.modeKS <- round(logp2.modeKS, digits=2)

xclear <- seq(1.4,logp2.modeKS, .01)
xclear.pred <- matrix(rep(NA, length(xclear)*num.sims), nrow=length(xclear))

for (i in 1:length(xclear)){
  x <- xclear[i]
  xclear.pred[i,] <- a0.samps + b0.samps*x
}

clear.means <- apply(xclear.pred, 1, mean)
clear.lower <- apply(xclear.pred, 1, quantile, probs = 0.025)
clear.upper <- apply(xclear.pred, 1, quantile, probs = 0.975)


logp1.modeKS <- round(logp1.modeKS, digits=2)
xturbid <- seq(logp1.modeKS,7, .01)
xturbid.pred <- matrix(rep(NA, length(xturbid)*num.sims), nrow=length(xturbid))

for (i in 1:length(xturbid)){
  x <- xturbid[i]
  xturbid.pred[i,] <- a1.samps + b1.samps*x
}

turbid.means <- apply(xturbid.pred, 1, mean)
turbid.lower <- apply(xturbid.pred, 1, quantile, probs = 0.025)
turbid.upper <- apply(xturbid.pred, 1, quantile, probs = 0.975)


ggplot(NULL, aes(x=log.TP, y=log.Chla , color=Continuous.State )) +
  
  annotate("rect", xmin=coefs$Bifur.Log.LowBound[1], xmax=coefs$Bifur.Log.UpBound[1], ymin=-Inf, ymax=Inf, alpha=0.15, fill="gray20") + annotate("rect", xmin=coefs$Bifur.Log.LowBound[2], xmax=coefs$Bifur.Log.UpBound[2], ymin=-Inf, ymax=Inf, alpha=0.15, fill="gray20") + 
 
  geom_point( aes(shape=Discrete.State, size=SAV)) + 
  geom_line(aes(x=xturbid, y=turbid.means), size=1.1, colour="gray20") +   
  geom_line(aes(x=xclear, y=clear.means), size=1.1, colour="gray20")  + 
  geom_line(aes(x=xturbid, y=turbid.lower), size=1.1, colour="gray20", linetype=2) + 
  geom_line(aes(x=xturbid, y=turbid.upper), size=1.1, colour="gray20", linetype=2) + 
  geom_line(aes(x=xclear, y=clear.lower), size=1.1, colour="gray20", linetype=2) + 
  geom_line(aes(x=xclear, y=clear.upper), size=1.1, colour="gray20", linetype=2)+

  scale_size(range=c(2,6),  limits=range(SAV)) + 
  
  theme(legend.position="right", panel.border=element_rect(fill=NA, colour = "black") , panel.background = element_rect(fill = "white"), axis.text=element_text(colour="black", size=15),axis.title=element_text(colour="black", size=16),legend.text=element_text(colour="black", size=13),legend.title=element_text(colour="black", size=14),axis.title.y=element_text(vjust=1.5),axis.title.x=element_text(vjust=0), plot.title = element_text(size = rel(1.5), hjust = 0.5),   legend.key.height=unit(1.2, 'lines')) +coord_cartesian(ylim = c(-1,6))+ 
  
  coord_cartesian(ylim = c(-1,6)) + 
  
  ylab("log(Chla)") + xlab("log(TP)") + 
  
  scale_colour_gradient2( mid="royalblue3",high="seagreen3", limits=c(0,1)) + 
  
  guides(colour = guide_colourbar(order = 1, label.vjust=1, title.vjust=.85, barwidth=unit(1, 'lines'), barheight=unit(7, 'lines')), shape = guide_legend(order = 2), size = guide_legend(order = 3) )
```


#### R session info
```{r}
sessionInfo()
```