How to do SGMCMC with PyMC models?

I am trying to implement sgmcmc with models built in PyMC. My understanding is that PyMC does not currently have sgmcmc samplers yet, and meanwhile Blackjax seemed promising so I started with it. Now, I am stuck at trying to run the combination of these two examples given in Blackjax doc:
This example shows how to do sgld with a simple custom likelihood function without a PyMC model, while this example shows how to sample a PyMC model using nuts but not sgmcmc.
Here is the code I used:

J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
with pm.Model() as m:

    mu = pm.Normal("mu", mu=0.0, sigma=10.0)
    tau = pm.HalfCauchy("tau", 5.0)

    theta = pm.Normal("theta", mu=0, sigma=1, shape=J)
    theta_1 = mu + tau * theta
    obs = pm.Normal("obs", mu=theta_1, sigma=sigma, shape=J, observed=y)

# In the original example there is only 1 variable
# Here I broadcast it from 1 zero to a total of '# of total rvs' zeros, but not sure if it's correct
def logprior_fn(rvs):
    return [0] * len(rvs)

data_size = 8

rng_key = jax.random.PRNGKey(888)
rng_key, sample_key = jax.random.split(rng_key)
X_data = sample_fn(sample_key, data_size)
# Specify hyperparameters for SGLD
total_iter = 10_00
thinning_factor = 10

batch_size = 2
lr = 1e-3
temperature = 50.0

rvs = [rv.name for rv in m.value_vars]
logdensity_fn = get_jaxified_logp(m)

init_position_dict = m.initial_point()
init_position = [init_position_dict[rv] for rv in rvs]

grad_fn = gradients.grad_estimator(logprior_fn, logdensity_fn, data_size)
sgld = blackjax.sgld(grad_fn)

position = init_position
sgld_sample_list = jnp.array([])

pb = progress_bar(range(total_iter))
for iter_ in pb:
    rng_key, batch_key, sample_key = jax.random.split(rng_key, 3)
    data_batch = jax.random.shuffle(batch_key, X_data)[:batch_size, :]
    position = jax.jit(sgld.step)(sample_key, position, data_batch, lr, temperature)
    if iter_ % thinning_factor == 0:
        sgld_sample_list = jnp.append(sgld_sample_list, position)
        pb.comment = f"| position: {position: .2f}"

And the error returned:

TypeError                                 Traceback (most recent call last)
Cell In[86], line 8
      6 rng_key, batch_key, sample_key = jax.random.split(rng_key, 3)
      7 data_batch = jax.random.shuffle(batch_key, X_data)[:batch_size, :]
----> 8 position = jax.jit(sgld.step)(sample_key, position, data_batch, lr, temperature)
      9 if iter_ % thinning_factor == 0:
     10     sgld_sample_list = jnp.append(sgld_sample_list, position)

    [... skipping hidden 12 frame]

File /opt/conda/envs/pm5/lib/python3.11/site-packages/blackjax/sgmcmc/sgld.py:122, in sgld.__new__.<locals>.step_fn(rng_key, state, minibatch, step_size, temperature)
    115 def step_fn(
    116     rng_key: PRNGKey,
    117     state: ArrayLikeTree,
   (...)
    120     temperature: float = 1,
    121 ) -> ArrayTree:
--> 122     return kernel(
    123         rng_key, state, grad_estimator, minibatch, step_size, temperature
    124     )

File /opt/conda/envs/pm5/lib/python3.11/site-packages/blackjax/sgmcmc/sgld.py:40, in build_kernel.<locals>.kernel(rng_key, position, grad_estimator, minibatch, step_size, temperature)
     32 def kernel(
     33     rng_key: PRNGKey,
     34     position: ArrayLikeTree,
   (...)
     38     temperature: float = 1.0,
     39 ) -> ArrayTree:
---> 40     logdensity_grad = grad_estimator(position, minibatch)
     41     new_position = integrator(
     42         rng_key, position, logdensity_grad, step_size, temperature
     43     )
     45     return new_position

    [... skipping hidden 10 frame]

File /opt/conda/envs/pm5/lib/python3.11/site-packages/blackjax/sgmcmc/gradients.py:68, in logdensity_estimator.<locals>.logdensity_estimator_fn(position, minibatch)
     65 logprior = logprior_fn(position)
     66 batch_loglikelihood = jax.vmap(loglikelihood_fn, in_axes=(None, 0))
     67 return logprior + data_size * jnp.mean(
---> 68     batch_loglikelihood(position, minibatch), axis=0
     69 )

    [... skipping hidden 2 frame]

File /opt/conda/envs/pm5/lib/python3.11/site-packages/jax/_src/linear_util.py:188, in WrappedFun.call_wrapped(self, *args, **kwargs)
    185 gen = gen_static_args = out_store = None
    187 try:
--> 188   ans = self.f(*args, **dict(self.params, **kwargs))
    189 except:
    190   # Some transformations yield from inside context managers, so we have to
    191   # interrupt them before reraising the exception. Otherwise they will only
    192   # get garbage-collected at some later time, running their cleanup tasks
    193   # only after this exception is handled, which can corrupt the global
    194   # state.
    195   while stack:

TypeError: get_jaxified_logp.<locals>.logp_fn_wrap() takes 1 positional argument but 2 were given

BTW, the first example mentioned above in Blackjax doc has a typo at line “position = jax.jit(sgld)(sample_key, position, data_batch, lr, temperature)”, it should be “position = jax.jit(sgld.step)(sample_key, position, data_batch, lr, temperature)”.

I guess my questions are:

  1. what’s wrong with the error? I tried my best to read the code and guess it probably has something to do with this block of code in jax.py:
def get_jaxified_logp(model: Model, negative_logp=True) -> Callable:
    model_logp = model.logp()
    if not negative_logp:
        model_logp = -model_logp
    logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])

    def logp_fn_wrap(x):
        return logp_fn(*x)[0]

    return logp_fn_wrap

which seems to take only data as input, while the likelihood function examplified in Blackjax doc has both the state of variables and data as input. I am definitely not sure.
2. @junpenglao, please help take a look at this. I noticed Blackjax hasn’t been updated for a while. Are you guys still trying to make it compatible with the PPLs?
3. Generally speaking, what are recommandations to implement sgmcmc with a PyMC model, not limited to Blackjax? Any advice will be appreciated!

Is this more helpful perhaps?

It does look very similar. Let me try it out. Thank you!