Pm.sampling_jax to sample a MvNormal()

Hello! I have a code in PyMC3 using Theano Tensor and I’m trying to transcribe it to PyMC4 using Aesara. In my problem I want to sample the posterior distribution of a Normal Multivariate, in this case I get 50 observations and each observation is a vector of 4 samples (one for each variable in my multivariate distribution). If I use the pm.sample(), I can sample the posterior distribution and get some results similar to my code using PyMC3. However, if I try to apply pm.sampling_jax.sample_numpyro_nuts() or pm.sampling_jax.sample_blackjax_nuts() I get the following error message:

TypeError: Shapes must be 1D sequences of concrete values of integer type, got [50].
If using jit, try using static_argnums or applying jit to smaller subfunctions.

Does someone know what could be my problem. Can I use sampling_jax to sample a MvNormal()? I’m new using Aesara and JAX so any tips will be helpfull. This is my model.

with pm.Model() as pos_model:

x = pm.Uniform('x', 0, L, shape = 1)
y = pm.Uniform('y', 0, B, shape = 1)
ϕ = at.stack([x, y], axis = 1)[0]

β = pm.Normal('β', mu = 0, sigma = 10, shape = 1)

d = [[] for ii in range(4)]
for ii in  range(0,  4):
    d[ii] = pm.Deterministic('d__'+ str(ii), t_euclidean_distance(anchor_vt[ii], ϕ))
μ = [[] for ii in range((4))]
for ii in  range(0, 4):
    μ[ii] = β*at.log10(at.repeat(at.stacklists([d[ii]]), 50, axis=0))
σ = [[] for ii in range((4))]
for ii in  range(0, 4):
    σ[ii] = pm.HalfNormal('σ__' + str(ii), sigma = 10)

σ_ = at.stack(σ)
cov = at.eye(4)*σ
μ_hat = at.stack(μ).T

## Define the multivariate normal distrib.
joint_obs = pm.MvNormal('joint', mu=μ_hat, cov=cov, observed = ρ_hat)

I have a suspicion of the problem but can you post the whole error message to confirm?

This is the full error message.

Compiling…
Compilation time = 0:00:03.496124
Sampling…

TypeError Traceback (most recent call last)
C:\Users\XXX\AppData\Local\Temp/ipykernel_84224/3518610046.py in
1 n_draws, n_jobs, n_tunes = 2500, 2, 1500
2 with pos_model:
----> 3 idata2 = pm.sampling_jax.sample_numpyro_nuts(draws = n_draws, chains = n_jobs, tune = n_tunes)
4 #idata2 = pm.sampling_jax.sample_blackjax_nuts(draws = n_draws, chains = n_jobs, tune = n_tunes)
5 #idata2 = pm.sample(draws = n_draws, step = pm.NUTS(target_accept = 0.8), chains = n_jobs, cores = 4, tune = n_tunes)

