Fitting CensoredDistributions.jl modified distributions with Turing.jl

Introduction

What are we going to do in this exercise

We'll demonstrate how to use CensoredDistributions.jl in conjunction with Turing.jl for Bayesian inference of epidemiological delay distributions. We'll cover the following key points:

  1. Defining a simple delay distribution model without observation processes.

  2. Exploring the prior distribution of this model.

  3. Defining a Bayesian model that incorporates double censoring and right truncation

  4. Generating synthetic data from the model using fixed parameters

  5. Fitting a naive model that ignores censoring

  6. Fitting a model that accounts for secondary event censoring and truncation but not primary event censoring.

  7. Fitting the full model that accounts for double censoring and right truncation.

  8. Using improved weight conditioning with joint observations and fix() patterns

  9. Demonstrating StatsBase.AbstractWeights integration patterns

What might I need to know before starting

This tutorial builds on the concepts introduced in Getting Started with CensoredDistributions.jl.

Packages used

We use CairoMakie for plotting, Turing for probabilistic programming, Chain.jl for data pipeline workflows, DataFramesMeta, Random, and StatsBase.

begin
    using DataFramesMeta
    using Turing
    using DynamicPPL
    using Distributions
    using Random
    using CairoMakie, PairPlots
    using StatsBase
    using CensoredDistributions
end

Generate synthetic data using Turing model simulation

We'll generate synthetic data by simulating from our Turing model with known true parameters. This approach ensures consistency between the data generation process and the model we'll use for inference, demonstrating how Turing models can be used for both simulation and fitting.

The proper Turing simulation approach:

  1. Define a Turing model that incorporates double censoring and right truncation

  2. Create a model instance with missing observations for simulation

  3. Use DynamicPPL's fix function to set parameters to their true values

  4. Sample from the prior predictive distribution by calling the model as a function

Define the true parameters for generating synthetic data

We start by defining the number of samples and the true parameters of the lognormal.

n = 2000;
meanlog = 1.5;
sdlog = 0.75;

Now we can define a lognormal distribution using Distributions.jl.

true_dist = LogNormal(meanlog, sdlog);

For each individual we now sample a primary and secondary event window as well as a relative observation time (relative to their censored primary event).

Define a reusable submodel for the latent delay distribution

To avoid code duplication across our models, we define a submodel that encapsulates the latent delay distribution parameters. This pattern allows us to reuse the same prior structure across all our models:

@model function latent_delay_dist()
    mu ~ Normal(1.0, 2.0);
    sigma ~ truncated(Normal(0.5, 1); lower = 0.0)
    return LogNormal(mu, sigma)
end
latent_delay_dist (generic function with 2 methods)

and define a helper function to standardize our pairplot visualizations across all model fits:

function plot_fit_with_truth(chain, truth_dict)
    f = pairplot(
        chain,
        PairPlots.Truth(
            truth_dict,
            label = "True Values"
        )
    )
    return f
end
plot_fit_with_truth (generic function with 1 method)

Prior predictive checks using pairplot

First, let's visualise the prior predictive distribution by sampling from the instantiated model with uninformative priors and comparing against our true parameters. This shows what the model believes before seeing any data.

begin
    Random.seed!(123);

    # Sample from the latent delay distribution prior
    latent_prior_samples = sample(latent_delay_dist(), Prior(), 1000)

    # Visualize the prior distribution
    plot_fit_with_truth(latent_prior_samples, (; mu = meanlog, sigma = sdlog))
end

Define the double censored model for simulation and fitting

Now we define our full model that incorporates double censoring and right truncation. This model uses the latent_delay_dist() submodel via to_submodel() to include the delay distribution parameters. It also uses our double_interval_censored() function to define each double censored and right truncated delay:

