Version dependant slowing down of Gaussian Mixture sampling in Ubuntu 20.04

No only pymc=5.9.1 is slow with pytensor=2.17.3. It installs pytensors that way when you require pymc=5.9.1 and nothing else.

(pymc_env_5_9_1_test4) avicenna@avicenna:~/Desktop$ conda create -n pymc_env_5_9_1_test6 pymc=5.9.1  
Collecting package metadata (current_repodata.json): done
Solving environment: done

## Package Plan ##

  environment location: /home/avicenna/miniconda3/envs/pymc_env_5_9_1_test6

  added / updated specs:
    - pymc=5.9.1


The following NEW packages will be INSTALLED:

  _libgcc_mutex      conda-forge/linux-64::_libgcc_mutex-0.1-conda_forge 
  _openmp_mutex      conda-forge/linux-64::_openmp_mutex-4.5-2_gnu 
  arviz              conda-forge/noarch::arviz-0.16.1-pyhd8ed1ab_1 
  atk-1.0            conda-forge/linux-64::atk-1.0-2.38.0-hd4edc92_1 
  brotli             conda-forge/linux-64::brotli-1.1.0-hd590300_1 
  brotli-bin         conda-forge/linux-64::brotli-bin-1.1.0-hd590300_1 
  bzip2              conda-forge/linux-64::bzip2-1.0.8-h7f98852_4 
  c-ares             conda-forge/linux-64::c-ares-1.20.1-hd590300_1 
  ca-certificates    conda-forge/linux-64::ca-certificates-2023.7.22-hbcca054_0 
  cached-property    conda-forge/noarch::cached-property-1.5.2-hd8ed1ab_1 
  cached_property    conda-forge/noarch::cached_property-1.5.2-pyha770c72_1 
  cachetools         conda-forge/noarch::cachetools-5.3.2-pyhd8ed1ab_0 
  cairo              conda-forge/linux-64::cairo-1.18.0-h3faef2a_0 
  certifi            conda-forge/noarch::certifi-2023.7.22-pyhd8ed1ab_0 
  cloudpickle        conda-forge/noarch::cloudpickle-3.0.0-pyhd8ed1ab_0 
  cons               conda-forge/noarch::cons-0.4.6-pyhd8ed1ab_0 
  contourpy          conda-forge/linux-64::contourpy-1.1.1-py311h9547e67_1 
  cycler             conda-forge/noarch::cycler-0.12.1-pyhd8ed1ab_0 
  etuples            conda-forge/noarch::etuples-0.3.9-pyhd8ed1ab_0 
  expat              conda-forge/linux-64::expat-2.5.0-hcb278e6_1 
  fastprogress       conda-forge/noarch::fastprogress-1.0.3-pyhd8ed1ab_0 
  filelock           conda-forge/noarch::filelock-3.13.1-pyhd8ed1ab_0 
  font-ttf-dejavu-s~ conda-forge/noarch::font-ttf-dejavu-sans-mono-2.37-hab24e00_0 
  font-ttf-inconsol~ conda-forge/noarch::font-ttf-inconsolata-3.000-h77eed37_0 
  font-ttf-source-c~ conda-forge/noarch::font-ttf-source-code-pro-2.038-h77eed37_0 
  font-ttf-ubuntu    conda-forge/noarch::font-ttf-ubuntu-0.83-hab24e00_0 
  fontconfig         conda-forge/linux-64::fontconfig-2.14.2-h14ed4e7_0 
  fonts-conda-ecosy~ conda-forge/noarch::fonts-conda-ecosystem-1-0 
  fonts-conda-forge  conda-forge/noarch::fonts-conda-forge-1-0 
  fonttools          conda-forge/linux-64::fonttools-4.43.1-py311h459d7ec_0 
  freetype           conda-forge/linux-64::freetype-2.12.1-h267a509_2 
  fribidi            conda-forge/linux-64::fribidi-1.0.10-h36c2ea0_0 
  gdk-pixbuf         conda-forge/linux-64::gdk-pixbuf-2.42.10-h829c605_4 
  gettext            conda-forge/linux-64::gettext-0.21.1-h27087fc_0 
  giflib             conda-forge/linux-64::giflib-5.2.1-h0b41bf4_3 
  graphite2          conda-forge/linux-64::graphite2-1.3.13-h58526e2_1001 
  graphviz           conda-forge/linux-64::graphviz-8.1.0-h28d9a01_0 
  gtk2               conda-forge/linux-64::gtk2-2.24.33-h90689f9_2 
  gts                conda-forge/linux-64::gts-0.7.6-h977cf35_4 
  h5netcdf           conda-forge/noarch::h5netcdf-1.2.0-pyhd8ed1ab_0 
  h5py               conda-forge/linux-64::h5py-3.10.0-nompi_py311h3839ddf_100 
  harfbuzz           conda-forge/linux-64::harfbuzz-8.2.1-h3d44ed6_0 
  hdf5               conda-forge/linux-64::hdf5-1.14.2-nompi_h4f84152_100 
  icu                conda-forge/linux-64::icu-73.2-h59595ed_0 
  keyutils           conda-forge/linux-64::keyutils-1.6.1-h166bdaf_0 
  kiwisolver         conda-forge/linux-64::kiwisolver-1.4.5-py311h9547e67_1 
  krb5               conda-forge/linux-64::krb5-1.21.2-h659d440_0 
  lcms2              conda-forge/linux-64::lcms2-2.15-hb7c19ff_3 
  ld_impl_linux-64   conda-forge/linux-64::ld_impl_linux-64-2.40-h41732ed_0 
  lerc               conda-forge/linux-64::lerc-4.0.0-h27087fc_0 
  libaec             conda-forge/linux-64::libaec-1.1.2-h59595ed_1 
  libblas            conda-forge/linux-64::libblas-3.9.0-19_linux64_openblas 
  libbrotlicommon    conda-forge/linux-64::libbrotlicommon-1.1.0-hd590300_1 
  libbrotlidec       conda-forge/linux-64::libbrotlidec-1.1.0-hd590300_1 
  libbrotlienc       conda-forge/linux-64::libbrotlienc-1.1.0-hd590300_1 
  libcblas           conda-forge/linux-64::libcblas-3.9.0-19_linux64_openblas 
  libcurl            conda-forge/linux-64::libcurl-8.4.0-hca28451_0 
  libdeflate         conda-forge/linux-64::libdeflate-1.19-hd590300_0 
  libedit            conda-forge/linux-64::libedit-3.1.20191231-he28a2e2_2 
  libev              conda-forge/linux-64::libev-4.33-h516909a_1 
  libexpat           conda-forge/linux-64::libexpat-2.5.0-hcb278e6_1 
  libffi             conda-forge/linux-64::libffi-3.4.2-h7f98852_5 
  libgcc-ng          conda-forge/linux-64::libgcc-ng-13.2.0-h807b86a_2 
  libgd              conda-forge/linux-64::libgd-2.3.3-h119a65a_9 
  libgfortran-ng     conda-forge/linux-64::libgfortran-ng-13.2.0-h69a702a_2 
  libgfortran5       conda-forge/linux-64::libgfortran5-13.2.0-ha4646dd_2 
  libglib            conda-forge/linux-64::libglib-2.78.0-hebfc3b9_0 
  libgomp            conda-forge/linux-64::libgomp-13.2.0-h807b86a_2 
  libiconv           conda-forge/linux-64::libiconv-1.17-h166bdaf_0 
  libjpeg-turbo      conda-forge/linux-64::libjpeg-turbo-3.0.0-hd590300_1 
  liblapack          conda-forge/linux-64::liblapack-3.9.0-19_linux64_openblas 
  libnghttp2         conda-forge/linux-64::libnghttp2-1.55.1-h47da74e_0 
  libnsl             conda-forge/linux-64::libnsl-2.0.1-hd590300_0 
  libopenblas        conda-forge/linux-64::libopenblas-0.3.24-pthreads_h413a1c8_0 
  libpng             conda-forge/linux-64::libpng-1.6.39-h753d276_0 
  librsvg            conda-forge/linux-64::librsvg-2.56.3-h98fae49_0 
  libsqlite          conda-forge/linux-64::libsqlite-3.43.2-h2797004_0 
  libssh2            conda-forge/linux-64::libssh2-1.11.0-h0841786_0 
  libstdcxx-ng       conda-forge/linux-64::libstdcxx-ng-13.2.0-h7e041cc_2 
  libtiff            conda-forge/linux-64::libtiff-4.6.0-ha9c0a0a_2 
  libtool            conda-forge/linux-64::libtool-2.4.7-h27087fc_0 
  libuuid            conda-forge/linux-64::libuuid-2.38.1-h0b41bf4_0 
  libwebp            conda-forge/linux-64::libwebp-1.3.2-h658648e_1 
  libwebp-base       conda-forge/linux-64::libwebp-base-1.3.2-hd590300_0 
  libxcb             conda-forge/linux-64::libxcb-1.15-h0b41bf4_0 
  libxml2            conda-forge/linux-64::libxml2-2.11.5-h232c23b_1 
  libzlib            conda-forge/linux-64::libzlib-1.2.13-hd590300_5 
  logical-unificati~ conda-forge/noarch::logical-unification-0.4.6-pyhd8ed1ab_0 
  matplotlib-base    conda-forge/linux-64::matplotlib-base-3.8.0-py311h54ef318_2 
  minikanren         conda-forge/noarch::minikanren-1.0.3-pyhd8ed1ab_0 
  multipledispatch   conda-forge/noarch::multipledispatch-0.6.0-py_0 
  munkres            conda-forge/noarch::munkres-1.1.4-pyh9f0ad1d_0 
  ncurses            conda-forge/linux-64::ncurses-6.4-h59595ed_2 
  numpy              conda-forge/linux-64::numpy-1.25.2-py311h64a7726_0 
  openjpeg           conda-forge/linux-64::openjpeg-2.5.0-h488ebb8_3 
  openssl            conda-forge/linux-64::openssl-3.1.4-hd590300_0 
  packaging          conda-forge/noarch::packaging-23.2-pyhd8ed1ab_0 
  pandas             conda-forge/linux-64::pandas-2.1.2-py311h320fe9a_0 
  pango              conda-forge/linux-64::pango-1.50.14-ha41ecd1_2 
  pcre2              conda-forge/linux-64::pcre2-10.40-hc3806b6_0 
  pillow             conda-forge/linux-64::pillow-10.1.0-py311ha6c5da5_0 
  pip                conda-forge/noarch::pip-23.3.1-pyhd8ed1ab_0 
  pixman             conda-forge/linux-64::pixman-0.42.2-h59595ed_0 
  pthread-stubs      conda-forge/linux-64::pthread-stubs-0.4-h36c2ea0_1001 
  pymc               conda-forge/noarch::pymc-5.9.1-hd8ed1ab_0 
  pymc-base          conda-forge/noarch::pymc-base-5.9.1-pyhd8ed1ab_0 
  pyparsing          conda-forge/noarch::pyparsing-3.1.1-pyhd8ed1ab_0 
  pytensor           pkgs/main/linux-64::pytensor-2.13.1-py311ha02d727_0 
  pytensor-base      conda-forge/linux-64::pytensor-base-2.17.3-py311h320fe9a_0 
  python             conda-forge/linux-64::python-3.11.6-hab00c5b_0_cpython 
  python-dateutil    conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0 
  python-graphviz    conda-forge/noarch::python-graphviz-0.20.1-pyh22cad53_0 
  python-tzdata      conda-forge/noarch::python-tzdata-2023.3-pyhd8ed1ab_0 
  python_abi         conda-forge/linux-64::python_abi-3.11-4_cp311 
  pytz               conda-forge/noarch::pytz-2023.3.post1-pyhd8ed1ab_0 
  readline           conda-forge/linux-64::readline-8.2-h8228510_1 
  scipy              conda-forge/linux-64::scipy-1.11.3-py311h64a7726_1 
  setuptools         conda-forge/noarch::setuptools-68.2.2-pyhd8ed1ab_0 
  six                conda-forge/noarch::six-1.16.0-pyh6c4a22f_0 
  tk                 conda-forge/linux-64::tk-8.6.13-h2797004_0 
  toolz              conda-forge/noarch::toolz-0.12.0-pyhd8ed1ab_0 
  typing-extensions  conda-forge/noarch::typing-extensions-4.8.0-hd8ed1ab_0 
  typing_extensions  conda-forge/noarch::typing_extensions-4.8.0-pyha770c72_0 
  tzdata             conda-forge/noarch::tzdata-2023c-h71feb2d_0 
  wheel              conda-forge/noarch::wheel-0.41.3-pyhd8ed1ab_0 
  xarray             conda-forge/noarch::xarray-2023.10.1-pyhd8ed1ab_0 
  xarray-einstats    conda-forge/noarch::xarray-einstats-0.6.0-pyhd8ed1ab_0 
  xorg-kbproto       conda-forge/linux-64::xorg-kbproto-1.0.7-h7f98852_1002 
  xorg-libice        conda-forge/linux-64::xorg-libice-1.1.1-hd590300_0 
  xorg-libsm         conda-forge/linux-64::xorg-libsm-1.2.4-h7391055_0 
  xorg-libx11        conda-forge/linux-64::xorg-libx11-1.8.7-h8ee46fc_0 
  xorg-libxau        conda-forge/linux-64::xorg-libxau-1.0.11-hd590300_0 
  xorg-libxdmcp      conda-forge/linux-64::xorg-libxdmcp-1.1.3-h7f98852_0 
  xorg-libxext       conda-forge/linux-64::xorg-libxext-1.3.4-h0b41bf4_2 
  xorg-libxrender    conda-forge/linux-64::xorg-libxrender-0.9.11-hd590300_0 
  xorg-renderproto   conda-forge/linux-64::xorg-renderproto-0.11.1-h7f98852_1002 
  xorg-xextproto     conda-forge/linux-64::xorg-xextproto-7.3.0-h0b41bf4_1003 
  xorg-xproto        conda-forge/linux-64::xorg-xproto-7.0.31-h7f98852_1007 
  xz                 conda-forge/linux-64::xz-5.2.6-h166bdaf_0 
  zlib               conda-forge/linux-64::zlib-1.2.13-hd590300_5 
  zstd               conda-forge/linux-64::zstd-1.5.5-hfc55251_0 


