Combining two components in a single model, with hierarchy

Hello, I have a model with two distinct components that I’m having trouble knitting together.


Component 1 is a measurement model that infers a value of interest from raw data:

y_i - data is normally distributed with mean coming from a function of a_i, predictors x_i, and other (scalar) parameters (...).
a_i - value of interest; inferred as part of a basic hierarchical model, with hyperpriors \mu_a and \sigma_a.


Component 2 is a process model which tries to explain how values of a are generated:

  • a_t is related to its previous value in time (a_{t-1}) and a function of e_t as well as other (scalar) parameters (...).

Initially I had set this up as two separate models:

# component 1 (simplified code)
with pm.Model() as model1:

    # hyperpriors
    mu_a = pm.Normal("mu_a", 0, 1)
    sigma_a = pm.Exponential("sigma_a", 1)

    # priors
    a = pm.Normal("a", mu_a, sigma_a, shape=shape)
    sigma = pm.Exponential("sigma", 1)

    # other parameters
    ...

    # likelihood
    pm.Normal("lik1", f(x, a, ...), sigma=sigma, observed=observed)

    trace1 = pm.sample()

Then, I would compute the posterior mean of a:

a_observed = trace1.posterior["a"].mean(["chain", "draw"])

and use that as the observed in the second component:

# component 2 - simplified code
with pm.Model() as model2:

    # priors
    p = pm.Beta("p", 1, 1)
    e = pm.Bernoulli("e", p, shape=shape_e)

    # other parameters
    ....

    a = aesara.scan(
        fun,
        sequences=e,
        outputs_info=at.zeros(n),
        non_sequences=[...]
    )
    sigma = pm.Exponential("sigma", 1)

    # likelihood
    pm.Normal("lik2", a, sigma=sigma, observed=a_observed)
    trace2 = pm.sample()

This works fine, but ideally uncertainty in a should flow from component 1 to component 2, rather than simply taking the mean (distributions of a are variable, sometimes they are tight and symmetrical, sometimes they are wide and skewed).

However, I don’t see a way of combining these two components without losing the hierarchical inference of a.

E.g. it could be something like:

# combined
with pm.Model() as model3:

   # (from component 2)
    p = pm.Beta("p", 1, 1)
    e = pm.Bernoulli("e", p, shape=shape_e)
    a = aesara.scan(
        fun,
        sequences=e,
        outputs_info=at.zeros(n),
        non_sequences=[...]
    )

    # (from component 1)
    mu_a = pm.Normal("mu_a", 0, 1)  # mu_a is now superfluous!
    sigma_a = pm.Exponential("sigma_a", 1)

    # priors
    a = pm.Normal("a", a, sigma_a)
    sigma = pm.Exponential("sigma", 1)
    pm.Normal("lik1", f(x, a, ...), sigma=sigma, observed=observed)

But, mu_a has become superfluous - we’ve lost the hierarchical sampling of a.

Even if this is the ‘correct’ way to do this, pragmatically the chains don’t mix at all - traces contain the same value repeated.


I’ve tried various different things (e.g. different formulations, compound steps in sampling, using pm.Potential, avoiding using aesara.scan) but none that sample.

Does anyone have any suggestions? Perhaps I’m thinking about this problem the wrong way?

Thanks in advance
David