I am trying to implement sgmcmc with models built in PyMC. My understanding is that PyMC does not currently have sgmcmc samplers yet, and meanwhile Blackjax seemed promising so I started with it. Now, I am stuck at trying to run the combination of these two examples given in Blackjax doc:
This example shows how to do sgld with a simple custom likelihood function without a PyMC model, while this example shows how to sample a PyMC model using nuts but not sgmcmc.
Here is the code I used:
J = 8
y = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
with pm.Model() as m:
mu = pm.Normal("mu", mu=0.0, sigma=10.0)
tau = pm.HalfCauchy("tau", 5.0)
theta = pm.Normal("theta", mu=0, sigma=1, shape=J)
theta_1 = mu + tau * theta
obs = pm.Normal("obs", mu=theta_1, sigma=sigma, shape=J, observed=y)
# In the original example there is only 1 variable
# Here I broadcast it from 1 zero to a total of '# of total rvs' zeros, but not sure if it's correct
def logprior_fn(rvs):
return [0] * len(rvs)
data_size = 8
rng_key = jax.random.PRNGKey(888)
rng_key, sample_key = jax.random.split(rng_key)
X_data = sample_fn(sample_key, data_size)
# Specify hyperparameters for SGLD
total_iter = 10_00
thinning_factor = 10
batch_size = 2
lr = 1e-3
temperature = 50.0
rvs = [rv.name for rv in m.value_vars]
logdensity_fn = get_jaxified_logp(m)
init_position_dict = m.initial_point()
init_position = [init_position_dict[rv] for rv in rvs]
grad_fn = gradients.grad_estimator(logprior_fn, logdensity_fn, data_size)
sgld = blackjax.sgld(grad_fn)
position = init_position
sgld_sample_list = jnp.array([])
pb = progress_bar(range(total_iter))
for iter_ in pb:
rng_key, batch_key, sample_key = jax.random.split(rng_key, 3)
data_batch = jax.random.shuffle(batch_key, X_data)[:batch_size, :]
position = jax.jit(sgld.step)(sample_key, position, data_batch, lr, temperature)
if iter_ % thinning_factor == 0:
sgld_sample_list = jnp.append(sgld_sample_list, position)
pb.comment = f"| position: {position: .2f}"
And the error returned:
TypeError Traceback (most recent call last)
Cell In[86], line 8
6 rng_key, batch_key, sample_key = jax.random.split(rng_key, 3)
7 data_batch = jax.random.shuffle(batch_key, X_data)[:batch_size, :]
----> 8 position = jax.jit(sgld.step)(sample_key, position, data_batch, lr, temperature)
9 if iter_ % thinning_factor == 0:
10 sgld_sample_list = jnp.append(sgld_sample_list, position)
[... skipping hidden 12 frame]
File /opt/conda/envs/pm5/lib/python3.11/site-packages/blackjax/sgmcmc/sgld.py:122, in sgld.__new__.<locals>.step_fn(rng_key, state, minibatch, step_size, temperature)
115 def step_fn(
116 rng_key: PRNGKey,
117 state: ArrayLikeTree,
(...)
120 temperature: float = 1,
121 ) -> ArrayTree:
--> 122 return kernel(
123 rng_key, state, grad_estimator, minibatch, step_size, temperature
124 )
File /opt/conda/envs/pm5/lib/python3.11/site-packages/blackjax/sgmcmc/sgld.py:40, in build_kernel.<locals>.kernel(rng_key, position, grad_estimator, minibatch, step_size, temperature)
32 def kernel(
33 rng_key: PRNGKey,
34 position: ArrayLikeTree,
(...)
38 temperature: float = 1.0,
39 ) -> ArrayTree:
---> 40 logdensity_grad = grad_estimator(position, minibatch)
41 new_position = integrator(
42 rng_key, position, logdensity_grad, step_size, temperature
43 )
45 return new_position
[... skipping hidden 10 frame]
File /opt/conda/envs/pm5/lib/python3.11/site-packages/blackjax/sgmcmc/gradients.py:68, in logdensity_estimator.<locals>.logdensity_estimator_fn(position, minibatch)
65 logprior = logprior_fn(position)
66 batch_loglikelihood = jax.vmap(loglikelihood_fn, in_axes=(None, 0))
67 return logprior + data_size * jnp.mean(
---> 68 batch_loglikelihood(position, minibatch), axis=0
69 )
[... skipping hidden 2 frame]
File /opt/conda/envs/pm5/lib/python3.11/site-packages/jax/_src/linear_util.py:188, in WrappedFun.call_wrapped(self, *args, **kwargs)
185 gen = gen_static_args = out_store = None
187 try:
--> 188 ans = self.f(*args, **dict(self.params, **kwargs))
189 except:
190 # Some transformations yield from inside context managers, so we have to
191 # interrupt them before reraising the exception. Otherwise they will only
192 # get garbage-collected at some later time, running their cleanup tasks
193 # only after this exception is handled, which can corrupt the global
194 # state.
195 while stack:
TypeError: get_jaxified_logp.<locals>.logp_fn_wrap() takes 1 positional argument but 2 were given
BTW, the first example mentioned above in Blackjax doc has a typo at line “position = jax.jit(sgld)(sample_key, position, data_batch, lr, temperature)”, it should be “position = jax.jit(sgld.step)(sample_key, position, data_batch, lr, temperature)”.
I guess my questions are:
- what’s wrong with the error? I tried my best to read the code and guess it probably has something to do with this block of code in jax.py:
def get_jaxified_logp(model: Model, negative_logp=True) -> Callable:
model_logp = model.logp()
if not negative_logp:
model_logp = -model_logp
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
def logp_fn_wrap(x):
return logp_fn(*x)[0]
return logp_fn_wrap
which seems to take only data as input, while the likelihood function examplified in Blackjax doc has both the state of variables and data as input. I am definitely not sure.
2. @junpenglao, please help take a look at this. I noticed Blackjax hasn’t been updated for a while. Are you guys still trying to make it compatible with the PPLs?
3. Generally speaking, what are recommandations to implement sgmcmc with a PyMC model, not limited to Blackjax? Any advice will be appreciated!