Hi, I am trying to reshape a random variable
When I use reshape
, I got an error:
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [10].
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.
Even though reshape as same shape
Y = np.random.binomial(1,0.5,(100,10))
with pm.Model() as m:
theta = pm.Normal("theta",0,1,shape=(100,1))
b = pm.Normal("b",0,10,shape=(1,10))
b_2 = pm.Deterministic("b2",b.reshape((1,10)))
pm.Bernoulli("Y",logit_p=theta-b_2,observed=Y)
How to change the shape of random variables when using JAX sampler?
Any help is appreciated!
2 Likes
What version of v4 are you running? I installed (at some point) straight from the repo and your code works fine for me.
1 Like
Hi, My pymc version is 4.0.0b6
with m:
# work
tr = pm.sample(draws=10)
# error
tr = sampling_jax.sample_blackjax_nuts(10)
tr = sampling_jax.sample_numpyro_nuts(10)
This error seems to be related to Array Shape as Random Variable · Issue #5100 · google/jax · GitHub
2 Likes
Hi @qipengchen, I’m having the same issue. Did you happen to find a solution to this problem?
Hi there, I found a solution in my case to this problem. In my case, I realized that the following line is causing the issue:
betas = at.reshape(at.tile(alphas.T,num_respondents), (num_respondents, num_items)) + at.dot(betas_init,L_sigma)
I had to do the above calculation because I was taking samples from a standard multivariate normal and then needed to do the above transformation. To avoid that, I took samples from a multivariate normal with mean of “alphas” and covariance of “L_sigma.dot(L_sigma.T)” and then the above line was not necessary anymore.
As another update, replacing:
betas = at.reshape(at.tile(alphas.T,num_respondents), (num_respondents, num_items)) + at.dot(betas_init,L_sigma)
with:
betas = alphas.reshape(num_respondents, 1) + at.dot(betas_init,L_sigma)
Allowed me to sample from a standard normal.
1 Like
Hey @qipengchen, were you able to solve your problem? I have a simillar one.