pm.Categorical with sample_numpyro_nuts


I have used sample_numpyro_nuts because the sampler is less time consuming. It even works well with pm.Categorial() with observed; however, if I want to make predictions with the new data, it poses a problem.

with pm.Model() as model1:
    sigma_ard = pm.Gamma("sigma_ard", alpha = 2, beta = 0.01, shape = (n_model_1, num_action-1))
    beta_others = pm.Normal("beta_others", mu = 0, sigma = sigma_ard, shape = (n_model_1, num_action-1))

    sigma_ard0 = pm.Gamma("sigma_ard0", alpha = 2, beta = 0.01, shape = (1, num_action-1))
    beta0_others = pm.Normal("beta0_others", mu = 0, sigma = sigma_ard0, shape = (1, num_action-1))
    beta0_zero = np.zeros((1,1))
    beta0_all = pm.math.concatenate((beta0_others, beta0_zero), axis = 1)

    zeros = np.zeros((n_model_1,1))
    betas = pm.math.concatenate((beta_others, zeros), axis = 1)
    y_given_x = pm.Categorical("y_given_x", p = pm.math.softmax(, betas) + beta0_all, axis = 1), observed = train_y)

    y_pred = pm.Categorical("y_pred", p = pm.math.softmax(, betas) + beta0_all, axis = 1))
    samples1 = pmjax.sample_numpyro_nuts(tune = n_tune, chains = n_chains, draws = n_samples, idata_kwargs={'log_likelihood':True});

When I run the above code. There is an error:

  TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int64. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.

But, the above code works well if I omit the “y_pred” line, or if I change the sampler into pm.sample.

Would there be anyone who can explain what the problem is?

Thanks for your time and consideration,


NUTS, which numpyro uses can’t sample discrete variables. To make predictions you can use sample_posterior_predictive instead. Just recreate the model after you’re done sampling with the extra variable and include it in var_names.

This may help understand the uses: Out of model predictions with PyMC - PyMC Labs

1 Like

Does this model need an predictive model? It looks like you can just use pm.set_data here

1 Like


I’ve rather used directly imply the prediction term in my model rather than using “sample_posterior_predictive”. For such situation, I need to use “sample_posterior_predictive” with “pm_set_data.”