I’m trying to sample values x1 and x2 defined as:
0 <= x1 <= x2 <= 1
x1 + x2 <= 1
The second condition can be fulfilled by doing:
x_3d = pm.Dirichlet("x_3d", [1, 1, 1])
x_2d = x_3d[:2]
The values in x_2d
add up to less than or equal to 1 and are both positive. But how do I sort the values?
Is it possible to apply the ordered
transform to x_2d
?
Coincidentally, I am also trying to work this out. My understanding that right now this is not possible because of the bug in the ordered transform (Ordered transform incompatible with constrained space transforms (ZeroSum, Simplex, etc.) · Issue #6975 · pymc-devs/pymc · GitHub)
I am playing around with doing it “manually” but I am very early in the process. All I got so far is that the following transformation can give us an ordered simplex from an unconstrained space:
import pytensor.tensor as tt
def ordered_simplex(vals):
n = vals.shape[-1]
v_exp = tt.exp(vals)
weights = tt.arange(start=n, stop=0, step=-1)
return tt.cumsum(v_exp)/tt.dot(v_exp, weights)
But I don’t think on its own this is enough… I am not sure right now:
- Whether to define a PyMC Transform subclass using this (which would require also defining and inverse and a log determinant of a Jacobian) and then passing this to Dirichlet in order to override the default Simplex transform that the distribution uses
- Also if I go this way the transform would only work for MCMC; it’s not going to produce ordered simplices using forward sampling and I am not sure yet how to do that
- Whether to eschew Dirichlet entirely and just sample some, say, Normals and just pipe them directly through
ordered_simplex
. But (a) I am not sure how “advisable” that is and (b) I think that wouldn’t allow it to be used for an observed random variable, for which you’d end up needing to definte a custom distribution. And at that stange I am not sure whether it’s any different/better than overwriting the default transform for Dirichlet (but again, not sure how to deal with defining the rv
operation for it)
Maybe @ricardoV94 has some thoughts?
Just to clarify, it’s not a bug, those transforms are not built to be composable. Namely order transform is a distorting transform. An ordered variable will in general not correspond to any simple distribution.
Also observed variables can’t be transformed, so I’m abut confused by your text.
You can apply a ChainedTransform (ordered + simplex) to x_3d, then x_3d[:2] will follow both properties
But if you order a 3-vector and then you discard one component, the result is different to discard one component and then sort the two remaining values.
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats
rng = np.random.default_rng()
n = 1000
# Sample, drop one coordinate, then sort
sx3 = scipy.stats.dirichlet.rvs([1,1,1], size=n, random_state=rng)
sx_sorted = np.sort(sx3[:,:2], axis=1)
# Sample, sort, drop one coordinate
sy3 = scipy.stats.dirichlet.rvs([1,1,1], size=n, random_state=rng)
sy_sorted = np.sort(sy3, axis=1)[:,:2]
plt.scatter(sx_sorted[:,0], sx_sorted[:,1], marker='+')
plt.scatter(sy_sorted[:,0], sy_sorted[:,1], marker='.');
plt.show()
The points marked with ‘+’ fullfill both conditions. Orange points are biased
Yes, that makes sense. Doesn’t the ChainTransform followed by drop give you the less biased case?