Hi,
I’m trying to profile a model which uses a CustomDist and I am getting a AttributeError: 'function' object has no attribute 'owner'
error.
The code is the following
def rl_logp(observed, alpha_rew, alpha_pun, beta_rew, beta_pun, sens_rew, sens_pun, lapse, decay, perseverance):
# unpack variables
choices = observed[:,0].astype('int32')
outcomes = observed[:,1:].astype('int32')
n_subj = choices.shape[1]
choices_ = pt.as_tensor_variable(choices, dtype='int32')
outcomes_ = pt.as_tensor_variable(outcomes, dtype='int32')
beliefs = 0.5 * pt.ones((n_subj,4), dtype='float64') # [n_subj x 4]
choice_probs_ = 0.5 * pt.ones((n_subj,1), dtype='float64') # [n_subj x 1]
choice_trace_ = pt.zeros((n_subj,2), dtype='float64')
[beliefs_pymc, choice_probs_pymc, choice_trace_pymc], updates = scan(
fn=update_belief,
sequences=[choices_, outcomes_],
non_sequences=[n_subj, alpha_rew, alpha_pun, beta_rew, beta_pun, sens_rew, sens_pun, lapse, decay, perseverance],
outputs_info=[beliefs, choice_probs_, choice_trace_]
)
# pymc expects the log-likelihood
ll = pt.sum(pt.log(choice_probs_pymc))
return ll
with pm.Model() as m:
choices_ = pm.ConstantData('choices_', choices)
outcomes_ = pm.ConstantData('outcomes_', outcomes)
# priors on only some parameters
alpha_rew= pm.Beta(name="alpha_rew", alpha=1, beta=1, shape=n_subj)
alpha_pun = pm.Beta(name="alpha_pun", alpha=1, beta=1, shape=n_subj)
beta_rew = pm.HalfNormal(name="beta_rew", sigma=10, shape=n_subj)
beta_pun = pm.HalfNormal(name="beta_pun", sigma=10, shape=n_subj)
sens_rew = pm.Deterministic('sens_rew', pt.ones(shape=n_subj))
sens_pun = pm.Deterministic('sens_pun', pt.ones(shape=n_subj))
lapse = pm.Deterministic('lapse', pt.zeros(shape=n_subj))
decay = pm.Deterministic('decay', pt.zeros(shape=n_subj))
perseverance = pm.Deterministic('perseverance', 1 * pt.zeros(shape=n_subj))
observed = np.concatenate([choices[:,:,None], outcomes], axis=2)
logp = pm.CustomDist('logp', alpha_rew, alpha_pun, beta_rew, beta_pun, sens_rew, sens_pun, lapse, decay, perseverance,
logp=rl_logp, observed=observed)
and it gets called using m.profile(m.logp).summary()
.
What’s the correct way to profile a function like this?
Thanks!