I am trying to model a weighted sum of truncated normal random variables. The weights are known but vary and sum to 1 (by date i.e. at each sample). The observed data is the weighted sum samples at various dates. I am trying to recover the individual component parameters (mu and sigma) for each component. Given a sufficient amount of samples with known, varying weights, I expect this problem should be solvable.
In code, what I’d like to do is use an observed deterministic; however, I understand this is not yet possible as written. I experimented with trying to use CustomDist and a grid based convolution approximation of the pmf, but I ended up at a roadblock attempting to resolve a MissingInputError. Ultimately, I’m not sure if it would be prohibitively slow, even if I were able to get it working properly.
After doing some digging, I learned that AePPL (now logprob) may be able at some point to simplify this problem for me, so I’m curious what the “best” current approach would be to solve this problem. I’m wondering if CustomDist is the answer if properly implemented, or if there’s possibly a transformation that could help me.
Greatly appreciate any and all feedback!
Essentially, I’m trying to recover the component parameters from the following equation:
obs[i] = sum_over_j ( component[j] * w[i,j] )
Am I correct in thinking that sometime soon, AePPL may be capable of determining the logp based solely on the following custom RV?
def combined_truncnorm_rv(mus, sigmas, weights, lower, upper, size):
component = pm.TruncatedNormal.dist(mu=mus, sigma=sigmas, lower=lower, upper=upper, size=size)
return pm.math.dot(component, weights.T)
Here’s a non-working attempt (since I can’t observe the deterministic):
w_: dataframe_of_known_weights # (rows: date, cols: component)
obs = vector_of_observed_samples_by_date
llimit = 0.0
ulimit = np.inf
model = pm.Model()
with model:
model.add_coord('date', w_.index, mutable=True)
model.add_coord('component', w_.columns, mutable=False)
w = pm.MutableData('w', w_, dims=['date', 'component'])
mu = pm.HalfNormal('mu', sigma=100, dims='component')
sigma = pm.HalfNormal('sigma', sigma=10, dims='component')
component = pm.TruncatedNormal('component', mu=mu, sigma=sigma, lower=llimit, upper=ulimit, dims='component')
obs = pm.Deterministic('obs', pm.math.dot(compenent, w.T), dims='date', observed=samples)
trace_prior = pm.sample_prior_predictive(samples=100, random_seed=RANDOM_SEED)
ax = az.plot_ppc(trace_prior, var_names='obs', group='prior')
trace = pm.sample(2000, tune=2000, chains=4)