Markov Chain Monte Carlo With Turing

Overview

This tutorial will give some examples of using Turing.jl and Markov Chain Monte Carlo to sample from posterior distributions.

Setup

using Turing
using Distributions
using Plots
default(fmt = :png) # the tide gauge data is long, this keeps images a manageable size
using LaTeXStrings
using StatsPlots
using Measures
using StatsBase
using Optim
using Random
using DataFrames
using DataFramesMeta
using Dates
using CSV

As this tutorial involves random number generation, we will set a random seed to ensure reproducibility.


Random.seed!(1);

Fitting A Linear Regression Model

Let’s start with a simple example: fitting a linear regression model to simulated data.

Positive Control Tests

Simulating data with a known data-generating process and then trying to obtain the parameters for that process is an important step in any workflow.

Simulating Data

The data-generating process for this example will be: \[ \begin{gather} y = 5 + 2x + \varepsilon \\ \varepsilon \sim \text{Normal}(0, 3), \end{gather} \] where \(\varepsilon\) is so-called “white noise”, which adds stochasticity to the data set. The generated dataset is shown in Figure 1.

Figure 1: Scatterplot of our generated data.

Model Specification

The statistical model for a standard linear regression problem is \[ \begin{gather} y = a + bx + \varepsilon \\ \varepsilon \sim \text{Normal}(0, \sigma). \end{gather} \]

Rearranging, we can rewrite the likelihood function as: \[y \sim \text{Normal}(\mu, \sigma),\] where \(\mu = a + bx\). This means that we have three parameters to fit: \(a\), \(b\), and \(\sigma^2\).

Next, we need to select priors on our parameters. We’ll use relatively generic distributions to avoid using the information we have (since we generated the data ourselves), but in practice, we’d want to use any relevant information that we had from our knowledge of the problem. Let’s use relatively diffuse normal distributions for the trend parameters \(a\) and \(b\) and a half-normal distribution (a normal distribution truncated at 0, to only allow positive values) for the variance \(\sigma^2\), as recommended by Gelman (2006).

Gelman, A. (2006). Prior distributions for variance parameters in hierarchical models (comment on article by Browne and Draper). Bayesian Anal., 1(3), 515–533. https://doi.org/10.1214/06-BA117A

\[ \begin{gather} a \sim \text{Normal(0, 10)} \\ b \sim \text{Normal(0, 10)} \\ \sigma \sim \text{Half-Normal}(0, 25) \end{gather} \]

Using Turing

Coding the Model

Turing.jl uses the @model macro to specify the model function. We’ll follow the setup in the Turing documentation.

To specify distributions on parameters (and the data, which can be thought of as uncertain parameters in Bayesian statistics), use a tilde ~, and use equals = for transformations (which we don’t have in this case).


@model function linear_regression(x, y)
    # set priors
    σ ~ truncated(Normal(0, 25); lower=0)
    a ~ Normal(0, 10)
    b ~ Normal(0, 10)

    # compute the likelihood
    for i = 1:length(y)
        # compute the mean value for the data point
        μ = a + b * x[i]
        y[i] ~ Normal(μ, σ)
    end
end
1
Standard deviations must be positive, so we use a normal distribution truncated at zero.
2
We’ll keep these both relative uninformative to reflect a more “realistic” modeling scenario.
3
In this case, we specify the likelihood with a loop. We could also rewrite this as a joint likelihood over all of the data using linear algebra, which might be more efficient for large and/or complex models or datasets, but the loop is more readable in this simple case.
linear_regression (generic function with 2 methods)

Fitting The Model

Now we can call the sampler to draw from the posterior. We’ll use the No-U-Turn sampler (Hoffman & Gelman, 2014), which is a Hamiltonian Monte Carlo algorithm (a different category of MCMC sampler than the Metropolis-Hastings algorithm discussed in class). We’ll also use 4 chains so we can test that the chains are well-mixed, and each chain will be run for 5,000 iterations1

Hoffman, M. D., & Gelman, A. (2014). The No-U-Turn sampler: Adaptively setting path lengths in Hamiltonian Monte Carlo. J. Mach. Learn. Res., 15(47), 1593–1623.

