Profiling CustomDist

Hi,

I’m trying to profile a model which uses a CustomDist and I am getting a AttributeError: 'function' object has no attribute 'owner' error.

The code is the following

def rl_logp(observed, alpha_rew, alpha_pun, beta_rew, beta_pun, sens_rew, sens_pun, lapse, decay, perseverance):
    # unpack variables
    choices = observed[:,0].astype('int32')
    outcomes = observed[:,1:].astype('int32')
    n_subj = choices.shape[1]

    choices_ = pt.as_tensor_variable(choices, dtype='int32')
    outcomes_ = pt.as_tensor_variable(outcomes, dtype='int32')

    beliefs = 0.5 * pt.ones((n_subj,4), dtype='float64')       # [n_subj x 4]
    choice_probs_ = 0.5 * pt.ones((n_subj,1), dtype='float64') # [n_subj x 1]
    choice_trace_ = pt.zeros((n_subj,2), dtype='float64')

    [beliefs_pymc, choice_probs_pymc, choice_trace_pymc], updates = scan(
        fn=update_belief, 
        sequences=[choices_, outcomes_],
        non_sequences=[n_subj, alpha_rew, alpha_pun, beta_rew, beta_pun, sens_rew, sens_pun, lapse, decay, perseverance],
        outputs_info=[beliefs, choice_probs_, choice_trace_]
    )

    # pymc expects the log-likelihood
    ll = pt.sum(pt.log(choice_probs_pymc))
    return ll

with pm.Model() as m:
    choices_ = pm.ConstantData('choices_', choices)
    outcomes_ = pm.ConstantData('outcomes_', outcomes)

    # priors on only some parameters
    alpha_rew= pm.Beta(name="alpha_rew", alpha=1, beta=1, shape=n_subj)
    alpha_pun = pm.Beta(name="alpha_pun", alpha=1, beta=1, shape=n_subj)
    beta_rew = pm.HalfNormal(name="beta_rew", sigma=10, shape=n_subj)
    beta_pun = pm.HalfNormal(name="beta_pun", sigma=10, shape=n_subj)
    sens_rew = pm.Deterministic('sens_rew', pt.ones(shape=n_subj))
    sens_pun = pm.Deterministic('sens_pun', pt.ones(shape=n_subj))
    lapse = pm.Deterministic('lapse', pt.zeros(shape=n_subj))
    decay = pm.Deterministic('decay', pt.zeros(shape=n_subj))
    perseverance = pm.Deterministic('perseverance', 1 * pt.zeros(shape=n_subj))

    observed = np.concatenate([choices[:,:,None], outcomes], axis=2)

    logp = pm.CustomDist('logp', alpha_rew, alpha_pun, beta_rew, beta_pun, sens_rew, sens_pun, lapse, decay, perseverance, 
                         logp=rl_logp, observed=observed)

and it gets called using m.profile(m.logp).summary().
What’s the correct way to profile a function like this?
Thanks!

I couldn’t run the code as posted so I can’t give a precise answer, but m.logp is a method that returns the logp graph. So maybe m.profile(m.logp()).summary()?

Thanks Jesse, that works!
I think it may be a change in the API from v3 (the one used in the profiling example) to v5 - I’ve seen that it’s already been reported in the Github issues.

Unsurprisingly, 99+% of the time is spent in the function (update_belief) called by scan. I’ve tried to include the profile=True flags when calling it, but there is very limited information in the profiling output about what’s happening in there.
Do you know if there is anyway to get more profiling information for functions called by scan?
I’ve included all the code and the profiling output in this notebook.

Thanks!
Filippo

To get the profiler to see the inner scan function you have to do some extra gymnastics. Basically you have to tell the inner function to share it’s profile manually:

from pytensor.scan.op import Scan

for node in shared_scan_fn.maker.fgraph.apply_nodes:
    if isinstance(node.op, Scan) and hasattr(node.op.fn, "profile"):
        node.op.fn.profile.summary()

Where shared_scan_fn is the compiled fuction you’re profiling, so logp in your case. Compile the logp function with f_logp = model.compile_logp, run it a bunch of times with %timeit on the test point, then use that loop to get the outer and inner profiles.

Also remember that scan graphs have scan gradients, which can acutally be the true cause of the problem. So it is useful to also profile the dlogp graph.

Also also, if you have a scan-based model, you should only be using JAX samplers. I have seen up to 1000x speedups on scan models by switching to JAX.

I’ve quickly changed the sampler to numpyro and sampling went from around 10-20 minutes to 10 seconds.
Is it also possible to use the numpyro ADVI implementation in a similar way?

Do you think it’s worth reimplementing the model in JAX and then wrap it into a PyMC function (basically doing this How to wrap a JAX function for use in PyMC — PyMC example gallery) or the gains would be limited?

Thanks,
Filippo