This is the full error message.
Compiling…
Compilation time = 0:00:03.496124
Sampling…
TypeError Traceback (most recent call last)
C:\Users\XXX\AppData\Local\Temp/ipykernel_84224/3518610046.py in
1 n_draws, n_jobs, n_tunes = 2500, 2, 1500
2 with pos_model:
----> 3 idata2 = pm.sampling_jax.sample_numpyro_nuts(draws = n_draws, chains = n_jobs, tune = n_tunes)
4 #idata2 = pm.sampling_jax.sample_blackjax_nuts(draws = n_draws, chains = n_jobs, tune = n_tunes)
5 #idata2 = pm.sample(draws = n_draws, step = pm.NUTS(target_accept = 0.8), chains = n_jobs, cores = 4, tune = n_tunes)
~\Anaconda3\lib\site-packages\pymc\sampling_jax.py in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
512 map_seed = jax.random.split(map_seed, chains)
513
→ 514 pmap_numpyro.run(
515 map_seed,
516 init_params=init_params,
~\Anaconda3\lib\site-packages\numpyro\infer\mcmc.py in run(self, rng_key, extra_fields, init_params, *args, **kwargs)
597 states, last_state = _laxmap(partial_map_fn, map_args)
598 elif self.chain_method == “parallel”:
→ 599 states, last_state = pmap(partial_map_fn)(map_args)
600 else:
601 assert self.chain_method == “vectorized”
[... skipping hidden 17 frame]
~\Anaconda3\lib\site-packages\numpyro\infer\mcmc.py in _single_chain_mcmc(self, init, args, kwargs, collect_fields)
379 rng_key, init_state, init_params = init
380 if init_state is None:
→ 381 init_state = self.sampler.init(
382 rng_key,
383 self.num_warmup,
~\Anaconda3\lib\site-packages\numpyro\infer\hmc.py in init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
744 )
745 if rng_key.ndim == 1:
→ 746 init_state = hmc_init_fn(init_params, rng_key)
747 else:
748 # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
~\Anaconda3\lib\site-packages\numpyro\infer\hmc.py in (init_params, rng_key)
724 assert isinstance(dense_mass, list)
725
→ 726 hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731
727 init_params,
728 num_warmup=num_warmup,
~\Anaconda3\lib\site-packages\numpyro\infer\hmc.py in init_kernel(init_params, num_warmup, step_size, inverse_mass_matrix, adapt_step_size, adapt_mass_matrix, dense_mass, target_accept_prob, num_steps, trajectory_length, max_tree_depth, find_heuristic_step_size, forward_mode_differentiation, regularize_mass_matrix, model_args, model_kwargs, rng_key)
320 r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
321 vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)
→ 322 vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
323 energy = vv_state.potential_energy + kinetic_fn(
324 wa_state.inverse_mass_matrix, vv_state.r
~\Anaconda3\lib\site-packages\numpyro\infer\hmc_util.py in init_fn(z, r, potential_energy, z_grad)
276 “”"
277 if potential_energy is None or z_grad is None:
→ 278 potential_energy, z_grad = _value_and_grad(
279 potential_fn, z, forward_mode_differentiation
280 )
~\Anaconda3\lib\site-packages\numpyro\infer\hmc_util.py in _value_and_grad(f, x, forward_mode_differentiation)
244 return f(x), jacfwd(f)(x)
245 else:
→ 246 return value_and_grad(f)(x)
247
248
[... skipping hidden 8 frame]
~\Anaconda3\lib\site-packages\pymc\sampling_jax.py in logp_fn_wrap(x)
107
108 def logp_fn_wrap(x):
→ 109 return logp_fn(*x)[0]
110
111 return logp_fn_wrap
~\Anaconda3\lib\site-packages\aesara\link\utils.py in jax_funcified_fgraph(x_interval_, y_interval_, _, 0_log, 1_log, 2_log, 3_log)
120 auto_371486 = alloc3(auto_371485, auto_369963, auto_226526)
121 # Reshape{1}(Alloc.0, TensorConstant{(1,) of 10})
→ 122 auto_371499 = reshape(auto_371498, auto_370154)
123 # Reshape{1}(Alloc.0, TensorConstant{(1,) of 10})
124 auto_371510 = reshape1(auto_371509, auto_370154)
~\Anaconda3\lib\site-packages\aesara\link\jax\dispatch.py in reshape(x, shape)
731 def jax_funcify_Reshape(op, **kwargs):
732 def reshape(x, shape):
→ 733 return jnp.reshape(x, shape)
734
735 return reshape
~\Anaconda3\lib\site-packages\jax_src\numpy\lax_numpy.py in reshape(a, newshape, order)
738 _stackable(a) or _check_arraylike(“reshape”, a)
739 try:
→ 740 return a.reshape(newshape, order=order) # forward to method for ndarrays
741 except AttributeError:
742 return _reshape(a, newshape, order=order)
~\Anaconda3\lib\site-packages\jax_src\numpy\lax_numpy.py in _reshape(a, order, *args)
756
757 def _reshape(a, *args, order=“C”):
→ 758 newshape = _compute_newshape(a, args[0] if len(args) == 1 else args)
759 if order == “C”:
760 return lax.reshape(a, newshape, None)
~\Anaconda3\lib\site-packages\jax_src\numpy\lax_numpy.py in _compute_newshape(a, newshape)
749 except: iterable = False
750 else: iterable = True
→ 751 newshape = core.canonicalize_shape(newshape if iterable else [newshape])
752 return tuple(- core.divide_shape_sizes(np.shape(a), newshape)
753 if core.symbolic_equal_dim(d, -1) else d
~\Anaconda3\lib\site-packages\jax\core.py in canonicalize_shape(shape, context)
1720 except TypeError:
1721 pass
→ 1722 raise _invalid_shape_error(shape, context)
1723
1724 def canonicalize_dim(d: DimSize, context: str=“”) → DimSize:
TypeError: Shapes must be 1D sequences of concrete values of integer type, got [10].
If using jit
, try using static_argnums
or applying jit
to smaller subfunctions.