For loop does not work for CustomDist distribution

I’ve made some updates. I could have rewritten the for-loop part using scan as shown below. However, now I’ve encountered a performance issue. This code involves dynamic slicing, so I cannot use the JAX/NumPyro sampler (it throws a NotImplementedError: JAX does not support slicing arrays with a dynamic slice length.). When I tried to test this scan version and the covariance matrix with pm.MvNormal as in the first snippet, but using pymc.sample(), it took 60-120 minutes, whereas the for-loop version with NumPyro only took 3 minutes.

offset = 0
n_list = np.zeros(C)
offset_list = np.zeros(C)
for i in range(C):
    n = sum(Z[:, i] == 1)
    n_list[i] = n
    offset_list[i] = offset
    offset += n

def oneStep(n, offset, Λ_tm, X, a, b):
    X_i = X[offset:(offset + n), :]
    iV_i = 1/a * pt.eye(n) - b/(a*(a + b*n)) * pt.ones(shape=[n, n])
    Λ = Λ_tm + pt.linalg.matrix_dot(X_i.T, iV_i, X_i)
    return Λ

Λ_ini = pt.as_tensor_variable(np.zeros([d, d]))
n_tv = pt.as_tensor_variable(n_list.astype('int32'))
offset_tv = pt.as_tensor_variable(offset_list.astype('int32'))
Λ, _ = pytensor.scan(fn=oneStep, outputs_info=Λ_ini,
                     sequences=[n_tv, offset_tv],
                     non_sequences=[pt.as_tensor_variable(X), a, b],
                     strict=True)
Λ = Λ[-1]

So, let me summarize the points I want to ask. The situation is that I want to use my custom distribution with CustomDist and feed the covariance matrix involving a for-loop to construct the custom distribution. I’d also like to use the NumPyro sampler instead of pymc.sample() because it’s much faster.

  1. The for-loop causes an AttributeError: 'Scratchpad' object has no attribute 'ufunc' when combined with CustomDist (see the first post). How should I resolve this?
  2. I tried rewriting the for-loop by scan (see this post). It runs without errors, but it’s very slow and doesn’t allow me to use JAX/NumPyro. Are there more efficient ways to write it?

Thank you in advance for your advice!