1 Hamiltonian Monte Carlo samplers often need to be run for fewer iterations than Metropolis-Hastings samplers, as the exploratory step uses information about the gradient of the statistical model, versus the random walk of Metropolis-Hastings. The disadvantage is that this gradient information must be available, which is not always the case for external simulation models. Simulation models coded in Julia can usually be automatically differentiated by Turing’s tools, however.

# set up the sampler
model = linear_regression(x, y)
n_chains = 4
n_per_chain = 5000
chain = sample(model, NUTS(), MCMCThreads(), n_per_chain, n_chains, drop_warmup=true)
@show chain
1
Initialize the model with the data.
2
We use multiple chains to help diagnose convergence.
3
This sets the number of iterations for each chain.
4
Sample from the posterior using NUTS and drop the iterations used to warmup the sampler. The MCMCThreads() call tells the sampler to use available processor threads for the multiple chains, but it will just sample them in serial if only one thread exists.
5
The @show macro makes the display of the output a bit cleaner.
Warning: Only a single thread available: MCMC chains are not sampled in parallel
@ AbstractMCMC ~/.julia/packages/AbstractMCMC/Es490/src/sample.jl:307
Sampling (1 threads)   0%|                              |  ETA: N/A
Info: Found initial step size
  ϵ = 0.003125
Info: Found initial step size
  ϵ = 0.003125
Sampling (1 threads)  25%|███████▌                      |  ETA: 0:00:55
Sampling (1 threads)  50%|███████████████               |  ETA: 0:00:19
Info: Found initial step size
  ϵ = 0.00625
Info: Found initial step size
  ϵ = 0.000390625
Sampling (1 threads)  75%|██████████████████████▌       |  ETA: 0:00:07
Sampling (1 threads) 100%|██████████████████████████████| Time: 0:00:21
Sampling (1 threads) 100%|██████████████████████████████| Time: 0:00:21
chain = MCMC chain (5000×15×4 Array{Float64, 3})
Chains MCMC chain (5000×15×4 Array{Float64, 3}):

Iterations        = 1001:1:6000
Number of chains  = 4
Samples per chain = 5000
Wall duration     = 13.95 seconds
Compute duration  = 12.23 seconds
parameters        = σ, a, b
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat      Symbol   Float64   Float64   Float64     Float64     Float64   Float64   ⋯

           σ    5.3995    0.9907    0.0106   8907.8053   8357.6551    1.0008   ⋯
           a    7.3413    2.1497    0.0228   8980.2158   9558.6500    1.0006   ⋯
           b    1.7991    0.1916    0.0020   8980.6008   9477.6142    1.0006   ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

           σ    3.8939    4.6953    5.2650    5.9393    7.7752
           a    2.9639    5.9700    7.3826    8.7456   11.5116
           b    1.4242    1.6757    1.7948    1.9202    2.1879

How can we interpret the output? The first parts of the summary statistics are straightforward: we get the mean, standard deviation, and Monte Carlo standard error (mcse) of each parameter. We also get information about the effective sample size (ESS)2 and \(\hat{R}\), which measures the ratio of within-chain variance and across-chain variance as a check for convergence3.

2 The ESS reflects the efficiency of the sampler: this is an estimate of the equivalent number of independent samples; the more correlated the samples, the lower the ESS.

3 The closer \(\hat{R}\) is to 1, the better.

In this case, we can see that we were generally able to recover the “true” data-generating values of \(\sigma = 4\) and \(b = 2\), but \(a\) is slightly off (the mean is 3, rather than the data-generating value of 5). In fact, there is substantial uncertainty about \(a\), with a 95% credible interval of \((3.1, 11.4)\) (compared to \((1.4, 2.2)\) for \(b\)). This isn’t surprising: given the variance of the noise \(\sigma^2\), there are many different intercepts which could fit within that spread.

Let’s now plot the chains for visual inspection.

plot(chain)
Figure 2: Output from the MCMC sampler. Each row corresponds to a different parameter: \(\sigma\), \(a\), and \(b\). Each chain is shown in a different color. The left column shows the sampler traceplots, and the right column the resulting posterior distributions.

