How do I manipulate shape when using sample_numpyro_nuts

Hi, I am trying to reshape a random variable
When I use reshape, I got an error:

TypeError: Shapes must be 1D sequences of concrete values of integer type, got [10].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.

Even though reshape as same shape

Y = np.random.binomial(1,0.5,(100,10))

with pm.Model() as m:

   theta = pm.Normal("theta",0,1,shape=(100,1))

   b = pm.Normal("b",0,10,shape=(1,10))

   b_2 = pm.Deterministic("b2",b.reshape((1,10)))

   pm.Bernoulli("Y",logit_p=theta-b_2,observed=Y)

How to change the shape of random variables when using JAX sampler?
Any help is appreciated!

1 Like

What version of v4 are you running? I installed (at some point) straight from the repo and your code works fine for me.

1 Like

Hi, My pymc version is 4.0.0b6

with m:
    # work 
    tr = pm.sample(draws=10)
    # error
    tr = sampling_jax.sample_blackjax_nuts(10)
    tr = sampling_jax.sample_numpyro_nuts(10)

This error seems to be related to Array Shape as Random Variable · Issue #5100 · google/jax · GitHub

2 Likes

Hi @qipengchen, I’m having the same issue. Did you happen to find a solution to this problem?

not yet. :joy: