Setting prior for one of the 'dims'

Hello, PyMC community!

Imagine that I am modelling the impact of X on Y for different countries.
And I am doing it using the following piece of code:

beta_country = pm.Normal(f"beta_country", mu=0, sigma=3, dims="country")

As I understand, this will give the same Normal priors for each of the country with specified parameters.
But imagine that for some one specific country “C” I have a separate belief. (i.e. for all countries I would set mu=0, but for that country I would set it to mu=10).

Can you please recommend the best way to put it to code?
One version I have in mind is to use a for-loop, something like:

(pseudocode)
for country, country_prior in zip([countries,country_priors]):
   beta_{country} = pm.Normal("beta_{country}", mu=country_prior, sigma=3)

Please let know if there are better versions.
Thank you!

You can provide a sequence to mu, sigma, or both, and PyMC will “broadcast” the parameteres across the RVs that are created. For example, here are 9 countries with mu of 0, and one with mu of 10:

n_countries = 10
mu_country = np.zeros(n_countries)
mu_country[3] = 10
beta_country = pm.Normal('beta_country', mu=mu_country, sigma=3, dims="country")

Look at some draws:

with np.printoptions(suppress=True, precision=3):
    print(pm.draw(beta_country, 10))
>>>Out: [[ 5.973 -5.269  3.147  7.184 -1.387  0.398 -2.57   0.23   0.08   0.291]
>>>      [ 0.042  1.767  4.018  8.41  -3.219  0.786  3.871  0.99   1.003  1.992]
>>>      [ 3.58  -3.631  4.387 13.197  4.132  1.609 -3.308  1.881 -0.717 -0.658]
>>>      [ 0.057 -3.223 -1.516 11.424 -6.257  3.19  -2.073  0.682  2.305  1.862]
>>>      [ 4.132  0.885  0.783 11.652 -1.128  0.276 -1.531 -6.94   1.327  3.237]
>>>      [-3.827 -5.264  0.452  9.152 -1.022  1.621  0.407  1.508  6.123 -1.106]
>>>      [ 1.843  0.676 -1.493  7.441 -0.922  3.346 -3.009 -0.124 -3.32  -3.799]
>>>      [ 1.251 -2.288  0.307  7.046  1.591 -1.67  -2.838  1.829 -2.188  1.75 ]
>>>      [-0.052  2.99   1.198  7.059 -0.276  1.99   2.898 -4.114  1.749  0.47 ]
>>>      [-3.075  0.637 -2.235 11.917  2.998 -0.811 -3.973  1.682  2.09  -2.479]]

You can see the 4th column has a different mean than the others, as expected.

1 Like

Thanks a lot, @jessegrabowski !
Sounds like a solution!