Iterating & printing the prior parameters for each variable in a PyMC3 model

I am trying to create a table (in latex) for an upcoming paper detailing all the priors, prior parameters, posteriors, etc for each of the variables in my PyMC3 model.

To do that, I would like to be able to loop over each variable in the model, and access the relevant prior parameters (for example, the upper and lower bounds for Uniform variables, the mu & std for Normal variables, the variable shape, etc). If I simply print the model in python, that output is close to what I would like but I would prefer to be able to access specific model variables and reformat the outputs. However, I cannot find any pm.model.Model function, or any documentation about how to directly access:

  1. An iterable list of model variables
  2. A variable type/prior parameters, given some variable.

This seems like it should be something easy that I’m missing…

Ok, I should have a MWE, so let’s take the basic_model from Getting started with PyMC3 — PyMC3 3.11.4 documentation

basic_model = pm.Model()

with basic_model:
    # Priors for unknown model parameters
    alpha = pm.Normal("alpha", mu=0, sigma=10)
    beta = pm.Normal("beta", mu=0, sigma=10, shape=2)
    sigma = pm.HalfNormal("sigma", sigma=1)

    # Expected value of outcome
    mu = alpha + beta[0] * X1 + beta[1] * X2

    # Likelihood (sampling distribution) of observations
    Y_obs = pm.Normal("Y_obs", mu=mu, sigma=sigma, observed=Y)

What I would like it to be able to do:

for var in basic_model.some_iterable_list_of_vars:
    if var.some_way_to_access_prior_type=="Uniform":
        print(var.lower,var.upper)
    elif var.some_way_to_access_prior_type=="Normal":
        print(var.mu,var.std)

I don’t have a full answer, but this might be a starting point

for var in basic_model.free_RVs:
    # Get the distribution
    distribution = str(var.distribution).split("~")[1].strip()
    print(distribution)
    
    # Get the parameters. Not sure how to access them, but can display them
    display(var)
    print("\n")

You can get the raw latex using this, then so some funky stuff with the strings.

for var in basic_model.free_RVs:
    print(var.__latex__())

Gives you

$\text{alpha} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$
$\text{beta} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$
$\text{sigma_log__} \sim \text{TransformedDistribution}()$

But maybe someone knows an even easier way?

1 Like

There is some variable params_dist_repr or something that might be helpful as well

for var in basic_model.free_RVs:
    distribution = str(var.distribution).split("~")[1].strip()
    if distribution=='Normal':
        print(var.name,"$\\mathcal{{N}}(\\mu={0:.6g},\\sigma={1:.6g})$".format(var.distribution.mu.value,var.distribution.sigma.value))

This does indeed work to loop over variables. And for Normally distributed variables I’m able to get what I want out.

However, every other type of variable appears to be a TransformedDistribution, even Uniform variables (which are the majority of the rest of my model parameters). For those, I cannot seem to access any meta information (i.e. lower/upper limits)… Any ideas?