Hey guys, first of all thank you for your amazing work on pymc!
I’ve just migrated a somewhat complex model from pymc3 to pymc.
I tried running fitting the model using the experimental jax sampling via
import pymc.sampling_jax
with model:
trace: az.InferenceData = pm.sampling_jax.sample_numpyro_nuts(
draws=100, tune=100, # to test
idata_kwargs={"log_likelihood": False}
)
It seems to run as expected for quite a while but throws an exception at the end of sampling:
[...]
Running chain 3: 100%|██████████| 200/200 [01:52<00:00, 1.78it/s]
Sampling time = 0:01:56.322437
Transforming variables...
Traceback (most recent call last):
[...]
File "/var/folders/sc/xxx/T/tmpp_zzzzz", line 51, in jax_funcified_fgraph
return i_first_value, [....], age_below_40_mu, first_value_sigma, auto_403375, auto_403356, auto_403337, final_Y_sigma, auto_403213, auto_403236, auto_403259, auto_403282, auto_403305
NameError: name 'auto_403213' is not defined
Does anyone have advice for me what may be wrong here and how I could fix it?
I’m not sure what the auto_N
nodes are for and why they are generated.