I am trying to use float32
for sampling on the GPU. I get the following error indicating that I cannot use float32
for sampling.
File ".../lib/python3.11/site-packages/pymc/sampling/jax.py", line 611, in sample_jax_nuts
raw_mcmc_samples, sample_stats, library = sampler_fn(
^^^^^^^^^^^
File .../lib/python3.11/site-packages/pymc/sampling/jax.py", line 384, in _sample_blackjax_nuts
raw_mcmc_samples, sample_stats = map_fn(get_posterior_samples)(keys, initial_points)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lib/python3.11/site-packages/pymc/sampling/jax.py", line 260, in _blackjax_inference_loop
(last_state, tuned_params), _ = adapt.run(seed, init_position, num_steps=tune)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lib/python3.11/site-packages/blackjax/adaptation/window_adaptation.py", line 340, in run
last_state, info = scan_fn(
^^^^^^^^
File ".../lib/python3.11/site-packages/blackjax/progress_bar.py", line 105, in scan_wrap
(last_state, _), output = lax.scan(func, carry, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lib/python3.11/site-packages/blackjax/progress_bar.py", line 90, in wrapper_progress_bar
subcarry, y = func(subcarry, x)
^^^^^^^^^^^^^^^^^
File ".../lib/python3.11/site-packages/blackjax/adaptation/window_adaptation.py", line 310, in one_step
new_state, info = mcmc_kernel(
^^^^^^^^^^^^
File ".../lib/python3.11/site-packages/blackjax/mcmc/nuts.py", line 141, in kernel
proposal, info = proposal_generator(key_integrator, integrator_state, step_size)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lib/python3.11/site-packages/blackjax/mcmc/nuts.py", line 305, in propose
expansion_state, info = expand(
^^^^^^^
File ".../lib/python3.11/site-packages/blackjax/mcmc/trajectory.py", line 611, in expand
expansion_state, (is_diverging, is_turning) = jax.lax.while_loop(
^^^^^^^^^^^^^^^^^^^
File ".../lib/python3.11/site-packages/blackjax/mcmc/trajectory.py", line 554, in expand_once
) = trajectory_integrator(
^^^^^^^^^^^^^^^^^^^^^^
File ".../lib/python3.11/site-packages/blackjax/mcmc/trajectory.py", line 262, in integrate
new_integration_state, (is_diverging, has_terminated) = jax.lax.while_loop(
^^^^^^^^^^^^^^^^^^^
File ".../lib/python3.11/site-packages/blackjax/mcmc/trajectory.py", line 222, in add_one_state
(new_trajectory, sampled_proposal) = jax.lax.cond(
^^^^^^^^^^^^^
File ".../lib/python3.11/site-packages/blackjax/mcmc/trajectory.py", line 230, in <lambda>
sample_proposal(proposal_key, proposal, new_proposal),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File ".../lib/python3.11/site-packages/blackjax/mcmc/proposal.py", line 129, in progressive_uniform_sampling
return jax.lax.cond(
^^^^^^^^^^^^^
TypeError: true_fun and false_fun output must have identical types, got
Proposal(state=IntegratorState(position=['DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[6,99]) vs. ShapedArray(float32[6,99])'], momentum=['ShapedArray(float64[6])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[6,99])'], logdensity='ShapedArray(float64[])', logdensity_grad=['DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[6,99]) vs. ShapedArray(float32[6,99])']), energy='ShapedArray(float64[])', weight='ShapedArray(float64[])', sum_log_p_accept='ShapedArray(float64[])').
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Minimal reproducing example for the error is
import os
import time
from typing import Literal
import arviz
import arviz as az
import numpy as np
import pandas as pd
import pytensor
import pymc as pm
import pymc.sampling.jax
pytensor.config.floatX="float32"
pytensor.config.warn_float64="warn"
BETA = [0.0008, 0.0011, -0.0012, 0.0012, -0.0010, -0.0040]
TAU = [0.0078, 0.0099, 0.0079, 0.0062, 0.0098, 0.0190]
SIGMA = 0.07
def generate_data(num_groups: int = 5, num_samples_per_group: int = 20) -> pd.DataFrame:
"""Generating the data."""
rng = np.random.default_rng(seed=42)
groups = [i for i in range(num_groups)]
ids = [f"id_{i+1}" for i in range(num_samples_per_group)]
beta = BETA
tau = TAU
sigma = SIGMA
num_predictors = len(beta)
X = []
y = []
for _ in range(num_groups):
Xg = rng.normal(loc=0.0, scale=1.0, size=(num_samples_per_group, num_predictors))
beta_g = rng.normal(loc=beta, scale=tau)
yg = Xg.dot(beta_g) + sigma * rng.normal(size=num_samples_per_group)
X.append(Xg)
y.append(yg)
X_ = pd.DataFrame(np.concatenate(X, axis=0), columns=[f"x_{i+1}" for i in range(num_predictors)])
y_ = pd.DataFrame(np.concatenate(y, axis=0), columns=["y"])
frame = pd.concat([X_, y_], axis=1)
groups_ = np.repeat(groups, num_samples_per_group)
frame["group"] = groups_
ids_ = ids * num_groups
frame["id"] = ids_
return frame.set_index(["group", "id"])
def make_model(frame: pd.DataFrame) -> pm.Model:
"""Building the model."""
unique_groups = frame.index.unique(level="group")
predictors = [col for col in frame.columns if col.startswith("x")]
group_idx = frame.index.get_level_values("group")
coords = {"group": unique_groups, "predictor": predictors}
with pm.Model(coords=coords) as model:
# Data
x = pm.Data("X", frame[predictors])
g = pm.Data("g", group_idx)
y = pm.Data("y", frame["y"])
# Panel/population level
beta = pm.Normal("beta", sigma=0.01, dims="predictor")
tau = pm.Gamma("tau", mu=0.01, sigma=0.005, dims="predictor")
sigma = pm.HalfNormal("sigma", sigma=0.15)
# Group level
# ZeroSumNormal enforces a sum of zero over the last dimension (axis) of the
# generated array. We want zero when summing all groups, i.e.
# we need to put the "group" dimension last. Then we transpose to put the
# group dimension first
epsilon_beta_g = pm.ZeroSumNormal("epsilon_beta_g", dims=["predictor", "group"]).T
beta_g = beta + tau * epsilon_beta_g
# Linear model, is this selection correct
mu = (beta_g[g] * x).sum(axis=-1)
# Likelihood
pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y, shape=mu.shape)
return model
def _add_truth(summary):
ind = summary.index.str.startswith("beta")
summary.loc[ind, "true"] = BETA
ind = summary.index.str.startswith("tau")
summary.loc[ind, "true"] = TAU
summary.loc["sigma", "true"] = SIGMA
tmp = summary.pop("true")
summary.insert(0, "true", tmp)
if __name__ == "__main__":
frame = generate_data(num_groups=100, num_samples_per_group=270)
model = make_model(frame)
with model:
t0 = time.time()
trace = pymc.sampling.jax.sample_blackjax_nuts(
tune=300,
draws=1000,
chains=6,
chain_method="vectorized",
postprocessing_backend="cpu"
)
t = time.time() - t0
print(f"Time for sampling is {t=:.3f}s.")
summary = az.summary(trace, var_names=["beta", "tau", "sigma"], round_to=4).loc[:, ["mean", "hdi_3%", "hdi_97%", "ess_bulk", "ess_tail", "r_hat"]]
_add_truth(summary)
print(summary)