Dynamic shaping, "round" function, JAX, and a "few" more questions

Hi, I’m trying to reproduce this example (LKJ Cholesky Covariance Priors for Multivariate Normal Models — PyMC example gallery) for one of my projects. When, to test the code, I pass the dictionary returned by model.initial_point() as the initvals parameter of the pm.sample sampler, I get the following error:

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Traceback (most recent call last):

  File ~/InstProg/anaconda3/envs/myenv/lib/python3.10/site-packages/spyder_kernels/py3compat.py:356 in compat_exec
    exec(code, globals, locals)

  File ~/Bureau/detector_work/IPC/pymc_ipc.py:304
    idata = pm.sample(draws=1000, tune=1000, chains=2, cores=2,

  File ~/InstProg/anaconda3/envs/myenv/lib/python3.10/site-packages/pymc/sampling/mcmc.py:682 in sample
    initial_points, step = init_nuts(

  File ~/InstProg/anaconda3/envs/myenv/lib/python3.10/site-packages/pymc/sampling/mcmc.py:1327 in init_nuts
    initial_points = _init_jitter(

  File ~/InstProg/anaconda3/envs/myenv/lib/python3.10/site-packages/pymc/sampling/mcmc.py:1204 in _init_jitter
    ipfns = make_initial_point_fns_per_chain(

  File ~/InstProg/anaconda3/envs/myenv/lib/python3.10/site-packages/pymc/initial_point.py:86 in make_initial_point_fns_per_chain
    make_initial_point_fn(

  File ~/InstProg/anaconda3/envs/myenv/lib/python3.10/site-packages/pymc/initial_point.py:134 in make_initial_point_fn
    sdict_overrides = convert_str_to_rv_dict(model, overrides or {})

  File ~/InstProg/anaconda3/envs/myenv/lib/python3.10/site-packages/pymc/initial_point.py:47 in convert_str_to_rv_dict
    initvals[rv] = model.rvs_to_transforms[rv].backward(initval, *rv.owner.inputs)

  File ~/InstProg/anaconda3/envs/myenv/lib/python3.10/site-packages/pymc/distributions/transforms.py:157 in backward
    return pt.set_subtensor(value[..., self.diag_idxs], pt.exp(value[..., self.diag_idxs]))

IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

It doesn’t happen if I pass “chol” directly to initvals and not the transformed parameter “chol_cholesky-cov-packed__”. Are you able to reproduce the same error on your side, and could you please explain where it comes from?

Thank you!