Best Way to Handle a for Loop in pymc4

Hey Jesse, thanks for your help! I finally had time to get back to work on this. Your code sample is really helpful, and your explanation makes a lot of sense. Here’s where I landed:

import pandas as pd
import matplotlib.pyplot as plt
import pymc as pm
import numpy as np
import aesara
import aesara.tensor as at

def build_tweet_model(data, k, x0 = 0):
    tweet_model = pm.Model()
    
    with tweet_model:
        alpha = pm.Uniform("alpha", lower=-30, upper=700)
        beta = pm.LogNormal("beta", mu=0, sigma=2)
        delta = pm.Beta("delta", alpha=5, beta=1)

        outputs_info = at.as_tensor_variable(np.asarray(x0, dtype='float64'))

        def inner_func(prior_result):
            return delta * prior_result + data.gamma()

        virality, updates = aesara.scan(fn=inner_func,
                                      outputs_info=outputs_info,
                                      n_steps=k)

        shares = alpha*(pm.math.exp(beta*virality) - 1/(virality+1))
        Y_obs = pm.Normal("Y_obs", mu=shares, sigma=1, observed=time_series)

    return tweet_model


tweet_model = build_tweet_model(time_series, len(time_series))
with tweet_model:
    trace_g = pm.sample(2000, tune=1000, cores=2)

The symbolic stuff is definitely making more sense, I did a lot more reading on theano/aesara/other frameworks. In the last 3 lines, is this the proper way to call this function? In my mind it’s just returning a pymc model so I should be able to use it like a normal model when I sample. I’m also curious if adding the observable inside build_tweet_model() is the right idea.

However I’m getting a different error that might need it’s own thread: “TypeError: <class ‘numpy.typing._dtype_like._SupportsDType’> is not a generic class”. The traceback is showing that this originates in the line import pymc as pm. Google searching showed me this: python 3.x - <class 'numpy.typing._dtype_like._SupportsDType'> is not a generic class when importing the plotly.express library - Stack Overflow, but upgrading numpy on my end didn’t help.

Thanks,
Trevor Crupi