How do I manipulate shape when using sample_numpyro_nuts

Hi, I am trying to reshape a random variable
When I use reshape, I got an error:

TypeError: Shapes must be 1D sequences of concrete values of integer type, got [10].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.

Even though reshape as same shape

Y = np.random.binomial(1,0.5,(100,10))

with pm.Model() as m:

   theta = pm.Normal("theta",0,1,shape=(100,1))

   b = pm.Normal("b",0,10,shape=(1,10))

   b_2 = pm.Deterministic("b2",b.reshape((1,10)))


How to change the shape of random variables when using JAX sampler?
Any help is appreciated!


What version of v4 are you running? I installed (at some point) straight from the repo and your code works fine for me.

1 Like

Hi, My pymc version is 4.0.0b6

with m:
    # work 
    tr = pm.sample(draws=10)
    # error
    tr = sampling_jax.sample_blackjax_nuts(10)
    tr = sampling_jax.sample_numpyro_nuts(10)

This error seems to be related to Array Shape as Random Variable · Issue #5100 · google/jax · GitHub


Hi @qipengchen, I’m having the same issue. Did you happen to find a solution to this problem?

not yet. :joy:

Hi there, I found a solution in my case to this problem. In my case, I realized that the following line is causing the issue:

betas = at.reshape(at.tile(alphas.T,num_respondents), (num_respondents, num_items)) +,L_sigma)

I had to do the above calculation because I was taking samples from a standard multivariate normal and then needed to do the above transformation. To avoid that, I took samples from a multivariate normal with mean of “alphas” and covariance of “” and then the above line was not necessary anymore.

As another update, replacing:

betas = at.reshape(at.tile(alphas.T,num_respondents), (num_respondents, num_items)) +,L_sigma)


betas = alphas.reshape(num_respondents, 1) +,L_sigma)

Allowed me to sample from a standard normal.

1 Like

Hey @qipengchen, were you able to solve your problem? I have a simillar one.