Proceed ([y]/n)? 

Here are the prints in 5.9.1 and 5.9 (where in the former case enviroment has been uniformized to 5.9 by specifying some package versions).

Diff for trace is 0 but for logp here it is:

diff logp-dlogp_5_9.txt logp-dlogp_5_9_1.txt

2257c2257
<     └─ exp [id BNS] 't4'
---
>     └─ exp [id BNS] 't1'
2263c2263
<     └─ exp [id BNS] 't4'
---
>     └─ exp [id BNS] 't1'
2267c2267
<     └─ exp [id BNS] 't4'
---
>     └─ exp [id BNS] 't1'
2290c2290
<     β”œβ”€ GE [id BOP] 't8'
---
>     β”œβ”€ GE [id BOP] 't14'
2300c2300
<     β”‚  β”œβ”€ GE [id BOP] 't8'
---
>     β”‚  β”œβ”€ GE [id BOP] 't14'

trace_5_9_1.txt (1.4 KB)
logp-dlogp_5_9_1.txt (203.0 KB)
trace_5_9.txt (1.4 KB)
logp-dlogp_5_9.txt (203.0 KB)

It seems my conda configuration was the culprit in this weird installation of pytensors. My conda channel settings were

channels:
  - conda-forge
  - bioconda
  - r
  - defaults
auto_activate_base: false
channel_priority: flexible

Changing flexible to strict now asks to install pytensor 2.17 from conda-forge (rather than pytensor 2.13 from default and base 2.17 from forge). I will try reinstalling the environment now and running the code.

Done, installed it with strict and now environments are more uniform by default but there is still a speed difference between when doing pm.sample(random_seed=10, init=β€˜advi+adapt_diag’).

Atleast now I know the answer to this. This somehow happend because channel priority was set to flexible:

channels:
  - conda-forge
  - bioconda
  - r
  - defaults
