My main question is in regards to how memory is allocated when performing NUTS sampling on a user-defined `logp()`

model. This model (which I admit is quite onerous to compute) can be successfully fit using the Variation Inference methods but it chokes when being sampled with NUTS. The model is a mixture of several different distributions one of which is user-defined and requires numerical integration for normalization (the other three are built in pymc distributions: `ExGaussian()`

, `Normal()`

, and `Uniform()`

).

The data being fit are an array of large integer values (hence the `int64`

data type in the error below).

See the following error when sampling with NUTS:

```
with model:
res = pm.sample()
```

Produces the following memory error:

```
MemoryError Traceback (most recent call last)
MemoryError: Unable to allocate 1.47 TiB for an array with shape (201417847152,) and data type int64
```

The full traceback is as follows:

```
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mL, sL0, tI0, mT0, sT0, w, mL_a, sL_a0, tI_a0, w_a]
0.04% [3/8000 00:00<29:53 Sampling 4 chains, 0 divergences]
/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/link/utils.py:529: UserWarning: <class 'numpy.core._exceptions._ArrayMemoryError'> error does not allow us to add an extra error message
warnings.warn(
/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/link/utils.py:529: UserWarning: <class 'numpy.core._exceptions._ArrayMemoryError'> error does not allow us to add an extra error message
warnings.warn(
---------------------------------------------------------------------------
RemoteTraceback Traceback (most recent call last)
RemoteTraceback:
"""
Traceback (most recent call last):
File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/parallel_sampling.py", line 129, in run
self._start_loop()
File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/parallel_sampling.py", line 182, in _start_loop
point, stats = self._compute_point()
File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/parallel_sampling.py", line 207, in _compute_point
point, stats = self._step_method.step(self._point)
File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/step_methods/arraystep.py", line 286, in step
return super().step(point)
File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/step_methods/arraystep.py", line 208, in step
step_res = self.astep(q)
File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/step_methods/hmc/base_hmc.py", line 186, in astep
hmc_step = self._hamiltonian_step(start, p0.data, step_size)
File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/step_methods/hmc/nuts.py", line 194, in _hamiltonian_step
divergence_info, turning = tree.extend(direction)
File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/step_methods/hmc/nuts.py", line 295, in extend
tree, diverging, turning = self._build_subtree(
File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/step_methods/hmc/nuts.py", line 373, in _build_subtree
return self._single_step(left, epsilon)
File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/step_methods/hmc/nuts.py", line 333, in _single_step
right = self.integrator.step(epsilon, left)
File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/step_methods/hmc/integration.py", line 73, in step
return self._step(epsilon, state)
File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/step_methods/hmc/integration.py", line 109, in _step
logp = self._logp_dlogp_func(q_new, grad_out=q_new_grad)
File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/model.py", line 410, in __call__
cost, *grads = self._aesara_function(*grad_vars)
File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py", line 984, in __call__
raise_with_op(
File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/link/utils.py", line 534, in raise_with_op
raise exc_value.with_traceback(exc_trace)
File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py", line 971, in __call__
self.vm()
File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/op.py", line 543, in rval
r = p(n, [x[0] for x in i], o)
File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/tensor/basic.py", line 2923, in perform
out[0] = np.arange(start, stop, step, dtype=self.dtype)
numpy.core._exceptions._ArrayMemoryError: Unable to allocate 1.47 TiB for an array with shape (201417847152,) and data type int64
"""
The above exception was the direct cause of the following exception:
MemoryError Traceback (most recent call last)
MemoryError: Unable to allocate 1.47 TiB for an array with shape (201417847152,) and data type int64
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
Input In [7], in <cell line: 2>()
1 import pymc as pm
2 with model:
----> 3 res = pm.sample()
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling.py:609, in sample(draws, step, init, n_init, initvals, trace, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
607 _print_step_hierarchy(step)
608 try:
--> 609 mtrace = _mp_sample(**sample_args, **parallel_args)
610 except pickle.PickleError:
611 _log.warning("Could not pickle model, sampling singlethreaded.")
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling.py:1521, in _mp_sample(draws, tune, step, chains, cores, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, **kwargs)
1519 try:
1520 with sampler:
-> 1521 for draw in sampler:
1522 strace = traces[draw.chain]
1523 if strace.supports_sampler_stats and draw.stats is not None:
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/parallel_sampling.py:463, in ParallelSampler.__iter__(self)
460 self._progress.update(self._total_draws)
462 while self._active:
--> 463 draw = ProcessAdapter.recv_draw(self._active)
464 proc, is_last, draw, tuning, stats, warns = draw
465 self._total_draws += 1
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/parallel_sampling.py:353, in ProcessAdapter.recv_draw(processes, timeout)
351 else:
352 error = RuntimeError("Chain %s failed." % proc.chain)
--> 353 raise error from old_error
354 elif msg[0] == "writing_done":
355 proc._readable = True
RuntimeError: Chain 3 failed.
```

However, when using ADVI, the fitting is possible (if not slow):

```
vi = pm.ADVI(model=model)
approx = vi.fit(10000)
```

Output (intentionally interrupted mid fitting):

```
30.88% [3088/10000 06:25<14:22 Average Loss = 1.0235e+05]
Interrupted at 3,088 [30%]: Average Loss = 1.0395e+05
```

And the final fit result using VI is generally what I would expect (the model fits the data and the resulting best fit parameter values make sense). However, my suspicion is the posterior distributions are actually a bit more complicated (e.g. possibly bimodal) which is why I’m trying to figure out how to get the NUTS sampler to work.

Though the full mixture is too much to report here, the following is the `logp`

function (and the functions on which it depends) for the custom component of the full model:

```
import aesara.tensor as tt
import pymc as pm
# CDF/logCDF components
def _emg_cdf(x, mu, sigma, tau):
rv = pm.ExGaussian.dist(mu=mu,sigma=sigma, nu=tau)
lcdf = pm.logcdf(rv, x)
return tt.exp(lcdf)
def _log_emg_cdf(x, mu, sigma, tau):
rv = pm.ExGaussian.dist(mu=mu,sigma=sigma, nu=tau)
lcdf = pm.logcdf(rv, x)
return lcdf
def _norm_sf(x, mu, sigma):
arg = (x - mu) / (sigma * tt.sqrt(2.0))
return 0.5 * tt.erfc(arg)
def _log_norm_sf(x, mu, sigma):
return pm.distributions.dist_math.normal_lccdf(mu, sigma, x)
# Custom log pdf
def e_logp(x, mL, sL, tI, mT, sT):
# Compute norm factor by numeric integrating over entire distribution
_n = 10 #number of stdevs for numerical normalization
_min = tt.floor(tt.min([mL-_n*sL, mT-_n*sT]))
_max = tt.ceil(tt.max([mL+_n*np.sqrt(sL**2+tI**2), mT+_n*sT]))
_x = tt.arange(_min, _max, dtype="int64")
_norm_array = (
_emg_cdf(_x, mu=mL, sigma=sL, tau=tI)
*_norm_sf(_x, mu=mT, sigma=sT)
)
_log_norm_factor = tt.log(tt.sum(_norm_array))
# Unnormalized dist values (log(CDF*SF) = log(CDF) + log(SF))
_log_unscaled = (
_log_emg_cdf(x, mu=mL, sigma=sL, tau=tI)
+_log_norm_sf(x, mu=mT, sigma=sT)
)
# Normalize distribution in logscale
log_pdf = _log_unscaled - _log_norm_factor
return log_pdf
```

I then use `DensityDist()`

to generate a component RV for the mixture:

```
e_pdf = pm.DensityDist.dist(mL, sL, tI, mT, sT, logp=e_logp, class_name='e_pdf')
```

(all parameters `mL`

… `sT`

have normal or exponential priors with reasonable length scales on the order of `1e4`

)

I guess what I’m asking is **what is it about this model that’s resulting it too great of memory allocation when fitting with NUTS but not with the VI methods?** Any insight y’all could provide would be greatly appreciated!

An aside question: Why does the `.dist()`

method now require a `class_name`

string to identify it when in PyMC3 this was not required?