For loop does not work for CustomDist distribution

Hi,
I’m having trouble with combining the for-loop and the CustomDist function, which worked when I used pymc.MvNormal instead of CustomDist. I’d like to seek advice on how to tackle this issue.

I’m writing a custom likelihood function needed for my analysis using a linear mixed models framework. In this analysis, a specific form of covariance (the inverse of the Fisher information matrix) is used for the prior on regression coefficients. I first tested this specific covariance with pymc.MvNormal, and it worked without any issues and the posterior inference converged to the expected results. The snippet of the tested code is below.

# test data
# X and Z are design matrices for fixed effects and random effects, respectively.
C = 64
N = 512
Z = rng.multinomial(1, [1/C]*C, size=N)
r, c = np.where(Z == 1)
sortorder = np.hstack([r[c == i] for i in range(C)])
Z = Z[sortorder, :]

X = np.stack([np.ones(N), np.zeros(N), rng.lognormal(mean=0, sigma=1.5, size=N) + 50], axis=1)
d = X.shape[1]

# constructing a specific form of covariance matrix for the prior on β (regression coefficients)
# β is a vector of a length of d
m = np.zeros(d)
σ = pm.HalfStudentT('sgm', nu=2, sigma=1000)
s = pm.HalfStudentT('s', nu=2, sigma=1000)
a = σ**2
b = s**2
Λ = pytensor.shared(np.zeros([d, d]))
offset = 0
for i in range(C): # This for-loop!
    n = sum(Z[:, i] == 1)
    X_i = X[offset:(offset + n), :]
    iV_i = 1/a * np.eye(n) - b/(a*(a + b*n)) * np.ones([n, n]) 
    Λ += pt.linalg.matrix_dot(X_i.T, iV_i, X_i)
    offset += n
β = pm.MvNormal('be', mu=m, cov=Λ)
# ---omit unessential parts--- #
idata = pmj.sample_numpyro_nuts(draws=1000, chains=4, tune=9000, target_accept=0.8)

However, when I switched to my custom distribution, my code threw an AttributeError: 'Scratchpad' object has no attribute 'ufunc'. This custom distribution worked without errors and the posterior inference was fine as well when I set simple covariance variables to the parameter Λ, including a random value sampled from scipy.stats.invwishart.rvs() and pymc.LKJCholeskyCov.

β = pm.CustomDist('be', m, Λ, logp=logp_nonlocal,
                  random=random_nonlocal, support_point=support_point_nonlocal,
                  signature='()->()')
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pytensor/link/vm.py", line 407, in __call__
    thunk()
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pytensor/graph/op.py", line 515, in rval
    r = p(n, [x[0] for x in i], o)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pytensor/tensor/elemwise.py", line 747, in perform
    ufunc = node.tag.ufunc
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pytensor/graph/utils.py", line 286, in __getattribute__
    return super().__getattribute__(name)
