Welcome to Practical 2, an introduction to using R, JAGS and Stan for fitting Generalised Linear Models (GLMs). In this practical we’ll:
Throughout this document you will see code in gray boxes. You should try to understand this code, and you can usually copy it directly from this document into your R script window. At various points you will see a horizontal line in the text which indicates a question you should try to answer, like this:
Exercise X
What words does the following command print to the console?
print("Hello World")
The two main documents I use when writing Stan/JAGS code are the manuals:
Here is a quick warm-up exercise:
Exercise 1
read.csv
lm
to perform a linear regression with log(earnings) as the response and height in cm as the explanatory variableThe first thing to know about JAGS is that it is a separate language from R which has its own code and structure. It is unlike Stan (which is an R package) in that JAGS is separate software that needs to be installed by itself. Hopefully you already did this when going through the pre-requisites.
To fit a model in JAGS we use this procedure:
R2Jags
but there are othersjags
function, where we additionally provide data, parameters to watch, and (optionally) initial values and other algorithm detailsHere’s a standard workflow. We’re going to fit a simple linear regression using the earnings
data as in the initial exercise above. We’re then going to plot the output. I suggest first just copying and pasting this code to check that it works, then go back through the details line-by-line to understand what’s happening.
First load in the data:
dat = read.csv('https://raw.githubusercontent.com/andrewcparnell/bhm_course/master/data/earnings.csv')
Now load in the package:
library(R2jags)
Type in our JAGS code and store it in an object. The code is as explained in Class 1.
jags_code = '
model{
# Likelihood
for(i in 1:N) {
y[i] ~ dnorm(intercept + slope * x[i],
residual_sd^-2)
}
# Priors
intercept ~ dnorm(0, 100^-2)
slope ~ dnorm(0, 100^-2)
residual_sd ~ dunif(0, 100)
}
'
Run the code by calling jags
:
jags_run = jags(data = list(N = nrow(dat),
y = log(dat$earn),
x = dat$height_cm),
parameters.to.save = c('intercept',
'slope',
'residual_sd'),
model.file = textConnection(jags_code))
Above I am giving R a list of data which matches the objects in the jags_code
, and I’m also telling it which parameters I am interested in via parameters.to.save
. If you don’t provide a specific parameter here, JAGS will not store it in the output object. The textConnection
function is used to tell JAGS that the model code is stored in an R object.
Once run we can use the print or plot commands:
print(jags_run)
## Inference for Bugs model at "4", fit using jags,
## 3 chains, each with 2000 iterations (first 1000 discarded)
## n.sims = 3000 iterations saved
## mu.vect sd.vect 2.5% 25% 50% 75% 97.5%
## intercept 5.896 0.482 4.961 5.569 5.898 6.222 6.849
## residual_sd 0.907 0.020 0.869 0.893 0.906 0.920 0.948
## slope 0.023 0.003 0.017 0.021 0.023 0.024 0.028
## deviance 2798.530 2.480 2795.731 2796.752 2797.899 2799.650 2804.846
## Rhat n.eff
## intercept 1.001 2200
## residual_sd 1.002 1200
## slope 1.001 2200
## deviance 1.002 3000
##
## For each parameter, n.eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor (at convergence, Rhat=1).
##
## DIC info (using the rule, pD = var(deviance)/2)
## pD = 3.1 and DIC = 2801.6
## DIC is an estimate of expected predictive error (lower deviance is better).
The important things to look at here are the estimated posterior means and standard deviations, the quantiles, and also the Rhat values. You can get a slightly useless plot with:
plot(jags_run)
Exercise 2
Once you are happy that the above runs for you, try the following:
x
variable in the data list from height_cm
to height
. How does the model change?parameters.to.save
argument of the JAGS call. What happens to the resulting print
/plot
commands? Do any of the values change? (they shouldn’t change much, but remember that these are stochastic algorithms and you won’t get exactly the same values every time, though they should be close)The object created from the jags
call is a huge list with lots of things in it, most of them useless. A good command to explore the structure of a list is str
:
str(jags_run)
There are two useful objects to extract out of the model. The first is the list of all the parameter values. You can get this with:
pars = jags_run$BUGSoutput$sims.list
str(pars)
## List of 4
## $ deviance : num [1:3000, 1] 2796 2798 2796 2797 2800 ...
## $ intercept : num [1:3000, 1] 5.9 5.62 6.09 5.71 5.71 ...
## $ residual_sd: num [1:3000, 1] 0.9 0.887 0.914 0.89 0.921 ...
## $ slope : num [1:3000, 1] 0.0224 0.0243 0.0214 0.0235 0.024 ...
After this run we can get at, e.g. the first 10 sample values of the intercept
parameter with:
pars$intercept[1:10]
## [1] 5.901483 5.621074 6.092228 5.712966 5.710167 5.544540 6.074314
## [8] 5.416764 6.122312 4.958255
From this we can calculate the posterior mean with, e.g. mean(pars$intercept)
but we can also get this more directly with:
par_means = jags_run$BUGSoutput$mean
str(par_means)
## List of 4
## $ deviance : num [1(1d)] 2799
## $ intercept : num [1(1d)] 5.9
## $ residual_sd: num [1(1d)] 0.907
## $ slope : num [1(1d)] 0.0225
The other things you might want to get out are also stored in the BUGSoutput
bit:
str(jags_run$BUGSoutput)
Exercise 3
$summary
)pars
object created above)If you can fit models in JAGS then it’s pretty easy to move them over to Stan. The structure of the workflow is very similar, as are the commands. Stan code tends to be slightly longer as you have to declare all of the data and parameters, and this has its disadvantages (longer code), and advantages (easier to follow).
When you load up stan, it’s best not just to call the package, but also to run two additional lines:
library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
These two lines allow Stan to run in parallel on a multi-core machine which can really speed up workflow. We now specify our model code:
stan_code = '
data {
int<lower=0> N;
vector[N] x;
vector[N] y;
}
parameters {
real intercept;
real slope;
real<lower=0> residual_sd;
}
model {
y ~ normal(intercept + slope * x, residual_sd);
}
'
It’s a good idea to compare and contrast the code between JAGS and Stan. The key differences are:
data
, parameters
, and model
. There are others we will come to later(mean, standard deviation)
. By contrast, JAGS uses (mean, precision)
. Remember precision is the reciprocal of the variance, so precision = 1 / sd^2=
to assign variables, JAGS uses <-
Finally we run the model with:
stan_run = stan(data = list(N = nrow(dat),
y = log(dat$earn),
x = dat$height_cm),
model_code = stan_code)
In contrast to JAGS, we do not need to specify the parameters to watch. Stan watches all of them by default. We also don’t need a textConnection
command to provide the stan
function with the stan code. The data
call is identical to JAGS.
I often find with Stan that I get lots of warnings when I run a model, especially with more complicated models (which we haven’t met yet). Most of them tend to be ignorable and are unhelpful. Usually Stan gives you a count of how many ‘bad things’ happened in each chain. However, there doesn’t seem to be much guidance on what to change, if anything, in response. If you start to see lots of warnings in your Stan output, let me know!
Just like JAGS, we can print or plot the output. This is also a pretty useless plot:
plot(stan_run)
## ci_level: 0.8 (80% intervals)
## outer_level: 0.95 (95% intervals)
Exercise 4
Let’s do the same exercises again for Stan:
Try changing the x
variable in the data list to height
. How does the model change?
Try adding in some prior distributions to the Stan code. For example, add intercept ~ normal(0, 100);
. (If you get stuck there’s an example below with priors for a diferent model). Again experiment with making these ranges small or large
Stan by default saves all the parameters, but they are not easily accessible directly from the stan_run
object we just created. Instead there is a special function called extract
to get at the posterior samples.
pars = extract(stan_run)
str(pars)
## List of 4
## $ intercept : num [1:4000(1d)] 5.91 4.61 6.47 5.12 5.46 ...
## ..- attr(*, "dimnames")=List of 1
## .. ..$ iterations: NULL
## $ slope : num [1:4000(1d)] 0.0225 0.0298 0.0191 0.0272 0.0252 ...
## ..- attr(*, "dimnames")=List of 1
## .. ..$ iterations: NULL
## $ residual_sd: num [1:4000(1d)] 0.94 0.941 0.922 0.927 0.906 ...
## ..- attr(*, "dimnames")=List of 1
## .. ..$ iterations: NULL
## $ lp__ : num [1:4000(1d)] -426 -431 -426 -427 -425 ...
## ..- attr(*, "dimnames")=List of 1
## .. ..$ iterations: NULL
If we want the first 10 samples of the intercept we can run:
pars$intercept[1:10]
## [1] 5.905431 4.614097 6.474814 5.119850 5.460803 6.914850 5.827420
## [8] 6.206904 5.906729 5.967672
Alternatively if you want to get at the summary statistics you can run:
pars_summ = summary(stan_run)$summary
pars_summ['intercept','mean']
## [1] 5.890555
In addition, stan has some useful plotting commands, e.g.
stan_dens(stan_run)
You can find a bigger list with ?stan_plot
Exercise 5
As with JAGS:
Once you have the basic workflow in place, and the ability to manipulate output, the only hard thing that remains is the ability to write your own JAGS/Stan code to fit the model you want.
Recall that for a generalised linear model (GLM) we don’t have a normally distributed likelihood, instead we have a probability distribution that matches the data (e.g. binomial for restricted counts, Poisson for unrestricted) and a link function which transforms the key parameter (usually the mean parameter) from its restricted range into something we can use in a linear regression type model.
Here is some JAGS code to fit a Binomial-logit model to the Swiss willow tit data:
swt = read.csv('https://raw.githubusercontent.com/andrewcparnell/bhm_course/master/data/swt.csv')
jags_code = '
model{
# Likelihood
for(i in 1:N) {
y[i] ~ dbin(p[i], 1)
logit(p[i]) <- alpha + beta*x[i]
}
# Priors
alpha ~ dnorm(0, 20^-2)
beta ~ dnorm(0, 20^-2)
}
'
jags_run = jags(data = list(N = nrow(swt),
y = swt$rep.1,
x = swt$forest),
parameters.to.save = c('alpha',
'beta'),
model.file = textConnection(jags_code))
Most of the above steps you have already seen. The key things to note are:
y[i] ~ dbin(p[i], 1)
to represent the binomial likelihood with probability \(p_i\) and number of trials set to 1. (If you’ve forgotten the details of the binomial model please ask!)logit
. Both JAGS and Stan are slightly different from other programming languages in that you can write the link function on the left hand side. In most programming languages you would have to write something like p[i] <- logit_inverse(alpha + beta*x[i])
alpha
and beta
here, rather than intercept
and slope
Once run, you can manipulate the object in any way you like. Let’s create a simple plot of the data with the fitted line going through it. A useful package to call here is the boot
package which contains the logit
and inv.logit
functions. You can try them out with:
library(boot)
logit(0.4) # Convert from 0-to-1 space to -Infinity to +Infinity
## [1] -0.4054651
inv.logit(3) # ... and the opposite
## [1] 0.9525741
To create the plot we will first get the posterior means of the slopes:
post_means = jags_run$BUGSoutput$mean
alpha_mean = post_means$alpha
beta_mean = post_means$beta
Now plot the data and add in the predicted values of the probability by creating new explanatory variables on a grid:
par(mar=c(3,3,2,1), mgp=c(2,.7,0), tck=-.01, las=1) # Prettify plots a bit
with(swt, plot(forest, rep.1,
xlab = 'Forest cover (%)',
ylab = 'Probability of finding swt'))
forest_grid = pretty(swt$forest, n = 100)
lines(forest_grid, inv.logit(alpha_mean + beta_mean * forest_grid), col = 'red')
## Warning in beta_mean * forest_grid: Recycling array of length 1 in array-vector arithmetic is deprecated.
## Use c() or as.vector() instead.
## Warning in alpha_mean + beta_mean * forest_grid: Recycling array of length 1 in array-vector arithmetic is deprecated.
## Use c() or as.vector() instead.
At this point if you’re feeling really brave, you could try and code the above model in Stan without looking any further down the page. I’d suggest looking in the manual for the binomial_logit
function (unfortunately there aren’t currently any examples in the Stan manual for Binomial-logit regression).
If you’re not so confident (or you’ve finished and want to check), here is the code for a Stan Binomial logit model.
stan_code = '
data {
int<lower=0> N;
vector[N] x;
int y[N];
}
parameters {
real alpha;
real beta;
} model {
y ~ binomial_logit(1, alpha + beta * x);
alpha ~ normal(0, 10);
beta ~ normal(0, 10);
}
'
stan_run = stan(data = list(N = nrow(swt),
y = swt$rep.1,
x = swt$forest),
model_code = stan_code)
plot(stan_run)
The things to note about this code are:
x
to be a vector
and y
to be an int
. This is because x
is continuous (it’s our percentage of forest cover) but y
must be a whole number to be a valid binomial random variable. If you try to set y
as a vector too Stan will complainvector[N] x
but int y[N]
. So the [N]
appears in different places for vectors than for integers. I can never remember which way round is correct!binomial_logit
does the two lines that JAGS uses in one go. It calculates the binomial probability and creates the logit of alpha + beta * x
for you.for
loop like JAGSExercise 6
Pick your favourite so far (either JAGS or Stan) and try these:
forest
to elev
. Re-create the plots and interpret your findings.forest
and elev
. How do the results change?Now try and swap programmes and try the above again! Then compare the results of JAGS vs Stan on the GLM. They should look similar, though the algorithm they use is stochastic so the results won’t be identical.
Both JAGS and Stan provide the Rhat value with their default print
output. Remember the rule of thumb is that the model run is satisfactory if all values are less than 1.1. If the Rhat values are above 1.1 then you have two possibilities:
If the Rhat values are above 1.1 then the first thing I do is create a trace plot of some of the parameters. This is the plot of the parameter values at each iteration. You can get it in Stan from:
rstan::traceplot(stan_run)
It’s a bit more fiddly in JAGS:
coda::traceplot(as.mcmc(jags_run$BUGSoutput$sims.matrix))
The reason for the stan::
and the coda::
bit is because JAGS and Stan each have a function called traceplot
. You have to specify the package to get the correct one. The traceplots should look like hairy caterpillars and shouldn’t be wandering around or stuck in a location.
If the traceplots look bad there are five things you can change in a JAGS/Stan run to help. These were all covered in Class 4 so go back to the slides if you have forgotten. They are:
Here’s an example of Stan code with the iterations/burn-in/thinning values doubled (Stan calls the burn-in the warmup
):
stan_run = stan(data = list(N = nrow(swt),
y = swt$rep.1,
x = swt$forest),
iter = 4000,
warmup = 2000,
thin = 2,
model_code = stan_code)
and here’s the same example in JAGS
jags_run = jags(data = list(N = nrow(swt),
y = swt$rep.1,
x = swt$forest),
parameters.to.save = c('alpha',
'beta'),
n.iter = 4000,
n.burnin = 2000,
n.thin = 2,
model.file = textConnection(jags_code))
For the vast majority of problems we will cover (up until the last few classes) you shouldn’t need to change any of the defaults but it is useful to know.
whitefly.csv
data set using the imm
variable as the response and the trt
variable as a covariate. Don’t worry about interpreting the output, just get the model to run and converge.poisson_log
function in the Stan manualn
in this data set is a potential offset. Use the notes on offsets in Class 2 to fit a model that uses the n
variable as an offset. Again, don’t worry about interpretation. Just get the models to run!The answer script is here