How to define a model with an unobserved subsampling step

I have a simulated data that is generated with the following steps (code below also contains a function that generates the data as described below):

1- Generate real initial proportions p0 for m categories.
2- Generate initial observed counts from multinomial with p=p0 and total count = N0 (observed)
3- Generate counts from multinomial with p=p0 and some N1<<N0 (unobserved subsampling, but N1 is known)
4- Transform subsampled counts by scaling then with m scale values (drawn from uniform distribution between 0 and 20) which I call real fs. Let p1 be the proportions of the transformed counts.
5- Generate final observed counts from multinomial with p=p1 and total count = N2 (observed)

Now given the final observed counts (5) and initial observed counts (1) I would like to estimate fs using pymc3. If one did not have the step 3, it can easily be modelled by combination of two observed Multinomials (see code below). On the other hand if N1 is sufficiently small compared to m, it will lead to uncertainties in the results. For instance lets assume one category has high f and non-zero value in initial counts but gets left out during subsampling stage (in all repeats) because it had low initial proportions. Then clearly the estimate for f will be very low (compared to its real value) since final counts will be 0 too.

I understand that one can not improve the estimates for such points by tweaking the model when subsampling completely destroys any data associated to them. But I am wondering if one can incorporate something into the model (I assume that I can observe N0,N1,N2 so I know how severe subsampling is), such as an unobserved multinomial, which would increase the uncertainty in hdi estimates when N1 is low. I tried several things but none of those were satisfactory. So the code below contains a model in which only two multinomials are used (as if step 3 does not exist). It also has the necessary code for generating the data. Please let me know if it makes sense to try to incorporate this unobserved subsampling into the model. I am hoping it will be more helpful for cases where there is only one repeats of observations. For instance if you run the fitting procedure above with m=150, N0=N2=150000 and N1=100 vs N1=10000, the observed ranges for HDI values do not really change though the results become more off (which is to be expected as there is nothing in the fitting procedure that includes information regarding N1). I also understand that given a very low initial proportion and low m, if you increase a category’s f, the quantified uncertainty for that category will not change but the estimates will keep getting worse. Nevertheless if there is a mathematically sound way of modelling this uncertainty modulo pathological outliers as above, I would still be satisfied.

Another thing that comes to mind which is not incorporated into the fitting procedure but is a separete piece of information is given p0 and N1, I can compute the probability that a particular category comes up as 0 during subsampling and create a visualization of results where things are colored with a gradient according to this probability though they wont be reflected in the HDI estimates coming from the fit.

Thanks

import pymc3 as pm
import numpy as np
import arviz as az
import matplotlib.pyplot as plt

def fit(initial_counts, final_counts, change_in_total_counts):
    
    nrepeats,ncategories = initial_counts.shape

    with pm.Model() as model:
        
        log2_fs = pm.Uniform("log2_fs", lower=0, upper=21, shape=(ncategories))
        fs = 2**log2_fs
        p1 = pm.Dirichlet("p1", a=[1.0 for _ in range(ncategories)], 
                          shape=ncategories)
        
        
        for i in range(nrepeats):
            
            input_sample = pm.Multinomial(f"input_sample{i}", 
                                          n=np.sum(initial_counts[i,:]), 
                                          p=p1, observed=initial_counts[i,:])
                
            
            p2 = fs*p1
            p2 = p2/pm.math.sum(p2)

            log_initial_sum = pm.math.log(pm.math.sum(initial_counts[i,:]))
            log_final_sum = pm.math.log(pm.math.sum(input_sample*fs))
            expected_log_sum = log_initial_sum + np.log(change_in_total_counts)
            
            final_sample = pm.Multinomial(f"final_sample{i}", 
                                          n=np.sum(final_counts[i,:]), 
                                          p=p2,
                                          observed=final_counts[i,:])
            
            # normally log2_fs is determined up to a additive constant
            # since we work with proportions. The term below is a necessary
            # constraint to fix this constant
            final_sum = pm.Normal(f"final_sum{i}", log_final_sum, sd=0.01, 
                                  observed=expected_log_sum)
          
        trace = pm.sample(draws=300, tune=100, chains=6, cores=6, 
                          return_inferencedata=True, target_accept=0.9,
                          progressbar=True)
        
        return trace


def generate_observations(ncategories, nobservations, subsample_size, rng,
                          nrepeats):
    # real proportions
    real_initial_proportions = rng.dirichlet([1 for _ in range(ncategories)])

    # observed initial counts
    observed_initial_counts = rng.multinomial(nobservations, 
                                              real_initial_proportions,
                                              size=nrepeats)
    
    # subsampling counts (unobserved)
    subsampled_counts = rng.multinomial(subsample_size, 
                                        real_initial_proportions,
                                        size=nrepeats)
    
    # subsampled counts changing by a factor f which is the quantity I want to 
    # estimate
    real_log2_fs = rng.uniform(1,20,ncategories)
    
    real_final_counts = 2.0**real_log2_fs*subsampled_counts
    real_final_proportions =\
        (real_final_counts.T/np.sum(real_final_counts, axis=1)).T
    
    # observed final counts 
    observed_final_counts = rng.multinomial(nobservations, 
                                            real_final_proportions, nrepeats)
    
    change_in_total_counts =\
        np.sum(real_final_counts)/np.sum(subsampled_counts)
    
    
    return (observed_initial_counts, observed_final_counts, real_log2_fs, 
            real_initial_proportions, change_in_total_counts)


def plot(trace, real_log2_fs):
    
    summary1 = az.summary(trace, var_names=['log2_fs'], hdi_prob=0.99, round_to=5)


    fitted_log2_fs = summary1.loc[:,'mean']
    low = fitted_log2_fs - summary1.iloc[:,2]
    high =  summary1.iloc[:,3] - fitted_log2_fs

    fig,ax = plt.subplots(1,1,figsize=(5,5))

    ax.scatter(real_log2_fs, fitted_log2_fs, zorder=0)
    ax.errorbar(real_log2_fs, fitted_log2_fs, np.array(low, high),
                fmt='none', ecolor='black', capsize=5,
                alpha=0.5,zorder=2)
    ax.plot(np.arange(0,20.1,0.1),np.arange(0,20.1,0.1),zorder=-2)    

    ax.set_xlim(0,20)
    ax.set_ylim(0,20)

    return fig,ax


seed = 0
rng = np.random.default_rng(seed)
ncategories = 150
nobservations = 150000
subsample_size = 1000
nrepeats = 2


(observed_initial_counts, observed_final_counts, real_log2_fs,
  real_initial_proportions, change_in_total_counts) =\
    generate_observations(ncategories, nobservations, subsample_size, rng,
                          nrepeats)
    

trace = fit(observed_initial_counts, observed_final_counts, 
              change_in_total_counts)

fig,ax=\
plot(trace, real_log2_fs)