Ah ok, that makes sense, thanks. I’ll try that samples=(1,n) workaround. I’ve seen this shape confusion pop up in a number of other posts, it’d be helpful to have a small discussion/warning somewhere in the documentation.
As for the first workaround, what’s the proper import for broadcast_distribution_samples? It’s not in numpy, and I found an import for it in the pymc3 source code here, but then don’t see the corresponding function in distributions.distribution.py.
EDIT: Looks like broadcast_distribution_samples is new in pymc 3.7, which is why I couldn’t find it in my local pymc3 source code