@model function CensoredDistributions_model(pwindow_bounds, swindow_bounds, obs_time_bounds)
    pwindows ~ arraydist([DiscreteUniform(pw[1], pw[2]) for pw in pwindow_bounds])
    swindows ~ arraydist([DiscreteUniform(sw[1], sw[2]) for sw in swindow_bounds])
    obs_times ~ arraydist([DiscreteUniform(ot[1], ot[2]) for ot in obs_time_bounds])

    dist ~ to_submodel(latent_delay_dist())

    pcens_dists = map(pwindows, obs_times, swindows) do pw, D, sw
        pe = Uniform(0, pw)
        double_interval_censored(
            dist; primary_event = pe, upper = D, interval = sw)
    end

    obs ~ weight(pcens_dists)
end
CensoredDistributions_model (generic function with 2 methods)

We also need to define our simulated observation window bounds for each observed delay as well as the bounds on the amount of censored time in which events have been observed (required to adjust for truncation). We will then combine these bounds with the model in order to simulate data.

bounds_df = DataFrame(
    pwindow_bounds = fill((1, 3), n),  # Each observation can have pwindow 1-3
    swindow_bounds = fill((1, 3), n),  # Each observation can have swindow 1-3
    obs_time_bounds = fill((8, 12), n)  # Each observation can have obs_time 8-12
)
pwindow_boundsswindow_boundsobs_time_bounds
1(1, 3)(1, 3)(8, 12)
2(1, 3)(1, 3)(8, 12)
3(1, 3)(1, 3)(8, 12)
4(1, 3)(1, 3)(8, 12)
5(1, 3)(1, 3)(8, 12)
6(1, 3)(1, 3)(8, 12)
7(1, 3)(1, 3)(8, 12)
8(1, 3)(1, 3)(8, 12)
9(1, 3)(1, 3)(8, 12)
10(1, 3)(1, 3)(8, 12)
...
2000(1, 3)(1, 3)(8, 12)

Simulate from the double censored distribution for each individual

Using the double censored model, we simulate data by sampling from the model using known true parameters. We use Turing's simulation approach with DynamicPPL's fix function to set parameters to their true values and sample from the prior predictive distribution. This means we can use the same model for simulation and inference.

We first create the base model which only specifies bounds to sample for the observations processes - we'll use this for both simulation and fitting:

base_model = @with bounds_df begin
    CensoredDistributions_model(:pwindow_bounds, :swindow_bounds, :obs_time_bounds)
end
Model{typeof(CensoredDistributions_model), (:pwindow_bounds, :swindow_bounds, :obs_time_bounds), (), (), Tuple{Vector{Tuple{Int64, Int64}}, Vector{Tuple{Int64, Int64}}, Vector{Tuple{Int64, Int64}}}, Tuple{}, DefaultContext}(CensoredDistributions_model, (pwindow_bounds = [(1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3)  …  (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3)], swindow_bounds = [(1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3)  …  (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3)], obs_time_bounds = [(8, 12), (8, 12), (8, 12), (8, 12), (8, 12), (8, 12), (8, 12), (8, 12), (8, 12), (8, 12)  …  (8, 12), (8, 12), (8, 12), (8, 12), (8, 12), (8, 12), (8, 12), (8, 12), (8, 12), (8, 12)]), NamedTuple(), DefaultContext())

For simulation, fix the distribution parameters to known true values:

simulation_model = fix(
    base_model,
    (
        @varname(dist.mu) => meanlog,
        @varname(dist.sigma) => sdlog
    )
)
Model{typeof(CensoredDistributions_model), (:pwindow_bounds, :swindow_bounds, :obs_time_bounds), (), (), Tuple{Vector{Tuple{Int64, Int64}}, Vector{Tuple{Int64, Int64}}, Vector{Tuple{Int64, Int64}}}, Tuple{}, DynamicPPL.FixedContext{Dict{VarName{:dist}, Float64}, DefaultContext}}(CensoredDistributions_model, (pwindow_bounds = [(1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3)  …  (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3)], swindow_bounds = [(1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3)  …  (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3), (1, 3)], obs_time_bounds = [(8, 12), (8, 12), (8, 12), (8, 12), (8, 12), (8, 12), (8, 12), (8, 12), (8, 12), (8, 12)  …  (8, 12), (8, 12), (8, 12), (8, 12), (8, 12), (8, 12), (8, 12), (8, 12), (8, 12), (8, 12)]), NamedTuple(), FixedContext(Dict{AbstractPPL.VarName{:dist}, Float64}(dist.mu => 1.5, dist.sigma => 0.75), DynamicPPL.DefaultContext()))