auto_activate_base: false
channel_priority: flexible

which result also in a weird installation

pytensor                  2.13.1          py311ha02d727_0  
pytensor-base             2.17.3          py311h320fe9a_0    conda-forge

After changing that to strict, installs the environment as follows:

# packages in environment at /home/avicenna/miniconda3/envs/pymc_env_5_9_1_test6:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                 conda_forge    conda-forge
_openmp_mutex             4.5                  2_kmp_llvm    conda-forge
arviz                     0.16.1             pyhd8ed1ab_1    conda-forge
asttokens                 2.4.1              pyhd8ed1ab_0    conda-forge
atk-1.0                   2.38.0               hd4edc92_1    conda-forge
backports                 1.0                pyhd8ed1ab_3    conda-forge
backports.functools_lru_cache 1.6.5              pyhd8ed1ab_0    conda-forge
binutils_impl_linux-64    2.40                 hf600244_0    conda-forge
binutils_linux-64         2.40                 hbdbef99_2    conda-forge
blas                      2.119                  openblas    conda-forge
blas-devel                3.9.0           19_linux64_openblas    conda-forge
brotli                    1.1.0                hd590300_1    conda-forge
brotli-bin                1.1.0                hd590300_1    conda-forge
bzip2                     1.0.8                h7f98852_4    conda-forge
c-ares                    1.20.1               hd590300_1    conda-forge
ca-certificates           2023.7.22            hbcca054_0    conda-forge
cached-property           1.5.2                hd8ed1ab_1    conda-forge
cached_property           1.5.2              pyha770c72_1    conda-forge
cachetools                5.3.2              pyhd8ed1ab_0    conda-forge
cairo                     1.18.0               h3faef2a_0    conda-forge
certifi                   2023.7.22          pyhd8ed1ab_0    conda-forge
cloudpickle               3.0.0              pyhd8ed1ab_0    conda-forge
comm                      0.1.4              pyhd8ed1ab_0    conda-forge
cons                      0.4.6              pyhd8ed1ab_0    conda-forge
contourpy                 1.1.1           py311h9547e67_1    conda-forge
cycler                    0.12.1             pyhd8ed1ab_0    conda-forge
debugpy                   1.8.0           py311hb755f60_1    conda-forge
decorator                 5.1.1              pyhd8ed1ab_0    conda-forge
etuples                   0.3.9              pyhd8ed1ab_0    conda-forge
exceptiongroup            1.1.3              pyhd8ed1ab_0    conda-forge
executing                 2.0.1              pyhd8ed1ab_0    conda-forge
expat                     2.5.0                hcb278e6_1    conda-forge
fastprogress              1.0.3              pyhd8ed1ab_0    conda-forge
filelock                  3.13.1             pyhd8ed1ab_0    conda-forge
font-ttf-dejavu-sans-mono 2.37                 hab24e00_0    conda-forge
font-ttf-inconsolata      3.000                h77eed37_0    conda-forge
font-ttf-source-code-pro  2.038                h77eed37_0    conda-forge
font-ttf-ubuntu           0.83                 hab24e00_0    conda-forge
fontconfig                2.14.2               h14ed4e7_0    conda-forge
fonts-conda-ecosystem     1                             0    conda-forge
fonts-conda-forge         1                             0    conda-forge
fonttools                 4.43.1          py311h459d7ec_0    conda-forge
freetype                  2.12.1               h267a509_2    conda-forge
fribidi                   1.0.10               h36c2ea0_0    conda-forge
gcc                       12.3.0               h8d2909c_2    conda-forge
gcc_impl_linux-64         12.3.0               he2b93b0_2    conda-forge
gcc_linux-64              12.3.0               h76fc315_2    conda-forge
gdk-pixbuf                2.42.10              h829c605_4    conda-forge
gettext                   0.21.1               h27087fc_0    conda-forge
giflib                    5.2.1                h0b41bf4_3    conda-forge
graphite2                 1.3.13            h58526e2_1001    conda-forge
graphviz                  8.1.0                h28d9a01_0    conda-forge
gtk2                      2.24.33              h90689f9_2    conda-forge
gts                       0.7.6                h977cf35_4    conda-forge
gxx                       12.3.0               h8d2909c_2    conda-forge
gxx_impl_linux-64         12.3.0               he2b93b0_2    conda-forge
gxx_linux-64              12.3.0               h8a814eb_2    conda-forge
h5netcdf                  1.2.0              pyhd8ed1ab_0    conda-forge
h5py                      3.10.0          nompi_py311h3839ddf_100    conda-forge
harfbuzz                  8.2.1                h3d44ed6_0    conda-forge
hdf5                      1.14.2          nompi_h4f84152_100    conda-forge
icu                       73.2                 h59595ed_0    conda-forge
importlib-metadata        6.8.0              pyha770c72_0    conda-forge
importlib_metadata        6.8.0                hd8ed1ab_0    conda-forge
ipykernel                 6.26.0             pyhf8b6a83_0    conda-forge
ipython                   8.17.2             pyh41d4057_0    conda-forge
jedi                      0.19.1             pyhd8ed1ab_0    conda-forge
joblib                    1.3.2              pyhd8ed1ab_0    conda-forge
jupyter_client            8.5.0              pyhd8ed1ab_0    conda-forge
jupyter_core              5.5.0           py311h38be061_0    conda-forge
kernel-headers_linux-64   2.6.32              he073ed8_16    conda-forge
keyutils                  1.6.1                h166bdaf_0    conda-forge
kiwisolver                1.4.5           py311h9547e67_1    conda-forge
krb5                      1.21.2               h659d440_0    conda-forge
lcms2                     2.15                 hb7c19ff_3    conda-forge
ld_impl_linux-64          2.40                 h41732ed_0    conda-forge
lerc                      4.0.0                h27087fc_0    conda-forge
libaec                    1.1.2                h59595ed_1    conda-forge
libblas                   3.9.0           19_linux64_openblas    conda-forge
libbrotlicommon           1.1.0                hd590300_1    conda-forge
libbrotlidec              1.1.0                hd590300_1    conda-forge
libbrotlienc              1.1.0                hd590300_1    conda-forge
libcblas                  3.9.0           19_linux64_openblas    conda-forge
libcurl                   8.4.0                hca28451_0    conda-forge
libdeflate                1.19                 hd590300_0    conda-forge
libedit                   3.1.20191231         he28a2e2_2    conda-forge
libev                     4.33                 h516909a_1    conda-forge
libexpat                  2.5.0                hcb278e6_1    conda-forge
libffi                    3.4.2                h7f98852_5    conda-forge
libgcc-devel_linux-64     12.3.0               h8bca6fd_2    conda-forge
libgcc-ng                 13.2.0               h807b86a_2    conda-forge
libgd                     2.3.3                h119a65a_9    conda-forge
libgfortran-ng            13.2.0               h69a702a_2    conda-forge
libgfortran5              13.2.0               ha4646dd_2    conda-forge
libglib                   2.78.0               hebfc3b9_0    conda-forge
libgomp                   13.2.0               h807b86a_2    conda-forge
libhwloc                  2.9.3           default_h554bfaf_1009    conda-forge
libiconv                  1.17                 h166bdaf_0    conda-forge
libjpeg-turbo             3.0.0                hd590300_1    conda-forge
liblapack                 3.9.0           19_linux64_openblas    conda-forge
liblapacke                3.9.0           19_linux64_openblas    conda-forge
libnghttp2                1.55.1               h47da74e_0    conda-forge
libnsl                    2.0.1                hd590300_0    conda-forge
libopenblas               0.3.24          pthreads_h413a1c8_0    conda-forge
libpng                    1.6.39               h753d276_0    conda-forge
librsvg                   2.56.3               h98fae49_0    conda-forge
libsanitizer              12.3.0               h0f45ef3_2    conda-forge
libsodium                 1.0.18               h36c2ea0_1    conda-forge
libsqlite                 3.43.2               h2797004_0    conda-forge
libssh2                   1.11.0               h0841786_0    conda-forge
libstdcxx-devel_linux-64  12.3.0               h8bca6fd_2    conda-forge
libstdcxx-ng              13.2.0               h7e041cc_2    conda-forge
libtiff                   4.6.0                ha9c0a0a_2    conda-forge
libtool                   2.4.7                h27087fc_0    conda-forge
libuuid                   2.38.1               h0b41bf4_0    conda-forge
libwebp                   1.3.2                h658648e_1    conda-forge
libwebp-base              1.3.2                hd590300_0    conda-forge
libxcb                    1.15                 h0b41bf4_0    conda-forge
libxml2                   2.11.5               h232c23b_1    conda-forge
libzlib                   1.2.13               hd590300_5    conda-forge
llvm-openmp               17.0.4               h4dfa4b3_0    conda-forge
logical-unification       0.4.6              pyhd8ed1ab_0    conda-forge
matplotlib-base           3.8.0           py311h54ef318_2    conda-forge
matplotlib-inline         0.1.6              pyhd8ed1ab_0    conda-forge
minikanren                1.0.3              pyhd8ed1ab_0    conda-forge
mkl                       2022.2.1         h84fe81f_16997    conda-forge
mkl-service               2.4.0           py311hb711fc7_0    conda-forge
multipledispatch          0.6.0                      py_0    conda-forge
munkres                   1.1.4              pyh9f0ad1d_0    conda-forge
ncurses                   6.4                  h59595ed_2    conda-forge
nest-asyncio              1.5.8              pyhd8ed1ab_0    conda-forge
numpy                     1.25.2          py311h64a7726_0    conda-forge
openblas                  0.3.24          pthreads_h7a3da1a_0    conda-forge
openjpeg                  2.5.0                h488ebb8_3    conda-forge
openssl                   3.1.4                hd590300_0    conda-forge
packaging                 23.2               pyhd8ed1ab_0    conda-forge
pandas                    2.1.2           py311h320fe9a_0    conda-forge
pango                     1.50.14              ha41ecd1_2    conda-forge
parso                     0.8.3              pyhd8ed1ab_0    conda-forge
pcre2                     10.40                hc3806b6_0    conda-forge
pexpect                   4.8.0              pyh1a96a4e_2    conda-forge
pickleshare               0.7.5                   py_1003    conda-forge
pillow                    10.1.0          py311ha6c5da5_0    conda-forge
pip                       23.3.1             pyhd8ed1ab_0    conda-forge
pixman                    0.42.2               h59595ed_0    conda-forge
platformdirs              3.11.0             pyhd8ed1ab_0    conda-forge
prompt-toolkit            3.0.39             pyha770c72_0    conda-forge
prompt_toolkit            3.0.39               hd8ed1ab_0    conda-forge
psutil                    5.9.5           py311h459d7ec_1    conda-forge
pthread-stubs             0.4               h36c2ea0_1001    conda-forge
ptyprocess                0.7.0              pyhd3deb0d_0    conda-forge
pure_eval                 0.2.2              pyhd8ed1ab_0    conda-forge
pygments                  2.16.1             pyhd8ed1ab_0    conda-forge
pymc                      5.9.1                hd8ed1ab_0    conda-forge
pymc-base                 5.9.1              pyhd8ed1ab_0    conda-forge
pyparsing                 3.1.1              pyhd8ed1ab_0    conda-forge
pytensor                  2.17.3          py311hb755f60_0    conda-forge
pytensor-base             2.17.3          py311h320fe9a_0    conda-forge
python                    3.11.6          hab00c5b_0_cpython    conda-forge
python-dateutil           2.8.2              pyhd8ed1ab_0    conda-forge
python-graphviz           0.20.1             pyh22cad53_0    conda-forge
python-tzdata             2023.3             pyhd8ed1ab_0    conda-forge
python_abi                3.11                    4_cp311    conda-forge
pytz                      2023.3.post1       pyhd8ed1ab_0    conda-forge
pyzmq                     25.1.1          py311h34ded2d_2    conda-forge
readline                  8.2                  h8228510_1    conda-forge
scikit-learn              1.3.2           py311hc009520_1    conda-forge
scipy                     1.11.3          py311h64a7726_1    conda-forge
setuptools                68.2.2             pyhd8ed1ab_0    conda-forge
six                       1.16.0             pyh6c4a22f_0    conda-forge
spyder-kernels            2.4.4           unix_pyh707e725_0    conda-forge
stack_data                0.6.2              pyhd8ed1ab_0    conda-forge
sysroot_linux-64          2.12                he073ed8_16    conda-forge
tbb                       2021.10.0            h00ab1b0_2    conda-forge
threadpoolctl             3.2.0              pyha21a80b_0    conda-forge
tk                        8.6.13               h2797004_0    conda-forge
toolz                     0.12.0             pyhd8ed1ab_0    conda-forge
tornado                   6.3.3           py311h459d7ec_1    conda-forge
traitlets                 5.13.0             pyhd8ed1ab_0    conda-forge
typing-extensions         4.8.0                hd8ed1ab_0    conda-forge
typing_extensions         4.8.0              pyha770c72_0    conda-forge
tzdata                    2023c                h71feb2d_0    conda-forge
wcwidth                   0.2.9              pyhd8ed1ab_0    conda-forge
wheel                     0.41.3             pyhd8ed1ab_0    conda-forge
wurlitzer                 3.0.3              pyhd8ed1ab_0    conda-forge
xarray                    2023.10.1          pyhd8ed1ab_0    conda-forge
xarray-einstats           0.6.0              pyhd8ed1ab_0    conda-forge
xorg-kbproto              1.0.7             h7f98852_1002    conda-forge
xorg-libice               1.1.1                hd590300_0    conda-forge
xorg-libsm                1.2.4                h7391055_0    conda-forge
xorg-libx11               1.8.7                h8ee46fc_0    conda-forge
xorg-libxau               1.0.11               hd590300_0    conda-forge
xorg-libxdmcp             1.1.3                h7f98852_0    conda-forge
xorg-libxext              1.3.4                h0b41bf4_2    conda-forge
xorg-libxrender           0.9.11               hd590300_0    conda-forge
xorg-renderproto          0.11.1            h7f98852_1002    conda-forge
xorg-xextproto            7.3.0             h0b41bf4_1003    conda-forge
xorg-xproto               7.0.31            h7f98852_1007    conda-forge
xz                        5.2.6                h166bdaf_0    conda-forge
zeromq                    4.3.5                h59595ed_0    conda-forge
zipp                      3.17.0             pyhd8ed1ab_0    conda-forge
zlib                      1.2.13               hd590300_5    conda-forge
zstd                      1.5.5                hfc55251_0    conda-forge

