Currently, a simple random walk example fails for me:
```python
import numpy a…s np
import pymc as pm
import matplotlib.pyplot as plt
import aesara
import aesara.tensor as at
import aeppl
num_timesteps = 100
data = np.random.normal(0, 2.5, size=num_timesteps).cumsum()
plt.plot(data);
with pm.Model() as m:
sigma = pm.HalfNormal("sigma", 5.)
mu = pm.Normal("mu", 0., 1.)
X_rv, updates = aesara.scan(
fn=lambda x_tm1: at.random.normal(x_tm1, sigma),
outputs_info=[{"initial": mu}],
n_steps=num_timesteps
)
m.register_rv(X_rv, name="X_rv", data=data)
# X_rv = pm.GaussianRandomWalk("X_rv", mu, sigma, observed=data)
idata = pm.sample()
```
with:
```python
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
/var/folders/7p/srk5qjp563l5f9mrjtp44bh800jqsw/T/ipykernel_33625/1955773877.py in <module>
9 m.register_rv(X_rv, name="X_rv", data=data)
10 # X_rv = pm.GaussianRandomWalk("X_rv", mu, sigma, observed=data)
---> 11 idata = pm.sample()
~/Documents/OSS/pymc/pymc/sampling.py in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
528
529 initial_points = None
--> 530 step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
531
532 if isinstance(step, list):
~/Documents/OSS/pymc/pymc/sampling.py in assign_step_methods(model, step, methods, step_kwargs)
204 # variables
205 selected_steps = defaultdict(list)
--> 206 model_logp = model.logp()
207
208 for var in model.value_vars:
~/Documents/OSS/pymc/pymc/model.py in logp(self, vars, jacobian, sum)
756 rv_logps: List[TensorVariable] = []
757 if rv_values:
--> 758 rv_logps = joint_logp(list(rv_values.keys()), rv_values, sum=False, jacobian=jacobian)
759 assert isinstance(rv_logps, list)
760
~/Documents/OSS/pymc/pymc/distributions/logprob.py in joint_logp(var, rv_values, jacobian, scaling, transformed, sum, **kwargs)
269 logp_var_dict = {}
270 for value_var in rv_values.values():
--> 271 logp_var_dict[value_var] = temp_logp_var_dict[value_var]
272
273 if scaling:
KeyError: X_rv{[-1.502085..27194e+01]}
```
Inspecting with aeppl seems to indicate it does not recognize the RV result from scan:
```python
x_vv = at.constant(data)
mu_vv = mu.clone()
sigma_vv = sigma.clone()
logp_dict = aeppl.factorized_joint_logprob({X_rv: x_vv, mu: mu_vv, sigma: sigma_vv})
logp_dict
# ==> {mu: mu_logprob, sigma: sigma_logprob}
```
cc @ricardoV94