What is the correct way to assert equality with tensor objects?

Hey y’all,

I am writing test for my code and need to know the best way to assert the equality of tensor objects using pytensor.

A simple comparison like assert pt.as_tensor_variable(0) == pt.as_tensor_variable(0) doesn’t work.

I can do a comparison like assert str(pt.as_tensor_variable(0)) == str(pt.as_tensor_variable(0)) which works, but isn’t totally ideal.

But do y’all have any utility functions (like for example numpy.testing’s assert_array_equals

1 Like

Welcome!

For debugging purposes, you are probably looking for the evaluated tensors, not the tensors themselves.
So this:

pt.as_tensor_variable(0).eval() == pt.as_tensor_variable(0).eval()

returns True. So does:

pt.eq(pt.as_tensor_variable(0), pt.as_tensor_variable(0)).eval()

This is also an option to deal with floating point values:

pt.isclose(pt.as_tensor_variable(0), pt.as_tensor_variable(0+10e-20)).eval()

pt.allclose() may get you something similar to assert_array_equals(). @ricardoV94 may have other suggestions.

Can you provide more info on what are you trying to test exactly?

Sorry for never answering this, I have a function that’s calling a generic pymc distribution that looks something like the following:

def get_distribution(dist_name, **kwargs):
    Distribution = getattr(pm, dist_name)
    return Distribution.dist(size=(1,), **kwargs)

And I wanted to make sure the kwargs were being passed in the right way, so for example if we were doing:

pm.Normal.dist(size=(1,), mu=0.0, sigma=0.01)

Then asserting that mu and sigma got passed directly. After originally posting this, I was able to get things working by just casting to strings, but due to some recent changes in pymc those expected strings needed to be updated, so I returned to this post to implement the test using .eval.

I was able to make a working test that looks something like this:

def test():
    expected_args = [[0.], [0.01]]
    n_args = len(expected_args)
    output = get_distribution("Normal", mu=0.0, sigma=0.01)
    dist_args = output.owner.inputs[-n_args:]
    dist_args_raw = [x.eval() for x in dist_args]
    assert dist_args_raw == expected_args

If you only care about the parameters (and not size/rng) you can do dist.owner.op.dist_params(dist.owner) to get those. Then the eval approach you did is fine