Jax sampling - NameError: name 'auto_403213' is not defined

Hey guys, first of all thank you for your amazing work on pymc!

I’ve just migrated a somewhat complex model from pymc3 to pymc.

I tried running fitting the model using the experimental jax sampling via

import pymc.sampling_jax
with model:
    trace: az.InferenceData = pm.sampling_jax.sample_numpyro_nuts(
        draws=100, tune=100,  # to test
        idata_kwargs={"log_likelihood": False}

It seems to run as expected for quite a while but throws an exception at the end of sampling:

Running chain 3: 100%|██████████| 200/200 [01:52<00:00,  1.78it/s]
Sampling time =  0:01:56.322437
Transforming variables...
Traceback (most recent call last):
File "/var/folders/sc/xxx/T/tmpp_zzzzz", line 51, in jax_funcified_fgraph
    return i_first_value, [....], age_below_40_mu, first_value_sigma, auto_403375, auto_403356, auto_403337, final_Y_sigma, auto_403213, auto_403236, auto_403259, auto_403282, auto_403305
NameError: name 'auto_403213' is not defined

Does anyone have advice for me what may be wrong here and how I could fix it?

I’m not sure what the auto_N nodes are for and why they are generated.

1 Like

Those are just dummy names used when creating the JAX functions. I am afraid you will have to share more details from your model to be able to help. Ideally a small reproducible example.

Thank you for taking your time to look at my question and to reply @ricardoV94

I can use the standard sample method (I guess that’s always NUTS) without issues.

I’ll see if I can create a small(ish) reproducible example

1 Like