Hi,
Thanks alot for the detailed analysis.
1- Which version are you running? I am on 5.9.1 and running your code as is starts off with an unreasonable amount of slowness:
Auto-assigning NUTS sampler...
Initializing NUTS using advi+adapt_diag...
|--------------------| 0.02% [44/200000 00:09<12:28:29 Average Loss = 3,736.7]
I just ran the code above as is after adding the necessary imports. It does not seem like an issue with advi. If I try:
idata = pm.sample(1000, tune=1000, chains=3)
I get:
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (3 chains in 4 jobs)
NUTS: [w, x_coord, y_coord, sigma]
|--------------| 0.35% [21/6000 00:15<1:12:28 Sampling 3 chains, 0 divergences]
which does not improve much (in fact gets worse!). I thought this was an issue which could be resolved by supplying correct initial values (as a test) suprisingly, it did not fix anything. I removed scaling, no luck.
Finally it worked in a reasonable amount of speed only when I changed MvNormal to normal. So at the moment I am not able to play around with your code and whatever I write below is a bit speculative.
2- You are using MvNormal with diagonal covariance which means that components of the MvNormal are independent normal distributions (though they might have difference variances since that is left as a parameter). So from a mathematical point of view, there is no difference from using a set of independent normals. I am confused as to why that would create a difference unless the sampler treats sets of normals and a MvNormal with diag cov differently. Unfortunately, MvNormal is not an option for me because later on I am going to use it for data where censoring is present. Since censoring is not applicable for MvNormal I can not use it. There is also no reason currently for me to use MvNormal because I have a pretty good guess that features I am trying to cluster on are independent from each other (i.e clusters are spherical from a geometric point of view). Can you or @ricardoV94 elaborate as to why there would be a difference between a diagonal cov MvNormal and a set of normals? Also if MvNormal is the suggested practice, then is there anyway around to couple MvNormal with diagonal cov and censoring?
3- I did run the model above except just changing MvNormals to Normals and when I plot the trace, results I get is nowhere near as good as yours. So against my intuition (and yours) there does seem to be a difference between MvNormal with diag cov and a set of normals. However I can not play around with MvNormal due to optimization being extremely slow. I am still going to leave it on for a while in hopes that it may speed up later.
ps: I have been running this code for a while and it shows no signs of speeding up! I am on a pretty decent computer.
ps2: I have followed the installation instructions at (including numpyro, blackjax and nutpie but will try uninstalling these later):
https://www.pymc.io/projects/docs/en/latest/installation.html
I dont get any warnings when I import pymc