Custom Categorical Distribution - ensure bounded candidates

Hi all,

I am new to Pymc and I would like to create a custom distribution that outputs a Markov Chain:

(initial probs) x (transition matrix prob) x (length of the chain) => markov chain

For this, I used the following tutorial: Implementing a RandomVariable Distribution — PyMC 5.11.0 documentation

Most of the implementation works: logp returns the right values, I can generate data, … However, now I want to use this distribution as a latent variable for a HMM:

with pm.Model() as hmm_observed_states: 
    transition_mat = pm.Dirichlet(
        "p_transition",
        a=pt.ones((n_states, n_states)) + pt.eye(n_states),
        shape=(n_states, n_states))
    
    init_probs = pm.Dirichlet('init_probs', a = pt.ones((n_states,)), shape=n_states)

    
    # Prior for mu and sigma
    μ = pm.Normal("mu", mu=[5, 0., -5], sigma=10, shape=(n_states,))
    σ = pm.Exponential("sigma", lam=2, shape=(n_states,))

    # for i in range(n_trajectories):
    for i in range(1):

        states_i = HMM(
            f"HMM_{i}",
            init_probs,
            transition_mat,
            n_timesteps,
            n_states,
            shape=(1,n_timesteps),
        )

        # states_i = pt.clip(states_i, 0, n_states - 1)

        emissions_i = pm.Normal(
            f'emissions_{i}',
            mu=μ[states_i[0]],
            sigma=σ[states_i[0]],
            observed=trajectories[i].astype("float")
        )

By doing so, an error raised whenever the proposed new node q (proposed by Metropolis step) has a value outside of the range of states. E.g. q = [1,4,1,1] when only 3 states. Therefore the issue is simply an Index error for sigma=σ[states_i[0]] whenever states_i consists of values outside of [0,1,2].

Of course, the logp corresponding to these values is set to -inf.

The problem is “solved” by setting artificially a pt.clip between the HMM states and the Normal distribution. However, it is inelegant, and, I suppose, very much inefficient.

=> Is there a way to force the proposed values of Metropolis to be in a certain range?

Many thanks already,
Chris

Hi Chris,

You could have a look at the DiscreteMarkovChain distribution in pymc-experimental. We also have an example showing how it can be used in an HMM. As of a few weeks ago, DiscreteMarkovChain is now compatible with MarginalModel, so you can automatically marginalize out the markov chain in an HMM. There are some limitations: you can only have an order 1 HMM (we haven’t added support for marginalizing more lags yet), and you can’t automatically recover the hidden states after marginalization.

To actually answer your question as posed, clip is indeed the best way. See this discussion for details.

1 Like

Many thanks for your help!

I looked at what you sent and some new questions arose.

First, it seems that the DiscreteMarkovChain distribution you shared works only with Binary sampler (according to the tutorial itself, and on my data too). Is that a binding issue / is there some get-around that you are aware of?

Second, I would greatly appreciate it if you could provide me what are the benefits of marginalization. I see that the Hidden States of the tutorial you sent aren’t using any marginalization, and yet are meaningful.

Of course, I am also currently searching for this information on my own, but any help is more than welcome.

Thanks again,
Chris

The DiscreteMarkovChain uses a CategoricalGibbs sampler, so it’s not binary. If you only have two states, it will fall back to a binary sampler. The work around is to marginalize the hidden states. The reason to do this is that it makes the logp computation fully continuous and thus amenable to NUTS. In addition, marginalization reduces the variance of MCMC estimation. You can see this tutorial by @zaxtax for a theoretical derivation showing the benefits of marginalization.

I can’t really comment on what is going wrong with your implementation unless you share the actual code for your HMM distribution. I can say that a loop is not what you want to be doing though; you should be using a scan (this is what DiscreteMarkovChain does internally).

Alright, I will look into the marginalizing tutorial tomorrow. Thank you!

Regarding the implementation, my goal is to implement this pymc3 model (Hidden-Markov-Models/HMMs for MDP and POMDP.ipynb at master · giarcieri/Hidden-Markov-Models · GitHub) into pymc5, and later to adapt it for different tasks. The first step for that is to build efficient trajectories of latent states. For that I looked, with no success until now into:

1 - Building a class for MarkovChain of states from scratch (implementing RV, and class).
2 - Building a class for MarkovChain of states from CustomDist function
3 - Augmenting the class Categorical distribution (as done in pymc3)
4 - Using the DiscreteMarkovChain class