Can you check if the following returns the same values in both versions?

model.compile_logp()(model.initial_point())
model.compile_dlogp()(model.initial_point())

It’s surprising that the compiled functions are so similar except some input ordering, and yet you get so different results.

Ok. If you think that this is surprising perhaps I should first nuke these environments and recreate them (given the issue with channel priority) just to test this more cleanly. I will get back to you shortly

1 Like

Thanks a lot @iavicenna for the thorough debugging work! It’s really helpful to know that the channel priority was responsible for the creation of this broken environment.

I managed to reproduce it locally, and for everyone’s reference, providing the the flags -c conda-forge --override-channels works around the problem for me without the need for modifying the config. (This may help to speed up debugging in the future.)

2 Likes

Ok so I nuked the environments and reinstalled them. package info for enviroment pymc=5.8.2 and its comparison to 5.9.0 and 5.9.1 are attached below. The comparison is now very clean, only pytensor and pymc versions are different between different environments (and pytensor and its base are the same versions!).

Following, I did a set of tests where I call the script (added to the end) from these three environments from the terminal (did not use spyder this time and also printed pm.version to make sure I am using the version that I think I am using). I ran 5.9.1 with three different seeds to make sure it was not a seed issue and others with one seed. As before 5.8.2 and 5.9.0 reach completion (actually in similar times this round, so maybe the difference between these I saw previously was just a seed issue). 5.9.1 still gets infinitely stuck after a while (so I stopped after sometime). Here are the logs (note in 5.8.2 and 5.9.0 advi terminates before %100 when it converges):

