I’m trying to run the same model with Pymc v4 and JAX as the backend.
I have made small changes in the original code (to remove trial with shock administered), but this works perfectly well with Pymc3.
This is the current code I try to run:
with pm.Model() as m5:
# α
phi = pm.Uniform("phi", lower=0.0, upper=1.0)
kappa_log = pm.Exponential("kappa_log", lam=1.5)
kappa = pm.Deterministic("kappa", tt.exp(kappa_log))
alpha = pm.Beta("alpha", alpha=phi * kappa, beta=(1.0 - phi) * kappa, shape=n_subj)
# β (reparametarization)
beta_h = pm.Normal('beta_h', 0,1, shape=n_subj)
beta_sd = pm.HalfNormal('beta_sd', 5)
beta = pm.Deterministic('beta',0 + beta_h*beta_sd)
eps = pm.HalfNormal('eps', 5)
Qs = 0.5 * tt.ones((n_subj,2), dtype='float64') # set values for boths stimuli (CS+, CS-)
vec = 0.5 * tt.ones((n_subj,1), dtype='float64') # vector to save the relevant stimulus's expactation
[Qs,vec, pe], updates = aesara.scan(
fn=update_Q,
sequences=[stim, shock],
outputs_info=[Qs, vec, None],
non_sequences=[alpha, n_subj])
vec_ = vec[trials, subj,0] * beta[subj]
# add matrix of expected values (trials X subjects)
ev = pm.Deterministic('expected_value', vec_)
# add PE
pe = pm.Deterministic('pe', pe)
# transform to vector
v = tt.reshape(vec_.T, n_subj*n_trials, ndim=1)
# clean shocks
vec_clean = v[shockVec==0]
# likelihood function
scrs = pm.Normal('scrs', mu = vec_clean, sigma = eps, observed=scrClean)
trH_phi = pm.sampling_jax.sample_numpyro_nuts(target_accept=.95, chains=4, draws = 5000)
And this is the error I receive:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/aesara/tensor/type.py:276, in TensorType.dtype_specs(self)
275 try:
--> 276 return self.dtype_specs_map[self.dtype]
277 except KeyError:
KeyError: 'object'
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
Input In [76], in <cell line: 1>()
1 with m5:
----> 2 trH_phi = pm.sampling_jax.sample_numpyro_nuts(target_accept=.95, chains=4, draws = 5000)
File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/pymc/sampling_jax.py:474, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, idata_kwargs, nuts_kwargs)
465 print("Compiling...", file=sys.stdout)
467 init_params = _get_batched_jittered_initial_points(
468 model=model,
469 chains=chains,
470 initvals=initvals,
471 random_seed=random_seed,
472 )
--> 474 logp_fn = get_jaxified_logp(model, negative_logp=False)
476 if nuts_kwargs is None:
477 nuts_kwargs = {}
File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/pymc/sampling_jax.py:106, in get_jaxified_logp(model, negative_logp)
104 if not negative_logp:
105 model_logpt = -model_logpt
--> 106 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logpt])
108 def logp_fn_wrap(x):
109 return logp_fn(*x)[0]
File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/pymc/sampling_jax.py:81, in get_jaxified_graph(inputs, outputs)
75 def get_jaxified_graph(
76 inputs: Optional[List[TensorVariable]] = None,
77 outputs: Optional[List[TensorVariable]] = None,
78 ) -> List[TensorVariable]:
79 """Compile an Aesara graph into an optimized JAX function"""
---> 81 graph = _replace_shared_variables(outputs)
83 fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
84 # We need to add a Supervisor to the fgraph to be able to run the
85 # JAX sequential optimizer without warnings. We made sure there
86 # are no mutable input variables, so we only need to check for
87 # "destroyers". This should be automatically handled by Aesara
88 # once https://github.com/aesara-devs/aesara/issues/637 is fixed.
File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/pymc/sampling_jax.py:69, in _replace_shared_variables(graph)
63 if any(hasattr(var, "default_update") for var in shared_variables):
64 raise ValueError(
65 "Graph contains shared variables with default_update which cannot "
66 "be safely replaced."
67 )
---> 69 replacements = {var: at.constant(var.get_value(borrow=True)) for var in shared_variables}
71 new_graph = clone_replace(graph, replace=replacements)
72 return new_graph
File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/pymc/sampling_jax.py:69, in <dictcomp>(.0)
63 if any(hasattr(var, "default_update") for var in shared_variables):
64 raise ValueError(
65 "Graph contains shared variables with default_update which cannot "
66 "be safely replaced."
67 )
---> 69 replacements = {var: at.constant(var.get_value(borrow=True)) for var in shared_variables}
71 new_graph = clone_replace(graph, replace=replacements)
72 return new_graph
File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/aesara/tensor/basic.py:219, in constant(x, name, ndim, dtype)
213 raise ValueError(
214 f"ndarray could not be cast to constant with {int(ndim)} dimensions"
215 )
217 assert x_.ndim == ndim
--> 219 ttype = TensorType(dtype=x_.dtype, shape=x_.shape)
221 return TensorConstant(ttype, x_, name=name)
File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/aesara/tensor/type.py:107, in TensorType.__init__(self, dtype, shape, name, broadcastable)
104 return s
106 self.shape = tuple(parse_bcast_and_shape(s) for s in shape)
--> 107 self.dtype_specs() # error checking is done there
108 self.name = name
109 self.numpy_dtype = np.dtype(self.dtype)
File /gpfs/ysm/project/joormann/oad4/conda_envs/pymc4/lib/python3.9/site-packages/aesara/tensor/type.py:278, in TensorType.dtype_specs(self)
276 return self.dtype_specs_map[self.dtype]
277 except KeyError:
--> 278 raise TypeError(
279 f"Unsupported dtype for {self.__class__.__name__}: {self.dtype}"
280 )
TypeError: Unsupported dtype for TensorType: object
Any ideas?