Hierarchical gaussian mixture model VI minibatch

Hi everyone,
I have a large dataset containing 180 million datapoints. Those datapoints could be assigned to 280 groups. A good way to describe the data is gaussian mixture model since it has several modes. A group could look like this:


As described here https://docs.pymc.io/notebooks/multilevel_modeling.html I want to compare the results of a hierachical model with a pooled and unpooled version.
Since the dataset is large I use ADVI with minibatch.
The pooled model worked pretty well on the first try but I got a lot of issues with the unpolled and hierarchical models. By issues I mean that the models work but deliver totally wrong results. After a lot of hacking I doubt that it is possible at all to model something like this in pymc3.
Simplyfied toy data could be produced with the following code:

def build_up_dataset(num,num_el,loc,scale,a,b):
  w = np.random.dirichlet((1,1,1),num)
  mu = np.asarray([np.random.normal(loc=loc[el] ,scale=scale[el], size=int(num)) for el in range(num_el)])
  mu = mu.transpose()
  sigma = np.asarray([np.random.gamma(a[el],b[el], int(num)) for el in range(num_el)])
  sigma = sigma.transpose()
  x = []
  for row in range(w.shape[0]):
      component = np.random.choice(mu[row].size, size=num_el, p=w[row])
      for com in component:
         x.append(np.random.normal(mu[row][com], sigma[row][com], size=1)[0])
  return x

samples = 4
num = np.random.randint(800,1000,samples)
num_el = np.full((samples,),3)
loc = [[1,3,5],[8,10,12],[2,4,8],[2,8,15]]
scale = np.full((samples,3),0.1)
a = [[2,2,2],[2,2,2],[4,5,1],[4,5,1]]
b = [[0.4,0.4,0.4],[0.4,0.4,0.4],[0.5,0.18,0.2],[0.4,0.2,0.1]]
da = np.array([])
ts = np.array([])
for i in range(samples):
    d = np.asarray(build_up_dataset(num[i],num_el[i],loc[i],scale[i],a[i],b[i]))
    t = np.full(d.shape,i)
    da = np.concatenate((da,d))
    ts = np.concatenate((ts,t))

My simplified unpooled version looks like this:

W = np.array([0.33, 0.33, 0.33])
X = pm.Minibatch(da, batch_size=100)
ts = ts.astype(int)
idx = pm.Minibatch(ts, batch_size=100)

with pm.Model() as model:
        w = pm.Dirichlet('w', np.ones(W.size), shape=(samples,W.size))
        mu = pm.Normal('mu', mu=0., sigma=10., shape=(samples,W.size), transform=tr.ordered, testval=np.array([0.1, 0.5, 0.9]))
        sigma = pm.HalfNormal('sigma', sigma=10., shape=(samples,W.size))
        mixture = pm.NormalMixture('x_obs', w[idx], mu[idx], sigma=sigma[idx], observed=X, total_size=da.size, comp_shape=(samples,W.size),shape=(samples,))
        advi = pm.ADVI()
        tracker = pm.callbacks.Tracker(mean=advi.approx.mean.eval,std=advi.approx.std.eval)
        approx = advi.fit(100000, callbacks=[tracker])

and the hierarchical:

W = np.array([0.33, 0.33, 0.33])
X = pm.Minibatch(da, batch_size=100)
ts = ts.astype(int)
idx = pm.Minibatch(ts, batch_size=100)

with pm.Model() as model:
        # global values
        mu_p = pm.Normal('mu_p', 0., 1e5, shape=W.size, transform=tr.ordered, testval=np.array([2, 5]))
        mu_s_p = pm.HalfNormal('mu_s_p', 1e5, shape=W.size, transform=tr.ordered, testval=np.array([2, 5]))
        sigma_p = pm.Gamma('sigma_p', 2., 2., shape=W.size)
        w_p = pm.Normal('w_p',0,1., shape=W.size, testval=np.ones_like(W))
        # local values
        w = pm.Dirichlet('w', w_p, shape=(samples,W.size))
        mu = pm.Normal('mu', mu=mu_p, sigma=mu_s_p, shape=(samples,W.size), transform=tr.ordered, testval=np.array([0.1, 0.5, 0.9]))
        sigma = pm.HalfNormal('sigma', sigma=sigma_p, shape=(samples,W.size))
        mixture = pm.NormalMixture('x_obs', w[idx], mu[idx], sigma=sigma[idx], observed=X, total_size=da.size, comp_shape=(samples,W.size),shape=(samples,))
        advi = pm.ADVI()
        tracker = pm.callbacks.Tracker(mean=advi.approx.mean.eval,std=advi.approx.std.eval)
        approx = advi.fit(100000, callbacks=[tracker])

Any suggestions?