Hierarchical Model with Multiple Predictors

The “dims” parameter will handle the (array) dimensionality of your parameter (in this case b_2). If you have, for instance, a discrete variable with 2 levels (e.g. “high”, “low”), then your b_2 distribution will take on those 2 levels. You can check that up by doing b_2.eval().shape, after you have defined your model. A quick example with toy data below, where you can see how PyMC manages the variable dimensionality throughout the whole modelling process (from defining the model to sampling).

import pymc as pm
import numpy as np
import pandas as pd

y = np.random.normal(0,1,12)
x = list(np.repeat("high",6)) + list(np.repeat("low",6))
z = list(np.repeat("a",4)) + list(np.repeat("b",4)) + list(np.repeat("c",4))

df = pd.DataFrame({"obs":y, "var":x, "cov":z})

df
Out[1]: 
         obs   var cov
0   1.407061  high   a
1   0.686457  high   a
2  -0.670788  high   a
3   0.576915  high   a
4   1.935287  high   b
5  -2.331960  high   b
6   0.438080   low   b
7  -1.190885   low   b
8   0.144347   low   c
9   0.584575   low   c
10 -0.953830   low   c
11  0.442864   low   c

var_idx = pd.Categorical(df['var']).codes
cov_idx = pd.Categorical(df['cov']).codes

coords = {'loc':df.index.values, 
          'var':df['var'].unique(),
          'cov':df['cov'].unique()}

with pm.Model(coords=coords) as model:
    v_idx = pm.ConstantData("var_idx", var_idx, dims="loc")
    c_idx = pm.ConstantData("cov_idx", cov_idx, dims="loc")

    a = pm.Normal("a", 0, 1)
    b = pm.Normal("b", 0, 1, dims="var")
    c = pm.Normal("c", 0, 1, dims="cov")

    m = a + b[v_idx] + c[c_idx]

    s = pm.HalfNormal("s", 1)

    y = pm.Normal("y", m, s, observed=df['obs'].values)

b.eval().shape
Out[2]: (2,)

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [a, b, c, s]
 |████████████| 100.00% [8000/8000 00:20<00:00 Sampling 4 chains, 0 divergences]Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 42 seconds.

idata
Out[4]: 
Inference data with groups:
	> posterior
	> sample_stats
	> observed_data
	> constant_data

idata.posterior['b']
Out[5]: 
<xarray.DataArray 'b' (chain: 4, draw: 1000, var: 2)>
array([[[-0.65472295, -0.31560767],
        [-0.49080176, -0.19496418],
        [-0.43480596,  0.2989079 ],
        ...,
        [ 1.03582084, -0.73988372],
        [-0.92628691,  0.33085689],
        [ 1.67374922,  0.47049486]],

       [[-0.33019963, -0.52679078],
        [ 0.3327297 , -0.23758491],
        [ 0.75950194, -0.78424727],
        ...,
        [-0.65862662, -1.49969723],
        [-0.36586425, -0.01656974],
        [-1.1275673 ,  0.99075037]],

       [[-1.26850014, -1.55347173],
        [-2.00023507, -0.00816708],
        [ 0.79022025, -0.60316954],
        ...,
        [ 1.07199527,  0.45203872],
        [ 0.65674821, -0.23802447],
        [ 0.83036322, -0.48556644]],

       [[ 1.72620799, -2.0443458 ],
        [ 1.88393521, -0.75462455],
        [ 1.28261164, -1.16342085],
        ...,
        [ 0.42779301,  0.64523998],
        [ 0.04914379, -0.33293152],
        [ 0.41417636,  1.01508946]]])
Coordinates:
  * chain    (chain) int32 0 1 2 3
  * draw     (draw) int32 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
  * var      (var) <U4 'high' 'low'
1 Like