Yes. I have just changed the
import theano.tensor as tt
to:
import aesara.tensor as at
For some reason, after restarting the kernel, I get a different error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [36], in <cell line: 1>()
1 with m5:
----> 2 trH_phi = pm.sampling_jax.sample_numpyro_nuts(target_accept=.95, chains=4, draws = 5000)
File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/pymc/sampling_jax.py:506, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, idata_kwargs, nuts_kwargs)
503 if chains > 1:
504 map_seed = jax.random.split(map_seed, chains)
--> 506 pmap_numpyro.run(
507 map_seed,
508 init_params=init_params,
509 extra_fields=(
510 "num_steps",
511 "potential_energy",
512 "energy",
513 "adapt_state.step_size",
514 "accept_prob",
515 "diverging",
516 ),
517 )
519 raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
521 tic3 = datetime.now()
File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/numpyro/infer/mcmc.py:599, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
597 states, last_state = _laxmap(partial_map_fn, map_args)
598 elif self.chain_method == "parallel":
--> 599 states, last_state = pmap(partial_map_fn)(map_args)
600 else:
601 assert self.chain_method == "vectorized"
[... skipping hidden 13 frame]
File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/numpyro/infer/mcmc.py:381, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
379 rng_key, init_state, init_params = init
380 if init_state is None:
--> 381 init_state = self.sampler.init(
382 rng_key,
383 self.num_warmup,
384 init_params,
385 model_args=args,
386 model_kwargs=kwargs,
387 )
388 sample_fn, postprocess_fn = self._get_cached_fns()
389 diagnostics = (
390 lambda x: self.sampler.get_diagnostics_str(x[0])
391 if rng_key.ndim == 1
392 else ""
393 ) # noqa: E731
File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/numpyro/infer/hmc.py:746, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
726 hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731
727 init_params,
728 num_warmup=num_warmup,
(...)
743 rng_key=rng_key,
744 )
745 if rng_key.ndim == 1:
--> 746 init_state = hmc_init_fn(init_params, rng_key)
747 else:
748 # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
749 # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
750 # wa_steps because those variables do not depend on traced args: init_params, rng_key.
751 init_state = vmap(hmc_init_fn)(init_params, rng_key)
File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/numpyro/infer/hmc.py:726, in HMC.init.<locals>.<lambda>(init_params, rng_key)
723 dense_mass = [tuple(sorted(z))] if dense_mass else []
724 assert isinstance(dense_mass, list)
--> 726 hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731
727 init_params,
728 num_warmup=num_warmup,
729 step_size=self._step_size,
730 num_steps=self._num_steps,
731 inverse_mass_matrix=inverse_mass_matrix,
732 adapt_step_size=self._adapt_step_size,
733 adapt_mass_matrix=self._adapt_mass_matrix,
734 dense_mass=dense_mass,
735 target_accept_prob=self._target_accept_prob,
736 trajectory_length=self._trajectory_length,
737 max_tree_depth=self._max_tree_depth,
738 find_heuristic_step_size=self._find_heuristic_step_size,
739 forward_mode_differentiation=self._forward_mode_differentiation,
740 regularize_mass_matrix=self._regularize_mass_matrix,
741 model_args=model_args,
742 model_kwargs=model_kwargs,
743 rng_key=rng_key,
744 )
745 if rng_key.ndim == 1:
746 init_state = hmc_init_fn(init_params, rng_key)
File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/numpyro/infer/hmc.py:322, in hmc.<locals>.init_kernel(init_params, num_warmup, step_size, inverse_mass_matrix, adapt_step_size, adapt_mass_matrix, dense_mass, target_accept_prob, num_steps, trajectory_length, max_tree_depth, find_heuristic_step_size, forward_mode_differentiation, regularize_mass_matrix, model_args, model_kwargs, rng_key)
320 r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
321 vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)
--> 322 vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
323 energy = vv_state.potential_energy + kinetic_fn(
324 wa_state.inverse_mass_matrix, vv_state.r
325 )
326 zero_int = jnp.array(0, dtype=jnp.result_type(int))
File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/numpyro/infer/hmc_util.py:278, in velocity_verlet.<locals>.init_fn(z, r, potential_energy, z_grad)
270 """
271 :param z: Position of the particle.
272 :param r: Momentum of the particle.
(...)
275 :return: initial state for the integrator.
276 """
277 if potential_energy is None or z_grad is None:
--> 278 potential_energy, z_grad = _value_and_grad(
279 potential_fn, z, forward_mode_differentiation
280 )
281 return IntegratorState(z, r, potential_energy, z_grad)
File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/numpyro/infer/hmc_util.py:246, in _value_and_grad(f, x, forward_mode_differentiation)
244 return f(x), jacfwd(f)(x)
245 else:
--> 246 return value_and_grad(f)(x)
[... skipping hidden 7 frame]
File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/pymc/sampling_jax.py:109, in get_jaxified_logp.<locals>.logp_fn_wrap(x)
108 def logp_fn_wrap(x):
--> 109 return logp_fn(*x)[0]
File /tmp/tmpwurs7sxl:44, in jax_funcified_fgraph(phi_interval_, kappa_log_log_, alpha_logodds_, beta_h, beta_sd_log_, eps_log_)
42 auto_194865 = log(auto_194201)
43 # forall_inplace,cpu,scan_fn}(TensorConstant{69}, TensorConstant{[[ True T..se False]]}, TensorConstant{[[1 1 1 ...... 0 0 0]]}, TensorConstant{[[1 1 1 ...... 0 0 0]]}, IncSubtensor{InplaceSet;:int64:}.0, IncSubtensor{InplaceSet;:int64:}.0, Elemwise{sigmoid,no_inplace}.0)
---> 44 auto_198126, auto_198127 = scan(auto_191975, auto_194927, auto_193120, auto_192695, auto_197868, auto_197866, auto_194208)
45 # Elemwise{mul,no_inplace}(beta_h, InplaceDimShuffle{x}.0)
46 auto_194205 = elemwise(beta_h, auto_194204)
File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/aesara/link/jax/dispatch.py:420, in jax_funcify_Scan.<locals>.scan(*outer_inputs)
419 def scan(*outer_inputs):
--> 420 scan_args = ScanArgs(
421 list(outer_inputs), [None] * op.n_outs, op.inputs, op.outputs, op.info
422 )
424 # `outer_inputs` is a list with the following composite form:
425 # [n_steps]
426 # + outer_in_seqs
(...)
431 # + outer_in_nit_sot
432 # + outer_in_non_seqs
433 n_steps = scan_args.n_steps
TypeError: __init__() missing 1 required positional argument: 'as_while'