So my first question would simply be: What method for custom distribution would you recommend in my case?

(I understand that this might be beyond the usual scope of support, and if so, please feel free to let me know. Any pointers or suggestions would be greatly appreciated.)

This seems very close to what is shown in the HMM notebook I linked. You could use the code from that pymc3 notebook in a CustomDist pretty much as-is by passing in their logp function, but you won’t have access to prior or posterior predictive sampling in that case (unless you also write a function to sample from the HMM). But it seems like re-inventing the wheel?

If you want to get your hands messy working on the problem, help is welcome on automatically unmarginalizing HMMs and on support for higher order lags. I’d say the first one is more important – you’d get the best of all worlds – fast and efficient sampling, plus hidden state inference.

I believe that using my own distribution for later adjustments depending on my needs. Currently, I have the following HMM states class:

class HMMRV(RandomVariable):
    name: str = "categorical"

    ndim_supp: int = 1
    ndims_params: List[int] = [1,2,0]
    dtype: str = "int64"
    _print_name: Tuple[str, str] = ("hmm", "\\operatorname{hmm}")

    @classmethod
    def _supp_shape_from_params(cls, dist_params, param_shapes) -> Tuple[int, ...]:
        return (dist_params[2],)
    
    def __call__(self, p_init=None, p_mat=None, len_traj=None, n_states=None, size=None, **kwargs) -> TensorVariable:
        return super().__call__(p_init, p_mat, len_traj, n_states, size=size, **kwargs)

    @classmethod
    def rng_fn(
        cls,
        rng: np.random.RandomState,
        p_init: np.ndarray,
        p_mat: np.matrix,
        len_traj: int,
        n_states: int,
        size: Tuple[int, ...],
    ) -> np.ndarray:
        samples = np.zeros(size + (len_traj,), dtype=np.int64)
        for i in np.ndindex(size):
            current_state = rng.choice(np.arange(len(p_init)), p=p_init)
            samples[i + (0,)] = current_state  # Store initial state

            # Generate the trajectory using the transition matrix
            for j in range(1, len_traj):
                next_state = rng.choice(np.arange(len(p_init)), p=p_mat[current_state])
                samples[i + (j,)] = next_state
                current_state = next_state

        return samples
    
hmm = HMMRV()

class HMM(pm.Discrete):
    rv_op = hmm

    @classmethod
    def dist(cls, p_init, p_trans_mat, traj_length, n_states, *args, **kwargs):
        return super().dist([p_init, p_trans_mat, traj_length, n_states], **kwargs)
    
    def moment(rv, size, p_init, p_trans_mat, traj_length, n_states): # TO BE IMPROVED !!!

        p_init_ = p_init.eval()
        p_trans_mat_ = p_trans_mat.eval()
        traj_length_ = traj_length.eval()
        size_ = size.eval()

        mode = np.zeros(shape=[size_[0], traj_length_])  # Assuming size is the size of the state space
        current_state = np.argmax(p_init_, axis=-1)  # Get the initial state with maximum probability

        for i in range(1, traj_length_):
            next_state = np.argmax(p_trans_mat_[current_state], axis=-1)
            mode[:, i] = next_state
            current_state = next_state
        
        mode = pt.as_tensor(mode)

        return mode

    def logp(value, p_init, p_trans_mat, traj_length, n_states):

            k = n_states
            value_clip = pt.clip(value, 0, k - 1)
            # print(value_clip.eval())

            if value_clip.ndim == 1:
                value_clip = value_clip[None, :]

            value_init = value_clip[:, 0]
            value_trans1 = value_clip[:, 0:(traj_length-1)]
            value_trans2 = value_clip[:, 1:(traj_length)]

            logp_init = pm.logp(pm.Categorical.dist(p=p_init), value_init)

            logp_trans = pm.logp(pm.Categorical.dist(p=p_trans_mat[value_trans1]), value_trans2)[0]
            logp_trans = pt.sum(logp_trans, axis=1)

            total_log = logp_init +logp_trans

            res = pt.switch(
            pt.any(pt.or_(pt.lt(value, 0), pt.gt(value, k - 1))),
            -np.inf,
            total_log,
            )
            
            return res

