Var_names not working with sample_numpyro_nuts

,

When performing sampling with sample_numpyro_nuts on a model, and specifying the var_names argument to selectively track variables, I get a mysterious error. MWE and error below. Any idea how to fix this? Or is it a bug?

MWE:

import numpy as np
import pymc as pm
import pymc.sampling_jax

# True parameter values
size = 100
Y = 1 + np.random.normal(size=size, scale = 1)

basic_model = pm.Model()
with basic_model:
    scale = pm.HalfNormal("scale", sigma=1)
    loc = pm.Normal("loc", mu=0, sigma=10)
    Y_obs = pm.Normal("Y_obs", mu=loc, sigma=scale, observed=Y)

with basic_model:
    trace = pymc.sampling_jax.sample_numpyro_nuts(
        chains = 1, 
        tune = 1000,
        draws = 1000,
        var_names = ["loc"]
    )

Error:

AttributeError                            Traceback (most recent call last)
      15 with basic_model:
----> 16     trace = pymc.sampling_jax.sample_numpyro_nuts(
      17         chains = 1, 
      18         tune = 1000,
      19         draws = 1000,
      20         var_names = ["loc"]
      21     )

File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling_jax.py:533, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
    530 print("Sampling time = ", tic3 - tic2, file=sys.stdout)
    532 print("Transforming variables...", file=sys.stdout)
--> 533 jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
    534 result = jax.vmap(jax.vmap(jax_fn))(
    535     *jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
    536 )
    537 mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}

File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling_jax.py:81, in get_jaxified_graph(inputs, outputs)
     75 def get_jaxified_graph(
     76     inputs: Optional[List[TensorVariable]] = None,
     77     outputs: Optional[List[TensorVariable]] = None,
     78 ) -> List[TensorVariable]:
...
    850 def expand(r: Variable) -> Optional[Iterator[Variable]]:
--> 851     if r.owner and (not blockers or r not in blockers):
    852         return reversed(r.owner.inputs)

AttributeError: 'str' object has no attribute 'owner'

PyMC version: 4.1.3
Aesara version: 2.7.7
Python version: 3.10.5
Operating system: MacOS (but the error also happens on linux)
How you installed PyMC: conda

Did not know about var_names option in sampling. You can open an issue on GitHub as this looks like a clear bug.

I dont think you can use var_names in pm.sample(...), it is not an API we currently support.

Ah it just gets eaten up by **kwargs and never raises? I thought that because it doesn’t fail it’s supposed to do something.

It raises an error with pm.sample(...) as well.

But it is definitely mentioned in the docstrings of sample_numpyro_nuts: pymc/sampling_jax.py at 6804c96cf2039780e167328d7476dc92f043554f · pymc-devs/pymc · GitHub

Whether by accident or not I don’t know.