Translating a Stan model to the most performant PyMC3 form

Hey all,

I am Rok, a Stan developer, and I am looking for a bit of your help. I am trying to do a bit of performance comparisons between the various MCMC frameworks and would like to do it fairly.

And this is where I need help from someone that can read Stan but is very knowledgeable in PyMC3.

What would be the most performant way of writing this logistic regression model and running it with 4 chains in PyMC3 with JAX and optionally also with the previously used Theano backend:

data {
  int<lower=1> k;
  int<lower=0> n;
  matrix[n, k] X;
  int y[n];
parameters {
  vector[k] beta;
  real alpha;
model {
  target += std_normal_lpdf(beta);
  target += std_normal_lpdf(alpha);
  target += bernoulli_logit_glm_lpmf(y | X, alpha, beta);

While I think I could piece it together, I am in no way PyMC3-proficient enough to be sure I got the most performant version.

Thanks in advance!

with pm.Model(check_bounds=False) as model:
    alpha = pm.Normal('alpha', 0, 1)
    beta = pm.Normal('beta', 0, 1, shape=k)
    like = pm.Bernoulli('like', logit_p=alpha + beta @ x.T, observed=y)

Right now we only have a Python-based sampler which you can call with:

with model:
    posterior = pm.sample(chains=4)

In the next major release of PyMC you may also try:

with model:
    posterior = pm.sampling_jax.sample_numpyro_nuts(chains=4)

Which uses NumPyro NUTS sampler under the hood.

Now there is not much interesting going on with this model, and basically whichever library can evaluate it on a GPU will win in terms of performance. Having a smarter tuning or sampler algorithm is going to be dwarfed by how quickly you can evaluate the logp and dlogp, which basically amounts do that dot product in the likelihood.