@gBokiau, @ferrine
Ok, I did build a new example without the pymc3 import *, and that fixed the previous issue.
I had to modify the NormalMixture class to get rid of the call to get_tau_sd(), as it was causing an error, and shouldn’t be necessary since I’m providng std, not tau.
Here’s the code:
class NormalMixture(pm.Mixture):
def __init__(self, w, mu, comp_shape=(), *args, **kwargs):
sd=kwargs.pop('sd',None)
self.mu = mu = tt.as_tensor_variable(mu)
self.sd = sd = tt.as_tensor_variable(sd)
super(NormalMixture, self).__init__(w, pm.Normal.dist(mu, sd=sd, shape=comp_shape),
*args, **kwargs)
def _repr_latex_(self, name=None, dist=None):
if dist is None:
dist = self
mu = dist.mu
w = dist.w
sd = dist.sd
name = r'\text{%s}' % name
return r'${} \sim \text{{NormalMixture}}(\mathit{{w}}={},~\mathit{{mu}}={},~\mathit{{sigma}}={})$'.format(name,
get_variable_name(w),
get_variable_name(mu),
get_variable_name(sd))
@Group.register
class SymmetricMeanFieldGroup(MeanFieldGroup):
"""Symmetric MeanField Group for Symmetrized VI"""
__param_spec__ = dict(smu=('d', ), rho=('d', ))
short_name = 'sym_mean_field'
alias_names = frozenset(['smf'])
@node_property
def mean(self):
return self.params_dict['smu']
def create_shared_params(self, start=None):
if start is None:
start = self.model.test_point
else:
start_ = start.copy()
update_start_vals(start_, self.model.test_point, self.model)
start = start_
if self.batched:
start = start[self.group[0].name][0]
else:
start = self.bij.map(start)
rho = np.zeros((self.ddim,))
if self.batched:
start = np.tile(start, (self.bdim, 1))
rho = np.tile(rho, (self.bdim, 1))
return {'smu': theano.shared(
pm.floatX(start), 'smu'),
'rho': theano.shared(
pm.floatX(rho), 'rho')}
@node_property
def symbolic_logq_not_scaled(self):
z = self.symbolic_random
logq = NormalMixture.dist([.5,.5],
mu=[self.mean, -self.mean],
sd=[self.std, self.std],
).logp(z)
return logq.sum(range(1, logq.ndim))
class SADVI(ADVI):
def __init__(self, *args, **kwargs):
super(ADVI, self).__init__(SymmetricMeanField(*args, **kwargs))
class SymmetricMeanField(SingleGroupApproximation):
__doc__ = """**Single Group Mean Field Approximation**
""" + str(SymmetricMeanFieldGroup.__doc__)
_group_class = SymmetricMeanFieldGroup
Now .fit() fails on the first iteration with the error below.
It seems like the Join operation produces a (2,510), while the Composite produces a (1,510).
To me it seems like the Join operation joins a (1,510) with a copy multiplied by -1? I don’t really know how to debug this…
ValueError: Input dimension mis-match. (input[0].shape[0] = 1, input[1].shape[0] = 2)
Apply node that caused the error: Elemwise{sub,no_inplace}(Elemwise{add,no_inplace}.0, Join.0)
Toposort index: 157
Inputs types: [TensorType(float64, matrix), TensorType(float64, matrix)]
Inputs shapes: [(1, 510), (2, 510)]
Inputs strides: [(4080, 8), (4080, 8)]
Inputs values: [‘not shown’, ‘not shown’]
Inputs type_num: [12, 12]
Outputs clients: [[Elemwise{pow,no_inplace}(Elemwise{sub,no_inplace}.0, InplaceDimShuffle{x,x}.0), Elemwise{pow}(Elemwise{sub,no_inplace}.0, Elemwise{sub}.0)]]
Backtrace when the node is created(use Theano flag traceback.limit=N to make it longer):
File “C:\Users\dycontri\AppData\Local\conda\conda\envs\analytics\lib\site-packages\theano\configparser.py”, line 117, in res
return f(*args, **kwargs)
File “C:\Users\dycontri\AppData\Local\conda\conda\envs\analytics\lib\site-packages\pymc3\variational\opvi.py”, line 1174, in symbolic_logq
return self.symbolic_logq_not_scaled
File “C:\Users\dycontri\AppData\Local\conda\conda\envs\analytics\lib\site-packages\pymc3\memoize.py”, line 31, in memoizer
cache[key] = obj(*args, **kwargs)
File “C:\Users\dycontri\AppData\Local\conda\conda\envs\analytics\lib\site-packages\theano\configparser.py”, line 117, in res
return f(*args, **kwargs)
File “”, line 42, in symbolic_logq_not_scaled
).logp(z)
File “C:\Users\dycontri\AppData\Local\conda\conda\envs\analytics\lib\site-packages\pymc3\distributions\mixture.py”, line 146, in logp
return bound(logsumexp(tt.log(w) + self._comp_logp(value), axis=-1),
File “C:\Users\dycontri\AppData\Local\conda\conda\envs\analytics\lib\site-packages\pymc3\distributions\mixture.py”, line 112, in comp_logp
return comp_dists.logp(value)
File “C:\Users\dycontri\AppData\Local\conda\conda\envs\analytics\lib\site-packages\pymc3\distributions\continuous.py”, line 480, in logp
return bound((-tau * (value - mu)**2 + tt.log(tau / np.pi / 2.)) / 2.,