Sure, I’ll try and see how it goes, thanks. Although I am more worried about how to retrieve that NxR matrix V.
That said, even though I couldn’t assess how good this approch to missing value is without the full matrix V, ELBO does decrease by a few orders of magnitude and eventually converged, which at leasts suggest this model is numerically stable.
On the issue with Indian Buffet Process, I have the stick breaking approach as follows:
def stick_breaking_BBP(beta):
return tt.extra_ops.cumprod(beta)
def create_BetaBernoulliProcess(alpha, K):
beta = pm.Beta('beta', alpha, 1, shape=K)
return pm.Deterministic('v', stick_breaking_BBP(beta))
class BetaBernoulliLatents(pm.Continuous):
def __init__(self, pis, normal_prior, variance_prior, *args, **kwargs):
super(BetaBernoulliLatents, self).__init__(*args, **kwargs)
self.pis = pis
self.normal_prior = normal_prior
self.variance_prior = variance_prior
self.mean = pm.floatX(np.zeros([batch_size, R]))
def logp(self, x):
chosen_likelihood = tt.log(self.pis) + pm.Normal.dist(self.normal_prior, self.variance_prior).logp(x)
ignore_likelihood = tt.log(1 - self.pis) + pm.Constant.dist(0.).logp(x)
logps = [chosen_likelihood, ignore_likelihood]
return tt.sum(logsumexp(tt.stacklists(logps)[:, :x.shape[0]], axis=0))
def fit_IBP_FA():
with pm.Model() as model:
ard_prior = pm.InverseGamma('ard', 1., 1.)
uv_prior = pm.Normal('uv_prior', 0., 20.)
U = pm.Normal('factors', uv_prior, ard_prior, shape=(D, R))
alpha = pm.Gamma('alpha', 1.0, 1.0)
pis = create_BetaBernoulliProcess(alpha, R)
V = BetaBernoulliLatents('bb_latents', pis, uv_prior, ard_prior, shape=(batch_size, R))
tau = pm.Gamma('tau', 1.0, 1.0, shape=batch_size * D)
lambda_ = pm.Uniform('lambda', 0, 5, shape=batch_size * D)
# True Distribution
mu = tt.dot(V * choices, U.T)
precision = tt.reshape(tau * lambda_, (batch_size, D))
beta_mask = pm.Beta('mask', 1, 1, shape=D)
Y = NormalMissingValMixture('observations', beta_mask, mu, precision, MISSING_PLACEHOLDER, D, observed=df_multi_strength[:batch_size])
However, the Bernoulli mask matrix will have to be integrated away for ADVI to work - this again will create a binary mixture of a Dirac centered at zero, and a Normal centered at wherever the latents are supposed to be - which results in a spike and slab latent that meanfield variational approximation is not designed to work with.
Sampling should in theory work though.
Best,
Hugo