Is there a version with pymc 4 with aesara ? I try to use your code and I get an Input dimension mismatch error.
def update_Q(action, reward,
Qs,
alpha):
"""
This function updates the Q table according to the RL update rule.
It will be called by theano.scan to do so recursevely, given the observed data and the alpha parameter
This could have been replaced be the following lamba expression in the theano.scan fn argument:
fn=lamba action, reward, Qs, alpha: tt.set_subtensor(Qs[action], Qs[action] + alpha * (reward - Qs[action]))
"""
Qs = at.set_subtensor(Qs[action], Qs[action] + alpha * (reward - Qs[action]))
return Qs
def theano_llik_td(alpha, beta, actions, rewards):
rewards = aesara.shared(np.asarray(rewards, dtype='int16'))
actions = aesara.shared(np.asarray(actions, dtype='int16'))
# Compute the Qs values
Qs = 0.5 * at.ones((2), dtype='float64')
Qs, updates = aesara.scan(
fn=update_Q,
sequences=[actions, rewards],
outputs_info=[Qs],
non_sequences=[alpha])
# Apply the sotfmax transformation
Qs_ = Qs[:-1] * beta
log_prob_actions = Qs_ - pm.math.logsumexp(Qs_, axis=1)
# Calculate the negative log likelihod of the observed actions
log_prob_actions = log_prob_actions[at.arange(actions.shape[0]-1), actions[1:]]
return at.sum(log_prob_actions) # PyMC makes it negative by default
with pm.Model() as m:
alpha = pm.Beta('alpha', 1, 1)
beta = pm.HalfNormal('beta', 10)
like = pm.Potential('like', theano_llik_td(alpha, beta, actions, rewards))
tr = pm.sample()
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:962, in Function.__call__(self, *args, **kwargs)
960 try:
961 outputs = (
--> 962 self.fn()
963 if output_subset is None
964 else self.fn(output_subset=output_subset)
965 )
966 except Exception:
ValueError: Input dimension mismatch. One other input has shape[1] = 2, but input[1].shape[1] = 99.
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
Input In [12], in <cell line: 7>()
9 beta = pm.HalfNormal('beta', 10)
10 like = pm.Potential('like', theano_llik_td(alpha, beta, actions, rewards))
---> 11 tr = pm.sample()
File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling.py:533, in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
531 [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()]
532 _log.info("Auto-assigning NUTS sampler...")
--> 533 initial_points, step = init_nuts(
534 init=init,
535 chains=chains,
536 n_init=n_init,
537 model=model,
538 random_seed=random_seed_list,
539 progressbar=progressbar,
540 jitter_max_retries=jitter_max_retries,
541 tune=tune,
542 initvals=initvals,
543 **kwargs,
544 )
546 if initial_points is None:
547 # Time to draw/evaluate numeric start points for each chain.
548 ipfns = make_initial_point_fns_per_chain(
549 model=model,
550 overrides=initvals,
551 jitter_rvs=filter_rvs_to_jitter(step),
552 chains=chains,
553 )
File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling.py:2487, in init_nuts(init, chains, n_init, model, random_seed, progressbar, jitter_max_retries, tune, initvals, **kwargs)
2480 _log.info(f"Initializing NUTS using {init}...")
2482 cb = [
2483 pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="absolute"),
2484 pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"),
2485 ]
-> 2487 initial_points = _init_jitter(
2488 model,
2489 initvals,
2490 seeds=random_seed_list,
2491 jitter="jitter" in init,
2492 jitter_max_retries=jitter_max_retries,
2493 )
2495 apoints = [DictToArrayBijection.map(point) for point in initial_points]
2496 apoints_data = [apoint.data for apoint in apoints]
File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling.py:2381, in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries)
2379 if i < jitter_max_retries:
2380 try:
-> 2381 model.check_start_vals(point)
2382 except SamplingError:
2383 # Retry with a new seed
2384 seed = rng.randint(2**30, dtype=np.int64)
File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/pymc/model.py:1722, in Model.check_start_vals(self, start)
1716 valid_keys = ", ".join(self.named_vars.keys())
1717 raise KeyError(
1718 "Some start parameters do not appear in the model!\n"
1719 f"Valid keys are: {valid_keys}, but {extra_keys} was supplied"
1720 )
-> 1722 initial_eval = self.point_logps(point=elem)
1724 if not all(np.isfinite(v) for v in initial_eval.values()):
1725 raise SamplingError(
1726 "Initial evaluation of model at starting point failed!\n"
1727 f"Starting values:\n{elem}\n\n"
1728 f"Initial evaluation results:\n{initial_eval}"
1729 )
File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/pymc/model.py:1763, in Model.point_logps(self, point, round_vals)
1757 factors = self.basic_RVs + self.potentials
1758 factor_logps_fn = [at.sum(factor) for factor in self.logpt(factors, sum=False)]
1759 return {
1760 factor.name: np.round(np.asarray(factor_logp), round_vals)
1761 for factor, factor_logp in zip(
1762 factors,
-> 1763 self.compile_fn(factor_logps_fn)(point),
1764 )
1765 }
File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/pymc/model.py:1862, in PointFunc.__call__(self, state)
1861 def __call__(self, state):
-> 1862 return self.f(**state)
File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:975, in Function.__call__(self, *args, **kwargs)
973 if hasattr(self.fn, "thunks"):
974 thunk = self.fn.thunks[self.fn.position_of_error]
--> 975 raise_with_op(
976 self.maker.fgraph,
977 node=self.fn.nodes[self.fn.position_of_error],
978 thunk=thunk,
979 storage_map=getattr(self.fn, "storage_map", None),
980 )
981 else:
982 # old-style linkers raise their own exceptions
983 raise
File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/aesara/link/utils.py:538, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
533 warnings.warn(
534 f"{exc_type} error does not allow us to add an extra error message"
535 )
536 # Some exception need extra parameter in inputs. So forget the
537 # extra long error message in that case.
--> 538 raise exc_value.with_traceback(exc_trace)
File ~/miniforge3/envs/pymc_env/lib/python3.10/site-packages/aesara/compile/function/types.py:962, in Function.__call__(self, *args, **kwargs)
959 t0_fn = time.time()
960 try:
961 outputs = (
--> 962 self.fn()
963 if output_subset is None
964 else self.fn(output_subset=output_subset)
965 )
966 except Exception:
967 restore_defaults()
ValueError: Input dimension mismatch. One other input has shape[1] = 2, but input[1].shape[1] = 99.
Apply node that caused the error: Elemwise{Sub}[(0, 0)](Elemwise{Mul}[(0, 0)].0, Elemwise{Composite{(i0 + log(i1))}}[(0, 0)].0)
Toposort index: 68
Inputs types: [TensorType(float64, (None, None)), TensorType(float64, (1, None))]
Inputs shapes: [(99, 2), (1, 99)]
Inputs strides: [(16, 8), (792, 8)]
Inputs values: ['not shown', 'not shown']
Outputs clients: [[AdvancedSubtensor(Elemwise{Sub}[(0, 0)].0, ARange{dtype='int64'}.0, Subtensor{int64::}.0)]]
HINT: Re-running with most Aesara optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the Aesara flag 'optimizer=fast_compile'. If that does not work, Aesara optimizations can be disabled with 'optimizer=None'.
HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.