Reinforcement learning - help building a model

I’m trying to run the same model with Pymc v4 and JAX as the backend.
I have made small changes in the original code (to remove trial with shock administered), but this works perfectly well with Pymc3.
This is the current code I try to run:

with pm.Model() as m5:
    
    # α
    phi = pm.Uniform("phi", lower=0.0, upper=1.0)
    kappa_log = pm.Exponential("kappa_log", lam=1.5)
    kappa = pm.Deterministic("kappa", tt.exp(kappa_log))
    alpha = pm.Beta("alpha", alpha=phi * kappa, beta=(1.0 - phi) * kappa, shape=n_subj)
    
    # β (reparametarization)
    beta_h = pm.Normal('beta_h', 0,1, shape=n_subj)
    beta_sd = pm.HalfNormal('beta_sd', 5)
    beta = pm.Deterministic('beta',0 + beta_h*beta_sd)
       
    eps = pm.HalfNormal('eps', 5)
    
    Qs = 0.5 * tt.ones((n_subj,2), dtype='float64') # set values for boths stimuli (CS+, CS-)
    vec = 0.5 * tt.ones((n_subj,1), dtype='float64') # vector to save the relevant stimulus's expactation
    
    [Qs,vec, pe], updates = aesara.scan(
        fn=update_Q,
        sequences=[stim, shock],
        outputs_info=[Qs, vec, None],
        non_sequences=[alpha, n_subj])
   
     
    vec_ = vec[trials, subj,0] * beta[subj]
    # add matrix of expected values (trials X subjects)
    ev = pm.Deterministic('expected_value', vec_)
    # add PE
    pe = pm.Deterministic('pe', pe)
    
    # transform to vector
    v = tt.reshape(vec_.T, n_subj*n_trials, ndim=1)
    # clean shocks
    vec_clean = v[shockVec==0]
       
    # likelihood function
    scrs = pm.Normal('scrs', mu = vec_clean, sigma = eps, observed=scrClean) 
    trH_phi = pm.sampling_jax.sample_numpyro_nuts(target_accept=.95, chains=4,  draws = 5000)

And this is the error I receive:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/aesara/tensor/type.py:276, in TensorType.dtype_specs(self)
    275 try:
--> 276     return self.dtype_specs_map[self.dtype]
    277 except KeyError:

KeyError: 'object'

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
Input In [76], in <cell line: 1>()
      1 with m5:
----> 2     trH_phi = pm.sampling_jax.sample_numpyro_nuts(target_accept=.95, chains=4,  draws = 5000)

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/pymc/sampling_jax.py:474, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, idata_kwargs, nuts_kwargs)
    465 print("Compiling...", file=sys.stdout)
    467 init_params = _get_batched_jittered_initial_points(
    468     model=model,
    469     chains=chains,
    470     initvals=initvals,
    471     random_seed=random_seed,
    472 )
--> 474 logp_fn = get_jaxified_logp(model, negative_logp=False)
    476 if nuts_kwargs is None:
    477     nuts_kwargs = {}

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/pymc/sampling_jax.py:106, in get_jaxified_logp(model, negative_logp)
    104 if not negative_logp:
    105     model_logpt = -model_logpt
--> 106 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logpt])
    108 def logp_fn_wrap(x):
    109     return logp_fn(*x)[0]

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/pymc/sampling_jax.py:81, in get_jaxified_graph(inputs, outputs)
     75 def get_jaxified_graph(
     76     inputs: Optional[List[TensorVariable]] = None,
     77     outputs: Optional[List[TensorVariable]] = None,
     78 ) -> List[TensorVariable]:
     79     """Compile an Aesara graph into an optimized JAX function"""
---> 81     graph = _replace_shared_variables(outputs)
     83     fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
     84     # We need to add a Supervisor to the fgraph to be able to run the
     85     # JAX sequential optimizer without warnings. We made sure there
     86     # are no mutable input variables, so we only need to check for
     87     # "destroyers". This should be automatically handled by Aesara
     88     # once https://github.com/aesara-devs/aesara/issues/637 is fixed.

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/pymc/sampling_jax.py:69, in _replace_shared_variables(graph)
     63 if any(hasattr(var, "default_update") for var in shared_variables):
     64     raise ValueError(
     65         "Graph contains shared variables with default_update which cannot "
     66         "be safely replaced."
     67     )
---> 69 replacements = {var: at.constant(var.get_value(borrow=True)) for var in shared_variables}
     71 new_graph = clone_replace(graph, replace=replacements)
     72 return new_graph

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/pymc/sampling_jax.py:69, in <dictcomp>(.0)
     63 if any(hasattr(var, "default_update") for var in shared_variables):
     64     raise ValueError(
     65         "Graph contains shared variables with default_update which cannot "
     66         "be safely replaced."
     67     )
---> 69 replacements = {var: at.constant(var.get_value(borrow=True)) for var in shared_variables}
     71 new_graph = clone_replace(graph, replace=replacements)
     72 return new_graph

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/aesara/tensor/basic.py:219, in constant(x, name, ndim, dtype)
    213             raise ValueError(
    214                 f"ndarray could not be cast to constant with {int(ndim)} dimensions"
    215             )
    217     assert x_.ndim == ndim
--> 219 ttype = TensorType(dtype=x_.dtype, shape=x_.shape)
    221 return TensorConstant(ttype, x_, name=name)

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/aesara/tensor/type.py:107, in TensorType.__init__(self, dtype, shape, name, broadcastable)
    104         return s
    106 self.shape = tuple(parse_bcast_and_shape(s) for s in shape)
--> 107 self.dtype_specs()  # error checking is done there
    108 self.name = name
    109 self.numpy_dtype = np.dtype(self.dtype)

File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/aesara/tensor/type.py:278, in TensorType.dtype_specs(self)
    276     return self.dtype_specs_map[self.dtype]
    277 except KeyError:
--> 278     raise TypeError(
    279         f"Unsupported dtype for {self.__class__.__name__}: {self.dtype}"
    280     )

TypeError: Unsupported dtype for TensorType: object

Any ideas?