Using `float32` for `pymc.sampling.jax.sample_blackjax_nuts`

I am trying to use float32 for sampling on the GPU. I get the following error indicating that I cannot use float32 for sampling.

File ".../lib/python3.11/site-packages/pymc/sampling/", line 611, in sample_jax_nuts
    raw_mcmc_samples, sample_stats, library = sampler_fn(
  File .../lib/python3.11/site-packages/pymc/sampling/", line 384, in _sample_blackjax_nuts
    raw_mcmc_samples, sample_stats = map_fn(get_posterior_samples)(keys, initial_points)
  File ".../lib/python3.11/site-packages/pymc/sampling/", line 260, in _blackjax_inference_loop
    (last_state, tuned_params), _ =, init_position, num_steps=tune)
  File ".../lib/python3.11/site-packages/blackjax/adaptation/", line 340, in run
    last_state, info = scan_fn(
  File ".../lib/python3.11/site-packages/blackjax/", line 105, in scan_wrap
    (last_state, _), output = lax.scan(func, carry, *args, **kwargs)
  File ".../lib/python3.11/site-packages/blackjax/", line 90, in wrapper_progress_bar
    subcarry, y = func(subcarry, x)
  File ".../lib/python3.11/site-packages/blackjax/adaptation/", line 310, in one_step
    new_state, info = mcmc_kernel(
  File ".../lib/python3.11/site-packages/blackjax/mcmc/", line 141, in kernel
    proposal, info = proposal_generator(key_integrator, integrator_state, step_size)
  File ".../lib/python3.11/site-packages/blackjax/mcmc/", line 305, in propose
    expansion_state, info = expand(
  File ".../lib/python3.11/site-packages/blackjax/mcmc/", line 611, in expand
    expansion_state, (is_diverging, is_turning) = jax.lax.while_loop(
  File ".../lib/python3.11/site-packages/blackjax/mcmc/", line 554, in expand_once
    ) = trajectory_integrator(
  File ".../lib/python3.11/site-packages/blackjax/mcmc/", line 262, in integrate
    new_integration_state, (is_diverging, has_terminated) = jax.lax.while_loop(
  File ".../lib/python3.11/site-packages/blackjax/mcmc/", line 222, in add_one_state
    (new_trajectory, sampled_proposal) = jax.lax.cond(
  File ".../lib/python3.11/site-packages/blackjax/mcmc/", line 230, in <lambda>
    sample_proposal(proposal_key, proposal, new_proposal),
  File ".../lib/python3.11/site-packages/blackjax/mcmc/", line 129, in progressive_uniform_sampling
    return jax.lax.cond(
TypeError: true_fun and false_fun output must have identical types, got
Proposal(state=IntegratorState(position=['DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[6,99]) vs. ShapedArray(float32[6,99])'], momentum=['ShapedArray(float64[6])', 'ShapedArray(float64[6])', 'ShapedArray(float64[])', 'ShapedArray(float64[6,99])'], logdensity='ShapedArray(float64[])', logdensity_grad=['DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[6]) vs. ShapedArray(float32[6])', 'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])', 'DIFFERENT ShapedArray(float64[6,99]) vs. ShapedArray(float32[6,99])']), energy='ShapedArray(float64[])', weight='ShapedArray(float64[])', sum_log_p_accept='ShapedArray(float64[])').
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Minimal reproducing example for the error is

import os
import time
from typing import Literal

import arviz
import arviz as az
import numpy as np
import pandas as pd
import pytensor
import pymc as pm
import pymc.sampling.jax


BETA = [0.0008, 0.0011, -0.0012, 0.0012, -0.0010, -0.0040]
TAU = [0.0078, 0.0099, 0.0079, 0.0062, 0.0098, 0.0190]
SIGMA = 0.07

def generate_data(num_groups: int = 5, num_samples_per_group: int = 20) -> pd.DataFrame:
    """Generating the data."""
    rng = np.random.default_rng(seed=42)
    groups = [i for i in range(num_groups)]
    ids = [f"id_{i+1}" for i in range(num_samples_per_group)]
    beta = BETA
    tau = TAU
    sigma = SIGMA
    num_predictors = len(beta)
    X = []
    y = []
    for _ in range(num_groups):        
        Xg = rng.normal(loc=0.0, scale=1.0, size=(num_samples_per_group, num_predictors))
        beta_g = rng.normal(loc=beta, scale=tau)
        yg = + sigma * rng.normal(size=num_samples_per_group)
    X_ = pd.DataFrame(np.concatenate(X, axis=0), columns=[f"x_{i+1}" for i in range(num_predictors)])
    y_ = pd.DataFrame(np.concatenate(y, axis=0), columns=["y"])
    frame = pd.concat([X_, y_], axis=1)
    groups_ = np.repeat(groups, num_samples_per_group)
    frame["group"] = groups_
    ids_ = ids * num_groups
    frame["id"] = ids_
    return frame.set_index(["group", "id"])

def make_model(frame: pd.DataFrame) -> pm.Model:
    """Building the model."""
    unique_groups = frame.index.unique(level="group")
    predictors = [col for col in frame.columns if col.startswith("x")]
    group_idx = frame.index.get_level_values("group")

    coords = {"group": unique_groups, "predictor": predictors}
    with pm.Model(coords=coords) as model:
        # Data
        x = pm.Data("X", frame[predictors])
        g = pm.Data("g", group_idx)
        y = pm.Data("y", frame["y"])

        # Panel/population level
        beta = pm.Normal("beta", sigma=0.01, dims="predictor")

        tau = pm.Gamma("tau", mu=0.01, sigma=0.005, dims="predictor")
        sigma = pm.HalfNormal("sigma", sigma=0.15)

        # Group level
        # ZeroSumNormal enforces a sum of zero over the last dimension (axis) of the
        # generated array. We want zero when summing all groups, i.e.
        # we need to put the "group" dimension last. Then we transpose to put the 
        # group dimension first
        epsilon_beta_g = pm.ZeroSumNormal("epsilon_beta_g", dims=["predictor", "group"]).T
        beta_g = beta + tau * epsilon_beta_g

        # Linear model, is this selection correct
        mu = (beta_g[g] * x).sum(axis=-1)

        # Likelihood
        pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y, shape=mu.shape)

    return model

def _add_truth(summary):
    ind = summary.index.str.startswith("beta")
    summary.loc[ind, "true"] = BETA

    ind = summary.index.str.startswith("tau")
    summary.loc[ind, "true"] = TAU

    summary.loc["sigma", "true"] = SIGMA

    tmp = summary.pop("true")
    summary.insert(0, "true", tmp)

if __name__ == "__main__":
    frame = generate_data(num_groups=100, num_samples_per_group=270)

    model = make_model(frame)

    with model:
        t0 = time.time()
        trace = pymc.sampling.jax.sample_blackjax_nuts(
        t = time.time() - t0
    print(f"Time for sampling is {t=:.3f}s.")
    summary = az.summary(trace, var_names=["beta", "tau", "sigma"], round_to=4).loc[:, ["mean", "hdi_3%", "hdi_97%", "ess_bulk", "ess_tail", "r_hat"]]

I found the solution. The changes to pytensor.config need to happen before the pymc import.