And a dummy class for Emissions (which simply is a gaussian emission, that returns logp=-inf for index not in range:

class EmissionsRV(RandomVariable):
    name: str = "Emissions"

    ndim_supp: int = 0
    ndims_params: List[int] = [1,1,0]
    dtype: str = "floatX"
    _print_name: Tuple[str, str] = ("emissions", "\\operatorname{emissions}")
    
    def __call__(self, mu=None, sigma=None, index=None, n_states = None, size=None, **kwargs) -> TensorVariable:
        return super().__call__(mu, sigma, index, n_states, size=size, **kwargs)

    @classmethod
    def rng_fn(
        cls,
        rng: np.random.RandomState,
        mu: np.ndarray,
        sigma: np.matrix,
        index: int,
        n_states: int,
        size: Tuple[int, ...],
    ) -> np.ndarray:
        
        samples = np.zeros(size, dtype=np.float64)
        for i in np.ndindex(size):
            samples[i] = rng.normal(loc=mu[index], scale=sigma[index])
        return samples
    
emissions = EmissionsRV()

class Emissions(pm.Continuous):
    rv_op = emissions

    @classmethod
    def dist(cls, mu, sigma, index, n_states, *args, **kwargs):
        return super().dist([mu, sigma, index, n_states], **kwargs)
    
    def moment(rv, size, mu, sigma, index, n_states): # TO BE CHANGED !!!

        mode = mu[index]

        return mode

    def logp(value, mu, sigma, index, n_states):

            k = n_states
            index_clipped = pt.clip(index, 0, k - 1)

            res = pt.switch(
                pt.any(pt.or_(pt.lt(index, 0), pt.gt(index, k - 1))),
                -np.inf,
                pm.logp(pm.Normal.dist(mu=mu[index_clipped], sigma=sigma[index_clipped]), value),
            )
            
            return res

Now, even with simplistic generated data, the inference is not good: it won’t properly understand the states transition and will increase the variance of the emissions to compensate for the wrong states.

Regarding the support on new features, I doubt I have the required skills and lack time at the moment. Therefore, I’m afraid I won’t be able to help at the moment.

I don’t think you’re implementing a hidden markov chain with those code snippets, but an observed markov chain? In an HMM the logp should account for all possible chain states, but your logp seems to assume a specific chain is being sampled?

If I am not wrong your code is equivalent to a DiscreteMarkovChain + Normal RV based on the chain draws. The fact you have two separate RVs is also indicative of it not being a HMM. In a HMM you wouldn’t sample the states, but marginalize over them.

Did I miss something?

So, what you are saying is that I am currently doing:

for states S and emissions Y,
P[theta | Y, S] = 1/Z * P[theta] * P[Y,S | theta] = 1/Z * P[theta] * P[S | theta] * P[Y | S, theta]

But as I don’t observe the states S, I cannot do this and I should instead do:
P[theta | Y] = 1/Z * P[theta] * \integral[ P[S | theta] * P[Y | S, theta] , dS]

By not-marginalizing the latent variable (here the chain of hidden states), what are the drawbacks? It is fundamentally wrong or a loss of efficiency?
If I am not mistaken, the notebooks of HMM shared before pymc-experimental/notebooks/discrete_markov_chain.ipynb at main · pymc-devs/pymc-experimental · GitHub and the notebook I am reproducing Hidden-Markov-Models/HMMs for MDP and POMDP.ipynb at master · giarcieri/Hidden-Markov-Models · GitHub are both using latent variables without marginalization.

What would then be your approach to the problem, i.e. building a POMDP (i.e. HMM with action-dependent transitions and emissions) with more complex emission model to come (e.g. Autogressive emissions, Gaussian process, …)? Using the DiscreteMarkovChain class together with Marginalization?

Thank you a lot for taking the time to look at my issue! It is much appreciated :slight_smile:

Apologies I was mixing hidden with marginalized. Your model is a HMM if you don’t observe the states but infer them. Whether you marginalize it or not doesn’t change the kind of model.


The problem with not marginalizing is that you have a discrete variable that can’t be sampled with NUTS.

It can be hard to sample HMM models without a specialized sampler due to the fact it is (usually) a very long multivariate variable

PS that notebook is very slick. Only issue I see with newer PyMC is that the CategoricalGibbs won’t be picked automatically for your HMM states RV. You can try to specify manually in the step argument of pm.sample but I am not sure it will be happy with it. We can tweak it if it fails.

In either case I don’t think you need custom RVs? You can just use the DiscreteMarkovChain and whatever likelihood on the observations.

This Reinforcement Learning example may be instructive, in that case the “chain” of actions is observed so as to speak, but could also be sampled: Fitting a Reinforcement Learning Model to Behavioral Data with PyMC — PyMC example gallery

UPDATE: I noticed that not a single state in the posterior had value 1 (only 0 or 2), this is clear evidence of a bug larger than simply low inference power - tbc

The issue when using DiscreteMarkovChain is that, afaik, I won’t be able to incorporate the actions from the POMDP but can only use fixed transition matrices.

And indeed, CategoricalGibbs won’t work, raising a clear ValueError: All variables must be categorical or binaryfor CategoricalGibbsMetropolis

Do you think the sampler can alone explain the poor inference on simplistic generated data? It seems unlikely that the results are so much poorer than the much more complex model from the notebook

For completeness, here is the data and the model:

The data:

n_trajectories = 50
n_timesteps = 15
n_states = 3

# Initial probs
p0 = np.array([0.5, 0.25, 0.25])

# Transition matrix
P = np.array([[0.5, 0.25, 0.25],
              [0., 0.5, 0.5],
              [0.25, 0.25, 0.5]])

# Emission parameters (t-distribution), transition
mu1 = np.array([10, 0, -10])
sigma1 = np.array([0.1, 0.1, 0.1])

# Generate the HMM
for i in range(n_trajectories):
    # Generate the trajectory
    z = np.zeros(n_timesteps, dtype=int)
    x = np.zeros(n_timesteps)
    for t in range(n_timesteps):
        if t == 0:
            z[t] = np.random.choice(n_states, p=p0)
        else:
            z[t] = np.random.choice(n_states, p=P[z[t - 1]])

        # x[t] = np.random.standard_t(df=nu1[z[t]]) * sigma1[z[t]] + mu1[z[t]]
        x[t] = np.random.normal(loc=mu1[z[t]], scale=sigma1[z[t]])

    # Save the trajectory
    if i == 0:
        trajectories = np.array([x])
        states = np.array([z])
    else:
        trajectories = np.vstack([trajectories, x])
        states = np.vstack([states, z])

And the model:

with pm.Model() as model: 
    # Prior for transition matrix
    transition_mat = pm.Dirichlet(
        "p_transition",
        a=pt.ones((n_states, n_states)) + pt.eye(n_states),
        shape=(n_states, n_states))
    
    init_probs = pm.Dirichlet('init_probs', a = pt.ones((n_states,)), shape=n_states)
    
    # Prior for mu and sigma
    μ = pm.Normal("mu", mu=[5, 0., -5], sigma=1, shape=(n_states,))
    σ = pm.Exponential("sigma", lam=2, shape=(n_states,))

    states_all = HMM(
        f"HMM",
        init_probs,
        transition_mat,
        n_timesteps,
        n_states,
        shape=(n_trajectories,n_timesteps),
    )

    emissions_all = Emissions(
        f"Emissions",
        μ,
        σ,
        states_all,
        n_states,
        observed=trajectories,
    )

I spent the last few days looking at the issue some more and here are some concluding notes on this problem (for the curiosity of interested readers):

  • When I define a moment for the HMM class, all the posterior are set at the first moment it samples (e.g. if given the initial p_init and p_transition, the mode is [0,1,1,1,1,1], this sequence will never change in every iteration)

  • when I deactivate the moment of HMMStates, the posterior consists of repeated same trajectories

I looked into the sampler and noticed that, in Metropolis:

  • At first, the proposed q will systematically be out of bound: e.g. negative, or more than the number of states => this leads automatically to a logp of zero and rejecting the proposed value
  • Then, the sampler will propose always a smaller delta, leading to the proposed q being identical to the current q0, and a posterior being always the same trajectory

I changed the sampler to CategoricalGibbsMetropolis and avoided the error by editing these lines in metropolis.py:

            if isinstance(distr, CategoricalRV):
       
                k_graph = rv_var.owner.inputs[3].shape[-1]
                (k_graph,) = model.replace_rvs_by_values((k_graph,))
                k = model.compile_fn(k_graph, inputs=model.value_vars, on_unused_input="ignore")(
                    initial_point
                )
            elif isinstance(distr, BernoulliRV):
                k = 2
            else:
                k = 3  # <======= number of states I use
            # else:
            #     raise ValueError(
            #         "All variables must be categorical or binary" + "for CategoricalGibbsMetropolis"
            #     )
            start = len(dimcats)
            dimcats += [(dim, k) for dim in range(start, start + v_init_val.size)]

This seems to solve the sampling issues and leads to good inference on generated data.

1 Like

Yes you definitely need a specialized sampler for markov chains.