Complex model samples very slowly - wondering about possible improvements

I think you have the right intuition with your attempt to create a vectorized function. I am not sure why switch would break the gradients. Did you write something like x = T.switch(T.isnan(x), 0., x) ? Do you think you could flatten along the ragged dimension? I know it can be a headache to keep track of the start/stop indices of different pieces of that vector but it could help you get the computations sped up.

Also, a more basic thing to check out might be where divergences occur. I notice that you have a quite high target accept rate. If you relax that and check out where the divergences occur, you may uncover the source of troublesome geometry that requires lots of leapfrog steps in NUTS to get accepted samples. The scatter plots in this notebook may be a useful example.