We can see from Figure 2 that our chains mixed well and seem to have converged to similar distributions! The traceplots have a “hairy caterpiller” appearance, suggesting relatively little autocorrelation. We can also see how much more uncertainty there is with the intercept \(a\), while the slope \(b\) is much more constrained.

Another interesting comparison we can make is with the maximum-likelihood estimate (MLE), which we can obtain through optimization.

mle_model = linear_regression(x, y)
mle = optimize(mle_model, MLE())
coef(mle)
1
This is where we use the Optim.jl package in this tutorial.
3-element Named Vector{Float64}
A  │ 
───┼────────
σ  │ 4.75545
a  │ 7.65636
b  │ 1.77736

We could also get the maximum a posteriori (MAP) estimate, which includes the prior density, by replacing MLE() with MAP().

Model Diagnostics and Posterior Predictive Checks

One advantage of the Bayesian modeling approach here is that we have access to a generative model, or a model which we can use to generate datasets. This means that we can now use Monte Carlo simulation, sampling from our posteriors, to look at how uncertainty in the parameter estimates propagates through the model. Let’s write a function which gets samples from the MCMC chains and generates datasets.

function mc_predict_regression(x, chain)
    # get the posterior samples
    a = Array(group(chain, :a))
    b = Array(group(chain, :b))
    σ = Array(group(chain, :σ))

    # loop and generate alternative realizations
    μ = a' .+ x * b'
    y = zeros((length(x), length(a)))
    for i = 1:length(a)
        y[:, i] = rand.(Normal.(μ[:, i], σ[i]))
    end
    return y
end
1
The Array(group()) syntax is more general than we need, but is useful if we have multiple variables which were sampled as a group, for example multiple regression coefficients. Otherwise, we can just use e.g. Array(chain, :a).
mc_predict_regression (generic function with 1 method)

Now we can generate a predictive interval and median and compare to the data.

x_pred = 0:20
y_pred = mc_predict_regression(x_pred, chain)
21×20000 Matrix{Float64}:
 -2.40865    1.9357    10.5589    15.7446   …   6.26661    0.222214  10.1702
 10.7976    18.8545     0.296641   3.38924     10.5971    10.8189    15.9384
 -0.417529  -0.885769   5.56482    7.39414     -0.981201   7.27796    6.9013
  4.33488   11.1663     9.51384    9.95352      5.88792   10.313     11.4574
  5.26926    5.42713   20.2392    13.7574      17.4582    10.5193     8.7948
 15.825     16.9226    19.3498    28.6916   …  13.6763    15.6275     6.24528
 16.504     14.0514    13.6398    18.3671      14.349     18.3797    14.9837
 23.6586    26.4983    29.1236    21.162       20.1668    20.2031    28.3504
 16.5461    23.2524    20.7667    22.3589      12.5181     9.4516    12.7805
 32.7533    13.6189    17.0692    26.5378      26.2872    16.5962    26.4759
  ⋮                                         ⋱                        
 30.6918    18.5685    30.0918    34.6757      23.5247    27.7565    28.9372
 27.8009    39.4466    34.1512    34.7717      30.2555    36.9157    26.2247
 26.9089    34.0929    36.0757    39.6863      28.6581    33.8906    35.4109
 44.642     31.9187    37.0507    25.4562   …  34.3312    27.4952    23.2894
 47.9252    36.3149    34.642     37.9717      41.1715    34.784     36.6823
 39.9288    27.0537    33.4583    42.8428      41.3582    32.8539    31.7867
 44.4605    34.1545    46.2137    35.4133      42.2067    39.4086    36.1819
 43.712     42.4212    41.7049    50.7652      51.3167    33.0422    49.135
 45.9144    43.2963    52.2372    50.891    …  39.7291    52.5323    46.89

Notice the dimension of y_pred: we have 20,000 columns, because we have 4 chains with 5,000 samples each. If we had wanted to subsample (which might be necessary if we had hundreds of thousands or millions of samples), we could have done that within mc_linear_regression before simulation.

# get the boundaries for the 95% prediction interval and the median
y_ci_low = quantile.(eachrow(y_pred), 0.025)
y_ci_hi = quantile.(eachrow(y_pred), 0.975)
y_med = quantile.(eachrow(y_pred), 0.5)

Now, let’s plot the prediction interval and median, and compare to the original data.

# plot prediction interval
plot(x_pred, y_ci_low, fillrange=y_ci_hi, xlabel=L"$x$", ylabel=L"$y$", fillalpha=0.3, fillcolor=:blue, label="95% Prediction Interval", legend=:topleft, linealpha=0)
plot!(x_pred, y_med, color=:blue, label="Prediction Median")
scatter!(x, y, color=:red, label="Data")
1
Plot the 95% posterior prediction interval as a shaded blue ribbon.
2
Plot the posterior prediction median as a blue line.
3
Plot the data as discrete red points.
Figure 3: Posterior 95% predictive interval and median for the linear regression model. The data is plotted in red for comparison.

From Figure 3, it looks like our model might be slightly under-confident, as with 20 data points, we would expect 5% of them (or 1 data point) to be outside the 95% prediction interval. It’s hard to tell with only 20 data points, though! We could resolve this by tightening our priors, but this depends on how much information we used to specify them in the first place. The goal shouldn’t be to hit a specific level of uncertainty, but if there is a sound reason to tighten the priors, we could do so.

Now let’s look at the residuals from the posterior median and the data. The partial autocorrelations plotted in Figure 4 are not fully convincing, as there are large autocorrelation coefficients with long lags, but the dataset is quite small, so it’s hard to draw strong conclusions. We won’t go further down this rabbit hole as we know our data-generating process involved independent noise, but for a real dataset, we might want to try a model specification with autocorrelated errors to compare.

# calculate the median predictions and residuals
y_pred_data = mc_predict_regression(x, chain)
y_med_data = quantile.(eachrow(y_pred_data), 0.5)
residuals = y_med_data .- y

# plot the residuals and a line to show the zero
plot(pacf(residuals, 1:4), line=:stem, marker=:circle, legend=:false, grid=:false, linewidth=2, xlabel="Lag", ylabel="Partial Autocorrelation", markersize=8, tickfontsize=14, guidefontsize=16, legendfontsize=16)
hline!([0], linestyle=:dot, color=:red)
Figure 4: Partial autocorrelation function of model residuals, relative to the predictive median.

Fitting Extreme Value Models to Tide Gauge Data

Let’s now look at an example of fitting an extreme value distribution (namely, a generalized extreme value distribution, or GEV) to tide gauge data. GEV distributions have three parameters:

  • \(\mu\), the location parameter, which reflects the positioning of the bulk of the GEV distribution;
  • \(\sigma\), the scale parameter, which reflects the width of the bulk;
  • \(\xi\), the shape parameter, which reflects the thickness and boundedness of the tail.

The shape parameter \(\xi\) is often of interest, as there are three classes of GEV distributions corresponding to different signs:

  • \(\xi < 0\) means that the distribution is bounded;
  • \(\xi = 0\) means that the distribution has a thinner tail, so the “extreme extremes” are less likely;
  • \(\xi > 0\) means that the distribution has a thicker tail.

Load Data

First, let’s load the data. We’ll use data from the University of Hawaii Sea Level Center (Caldwell et al., 2015) for San Francisco, from 1897-2013. If you don’t have this data and are working with the notebook, download it here. We’ll assume it’s in a data/ subdirectory, but change the path as needed.

Caldwell, P. C., Merrifield, M. A., & Thompson, P. R. (2015). Sea level measured by tide gauges from global oceans — the joint archive for sea level holdings (NCEI accession 0019568). NOAA National Centers for Environmental Information (NCEI). https://doi.org/10.7289/V5V40S7W

The dataset consists of dates and hours and the tide-gauge measurement, in mm. We’ll load the dataset into a DataFrame.