Now we can sample from the model using rand to get simmulated observations with their observation windows and relative observation time:

simulated_data = @chain simulation_model begin
    rand
    DataFrame
end
pwindowsswindowsobs_timesobs
121118.0
211103.0
33186.0
41183.0
533109.0
623110.0
73284.0
811117.0
923106.0
1023106.0
...
200012112.0

Visualise the simulated data

To make handling the data easier and later to speed up our models we first create a dataframe with the data we just generated, aggregated to unique combinations and count occurrences.

simulated_counts = @chain simulated_data begin
    @transform :obs_upper = :obs .+ :swindows
    @groupby All()
    @combine :n = length(:pwindows)
end
pwindowsswindowsobs_timesobsobs_uppern
121118.09.03
211103.04.07
33186.07.06
41183.04.013
533109.012.04
623110.03.011
73284.06.016
811117.08.02
923106.09.018
101194.05.03
...
2583180.01.01

Now let's compare the samples with and without double interval censoring to the true distribution. First let's calculate the empirical CDF:

empirical_cdf_obs = @with(simulated_counts, ecdf(:obs, weights = :n));
# Create a sequence of x values for the theoretical CDF
x_seq = @with simulated_counts begin
    range(minimum(:obs), stop = maximum(:obs) + 2, length = 100);
end;
begin
    # Calculate theoretical CDF using true log-normal distribution
    theoretical_cdf = @chain x_seq begin
        cdf.(true_dist, _)
    end;

    # Generate uncensored samples from the true distribution for comparison
    uncensored_samples = rand(true_dist, n);
    empirical_cdf_uncensored = ecdf(uncensored_samples);
end
ECDF{Vector{Float64}, Weights{Float64, Float64, Vector{Float64}}}([0.2831019516631152, 0.3413904185802441, 0.38240225531907285, 0.40072833011369846, 0.4194702379951856, 0.44993332958009924, 0.5516367868777338, 0.6124876191977476, 0.6137408252214501, 0.6223403193551887  …  33.083613274309144, 34.414921875981186, 36.03496752826056, 37.426275686830756, 37.67560231106244, 37.80858368604233, 38.09976203612415, 50.744749999903185, 54.44121931693789, 62.94007431254143], Float64[])
let
    f = Figure()
    ax = Axis(f[1, 1],
        title = "Comparison of Censored vs Uncensored vs Theoretical CDF",
        ylabel = "Cumulative Probability",
        xlabel = "Delay"
    )
    scatter!(
        ax,
        x_seq,
        empirical_cdf_obs.(x_seq),
        label = "Empirical CDF (Censored)",
        color = :blue
    )
    scatter!(
        ax,
        x_seq,
        empirical_cdf_uncensored.(x_seq),
        label = "Empirical CDF (Uncensored)",
        color = :red,
        marker = :cross
    )
    lines!(ax, x_seq, theoretical_cdf, label = "Theoretical CDF",
        color = :black, linewidth = 2)
    vlines!(
        ax, [mean(simulated_data.obs)], color = :blue, linestyle = :dash,
        label = "Censored mean", linewidth = 2)
    vlines!(ax, [mean(uncensored_samples)], color = :red, linestyle = :dash,
        label = "Uncensored mean", linewidth = 2)
    vlines!(ax, [mean(true_dist)], linestyle = :dash,
        label = "Theoretical mean", color = :black, linewidth = 2)
    axislegend(position = :rb)

    f
