Memory allocation limit for NUTS with custom `logp` function (but not with VI methods)

@ricardoV94 @aseyboldt @OriolAbril I had to set this aside for a couple weeks but am working on it again.

I’ve profiled the model using the mod.profile() method and as far as I can tell there’s nothing obviously wrong in the mod.logp() or mod.dlogp() functions. Though, to be honest, I’m still not sure what I’m looking at. The vast majority of the time is consumed by the Elemwise class which I assume is tasked with scaling all the intermediate tensors to the correct dimensions?

I’ve included the .summary() for both logp and dlogp below. Is there anything y’all see that could be the problem? How do I interpret the Class, Ops, and Apply summaries?

After running the “Mock Data” and “PyMC model build” code in my earlier post (from Nov. 5), I computed the logp summary as follows (those code blocks should be complete so that you can reproduce it locally):

mod_profile_logp = mod.profile(mod.logp())
mod_profile_logp.summary()

Return:

Function profiling
==================
  Message: /Users/jast/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/aesaraf.py:970
  Time in 1000 calls to Function.__call__: 6.009044e+01s
  Time in Function.vm.__call__: 60.014588832855225s (99.874%)
  Time in thunks: 59.87600064277649s (99.643%)
  Total compilation time: 3.742687e+00s
    Number of Apply nodes: 109
    Aesara rewrite time: 2.879735e+00s
       Aesara validate time: 3.279018e-02s
    Aesara Linker time (includes C, CUDA code generation/compiling): 0.5069496631622314s
       Import time 3.829298e-01s
       Node make_thunk time 5.027940e-01s
           Node Elemwise{Composite{(Switch(OR(i0, i1), i2, (i3 + i4)) + ((log1p(i5) + ((i6 + i5) * i7)) - ((i6 + i5) * (i8 + log(i9)))))}}[(0, 4)](Any{0}.0, Any{0}.0, TensorConstant{-inf}, TensorConstant{1.791759469228055}, Sum{acc_dtype=float64}.0, Shape_i{0}.0, TensorConstant{1}, Sum{acc_dtype=float64}.0, max, Sum{acc_dtype=float64}.0) time 2.737141e-02s
           Node MakeVector{dtype='float64'}(m0_logprob, s0_0_log___logprob, t0_0_log___logprob, m2_0_log___logprob, s2_0_log___logprob, w_simplex___logprob, Sum{acc_dtype=float64}.0) time 2.539825e-02s
           Node Elemwise{Log}[(0, 0)](InplaceDimShuffle{x,0}.0) time 2.488804e-02s
           Node Elemwise{Composite{(Switch(GE(i0, i1), (i2 - (i3 * i0)), i4) + i5)}}[(0, 0)](m2_0_log___log, TensorConstant{0.0}, TensorConstant{-9.210340371976184}, TensorConstant{0.0001}, TensorConstant{-inf}, m2_0_log__) time 2.408910e-02s
           Node Elemwise{Composite{Switch(i0, ((i1 + (i2 * sqr(i3))) - i4), i5)}}[(0, 3)](Elemwise{gt,no_inplace}.0, TensorConstant{(1, 1) of ..5332046727}, TensorConstant{(1, 1) of -0.5}, Elemwise{Composite{((i0 - i1) / i2)}}.0, Elemwise{Log}[(0, 0)].0, TensorConstant{(1, 1) of -inf}) time 2.279091e-02s

Time in all call to aesara.grad() 2.374974e+00s
Time since aesara import 219.557s
Class
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Class name>
  99.0%    99.0%      59.272s       1.04e-03s     C    57000      57   aesara.tensor.elemwise.Elemwise
   0.3%    99.3%       0.185s       2.65e-05s     C     7000       7   aesara.tensor.math.Sum
   0.2%    99.5%       0.114s       4.94e-06s     C    23000      23   aesara.tensor.elemwise.DimShuffle
   0.2%    99.7%       0.112s       2.80e-05s     C     4000       4   aesara.tensor.math.Max
   0.2%    99.8%       0.102s       3.39e-05s     C     3000       3   aesara.tensor.basic.Join
   0.1%   100.0%       0.071s       7.05e-05s     Py    1000       1   aesara.tensor.basic.ARange
   0.0%   100.0%       0.007s       1.45e-06s     C     5000       5   aesara.tensor.basic.MakeVector
   0.0%   100.0%       0.005s       1.21e-06s     C     4000       4   aesara.tensor.math.All
   0.0%   100.0%       0.004s       4.34e-06s     C     1000       1   aesara.tensor.nnet.basic.Softmax
   0.0%   100.0%       0.002s       1.08e-06s     C     2000       2   aesara.tensor.math.Any
   0.0%   100.0%       0.001s       1.28e-06s     C     1000       1   aesara.tensor.shape.Shape_i
   0.0%   100.0%       0.001s       1.18e-06s     C     1000       1   aesara.tensor.math.Min
   ... (remaining 0 Classes account for   0.00%(0.00s) of the runtime)