function load_data(fname)
    date_format = DateFormat("yyyy-mm-dd HH:MM:SS")
    df = @chain fname begin
        CSV.File(; delim=',', header=false)
        DataFrame
        rename("Column1" => "year",
                "Column2" => "month",
                "Column3" => "day",
                "Column4" => "hour",
                "Column5" => "gauge")
        # need to reformat the decimal date in the data file
        @transform :datetime = DateTime.(:year, :month, :day, :hour)
        # replace -99999 with missing
        @transform :gauge = ifelse.(abs.(:gauge) .>= 9999, missing, :gauge)
        select(:datetime, :gauge)
    end
    return df
end
1
This uses the DataFramesMeta.jl package, which makes it easy to string together commands to load and process data
2
Load the file, assuming there is no header.
3
Convert to a DataFrame.
4
Rename columns for ease of access.
5
Reformat the decimal datetime provided in the file into a Julia DateTime.
6
Replace missing data with missing.
7
Select only the :datetime and :gauge columns.
load_data (generic function with 1 method)
dat = load_data("data/h551a.csv")
first(dat, 6)
Table 1: Processed hourly tide gauge data from San Francisco, from 8/1/1897-1/31/2023.
6×2 DataFrame
Row datetime gauge
DateTime Int64?
1 1897-08-01T08:00:00 3292
2 1897-08-01T09:00:00 3322
3 1897-08-01T10:00:00 3139
4 1897-08-01T11:00:00 2835
5 1897-08-01T12:00:00 2377
6 1897-08-01T13:00:00 2012
@df dat plot(:datetime, :gauge, label="Observations", bottom_margin=9mm)
xaxis!("Date", xrot=30)
yaxis!("Mean Water Level")
1
This uses the DataFrame plotting recipe with the @df macro from StatsPlots.jl. This is not needed (you could replace e.g. :datetime with dat.datetime), but it cleans things up slightly.
Figure 5: Hourly mean water at the San Francisco tide gauge from 1897-2023.

Next, we need to detrend the data to remove the impacts of sea-level rise. We do this by removing a one-year moving average, centered on the data point, per the recommendation of Arns et al. (2013).

# calculate the moving average and subtract it off
ma_length = 366
ma_offset = Int(floor(ma_length/2))
moving_average(series,n) = [mean(@view series[i-n:i+n]) for i in n+1:length(series)-n]
dat_ma = DataFrame(datetime=dat.datetime[ma_offset+1:end-ma_offset], residual=dat.gauge[ma_offset+1:end-ma_offset] .- moving_average(dat.gauge, ma_offset))

# plot
@df dat_ma plot(:datetime, :residual, label="Detrended Observations", bottom_margin=9mm)
xaxis!("Date", xrot=30)
yaxis!("Mean Water Level")
Figure 6: Mean water level from the San Francisco tide gauge, detrended using a 1-year moving average centered on the data point, per the recommendation of Arns et al. (2013).
Arns, A., Wahl, T., Haigh, I. D., Jensen, J., & Pattiaratchi, C. (2013). Estimating extreme water level probabilities: A comparison of the direct methods and recommendations for best practise. Coast. Eng., 81, 51–66. https://doi.org/10.1016/j.coastaleng.2013.07.003

The last step in preparing the data is to find the annual maxima. We can do this using the groupby, transform, and combine functions from DataFrames.jl, as below.

# calculate the annual maxima
dat_ma = dropmissing(dat_ma)
dat_annmax = combine(dat_ma -> dat_ma[argmax(dat_ma.residual), :],
                groupby(DataFrames.transform(dat_ma, :datetime => x->year.(x)), :datetime_function))
delete!(dat_annmax, nrow(dat_annmax))

# make a histogram of the maxima to see the distribution
histogram(dat_annmax.residual, label=false)
ylabel!("Count")
xlabel!("Mean Water Level (mm)")
1
If we don’t drop the values which are missing, they will affect the next call to argmax.
2
This first groups the data based on the year (with groupby and using Dates.year() to get the year of each data point), then pulls the rows which correspond to the maxima for each year (using argmax).
3
This will delete the last year, in this case 2023, because the dataset only goes until March 2023 and this data point is almost certainly an outlier due to the limited data from that year.
Figure 7: Histogram of annual block maxima from 1898-2022 from the San Francisco tide gauge dataset.

