Hey y’all,
I am writing test for my code and need to know the best way to assert the equality of tensor objects using pytensor.
A simple comparison like assert pt.as_tensor_variable(0) == pt.as_tensor_variable(0)
doesn’t work.
I can do a comparison like assert str(pt.as_tensor_variable(0)) == str(pt.as_tensor_variable(0))
which works, but isn’t totally ideal.
But do y’all have any utility functions (like for example numpy.testing
’s assert_array_equals
1 Like
Welcome!
For debugging purposes, you are probably looking for the evaluated tensors, not the tensors themselves.
So this:
pt.as_tensor_variable(0).eval() == pt.as_tensor_variable(0).eval()
returns True
. So does:
pt.eq(pt.as_tensor_variable(0), pt.as_tensor_variable(0)).eval()
This is also an option to deal with floating point values:
pt.isclose(pt.as_tensor_variable(0), pt.as_tensor_variable(0+10e-20)).eval()
pt.allclose() may get you something similar to assert_array_equals()
. @ricardoV94 may have other suggestions.
Can you provide more info on what are you trying to test exactly?
Sorry for never answering this, I have a function that’s calling a generic pymc distribution that looks something like the following:
def get_distribution(dist_name, **kwargs):
Distribution = getattr(pm, dist_name)
return Distribution.dist(size=(1,), **kwargs)
And I wanted to make sure the kwargs were being passed in the right way, so for example if we were doing:
pm.Normal.dist(size=(1,), mu=0.0, sigma=0.01)
Then asserting that mu
and sigma
got passed directly. After originally posting this, I was able to get things working by just casting to strings, but due to some recent changes in pymc those expected strings needed to be updated, so I returned to this post to implement the test using .eval
.
I was able to make a working test that looks something like this:
def test():
expected_args = [[0.], [0.01]]
n_args = len(expected_args)
output = get_distribution("Normal", mu=0.0, sigma=0.01)
dist_args = output.owner.inputs[-n_args:]
dist_args_raw = [x.eval() for x in dist_args]
assert dist_args_raw == expected_args
If you only care about the parameters (and not size/rng) you can do dist.owner.op.dist_params(dist.owner)
to get those. Then the eval approach you did is fine