end

Fitting a naive model using Turing

We'll now fit a naive model that ignores the censoring process. This model treats the observed delay data as if it came directly from the uncensored delay distribution, providing a baseline for comparison.

@model function naive_model()
    dist ~ to_submodel(latent_delay_dist())
    obs ~ weight(dist)
end
naive_model (generic function with 2 methods)

Now let's instantiate and condition this model using weighted observations. We use a small constant to avoid issues at zero (a hint that this model is misspecified) and condition directly using NamedTuple format (values = values, weights = counts) which enables joint observation conditioning.

naive_mdl = @with simulated_counts begin
    condition(naive_model(), obs = (values = :obs .+ 1e-6, weights = :n))
end
Model{typeof(naive_model), (), (), (), Tuple{}, Tuple{}, ConditionContext{@NamedTuple{obs::@NamedTuple{values::Vector{Float64}, weights::Vector{Int64}}}, DefaultContext}}(naive_model, NamedTuple(), NamedTuple(), ConditionContext((obs = (values = [8.000001, 3.000001, 6.000001, 3.000001, 9.000001, 1.0e-6, 4.000001, 7.000001, 6.000001, 4.000001, 6.000001, 2.000001, 6.000001, 4.000001, 8.000001, 4.000001, 3.000001, 6.000001, 3.000001, 7.000001, 8.000001, 2.000001, 4.000001, 2.000001, 6.000001, 2.000001, 6.000001, 6.000001, 1.000001, 2.000001, 5.000001, 4.000001, 2.000001, 4.000001, 5.000001, 6.000001, 6.000001, 4.000001, 1.0e-6, 3.000001, 3.000001, 3.000001, 1.0e-6, 3.000001, 10.000001, 6.000001, 4.000001, 2.000001, 2.000001, 2.000001, 3.000001, 3.000001, 7.000001, 2.000001, 3.000001, 2.000001, 1.0e-6, 2.000001, 8.000001, 6.000001, 3.000001, 9.000001, 8.000001, 3.000001, 3.000001, 3.000001, 3.000001, 4.000001, 4.000001, 3.000001, 8.000001, 3.000001, 4.000001, 10.000001, 2.000001, 6.000001, 6.000001, 3.000001, 9.000001, 5.000001, 5.000001, 2.000001, 4.000001, 1.000001, 6.000001, 6.000001, 5.000001, 5.000001, 1.0e-6, 1.0e-6, 6.000001, 1.0e-6, 3.000001, 6.000001, 4.000001, 6.000001, 9.000001, 4.000001, 6.000001, 2.000001, 3.000001, 1.0e-6, 7.000001, 5.000001, 4.000001, 2.000001, 3.000001, 1.0e-6, 6.000001, 6.000001, 4.000001, 3.000001, 4.000001, 6.000001, 6.000001, 3.000001, 1.0e-6, 7.000001, 4.000001, 4.000001, 6.000001, 9.000001, 2.000001, 3.000001, 6.000001, 3.000001, 6.000001, 1.000001, 10.000001, 8.000001, 9.000001, 1.0e-6, 6.000001, 4.000001, 2.000001, 4.000001, 8.000001, 4.000001, 8.000001, 1.0e-6, 3.000001, 6.000001, 2.000001, 8.000001, 5.000001, 4.000001, 1.0e-6, 5.000001, 4.000001, 1.000001, 7.000001, 4.000001, 9.000001, 6.000001, 2.000001, 1.0e-6, 5.000001, 6.000001, 2.000001, 8.000001, 2.000001, 8.000001, 2.000001, 1.000001, 1.0e-6, 6.000001, 3.000001, 4.000001, 6.000001, 6.000001, 9.000001, 9.000001, 7.000001, 8.000001, 1.0e-6, 8.000001, 6.000001, 10.000001, 4.000001, 1.0e-6, 4.000001, 8.000001, 2.000001, 1.0e-6, 6.000001, 9.000001, 2.000001, 7.000001, 9.000001, 2.000001, 6.000001, 3.000001, 5.000001, 4.000001, 6.000001, 8.000001, 5.000001, 5.000001, 7.000001, 9.000001, 2.000001, 6.000001, 6.000001, 6.000001, 1.0e-6, 1.0e-6, 5.000001, 1.000001, 7.000001, 2.000001, 6.000001, 7.000001, 9.000001, 7.000001, 1.0e-6, 10.000001, 1.0e-6, 8.000001, 7.000001, 1.000001, 2.000001, 1.000001, 3.000001, 1.000001, 1.0e-6, 1.0e-6, 9.000001, 8.000001, 10.000001, 1.0e-6, 6.000001, 9.000001, 8.000001, 10.000001, 10.000001, 8.000001, 7.000001, 2.000001, 6.000001, 6.000001, 5.000001, 1.0e-6, 1.000001, 1.0e-6, 10.000001, 1.000001, 11.000001, 1.0e-6, 8.000001, 9.000001, 1.000001, 1.0e-6, 1.000001, 1.000001, 11.000001, 1.0e-6, 9.000001, 1.0e-6], weights = [3, 7, 6, 13, 4, 11, 16, 2, 18, 3, 12, 22, 2, 12, 3, 9, 22, 20, 28, 4, 1, 22, 16, 10, 7, 20, 10, 11, 4, 17, 7, 9, 11, 15, 5, 16, 17, 16, 15, 22, 9, 3, 8, 26, 4, 14, 14, 10, 15, 14, 20, 8, 2, 9, 19, 4, 2, 6, 6, 9, 24, 9, 4, 11, 22, 20, 20, 21, 16, 24, 7, 4, 5, 2, 10, 7, 11, 13, 3, 3, 7, 13, 6, 6, 4, 14, 5, 11, 3, 16, 15, 9, 19, 11, 6, 13, 8, 15, 6, 6, 8, 11, 6, 6, 7, 11, 6, 1, 17, 9, 5, 20, 3, 9, 21, 24, 5, 6, 10, 8, 10, 6, 6, 14, 7, 12, 8, 1, 6, 7, 3, 7, 9, 5, 3, 21, 7, 5, 6, 9, 8, 5, 7, 2, 6, 6, 9, 8, 6, 5, 4, 7, 3, 17, 8, 9, 8, 6, 5, 5, 5, 2, 9, 2, 5, 13, 6, 11, 1, 3, 5, 1, 7, 4, 4, 2, 9, 2, 9, 2, 11, 9, 8, 11, 6, 1, 7, 3, 3, 10, 8, 10, 5, 5, 5, 1, 5, 7, 2, 4, 4, 3, 3, 12, 1, 2, 9, 4, 7, 11, 5, 3, 3, 3, 12, 2, 3, 6, 2, 1, 2, 4, 9, 2, 3, 4, 2, 3, 1, 4, 4, 3, 1, 4, 1, 3, 3, 3, 4, 9, 4, 2, 3, 4, 2, 1, 1, 2, 1, 1, 2, 1, 1, 2, 2, 1, 3, 1]),), DynamicPPL.DefaultContext()))

