I am trying to write some code with tensor valued variables where I need to know which dimension to apply e.g. x.mean(axis=??)
to. I can guess but I’d like a simple way to do the equivalent of print(x.shape)
which seems to provide only symbolic information. aesara.printing.debugprint(x)
prints a large thing without clear shape info. Here’s an example
def make_test_model(dt=0.1, s2=0.1, k=0.1, nt=60, obs=None, N=10):
import aesara
model = pm.Model()
def sde_fn(x,a):
return x-x**3/3+a + k*x.mean(), s2
with model:
a = pm.distributions.Normal("a", mu=0, sigma=1, shape=N)
xi = pm.distributions.Normal.dist(0, 1, shape=N)
x = pm.distributions.timeseries.EulerMaruyama("x", dt, sde_fn, (a,), steps=nt, init_dist=xi)
r = (x+2)**2/16
# not sure what shape r is, so I don't know what ndim should be
rw = r[...,:nt//10*10].reshape(r.shape[:-1] + (nt//10, 10,), ndim=3).mean(axis=-1)
y_kwargs = {}
if obs is not None:
y_kwargs['observed'] = obs
y = pm.distributions.Normal("y", mu=rw, sigma=0.1, **y_kwargs)
return model
pp_model = make_test_model(s2=1.5, dt=0.01, k=3, nt=1000)
with pp_model as model:
idata = pm.sample_prior_predictive(
samples=100)
which complains that
ERROR (aesara.graph.rewriting.basic): Rewrite failure due to: constant_folding
ERROR (aesara.graph.rewriting.basic): node: Subtensor{int64}(TensorConstant{(1,) of 20}, ScalarConstant{1})
ERROR (aesara.graph.rewriting.basic): TRACEBACK:
ERROR (aesara.graph.rewriting.basic): Traceback (most recent call last):
File "/home/duke/miniconda3/envs/muse/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py", line 1933, in process_node
replacements = node_rewriter.transform(fgraph, node)
File "/home/duke/miniconda3/envs/muse/lib/python3.10/site-packages/aesara/graph/rewriting/basic.py", line 1092, in transform
return self.fn(fgraph, node)
File "/home/duke/miniconda3/envs/muse/lib/python3.10/site-packages/aesara/tensor/rewriting/basic.py", line 1142, in constant_folding
required = thunk()
File "/home/duke/miniconda3/envs/muse/lib/python3.10/site-packages/aesara/link/c/op.py", line 103, in rval
thunk()
File "/home/duke/miniconda3/envs/muse/lib/python3.10/site-packages/aesara/link/c/basic.py", line 1788, in __call__
raise exc_value.with_traceback(exc_trace)
IndexError: index out of bounds
No values of ndim
in .reshape(..., ndim=ndim)
seem to work.
Any tips would be appreciated! thanks in advance