pymc version: 5.9.1, random_seed: 5
using advi: True
Auto-assigning NUTS sampler...
Initializing NUTS using advi+adapt_diag...
^CInterrupted at 327 [0%]: Average Loss = 3,869.3-----------------------------------------------------| 0.16% [327/200000 01:11<12:09:34 Average Loss = 3,870.8]

pymc version: 5.9.1, random_seed: 111
using advi: True
Auto-assigning NUTS sampler...
Initializing NUTS using advi+adapt_diag...
^CInterrupted at 420 [0%]: Average Loss = 3,738.8-----------------------------------------------------| 0.21% [420/200000 01:23<11:04:16 Average Loss = 3,748.2]

pymc version: 5.9.1, random_seed: 1111
using advi: True
Auto-assigning NUTS sampler...
Initializing NUTS using advi+adapt_diag...
^CInterrupted at 98 [0%]: Average Loss = 3,830.1------------------------------------------------------| 0.05% [98/200000 00:21<11:59:44 Average Loss = 3,868.2]


pymc version: 5.8.2, random_seed: 1111
using advi: True
Auto-assigning NUTS sampler...
Initializing NUTS using advi+adapt_diag...
Convergence achieved at 59400β–ˆβ–ˆ-----------------------------------------------------------------------| 29.67% [59346/200000 03:01<07:10 Average Loss = 1,038.5]
Interrupted at 59,399 [29%]: Average Loss = 1,824.4
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [w, x_coord, y_coord, sigma]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 665 seconds.[8000/8000 11:04<00:00 Sampling 4 chains, 0 divergences]

pymc version: 5.9.0, random_seed: 1111
using advi: True
Auto-assigning NUTS sampler...
Initializing NUTS using advi+adapt_diag...
Convergence achieved at 59400β–ˆβ–ˆ-----------------------------------------------------------------------| 29.68% [59354/200000 02:55<06:56 Average Loss = 1,038.5]
Interrupted at 59,399 [29%]: Average Loss = 1,824.4
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [w, x_coord, y_coord, sigma]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 657 seconds.[8000/8000 10:57<00:00 Sampling 4 chains, 0 divergences]

Following to make sure this was not an advi issue I ran the same code with just pm.sample().
What happens with 5.8 and 5.9.0, it initially starts with an estimate of ~1hr but then it speeds up considerably after 10 minutes and finishes around 20 minutes (this could be something to check for people complaining that Gaussian mixtures are too slow, i.e try advi and also try running until the end).
However, 5.9.1 again gets infinitely stuck:

pymc version: 5.8.2, random_seed: 1111
using advi: False
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [w, x_coord, y_coord, sigma]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1563 seconds.000/8000 26:02<00:00 Sampling 4 chains, 49 divergences]
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
There were 49 divergences after tuning. Increase `target_accept` or reparameterize.

pymc version: 5.9.1, random_seed: 1111
using advi: False
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [w, x_coord, y_coord, sigma]
^CTraceback (most recent call last):-----------------------------------------------------------| 0.95% [76/8000 06:01<10:27:32 Sampling 4 chains, 0 divergences]
  
  pymc version: 5.9.0, random_seed: 1111
using advi: False
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [w, x_coord, y_coord, sigma]
Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 1570 seconds.000/8000 26:09<00:00 Sampling 4 chains, 49 divergences]
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling. See https://arxiv.org/abs/1903.08008 for details
The effective sample size per chain is smaller than 100 for some parameters.  A higher number is needed for reliable rhat and ess computation. See https://arxiv.org/abs/1903.08008 for details
There were 49 divergences after tuning. Increase `target_accept` or reparameterize.

Now going back to what @ricardoV94 has asked, I have done some debug prints to compare between 5.9.0 and 5.9.1. Even though the traces are still identical, this time logp-dlogp are quite different so I attached the diff as logp_diff.txt. In both cases the model.compile_logp() values are identical (using the same seed):

pymc version: 5.9.0, random_seed: 1111
-3115.2765648137506
[   1.5396188   108.35295371  144.87835984  112.74718936   -1.31499805
 -125.21506787 -199.13245673 -205.10570647 -125.49660322   10.42051968
   25.78704799   20.52551608  -17.55547322  -39.17761053    7.88794009
    3.48514039  -42.51867365   -1.21457135  -29.03688925   -9.1239996
  -49.67507295    0.39696196   13.55134652    6.4564686 ]

 pymc version: 5.9.1, random_seed: 1111
-3115.2765648137506
[   1.5396188   108.35295371  144.87835984  112.74718936   -1.31499805
 -125.21506787 -199.13245673 -205.10570647 -125.49660322   10.42051968
   25.78704799   20.52551608  -17.55547322  -39.17761053    7.88794009
    3.48514039  -42.51867365   -1.21457135  -29.03688925   -9.1239996
  -49.67507295    0.39696196   13.55134652    6.4564686 ]

Happy to run more diagnostics should you have any in mind (I am going to try setting the threading environment variables and see if that makes a difference too).
Finally the code:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import pymc as pm
import numpy as np
import matplotlib.pyplot as plt
import pytensor.tensor as ptt
from sklearn.datasets import make_blobs
from sklearn.preprocessing import StandardScaler

n_clusters = 5
data, labels = make_blobs(n_samples=1000, centers=n_clusters, random_state=10)

scaler = StandardScaler()
scaled_data = scaler.fit_transform(data)
plt.scatter(*scaled_data.T, c=labels)
coords={"cluster": np.arange(n_clusters),
        "obs_id": np.arange(data.shape[0]),
        "coord":['x', 'y']}

# When you use the ordered transform, the initial values need to be
# monotonically increasing
sorted_initvals = np.linspace(-2, 2, 5)

pymc_version = pm.__version__

if pymc_version in ["5.8.2", "5.9.0", "5.9.1"]:
  trans = pm.distributions.transforms.ordered
else:
  raise ValueError(f"Unknown pymc version {pymc_version}")

random_seed = 1111
use_advi = False
print(f"pymc version: {pymc_version}, random_seed: {random_seed}")
print(f"using advi: {use_advi}")


with pm.Model(coords=coords) as m:
    # Use alpha > 1 to prevent the model from finding sparse solutions -- we know
    # all 5 clusters should be represented in the posterior
    w = pm.Dirichlet("w", np.full(n_clusters, 10), dims=['cluster'])

    # Mean component
    x_coord = pm.Normal("x_coord", sigma=1, dims=["cluster"],
                        transform=trans, initval=sorted_initvals)

    y_coord = pm.Normal('y_coord', sigma=1, dims=['cluster'])
    centroids = pm.Deterministic('centroids', ptt.concatenate([x_coord[None], y_coord[None]]).T,
                                dims=['cluster', 'coord'])

    # Diagonal covariances. Could also model the full covariance matrix, but I didn't try.
    sigma = pm.HalfNormal('sigma', sigma=1, dims=['cluster', 'coord'])
    covs = [ptt.diag(sigma[i]) for i in range(n_clusters)]

    # Define the mixture
    components = [pm.MvNormal.dist(mu=centroids[i], cov=covs[i]) for i in range(n_clusters)]
    y_hat = pm.Mixture("y_hat",
                       w,
                       components,
                       observed=scaled_data,
                       dims=["obs_id", 'coord'])

    if use_advi:
      idata = pm.sample(init='advi+adapt_diag', random_seed=random_seed)
    else:
      idata = pm.sample(random_seed=random_seed)

packages.txt (13.0 KB)
logp_diff.txt (342.7 KB)
trace_5_9.txt (1.4 KB)
logp-dlogp_5.9.0.txt (200.8 KB)
trace_5_9.txt (1.4 KB)
logp-dlogp_5.9.1.txt (203.0 KB)