Now let's fit the conditioned model using the joint observation pattern (values = values, weights = counts).

naive_fit = sample(naive_mdl, NUTS(), MCMCThreads(), 500, 4);
summarize(naive_fit)
parametersmeanstdmcseess_bulkess_tailrhatess_per_sec
1Symbol("dist.mu")0.03406060.1002730.002338151840.591523.681.00023300.407
2Symbol("dist.sigma")4.327340.06699930.001473712065.461611.071.00223337.108
plot_fit_with_truth(naive_fit, Dict("dist.mu" => meanlog, "dist.sigma" => sdlog))

We see that the model has converged and the diagnostics look good. However, just from the model posterior summary we see that we might not be very happy with the fit. mu is smaller than the target 1.5 and sigma is larger than the target 0.75.

Fitting a truncation-adjusted interval model

Now let's fit an intermediate model that accounts for interval censoring and right truncation but ignores the primary censoring process. This provides a comparison point between the naive model and the full model.

@model function interval_only_model(swindow_bounds, obs_time_bounds)
    swindows ~ arraydist([Uniform(sw[1], sw[2]) for sw in swindow_bounds])
    obs_times ~ arraydist([Uniform(ot[1], ot[2]) for ot in obs_time_bounds])

    dist ~ to_submodel(latent_delay_dist())

    icens_dists = map(obs_times, swindows) do D, sw
        truncated(interval_censored(dist, sw), upper = D)
    end
    obs ~ weight(icens_dists)
    return obs
