Shape debug information with aesara

I am trying to write some code with tensor valued variables where I need to know which dimension to apply e.g. x.mean(axis=??) to. I can guess but I’d like a simple way to do the equivalent of print(x.shape) which seems to provide only symbolic information. aesara.printing.debugprint(x) prints a large thing without clear shape info. Here’s an example

def make_test_model(dt=0.1, s2=0.1, k=0.1, nt=60, obs=None, N=10):
    import aesara
    model = pm.Model()

    def sde_fn(x,a):
        return x-x**3/3+a + k*x.mean(), s2

    with model:
        a = pm.distributions.Normal("a", mu=0, sigma=1, shape=N)
        xi = pm.distributions.Normal.dist(0, 1, shape=N)
        x = pm.distributions.timeseries.EulerMaruyama("x", dt, sde_fn, (a,), steps=nt, init_dist=xi)
        r = (x+2)**2/16
        # not sure what shape r is, so I don't know what ndim should be
        rw = r[...,:nt//10*10].reshape(r.shape[:-1] + (nt//10, 10,), ndim=3).mean(axis=-1)
        y_kwargs = {}
        if obs is not None:
            y_kwargs['observed'] = obs
        y = pm.distributions.Normal("y", mu=rw, sigma=0.1, **y_kwargs)
        
    return model

pp_model = make_test_model(s2=1.5, dt=0.01, k=3, nt=1000)

with pp_model as model:
    idata = pm.sample_prior_predictive(
        samples=100)

which complains that

ERROR (aesara.graph.rewriting.basic): Rewrite failure due to: constant_folding
ERROR (aesara.graph.rewriting.basic): node: Subtensor{int64}(TensorConstant{(1,) of 20}, ScalarConstant{1})
ERROR (aesara.graph.rewriting.basic): TRACEBACK:
ERROR (aesara.graph.rewriting.basic): Traceback (most recent call last):
  File "/home/duke/miniconda3/envs/muse/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py", line 1933, in process_node
    replacements = node_rewriter.transform(fgraph, node)
  File "/home/duke/miniconda3/envs/muse/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py", line 1092, in transform
    return self.fn(fgraph, node)
  File "/home/duke/miniconda3/envs/muse/lib/python3.10/site-packages/aesara/tensor/rewriting/basic.py", line 1142, in constant_folding
    required = thunk()
  File "/home/duke/miniconda3/envs/muse/lib/python3.10/site-packages/aesara/link/c/op.py", line 103, in rval
    thunk()
  File "/home/duke/miniconda3/envs/muse/lib/python3.10/site-packages/aesara/link/c/basic.py", line 1788, in __call__
    raise exc_value.with_traceback(exc_trace)
IndexError: index out of bounds

No values of ndim in .reshape(..., ndim=ndim) seem to work.

Any tips would be appreciated! thanks in advance

You can usually call var.shape.eval() for debugging shapes, if your graph is built only of RandomVariables (and operations on those), otherwise you will need to provide input values.

You can also call pm.draw(var), and take the shape of the output.

This guide uses the latter a lot: Distribution Dimensionality — PyMC 0+untagged.345.g2bd0611.dirty documentation

1 Like

perfect thanks!

btw this Discourse doesn’t seem to let me mark a reply as solution, is that expected? or I’m just too new a user perhaps

Example:

import pymc as pm

with pm.Model() as m:
  x = pm.Normal("x", shape=(7, 2))
  y = x[:3]
  print(y.shape.eval(), pm.draw(y).shape)  # [3 2] (3, 2)
  z = y.reshape((2, 3))
  print(z.shape.eval(), pm.draw(z).shape)  # [2 3] (2, 3)

@cluhmann @junpenglao?

Only questions can be marked with solutions. I have now changed the original post to be a question.

1 Like

Ah didn’t know, thanks!

1 Like