Reinforcement learning - help building a model

Yes. I have just changed the
import theano.tensor as tt
to:
import aesara.tensor as at

For some reason, after restarting the kernel, I get a different error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [36], in <cell line: 1>()
      1 with m5:
----> 2     trH_phi = pm.sampling_jax.sample_numpyro_nuts(target_accept=.95, chains=4,  draws = 5000)

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/pymc/sampling_jax.py:506, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, idata_kwargs, nuts_kwargs)
    503 if chains > 1:
    504     map_seed = jax.random.split(map_seed, chains)
--> 506 pmap_numpyro.run(
    507     map_seed,
    508     init_params=init_params,
    509     extra_fields=(
    510         "num_steps",
    511         "potential_energy",
    512         "energy",
    513         "adapt_state.step_size",
    514         "accept_prob",
    515         "diverging",
    516     ),
    517 )
    519 raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
    521 tic3 = datetime.now()

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/numpyro/infer/mcmc.py:599, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    597     states, last_state = _laxmap(partial_map_fn, map_args)
    598 elif self.chain_method == "parallel":
--> 599     states, last_state = pmap(partial_map_fn)(map_args)
    600 else:
    601     assert self.chain_method == "vectorized"

    [... skipping hidden 13 frame]

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/numpyro/infer/mcmc.py:381, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
    379 rng_key, init_state, init_params = init
    380 if init_state is None:
--> 381     init_state = self.sampler.init(
    382         rng_key,
    383         self.num_warmup,
    384         init_params,
    385         model_args=args,
    386         model_kwargs=kwargs,
    387     )
    388 sample_fn, postprocess_fn = self._get_cached_fns()
    389 diagnostics = (
    390     lambda x: self.sampler.get_diagnostics_str(x[0])
    391     if rng_key.ndim == 1
    392     else ""
    393 )  # noqa: E731

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/numpyro/infer/hmc.py:746, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    726 hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
    727     init_params,
    728     num_warmup=num_warmup,
   (...)
    743     rng_key=rng_key,
    744 )
    745 if rng_key.ndim == 1:
--> 746     init_state = hmc_init_fn(init_params, rng_key)
    747 else:
    748     # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
    749     # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
    750     # wa_steps because those variables do not depend on traced args: init_params, rng_key.
    751     init_state = vmap(hmc_init_fn)(init_params, rng_key)

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/numpyro/infer/hmc.py:726, in HMC.init.<locals>.<lambda>(init_params, rng_key)
    723         dense_mass = [tuple(sorted(z))] if dense_mass else []
    724     assert isinstance(dense_mass, list)
--> 726 hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
    727     init_params,
    728     num_warmup=num_warmup,
    729     step_size=self._step_size,
    730     num_steps=self._num_steps,
    731     inverse_mass_matrix=inverse_mass_matrix,
    732     adapt_step_size=self._adapt_step_size,
    733     adapt_mass_matrix=self._adapt_mass_matrix,
    734     dense_mass=dense_mass,
    735     target_accept_prob=self._target_accept_prob,
    736     trajectory_length=self._trajectory_length,
    737     max_tree_depth=self._max_tree_depth,
    738     find_heuristic_step_size=self._find_heuristic_step_size,
    739     forward_mode_differentiation=self._forward_mode_differentiation,
    740     regularize_mass_matrix=self._regularize_mass_matrix,
    741     model_args=model_args,
    742     model_kwargs=model_kwargs,
    743     rng_key=rng_key,
    744 )
    745 if rng_key.ndim == 1:
    746     init_state = hmc_init_fn(init_params, rng_key)

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/numpyro/infer/hmc.py:322, in hmc.<locals>.init_kernel(init_params, num_warmup, step_size, inverse_mass_matrix, adapt_step_size, adapt_mass_matrix, dense_mass, target_accept_prob, num_steps, trajectory_length, max_tree_depth, find_heuristic_step_size, forward_mode_differentiation, regularize_mass_matrix, model_args, model_kwargs, rng_key)
    320 r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
    321 vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)
--> 322 vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
    323 energy = vv_state.potential_energy + kinetic_fn(
    324     wa_state.inverse_mass_matrix, vv_state.r
    325 )
    326 zero_int = jnp.array(0, dtype=jnp.result_type(int))

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/numpyro/infer/hmc_util.py:278, in velocity_verlet.<locals>.init_fn(z, r, potential_energy, z_grad)
    270 """
    271 :param z: Position of the particle.
    272 :param r: Momentum of the particle.
   (...)
    275 :return: initial state for the integrator.
    276 """
    277 if potential_energy is None or z_grad is None:
--> 278     potential_energy, z_grad = _value_and_grad(
    279         potential_fn, z, forward_mode_differentiation
    280     )
    281 return IntegratorState(z, r, potential_energy, z_grad)

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/numpyro/infer/hmc_util.py:246, in _value_and_grad(f, x, forward_mode_differentiation)
    244     return f(x), jacfwd(f)(x)
    245 else:
--> 246     return value_and_grad(f)(x)

    [... skipping hidden 7 frame]

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/pymc/sampling_jax.py:109, in get_jaxified_logp.<locals>.logp_fn_wrap(x)
    108 def logp_fn_wrap(x):
--> 109     return logp_fn(*x)[0]

File /tmp/tmpwurs7sxl:44, in jax_funcified_fgraph(phi_interval_, kappa_log_log_, alpha_logodds_, beta_h, beta_sd_log_, eps_log_)
     42 auto_194865 = log(auto_194201)
     43 # forall_inplace,cpu,scan_fn}(TensorConstant{69}, TensorConstant{[[ True  T..se False]]}, TensorConstant{[[1 1 1 ...... 0 0 0]]}, TensorConstant{[[1 1 1 ...... 0 0 0]]}, IncSubtensor{InplaceSet;:int64:}.0, IncSubtensor{InplaceSet;:int64:}.0, Elemwise{sigmoid,no_inplace}.0)
---> 44 auto_198126, auto_198127 = scan(auto_191975, auto_194927, auto_193120, auto_192695, auto_197868, auto_197866, auto_194208)
     45 # Elemwise{mul,no_inplace}(beta_h, InplaceDimShuffle{x}.0)
     46 auto_194205 = elemwise(beta_h, auto_194204)

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/aesara/link/jax/dispatch.py:420, in jax_funcify_Scan.<locals>.scan(*outer_inputs)
    419 def scan(*outer_inputs):
--> 420     scan_args = ScanArgs(
    421         list(outer_inputs), [None] * op.n_outs, op.inputs, op.outputs, op.info
    422     )
    424     # `outer_inputs` is a list with the following composite form:
    425     # [n_steps]
    426     # + outer_in_seqs
   (...)
    431     # + outer_in_nit_sot
    432     # + outer_in_non_seqs
    433     n_steps = scan_args.n_steps

TypeError: __init__() missing 1 required positional argument: 'as_while'