Possible bug with backwards-transforming a BoundedContinuous variable

Hi pymc devs and all,

I’d like to be able to perform the “inverse” transform on model variables that have been automatically transformed. (For example, sometimes I’d like to be able to calculate a model’s logp by passing in model parameters from the untransformed space.)

After some digging around, the approach that I’ve found is to access the transform object via the transformed tensorvariable’s “tag” attribute. I’m not 100% sure if this is the recommended way to do this, so I’d appreciate feedback if there is a less hacky solution. For example, we can do the following:

with pm.Model() as m:
    x = pm.Beta("x", alpha=1, beta=1)
m.value_vars[0]  # x_logodds__
m.value_vars[0].tag.transform.backward(0).eval()  # array(0.5, dtype=float32)

with pm.Model() as m:
    x = pm.Gamma("x", alpha=1, beta=1)
m.value_vars[0]  # x_log__
m.value_vars[0].tag.transform.backward(np.log(1.2)).eval()  # array(1.2)

However, this approach is failing with variables from a BoundedContinuous distribution. (E.g. uniform, truncatednormal.) It appears that the values in bound_args_indices is out of bounds when the transform object tries to access the variable’s bound information:

with pm.Model() as m:
    x = pm.Uniform("x", lower=1, upper=2)
m.value_vars[0]  # x_interval__
m.value_vars[0].tag.transform.backward(5).eval()
IndexError                                Traceback (most recent call last)
test.ipynb Cell 2 in <cell line: 14>()
     12     x = pm.Uniform("x", lower=1, upper=2)
     13 m.value_vars[0]  # x_interval__
---> 14 m.value_vars[0].tag.transform.backward(5).eval()

File .../lib/python3.10/site-packages/aeppl/transforms.py:446, in IntervalTransform.backward(self, value, *inputs)
    445 def backward(self, value, *inputs):
--> 446     a, b = self.args_fn(*inputs)
    448     if a is not None and b is not None:
    449         sigmoid_x = at.sigmoid(value)

File .../lib/python3.10/site-packages/pymc/distributions/continuous.py:175, in bounded_cont_transform.<locals>.transform_params(*args)
    173 lower, upper = None, None
    174 if bound_args_indices[0] is not None:
--> 175     lower = args[bound_args_indices[0]]
    176 if bound_args_indices[1] is not None:
    177     upper = args[bound_args_indices[1]]

IndexError: tuple index out of range

Because I’m not 100% sure whether or not I’m doing the correct thing, I wanted to ask this in the discourse. Any help would be much appreciated. Thanks!


Edit: I’ve also noticed that the same error appears when calling transform.forward() from an automatically-transformed variable’s transform. I’m now curious—if this method errors out when performing the forward-transform, how/where does pymc calculate the transformation?

You can check out the transform-related tests, probably starting with check_transform() for some relevant vignettes.

Thanks for the quick reply! In fact it was the tests that were helpful for figuring out a way to access the transforms within a model :). However, I’m not sure if there’s a test that would touch on this particular error. The main candidate tests for this that I see are test_uniform() and test_triangular(); but they both create their own Interval object and provide it with a bounds_fn, which seems to sidestep this error (which stems from using the default bounds_fn i.e. here)

1 Like

Yeah, pointing you to the test(s) was more of a way to figure out how the transforms “should” be used, rather than finding some test that would detect your specific issue.

Not sure if this is “ideal”, but you might try something like this:

x = model.free_RVs[0]
x_val_transf = x.tag.value_var
test_val_transf = 5
transform = x_val_transf.tag.transform
test_val_untransf = transform.backward(test_val_transf, *x.owner.inputs).eval()
1 Like

We should add an option to get the model logp on untransformed space

Thanks again @cluhmann, I definitely missed the cue to try passing x.owner.inputs to the transform function. That worked great for my problem cases. I think I now understand that bound_args_indices refers to the indices of that inputs variable, and the aeppl transform functions expect the inputs to be passed.

1 Like