When performing sampling with sample_numpyro_nuts on a model, and specifying the var_names argument to selectively track variables, I get a mysterious error. MWE and error below. Any idea how to fix this? Or is it a bug?
MWE:
import numpy as np
import pymc as pm
import pymc.sampling_jax
# True parameter values
size = 100
Y = 1 + np.random.normal(size=size, scale = 1)
basic_model = pm.Model()
with basic_model:
scale = pm.HalfNormal("scale", sigma=1)
loc = pm.Normal("loc", mu=0, sigma=10)
Y_obs = pm.Normal("Y_obs", mu=loc, sigma=scale, observed=Y)
with basic_model:
trace = pymc.sampling_jax.sample_numpyro_nuts(
chains = 1,
tune = 1000,
draws = 1000,
var_names = ["loc"]
)
Error:
AttributeError Traceback (most recent call last)
15 with basic_model:
----> 16 trace = pymc.sampling_jax.sample_numpyro_nuts(
17 chains = 1,
18 tune = 1000,
19 draws = 1000,
20 var_names = ["loc"]
21 )
File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling_jax.py:533, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
530 print("Sampling time = ", tic3 - tic2, file=sys.stdout)
532 print("Transforming variables...", file=sys.stdout)
--> 533 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
534 result = jax.vmap(jax.vmap(jax_fn))(
535 *jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
536 )
537 mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling_jax.py:81, in get_jaxified_graph(inputs, outputs)
75 def get_jaxified_graph(
76 inputs: Optional[List[TensorVariable]] = None,
77 outputs: Optional[List[TensorVariable]] = None,
78 ) -> List[TensorVariable]:
...
850 def expand(r: Variable) -> Optional[Iterator[Variable]]:
--> 851 if r.owner and (not blockers or r not in blockers):
852 return reversed(r.owner.inputs)
AttributeError: 'str' object has no attribute 'owner'
PyMC version: 4.1.3
Aesara version: 2.7.7
Python version: 3.10.5
Operating system: MacOS (but the error also happens on linux)
How you installed PyMC: conda