CLV - Pareto NBD questions for new user

Hello! I’m using pymc-marketing’s CLV related models and had some questions as a new user.
I feel kinda dumb, but there doesn’t seem to be a lot of discourse surrounding the CLV usage online.

I’m using the Pareto NBD model vs the BetaGeoModel ('cause I don’t want to assume customer active-ness for new users), and am trying to find methods for faster sampling/fitting methods (that aren’t MAP). I’ve been going back and forth between using the fit_method = “demz” and the fit_method=“mcmc”, but I’ve run into issues with both.

For these issues, I’m working with about 690,000 customers, using a terminal-based environment (spyder), and I am using versions:

  • pymc - 5.20.1
  • pymc_marketing - 0.11.1
  • nutpie - 0.13.4

For the demz fitting, I find that I need to constantly need to increase the draw and tune parameters and often times a chain will get stuck and estimate 6+ hours for completion (I’ve read that this could mean that my posterior geometry could be problematic but I’ve no idea how to remedy this).

For the mcmc fitting, I find that the default sampler for mcmc can take a really-really long time. Before, when I was experimenting with the BetaGeoModel I tried using the nutpie sampler and found that to be wonderful! But, it seems applying nutpie to the mcmc fitting still takes a really long time, and no progress bar appears (even if I declare progress_bar=True).

So, my main issues are: re-runs, runtime, progress-bars, chains getting stuck.
Also, trying to use the plot_expected_purchases_ppc for prior use takes an extremely long time (10-30 minutes, even if I reduce the samples to 50)??

Any guidance on this would be amazing and I appreciate any time given to this in advance.
( reiterate “long-time” a lot, to which I typically mean to be about 4 to 12 hours)

(No issues with using Gamma-gamma and doing the final CLV related portions)
-Sarah

Welcome!

@ColtAllen might be able to make some suggestions.

Hey @swayward,

Thanks for sharing your Pareto/NBD findings. It’s a powerful model, and we still have functionality to add like expected lifetime transactions and marginal transaction & dropout rates per customer, but it has some nuances as you’ve discovered.

Pareto/NBD has a very complex likelihood function, so I’m afraid it will always be slow for gradient-based samplers. demz is the fastest sampler available for this model, and the issues you’ve had while using it are interesting. Have you tried specifying Prior("HalfFlat") for all parameters in the model config? It’ll run slower, but remove any assumptions that may be impacting posterior geometry. You’ll always need to increase the tune & draw parameters with demz.

I’m trying to find methods for faster sampling/fitting methods (that aren’t MAP).

I’m curious as to why? With this many customers, the posterior averages and MAP point-estimates should be identical, though you’ll lose credibility intervals. If you’re also having issues with MAP, please let me know.

The Pareto model is probably slow woth nutpie because it uses object mode for the gradient no? Do you see some warnings like that when you tried to use it?

Thanks for the prompt response!

I haven’t tried the HalfFlat priors. I will look into that and let you know how it goes.

For the MAP concerns, the trace plots and stats you get from doing a full run is something I wanted to be able to show my boss to help “validate” my modeling process. Although, I could try and make my case since I have such a high quantity of customers.

I get a " FutureWarning: compile_pymc was renamed to compile. Old name will be removed in future release of PyMC ". I find I even have a difficult time interrupting the run, having to kill the whole script.
I also can’t using numpyro or blackjax, those do get errors before it even compiles.
The error for those are: AttributeError: module ‘jax.scipy.special’ has no attribute ‘hyp2f1’

But, considering I don’t completely understand how the samplers work in correspondence with the model type, what you said could explain it.

The error for those are: AttributeError: module ‘jax.scipy.special’ has no attribute ‘hyp2f1’

To support numpyro & blackjax, A Hyp2F1 function needs to be written in jax, and a pytensor Op in turn. This would be challenging to implement, and benefits may be marginal given the complex likelihood.

Do you know how much faster nutpie ran than MCMC?

For the MAP concerns, . . . I could try and make my case since I have such a high quantity of customers.

There is also an rfm_train_test_split for MAP fits.