Not that I would understand it but I was eyeballing the logp-dlogp graphs to see where some essential looking differences occur and for instance I have seen stuff like this:
5.9.0

 β”‚     β”‚     β”‚  └─ Join [id BX] 229
 β”‚     β”‚     β”‚     β”œβ”€ 1 [id BY]
 β”‚     β”‚     β”‚     β”œβ”€ Composite{switch(i2, ((-1.8378770664093453 - (0.5 * i0)) - i1), -inf)} [id BZ] 227
 β”‚     β”‚     β”‚     β”‚  β”œβ”€ ExpandDims{axis=1} [id CA] 221
 β”‚     β”‚     β”‚     β”‚  β”‚  └─ CAReduce{Composite{(i0 + sqr(i1))}, axis=1} [id CB] 213
 β”‚     β”‚     β”‚     β”‚  β”‚     └─ Transpose{axes=[1, 0]} [id CC] 205
 β”‚     β”‚     β”‚     β”‚  β”‚        └─ SolveTriangular{trans=0, unit_diagonal=False, lower=True, check_finite=True, b_ndim=2} [id CD] 197
 β”‚     β”‚     β”‚     β”‚  β”‚           β”œβ”€ Switch [id CE] 156

vs 5.9.1

 β”‚     β”‚     β”‚  └─ Join [id BX] 238
 β”‚     β”‚     β”‚     β”œβ”€ 1 [id BY]
 β”‚     β”‚     β”‚     β”œβ”€ Composite{switch(i2, ((-1.8378770664093453 - (0.5 * i0)) - i1), -inf)} [id BZ] 236
 β”‚     β”‚     β”‚     β”‚  β”œβ”€ ExpandDims{axis=1} [id CA] 228
 β”‚     β”‚     β”‚     β”‚  β”‚  └─ CAReduce{Composite{(i0 + sqr(i1))}, axis=1} [id CB] 219
 β”‚     β”‚     β”‚     β”‚  β”‚     └─ Blockwise{SolveTriangular{trans=0, unit_diagonal=False, lower=True, check_finite=True, b_ndim=1}, (m,m),(m)->(m)} [id CC] 207
 β”‚     β”‚     β”‚     β”‚  β”‚        β”œβ”€ Switch [id CD] 166

I would just focus on 5.9 vs 5.9.1 because those are pretty similar and, importantly, actually rely on the same versions of PyTensor

Apologies my mistake, I indeed meant 5.9.0 vs 5.9.1. You can look at the attached file logp-dlogp (file names automatically derive from pm.version to prevent any possible confusion). The difference I displayed above is 5.9.0 vs 5.9.1. Most of the difference comes from just ID differences but in some cases it seems like a structural difference as above.

I have also seen the following:

5.9.0
 β”‚     β”‚     β”‚     β”‚  β”‚           β”‚  β”œβ”€ Cholesky{lower=True, destructive=False, on_error='nan'} [id EE] 73

5.9.1
β”‚     β”‚     β”‚     β”‚  β”‚        β”‚  β”œβ”€ ExpandDims{axis=0} [id EJ] 91
β”‚     β”‚     β”‚     β”‚  β”‚        β”‚  β”‚  └─ Cholesky{lower=True, destructive=False, on_error='nan'} [id EF] 73

but mainly most of the differences are of the form

5.9.0
 β”‚  β”‚  β”‚  β”‚  β”‚     β”‚  β”‚     β”‚  β”‚     β”‚     └─ SolveTriangular{trans=0, unit_diagonal=False, lower=False, check_finite=True, b_ndim=2} [id KR] 295

5.9.1
 β”‚  β”‚  β”‚  β”‚  β”‚     β”‚  β”‚     β”‚  β”‚     β”‚     └─ Blockwise{SolveTriangular{trans=0, unit_diagonal=False, lower=False, check_finite=True, b_ndim=1}, (m,m),(m)->(m)} [id KS] 294

One another big difference seems to be here

5.9.0

β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚        └─ Composite{((i0 * i1) - i2)} [id MA] 580
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”œβ”€ Transpose{axes=[1, 0]} [id MB] 388
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚  └─ dot [id MC] 380
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚     β”œβ”€ Switch [id LS] 176
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚     β”‚  └─ Β·Β·Β·
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚     └─ Composite{switch(i4, (i3 + switch(i2, (-1.0 * i0 * i1), 0.0)), 1)} [id MD] 373
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”œβ”€ dot [id ME] 298
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  β”œβ”€ SolveTriangular{trans=0, unit_diagonal=False, lower=False, check_finite=True, b_ndim=2} [id JL] 287
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  β”‚  └─ Β·Β·Β·
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  └─ Transpose{axes=[1, 0]} [id CC] 205
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚     └─ Β·Β·Β·
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”œβ”€ Tri{dtype='float64'} [id MF] 366
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  β”œβ”€ Subtensor{i} [id MG] 354
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  β”‚  β”œβ”€ Subtensor{start:} [id MH] 342
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  β”‚  β”‚  β”œβ”€ Shape [id MI] 330
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  β”‚  β”‚  β”‚  └─ Neg [id MJ] 314
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  β”‚  β”‚  β”‚     └─ dot [id ME] 298
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  β”‚  β”‚  β”‚        └─ Β·Β·Β·
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  β”‚  β”‚  └─ ScalarFromTensor [id MK] 12
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  β”‚  β”‚     └─ -2 [id ML]
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  β”‚  └─ ScalarFromTensor [id MM] 13
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  β”‚     └─ 0 [id MN]
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  β”œβ”€ Subtensor{i} [id MO] 353
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  β”‚  β”œβ”€ Subtensor{start:} [id MH] 342
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  β”‚  β”‚  └─ Β·Β·Β·
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  β”‚  └─ ScalarFromTensor [id MP] 11
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  β”‚     └─ 1 [id MQ]
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  └─ 0 [id E]

5.9.1
 └─ Composite{((i0 * i1) - i2)} [id MB] 604
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”œβ”€ Transpose{axes=[1, 0]} [id MC] 413
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚  └─ dot [id MD] 404
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚     β”œβ”€ Switch [id LS] 167
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚     β”‚  └─ Β·Β·Β·
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚     └─ Composite{switch(i3, (i2 + switch(i1, (-i0), 0.0)), 1)} [id ME] 397
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”œβ”€ Sum{axis=0} [id MF] 391
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  └─ Mul [id MG] 384
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚     β”œβ”€ Blockwise{dot, (i00,i01),(i10,i11)->(o00,o01)} [id MH] 313
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚     β”‚  β”œβ”€ ExpandDims{axis=2} [id MI] 297
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚     β”‚  β”‚  └─ Blockwise{SolveTriangular{trans=0, unit_diagonal=False, lower=False, check_finite=True, b_ndim=1}, (m,m),(m)->(m)} [id JQ] 286
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚     β”‚  β”‚     └─ Β·Β·Β·
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚     β”‚  └─ ExpandDims{axis=1} [id MJ] 220
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚     β”‚     └─ Blockwise{SolveTriangular{trans=0, unit_diagonal=False, lower=True, check_finite=True, b_ndim=1}, (m,m),(m)->(m)} [id CC] 207
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚     β”‚        └─ Β·Β·Β·
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚     └─ Blockwise{Tri{dtype='float64'}, (),(),()->(o00,o01)} [id MK] 377
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚        β”œβ”€ Blockwise{Subtensor{i}, (i00),()->()} [id ML] 365
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚        β”‚  β”œβ”€ Blockwise{Subtensor{start:}, (i00),()->(o00)} [id MM] 353
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚        β”‚  β”‚  β”œβ”€ Blockwise{Shape, (i00,i01)->(o00)} [id MN] 341
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚        β”‚  β”‚  β”‚  └─ Neg [id MO] 329
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚        β”‚  β”‚  β”‚     └─ Blockwise{dot, (i00,i01),(i10,i11)->(o00,o01)} [id MH] 313
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚        β”‚  β”‚  β”‚        └─ Β·Β·Β·
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚        β”‚  β”‚  └─ [-2] [id MP]
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚        β”‚  └─ [0] [id MQ]
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚        β”œβ”€ Blockwise{Subtensor{i}, (i00),()->()} [id MR] 364
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚        β”‚  β”œβ”€ Blockwise{Subtensor{start:}, (i00),()->(o00)} [id MM] 353
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚        β”‚  β”‚  └─ Β·Β·Β·
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚        β”‚  └─ [1] [id MS]
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚        └─ [0] [id CO]

this is also different

