Pymc3 changepoint model

'm having a little trouble translating the following JAGS model to pymc3

   model {
      for (i in 1:N) {
           tau[i] <- ifelse(i < cp, tau1,tau2)
           N_t_inf[i] ~ dnorm(fcast[i],tau[i])
           N_t_T[i] ~ dbinom(pstar[i],round(N_t_inf[i]))
      }  
      
      cp ~ dunif(1,N)   
     
      invTau1 ~ dgamma(100, 1)
      tau1 <- 1/invTau1

      invTau2 ~ dgamma(1, 1)
      tau2 <- 1/invTau2

      }"

The issue is that I need a index specific change-point but I only see pymc3 likelihoods written over all observations, not indexed by obs. Is there someway I can do this in pymc3?

Are fcast, pstar, and N_t_T observations?
If so, you dont need to index your observation, as

           N_t_inf[i] ~ dnorm(fcast[i],tau[i])
           N_t_T[i] ~ dbinom(pstar[i],round(N_t_inf[i]))

All contains [i], which means you can just treat them as vectors, something like:

with pm.Model() as m:
    invTau1 = pm.Gamma('invTau1', 100, 1)
    tau1 = 1/invTau1
    invTau2 = pm.Gamma('invTau2', 1, 1)
    tau2 = 1/invTau2
    cp = pm.Uniform('cp', 1, N)
    tau = tt.switch(tt.lt(tt.arange(1, N), cp), tau1, tau2)
    N_t_inf =pm.Normal('N_t_inf', fcast, tau)
    N_t_T = pm.Binominal('N_t_T', p=pstar, n=tt.round(N_t_inf), observed=...)

Do be careful of the parameter tho, for example I am not sure JAGS defines Normal with tau or sd.

Awesome thanks!