Pytesting "pymc3 functions"

I have written a number of functions that return pymc3 variables, for example a prior.
I am finding it a bit difficult to test these as the result is typically a theano tensor of some sort. Do you have any recommendations on how to test such functions?

I end up writing a lot of tests for types and len(model.free_RVs) but it seems to be a tedious job.

thank you

Can you provide more details about the type of functions you are trying to test?

You can always test the output of a "PyMC"function by compiling your expressions within a theano.function([inputs], output) but maybe you can do something simpler by checking the nodes returned by your function…

Hi, thank you for getting back to me @ricardoV94.
My apologies, I should have given an example.
My functions are to generate a simple inference case based on a dictionary input.
So lets say we have a dict with something like

{"distribution":"normal","mean":5,"sd":2}

as an input to my generate_normal_prior(specs: dict)-function.
this function returns a prior based on the input. The return value is then of type theano.var.TensorVariable. If the funtion returned i.e. a float,array,dict i may have done an assert statement verifying i got my wanted results. How to do this is less trivial to me for a function that returns a theano variable.

Hope this was clearifying

did that make sense @ricardoV94 ?

thanks

I am still not sure what your function is returning. I imagined it was something like:

def generate_normal(params):
  return pm.Normal(**params)

with pm.Model() as m:
  x = generate_normal({'mu': 0, 'sigma': 2})

In which case x is not a theano TensorVariable but a pymc3.model.FreeRV no?

you are exactly right. That is basically my function. However it seems to be a theano type?

image

Yeah, you are right, it was just Ipython pretty printing that confused me.

So you can evaluate the logp of your distributions via

with pm.Model() as m:
  x = generate_normal({'name': 'x', 'mu': 0, 'sigma': 2})

x.logp({'x': 1})  # array(-1.73708571)
scipy.stats.norm(0, 2).logpdf(1)  # -1.737085713764618
1 Like

ah thats a neat trick to check that you created the right kind of variable. awesome, thank you so much!