The “dims” parameter will handle the (array) dimensionality of your parameter (in this case b_2). If you have, for instance, a discrete variable with 2 levels (e.g. “high”, “low”), then your b_2 distribution will take on those 2 levels. You can check that up by doing b_2.eval().shape, after you have defined your model. A quick example with toy data below, where you can see how PyMC manages the variable dimensionality throughout the whole modelling process (from defining the model to sampling).
import pymc as pm
import numpy as np
import pandas as pd
y = np.random.normal(0,1,12)
x = list(np.repeat("high",6)) + list(np.repeat("low",6))
z = list(np.repeat("a",4)) + list(np.repeat("b",4)) + list(np.repeat("c",4))
df = pd.DataFrame({"obs":y, "var":x, "cov":z})
df
Out[1]:
obs var cov
0 1.407061 high a
1 0.686457 high a
2 -0.670788 high a
3 0.576915 high a
4 1.935287 high b
5 -2.331960 high b
6 0.438080 low b
7 -1.190885 low b
8 0.144347 low c
9 0.584575 low c
10 -0.953830 low c
11 0.442864 low c
var_idx = pd.Categorical(df['var']).codes
cov_idx = pd.Categorical(df['cov']).codes
coords = {'loc':df.index.values,
'var':df['var'].unique(),
'cov':df['cov'].unique()}
with pm.Model(coords=coords) as model:
v_idx = pm.ConstantData("var_idx", var_idx, dims="loc")
c_idx = pm.ConstantData("cov_idx", cov_idx, dims="loc")
a = pm.Normal("a", 0, 1)
b = pm.Normal("b", 0, 1, dims="var")
c = pm.Normal("c", 0, 1, dims="cov")
m = a + b[v_idx] + c[c_idx]
s = pm.HalfNormal("s", 1)
y = pm.Normal("y", m, s, observed=df['obs'].values)
b.eval().shape
Out[2]: (2,)
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [a, b, c, s]
|████████████| 100.00% [8000/8000 00:20<00:00 Sampling 4 chains, 0 divergences]Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 42 seconds.
idata
Out[4]:
Inference data with groups:
> posterior
> sample_stats
> observed_data
> constant_data
idata.posterior['b']
Out[5]:
<xarray.DataArray 'b' (chain: 4, draw: 1000, var: 2)>
array([[[-0.65472295, -0.31560767],
[-0.49080176, -0.19496418],
[-0.43480596, 0.2989079 ],
...,
[ 1.03582084, -0.73988372],
[-0.92628691, 0.33085689],
[ 1.67374922, 0.47049486]],
[[-0.33019963, -0.52679078],
[ 0.3327297 , -0.23758491],
[ 0.75950194, -0.78424727],
...,
[-0.65862662, -1.49969723],
[-0.36586425, -0.01656974],
[-1.1275673 , 0.99075037]],
[[-1.26850014, -1.55347173],
[-2.00023507, -0.00816708],
[ 0.79022025, -0.60316954],
...,
[ 1.07199527, 0.45203872],
[ 0.65674821, -0.23802447],
[ 0.83036322, -0.48556644]],
[[ 1.72620799, -2.0443458 ],
[ 1.88393521, -0.75462455],
[ 1.28261164, -1.16342085],
...,
[ 0.42779301, 0.64523998],
[ 0.04914379, -0.33293152],
[ 0.41417636, 1.01508946]]])
Coordinates:
* chain (chain) int32 0 1 2 3
* draw (draw) int32 0 1 2 3 4 5 6 7 8 ... 992 993 994 995 996 997 998 999
* var (var) <U4 'high' 'low'