Write theano.scan+theano.function with multinomial sampling inside as a pymc3 model


Hello everyone,

I have been trying to incorporate a multinomial sampling within theano.scan. I have managed to get it work as the theano function, and I think it is possible to write it completely without implementing pymc3. However, I wonder if I can do as a part of pymc3. Then I will be able to sample parameters from different distributions in nicer way.

For example, I have the following program:

import theano
import theano.tensor as tt

K_sym = tt.iscalar("K_sym")
n_sym = tt.lmatrix("n_sym")
p_sym = tt.dscalar("p_sym")

results_step, updates_step = theano.scan(lambda i, n, p: 2*theano_rng.multinomial(n=n[i,:],pvals=[p,1-p])[:,0], 

compute_step = theano.function(inputs=[K_sym,n_sym,p_sym], outputs=results_step, updates=updates_step)

res = compute_step(2,np.transpose([[100,50]]*100),.6)

and I would like to re-write it as something like

import pymc3 as pm
import theano
import theano.tensor as tt

with pm.Model() as model:
    # some definition of n, p, and K
    n = pm.RV(shape=2) #or tt.as_tensor_variable()
    p = pm.RV()
    K = 2
    results_step, updates_step = theano.scan(lambda i, n, p: 2*theano_rng.multinomial(n=n[i,:],pvals=[p,1-p])[:,0], 
    compute_step = theano.function(inputs=[K,n,p], outputs=results_step, updates=updates_step)

    sampling = pm.sample_prior_predictive(100)

which is wrong, because theano.function should not be a part of the model. But I need to have it since there is sampling from multinomial inside theano.scan.

Maybe some of you have had something similar, and would not mind to give some hints or an advice. Thank you in advance!

I have seen Sample within theano.scan?, but that one looked a bit different to me.


Most case you can just do compute_step=results_step (i.e., no need to compile it into a function, just use its tensor representation).
In your case, you might also want to replace 2*theano_rng.multinomial(n=n[i,:],pvals=[p,1-p])[:,0] with a PyMC3 Random Variable using pm.Multinomial


Thank you, @junpenglao

For pm.Multinomial I assume I would need to indicate the name of the random variable, would it be not a problem?