Hi,

I have used sample_numpyro_nuts because the sampler is less time consuming. It even works well with pm.Categorial() with observed; however, if I want to make predictions with the new data, it poses a problem.

```
with pm.Model() as model1:
sigma_ard = pm.Gamma("sigma_ard", alpha = 2, beta = 0.01, shape = (n_model_1, num_action-1))
beta_others = pm.Normal("beta_others", mu = 0, sigma = sigma_ard, shape = (n_model_1, num_action-1))
sigma_ard0 = pm.Gamma("sigma_ard0", alpha = 2, beta = 0.01, shape = (1, num_action-1))
beta0_others = pm.Normal("beta0_others", mu = 0, sigma = sigma_ard0, shape = (1, num_action-1))
beta0_zero = np.zeros((1,1))
beta0_all = pm.math.concatenate((beta0_others, beta0_zero), axis = 1)
zeros = np.zeros((n_model_1,1))
betas = pm.math.concatenate((beta_others, zeros), axis = 1)
y_given_x = pm.Categorical("y_given_x", p = pm.math.softmax(pm.math.dot(train_x, betas) + beta0_all, axis = 1), observed = train_y)
y_pred = pm.Categorical("y_pred", p = pm.math.softmax(pm.math.dot(test_x, betas) + beta0_all, axis = 1))
samples1 = pmjax.sample_numpyro_nuts(tune = n_tune, chains = n_chains, draws = n_samples, idata_kwargs={'log_likelihood':True});
```

When I run the above code. There is an error:

```
TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int64. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.
```

But, the above code works well if I omit the â€śy_predâ€ť line, or if I change the sampler into pm.sample.

Would there be anyone who can explain what the problem is?

Thanks for your time and consideration,

Jay.