Memory allocation limit for NUTS with custom `logp` function (but not with VI methods)

My main question is in regards to how memory is allocated when performing NUTS sampling on a user-defined logp() model. This model (which I admit is quite onerous to compute) can be successfully fit using the Variation Inference methods but it chokes when being sampled with NUTS. The model is a mixture of several different distributions one of which is user-defined and requires numerical integration for normalization (the other three are built in pymc distributions: ExGaussian(), Normal(), and Uniform()).

The data being fit are an array of large integer values (hence the int64 data type in the error below).

See the following error when sampling with NUTS:

with model:
    res = pm.sample()

Produces the following memory error:

MemoryError                               Traceback (most recent call last)
MemoryError: Unable to allocate 1.47 TiB for an array with shape (201417847152,) and data type int64

The full traceback is as follows:

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mL, sL0, tI0, mT0, sT0, w, mL_a, sL_a0, tI_a0, w_a]

 0.04% [3/8000 00:00<29:53 Sampling 4 chains, 0 divergences]
/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/link/utils.py:529: UserWarning: <class 'numpy.core._exceptions._ArrayMemoryError'> error does not allow us to add an extra error message
  warnings.warn(
/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/link/utils.py:529: UserWarning: <class 'numpy.core._exceptions._ArrayMemoryError'> error does not allow us to add an extra error message
  warnings.warn(
---------------------------------------------------------------------------
RemoteTraceback                           Traceback (most recent call last)
RemoteTraceback: 
"""
Traceback (most recent call last):
  File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/parallel_sampling.py", line 129, in run
    self._start_loop()
  File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/parallel_sampling.py", line 182, in _start_loop
    point, stats = self._compute_point()
  File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/parallel_sampling.py", line 207, in _compute_point
    point, stats = self._step_method.step(self._point)
  File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/step_methods/arraystep.py", line 286, in step
    return super().step(point)
  File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/step_methods/arraystep.py", line 208, in step
    step_res = self.astep(q)
  File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/step_methods/hmc/base_hmc.py", line 186, in astep
    hmc_step = self._hamiltonian_step(start, p0.data, step_size)
  File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/step_methods/hmc/nuts.py", line 194, in _hamiltonian_step
    divergence_info, turning = tree.extend(direction)
  File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/step_methods/hmc/nuts.py", line 295, in extend
    tree, diverging, turning = self._build_subtree(
  File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/step_methods/hmc/nuts.py", line 373, in _build_subtree
    return self._single_step(left, epsilon)
  File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/step_methods/hmc/nuts.py", line 333, in _single_step
    right = self.integrator.step(epsilon, left)
  File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/step_methods/hmc/integration.py", line 73, in step
    return self._step(epsilon, state)
  File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/step_methods/hmc/integration.py", line 109, in _step
    logp = self._logp_dlogp_func(q_new, grad_out=q_new_grad)
  File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/model.py", line 410, in __call__
    cost, *grads = self._aesara_function(*grad_vars)
  File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py", line 984, in __call__
    raise_with_op(
  File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/link/utils.py", line 534, in raise_with_op
    raise exc_value.with_traceback(exc_trace)
  File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py", line 971, in __call__
    self.vm()
  File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/graph/op.py", line 543, in rval
    r = p(n, [x[0] for x in i], o)
  File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/tensor/basic.py", line 2923, in perform
    out[0] = np.arange(start, stop, step, dtype=self.dtype)
numpy.core._exceptions._ArrayMemoryError: Unable to allocate 1.47 TiB for an array with shape (201417847152,) and data type int64
"""

The above exception was the direct cause of the following exception:

MemoryError                               Traceback (most recent call last)
MemoryError: Unable to allocate 1.47 TiB for an array with shape (201417847152,) and data type int64

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
Input In [7], in <cell line: 2>()
      1 import pymc as pm
      2 with model:
----> 3     res = pm.sample()

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling.py:609, in sample(draws, step, init, n_init, initvals, trace, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
    607 _print_step_hierarchy(step)
    608 try:
--> 609     mtrace = _mp_sample(**sample_args, **parallel_args)
    610 except pickle.PickleError:
    611     _log.warning("Could not pickle model, sampling singlethreaded.")

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling.py:1521, in _mp_sample(draws, tune, step, chains, cores, random_seed, start, progressbar, trace, model, callback, discard_tuned_samples, mp_ctx, **kwargs)
   1519 try:
   1520     with sampler:
-> 1521         for draw in sampler:
   1522             strace = traces[draw.chain]
   1523             if strace.supports_sampler_stats and draw.stats is not None:

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/parallel_sampling.py:463, in ParallelSampler.__iter__(self)
    460     self._progress.update(self._total_draws)
    462 while self._active:
--> 463     draw = ProcessAdapter.recv_draw(self._active)
    464     proc, is_last, draw, tuning, stats, warns = draw
    465     self._total_draws += 1

File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/parallel_sampling.py:353, in ProcessAdapter.recv_draw(processes, timeout)
    351     else:
    352         error = RuntimeError("Chain %s failed." % proc.chain)
--> 353     raise error from old_error
    354 elif msg[0] == "writing_done":
    355     proc._readable = True

RuntimeError: Chain 3 failed.

However, when using ADVI, the fitting is possible (if not slow):

vi = pm.ADVI(model=model)
approx = vi.fit(10000)

Output (intentionally interrupted mid fitting):

 30.88% [3088/10000 06:25<14:22 Average Loss = 1.0235e+05]
Interrupted at 3,088 [30%]: Average Loss = 1.0395e+05

And the final fit result using VI is generally what I would expect (the model fits the data and the resulting best fit parameter values make sense). However, my suspicion is the posterior distributions are actually a bit more complicated (e.g. possibly bimodal) which is why I’m trying to figure out how to get the NUTS sampler to work.

Though the full mixture is too much to report here, the following is the logp function (and the functions on which it depends) for the custom component of the full model:

import aesara.tensor as tt
import pymc as pm

# CDF/logCDF components
def _emg_cdf(x, mu, sigma, tau):
    rv = pm.ExGaussian.dist(mu=mu,sigma=sigma, nu=tau)
    lcdf = pm.logcdf(rv, x)
    return tt.exp(lcdf)

def _log_emg_cdf(x, mu, sigma, tau):
    rv = pm.ExGaussian.dist(mu=mu,sigma=sigma, nu=tau)
    lcdf = pm.logcdf(rv, x)
    return lcdf

def _norm_sf(x, mu, sigma):
    arg = (x - mu) / (sigma * tt.sqrt(2.0))
    return 0.5 * tt.erfc(arg)

def _log_norm_sf(x, mu, sigma):
    return pm.distributions.dist_math.normal_lccdf(mu, sigma, x)

# Custom log pdf
def e_logp(x, mL, sL, tI, mT, sT):
    # Compute norm factor by numeric integrating over entire distribution
    _n = 10 #number of stdevs for numerical normalization
    _min = tt.floor(tt.min([mL-_n*sL, mT-_n*sT]))
    _max = tt.ceil(tt.max([mL+_n*np.sqrt(sL**2+tI**2), mT+_n*sT]))

    _x = tt.arange(_min, _max, dtype="int64")

    _norm_array = (
        _emg_cdf(_x, mu=mL, sigma=sL, tau=tI) 
        *_norm_sf(_x, mu=mT, sigma=sT)
    )
    _log_norm_factor = tt.log(tt.sum(_norm_array))

    # Unnormalized dist values (log(CDF*SF) = log(CDF) + log(SF))
    _log_unscaled = (
        _log_emg_cdf(x, mu=mL, sigma=sL, tau=tI)
        +_log_norm_sf(x, mu=mT, sigma=sT)
    )

    # Normalize distribution in logscale
    log_pdf = _log_unscaled - _log_norm_factor

    return log_pdf

I then use DensityDist() to generate a component RV for the mixture:

e_pdf = pm.DensityDist.dist(mL, sL, tI, mT, sT, logp=e_logp, class_name='e_pdf')

(all parameters mL … sT have normal or exponential priors with reasonable length scales on the order of 1e4)

I guess what I’m asking is what is it about this model that’s resulting it too great of memory allocation when fitting with NUTS but not with the VI methods? Any insight y’all could provide would be greatly appreciated!

An aside question: Why does the .dist() method now require a class_name string to identify it when in PyMC3 this was not required?

This looks like some kind of bug to me, either in pymc, aesara or in your model. Can you post something that I can reproduce locally?

1 Like

Because we have to create a Python class dynamically, and before we didn’t.

1 Like

@aseyboldt and @ricardoV94 thank you for the willingness to help—your insights are always appreciated.

The below code is the most condensed I could make my pipeline such that it still produces the aforementioned bug. The model consists of 4 components (two conventional distributions–ExGauss and Norm, the custom distribution stated earlier in the thread, and a uniform background). This first block generates a mock dataset (dependencies scipy and numpy) to which I apply the pymc model.

Mock Data

import numpy as np
import scipy as sp

# Mock Parameters
m0 = 0
s0 = 1000
t0 = 500
m2 = 15000
s2 = 2500
w = [0.15, 0.55, 0.15, 0.15]
N = 10000
x = np.array(range(-10000, 25000))

# Model components
def rvs_0(m, s, t, size=1000, seed=42):
    samples = sp.stats.exponnorm.rvs(t/s, m, s, size=size, random_state=seed)
    samples = np.int0(np.round(samples))
    return samples

def pdf_1(x, m0, s0, t0, m1, s1):

    cdf = sp.stats.exponnorm.cdf(x, t0/s0, m0, s0)
    sf = sp.stats.norm.sf(x, m1, s1)
    unscaled = np.nan_to_num(cdf * sf)

    xmin = int(min(m0 - 10*s0, m1 - 10*s1))
    xmax = int(max(m0 + 10*np.sqrt(s0**2 + t0**2), m1 + 10*s1))
    xfull = np.array(range(xmin, xmax))

    cdf = sp.stats.exponnorm.cdf(xfull, t0/s0, m0, s0)
    sf = sp.stats.norm.sf(xfull, m1, s1)
    norm_factor = sum(np.nan_to_num(cdf * sf))

    pdf = np.nan_to_num(unscaled/norm_factor)
    return pdf

def rvs_1(m0, s0, t0, m1, s1, size=1000, seed=42):

    # Adjust <x> so that it encompasses full normalization range (xfull)
    xmin = int(min(m0 - 10*s0, m1 - 10*s1))
    xmax = int(max(m0 + 10*np.sqrt(s0**2 + t0**2), m1 + 10*s1))
    xfull = np.array(range(xmin, xmax))

    pdf = pdf_1(xfull, m0=m0, s0=s0, t0=t0, m1=m1, s1=s1)

    # Adjust pdf to sum to 1.0 (residue from finite normalization integration)
    residue = 1.0 - sum(pdf)
    pdf[-1] = pdf[-1] + abs(residue)

    np.random.seed(seed=seed)
    samples = np.random.choice(xfull, size=size, replace=True, p=pdf)
    return samples


def rvs_2(m, s, size=1000, seed=42):
    samples = sp.stats.norm.rvs(m, s, size=size, random_state=seed)
    samples = np.int0(np.round(samples))
    return samples


def rvs_b(x, size=10, seed=42):
    np.random.seed(seed=seed)
    samples = np.random.choice(x, size=size, replace=True)
    return samples

# Data point distributions
w0, w1, w2, wb = w
N0 = int(w0 * N)
N1 = int(w1 * N)
N2 = int(w2 * N)
Nb = int(wb * N)

# Full dataset 
data = np.concatenate((
    rvs_0(m0, s0, t0, size=N0), 
    rvs_1(m0, s0, t0, m2, s2, size=N1), 
    rvs_2(m2, s2, size=N2), 
    rvs_b(x, size=Nb)), 
    axis=0
)

PyMC model build

The above generated data set is then used as the observed data (in full_model) in the following code block in which I create the PyMC mixture model.

import pymc as pm
import aesara.tensor as at

with pm.Model() as mod:

    # CDF/logCDF components
    def _emg_cdf(x, mu, sigma, tau):
        rv = pm.ExGaussian.dist(mu=mu,sigma=sigma, nu=tau)
        lcdf = pm.logcdf(rv, x)
        return at.exp(lcdf)

    def _log_emg_cdf(x, mu, sigma, tau):
        rv = pm.ExGaussian.dist(mu=mu,sigma=sigma, nu=tau)
        lcdf = pm.logcdf(rv, x)
        return lcdf

    def _norm_sf(x, mu, sigma):
        arg = (x - mu) / (sigma * at.sqrt(2.0))
        return 0.5 * at.erfc(arg)

    def _log_norm_sf(x, mu, sigma):
        return pm.distributions.dist_math.normal_lccdf(mu, sigma, x)

    # Custom log pdf
    def comp1_logp(x, m0, s0, t0, m2, s2):
        # Compute norm factor by numeric integrating over entire distribution
        _n = 10 #number of stdevs for numerical normalization
        _min = at.floor(at.min([m0-_n*s0, m2-_n*s2]))
        _max = at.ceil(at.max([m0+_n*np.sqrt(s0**2+t0**2), m2+_n*s2]))

        _x = at.arange(_min, _max, dtype="int64")

        _norm_array = (
            _emg_cdf(_x, mu=m0, sigma=s0, tau=t0) 
            *_norm_sf(_x, mu=m2, sigma=s2)
        )
        _log_norm_factor = at.log(at.sum(_norm_array))

        # Unnormalized dist values (log(CDF*SF) = log(CDF) + log(SF))
        _log_unscaled = (
            _log_emg_cdf(x, mu=m0, sigma=s0, tau=t0)
            +_log_norm_sf(x, mu=m2, sigma=s2)
        )

        # Normalize distribution in logscale
        log_pdf = _log_unscaled - _log_norm_factor

        return log_pdf

    # Define parameter priors
    m0 = pm.Normal('m0', mu=0, sigma=2000)
    
    s0_0 = pm.Exponential('s0_0', lam=1/5000)
    s0 = pm.Deterministic('s0', s0_0 + 10)
    
    t0_0 = pm.Exponential('t0_0', lam=1/500)
    t0 = pm.Deterministic('t0', t0_0 + 10)
    
    m2_0 = pm.Exponential('m2_0', lam=1/10000)
    m2 = pm.Deterministic('m2', m2_0 + 10000)
    
    s2_0 = pm.Exponential('s2_0', lam=1/5000)
    s2 = pm.Deterministic('s2', s2_0 + 10)
    
    w = pm.Dirichlet('w', a=np.array([1, 1, 1, 1]))
    
    # Define distributions for mixture model component
    comp0 = pm.ExGaussian.dist(mu=m0, sigma=s0, nu=t0)
    comp1 = pm.DensityDist.dist(m0, s0, t0, m2, s2, logp=comp1_logp, class_name='comp1')
    comp2 = pm.Normal.dist(mu=m2, sigma=s2)
    compb = pm.Uniform.dist(lower=data.min(), upper=data.max())
    
    components = [comp0, comp1, comp2, compb]
    
    # Define observed model
    full_model = pm.Mixture('full_model', w=w, comp_dists=components, observed=data)

ADVI optimization

Using VI the model seems to run without error (if not slowly):

vi = pm.ADVI(model=mod)
approx = vi.fit(10000)

Output (terminated prematurely):

 2.20% [220/10000 00:29<21:33 Average Loss = 1.0545e+05]
Interrupted at 220 [2%]: Average Loss = 1.0542e+05

NUTS sampling error

with mod:
    trace = pm.sample()

Output (truncated traceback):

.
.
.
The above exception was the direct cause of the following exception:

MemoryError                               Traceback (most recent call last)
MemoryError: Unable to allocate 511. GiB for an array with shape (68551114498,) and data type int64
.
.
.

Dependencies

print(f"numpy: {np.__version__}")
print(f"scipy: {sp.__version__}")
print(f"pymc: {pm.__version__}")
numpy: 1.23.4
scipy: 1.9.3
pymc: 4.2.2

You should be able to run both mock data sim and model build exactly as is above to reproduce the error. I’m hoping it’s just a dumb mistake on my part, but if it is in fact a pymc bug I’d appreciate any suggestions you might have for temporary work arounds. I’m applying this model to tens of thousands of data sets and would like to scale it up to hundreds of thousands, but given the current speed of a single model fitting instance this is compute time prohibitive, even on our local compute cluster. Your sage advice is greatly appreciated. Much obliged!

1 Like

I haven’t had time to check the code yet, but the issue might come from this line. What are the ranges you’d expect for _min and _max?

You need to keep in mind that pm.sample runs mcmc and will explore all the parameter space, including the tails, not only around the mean/mode like an optimizer. From the prior it looks like the allowed range is huge. Does the error always happen with the same shape independently of tge seed?

1 Like

Hey @OriolAbril, thanks for taking a look at my issue. Yes, that was my initial suspicion as well. Since that model component (the user defined component comp1) is proportional to the product of the EMG cumulative distribution function and the Normal survival function (i.e. the _log_unscaled definition which = log(CDF*SF)) it must be numerically normalized in order to serve as a well behaved probability distribution (I can’t find a way to analytically solve that product to find a closed-form normalization factor). To numerically normalize it, I am numerically integrating it over a large range to approximate the integration over (-\infty, \infty). The _x array is that numerical integration range, the length of which will of course depend on the values of all the parameters at each step in the fitting. The _min and _max values of that integration range are also defined to be a large number of standard deviations (for a given choice of parameter values) from each end of the probability distribution, set by _n (in order to make sure the resulting probability value is of high precision and, thus, well-behaved). In this case _n = 10 standard deviations. So it seems like it would be a large numeric integration range, as you say.

However, I calculated an extreme case for this range and it doesn’t seem like it would be the culprit. Say I sampled 10\sigma out on each of the prior distributions (pretty far out into the tails) for the given values of the hyperparameters. In that case, the length of _x is still only on the order of 10e6. See below:

Nsig = 10

_n = 10
m0 = -Nsig*2000
s0 = Nsig*5000
t0 = Nsig*500
m2 = Nsig*10000 + 10000
s2 = Nsig*5000

_min = np.floor(min([m0-_n*s0, m2-_n*s2]))
_max = np.ceil(max([m0+_n*np.sqrt(s0**2+t0**2), m2+_n*s2]))
print(f"Length of _x (min:{_min}, max:{_max}): {abs(_min-_max):e}")
print(f"_x approx size in bytes: len*64/8 = {abs(_min-_max)*64/8:e}")

Output:

Length of _x (min:-520000.0, max:610000.0): 1.130000e+06
_x approx size in bytes: len*64/8 = 9.040000e+06

And 9e6 is still way smaller than the GiB scale of the memory error. This leads me to believe that the size of _x is not actually the memory bottleneck. Does that make sense? Am I missing something?

Yeah, the error always happens with the same shape of priors (indep of seed). However, the size of the array allocation leading to the memory error does vary. For example, I just ran it again and got the following:

MemoryError                               Traceback (most recent call last)
MemoryError: Unable to allocate 484. TiB for an array with shape (133056876572664,) and data type int64

This is much larger than the previously mentioned memory error. The stated memory value is random from run to run, but always in that same GiB to TiB range.

Also, along these same lines, I have decreased _n to 3 and changed dtype=int32 in order to decrease the length of _x and the memory footprint for each element to see if that solved the issue, but I’m still getting the same error.

As an aside: Is the best way of determining the length of a tensor object in aseara to use the .eval() method and then check the length? i.e.

x = at.arange(0, 100)
len(x.eval())

Return:

100

Or is there a more streamlined way that doesn’t require you to evaluate the actual tensor array?

I don’t know the aesara and pymc API’s/backend all that well yet, but am trying to learn.

You can do x.shape.eval() which for most cases will will not require computing the actual tensor outputs.

1 Like

I don’t know how ADVI works, but it seems to produce a much more compact graph.

You may need to dig a bit and profile the d/logp functions for your model: Profiling Aesara function — Aesara 2.8.7+37.geadc6e33e.dirty documentation

You can obtain them via model.compile_fn(model.logp(), point_fn=False), model.compile_fn(model.dlogp(), point_fn=False)) and the value and logp combined via model.compile_fn([model.logp(), model.dlogp()], point_fn=False)

1 Like

@ricardoV94 thanks for the guidance. I read through the tutorial at the link you provided, however I believe I’m missing some basics on how to perform the profiling. I think I understand the profiling example at the bottom of the page (here: aesara/profiling_example.py at main · aesara-devs/aesara · GitHub) but I don’t fully understand how that relates to the compiled model functions you mention.

Is the idea to use d/logp functions obtained using model.compile_fn() as the aesara function in that linked example, and then provide the function a bunch of random inputs across the input parameter space so the profiler can get estimates on where in the model time and memory are being used?

For example I can create the compiled logp function as follows:

comp_logp = mod.compile_fn(mod.logp(), point_fn=False)
comp_logp

Returns:

<aesara.compile.function.types.Function at 0x2abdabf3c790>

Which appears to have the following inputs:

list(comp_logp.inv_finder.values())
[In(m0),
 In(s0_0_log__),
 In(t0_0_log__),
 In(m2_0_log__),
 In(s2_0_log__),
 In(w_simplex__)]

This aesara function I should be able to evaluate by providing it inputs for all model parameters, like so:

comp_logp(
    m0=0, 
    s0_0_log__=np.log(100), 
    t0_0_log__=np.log(100), 
    m2_0_log__=np.log(500),
    s2_0_log__=np.log(300),
    w_simplex__=np.array([0.25, 0.25, 0.25])
)

Returns:

array(-110815.51280221)

So, in order to profile comp_logp I should just evaluate it some 10k times with random input values and with:

aesara.config.profile = True
aesara.config.profile_memory = True

in order to produce a profile report. Does that all seem correct? I’d just like to make sure I’m understanding your suggestion.

Assuming I’ve got the above right, I have a few other questions regarding the inputs to the compiled logp function:

  • Are all parameters ending in _log__ simply representing the log of the underlying parameter, i.e. is s0_0_log__ equal to np.log(s0_0)?
  • Why is m0 the only non-logged parameter in the inputs?
  • What is the form of w_simplex__? Is it the weights of the first three components of the model and the last weight is inferred as the remained to sum to 1?

I appreciate any feedback, as I’ve still got a lot to learn about the inner workings of aesara. Thanks!

Yes your intuition is mostly correct, except you don’t necessarily need to test the function with random inputs (although sometimes the bottlenecks can be coordinate specific)

You can get a useable point via model.initial_point(). Evaluating the function a couple hundred times on this point should give you a stable profile.

The underscore names are for transformed parameters. Positive distributions will by default have a log transform and the variable name will be f"{name}_log__", other distributions have different default transforms (e.g, simplex for Dirichlet), and some will have none like the Normal, as any value has a nonzero density.

If you want to generate multiple points you can set initval="prior" for each variable and everytime you call model.initial_point you’ll get a new point. But again, a single initial point is usually a good enough for benchmarking.

1 Like

I still seem to be doing something incorrectly. Based on your last message, I ran the following code along with the full model build described above:

ae.config.profile = True
ae.config.profile_memory = True

for _ in range(100):
    model.initial_point()

However, it returns nothing to stdout (no error either). What am I missing? Thanks!

@ricardoV94 also:

I assume I have to include this kwarg into the call to pm.Mixture? I did this by changing the last line of the model to:

full_model = pm.Mixture('full_model', w=w, comp_dists=components, observed=data, initval="prior")

However, when I call model.initial_point() multiple times, it returns the same dictionary:

{'m0': array(0.),
 's0_0_log__': array(8.51719319),
 't0_0_log__': array(6.2146081),
 'm2_0_log__': array(9.21034037),
 's2_0_log__': array(8.51719319),
 'w_simplex__': array([0., 0., 0.])}

Am I including the initval="prior" kwarg in the wrong place?

@ricardoV94
Ah, it seems I can’t even get the basic profiling example (at the bottom of the page: Profiling Aesara function — Aesara 2.8.7+37.geadc6e33e.dirty documentation) to run correctly.

The file profiling_example.py containing the following:

import numpy as np

import aesara

x, y, z = aesara.tensor.vectors("xyz")
f = aesara.function([x, y, z], [(x + y + z) * 2])
xv = np.random.random((10,)).astype(aesara.config.floatX)
yv = np.random.random((10,)).astype(aesara.config.floatX)
zv = np.random.random((10,)).astype(aesara.config.floatX)
f(xv, yv, zv)

If I run it as follows (as stated in the example):

(pymc_env) $ AESARA_FLAGS=optimizer_excluding=fusion:inplace,profile=True python profiling_example.py

I get the following error:

Exception ignored in atexit callback: <function _atexit_print_fn at 0x7f4d4f7ed120>
Traceback (most recent call last):
  File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/profiling.py", line 78, in _atexit_print_fn
    ps.summary(
  File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/profiling.py", line 1452, in summary
    self.summary_function(file)
  File "/Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/profiling.py", line 786, in summary_function
    print("Function profiling", file=file)
AttributeError: 'str' object has no attribute 'write'

Any thoughts?

Yes. The observed variable is the only one that is not part of the initial point :). You should set initval="prior" for every other unobserved variable in the model if you want their value to change everytime you call model.initial_point()

You don’t want to profile the initial point function, that’s just to get values you can use to profile the d/logp functions:

import pymc as pm

with pm.Model() as m:
  x = pm.Normal("x", initval="prior")
  y = pm.Normal("y", x, observed=[0])

f = m.compile_fn(m.logp(), profile=True)
for i in range(1000):
  ip = m.initial_point()
  f(ip)
print(f.f.profile.summary())

f.f accesses the actual aesara function

By the way, if after you figure out how to profile PyMC models you want to contribute that knowledge to our docs, that would be super valuable :slight_smile:

1 Like

Also I forgot, there is a model.profile that you can use that does the calling with the initial point (just one) for you automatically :man_facepalming: :

model.profile(m.logp())
1 Like

We do have a notebook on profiling but it is still using 3.9: Profiling — PyMC example gallery

1 Like

I would certainly be happy to contribute to the documentation once I can make more sense of the profiling process! :smiley:

I’m currently trying to make sense of the overarching design of the aesara library (mostly by reading the primary documentation) as I would eventually like to meaningfully contribute to pymc.

Though I’m finding it pretty difficult to get a handle on the design, so if you have any suggested resources I’d much appreciate it.

@ricardoV94 @aseyboldt @OriolAbril I had to set this aside for a couple weeks but am working on it again.

I’ve profiled the model using the mod.profile() method and as far as I can tell there’s nothing obviously wrong in the mod.logp() or mod.dlogp() functions. Though, to be honest, I’m still not sure what I’m looking at. The vast majority of the time is consumed by the Elemwise class which I assume is tasked with scaling all the intermediate tensors to the correct dimensions?

I’ve included the .summary() for both logp and dlogp below. Is there anything y’all see that could be the problem? How do I interpret the Class, Ops, and Apply summaries?

After running the “Mock Data” and “PyMC model build” code in my earlier post (from Nov. 5), I computed the logp summary as follows (those code blocks should be complete so that you can reproduce it locally):

mod_profile_logp = mod.profile(mod.logp())
mod_profile_logp.summary()

Return:

Function profiling
==================
  Message: /Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/aesaraf.py:970
  Time in 1000 calls to Function.__call__: 6.009044e+01s
  Time in Function.vm.__call__: 60.014588832855225s (99.874%)
  Time in thunks: 59.87600064277649s (99.643%)
  Total compilation time: 3.742687e+00s
    Number of Apply nodes: 109
    Aesara rewrite time: 2.879735e+00s
       Aesara validate time: 3.279018e-02s
    Aesara Linker time (includes C, CUDA code generation/compiling): 0.5069496631622314s
       Import time 3.829298e-01s
       Node make_thunk time 5.027940e-01s
           Node Elemwise{Composite{(Switch(OR(i0, i1), i2, (i3 + i4)) + ((log1p(i5) + ((i6 + i5) * i7)) - ((i6 + i5) * (i8 + log(i9)))))}}[(0, 4)](Any{0}.0, Any{0}.0, TensorConstant{-inf}, TensorConstant{1.791759469228055}, Sum{acc_dtype=float64}.0, Shape_i{0}.0, TensorConstant{1}, Sum{acc_dtype=float64}.0, max, Sum{acc_dtype=float64}.0) time 2.737141e-02s
           Node MakeVector{dtype='float64'}(m0_logprob, s0_0_log___logprob, t0_0_log___logprob, m2_0_log___logprob, s2_0_log___logprob, w_simplex___logprob, Sum{acc_dtype=float64}.0) time 2.539825e-02s
           Node Elemwise{Log}[(0, 0)](InplaceDimShuffle{x,0}.0) time 2.488804e-02s
           Node Elemwise{Composite{(Switch(GE(i0, i1), (i2 - (i3 * i0)), i4) + i5)}}[(0, 0)](m2_0_log___log, TensorConstant{0.0}, TensorConstant{-9.210340371976184}, TensorConstant{0.0001}, TensorConstant{-inf}, m2_0_log__) time 2.408910e-02s
           Node Elemwise{Composite{Switch(i0, ((i1 + (i2 * sqr(i3))) - i4), i5)}}[(0, 3)](Elemwise{gt,no_inplace}.0, TensorConstant{(1, 1) of ..5332046727}, TensorConstant{(1, 1) of -0.5}, Elemwise{Composite{((i0 - i1) / i2)}}.0, Elemwise{Log}[(0, 0)].0, TensorConstant{(1, 1) of -inf}) time 2.279091e-02s

Time in all call to aesara.grad() 2.374974e+00s
Time since aesara import 219.557s
Class
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Class name>
  99.0%    99.0%      59.272s       1.04e-03s     C    57000      57   aesara.tensor.elemwise.Elemwise
   0.3%    99.3%       0.185s       2.65e-05s     C     7000       7   aesara.tensor.math.Sum
   0.2%    99.5%       0.114s       4.94e-06s     C    23000      23   aesara.tensor.elemwise.DimShuffle
   0.2%    99.7%       0.112s       2.80e-05s     C     4000       4   aesara.tensor.math.Max
   0.2%    99.8%       0.102s       3.39e-05s     C     3000       3   aesara.tensor.basic.Join
   0.1%   100.0%       0.071s       7.05e-05s     Py    1000       1   aesara.tensor.basic.ARange
   0.0%   100.0%       0.007s       1.45e-06s     C     5000       5   aesara.tensor.basic.MakeVector
   0.0%   100.0%       0.005s       1.21e-06s     C     4000       4   aesara.tensor.math.All
   0.0%   100.0%       0.004s       4.34e-06s     C     1000       1   aesara.tensor.nnet.basic.Softmax
   0.0%   100.0%       0.002s       1.08e-06s     C     2000       2   aesara.tensor.math.Any
   0.0%   100.0%       0.001s       1.28e-06s     C     1000       1   aesara.tensor.shape.Shape_i
   0.0%   100.0%       0.001s       1.18e-06s     C     1000       1   aesara.tensor.math.Min
   ... (remaining 0 Classes account for   0.00%(0.00s) of the runtime)

Ops
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Op name>
  85.9%    85.9%      51.428s       5.14e-02s     C     1000        1   Elemwise{Composite{(exp(Switch(i0, Switch(i1, (Composite{Switch(LT(i0, i1), (log((i2 * erfcx(((i3 * i4) / i5)))) - (i2 * sqr(i0))), log1p((i6 * erfc(((i7 * i4) / i5)))))}(((i2 - i3) / i4), i5, i6, i7, (i2 - i3), i4, i8, i9) + scalar_log1mexp(((((i3 - i2) / i10) + i11 + Switch(LT(((i2 - i12) / i4), i5), (log((i6 * erfcx(((i7 * (i2 - i12)) / i4)))) - (i6 * sqr(((i2 - i12) / i4)))), log1p((i8 * erfc(((i9 * (i2 - i12)) / i4)))))) - Composite{
   7.0%    92.9%       4.186s       4.19e-03s     C     1000        1   Elemwise{Composite{((Switch(i0, Switch(i1, (Composite{Switch(LT(i0, i1), (log((i2 * erfcx(((i3 * i4) / i5)))) - (i2 * sqr(i0))), log1p((i6 * erfc(((i7 * i4) / i5)))))}((i2 / i3), i4, i5, i6, i2, i3, i7, i8) + scalar_log1mexp(((i9 + i10 + i11) - Composite{Switch(LT(i0, i1), (log((i2 * erfcx(((i3 * i4) / i5)))) - (i2 * sqr(i0))), log1p((i6 * erfc(((i7 * i4) / i5)))))}((i2 / i3), i4, i5, i6, i2, i3, i7, i8)))), Composite{Switch(LT(i0, i1), (
   2.7%    95.6%       1.627s       1.63e-03s     C     1000        1   Elemwise{Composite{Switch(LT(((i0 - i1) / i2), i3), (log((i4 * erfcx(((i5 * (i0 - i1)) / i2)))) - (i4 * sqr(((i0 - i1) / i2)))), log1p((i6 * erfc(((i7 * (i0 - i1)) / i2)))))}}
   1.7%    97.3%       1.025s       5.12e-04s     C     2000        2   Elemwise{Composite{Switch(i0, i1, exp((i2 - i3)))}}[(0, 2)]
   0.7%    98.0%       0.421s       4.21e-04s     C     1000        1   Elemwise{Composite{Switch(i0, (i1 + log(i2)), i3)}}[(0, 1)]
   0.4%    98.4%       0.235s       3.92e-05s     C     6000        6   Elemwise{exp,no_inplace}
   0.3%    98.7%       0.151s       2.51e-05s     C     6000        6   Sum{acc_dtype=float64}
   0.2%    98.8%       0.108s       1.08e-04s     C     1000        1   Max{maximum}{1}
   0.2%    99.0%       0.108s       5.38e-05s     C     2000        2   Elemwise{Composite{((i0 - i1) / i2)}}
   0.2%    99.2%       0.102s       3.39e-05s     C     3000        3   Join
   0.1%    99.3%       0.071s       7.05e-05s     Py    1000        1   ARange{dtype='int32'}
   0.1%    99.4%       0.062s       6.18e-05s     C     1000        1   Elemwise{Composite{Switch(i0, Switch(i1, (i2 + i3 + i4 + i5), (i6 - ((i7 * sqr(i8)) / i9))), i10)}}
   0.1%    99.5%       0.058s       5.30e-06s     C     11000       11   InplaceDimShuffle{x}
   0.1%    99.6%       0.046s       1.16e-05s     C     4000        4   Elemwise{Add}[(0, 1)]
   0.1%    99.7%       0.040s       2.02e-05s     C     2000        2   Elemwise{isinf,no_inplace}
   0.1%    99.7%       0.035s       3.45e-05s     C     1000        1   Sum{axis=[1], acc_dtype=float64}
   0.0%    99.8%       0.029s       2.89e-05s     C     1000        1   Elemwise{Composite{Switch(i0, ((i1 + (i2 * sqr(i3))) - i4), i5)}}[(0, 3)]
   0.0%    99.8%       0.027s       3.87e-06s     C     7000        7   InplaceDimShuffle{x,x}
   0.0%    99.8%       0.013s       4.32e-06s     C     3000        3   InplaceDimShuffle{0,x}
   0.0%    99.8%       0.011s       1.09e-05s     C     1000        1   Elemwise{sub,no_inplace}
   ... (remaining 37 Ops account for   0.16%(0.09s) of the runtime)

Apply
------
<% time> <sum %> <apply time> <time per call> <#call> <id> <Apply name>
  85.9%    85.9%      51.428s       5.14e-02s   1000    77   Elemwise{Composite{(exp(Switch(i0, Switch(i1, (Composite{Switch(LT(i0, i1), (log((i2 * erfcx(((i3 * i4) / i5)))) - (i2 * sqr(i0))), log1p((i6 * erfc(((i7 * i4) / i5)))))}(((i2 - i3) / i4), i5, i6, i7, (i2 - i3), i4, i8, i9) + scalar_log1mexp(((((i3 - i2) / i10) + i11 + Switch(LT(((i2 - i12) / i4), i5), (log((i6 * erfcx(((i7 * (i2 - i12)) / i4)))) - (i6 * sqr(((i2 - i12) / i4)))), log1p((i8 * erfc(((i9 * (i2 - i12)) / i4)))))) - Composite{Switch(LT(i
   7.0%    92.9%       4.186s       4.19e-03s   1000    95   Elemwise{Composite{((Switch(i0, Switch(i1, (Composite{Switch(LT(i0, i1), (log((i2 * erfcx(((i3 * i4) / i5)))) - (i2 * sqr(i0))), log1p((i6 * erfc(((i7 * i4) / i5)))))}((i2 / i3), i4, i5, i6, i2, i3, i7, i8) + scalar_log1mexp(((i9 + i10 + i11) - Composite{Switch(LT(i0, i1), (log((i2 * erfcx(((i3 * i4) / i5)))) - (i2 * sqr(i0))), log1p((i6 * erfc(((i7 * i4) / i5)))))}((i2 / i3), i4, i5, i6, i2, i3, i7, i8)))), Composite{Switch(LT(i0, i1), (log((i2 * e
   2.7%    95.6%       1.627s       1.63e-03s   1000    64   Elemwise{Composite{Switch(LT(((i0 - i1) / i2), i3), (log((i4 * erfcx(((i5 * (i0 - i1)) / i2)))) - (i4 * sqr(((i0 - i1) / i2)))), log1p((i6 * erfc(((i7 * (i0 - i1)) / i2)))))}}(TensorConstant{[[-2265.]
.. [ 7658.]]}, InplaceDimShuffle{0,x}.0, InplaceDimShuffle{x,x}.0, TensorConstant{(1, 1) of -1.0}, TensorConstant{(1, 1) of 0.5}, TensorConstant{(1, 1) of ..7932881648}, TensorConstant{(1, 1) of -0.5}, TensorConstant{(1, 1) of ..7932881648})
   1.7%    97.3%       1.023s       1.02e-03s   1000   103   Elemwise{Composite{Switch(i0, i1, exp((i2 - i3)))}}[(0, 2)](Elemwise{isinf,no_inplace}.0, Elemwise{exp,no_inplace}.0, Elemwise{Add}[(0, 1)].0, InplaceDimShuffle{0,x}.0)
   0.7%    98.0%       0.421s       4.21e-04s   1000   105   Elemwise{Composite{Switch(i0, (i1 + log(i2)), i3)}}[(0, 1)](InplaceDimShuffle{x}.0, max, Sum{axis=[1], acc_dtype=float64}.0, TensorConstant{(1,) of -inf})
   0.4%    98.4%       0.229s       2.29e-04s   1000   101   Elemwise{exp,no_inplace}(InplaceDimShuffle{0,x}.0)
   0.2%    98.6%       0.133s       1.33e-04s   1000    87   Sum{acc_dtype=float64}(Elemwise{Composite{(exp(Switch(i0, Switch(i1, (Composite{Switch(LT(i0, i1), (log((i2 * erfcx(((i3 * i4) / i5)))) - (i2 * sqr(i0))), log1p((i6 * erfc(((i7 * i4) / i5)))))}(((i2 - i3) / i4), i5, i6, i7, (i2 - i3), i4, i8, i9) + scalar_log1mexp(((((i3 - i2) / i10) + i11 + Switch(LT(((i2 - i12) / i4), i5), (log((i6 * erfcx(((i7 * (i2 - i12)) / i4)))) - (i6 * sqr(((i2 - i12) / i4)))), log1p((i8 * erfc(((i9 * (i2 - i12)) / i4)))))) 
   0.2%    98.8%       0.108s       1.08e-04s   1000    99   Max{maximum}{1}(Elemwise{Add}[(0, 1)].0)
   0.1%    98.9%       0.086s       8.63e-05s   1000    97   Join(TensorConstant{1}, Elemwise{Composite{Switch(i0, Switch(i1, (i2 + i3 + i4 + i5), (i6 - ((i7 * sqr(i8)) / i9))), i10)}}.0, Elemwise{Composite{((Switch(i0, Switch(i1, (Composite{Switch(LT(i0, i1), (log((i2 * erfcx(((i3 * i4) / i5)))) - (i2 * sqr(i0))), log1p((i6 * erfc(((i7 * i4) / i5)))))}((i2 / i3), i4, i5, i6, i2, i3, i7, i8) + scalar_log1mexp(((i9 + i10 + i11) - Composite{Switch(LT(i0, i1), (log((i2 * erfcx(((i3 * i4) / i5)))) - (i2 * sqr(i0)
   0.1%    99.1%       0.071s       7.05e-05s   1000    69   ARange{dtype='int32'}(Elemwise{Floor}[(0, 0)].0, Elemwise{Ceil}[(0, 0)].0, TensorConstant{1})
   0.1%    99.2%       0.062s       6.18e-05s   1000    90   Elemwise{Composite{Switch(i0, Switch(i1, (i2 + i3 + i4 + i5), (i6 - ((i7 * sqr(i8)) / i9))), i10)}}(InplaceDimShuffle{x,x}.0, InplaceDimShuffle{0,x}.0, Elemwise{Composite{(-log(i0))}}[(0, 0)].0, Elemwise{Composite{((i0 - i1) / i2)}}.0, InplaceDimShuffle{0,x}.0, Elemwise{Composite{Switch(LT(((i0 - i1) / i2), i3), (log((i4 * erfcx(((i5 * (i0 - i1)) / i2)))) - (i4 * sqr(((i0 - i1) / i2)))), log1p((i6 * erfc(((i7 * (i0 - i1)) / i2)))))}}.0, Elemwise{Com
   0.1%    99.3%       0.054s       5.39e-05s   1000    40   Elemwise{Composite{((i0 - i1) / i2)}}(InplaceDimShuffle{x,x}.0, TensorConstant{[[-2265.]
.. [ 7658.]]}, InplaceDimShuffle{x,x}.0)
   0.1%    99.3%       0.054s       5.38e-05s   1000    43   Elemwise{Composite{((i0 - i1) / i2)}}(TensorConstant{[[-12265.]..[ -2342.]]}, InplaceDimShuffle{x,x}.0, InplaceDimShuffle{x,x}.0)
   0.1%    99.4%       0.044s       4.36e-05s   1000    98   Elemwise{Add}[(0, 1)](Elemwise{Log}[(0, 0)].0, Join.0)
   0.1%    99.5%       0.039s       3.87e-05s   1000   102   Elemwise{isinf,no_inplace}(InplaceDimShuffle{0,x}.0)
   0.1%    99.5%       0.035s       3.45e-05s   1000   104   Sum{axis=[1], acc_dtype=float64}(Elemwise{Composite{Switch(i0, i1, exp((i2 - i3)))}}[(0, 2)].0)
   0.0%    99.6%       0.029s       2.89e-05s   1000    96   Elemwise{Composite{Switch(i0, ((i1 + (i2 * sqr(i3))) - i4), i5)}}[(0, 3)](Elemwise{gt,no_inplace}.0, TensorConstant{(1, 1) of ..5332046727}, TensorConstant{(1, 1) of -0.5}, Elemwise{Composite{((i0 - i1) / i2)}}.0, Elemwise{Log}[(0, 0)].0, TensorConstant{(1, 1) of -inf})
   0.0%    99.6%       0.012s       1.25e-05s   1000   106   Sum{acc_dtype=float64}(0 <= weights <= 1, sum(weights) == 1)
   0.0%    99.6%       0.011s       1.09e-05s   1000    10   Elemwise{sub,no_inplace}(TensorConstant{[[-2265.]
.. [ 7658.]]}, InplaceDimShuffle{x,x}.0)
   0.0%    99.6%       0.011s       1.07e-05s   1000    32   Join(TensorConstant{0}, w_simplex__, Elemwise{neg,no_inplace}.0)
   ... (remaining 89 Apply instances account for 0.36%(0.21s) of the runtime)

Here are tips to potentially make your code run faster
                 (if you think of new ones, suggest them on the mailing list).
                 Test them first, as they are not guaranteed to always provide a speedup.
  - Try the Aesara flag floatX=float32
  - Try installing amdlibm and set the Aesara flag lib__amblibm=True. This speeds up only some Elemwise operation.

The dlogp profile was run as follows:

mod_profile_dlogp = mod.profile(mod.dlogp())
mod_profile_dlogp.summary()

Returns:

Function profiling
==================
  Message: /Users/jast1849/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/aesaraf.py:970
  Time in 1000 calls to Function.__call__: 1.049568e+02s
  Time in Function.vm.__call__: 104.82313346862793s (99.873%)
  Time in thunks: 103.50367903709412s (98.615%)
  Total compilation time: 1.252786e+01s
    Number of Apply nodes: 265
    Aesara rewrite time: 1.154807e+01s
       Aesara validate time: 2.012794e-01s
    Aesara Linker time (includes C, CUDA code generation/compiling): 0.543797492980957s
       Import time 2.388318e-01s
       Node make_thunk time 5.282750e-01s
           Node Elemwise{Composite{(((Switch(GE(i0, i1), i2, i3) + ((i4 * i5) / i6) + ((i7 * i8 * i9 * i10) / i6) + i11 + (((-i12) / i13) * sgn(i10)) + (i14 * i15 * i10) + i16 + ((i17 * i5) / i6) + ((i18 * i8 * i19 * i10) / i6) + i20 + i21 + i22 + ((i23 * i5) / i6) + ((i24 * i8 * i25 * i10) / i6) + i26 + i27) * i0) + i28)}}[(0, 0)](s0_0_log___log, TensorConstant{0.0}, TensorConstant{-0.0002}, TensorConstant{0}, Sum{acc_dtype=float64}.0, Elemwise{true_div,no_inplace}.0, Elemwise{add,no_inplace}.0, TensorConstant{-1.0}, TensorConstant{2.0}, Sum{acc_dtype=float64}.0, Elemwise{add,no_inplace}.0, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, Elemwise{abs,no_inplace}.0, TensorConstant{4.0}, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, TensorConstant{-1.0}, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, TensorConstant{-1.0}, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, (d__logp/dt0_0_log___logprob){1.0}) time 5.202770e-03s
           Node Elemwise{exp,no_inplace}(s2_0_log__) time 5.120993e-03s
           Node Elemwise{Composite{(((Switch(GE(i0, i1), i2, i3) + ((-i4) / i5) + i6 + ((-((i4 * i7 * i8) / i5)) / i5) + ((-(i9 * i10 * i11)) / sqr(i5)) + i12 + ((-((i13 * i7 * i8) / i5)) / i5) + ((-(i14 * i15 * i11)) / sqr(i5)) + i16 + ((-((i17 * i7 * i8) / i5)) / i5) + ((-(i18 * i19 * i11)) / sqr(i5))) * i0) + i20)}}[(0, 0)](t0_0_log___log, TensorConstant{0.0}, TensorConstant{-0.002}, TensorConstant{0}, Sum{acc_dtype=float64}.0, Elemwise{add,no_inplace}.0, Sum{acc_dtype=float64}.0, Elemwise{true_div,no_inplace}.0, Elemwise{add,no_inplace}.0, TensorConstant{-1.0}, Sum{acc_dtype=float64}.0, Elemwise{sqr,no_inplace}.0, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, TensorConstant{-1.0}, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, TensorConstant{-1.0}, Sum{acc_dtype=float64}.0, (d__logp/dt0_0_log___logprob){1.0}) time 4.855871e-03s
           Node Elemwise{Composite{((-Switch(i0, Switch(i1, ((((i2 * i3 * i4 * i4 * i5 * i6 * i7 * i8) / i9) / i10) - ((i11 * i12 * (((i13 * i4) / i14) + ((i15 * i4 * i4) / i10)) * i5 * i6 * i7 * i8) / i9)), i16), Switch(i1, ((i17 * i18 * i19 * i4 * i5 * i6 * i7 * i8) / (i9 * i20)), i16))) / i21)}}(Elemwise{lt,no_inplace}.0, Elemwise{Composite{GT(i0, (i1 * i2))}}.0, TensorConstant{(1,) of -1.0}, TensorConstant{(1,) of -0.5}, Elemwise{sub,no_inplace}.0, Elemwise{Composite{Switch(IsInf(Composite{(i0 / expm1((-i1)))}(i0, i1)), i2, Composite{(i0 / expm1((-i1)))}(i0, i1))}}[(0, 1)].0, InplaceDimShuffle{x}.0, Elemwise{Composite{erfc(((i0 * i1) / i2))}}.0, Elemwise{Composite{exp(Switch(i0, Switch(i1, (i2 + scalar_log1mexp(i3)), i2), i4))}}[(0, 2)].0, InplaceDimShuffle{x}.0, InplaceDimShuffle{x}.0, TensorConstant{(1,) of -1.0}, TensorConstant{(1,) of 0...3966440824}, TensorConstant{(1,) of -1..1670955126}, Elemwise{Composite{erfcx(((i0 * i1) / i2))}}.0, TensorConstant{(1,) of -1..5865763297}, TensorConstant{(1,) of 0}, TensorConstant{(1,) of -1.0}, TensorConstant{(1,) of 0...2872290391}, Elemwise{Composite{exp(((i0 * i1 * i1) / i2))}}.0, Elemwise{Composite{(i0 + (-i1))}}[(0, 1)].0, Elemwise{sqr,no_inplace}.0) time 4.809141e-03s
           Node Elemwise{Composite{((-Switch(i0, Switch(i1, i2, ((((i3 * i4 * i5 * i5 * i6 * i7 * i8) / i9) / i10) - ((i11 * i12 * i13 * i6 * i7 * i8) / i9))), Switch(i1, i2, ((i14 * i15 * i16 * i5 * i6 * i7 * i8) / (i9 * i17))))) / i18)}}[(0, 5)](Elemwise{lt,no_inplace}.0, Elemwise{Composite{GT(i0, (i1 * i2))}}.0, TensorConstant{(1,) of 0}, TensorConstant{(1,) of -1.0}, TensorConstant{(1,) of -0.5}, Elemwise{sub,no_inplace}.0, InplaceDimShuffle{x}.0, Elemwise{Composite{erfc(((i0 * i1) / i2))}}.0, Elemwise{Composite{exp(Switch(i0, Switch(i1, (i2 + scalar_log1mexp(i3)), i2), i4))}}[(0, 2)].0, InplaceDimShuffle{x}.0, InplaceDimShuffle{x}.0, TensorConstant{(1,) of -1.0}, TensorConstant{(1,) of 0...3966440824}, Elemwise{Composite{(((i0 * i1) / i2) + ((i3 * i1 * i1) / i4))}}.0, TensorConstant{(1,) of -1.0}, TensorConstant{(1,) of 0...2872290391}, Elemwise{Composite{exp(((i0 * i1 * i1) / i2))}}.0, Elemwise{Composite{(i0 + (-i1))}}[(0, 1)].0, Elemwise{sqr,no_inplace}.0) time 4.699945e-03s

Time in all call to aesara.grad() 2.374974e+00s
Time since aesara import 154.726s
Class
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Class name>
  96.5%    96.5%      99.895s       6.57e-04s     C   152000     152   aesara.tensor.elemwise.Elemwise
   1.7%    98.3%       1.799s       4.73e-05s     C    38000      38   aesara.tensor.math.Sum
   1.0%    99.3%       1.079s       2.70e-04s     C     4000       4   aesara.tensor.nnet.basic.Softmax
   0.2%    99.5%       0.234s       6.16e-06s     C    38000      38   aesara.tensor.elemwise.DimShuffle
   0.2%    99.7%       0.187s       1.87e-04s     Py    1000       1   aesara.tensor.basic.ARange
   0.1%    99.8%       0.140s       3.49e-05s     C     4000       4   aesara.tensor.basic.Join
   0.1%    99.9%       0.087s       2.90e-05s     C     3000       3   aesara.tensor.basic.Split
   0.0%   100.0%       0.040s       8.07e-06s     C     5000       5   aesara.tensor.shape.Reshape
   0.0%   100.0%       0.014s       1.43e-06s     C    10000      10   aesara.tensor.shape.SpecifyShape
   0.0%   100.0%       0.010s       2.45e-06s     C     4000       4   aesara.tensor.basic.MakeVector
   0.0%   100.0%       0.007s       7.32e-06s     C     1000       1   aesara.tensor.basic.Alloc
   0.0%   100.0%       0.004s       2.13e-06s     C     2000       2   aesara.tensor.math.Max
   0.0%   100.0%       0.003s       2.82e-06s     C     1000       1   aesara.tensor.shape.Shape_i
   0.0%   100.0%       0.003s       2.66e-06s     C     1000       1   aesara.tensor.math.All
   0.0%   100.0%       0.002s       2.09e-06s     C     1000       1   aesara.tensor.math.Min
   ... (remaining 0 Classes account for   0.00%(0.00s) of the runtime)

Ops
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Op name>
  18.2%    18.2%      18.848s       3.77e-03s     C     5000        5   Elemwise{Composite{erfc(((i0 * i1) / i2))}}
  11.7%    29.9%      12.096s       2.02e-03s     C     6000        6   Elemwise{Composite{exp(((i0 * i1 * i1) / i2))}}
   8.2%    38.1%       8.475s       8.47e-03s     C     1000        1   Elemwise{Composite{exp(Switch(i0, Switch(i1, (i2 + scalar_log1mexp(i3)), i2), i4))}}[(0, 2)]
   8.0%    46.0%       8.235s       1.65e-03s     C     5000        5   Elemwise{Composite{erfcx(((i0 * i1) / i2))}}
   6.8%    52.8%       7.034s       2.34e-03s     C     3000        3   Elemwise{Composite{Switch(i0, (log((i1 * i2)) - (i1 * sqr(i3))), log1p((i4 * i5)))}}[(0, 2)]
   5.7%    58.5%       5.908s       5.91e-03s     C     1000        1   Elemwise{Composite{(((i0 / i1) + i2 + Switch(i3, (log((i4 * i5)) - (i4 * sqr(i6))), log1p((i7 * i8)))) - i9)}}[(0, 6)]
   4.5%    63.1%       4.696s       4.70e-03s     C     1000        1   Elemwise{Composite{((-Switch(i0, Switch(i1, ((((i2 * i3 * i4 * i4 * i5 * i6 * i7 * i8) / i9) / i10) - ((i11 * i12 * (((i13 * i4) / i14) + ((i15 * i4 * i4) / i10)) * i5 * i6 * i7 * i8) / i9)), i16), Switch(i1, ((i17 * i18 * i19 * i4 * i5 * i6 * i7 * i8) / (i9 * i20)), i16))) / i21)}}
   4.1%    67.2%       4.256s       2.13e-03s     C     2000        2   Elemwise{Composite{Switch(IsInf(Composite{(i0 / expm1((-i1)))}(i0, i1)), i2, Composite{(i0 / expm1((-i1)))}(i0, i1))}}[(0, 1)]
   4.1%    71.3%       4.254s       4.25e-03s     C     1000        1   Elemwise{Composite{Switch(i0, Switch(i1, ((((i2 * i3 * i4 * i5 * i6 * i7 * i8) / i9) / i10) - ((i11 * i12 * ((i13 / (i14 * i15)) + ((i16 * i4) / i10)) * i5 * i6 * i7 * i8) / i9)), i17), Switch(i1, (((i18 * i19 * i20 * i5 * i6 * i7 * i8) / i9) / (i21 * i15)), i17))}}[(0, 4)]
   2.9%    74.2%       3.039s       3.04e-03s     C     1000        1   Elemwise{Composite{Switch(i0, Switch(i1, i2, ((((i3 * i4 * i5 * i6 * i7 * i8) / i9) / i10) - ((i11 * i12 * i13 * i6 * i7 * i8) / i9))), Switch(i1, i2, (((i14 * i15 * i16 * i6 * i7 * i8) / i9) / i17)))}}[(0, 13)]
   2.7%    76.9%       2.779s       2.78e-03s     C     1000        1   Elemwise{Composite{((-Switch(i0, Switch(i1, i2, ((((i3 * i4 * i5 * i5 * i6 * i7 * i8) / i9) / i10) - ((i11 * i12 * i13 * i6 * i7 * i8) / i9))), Switch(i1, i2, ((i14 * i15 * i16 * i5 * i6 * i7 * i8) / (i9 * i17))))) / i18)}}[(0, 5)]
   2.2%    79.2%       2.317s       3.31e-04s     C     7000        7   Elemwise{true_div,no_inplace}
   2.0%    81.2%       2.120s       2.65e-04s     C     8000        8   Elemwise{sub,no_inplace}
   1.8%    83.1%       1.912s       9.56e-04s     C     2000        2   Elemwise{Composite{((-Switch(i0, Switch(i1, (((i2 * i3 * i3 * i4) / i5) - (i6 * i7 * i4)), i8), Switch(i1, (i9 * i10 * i3 * (i4 / i11)), i8))) / i12)}}[(0, 4)]
   1.8%    84.9%       1.865s       6.22e-04s     C     3000        3   Elemwise{Composite{((i0 / (i1 * i2)) + ((i3 * i4) / i5))}}
   1.8%    86.6%       1.847s       6.16e-04s     C     3000        3   Elemwise{Composite{(((i0 * i1) / i2) + ((i3 * i1 * i1) / i4))}}
   1.7%    88.3%       1.731s       4.68e-05s     C     37000       37   Sum{acc_dtype=float64}
   1.4%    89.7%       1.409s       4.70e-04s     C     3000        3   Elemwise{Composite{Switch(i0, Switch(i1, (((i2 * i3 * i4) / i5) - (i6 * i7 * i4)), i8), Switch(i1, ((i9 * i10 * i4) / i11), i8))}}
   1.2%    90.9%       1.228s       1.23e-03s     C     1000        1   Elemwise{Composite{Switch(i0, (((i1 * i2 * i3 * i4 * i5 * i6 * i7) / i8) / i9), i10)}}[(0, 3)]
   1.1%    91.9%       1.088s       1.09e-03s     C     1000        1   Elemwise{Composite{((Switch(i0, Switch(i1, (i2 + scalar_log1mexp(i3)), i2), i4) + Switch(i5, (log((i6 * i7)) - (i6 * sqr(i8))), log1p((i6 * i9)))) - i10)}}[(0, 2)]
   ... (remaining 83 Ops account for   8.08%(8.37s) of the runtime)

Apply
------
<% time> <sum %> <apply time> <time per call> <#call> <id> <Apply name>
   8.2%     8.2%       8.475s       8.47e-03s   1000   146   Elemwise{Composite{exp(Switch(i0, Switch(i1, (i2 + scalar_log1mexp(i3)), i2), i4))}}[(0, 2)](InplaceDimShuffle{x}.0, Elemwise{Composite{GT(i0, (i1 * i2))}}.0, Elemwise{Composite{Switch(i0, (log((i1 * i2)) - (i1 * sqr(i3))), log1p((i4 * i5)))}}[(0, 2)].0, Elemwise{Composite{(((i0 / i1) + i2 + Switch(i3, (log((i4 * i5)) - (i4 * sqr(i6))), log1p((i7 * i8)))) - i9)}}[(0, 6)].0, TensorConstant{(1,) of -inf})
   7.1%    15.2%       7.297s       7.30e-03s   1000   127   Elemwise{Composite{erfc(((i0 * i1) / i2))}}(TensorConstant{(1,) of 0...7932881648}, Elemwise{sub,no_inplace}.0, InplaceDimShuffle{x}.0)
   6.0%    21.2%       6.184s       6.18e-03s   1000   120   Elemwise{Composite{erfc(((i0 * i1) / i2))}}(TensorConstant{(1,) of 0...7932881648}, Elemwise{Composite{((i0 + i1) - i2)}}.0, InplaceDimShuffle{x}.0)
   5.7%    26.9%       5.908s       5.91e-03s   1000   142   Elemwise{Composite{(((i0 / i1) + i2 + Switch(i3, (log((i4 * i5)) - (i4 * sqr(i6))), log1p((i7 * i8)))) - i9)}}[(0, 6)](Elemwise{sub,no_inplace}.0, InplaceDimShuffle{x}.0, Elemwise{Composite{(i0 * sqr(i1))}}.0, Elemwise{lt,no_inplace}.0, TensorConstant{(1,) of 0.5}, Elemwise{Composite{erfcx(((i0 * i1) / i2))}}.0, Elemwise{true_div,no_inplace}.0, TensorConstant{(1,) of -0.5}, Elemwise{Composite{erfc(((i0 * i1) / i2))}}.0, Elemwise{Composite{Switch(i0,
   5.6%    32.5%       5.775s       5.77e-03s   1000   138   Elemwise{Composite{Switch(i0, (log((i1 * i2)) - (i1 * sqr(i3))), log1p((i4 * i5)))}}[(0, 2)](Elemwise{lt,no_inplace}.0, TensorConstant{(1,) of 0.5}, Elemwise{Composite{erfcx(((i0 * i1) / i2))}}.0, Elemwise{true_div,no_inplace}.0, TensorConstant{(1,) of -0.5}, Elemwise{Composite{erfc(((i0 * i1) / i2))}}.0)
   4.5%    37.0%       4.696s       4.70e-03s   1000   202   Elemwise{Composite{((-Switch(i0, Switch(i1, ((((i2 * i3 * i4 * i4 * i5 * i6 * i7 * i8) / i9) / i10) - ((i11 * i12 * (((i13 * i4) / i14) + ((i15 * i4 * i4) / i10)) * i5 * i6 * i7 * i8) / i9)), i16), Switch(i1, ((i17 * i18 * i19 * i4 * i5 * i6 * i7 * i8) / (i9 * i20)), i16))) / i21)}}(Elemwise{lt,no_inplace}.0, Elemwise{Composite{GT(i0, (i1 * i2))}}.0, TensorConstant{(1,) of -1.0}, TensorConstant{(1,) of -0.5}, Elemwise{sub,no_inplace}.0, Elemwise{Com
   4.3%    41.3%       4.425s       4.43e-03s   1000   123   Elemwise{Composite{erfc(((i0 * i1) / i2))}}(TensorConstant{(1,) of 0...7932881648}, Elemwise{sub,no_inplace}.0, InplaceDimShuffle{x}.0)
   4.1%    45.4%       4.254s       4.25e-03s   1000   206   Elemwise{Composite{Switch(i0, Switch(i1, ((((i2 * i3 * i4 * i5 * i6 * i7 * i8) / i9) / i10) - ((i11 * i12 * ((i13 / (i14 * i15)) + ((i16 * i4) / i10)) * i5 * i6 * i7 * i8) / i9)), i17), Switch(i1, (((i18 * i19 * i20 * i5 * i6 * i7 * i8) / i9) / (i21 * i15)), i17))}}[(0, 4)](Elemwise{lt,no_inplace}.0, Elemwise{Composite{GT(i0, (i1 * i2))}}.0, TensorConstant{(1,) of -1.0}, TensorConstant{(1,) of -0.5}, Elemwise{sub,no_inplace}.0, Elemwise{Composite{Sw
   3.8%    49.2%       3.910s       3.91e-03s   1000   147   Elemwise{Composite{Switch(IsInf(Composite{(i0 / expm1((-i1)))}(i0, i1)), i2, Composite{(i0 / expm1((-i1)))}(i0, i1))}}[(0, 1)](TensorConstant{(1,) of -1.0}, Elemwise{Composite{(((i0 / i1) + i2 + Switch(i3, (log((i4 * i5)) - (i4 * sqr(i6))), log1p((i7 * i8)))) - i9)}}[(0, 6)].0, TensorConstant{(1,) of -inf})
   3.8%    53.0%       3.892s       3.89e-03s   1000   126   Elemwise{Composite{erfcx(((i0 * i1) / i2))}}(TensorConstant{(1,) of -0..7932881648}, Elemwise{sub,no_inplace}.0, InplaceDimShuffle{x}.0)
   3.6%    56.6%       3.722s       3.72e-03s   1000   121   Elemwise{Composite{exp(((i0 * i1 * i1) / i2))}}(TensorConstant{(1,) of -0..0171142714}, Elemwise{Composite{((i0 + i1) - i2)}}.0, Elemwise{sqr,no_inplace}.0)
   3.6%    60.1%       3.717s       3.72e-03s   1000   129   Elemwise{Composite{exp(((i0 * i1 * i1) / i2))}}(TensorConstant{(1,) of -0..0171142714}, Elemwise{sub,no_inplace}.0, Elemwise{sqr,no_inplace}.0)
   3.6%    63.7%       3.714s       3.71e-03s   1000   125   Elemwise{Composite{exp(((i0 * i1 * i1) / i2))}}(TensorConstant{(1,) of -0..0171142714}, Elemwise{sub,no_inplace}.0, Elemwise{sqr,no_inplace}.0)
   2.9%    66.7%       3.039s       3.04e-03s   1000   225   Elemwise{Composite{Switch(i0, Switch(i1, i2, ((((i3 * i4 * i5 * i6 * i7 * i8) / i9) / i10) - ((i11 * i12 * i13 * i6 * i7 * i8) / i9))), Switch(i1, i2, (((i14 * i15 * i16 * i6 * i7 * i8) / i9) / i17)))}}[(0, 13)](Elemwise{lt,no_inplace}.0, Elemwise{Composite{GT(i0, (i1 * i2))}}.0, TensorConstant{(1,) of 0}, TensorConstant{(1,) of -1.0}, TensorConstant{(1,) of -0.5}, Elemwise{sub,no_inplace}.0, InplaceDimShuffle{x}.0, Elemwise{Composite{erfc(((i0 * i1
   2.8%    69.4%       2.871s       2.87e-03s   1000   124   Elemwise{Composite{erfcx(((i0 * i1) / i2))}}(TensorConstant{(1,) of -0..7932881648}, Elemwise{sub,no_inplace}.0, InplaceDimShuffle{x}.0)
   2.7%    72.1%       2.779s       2.78e-03s   1000   230   Elemwise{Composite{((-Switch(i0, Switch(i1, i2, ((((i3 * i4 * i5 * i5 * i6 * i7 * i8) / i9) / i10) - ((i11 * i12 * i13 * i6 * i7 * i8) / i9))), Switch(i1, i2, ((i14 * i15 * i16 * i5 * i6 * i7 * i8) / (i9 * i17))))) / i18)}}[(0, 5)](Elemwise{lt,no_inplace}.0, Elemwise{Composite{GT(i0, (i1 * i2))}}.0, TensorConstant{(1,) of 0}, TensorConstant{(1,) of -1.0}, TensorConstant{(1,) of -0.5}, Elemwise{sub,no_inplace}.0, InplaceDimShuffle{x}.0, Elemwise{Comp
   1.7%    73.8%       1.752s       1.75e-03s   1000   226   Elemwise{Composite{((-Switch(i0, Switch(i1, (((i2 * i3 * i3 * i4) / i5) - (i6 * i7 * i4)), i8), Switch(i1, (i9 * i10 * i3 * (i4 / i11)), i8))) / i12)}}[(0, 4)](Elemwise{lt,no_inplace}.0, Elemwise{Composite{GT(i0, (i1 * i2))}}.0, TensorConstant{(1,) of -1.0}, Elemwise{sub,no_inplace}.0, Elemwise{Composite{((i0 * i1 * i2 * i3) + ((i4 * i5 * i6 * i7 * i2 * i3) / i8))}}.0, InplaceDimShuffle{x}.0, TensorConstant{(1,) of 0...7932881648}, Elemwise{Composit
   1.6%    75.4%       1.655s       1.65e-03s   1000   134   Elemwise{Composite{((i0 / (i1 * i2)) + ((i3 * i4) / i5))}}(TensorConstant{(1,) of -1..1670955126}, Elemwise{Composite{erfcx(((i0 * i1) / i2))}}.0, InplaceDimShuffle{x}.0, TensorConstant{(1,) of -1..5865763297}, Elemwise{sub,no_inplace}.0, Elemwise{sqr,no_inplace}.0)
   1.6%    77.0%       1.638s       1.64e-03s   1000   133   Elemwise{Composite{(((i0 * i1) / i2) + ((i3 * i1 * i1) / i4))}}(TensorConstant{(1,) of -1..1670955126}, Elemwise{sub,no_inplace}.0, Elemwise{Composite{erfcx(((i0 * i1) / i2))}}.0, TensorConstant{(1,) of -1..5865763297}, InplaceDimShuffle{x}.0)
   1.2%    78.2%       1.228s       1.23e-03s   1000   203   Elemwise{Composite{Switch(i0, (((i1 * i2 * i3 * i4 * i5 * i6 * i7) / i8) / i9), i10)}}[(0, 3)](Elemwise{Composite{GT(i0, (i1 * i2))}}.0, TensorConstant{(1,) of -1.0}, TensorConstant{(1,) of -0.5}, Elemwise{sub,no_inplace}.0, Elemwise{Composite{Switch(IsInf(Composite{(i0 / expm1((-i1)))}(i0, i1)), i2, Composite{(i0 / expm1((-i1)))}(i0, i1))}}[(0, 1)].0, InplaceDimShuffle{x}.0, Elemwise{Composite{erfc(((i0 * i1) / i2))}}.0, Elemwise{Composite{exp(Swit
   ... (remaining 245 Apply instances account for 21.81%(22.57s) of the runtime)

Here are tips to potentially make your code run faster
                 (if you think of new ones, suggest them on the mailing list).
                 Test them first, as they are not guaranteed to always provide a speedup.
  - Try the Aesara flag floatX=float32
  - Try installing amdlibm and set the Aesara flag lib__amblibm=True. This speeds up only some Elemwise operation.

Any advice on where to go from here would be much appreciated. Thank you :pray: