Implement rank correlation model

Hi everyone,

I am trying to implement the following model (From p.2999 of this paper):

\begin{align} \rho_{z^x z^y} & \sim \mathcal{U}(-1 ,1 ) & (1) \\ \begin{bmatrix} Z^x_i\\ Z^y_i \end{bmatrix} & \sim \mathcal{N}( \begin{bmatrix} 0\\ 0 \end{bmatrix}, \begin{bmatrix} 1 & \rho_{z^x z^y} \\ \rho_{z^x z^y} & 1 \end{bmatrix} ) &(2)\\ r^x_i &= \text{Rank}(Z^x_i) & (3) \\ r^y_i &= \text{Rank}(Z^y_i) & (4) \end{align}

Simulating data from the model is quite easy:

rho_simulated = -0.8
zs_simulated = np.random.multivariate_normal(
ranks_z = rankdata(zs_simulated, method='min', axis=0)
ranks_x, ranks_y = ranks_z.T

And the results look as expected:


However, I am having troubles with the PyMC3 model. Lines 1 and 2 of the model above are easy. The problem is that the observed values (namely, the ranks) are a deterministic transformation of random variables, and therefore cannot be passed directly as observed for reasons explained in various other posts in this discourse.

One option that came to mind is to enforce the ranks with a Potential as follows:

from scipy import stats
import numpy as np
import pymc3 as pm
import theano.tensor as tt

argsort_x, argsort_y = zs_simulated.argsort(axis=0).T

with pm.Model() as model:
    # single parameter encoding the strength
    # of the rank correlation
    rho = pm.Uniform('rho', -1,1)
    # build covariance matrix 
    cov = tt.stack([1, rho, rho, 1]).reshape((2, 2))

    # for each i (i.e. each category), 
    # get the two Zs (which determine the rank)
    zs = pm.MvNormal(
    z_x_sorted = zs[0][argsort_x]
    z_y_sorted = zs[1][argsort_y]
    # make sure that the elements are ordered (this encodes the observed ranks)
        tt.switch(tt.all(z_x_sorted[:-1] <= z_x_sorted[1:]),0,-np.inf)
        tt.switch(tt.all(z_y_sorted[:-1] <= z_y_sorted[1:]),0,-np.inf)
    trace = pm.sample(

However, beyond being presumably very inefficient, it raises a Bad Initial Energy Error. The reason is not obvious to me because the test value assigns 0 to all zs by default, and so both Potentials should return logp of 0. Does this have something to do with the leapfrog steps?

Another thing that came to mind is to somehow use the ordered transform, but I am unsure how. Any help greatly appreciated! Thanks in advance.