end
interval_only_model (generic function with 2 methods)

Create the interval-only model with bounds, fix the window parameters, and condition on observations

interval_only_mdl = @with simulated_counts begin
    @chain interval_only_model(bounds_df.swindow_bounds, bounds_df.obs_time_bounds) begin
        fix((
            @varname(swindows) => :swindows,
            @varname(obs_times) => :obs_times
        ))
        condition(obs = (values = :obs, weights = :n))
    end
end;

Fit the interval-only model (Note: Turing.jl supports a wide range of fitting methods but here we use the No-U-turn sampler):

interval_only_fit = sample(interval_only_mdl, NUTS(), MCMCThreads(), 500, 4);
summarize(interval_only_fit)
parametersmeanstdmcseess_bulkess_tailrhatess_per_sec
1Symbol("dist.mu")1.822710.04091260.00198909428.632511.411.0147816.2042
2Symbol("dist.sigma")0.6726370.02525990.00119014454.916688.2931.010717.1978

Lets plot the posterior compared to the true values again. *Note: An annoying feature to to_submodel() is that it automatically prefixes the LHS name to all variables names in the model meaning we need to customise our postprocessing or turn this feature off.

plot_fit_with_truth(interval_only_fit, Dict("dist.mu" => meanlog, "dist.sigma" => sdlog))

Fitting the double censored model

Now we'll fit the full model that accounts for the censoring process. Since the CensoredDistributions_model was defined earlier and used for simulation, we'll reuse it for fitting. Here we fix the censoring windows and observation time based on the observed data and then condition on the weighted observations.

CensoredDistributions_mdl = @with simulated_counts begin
    @chain base_model begin
        fix((
            @varname(pwindows) => :pwindows,
            @varname(swindows) => :swindows,
            @varname(obs_times) => :obs_times
        ))
        condition(obs = (values = :obs, weights = :n))
    end
end;
CensoredDistributions_mdl()
(values = [8.0, 3.0, 6.0, 3.0, 9.0, 0.0, 4.0, 7.0, 6.0, 4.0  …  8.0, 9.0, 1.0, 0.0, 1.0, 1.0, 11.0, 0.0, 9.0, 0.0], weights = [3, 7, 6, 13, 4, 11, 16, 2, 18, 3  …  1, 1, 2, 1, 1, 2, 2, 1, 3, 1])

Now we fit the model to recover the true parameters from the synthetic data we generated earlier. This demonstrates the package's ability to perform accurate parameter recovery when the censoring process is properly modelled.

CensoredDistributions_fit = sample(
    CensoredDistributions_mdl, NUTS(), MCMCThreads(), 1000, 4);
summarize(CensoredDistributions_fit)
parametersmeanstdmcseess_bulkess_tailrhatess_per_sec
1Symbol("dist.mu")1.506460.03969060.001220841082.811480.661.002621.7983
2Symbol("dist.sigma")0.763840.03188330.0009590031118.721524.821.0042222.5212
plot_fit_with_truth(
    CensoredDistributions_fit, Dict("dist.mu" => meanlog, "dist.sigma" => sdlog))

We see that the model has converged and the diagnostics look good. We also see that the posterior means are near the true parameters and the 90% credible intervals include the true parameters.