Using data with different "starting point"

Hey everyone,

I’m fairly new to pymc3, and right now I’m having trouble putting my own model into work. I’m not even sure it’s suitable to use pymc3 in that case but hopefully some answers will provide some hints for that as well. I will probably get the term wrong, but I hope it’ll be clear enough.

Basically I have a model with 2 parameters that receives as input a stimulus array and generates as output a trajectory (xy coordinates). I also have a dataset which is basically a few trajectories, all coming from the same “sample”, but each trajectory has a different starting and different stimulus. My goal is to estimate the parameters that resulted in the data being generated the way it did.

To simplify things, I only took one data trajectory (and it’s corresponding stimulus) and added noise to generate 20 (relatively similar) trajectories. I then run pymc3 to estimate the model parameters (see code below). This was relatively successful and the traces look good. My question is how can I integrate the different data trajectories, given they all have a different starting point and stimuli? It seems like I have to call my model each time with a different stimulus.
Thanks for any help :slight_smile:
I attach the code (that works so far) below, since it might be helpful

@theano.compile.ops.as_op(itypes=[tt.dvector, tt.dscalar, tt.dscalar], otypes=[tt.dmatrix])
def controlModel(stimulus, param1, param2):
    ## does all sort of stuff
    return np.vstack((np.array(xs), np.array(ys))).T # size is (700,2)

stimulus = np.array(...) # size (700,1)
data = list(...) # data is 20 lists with each shaped (700,2)
ndraws = 2000
nburn = 1500
with pm.Model() as model:
    param1= pm.Uniform('param1', lower=0.1, upper=3)
    param2 = pm.Uniform('param2', lower=0.1, upper=3)
    mod = controlModel(stimulus, param1, param2)
    y = pm.Normal('y', mu=mod, sd=sdnoise, observed=data)
    trace = pm.sample(draws=ndraws,
                      discard_tuned_samples=True,step= pm.Metropolis([pterm,iterm]))
a = pm.traceplot(trace)