Error in Euler Maruyama

Dear all,

I facing an issue with the shape when trying with multivariate distribution.
Can anyone please point out my mistake?
Thank you very much

def sde(x, theta, mu, sigma):
return theta * (mu - x), sigma

import pymc3.distributions.timeseries as ts
import pymc3 as pm
with pm.Model() as wiggins_model:

volatility_theta = pm.Uniform('volatility_theta', lower=0., upper=1., shape=(n, 3, 3))
volatility_mu = pm.MvNormal('volatility_mu', mu=avg_mean, cov=cov_matrix, shape= (n, 3))
volatility_sigma = pm.Uniform('volatility_sigma', lower=0.001, upper=0.2, shape=(n, 3, 3))

volatility = ts.EulerMaruyama('volatility',
                              1.0,
                              sde,
                              (volatility_theta, volatility_mu, volatility_sigma),
                              shape=data.shape,
                              testval=np.ones_like(data))

pm.MvNormal('obs', mu=[0., 0., 0.], cov=pm.math.exp(volatility), observed=data)

trace = pm.sample(4000, cores=8, chains=2, tune=3000, random_seed=42)

ValueError Traceback (most recent call last)
in
13 (volatility_theta, volatility_mu, volatility_sigma),
14 shape=data.shape,
—> 15 testval=np.ones_like(data))
16
17 pm.MvNormal(‘obs’, mu=[0., 0., 0.], cov=pm.math.exp(volatility), observed=data)

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/pymc3/distributions/distribution.py in new(cls, name, *args, **kwargs)
120 else:
121 dist = cls.dist(*args, **kwargs)
→ 122 return model.Var(name, dist, data, total_size, dims=dims)
123
124 def getnewargs(self):

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/pymc3/model.py in Var(self, name, dist, data, total_size, dims)
1136 if getattr(dist, “transform”, None) is None:
1137 with self:
→ 1138 var = FreeRV(name=name, distribution=dist, total_size=total_size, model=self)
1139 self.free_RVs.append(var)
1140 else:

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/pymc3/model.py in init(self, type, owner, index, name, distribution, total_size, model)
1669 np.ones(distribution.shape, distribution.dtype) * distribution.default()
1670 )
→ 1671 self.logp_elemwiset = distribution.logp(self)
1672 # The logp might need scaling in minibatches.
1673 # This is done in Factor.

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/pymc3/distributions/timeseries.py in logp(self, x)
398 “”"
399 xt = x[:-1]
→ 400 f, g = self.sde_fn(x[:-1], *self.sde_pars)
401 mu = xt + self.dt * f
402 sd = tt.sqrt(self.dt) * g

in sde(x, theta, mu, sigma)
1 def sde(x, theta, mu, sigma):
----> 2 return theta * (mu - x), sigma

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/theano/tensor/var.py in sub(self, other)
118 # and the return value in that case
119 try:
→ 120 return theano.tensor.basic.sub(self, other)
121 except (NotImplementedError, TypeError):
122 return NotImplemented

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/theano/graph/op.py in call(self, *inputs, **kwargs)
251
252 if config.compute_test_value != “off”:
→ 253 compute_test_value(node)
254
255 if self.default_output is not None:

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/theano/graph/op.py in compute_test_value(node)
128 thunk.outputs = [storage_map[v] for v in node.outputs]
129
→ 130 required = thunk()
131 assert not required # We provided all inputs
132

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/theano/graph/op.py in rval()
604
605 def rval():
→ 606 thunk()
607 for o in node.outputs:
608 compute_map[o][0] = True

~/miniconda3/envs/pymc3/lib/python3.7/site-packages/theano/link/c/basic.py in call(self)
1769 print(self.error_storage, file=sys.stderr)
1770 raise
→ 1771 raise exc_value.with_traceback(exc_trace)
1772
1773

ValueError: Input dimension mis-match. (input[0].shape[0] = 3200, input[1].shape[0] = 3199)