5.9.0
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  └─ Subtensor{i} [id YM] 822
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚     β”œβ”€ Shape [id YN] 796
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚     β”‚  └─ ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id YO] 764
 β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚     β”‚     └─ SolveTriangular{trans=0, unit_diagonal=False, lower=False, check_finite=True, b_ndim=2} [id YP] 725

5.9.1
β”œβ”€ ARange{dtype='int64'} [id ZY] 865
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”œβ”€ 0 [id E]
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”œβ”€ Subtensor{i} [id ZZ] 839
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”œβ”€ Shape [id BAA] 813
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  └─ ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id BAB] 781
β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚  β”‚     └─ SolveTriangular{trans=0, unit_diagonal=False, lower=False, 

There is also one part where Dot was replaced with Sum and Mult (I guess in the end they do the same thing but still…) but then it is followed by something called Blockwise dot which seemed strange. And also the inputs look some what different in switch (i4,i3,i2,i1 vs i3,i2,i1).

5.9.0
└─ Composite{switch(i4, (i3 + switch(i2, (-1.0 * i0 * i1), 0.0)), 1)} [id BAN] 377
 β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”œβ”€ dot [id BAO] 307
 β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  β”œβ”€ SolveTriangular{trans=0, unit_diagonal=False, lower=False, check_finite=True, b_ndim=2} [id KK] 293
 β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  β”‚  └─ Β·Β·Β·
 β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  └─ Transpose{axes=[1, 0]} [id EQ] 202
 β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚     └─ Β·Β·Β·

5.9.1
└─ Composite{switch(i3, (i2 + switch(i1, (-i0), 0.0)), 1)} [id BBD] 402
 β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”œβ”€ Sum{axis=0} [id BBE] 395
 β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚  └─ Mul [id BBF] 388
 β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚     β”œβ”€ Blockwise{dot, (i00,i01),(i10,i11)->(o00,o01)} [id BBG] 322
 β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚     β”‚  β”œβ”€ ExpandDims{axis=2} [id BBH] 306
 β”‚  β”‚  β”‚  β”‚  β”‚           β”‚        β”‚     β”‚  β”‚  └─ Blockwise{SolveTriangular{trans=0, unit_diagonal=False, lower=False, check_finite=True, b_ndim=1}, (m,m),(m)->(m)} [id KM] 292

Possibly another related difference (note All{axis=0} vs All{axis=None})

5.9.0
β”‚     β”‚     β”‚     β”‚  β”‚  └─ CAReduce{Composite{(i0 + sqr(i1))}, axis=1} [id CB] 213
β”‚     β”‚     β”‚     β”‚  β”‚     └─ Transpose{axes=[1, 0]} [id CC] 205
β”‚     β”‚     β”‚     β”‚  β”‚        └─ SolveTriangular{trans=0, unit_diagonal=False, lower=True, check_finite=True, b_ndim=2} [id CD] 197
β”‚     β”‚     β”‚     β”‚  β”‚           β”œβ”€ Switch [id CE] 156
β”‚     β”‚     β”‚     β”‚  β”‚           β”‚  β”œβ”€ ExpandDims{axes=[0, 1]} [id CF] 138
β”‚     β”‚     β”‚     β”‚  β”‚           β”‚  β”‚  └─ All{axes=None} [id CG] 120
β”‚     β”‚     β”‚     β”‚  β”‚           β”‚  β”‚     └─ Gt [id CH] 101

5.9.1
β”‚     β”‚     β”‚     β”‚  β”‚  └─ CAReduce{Composite{(i0 + sqr(i1))}, axis=1} [id CB] 219
β”‚     β”‚     β”‚     β”‚  β”‚     └─ Blockwise{SolveTriangular{trans=0, unit_diagonal=False, lower=True, check_finite=True, b_ndim=1}, (m,m),(m)->(m)} [id CC] 207
β”‚     β”‚     β”‚     β”‚  β”‚        β”œβ”€ Switch [id CD] 166
β”‚     β”‚     β”‚     β”‚  β”‚        β”‚  β”œβ”€ ExpandDims{axes=[0, 1, 2]} [id CE] 145
β”‚     β”‚     β”‚     β”‚  β”‚        β”‚  β”‚  └─ All{axis=0} [id CF] 125
β”‚     β”‚     β”‚     β”‚  β”‚        β”‚  β”‚     └─ Gt [id CG] 106

Perhaps these patterns are more obvious from the inner graphs. When I look at these I see these kind of differences in multiple places (and these types of differences seem like the only ones)

5.9.0
Composite{switch(i4, (i3 + switch(i2, (-1.0 * i0 * i1), 0.0)), 1)} [id MD]
 ← Switch [id BWN] 'o0'
    β”œβ”€ i4 [id BWO]
    β”œβ”€ add [id BWP]
    β”‚  β”œβ”€ i3 [id BWQ]
    β”‚  └─ Switch [id BWR]
    β”‚     β”œβ”€ i2 [id BWS]
    β”‚     β”œβ”€ mul [id BWT]
    β”‚     β”‚  β”œβ”€ -1.0 [id BWU]
    β”‚     β”‚  β”œβ”€ i0 [id BWV]
    β”‚     β”‚  └─ i1 [id BWW]
    β”‚     └─ 0.0 [id BWX]
    └─ 1 [id BWY]
5.9.1
Composite{switch(i3, (i2 + switch(i1, (-i0), 0.0)), 1)} [id ME]
 ← Switch [id BXL] 'o0'
    β”œβ”€ i3 [id BXM]
    β”œβ”€ add [id BXN]
    β”‚  β”œβ”€ i2 [id BXO]
    β”‚  └─ Switch [id BXP]
    β”‚     β”œβ”€ i1 [id BXQ]
    β”‚     β”œβ”€ neg [id BXR]
    β”‚     β”‚  └─ i0 [id BXS]
    β”‚     └─ 0.0 [id BXT]
    └─ 1 [id BXU]

For your info, I did a dif after cleaning the ids so I am also adding that here which might be more useful. It is between 5.9.0 vs 5.9.1

clean_dif.txt (64.1 KB)

I have done some more tests (mostly in 5.9.1, one in 5.9.0) where I change some parameters in the code:

With advi (I did not ran these to completions):

1- nclusters=2 and nsamples=100, the advi init time bar gives an initial estimate ~ 40min.
2- nclusters=2 and nsamples=1000, initial estimate for advi timebar is about 6 hours.
3- nclusters=5 and nsamples=1000, advi starts with an initial estimate of 15 hrs.

So atleast the increase is consistent!

Without advi:
1- nclusters=2 and nsamples=100, sampling starts with an initial estimate of 20 minutes however after about 1 minute, it speeds up quite considerably and finishes at around 6 minutes.
2- nclusters=5 and nsamples=100, sampling gets progressively worse until about 2 hours and then speeds up and finishes at around 33 minutes.
3- nclusters=5 and nsamples=1000, this one starts around at an initial guess of 10 hours, gets progressively worse and then stuck at some point. This was what I had before I stopped it:

pymc version: 5.9.1, random_seed: 1111, nclusters 5, n_samples: 1000
using advi: False
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [w, x_coord, y_coord, sigma]
^CTraceback (most recent call last):----------------------------------------------------------| 2.75% [220/8000 42:21<24:57:59 Sampling 4 chains, 0 divergences]

As a test, I again ran the third option in pymc 5.9.0. It starts of with an estimate of one and a half hours though picks up on speed after sometime and finishes around 22 mins. As before any difference between these environments is now just pymc version. Then I have added the following at the top of the script (as suggested in the github issue page):

import os
os.environ['OPENBLAS_NUM_THREADS'] = '1'
os.environ['MKL_NUM_THREADS'] = '1'

still gets stuck. Tried also exporting these variables from bash just in case and setting conda env variables too. It doesn’t seem like it will ever end:

 |β–ˆ-------------------------------------------------------------------------------------------| 1.26% [101/8000 08:40<11:18:24 Sampling 4 chains, 0 divergences]

So atleast on my end it does not seem to be a threading issue. Let me know if there is anything I can try.
Can anyone replicate this speed issue on Ubuntu 20.04? Meanwhile I will stick to 5.9.0.

1 Like

One more thing, I tried replacing covariance with Cholesky factorization as suggested in MvNormal page as such:

sigmas = [pm.HalfNormal.dist(sigma=1, size=ndims) for i in range(n_clusters)]
chols = [pm.LKJCholeskyCov(f'chol_cov{i}', n=ndims, eta=2,
                           sd_dist=sigmas[i], compute_corr=True)[0] for
         i in range(n_clusters)]
# Define the mixture
components = [pm.MvNormal.dist(mu=centroids[i], chol=chols[i]) for i in range(n_clusters)]

In all of 5.8.2, 5.9.0 and 5.9.1 I get the following warning.

/home/avicenna/miniconda3/envs/pymc_env_5_8/lib/python3.11/site-packages/pytensor/compile/function/types.py:970: RuntimeWarning: invalid value encountered in accumulate
  self.vm()

However the computation is much faster on 5.8.2 and 5.9.0 (1-2 min as opposed to 10-20 mins). However 5.9.1 still gets stuck… Yet maybe this might help some other people having speed issues with their MvNormal mixtures (I do realize this a more flexible model in that previous one was just a diagonal one but increasing eta to 20, which should mostly produce diagonal cov, I still get the speed boost).

For your information here is the back-trace I get when I filter the warning as an error:

raceback (most recent call last):
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/compile/function/types.py", line 970, in __call__
    self.vm()
RuntimeWarning: invalid value encountered in accumulate

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pymc/sampling/parallel.py", line 122, in run
    self._start_loop()
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pymc/sampling/parallel.py", line 174, in _start_loop
    point, stats = self._step_method.step(self._point)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pymc/step_methods/arraystep.py", line 174, in step
    return super().step(point)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pymc/step_methods/arraystep.py", line 100, in step
    apoint, stats = self.astep(q)
                    ^^^^^^^^^^^^^
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pymc/step_methods/hmc/base_hmc.py", line 198, in astep
    hmc_step = self._hamiltonian_step(start, p0.data, step_size)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pymc/step_methods/hmc/nuts.py", line 197, in _hamiltonian_step
    divergence_info, turning = tree.extend(direction)
                               ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pymc/step_methods/hmc/nuts.py", line 290, in extend
    tree, diverging, turning = self._build_subtree(
                               ^^^^^^^^^^^^^^^^^^^^
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pymc/step_methods/hmc/nuts.py", line 371, in _build_subtree
    return self._single_step(left, epsilon)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pymc/step_methods/hmc/nuts.py", line 330, in _single_step
    right = self.integrator.step(epsilon, left)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pymc/step_methods/hmc/integration.py", line 82, in step
    return self._step(epsilon, state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pymc/step_methods/hmc/integration.py", line 118, in _step
    logp = self._logp_dlogp_func(q_new, grad_out=q_new_grad)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pymc/model/core.py", line 378, in __call__
    cost, *grads = self._pytensor_function(*grad_vars)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/compile/function/types.py", line 983, in __call__
    raise_with_op(
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/link/utils.py", line 535, in raise_with_op
    raise exc_value.with_traceback(exc_trace)
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/compile/function/types.py", line 970, in __call__
    self.vm()
RuntimeWarning: invalid value encountered in accumulate
Apply node that caused the error: CumOp{None, add}(Subtensor{::step}.0)
Toposort index: 324
Inputs types: [TensorType(float64, shape=(None,))]
Inputs shapes: [(3,)]
Inputs strides: [(-8,)]
Inputs values: [array([-inf,   0.,  inf])]
Outputs clients: [[Subtensor{::step}(CumOp{None, add}.0, -1)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1374, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1049, in access_term_cache
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1049, in <listcomp>
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1374, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1049, in access_term_cache
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1049, in <listcomp>
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1374, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1204, in access_term_cache
    input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.
"""

The above exception was the direct cause of the following exception:

RuntimeWarning: invalid value encountered in accumulate
Apply node that caused the error: CumOp{None, add}(Subtensor{::step}.0)
Toposort index: 324
Inputs types: [TensorType(float64, shape=(None,))]
Inputs shapes: [(3,)]
Inputs strides: [(-8,)]
Inputs values: [array([-inf,   0.,  inf])]
Outputs clients: [[Subtensor{::step}(CumOp{None, add}.0, -1)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1374, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1049, in access_term_cache
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1049, in <listcomp>
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1374, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1049, in access_term_cache
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1049, in <listcomp>
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1374, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1204, in access_term_cache
    input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/avicenna/Dropbox/data_analysis/MODELING/bayesian_clustering/debug/pymc_debug.py", line 80, in <module>
    idata = pm.sample(random_seed=random_seed)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 764, in sample
    _mp_sample(**sample_args, **parallel_args)
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 1153, in _mp_sample
    for draw in sampler:
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pymc/sampling/parallel.py", line 448, in __iter__
    draw = ProcessAdapter.recv_draw(self._active)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pymc/sampling/parallel.py", line 330, in recv_draw
    raise error from old_error
pymc.sampling.parallel.ParallelSamplingError: Chain 3 failed with: invalid value encountered in accumulate
Apply node that caused the error: CumOp{None, add}(Subtensor{::step}.0)
Toposort index: 324
Inputs types: [TensorType(float64, shape=(None,))]
Inputs shapes: [(3,)]
Inputs strides: [(-8,)]
Inputs values: [array([-inf,   0.,  inf])]
Outputs clients: [[Subtensor{::step}(CumOp{None, add}.0, -1)]]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1374, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1049, in access_term_cache
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1049, in <listcomp>
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1374, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1049, in access_term_cache
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1049, in <listcomp>
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1374, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/home/avicenna/miniconda3/envs/pymc_env_5_9_0/lib/python3.11/site-packages/pytensor/gradient.py", line 1204, in access_term_cache
    input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)

The cholesky factorization of a diagonal matrix is just its elemwise square roots, so you can maybe get the speedup in the diagonal model by just doing pm.MvNormal.dist(mu=centroids[i], chol=pt.sqrt(covs[i]))

1 Like

Thanks, I confirm, this way also gives the same speed up for 5.9.0. But 5.9.1 still stuck.

That Blockwise{SolveTriangular{ is suspicious. It’s also behind BUG: Regression in JAX model ops Β· Issue #6993 Β· pymc-devs/pymc Β· GitHub

@iavicenna could you test if the performance in the last version improves if you include this snippet before importing pymc

from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.rewriting.basic import register_canonicalize
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.slinalg import SolveBase

@register_canonicalize
@node_rewriter([Blockwise])
def batched_1d_solve_to_2d_solve(fgraph, node):
    """Replace a batched Solve(a, b, b_ndim=1) by Solve(a, b.T, b_ndim=2).T

    This works when `a` is a matrix, and `b` has arbitrary number of dimensions.
    Only the last two dimensions are swapped.
    """
    from pytensor.tensor.rewriting.linalg import _T
    
    core_op = node.op.core_op

    if not isinstance(core_op, SolveBase):
        return None

    if node.op.core_op.b_ndim != 1:
        return None

    [a, b] = node.inputs

    # Check `b` is actually batched
    if b.type.ndim == 1:
        return None

    # Check `a` is a matrix (possibly with degenerate dims on the left)
    a_batch_dims = a.type.broadcastable[:-2]
    if not all(a_batch_dims):
        return None
    # We squeeze degenerate dims, as they will be introduced by the new_solve again
    elif len(a_batch_dims):
        a = a.squeeze(axis=tuple(range(len(a_batch_dims))))

    # Recreate solve Op with b_ndim=2
    props = core_op._props_dict()
    props["b_ndim"] = 2
    new_core_op = type(core_op)(**props)
    matrix_b_solve = Blockwise(new_core_op)

    # Apply the rewrite
    new_solve = _T(matrix_b_solve(a, _T(b)))

    old_solve = node.outputs[0]

    return [new_solve]

Note that in an interactive environment that snippet can only be run once.

I can confirm the logp_dlogp is about 10x slower in 5.9.1 than in 5.9.0, even with the β€œfix above”. And it’s definitely due to Allow batched parameters in MvNormal and MvStudentT distributions by ricardoV94 Β· Pull Request #6897 Β· pymc-devs/pymc Β· GitHub