Ops
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Op name>
  85.9%    85.9%      51.428s       5.14e-02s     C     1000        1   Elemwise{Composite{(exp(Switch(i0, Switch(i1, (Composite{Switch(LT(i0, i1), (log((i2 * erfcx(((i3 * i4) / i5)))) - (i2 * sqr(i0))), log1p((i6 * erfc(((i7 * i4) / i5)))))}(((i2 - i3) / i4), i5, i6, i7, (i2 - i3), i4, i8, i9) + scalar_log1mexp(((((i3 - i2) / i10) + i11 + Switch(LT(((i2 - i12) / i4), i5), (log((i6 * erfcx(((i7 * (i2 - i12)) / i4)))) - (i6 * sqr(((i2 - i12) / i4)))), log1p((i8 * erfc(((i9 * (i2 - i12)) / i4)))))) - Composite{
   7.0%    92.9%       4.186s       4.19e-03s     C     1000        1   Elemwise{Composite{((Switch(i0, Switch(i1, (Composite{Switch(LT(i0, i1), (log((i2 * erfcx(((i3 * i4) / i5)))) - (i2 * sqr(i0))), log1p((i6 * erfc(((i7 * i4) / i5)))))}((i2 / i3), i4, i5, i6, i2, i3, i7, i8) + scalar_log1mexp(((i9 + i10 + i11) - Composite{Switch(LT(i0, i1), (log((i2 * erfcx(((i3 * i4) / i5)))) - (i2 * sqr(i0))), log1p((i6 * erfc(((i7 * i4) / i5)))))}((i2 / i3), i4, i5, i6, i2, i3, i7, i8)))), Composite{Switch(LT(i0, i1), (
   2.7%    95.6%       1.627s       1.63e-03s     C     1000        1   Elemwise{Composite{Switch(LT(((i0 - i1) / i2), i3), (log((i4 * erfcx(((i5 * (i0 - i1)) / i2)))) - (i4 * sqr(((i0 - i1) / i2)))), log1p((i6 * erfc(((i7 * (i0 - i1)) / i2)))))}}
   1.7%    97.3%       1.025s       5.12e-04s     C     2000        2   Elemwise{Composite{Switch(i0, i1, exp((i2 - i3)))}}[(0, 2)]
   0.7%    98.0%       0.421s       4.21e-04s     C     1000        1   Elemwise{Composite{Switch(i0, (i1 + log(i2)), i3)}}[(0, 1)]
   0.4%    98.4%       0.235s       3.92e-05s     C     6000        6   Elemwise{exp,no_inplace}
   0.3%    98.7%       0.151s       2.51e-05s     C     6000        6   Sum{acc_dtype=float64}
   0.2%    98.8%       0.108s       1.08e-04s     C     1000        1   Max{maximum}{1}
   0.2%    99.0%       0.108s       5.38e-05s     C     2000        2   Elemwise{Composite{((i0 - i1) / i2)}}
   0.2%    99.2%       0.102s       3.39e-05s     C     3000        3   Join
   0.1%    99.3%       0.071s       7.05e-05s     Py    1000        1   ARange{dtype='int32'}
   0.1%    99.4%       0.062s       6.18e-05s     C     1000        1   Elemwise{Composite{Switch(i0, Switch(i1, (i2 + i3 + i4 + i5), (i6 - ((i7 * sqr(i8)) / i9))), i10)}}
   0.1%    99.5%       0.058s       5.30e-06s     C     11000       11   InplaceDimShuffle{x}
   0.1%    99.6%       0.046s       1.16e-05s     C     4000        4   Elemwise{Add}[(0, 1)]
   0.1%    99.7%       0.040s       2.02e-05s     C     2000        2   Elemwise{isinf,no_inplace}
   0.1%    99.7%       0.035s       3.45e-05s     C     1000        1   Sum{axis=[1], acc_dtype=float64}
   0.0%    99.8%       0.029s       2.89e-05s     C     1000        1   Elemwise{Composite{Switch(i0, ((i1 + (i2 * sqr(i3))) - i4), i5)}}[(0, 3)]
   0.0%    99.8%       0.027s       3.87e-06s     C     7000        7   InplaceDimShuffle{x,x}
   0.0%    99.8%       0.013s       4.32e-06s     C     3000        3   InplaceDimShuffle{0,x}
   0.0%    99.8%       0.011s       1.09e-05s     C     1000        1   Elemwise{sub,no_inplace}
   ... (remaining 37 Ops account for   0.16%(0.09s) of the runtime)

Apply
------
<% time> <sum %> <apply time> <time per call> <#call> <id> <Apply name>
  85.9%    85.9%      51.428s       5.14e-02s   1000    77   Elemwise{Composite{(exp(Switch(i0, Switch(i1, (Composite{Switch(LT(i0, i1), (log((i2 * erfcx(((i3 * i4) / i5)))) - (i2 * sqr(i0))), log1p((i6 * erfc(((i7 * i4) / i5)))))}(((i2 - i3) / i4), i5, i6, i7, (i2 - i3), i4, i8, i9) + scalar_log1mexp(((((i3 - i2) / i10) + i11 + Switch(LT(((i2 - i12) / i4), i5), (log((i6 * erfcx(((i7 * (i2 - i12)) / i4)))) - (i6 * sqr(((i2 - i12) / i4)))), log1p((i8 * erfc(((i9 * (i2 - i12)) / i4)))))) - Composite{Switch(LT(i
   7.0%    92.9%       4.186s       4.19e-03s   1000    95   Elemwise{Composite{((Switch(i0, Switch(i1, (Composite{Switch(LT(i0, i1), (log((i2 * erfcx(((i3 * i4) / i5)))) - (i2 * sqr(i0))), log1p((i6 * erfc(((i7 * i4) / i5)))))}((i2 / i3), i4, i5, i6, i2, i3, i7, i8) + scalar_log1mexp(((i9 + i10 + i11) - Composite{Switch(LT(i0, i1), (log((i2 * erfcx(((i3 * i4) / i5)))) - (i2 * sqr(i0))), log1p((i6 * erfc(((i7 * i4) / i5)))))}((i2 / i3), i4, i5, i6, i2, i3, i7, i8)))), Composite{Switch(LT(i0, i1), (log((i2 * e
   2.7%    95.6%       1.627s       1.63e-03s   1000    64   Elemwise{Composite{Switch(LT(((i0 - i1) / i2), i3), (log((i4 * erfcx(((i5 * (i0 - i1)) / i2)))) - (i4 * sqr(((i0 - i1) / i2)))), log1p((i6 * erfc(((i7 * (i0 - i1)) / i2)))))}}(TensorConstant{[[-2265.]
.. [ 7658.]]}, InplaceDimShuffle{0,x}.0, InplaceDimShuffle{x,x}.0, TensorConstant{(1, 1) of -1.0}, TensorConstant{(1, 1) of 0.5}, TensorConstant{(1, 1) of ..7932881648}, TensorConstant{(1, 1) of -0.5}, TensorConstant{(1, 1) of ..7932881648})
   1.7%    97.3%       1.023s       1.02e-03s   1000   103   Elemwise{Composite{Switch(i0, i1, exp((i2 - i3)))}}[(0, 2)](Elemwise{isinf,no_inplace}.0, Elemwise{exp,no_inplace}.0, Elemwise{Add}[(0, 1)].0, InplaceDimShuffle{0,x}.0)
   0.7%    98.0%       0.421s       4.21e-04s   1000   105   Elemwise{Composite{Switch(i0, (i1 + log(i2)), i3)}}[(0, 1)](InplaceDimShuffle{x}.0, max, Sum{axis=[1], acc_dtype=float64}.0, TensorConstant{(1,) of -inf})
   0.4%    98.4%       0.229s       2.29e-04s   1000   101   Elemwise{exp,no_inplace}(InplaceDimShuffle{0,x}.0)
   0.2%    98.6%       0.133s       1.33e-04s   1000    87   Sum{acc_dtype=float64}(Elemwise{Composite{(exp(Switch(i0, Switch(i1, (Composite{Switch(LT(i0, i1), (log((i2 * erfcx(((i3 * i4) / i5)))) - (i2 * sqr(i0))), log1p((i6 * erfc(((i7 * i4) / i5)))))}(((i2 - i3) / i4), i5, i6, i7, (i2 - i3), i4, i8, i9) + scalar_log1mexp(((((i3 - i2) / i10) + i11 + Switch(LT(((i2 - i12) / i4), i5), (log((i6 * erfcx(((i7 * (i2 - i12)) / i4)))) - (i6 * sqr(((i2 - i12) / i4)))), log1p((i8 * erfc(((i9 * (i2 - i12)) / i4)))))) 
   0.2%    98.8%       0.108s       1.08e-04s   1000    99   Max{maximum}{1}(Elemwise{Add}[(0, 1)].0)
   0.1%    98.9%       0.086s       8.63e-05s   1000    97   Join(TensorConstant{1}, Elemwise{Composite{Switch(i0, Switch(i1, (i2 + i3 + i4 + i5), (i6 - ((i7 * sqr(i8)) / i9))), i10)}}.0, Elemwise{Composite{((Switch(i0, Switch(i1, (Composite{Switch(LT(i0, i1), (log((i2 * erfcx(((i3 * i4) / i5)))) - (i2 * sqr(i0))), log1p((i6 * erfc(((i7 * i4) / i5)))))}((i2 / i3), i4, i5, i6, i2, i3, i7, i8) + scalar_log1mexp(((i9 + i10 + i11) - Composite{Switch(LT(i0, i1), (log((i2 * erfcx(((i3 * i4) / i5)))) - (i2 * sqr(i0)
   0.1%    99.1%       0.071s       7.05e-05s   1000    69   ARange{dtype='int32'}(Elemwise{Floor}[(0, 0)].0, Elemwise{Ceil}[(0, 0)].0, TensorConstant{1})
   0.1%    99.2%       0.062s       6.18e-05s   1000    90   Elemwise{Composite{Switch(i0, Switch(i1, (i2 + i3 + i4 + i5), (i6 - ((i7 * sqr(i8)) / i9))), i10)}}(InplaceDimShuffle{x,x}.0, InplaceDimShuffle{0,x}.0, Elemwise{Composite{(-log(i0))}}[(0, 0)].0, Elemwise{Composite{((i0 - i1) / i2)}}.0, InplaceDimShuffle{0,x}.0, Elemwise{Composite{Switch(LT(((i0 - i1) / i2), i3), (log((i4 * erfcx(((i5 * (i0 - i1)) / i2)))) - (i4 * sqr(((i0 - i1) / i2)))), log1p((i6 * erfc(((i7 * (i0 - i1)) / i2)))))}}.0, Elemwise{Com
   0.1%    99.3%       0.054s       5.39e-05s   1000    40   Elemwise{Composite{((i0 - i1) / i2)}}(InplaceDimShuffle{x,x}.0, TensorConstant{[[-2265.]
.. [ 7658.]]}, InplaceDimShuffle{x,x}.0)
   0.1%    99.3%       0.054s       5.38e-05s   1000    43   Elemwise{Composite{((i0 - i1) / i2)}}(TensorConstant{[[-12265.]..[ -2342.]]}, InplaceDimShuffle{x,x}.0, InplaceDimShuffle{x,x}.0)
   0.1%    99.4%       0.044s       4.36e-05s   1000    98   Elemwise{Add}[(0, 1)](Elemwise{Log}[(0, 0)].0, Join.0)
   0.1%    99.5%       0.039s       3.87e-05s   1000   102   Elemwise{isinf,no_inplace}(InplaceDimShuffle{0,x}.0)
   0.1%    99.5%       0.035s       3.45e-05s   1000   104   Sum{axis=[1], acc_dtype=float64}(Elemwise{Composite{Switch(i0, i1, exp((i2 - i3)))}}[(0, 2)].0)
   0.0%    99.6%       0.029s       2.89e-05s   1000    96   Elemwise{Composite{Switch(i0, ((i1 + (i2 * sqr(i3))) - i4), i5)}}[(0, 3)](Elemwise{gt,no_inplace}.0, TensorConstant{(1, 1) of ..5332046727}, TensorConstant{(1, 1) of -0.5}, Elemwise{Composite{((i0 - i1) / i2)}}.0, Elemwise{Log}[(0, 0)].0, TensorConstant{(1, 1) of -inf})
   0.0%    99.6%       0.012s       1.25e-05s   1000   106   Sum{acc_dtype=float64}(0 <= weights <= 1, sum(weights) == 1)
   0.0%    99.6%       0.011s       1.09e-05s   1000    10   Elemwise{sub,no_inplace}(TensorConstant{[[-2265.]
.. [ 7658.]]}, InplaceDimShuffle{x,x}.0)
   0.0%    99.6%       0.011s       1.07e-05s   1000    32   Join(TensorConstant{0}, w_simplex__, Elemwise{neg,no_inplace}.0)
   ... (remaining 89 Apply instances account for 0.36%(0.21s) of the runtime)

Here are tips to potentially make your code run faster
                 (if you think of new ones, suggest them on the mailing list).
                 Test them first, as they are not guaranteed to always provide a speedup.
  - Try the Aesara flag floatX=float32
  - Try installing amdlibm and set the Aesara flag lib__amblibm=True. This speeds up only some Elemwise operation.

The dlogp profile was run as follows:

mod_profile_dlogp = mod.profile(mod.dlogp())
mod_profile_dlogp.summary()

Returns:

Function profiling
==================
  Message: /Users/jast1849/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/aesaraf.py:970
  Time in 1000 calls to Function.__call__: 1.049568e+02s
  Time in Function.vm.__call__: 104.82313346862793s (99.873%)
  Time in thunks: 103.50367903709412s (98.615%)
  Total compilation time: 1.252786e+01s
    Number of Apply nodes: 265
    Aesara rewrite time: 1.154807e+01s
       Aesara validate time: 2.012794e-01s
    Aesara Linker time (includes C, CUDA code generation/compiling): 0.543797492980957s
       Import time 2.388318e-01s
       Node make_thunk time 5.282750e-01s
           Node Elemwise{Composite{(((Switch(GE(i0, i1), i2, i3) + ((i4 * i5) / i6) + ((i7 * i8 * i9 * i10) / i6) + i11 + (((-i12) / i13) * sgn(i10)) + (i14 * i15 * i10) + i16 + ((i17 * i5) / i6) + ((i18 * i8 * i19 * i10) / i6) + i20 + i21 + i22 + ((i23 * i5) / i6) + ((i24 * i8 * i25 * i10) / i6) + i26 + i27) * i0) + i28)}}[(0, 0)](s0_0_log___log, TensorConstant{0.0}, TensorConstant{-0.0002}, TensorConstant{0}, Sum{acc_dtype=float64}.0, Elemwise{true_div,no_inplace}.0, Elemwise{add,no_inplace}.0, TensorConstant{-1.0}, TensorConstant{2.0}, Sum{acc_dtype=float64}.0, Elemwise{add,no_inplace}.0, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, Elemwise{abs,no_inplace}.0, TensorConstant{4.0}, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, TensorConstant{-1.0}, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, TensorConstant{-1.0}, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, (d__logp/dt0_0_log___logprob){1.0}) time 5.202770e-03s
           Node Elemwise{exp,no_inplace}(s2_0_log__) time 5.120993e-03s
           Node Elemwise{Composite{(((Switch(GE(i0, i1), i2, i3) + ((-i4) / i5) + i6 + ((-((i4 * i7 * i8) / i5)) / i5) + ((-(i9 * i10 * i11)) / sqr(i5)) + i12 + ((-((i13 * i7 * i8) / i5)) / i5) + ((-(i14 * i15 * i11)) / sqr(i5)) + i16 + ((-((i17 * i7 * i8) / i5)) / i5) + ((-(i18 * i19 * i11)) / sqr(i5))) * i0) + i20)}}[(0, 0)](t0_0_log___log, TensorConstant{0.0}, TensorConstant{-0.002}, TensorConstant{0}, Sum{acc_dtype=float64}.0, Elemwise{add,no_inplace}.0, Sum{acc_dtype=float64}.0, Elemwise{true_div,no_inplace}.0, Elemwise{add,no_inplace}.0, TensorConstant{-1.0}, Sum{acc_dtype=float64}.0, Elemwise{sqr,no_inplace}.0, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, TensorConstant{-1.0}, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, Sum{acc_dtype=float64}.0, TensorConstant{-1.0}, Sum{acc_dtype=float64}.0, (d__logp/dt0_0_log___logprob){1.0}) time 4.855871e-03s
           Node Elemwise{Composite{((-Switch(i0, Switch(i1, ((((i2 * i3 * i4 * i4 * i5 * i6 * i7 * i8) / i9) / i10) - ((i11 * i12 * (((i13 * i4) / i14) + ((i15 * i4 * i4) / i10)) * i5 * i6 * i7 * i8) / i9)), i16), Switch(i1, ((i17 * i18 * i19 * i4 * i5 * i6 * i7 * i8) / (i9 * i20)), i16))) / i21)}}(Elemwise{lt,no_inplace}.0, Elemwise{Composite{GT(i0, (i1 * i2))}}.0, TensorConstant{(1,) of -1.0}, TensorConstant{(1,) of -0.5}, Elemwise{sub,no_inplace}.0, Elemwise{Composite{Switch(IsInf(Composite{(i0 / expm1((-i1)))}(i0, i1)), i2, Composite{(i0 / expm1((-i1)))}(i0, i1))}}[(0, 1)].0, InplaceDimShuffle{x}.0, Elemwise{Composite{erfc(((i0 * i1) / i2))}}.0, Elemwise{Composite{exp(Switch(i0, Switch(i1, (i2 + scalar_log1mexp(i3)), i2), i4))}}[(0, 2)].0, InplaceDimShuffle{x}.0, InplaceDimShuffle{x}.0, TensorConstant{(1,) of -1.0}, TensorConstant{(1,) of 0...3966440824}, TensorConstant{(1,) of -1..1670955126}, Elemwise{Composite{erfcx(((i0 * i1) / i2))}}.0, TensorConstant{(1,) of -1..5865763297}, TensorConstant{(1,) of 0}, TensorConstant{(1,) of -1.0}, TensorConstant{(1,) of 0...2872290391}, Elemwise{Composite{exp(((i0 * i1 * i1) / i2))}}.0, Elemwise{Composite{(i0 + (-i1))}}[(0, 1)].0, Elemwise{sqr,no_inplace}.0) time 4.809141e-03s
           Node Elemwise{Composite{((-Switch(i0, Switch(i1, i2, ((((i3 * i4 * i5 * i5 * i6 * i7 * i8) / i9) / i10) - ((i11 * i12 * i13 * i6 * i7 * i8) / i9))), Switch(i1, i2, ((i14 * i15 * i16 * i5 * i6 * i7 * i8) / (i9 * i17))))) / i18)}}[(0, 5)](Elemwise{lt,no_inplace}.0, Elemwise{Composite{GT(i0, (i1 * i2))}}.0, TensorConstant{(1,) of 0}, TensorConstant{(1,) of -1.0}, TensorConstant{(1,) of -0.5}, Elemwise{sub,no_inplace}.0, InplaceDimShuffle{x}.0, Elemwise{Composite{erfc(((i0 * i1) / i2))}}.0, Elemwise{Composite{exp(Switch(i0, Switch(i1, (i2 + scalar_log1mexp(i3)), i2), i4))}}[(0, 2)].0, InplaceDimShuffle{x}.0, InplaceDimShuffle{x}.0, TensorConstant{(1,) of -1.0}, TensorConstant{(1,) of 0...3966440824}, Elemwise{Composite{(((i0 * i1) / i2) + ((i3 * i1 * i1) / i4))}}.0, TensorConstant{(1,) of -1.0}, TensorConstant{(1,) of 0...2872290391}, Elemwise{Composite{exp(((i0 * i1 * i1) / i2))}}.0, Elemwise{Composite{(i0 + (-i1))}}[(0, 1)].0, Elemwise{sqr,no_inplace}.0) time 4.699945e-03s

Time in all call to aesara.grad() 2.374974e+00s
Time since aesara import 154.726s
Class
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Class name>
  96.5%    96.5%      99.895s       6.57e-04s     C   152000     152   aesara.tensor.elemwise.Elemwise
   1.7%    98.3%       1.799s       4.73e-05s     C    38000      38   aesara.tensor.math.Sum
   1.0%    99.3%       1.079s       2.70e-04s     C     4000       4   aesara.tensor.nnet.basic.Softmax
   0.2%    99.5%       0.234s       6.16e-06s     C    38000      38   aesara.tensor.elemwise.DimShuffle
   0.2%    99.7%       0.187s       1.87e-04s     Py    1000       1   aesara.tensor.basic.ARange
   0.1%    99.8%       0.140s       3.49e-05s     C     4000       4   aesara.tensor.basic.Join
   0.1%    99.9%       0.087s       2.90e-05s     C     3000       3   aesara.tensor.basic.Split
   0.0%   100.0%       0.040s       8.07e-06s     C     5000       5   aesara.tensor.shape.Reshape
   0.0%   100.0%       0.014s       1.43e-06s     C    10000      10   aesara.tensor.shape.SpecifyShape
   0.0%   100.0%       0.010s       2.45e-06s     C     4000       4   aesara.tensor.basic.MakeVector
   0.0%   100.0%       0.007s       7.32e-06s     C     1000       1   aesara.tensor.basic.Alloc
   0.0%   100.0%       0.004s       2.13e-06s     C     2000       2   aesara.tensor.math.Max
   0.0%   100.0%       0.003s       2.82e-06s     C     1000       1   aesara.tensor.shape.Shape_i
   0.0%   100.0%       0.003s       2.66e-06s     C     1000       1   aesara.tensor.math.All
   0.0%   100.0%       0.002s       2.09e-06s     C     1000       1   aesara.tensor.math.Min
   ... (remaining 0 Classes account for   0.00%(0.00s) of the runtime)

Ops
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Op name>
  18.2%    18.2%      18.848s       3.77e-03s     C     5000        5   Elemwise{Composite{erfc(((i0 * i1) / i2))}}
  11.7%    29.9%      12.096s       2.02e-03s     C     6000        6   Elemwise{Composite{exp(((i0 * i1 * i1) / i2))}}
   8.2%    38.1%       8.475s       8.47e-03s     C     1000        1   Elemwise{Composite{exp(Switch(i0, Switch(i1, (i2 + scalar_log1mexp(i3)), i2), i4))}}[(0, 2)]
   8.0%    46.0%       8.235s       1.65e-03s     C     5000        5   Elemwise{Composite{erfcx(((i0 * i1) / i2))}}
   6.8%    52.8%       7.034s       2.34e-03s     C     3000        3   Elemwise{Composite{Switch(i0, (log((i1 * i2)) - (i1 * sqr(i3))), log1p((i4 * i5)))}}[(0, 2)]
   5.7%    58.5%       5.908s       5.91e-03s     C     1000        1   Elemwise{Composite{(((i0 / i1) + i2 + Switch(i3, (log((i4 * i5)) - (i4 * sqr(i6))), log1p((i7 * i8)))) - i9)}}[(0, 6)]
   4.5%    63.1%       4.696s       4.70e-03s     C     1000        1   Elemwise{Composite{((-Switch(i0, Switch(i1, ((((i2 * i3 * i4 * i4 * i5 * i6 * i7 * i8) / i9) / i10) - ((i11 * i12 * (((i13 * i4) / i14) + ((i15 * i4 * i4) / i10)) * i5 * i6 * i7 * i8) / i9)), i16), Switch(i1, ((i17 * i18 * i19 * i4 * i5 * i6 * i7 * i8) / (i9 * i20)), i16))) / i21)}}
   4.1%    67.2%       4.256s       2.13e-03s     C     2000        2   Elemwise{Composite{Switch(IsInf(Composite{(i0 / expm1((-i1)))}(i0, i1)), i2, Composite{(i0 / expm1((-i1)))}(i0, i1))}}[(0, 1)]
   4.1%    71.3%       4.254s       4.25e-03s     C     1000        1   Elemwise{Composite{Switch(i0, Switch(i1, ((((i2 * i3 * i4 * i5 * i6 * i7 * i8) / i9) / i10) - ((i11 * i12 * ((i13 / (i14 * i15)) + ((i16 * i4) / i10)) * i5 * i6 * i7 * i8) / i9)), i17), Switch(i1, (((i18 * i19 * i20 * i5 * i6 * i7 * i8) / i9) / (i21 * i15)), i17))}}[(0, 4)]
   2.9%    74.2%       3.039s       3.04e-03s     C     1000        1   Elemwise{Composite{Switch(i0, Switch(i1, i2, ((((i3 * i4 * i5 * i6 * i7 * i8) / i9) / i10) - ((i11 * i12 * i13 * i6 * i7 * i8) / i9))), Switch(i1, i2, (((i14 * i15 * i16 * i6 * i7 * i8) / i9) / i17)))}}[(0, 13)]
   2.7%    76.9%       2.779s       2.78e-03s     C     1000        1   Elemwise{Composite{((-Switch(i0, Switch(i1, i2, ((((i3 * i4 * i5 * i5 * i6 * i7 * i8) / i9) / i10) - ((i11 * i12 * i13 * i6 * i7 * i8) / i9))), Switch(i1, i2, ((i14 * i15 * i16 * i5 * i6 * i7 * i8) / (i9 * i17))))) / i18)}}[(0, 5)]
   2.2%    79.2%       2.317s       3.31e-04s     C     7000        7   Elemwise{true_div,no_inplace}
   2.0%    81.2%       2.120s       2.65e-04s     C     8000        8   Elemwise{sub,no_inplace}
   1.8%    83.1%       1.912s       9.56e-04s     C     2000        2   Elemwise{Composite{((-Switch(i0, Switch(i1, (((i2 * i3 * i3 * i4) / i5) - (i6 * i7 * i4)), i8), Switch(i1, (i9 * i10 * i3 * (i4 / i11)), i8))) / i12)}}[(0, 4)]
   1.8%    84.9%       1.865s       6.22e-04s     C     3000        3   Elemwise{Composite{((i0 / (i1 * i2)) + ((i3 * i4) / i5))}}
   1.8%    86.6%       1.847s       6.16e-04s     C     3000        3   Elemwise{Composite{(((i0 * i1) / i2) + ((i3 * i1 * i1) / i4))}}
   1.7%    88.3%       1.731s       4.68e-05s     C     37000       37   Sum{acc_dtype=float64}
   1.4%    89.7%       1.409s       4.70e-04s     C     3000        3   Elemwise{Composite{Switch(i0, Switch(i1, (((i2 * i3 * i4) / i5) - (i6 * i7 * i4)), i8), Switch(i1, ((i9 * i10 * i4) / i11), i8))}}
   1.2%    90.9%       1.228s       1.23e-03s     C     1000        1   Elemwise{Composite{Switch(i0, (((i1 * i2 * i3 * i4 * i5 * i6 * i7) / i8) / i9), i10)}}[(0, 3)]
   1.1%    91.9%       1.088s       1.09e-03s     C     1000        1   Elemwise{Composite{((Switch(i0, Switch(i1, (i2 + scalar_log1mexp(i3)), i2), i4) + Switch(i5, (log((i6 * i7)) - (i6 * sqr(i8))), log1p((i6 * i9)))) - i10)}}[(0, 2)]
   ... (remaining 83 Ops account for   8.08%(8.37s) of the runtime)

Apply
------
<% time> <sum %> <apply time> <time per call> <#call> <id> <Apply name>
   8.2%     8.2%       8.475s       8.47e-03s   1000   146   Elemwise{Composite{exp(Switch(i0, Switch(i1, (i2 + scalar_log1mexp(i3)), i2), i4))}}[(0, 2)](InplaceDimShuffle{x}.0, Elemwise{Composite{GT(i0, (i1 * i2))}}.0, Elemwise{Composite{Switch(i0, (log((i1 * i2)) - (i1 * sqr(i3))), log1p((i4 * i5)))}}[(0, 2)].0, Elemwise{Composite{(((i0 / i1) + i2 + Switch(i3, (log((i4 * i5)) - (i4 * sqr(i6))), log1p((i7 * i8)))) - i9)}}[(0, 6)].0, TensorConstant{(1,) of -inf})
   7.1%    15.2%       7.297s       7.30e-03s   1000   127   Elemwise{Composite{erfc(((i0 * i1) / i2))}}(TensorConstant{(1,) of 0...7932881648}, Elemwise{sub,no_inplace}.0, InplaceDimShuffle{x}.0)
   6.0%    21.2%       6.184s       6.18e-03s   1000   120   Elemwise{Composite{erfc(((i0 * i1) / i2))}}(TensorConstant{(1,) of 0...7932881648}, Elemwise{Composite{((i0 + i1) - i2)}}.0, InplaceDimShuffle{x}.0)
   5.7%    26.9%       5.908s       5.91e-03s   1000   142   Elemwise{Composite{(((i0 / i1) + i2 + Switch(i3, (log((i4 * i5)) - (i4 * sqr(i6))), log1p((i7 * i8)))) - i9)}}[(0, 6)](Elemwise{sub,no_inplace}.0, InplaceDimShuffle{x}.0, Elemwise{Composite{(i0 * sqr(i1))}}.0, Elemwise{lt,no_inplace}.0, TensorConstant{(1,) of 0.5}, Elemwise{Composite{erfcx(((i0 * i1) / i2))}}.0, Elemwise{true_div,no_inplace}.0, TensorConstant{(1,) of -0.5}, Elemwise{Composite{erfc(((i0 * i1) / i2))}}.0, Elemwise{Composite{Switch(i0,
   5.6%    32.5%       5.775s       5.77e-03s   1000   138   Elemwise{Composite{Switch(i0, (log((i1 * i2)) - (i1 * sqr(i3))), log1p((i4 * i5)))}}[(0, 2)](Elemwise{lt,no_inplace}.0, TensorConstant{(1,) of 0.5}, Elemwise{Composite{erfcx(((i0 * i1) / i2))}}.0, Elemwise{true_div,no_inplace}.0, TensorConstant{(1,) of -0.5}, Elemwise{Composite{erfc(((i0 * i1) / i2))}}.0)
   4.5%    37.0%       4.696s       4.70e-03s   1000   202   Elemwise{Composite{((-Switch(i0, Switch(i1, ((((i2 * i3 * i4 * i4 * i5 * i6 * i7 * i8) / i9) / i10) - ((i11 * i12 * (((i13 * i4) / i14) + ((i15 * i4 * i4) / i10)) * i5 * i6 * i7 * i8) / i9)), i16), Switch(i1, ((i17 * i18 * i19 * i4 * i5 * i6 * i7 * i8) / (i9 * i20)), i16))) / i21)}}(Elemwise{lt,no_inplace}.0, Elemwise{Composite{GT(i0, (i1 * i2))}}.0, TensorConstant{(1,) of -1.0}, TensorConstant{(1,) of -0.5}, Elemwise{sub,no_inplace}.0, Elemwise{Com
   4.3%    41.3%       4.425s       4.43e-03s   1000   123   Elemwise{Composite{erfc(((i0 * i1) / i2))}}(TensorConstant{(1,) of 0...7932881648}, Elemwise{sub,no_inplace}.0, InplaceDimShuffle{x}.0)
   4.1%    45.4%       4.254s       4.25e-03s   1000   206   Elemwise{Composite{Switch(i0, Switch(i1, ((((i2 * i3 * i4 * i5 * i6 * i7 * i8) / i9) / i10) - ((i11 * i12 * ((i13 / (i14 * i15)) + ((i16 * i4) / i10)) * i5 * i6 * i7 * i8) / i9)), i17), Switch(i1, (((i18 * i19 * i20 * i5 * i6 * i7 * i8) / i9) / (i21 * i15)), i17))}}[(0, 4)](Elemwise{lt,no_inplace}.0, Elemwise{Composite{GT(i0, (i1 * i2))}}.0, TensorConstant{(1,) of -1.0}, TensorConstant{(1,) of -0.5}, Elemwise{sub,no_inplace}.0, Elemwise{Composite{Sw
   3.8%    49.2%       3.910s       3.91e-03s   1000   147   Elemwise{Composite{Switch(IsInf(Composite{(i0 / expm1((-i1)))}(i0, i1)), i2, Composite{(i0 / expm1((-i1)))}(i0, i1))}}[(0, 1)](TensorConstant{(1,) of -1.0}, Elemwise{Composite{(((i0 / i1) + i2 + Switch(i3, (log((i4 * i5)) - (i4 * sqr(i6))), log1p((i7 * i8)))) - i9)}}[(0, 6)].0, TensorConstant{(1,) of -inf})
   3.8%    53.0%       3.892s       3.89e-03s   1000   126   Elemwise{Composite{erfcx(((i0 * i1) / i2))}}(TensorConstant{(1,) of -0..7932881648}, Elemwise{sub,no_inplace}.0, InplaceDimShuffle{x}.0)
   3.6%    56.6%       3.722s       3.72e-03s   1000   121   Elemwise{Composite{exp(((i0 * i1 * i1) / i2))}}(TensorConstant{(1,) of -0..0171142714}, Elemwise{Composite{((i0 + i1) - i2)}}.0, Elemwise{sqr,no_inplace}.0)
   3.6%    60.1%       3.717s       3.72e-03s   1000   129   Elemwise{Composite{exp(((i0 * i1 * i1) / i2))}}(TensorConstant{(1,) of -0..0171142714}, Elemwise{sub,no_inplace}.0, Elemwise{sqr,no_inplace}.0)
   3.6%    63.7%       3.714s       3.71e-03s   1000   125   Elemwise{Composite{exp(((i0 * i1 * i1) / i2))}}(TensorConstant{(1,) of -0..0171142714}, Elemwise{sub,no_inplace}.0, Elemwise{sqr,no_inplace}.0)
   2.9%    66.7%       3.039s       3.04e-03s   1000   225   Elemwise{Composite{Switch(i0, Switch(i1, i2, ((((i3 * i4 * i5 * i6 * i7 * i8) / i9) / i10) - ((i11 * i12 * i13 * i6 * i7 * i8) / i9))), Switch(i1, i2, (((i14 * i15 * i16 * i6 * i7 * i8) / i9) / i17)))}}[(0, 13)](Elemwise{lt,no_inplace}.0, Elemwise{Composite{GT(i0, (i1 * i2))}}.0, TensorConstant{(1,) of 0}, TensorConstant{(1,) of -1.0}, TensorConstant{(1,) of -0.5}, Elemwise{sub,no_inplace}.0, InplaceDimShuffle{x}.0, Elemwise{Composite{erfc(((i0 * i1
   2.8%    69.4%       2.871s       2.87e-03s   1000   124   Elemwise{Composite{erfcx(((i0 * i1) / i2))}}(TensorConstant{(1,) of -0..7932881648}, Elemwise{sub,no_inplace}.0, InplaceDimShuffle{x}.0)
   2.7%    72.1%       2.779s       2.78e-03s   1000   230   Elemwise{Composite{((-Switch(i0, Switch(i1, i2, ((((i3 * i4 * i5 * i5 * i6 * i7 * i8) / i9) / i10) - ((i11 * i12 * i13 * i6 * i7 * i8) / i9))), Switch(i1, i2, ((i14 * i15 * i16 * i5 * i6 * i7 * i8) / (i9 * i17))))) / i18)}}[(0, 5)](Elemwise{lt,no_inplace}.0, Elemwise{Composite{GT(i0, (i1 * i2))}}.0, TensorConstant{(1,) of 0}, TensorConstant{(1,) of -1.0}, TensorConstant{(1,) of -0.5}, Elemwise{sub,no_inplace}.0, InplaceDimShuffle{x}.0, Elemwise{Comp
   1.7%    73.8%       1.752s       1.75e-03s   1000   226   Elemwise{Composite{((-Switch(i0, Switch(i1, (((i2 * i3 * i3 * i4) / i5) - (i6 * i7 * i4)), i8), Switch(i1, (i9 * i10 * i3 * (i4 / i11)), i8))) / i12)}}[(0, 4)](Elemwise{lt,no_inplace}.0, Elemwise{Composite{GT(i0, (i1 * i2))}}.0, TensorConstant{(1,) of -1.0}, Elemwise{sub,no_inplace}.0, Elemwise{Composite{((i0 * i1 * i2 * i3) + ((i4 * i5 * i6 * i7 * i2 * i3) / i8))}}.0, InplaceDimShuffle{x}.0, TensorConstant{(1,) of 0...7932881648}, Elemwise{Composit
   1.6%    75.4%       1.655s       1.65e-03s   1000   134   Elemwise{Composite{((i0 / (i1 * i2)) + ((i3 * i4) / i5))}}(TensorConstant{(1,) of -1..1670955126}, Elemwise{Composite{erfcx(((i0 * i1) / i2))}}.0, InplaceDimShuffle{x}.0, TensorConstant{(1,) of -1..5865763297}, Elemwise{sub,no_inplace}.0, Elemwise{sqr,no_inplace}.0)
   1.6%    77.0%       1.638s       1.64e-03s   1000   133   Elemwise{Composite{(((i0 * i1) / i2) + ((i3 * i1 * i1) / i4))}}(TensorConstant{(1,) of -1..1670955126}, Elemwise{sub,no_inplace}.0, Elemwise{Composite{erfcx(((i0 * i1) / i2))}}.0, TensorConstant{(1,) of -1..5865763297}, InplaceDimShuffle{x}.0)
   1.2%    78.2%       1.228s       1.23e-03s   1000   203   Elemwise{Composite{Switch(i0, (((i1 * i2 * i3 * i4 * i5 * i6 * i7) / i8) / i9), i10)}}[(0, 3)](Elemwise{Composite{GT(i0, (i1 * i2))}}.0, TensorConstant{(1,) of -1.0}, TensorConstant{(1,) of -0.5}, Elemwise{sub,no_inplace}.0, Elemwise{Composite{Switch(IsInf(Composite{(i0 / expm1((-i1)))}(i0, i1)), i2, Composite{(i0 / expm1((-i1)))}(i0, i1))}}[(0, 1)].0, InplaceDimShuffle{x}.0, Elemwise{Composite{erfc(((i0 * i1) / i2))}}.0, Elemwise{Composite{exp(Swit
   ... (remaining 245 Apply instances account for 21.81%(22.57s) of the runtime)

Here are tips to potentially make your code run faster
                 (if you think of new ones, suggest them on the mailing list).
                 Test them first, as they are not guaranteed to always provide a speedup.
  - Try the Aesara flag floatX=float32
  - Try installing amdlibm and set the Aesara flag lib__amblibm=True. This speeds up only some Elemwise operation.

Any advice on where to go from here would be much appreciated. Thank you :pray: