Jax Experiencing`NAME ERROR`

Hello,

I’m running into an error I’ve never seen before. It’s throwing a name error out as something not being defined. See the trace below.

---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
/tmp/ipykernel_3935/1149294247.py in <module>
----> 1 trace_month = pymc.sampling_jax.sample_numpyro_nuts(model=monthly_model, tune=1000, chain_method='vectorized', draws = 1000, idata_kwargs=dict(log_likelihood = False))

/opt/conda/lib/python3.7/site-packages/pymc/sampling_jax.py in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
    534     jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample)
    535     result = jax.vmap(jax.vmap(jax_fn))(
--> 536         *jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
    537     )
    538     mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}

    [... skipping hidden 6 frame]

/opt/conda/lib/python3.7/site-packages/aesara/link/utils.py in jax_funcified_fgraph(mu_slope, sigma_item_slope_log_, offset_item_slope, mu_intercept, sigma_item_intercept_log_, offset_item_intercept, mu_delta, sigma_item_delta_log_, offset_item_delta, yearly_mu, yearly_sigma_log_, yearly_beta, beta_month)
     29     # Elemwise{Composite{(i0 + i1 + i2 + ((i3 + i4) * i5) + i6)}}[(0, 2)](TensorConstant{(1,) of 5.0}, InplaceDimShuffle{x}.0, AdvancedSubtensor1.0, InplaceDimShuffle{x}.0, AdvancedSubtensor1.0, TensorConstant{[0. 0. 0. .. 1. 1. 1.]}, AdvancedSubtensor1.0)
     30     auto_49500 = composite(auto_49120, auto_49108, auto_49293, auto_48885, auto_49294, auto_48239, auto_49312)
---> 31     return mu_slope, offset_item_slope, mu_intercept, offset_item_intercept, mu_delta, offset_item_delta, yearly_mu, yearly_beta, beta_month, sigma_item_slope, sigma_item_intercept, sigma_item_delta, yearly_sigma, auto_48897, auto_49500
     32 

NameError: name 'auto_48897' is not defined

What is weird, is that auto_48897 is not in my code. See my model code below.

with pm.Model() as monthly_model:
        
    monthly_model.add_coord('items',items, mutable = True)
    monthly_model.add_coord('months',months, mutable = True)
    monthly_model.add_coord('changepoints',df_train.index.get_level_values(0)[np.argwhere(np.diff(A, axis=0) != 0)[:, 0]], mutable = True)
    monthly_model.add_coord('yearly_components', [f'yearly_{f}_{i+1}' for f in ['cos', 'sin'] for i in range(yearly_fourier.shape[1] // 2)], mutable = True)
    monthly_model.add_coord('obs_id',[f'{time.year}_month_{time.month}_item_{item}' for time, item in df_train.index.values], mutable = True)
    
    A_ = pm.Data('A', A, mutable=True, dims=['time', 'changepoints'])
    s_ = pm.Data('s', s, mutable=True, dims=['changepoints'])
    t_ = pm.Data('t', t, mutable=True, dims=['time'])
    yearly = pm.Data('yearly_season', yearly_fourier, mutable=True, dims=['obs_id', 'yearly_components'])

     # Slope
    mu_slope = pm.Normal('mu_slope', mu=0, sigma=0.1)
    sigma_item_slope = pm.HalfNormal('sigma_item_slope', sigma=0.1)
    offset_item_slope = pm.Normal('offset_item_slope', mu=0, sigma=0.1, dims=['items'])


    item_slope = sigma_item_slope * offset_item_slope
    initial_slope = mu_slope + item_slope[item_idxs]
                                     

    # Intercept
    mu_intercept = pm.Normal('mu_intercept', mu=0, sigma=0.1)
    sigma_item_intercept = pm.HalfNormal('sigma_item_intercept', sigma=0.1)
    offset_item_intercept = pm.Normal('offset_item_intercept', mu=0, sigma=0.1, dims=['items'])
    
    item_intercept = sigma_item_intercept * offset_item_intercept
    initial_intercept = mu_intercept + item_intercept[item_idxs]
    
    # Offsets
    mu_delta = pm.Normal('mu_delta', 0, 0.1)
    sigma_item_delta = pm.HalfNormal('sigma_item_delta', sigma=0.1)
    offset_item_delta = pm.Normal('offset_item_delta', mu=0, sigma=0.25, dims=['items', 'changepoints'])

    item_delta = sigma_item_delta * offset_item_delta
    delta = mu_delta + item_delta[item_idxs, :]
    
    #monthly seasonality
    yearly_mu = pm.Normal('yearly_mu', 0, 0.1)
    yearly_sigma = pm.HalfNormal('yearly_sigma', sigma=0.1)
    yearly_beta = pm.Normal('yearly_beta', yearly_mu, yearly_sigma, dims=['yearly_components'])
    yearly_seasonality = pm.Deterministic('yearly_seasonality', (yearly[time_idxs].sum(axis=1)), dims=['obs_id'])

    # Monthly Effects
    beta_month = pm.Normal('beta_month', mu=0, sigma=0.1, dims=['months'])
    intercept = initial_intercept + ((-s_ * A_)[time_idxs, :] * delta).sum(axis=1)
    slope = initial_slope + (A_[time_idxs, :] * delta).sum(axis=1)

    mu = pm.Deterministic('mu', intercept + slope * t_ + yearly_seasonality + beta_month[month_idxs], dims=['obs_id'])
    
    likelihood = pm.Poisson('predicted_eaches',
                            mu=pm.math.exp(mu),
                            observed=y,
                            dims=['obs_id'])

trace_month = pymc.sampling_jax.sample_numpyro_nuts(model=monthly_model, tune=1000, chain_method='vectorized', draws = 1000, idata_kwargs=dict(log_likelihood = False))

This model fit before. The only thing I changed was moving the coords inside the model itself with monthly_model.add_coords...

I’m not sure why it is throwing this type of error. Has anyone else seen this before?

UPDATE: This does seem to sample with pm.sample without a NAME ERROR