Error during sampling - aesara.tensor.concatenate

Hi
I’m new to PyMC but really enjoying using it - thank you very much to the PyMC team!
I’ve encountered a shape error while running pm.sample() and also on pm.sampling_jax.sample_numpyro_nuts() - which oddly doesn’t show up when compiling the model - even when I go through shape.eval() on all of the inputs and outputs.
The shape error refers to a concatenate operation where I’m trying to pad an aesara 2d tensor with ones on the right - the most relevant part of the error message is below on the last line where I try to concatenate a column of ones to the right of matrix S:

ValueError: all the input array dimensions for the concatenation axis must match exactly, but along dimension 0, the array at index 0 has size 1 and the array at index 1 has size 473

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
/home/richardianbarnes/Projects/ordered_weibull/fw_weibor.ipynb Cell 27 in <cell line: 1>()
      1 with mod8:
----> 2     idata_8 = pymc.sampling_jax.sample_numpyro_nuts(target_accept=.99)

File ~/miniconda3/envs/pymc4_env/lib/python3.9/site-packages/pymc/sampling/jax.py:569, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
    566 tic1 = datetime.now()
    567 print("Compiling...", file=sys.stdout)
--> 569 init_params = _get_batched_jittered_initial_points(
    570     model=model,
...
    y = pm.Potential('y', logcens(ev_, t_, a, b, loc, w_))
  File "/tmp/ipykernel_12662/2482354477.py", line 15, in logcens
    Sj_ = at.concatenate([S[:, :], at.ones_like(S[:, :1])], axis=1)

Matrix S should have shape (473, 7) which I confirm when using shape.eval() during compilation of the pm.Model. But (if I’m understanding the message correctly) it’s telling me that dimension 0 of S is of size 1 whilst I’m trying to concatenate a column with dimension 0 of size 473. I tried a couple of variations of the code on the last line but in the end I deliberately specified concatenating S[:, :] and at.ones_like(S[:, :1]) so that I could establish that they should be broadcastable, so this is really puzzling me!

It’s also worth adding that I made an earlier version of the model which worked fine and sampled really well - I only made some changes to the main pm.Model - but didn’t make any changes to the function logcens() called in the pm.Potential() which is where this error occurs.

I’m using PyMC v4 with numpyro for the JAX backend with Jupyter notebook - running on WSL2. Just to confirm - the same error happens with both sampling_jax and the pm.sample() call.

I’d be really grateful for any help you can provide - I may have missed something really obvious but I just can’t see it at the moment!

The full code from Jupyter notebook:

import arviz as az
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pymc as pm
import pymc.sampling_jax
from aesara import tensor as at
from aesara import shared
import aesara

%config InlineBackend.figure_format = 'retina'
rng = np.random.default_rng(12345)
az.style.use("arviz-darkgrid")

AESARA_FLAGS = 'exception_verbosity=high'

# This contains time t of censoring for 
# the corresponding state (ordered as per "lcs" below)
fwdf = pd.read_csv('mod_data.csv') 
# This contains condition data with no corresponding age data - t needs to be a random variable
# our best guess is it's some time between 12 years and 80 years ago!
fw_nohist = pd.read_csv('quantities.csv')  
fw_nohist = fw_nohist.loc[(fw_nohist.t == -1) & (fw_nohist.lc_stage != 'nk'), :]
fwdf.f = fwdf.f + fwdf.s
Z_names = ['urban', 'pr']

loc_names = ['pr', 'pp']
b_names = ['f', 'c', 'pr']
fwgp = fwdf.groupby(list(set(Z_names + loc_names + b_names)) + ['t', 'lc_best', 'lc_worst'])['area_sqm'].sum().reset_index()
fwgp_nohist = fw_nohist.groupby(list(set(Z_names + loc_names + b_names)) + ['lc_stage'])['area_sqm'].sum().reset_index()
fwgp_nohist = pd.concat([fwgp_nohist] * 10)
fwgp_nohist['area_sqm'] /= 10.0
lcs = [
    "1 x100",
    "3 x_05",
    "3 x05_",
    "4_ x_05",
    "4_ x_05 3_ x60_",
    "4_ x05_",
    "4_ x05_ 3_ x60_",
    "4_ x40_"
]
lc_enum = np.arange(len(lcs))[np.newaxis, :]
lc_best_idx = np.array([lcs.index(lc) for lc in list(fwgp.lc_best)])[:, np.newaxis]
lc_worst_idx = np.array([lcs.index(lc) for lc in list(fwgp.lc_worst)])[:, np.newaxis]

ev = np.where((lc_enum >= lc_best_idx) & (lc_enum <= lc_worst_idx), 1.0, 0)
lc_nohist_idx = np.array([lcs.index(lc) for lc in list(fwgp_nohist.lc_stage)])[:, np.newaxis]
ev_nohist = np.where(lc_enum == lc_nohist_idx, 1.0, 0)
lateswitch = '4_ x_05' # This the lc stage up to which we should provide lc offset (because these are acceptable following treatment)
latemask = np.hstack([np.zeros(lcs.index(lateswitch) - 1), np.ones(len(lcs) - lcs.index(lateswitch))])

X_z = fwgp[Z_names].to_numpy(copy=True)
X_loc = fwgp[loc_names].to_numpy(copy=True)
X_b = fwgp[b_names].to_numpy(copy=True)

X_z_nohist = fwgp_nohist[Z_names].to_numpy(copy=True)
X_loc_nohist = fwgp_nohist[loc_names].to_numpy(copy=True)
X_b_nohist = fwgp_nohist[b_names].to_numpy(copy=True)

t = fwgp['t'].to_numpy()
w = fwgp['area_sqm'].to_numpy() / 500.0 
[fw_model_data_grouped.csv|attachment](upload://w6NUBhxDtCbjlGITP8goqnKuP8l.csv) (30.7 KB)
[fw_quantities.csv|attachment](upload://1dphYzwwV9lETmJ6qmAfpOUALNl.csv) (3.8 KB)

w_nohist = fwgp_nohist['area_sqm'].to_numpy(copy=True) / 500.0

def logcens(ev, t, a, b, loc, w):
    """
    ev, a and loc are shape N x (J - 1)
    w - vector of weights - shape N x 1
    """
    print('logcens: t: ' + str(t.shape.eval()))
    print('logcens: loc: ' + str(loc.shape.eval()))
    print('logcens: a: ' + str(a.shape.eval()))
    print('logcens: b: ' + str(b.shape.eval()))
    print('logcens: w: ' + str(w.shape.eval()))
    S = (at.exp(-(((t + loc) / a) ** b)))  #.dimshuffle(0, 'x')
    print('S: ' + str(S.shape.eval()))
    S_j = at.concatenate([at.zeros_like(S[:, :1]), S], axis=1)
    print('S_j: ' + str(S_j.shape.eval()))
    Sj_ = at.concatenate([S[:, :], at.ones_like(S[:, :1])], axis=1)
    print('Sj_: ' + str(Sj_.shape.eval()))
    P = at.sum(ev * (Sj_ - S_j), axis=1).dimshuffle(0, 'x')
    return w * at.log(P)


with pm.Model(coords={"lcs": lcs[2:], "b_names": b_names, "Z_names": Z_names, "loc_names": loc_names}) as mod8:
    a0_lc = pm.Normal("a0_lc", mu=np.log(10.0), sigma=1.0, shape=1)  # Can be negative
    #a_lc = pm.HalfNormal("a_lc", dims="lcs") # Strictly positive so cumulative is non-decreasing
    a_lc = pm.Gamma("a_lc", mu=1.0, sigma=1.0, dims='lcs')
    print('a_lc: ' + str(a_lc.shape.eval()))
    latemask_ = shared(latemask[np.newaxis, :])
    z_coef = pm.Normal('z_coef', 0, 1.0, dims='Z_names') #
    print('z_coef: ' + str(z_coef.shape.eval()))
    a = pm.Deterministic(
        "a", 
        at.exp(
            at.cumsum(
                at.concatenate([a0_lc, a_lc]), 
                axis=0
                ).dimshuffle('x', 0) + 
                at.dot(X_z, z_coef).dimshuffle(0, 'x')
            )
        )
    print('a: ' + str(a.shape.eval()))
    loc_coef = pm.Normal('loc_coef', mu=np.log(10.0), sigma=1.0, dims="loc_names")
    print('loc_coef: ' + str(loc_coef.shape.eval()))
    loc_early = pm.Normal('loc_early', mu=at.log(10.0), sigma=1.0)
    loc_late_dec = pm.HalfNormal('loc_late_dec') # Strictly positive so late is definitely less than early
    loc_late = pm.Deterministic('loc_late', loc_early - loc_late_dec)
    loc = pm.Deterministic(
        'loc', 
        at.exp(
            (
                (1 - latemask_) * loc_early + latemask_ * loc_late
            ) + 
            at.dot(X_loc, loc_coef).dimshuffle(0, 'x'))
        )
    print('loc: ' + str(loc.shape.eval()))
    b0 = pm.Normal('b0')
    b_coef = pm.Normal('b_coef', dims='b_names')
    b = pm.Deterministic('b', np.exp(b0 + at.dot(X_b, b_coef)).dimshuffle(0, 'x'))
    ev_ = shared(ev)
    t_ = shared(t[:, np.newaxis])
    w_ = shared(w[:, np.newaxis])
    tmin = 12
    tmax = 80
    v = 2.0
    mu_p = pm.Uniform('mu_p')
    p_nohist = pm.Beta('p_nohist', mu_p * v, (1 - mu_p) * v, shape=w_nohist.shape)  
    print('p_nohist: ' + str(p_nohist.shape.eval()))
    t_nohist = pm.Deterministic('t_nohist', (tmin + p_nohist * (tmax - tmin)).dimshuffle(0, 'x'))
    print('t_nohist: ' + str(t_nohist.shape.eval()))
    a_nohist = pm.Deterministic(
        "a_nohist", 
        at.exp(
            at.cumsum(
                at.concatenate([a0_lc, a_lc]), 
                axis=0
                ).dimshuffle('x', 0) + 
                at.dot(X_z_nohist, z_coef).dimshuffle(0, 'x')
            )
        )
    loc_nohist = pm.Deterministic(
        'loc_nohist', 
        at.exp(
            (
                (1 - latemask_) * loc_early + latemask_ * loc_late
            ) + 
            at.dot(X_loc_nohist, loc_coef).dimshuffle(0, 'x'))
        )
    b_nohist = pm.Deterministic('b_nohist', np.exp(b0 + at.dot(X_b_nohist, b_coef)).dimshuffle(0, 'x'))
    ev_nohist_ = shared(ev_nohist)
    
    w_nohist_ = shared(w_nohist[:, np.newaxis])
    y = pm.Potential('y', logcens(ev_, t_, a, b, loc, w_))
    y_nohist = pm.Potential('y_nohist', logcens(ev_nohist_, t_nohist, a_nohist, b_nohist, loc_nohist, w_nohist_))
    print('y: ' + str(y.shape.eval()))
    print('y_nohist: ' + str(y_nohist.shape.eval()))
with mod8:
    idata_mod8 = pymc.sampling_jax.sample_numpyro_nuts(target_accept=.99)