~\Anaconda3\lib\site-packages\pymc\sampling_jax.py in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
512 map_seed = jax.random.split(map_seed, chains)
513
→ 514 pmap_numpyro.run(
515 map_seed,
516 init_params=init_params,

~\Anaconda3\lib\site-packages\numpyro\infer\mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
597 states, last_state = _laxmap(partial_map_fn, map_args)
598 elif self.chain_method == “parallel”:
→ 599 states, last_state = pmap(partial_map_fn)(map_args)
600 else:
601 assert self.chain_method == “vectorized”

[... skipping hidden 17 frame]

~\Anaconda3\lib\site-packages\numpyro\infer\mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
379 rng_key, init_state, init_params = init
380 if init_state is None:
→ 381 init_state = self.sampler.init(
382 rng_key,
383 self.num_warmup,

~\Anaconda3\lib\site-packages\numpyro\infer\hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
744 )
745 if rng_key.ndim == 1:
→ 746 init_state = hmc_init_fn(init_params, rng_key)
747 else:
748 # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some

~\Anaconda3\lib\site-packages\numpyro\infer\hmc.py in (init_params, rng_key)
724 assert isinstance(dense_mass, list)
725
→ 726 hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731
727 init_params,
728 num_warmup=num_warmup,

~\Anaconda3\lib\site-packages\numpyro\infer\hmc.py in init_kernel(init_params, num_warmup, step_size, inverse_mass_matrix, adapt_step_size, adapt_mass_matrix, dense_mass, target_accept_prob, num_steps, trajectory_length, max_tree_depth, find_heuristic_step_size, forward_mode_differentiation, regularize_mass_matrix, model_args, model_kwargs, rng_key)
320 r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
321 vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)
→ 322 vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
323 energy = vv_state.potential_energy + kinetic_fn(
324 wa_state.inverse_mass_matrix, vv_state.r

~\Anaconda3\lib\site-packages\numpyro\infer\hmc_util.py in init_fn(z, r, potential_energy, z_grad)
276 “”"
277 if potential_energy is None or z_grad is None:
→ 278 potential_energy, z_grad = _value_and_grad(
279 potential_fn, z, forward_mode_differentiation
280 )

~\Anaconda3\lib\site-packages\numpyro\infer\hmc_util.py in _value_and_grad(f, x, forward_mode_differentiation)
244 return f(x), jacfwd(f)(x)
245 else:
→ 246 return value_and_grad(f)(x)
247
248

[... skipping hidden 8 frame]

~\Anaconda3\lib\site-packages\pymc\sampling_jax.py in logp_fn_wrap(x)
107
108 def logp_fn_wrap(x):
→ 109 return logp_fn(*x)[0]
110
111 return logp_fn_wrap

~\Anaconda3\lib\site-packages\aesara\link\utils.py in jax_funcified_fgraph(x_interval_, y_interval_, _, 0_log, 1_log, 2_log, 3_log)
120 auto_371486 = alloc3(auto_371485, auto_369963, auto_226526)
121 # Reshape{1}(Alloc.0, TensorConstant{(1,) of 10})
→ 122 auto_371499 = reshape(auto_371498, auto_370154)
123 # Reshape{1}(Alloc.0, TensorConstant{(1,) of 10})
124 auto_371510 = reshape1(auto_371509, auto_370154)

~\Anaconda3\lib\site-packages\aesara\link\jax\dispatch.py in reshape(x, shape)
731 def jax_funcify_Reshape(op, **kwargs):
732 def reshape(x, shape):
→ 733 return jnp.reshape(x, shape)
734
735 return reshape

~\Anaconda3\lib\site-packages\jax_src\numpy\lax_numpy.py in reshape(a, newshape, order)
738 _stackable(a) or _check_arraylike(“reshape”, a)
739 try:
→ 740 return a.reshape(newshape, order=order) # forward to method for ndarrays
741 except AttributeError:
742 return _reshape(a, newshape, order=order)

~\Anaconda3\lib\site-packages\jax_src\numpy\lax_numpy.py in _reshape(a, order, *args)
756
757 def _reshape(a, *args, order=“C”):
→ 758 newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
759 if order == “C”:
760 return lax.reshape(a, newshape, None)

~\Anaconda3\lib\site-packages\jax_src\numpy\lax_numpy.py in _compute_newshape(a, newshape)
749 except: iterable = False
750 else: iterable = True
→ 751 newshape = core.canonicalize_shape(newshape if iterable else [newshape])
752 return tuple(- core.divide_shape_sizes(np.shape(a), newshape)
753 if core.symbolic_equal_dim(d, -1) else d

~\Anaconda3\lib\site-packages\jax\core.py in canonicalize_shape(shape, context)
1720 except TypeError:
1721 pass
→ 1722 raise _invalid_shape_error(shape, context)
1723
1724 def canonicalize_dim(d: DimSize, context: str=“”) → DimSize:

TypeError: Shapes must be 1D sequences of concrete values of integer type, got [10].
If using jit, try using static_argnums or applying jit to smaller subfunctions.

Try to add this snippet at the top of your script:

import jax
from aesara.graph import Constant
from aesara.link.jax.dispatch import jax_funcify
from aesara.tensor.shape import Reshape

@jax_funcify.register(Reshape)
def jax_funcify_Reshape(op, node, **kwargs):

    shape = node.inputs[1]
    if isinstance(shape, Constant):
        constant_shape = shape.data
        def reshape(x, _):
            return jax.numpy.reshape(x, constant_shape)

    else:  
        def reshape(x, shape):        
            return jax.numpy.reshape(x, shape)

    return reshape 

We still need to fix this upstream…

Thanks! This have solved my problem!