Jax_sampling model

I have only been using pymc3 for a month or so but wanted to try out the new backend jax sampling on a model that i could understand and also gauge the increase in speed, a multinomial model using baseball data.

N = data_mlb.shape[0]
results = data_mlb[[‘single’, ‘double’, ‘triple’, ‘home_run’, ‘tw’, ‘strikeout’, ‘bo’]]
results = results.to_numpy()

K = results.shape[1]

with pm.Model() as hitting:
a = pm.Normal(‘a’, mu=0, sigma=1.5, shape=K)

ev0 = pm.math.exp(a[0])
ev1 = pm.math.exp(a[1])
ev2 = pm.math.exp(a[2])
ev3 = pm.math.exp(a[3])
ev4 = pm.math.exp(a[4])
ev5 = pm.math.exp(a[5])
ev6 = pm.math.exp(a[6])
ev = pm.math.stack([ev0, ev1, ev2, ev3, ev4, ev5, ev6]).T

p_ = pm.Dirichlet('p_', a=ev, shape=(N, K))
y = pm.Multinomial('y', n=data_mlb.pa, p=p_, shape=(N,K), observed=results)

The results are normal and as expected when doing traditional sampling. However, when I sample with sampling_jax.sample_tfp_nuts i get errors and zeros for all results. Using sampling_jax.sample_numpyro_nuts i get .143 or (1/number of outcomes). Not sure if I am missing something entirely or something specific with the jax sampling.

1 Like

I am having the same problem today. I am running on a Google Colab instance, with PyMC3 3.11.2 and JAX 0.2.11. This is my model:

p, n = df.shape
k = 5

with pm.Model() as non_hierarchical_model:
    exposures = pm.Dirichlet("W*", np.ones((n, k)), testval=np.ones((n, k))/k) 
    signatures = pm.Dirichlet("H", np.ones((k, p)), testval=np.ones((k, p))/p)
    exp_catalogue = pm.Deterministic("WH", pm.math.dot(exposures, signatures))
    pm.Multinomial("X", df.sum().values, exp_catalogue, observed=df.values.T, shape=(n,p), testval=df.values.T)

And this is how I sample from it:

with non_hierarchical_model:
  #trace = pm.sample()
  trace = pm.sampling_jax.sample_tfp_nuts()

I also get the same value (1/N) for all entries of my Dirichlet entries.

Does the model sample fine (no divergences) with the PyMC3 default NUTS sampler?

My model sampled fine using the default NUTS sampler and the results were what I anticipated and I recall all traceplots looked good. I have since tossed out the environment and data sorry.

In my case, default NUTS was very slow. I did not find any divergences, but maybe it’s because I took only a few samples. I switched to a Logit-Normal prior, which is working fine so far. Still looking for a solution though, since I think Dirichlet is a better fit for my case.

Most likely your model isn’t parameterized appropriately. I would first get it to reasonably sample with high ESS using the PyMC3 NUTS sampler and if that’s still too slow experiment with JAX.

What’s going on here most likely is that the JAX sampler just tunes worse than the PyMC3 one so it will have an even harder time for a not-well-specified model and not even get off the ground.

1 Like