AttributeError: 'Scratchpad' object has no attribute 'ufunc'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
  File "/Users/Library/Application Support/JetBrains/PyCharmCE2023.2/scratches/scratch_LMM3.py", line 156, in run_analysis
    idata = pmj.sample_numpyro_nuts(draws=1000, chains=4, tune=9000, target_accept=0.8)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pymc/sampling/jax.py", line 551, in sample_jax_nuts
    initial_points = _get_batched_jittered_initial_points(
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pymc/sampling/jax.py", line 218, in _get_batched_jittered_initial_points
    initial_points = _init_jitter(
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pymc/sampling/mcmc.py", line 1264, in _init_jitter
    point = ipfn(seed)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pymc/initial_point.py", line 169, in inner
    values = func(*args, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pytensor/compile/function/types.py", line 970, in __call__
    self.vm()
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pytensor/link/vm.py", line 411, in __call__
    raise_with_op(self.fgraph, node, thunk)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pytensor/link/utils.py", line 523, in raise_with_op
    raise exc_value.with_traceback(exc_trace)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pytensor/link/vm.py", line 407, in __call__
    thunk()
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pytensor/graph/op.py", line 515, in rval
    r = p(n, [x[0] for x in i], o)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pytensor/tensor/elemwise.py", line 747, in perform
    ufunc = node.tag.ufunc
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pytensor/graph/utils.py", line 286, in __getattribute__
    return super().__getattribute__(name)
AttributeError: 'Scratchpad' object has no attribute 'ufunc'
Apply node that caused the error: Add(<Matrix(float64, shape=(?, ?))>, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0, Dot22.0)
Toposort index: 227
Inputs types: [TensorType(float64, shape=(None, None)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3)), TensorType(float64, shape=(3, 3))]
Inputs shapes: [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
Inputs strides: [(24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8), (24, 8)]
Inputs values: ['not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown', 'not shown']
Outputs clients: [[True_div(Add.0, Add.0)]]
Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/Applications/PyCharm CE.app/Contents/plugins/python-ce/helpers/pydev/pydevd.py", line 2199, in <module>
    main()
  File "/Applications/PyCharm CE.app/Contents/plugins/python-ce/helpers/pydev/pydevd.py", line 2181, in main
    globals = debugger.run(setup['file'], None, None, is_module)
  File "/Applications/PyCharm CE.app/Contents/plugins/python-ce/helpers/pydev/pydevd.py", line 1493, in run
    return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
  File "/Applications/PyCharm CE.app/Contents/plugins/python-ce/helpers/pydev/pydevd.py", line 1500, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/Applications/PyCharm CE.app/Contents/plugins/python-ce/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/Users/Library/Application Support/JetBrains/PyCharmCE2023.2/scratches/scratch_LMM3.py", line 165, in <module>
    idata = run_analysis(y, X, Z)
  File "/Users/Library/Application Support/JetBrains/PyCharmCE2023.2/scratches/scratch_LMM3.py", line 119, in run_analysis
    Λ += pt.linalg.matrix_dot(X_i.T, iV_i, X_i)
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

I guess I need to rewrite the for-loop section using scan, but so far I haven’t been able to figure out how to do it… The point where I’m stuck is when I try to pass the list of PyTensor matrices, basic.py tries to convert the list to a NumPy array despite each element of the list being a different shape of matrices. Below is my attempt and the resultant error message. Any advice would be hugely appreciated!

offset = 0
X_list = [None]*C
n_list = np.ndarray(shape=C)
for i in range(C):
    n = sum(Z[:, i] == 1)
    X_list[i] = pt.as_tensor_variable(X[offset:(offset + n), :])
    n_list[i] = n
    offset += n
n_list = pt.as_tensor_variable(n_list)

def oneStep(X_i, n, Λ_tm, a, b):
    iV_i = 1/a * pt.eye(n) - b/(a*(a + b*n)) * pt.ones(shape=[n, n])
    Λ = Λ_tm + pt.linalg.matrix_dot(X_i.T, iV_i, X_i)
    return Λ

Λ_ini = pytensor.shared(np.zeros([d, d]))
Λ = pytensor.scan(fn=oneStep, outputs_info=Λ_ini,
                  sequences=[X_list, n_list],
                  non_sequences=[a, b])
Traceback (most recent call last):
  File "/Users/Library/Application Support/JetBrains/PyCharmCE2023.2/scratches/scratch_LMM3.py", line 141, in run_analysis
    Λ = pytensor.scan(fn=oneStep, outputs_info=Λ_ini,
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pytensor/scan/basic.py", line 595, in scan
    _seq_val = pt.as_tensor_variable(seq["input"])
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pytensor/tensor/__init__.py", line 50, in as_tensor_variable
    return _as_tensor_variable(x, name, ndim, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/functools.py", line 889, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pytensor/tensor/basic.py", line 177, in _as_tensor_Sequence
    return constant(x, name=name, ndim=ndim, dtype=dtype)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pytensor/tensor/basic.py", line 223, in constant
    x_ = ps.convert(x, dtype=dtype)
  File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/pytensor/scalar/basic.py", line 267, in convert
    x_ = np.asarray(x)
ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (64,) + inhomogeneous part.

I’ve made some updates. I could have rewritten the for-loop part using scan as shown below. However, now I’ve encountered a performance issue. This code involves dynamic slicing, so I cannot use the JAX/NumPyro sampler (it throws a NotImplementedError: JAX does not support slicing arrays with a dynamic slice length.). When I tried to test this scan version and the covariance matrix with pm.MvNormal as in the first snippet, but using pymc.sample(), it took 60-120 minutes, whereas the for-loop version with NumPyro only took 3 minutes.

offset = 0
n_list = np.zeros(C)
offset_list = np.zeros(C)
for i in range(C):
    n = sum(Z[:, i] == 1)
    n_list[i] = n
    offset_list[i] = offset
    offset += n

def oneStep(n, offset, Λ_tm, X, a, b):
    X_i = X[offset:(offset + n), :]
    iV_i = 1/a * pt.eye(n) - b/(a*(a + b*n)) * pt.ones(shape=[n, n])
    Λ = Λ_tm + pt.linalg.matrix_dot(X_i.T, iV_i, X_i)
    return Λ

Λ_ini = pt.as_tensor_variable(np.zeros([d, d]))
n_tv = pt.as_tensor_variable(n_list.astype('int32'))
offset_tv = pt.as_tensor_variable(offset_list.astype('int32'))
Λ, _ = pytensor.scan(fn=oneStep, outputs_info=Λ_ini,
                     sequences=[n_tv, offset_tv],
                     non_sequences=[pt.as_tensor_variable(X), a, b],
                     strict=True)
Λ = Λ[-1]

So, let me summarize the points I want to ask. The situation is that I want to use my custom distribution with CustomDist and feed the covariance matrix involving a for-loop to construct the custom distribution. I’d also like to use the NumPyro sampler instead of pymc.sample() because it’s much faster.

  1. The for-loop causes an AttributeError: 'Scratchpad' object has no attribute 'ufunc' when combined with CustomDist (see the first post). How should I resolve this?
  2. I tried rewriting the for-loop by scan (see this post). It runs without errors, but it’s very slow and doesn’t allow me to use JAX/NumPyro. Are there more efficient ways to write it?

Thank you in advance for your advice!

If you are using the latest PyMC you can try calling pymc.model.transform.optimization.freeze_dims_and_data — PyMC dev documentation and then sampling with numpyro.

Otherwise you can try the nutpie sampler which uses the numba backend.

Are you sure your implementations are the same? Numpyro sampler can also be very fast when diverging, which is useless.

The error you were getting was probably because your unrolled loop (without scan) had too many nodes and PyTensor fails to the function. I’ve seen that elsewhere but haven’t been able to reproduce.

If you could provide a reproducible snippet for the ufunc part that would be useful for us to finally reproduce that error

Thank you very much for your advice! Unfortunately, freeze_dims_and_data() with numpyro didn’t work somehow, as it still produced the same error (NotImplementedError: JAX does not support slicing arrays with a dynamic slice length.). However, nutpie sampler worked and the speed was decent. :blush:

Yes, I noticed that numpyro finishes very fast when diverging, but I suppose that was not the case here, as I either see no divergence warning or just a few divergent samples (usually less than 10) even when it happens.

I have attached my snippet in case you want to take a look.
Non-local_LMM.py (6.8 KB)
Thank you again for your help!

1 Like