Multivariate truncated emissions and posterior inspection of POMDP

Hi all,

I am currently building a POMDP model that consists of several correlated nodes:

  • each node has a trajectory of states (depending on actions)
  • each node has a trajectory of emissions (depending on their state and actions)
  • there is an underlying covariance between the nodes’ emissions given their distance from each other on one axis (“distance_vector”)

The class HMM builds a trajectory of discrete Markov decision process given some actions.
I use a gp covariance function to parameterise the covariance between the nodes’ emissions. The terms variance_factor allow for different variances depending on the action undertaken.

Here is the model:

used_traj_fractal = np.array(used_traj_fractal)
used_traj_action = np.array(used_traj_action)

n_timesteps = len(used_traj_fractal[0])
n_trajectories = len(used_traj_fractal)
n_states = 3

print("n timesteps: ", n_timesteps)
print("n trajectories: ", n_trajectories)
print("n states: ", n_states)

with pm.Model() as model: 

    ### PRIOR STATES ###

    transition_mat1 = pm.Dirichlet(
        "p_transition1",
        a=np.array([[6.  , 1.5 , 0.5 ], 
                    [0.05, 7.  , 0.5 ], 
                    [0.1 , 0.1 , 9.  ]]),
        shape=(n_states, n_states))
    transition_mat2 = pm.Dirichlet(
        "p_transition2",
        a=np.array([[7.  , 0.2 , 0.1 ], 
                    [5.  , 1.5 , 0.1 ], 
                    [3.  , 3.  , 1.  ]]),
        shape=(n_states, n_states))
    transition_mat3 = pm.Dirichlet(
        "p_transition3",
        a=np.array([[7.  , 0.2 , 0.1 ], 
                    [6.  , 1.  , 0.1 ], 
                    [4.  , 3.  , 0.5 ]]),
        shape=(n_states, n_states))
    transition_mat = pm.Deterministic("transition_mat", pt.stack([transition_mat1, transition_mat2, transition_mat3], axis=0))
    
    init_probs = pm.Dirichlet('init_probs', a = [1,1,1], shape=n_states)
    

    ### PRIOR EMISSIONS ###

    mu_0 = pm.TruncatedNormal("mu_0", mu=[-0.25,-0.5,-1], sigma=0.5, upper=0, shape=3)
    mu_d = pm.Normal("mu_d", mu=[-0.01, -0.05, -0.1], sigma=0.1, shape=n_states)
    mu_r = pm.Normal("mu_r", mu=0, sigma=0.5, shape=n_states)

    k_r = pm.Beta("k_r", alpha=4, beta=2, shape=2)


    ### STATES ###

    states_all = []
    emissions_all = []

    states_all = HMM("states",
                    init_probs,
                    transition_mat,
                    used_traj_action,
                    n_timesteps,
                    n_states,
                    shape=(n_trajectories,n_timesteps))
    

    ### EMISSIONS ###

    # GP

    # Covariance function
    l = pm.Gamma("l", alpha=2, beta=2)
    eta = pm.HalfCauchy("eta", beta=0.1)
    nu = pm.Exponential("nu", lam=1/30)
    
    cov = pm.gp.cov.ExpQuad(1, l)
    Sigma = cov(X=distance_vector[:, None]) * eta

    variance_factor_main1 = pm.TruncatedNormal("var_fact1", mu=3, sigma=1, lower=0, shape=1)
    variance_factor_main2 = pm.TruncatedNormal("var_fact2", mu=5, sigma=1, lower=0, shape=1)
    variance_factor_deter = pt.as_tensor([1])
    variance_factor = pm.Deterministic("variance_factor", pt.stack([variance_factor_deter, variance_factor_main1, variance_factor_main2]))

    emissions_all = []
    init_mean = pm.Deterministic("mean_0", mu_0[states_all[:,0]])
    sigma_init = Sigma * pt.outer(variance_factor[used_traj_action[:,0]], variance_factor[used_traj_action[:,0]])

    emissions = pm.MvStudentT(f"obs_0", mu=init_mean, scale=sigma_init, nu=nu, observed=used_traj_fractal[:,0])
    emissions_all.append(emissions)

    for i in range(1, n_timesteps):
        mean_next_emission = pm.Deterministic(f"mean_{i}", next_emission(mu_d, mu_r, k_r, states_all[:,i],used_traj_action[:,i-1], emissions_all[i-1]))
        sigma_next = Sigma * pt.outer(variance_factor[used_traj_action[:,i]], variance_factor[used_traj_action[:,i]])
        emissions = pm.MvStudentT(f"obs_{i}", mu=mean_next_emission, scale=sigma_next, nu=nu, observed=used_traj_fractal[:,i])
        emissions_all.append(emissions)

For completeness, here is the next emission function:

def next_emission(mu_d, mu_r, k_r, states, actions, prev_emissions):
    
    emissions = pt.switch(
        pt.eq(actions, 0),
        mu_d[states] + prev_emissions,
        mu_r[states] + k_r[actions - 1] * prev_emissions,
    ) 
       
    return emissions

I have several questions:

1 - The emissions need to be strictly negative. I tried the use of Potential and Interval. Potential works well to force the mean to be set < 0, however for the emissions, Potential has no effect as emissions are observed and thus always respect the condition during sampling. Do you have suggestions on an easy implementation to force this behaviour? I am considering simply filtering the posterior_predictive to keep the ones respecting that constraint.

2 - I would like to make sure that my way to inspect posterior is good. For that, I usually plot posterior_predictive samples against true trajectories. Do I need to use the parameter prediction=True? I couldn’t notice difference in the output. Also, I saw some use of the set_data: in my case, would it be the same as building a similar model, say model_pred, identical to model but with different trajectories of actions (which are the only outside predictors that I use); and then to sample posterior_predictive, given the trace learned on model?

3 - If you have any other remarks on the model, they are obviously welcome.

Many thanks,

Christophe

Regarding negative emissions, you can use CustomDist to define a negative LogMvStudentT

def dist(nu, mu, sigma, size):
  mvt = pm.MvStudentT.dist(nu=nu, mu=mu, sigma=sigma, size=size)
  return -pm.math.exp(mvt)

with pm.Model() as m:
  ...
  pm.CustomDist("obs", nu, mu, sigma, dist=dist, observed=...)

PyMC will figure out the logp for you

https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.CustomDist.html

You may also want to use a Scan for the HMM instead of a loop, as it will provide a much more succinct computational graph

1 Like