Fit The Model

@model function gev_annmax(y)               
    μ ~ Normal(1000, 100)
    σ ~ truncated(Normal(0, 100); lower=0)
    ξ ~ Normal(0, 0.5)

    y ~ GeneralizedExtremeValue(μ, σ, ξ)
end

gev_model = gev_annmax(dat_annmax.residual)
n_chains = 4
n_per_chain = 5000
gev_chain = sample(gev_model, NUTS(), MCMCThreads(), n_per_chain, n_chains; drop_warmup=true)
@show gev_chain
1
Location parameter prior: We know that this is roughly on the 1000 mm order of magnitude, but want to keep this relatively broad.
2
Scale parameter prior: This parameter must be positive, so we use a normal truncated at zero.
3
Shape parameter prior: These are usually small and are hard to constrain, so we will use a more informative prior.
4
The data is independently GEV-distributed as we’ve removed the long-term trend and are using long blocks.
5
Initialize the model.
6
We use multiple chains to help diagnose convergence.
7
This sets the number of iterations for each chain.
8
Sample from the posterior using NUTS and drop the iterations used to warmup the sampler.
Warning: Only a single thread available: MCMC chains are not sampled in parallel
@ AbstractMCMC ~/.julia/packages/AbstractMCMC/Es490/src/sample.jl:307
Sampling (1 threads)   0%|                              |  ETA: N/A
Info: Found initial step size
  ϵ = 0.05
Info: Found initial step size
  ϵ = 0.0015625
Sampling (1 threads)  25%|███████▌                      |  ETA: 0:00:13
Sampling (1 threads)  50%|███████████████               |  ETA: 0:00:05
Info: Found initial step size
  ϵ = 0.000390625
Info: Found initial step size
  ϵ = 0.00625
Sampling (1 threads)  75%|██████████████████████▌       |  ETA: 0:00:02
Sampling (1 threads) 100%|██████████████████████████████| Time: 0:00:06
Sampling (1 threads) 100%|██████████████████████████████| Time: 0:00:06
gev_chain = MCMC chain (5000×15×4 Array{Float64, 3})
Chains MCMC chain (5000×15×4 Array{Float64, 3}):

Iterations        = 1001:1:6000
Number of chains  = 4
Samples per chain = 5000
Wall duration     = 5.68 seconds
Compute duration  = 5.13 seconds
parameters        = μ, σ, ξ
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters        mean       std      mcse     ess_bulk     ess_tail      rh     Symbol     Float64   Float64   Float64      Float64      Float64   Float ⋯

           μ   1257.8434    5.6421    0.0489   13375.4301   11521.3394    1.00 ⋯
           σ     57.2113    4.2214    0.0363   13619.4235   13686.7664    1.00 ⋯
           ξ      0.0295    0.0625    0.0005   14332.1783   11774.3672    1.00 ⋯
                                                               2 columns omitted

Quantiles
  parameters        2.5%       25.0%       50.0%       75.0%       97.5% 
      Symbol     Float64     Float64     Float64     Float64     Float64 

           μ   1246.9789   1254.0379   1257.7738   1261.5762   1269.0232
           σ     49.5961     54.2510     57.0043     59.8866     66.0566
           ξ     -0.0814     -0.0150      0.0258      0.0693      0.1629
plot(gev_chain)
Figure 8: Traceplots (left) and marginal distributions (right) from the MCMC sampler for the GEV model.

From Figure 8, it looks like all of the chains have converged to the same distribution; the Gelman-Rubin diagnostic is also close to 1 for all parameters. Next, we can look at a corner plot to see how the parameters are correlated.

corner(gev_chain)
Figure 9: Corner plot for the GEV model.

Figure 9 suggests that the location and scale parameters \(\mu\) and \(\sigma\) are positively correlated. This makes some intuitive sense, as increasing the location parameter shifts the bulk of the distribution in a positive direction, and the increasing scale parameter then increases the likelihood of lower values. However, if these parameters are increased, the shape parameter \(\xi\) decreases, as the tail of the GEV does not need to be as thick due to the increased proximity of outliers to the bulk.