Estimating standard deviation of very coarsely binned data

Hi Everyone

I am a physicist working in Industry, I design and make electronic stuff. I have run up against an issue with PYMC which is somewhere between a bug and my poor Bayesian understanding

I am trying to illustrate the limitations of analogue to digital converters (ADC) at low bit levels. To do this I am simulating guassian noise into the ADC, this effectively bins the data with predefined bins. I am then trying to estimate the mean and standard deviation of the original distribution.

In real life people using equipment take a small number of readings (lets say 5 observed data points).

What I am finding is the model below crashes when N is made small.

It will run with bigger sigma but that sort of defeats the point of the exercise.

I want to get to sigma = 0.01 bits =8 and N = 5

Intuitively I know the reported answer for sigma should basically be flat for sigma< binwidth/2 and mu should be somewhere in the middle of the central bin (but we can really say much better than width of the bin).

I have tweaked the example on binned data https://www.pymc.io/projects/examples/en/latest/case_studies/binning.htmle. So I am sure I am not doing anything very clever. (Probably the exact opposite)

Thank you in advance

# This is based on https://www.pymc.io/projects/examples/en/latest/case_studies/binning.html
import numpy as np
import pymc as pm
import aesara.tensor as pt 
import arviz as az
import pandas as pd
import matplotlib.pyplot as plt



true_mu = 2.5
true_sigma = 0.01


np.random.seed(0)       
N  =10000
bits = 8
V0 = 5





x1 = np.random.normal(loc=true_mu, scale=true_sigma,size=N)




Vrange = V0*np.arange(0,2**bits,1)/(2**bits -1)

d1 = Vrange[np.where((Vrange > 2.47) & (Vrange <2.53) )]


def data_to_bincounts(data, cutpoints):
    # categorise each datum into correct bin
    bins = np.digitize(data, bins=cutpoints)
    # bin counts
    counts = pd.DataFrame({"bins": bins}).groupby(by="bins")["bins"].agg("count")
    return counts


c1 = data_to_bincounts(x1, d1)


with pm.Model() as model1:
    mu = pm.Uniform('mu', lower=2.45, upper=2.55)
    sigma = pm.Uniform('sigma', lower=0.005, upper=0.05)
    #sigma = pm.HalfNormal("sigma")
    #mu = pm.Normal("mu")

    probs1 = pm.math.exp(pm.logcdf(pm.Normal.dist(mu=mu, sigma=sigma), d1))
    probs1 = pt.extra_ops.diff(pm.math.concatenate([[0], probs1, [1]]))
    pm.Multinomial("counts1", p=probs1, n=c1.sum(), observed=c1.values)
    
pm.model_to_graphviz(model1)

with model1:
    #trace1 = pm.sample(draws=10000,chains=4)
    trace1 = pm.sample()
    
    
with model1:
    ppc = pm.sample_posterior_predictive(trace1)
    
az.plot_posterior(trace1, var_names=["mu", "sigma"], ref_val=[true_mu, true_sigma]);

How does it crash?

On my AWS Sagemaker system, when I set N=5 I get this

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
/opt/conda/lib/python3.10/site-packages/pymc/aesaraf.py:1005: UserWarning: The parameter 'updates' of aesara.function() expects an OrderedDict, got <class 'dict'>. Using a standard dictionary here results in non-deterministic behavior. You should use an OrderedDict if you are using Python 2.7 (collections.OrderedDict for older python), or use a list of (shared, update) pairs. Do not just convert your dictionary to this type before the call as the conversion will still be non-deterministic.
  aesara_function = aesara.function(
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/aesara/link/vm.py:309, in LoopGC.__call__(self)
    306 for thunk, node, old_storage in zip(
    307     self.thunks, self.nodes, self.post_thunk_clear
    308 ):
--> 309     thunk()
    310     for old_s in old_storage:

File /opt/conda/lib/python3.10/site-packages/aesara/graph/op.py:522, in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
    518 @is_thunk_type
    519 def rval(
    520     p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
    521 ):
--> 522     r = p(n, [x[0] for x in i], o)
    523     for o in node.outputs:

File /opt/conda/lib/python3.10/site-packages/aesara/tensor/elemwise.py:718, in Elemwise.perform(self, node, inputs, output_storage)
    717     if len(set(dim_shapes) - {1}) > 1:
--> 718         raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}")
    720 # Determine the shape of outputs

ValueError: Shapes on dimension 0 do not match: (2, 5, 1, 2, 2)

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
Cell In[18], line 59
     55 pm.model_to_graphviz(model1)
     57 with model1:
     58     #trace1 = pm.sample(draws=10000,chains=4)
---> 59     trace1 = pm.sample()
     62 with model1:
     63     ppc = pm.sample_posterior_predictive(trace1)

File /opt/conda/lib/python3.10/site-packages/pymc/sampling.py:481, in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
    479         [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()]
    480     _log.info("Auto-assigning NUTS sampler...")
--> 481     initial_points, step = init_nuts(
    482         init=init,
    483         chains=chains,
    484         n_init=n_init,
    485         model=model,
    486         seeds=random_seed,
    487         progressbar=progressbar,
    488         jitter_max_retries=jitter_max_retries,
    489         tune=tune,
    490         initvals=initvals,
    491         **kwargs,
    492     )
    494 if initial_points is None:
    495     # Time to draw/evaluate numeric start points for each chain.
    496     ipfns = make_initial_point_fns_per_chain(
    497         model=model,
    498         overrides=initvals,
    499         jitter_rvs=filter_rvs_to_jitter(step),
    500         chains=chains,
    501     )

File /opt/conda/lib/python3.10/site-packages/pymc/sampling.py:2307, in init_nuts(init, chains, n_init, model, seeds, progressbar, jitter_max_retries, tune, initvals, **kwargs)
   2300 _log.info(f"Initializing NUTS using {init}...")
   2302 cb = [
   2303     pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="absolute"),
   2304     pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"),
   2305 ]
-> 2307 initial_points = _init_jitter(
   2308     model,
   2309     initvals,
   2310     seeds=seeds,
   2311     jitter="jitter" in init,
   2312     jitter_max_retries=jitter_max_retries,
   2313 )
   2315 apoints = [DictToArrayBijection.map(point) for point in initial_points]
   2316 apoints_data = [apoint.data for apoint in apoints]

File /opt/conda/lib/python3.10/site-packages/pymc/sampling.py:2194, in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries)
   2192 if i < jitter_max_retries:
   2193     try:
-> 2194         model.check_start_vals(point)
   2195     except SamplingError:
   2196         # Retry with a new seed
   2197         seed = rng.randint(2**30, dtype=np.int64)

File /opt/conda/lib/python3.10/site-packages/pymc/model.py:1695, in Model.check_start_vals(self, start)
   1689     valid_keys = ", ".join(self.named_vars.keys())
   1690     raise KeyError(
   1691         "Some start parameters do not appear in the model!\n"
   1692         f"Valid keys are: {valid_keys}, but {extra_keys} was supplied"
   1693     )
-> 1695 initial_eval = self.point_logps(point=elem)
   1697 if not all(np.isfinite(v) for v in initial_eval.values()):
   1698     raise SamplingError(
   1699         "Initial evaluation of model at starting point failed!\n"
   1700         f"Starting values:\n{elem}\n\n"
   1701         f"Initial evaluation results:\n{initial_eval}"
   1702     )

File /opt/conda/lib/python3.10/site-packages/pymc/model.py:1736, in Model.point_logps(self, point, round_vals)
   1730 factors = self.basic_RVs + self.potentials
   1731 factor_logps_fn = [at.sum(factor) for factor in self.logpt(factors, sum=False)]
   1732 return {
   1733     factor.name: np.round(np.asarray(factor_logp), round_vals)
   1734     for factor, factor_logp in zip(
   1735         factors,
-> 1736         self.compile_fn(factor_logps_fn)(point),
   1737     )
   1738 }

File /opt/conda/lib/python3.10/site-packages/pymc/model.py:1835, in PointFunc.__call__(self, state)
   1834 def __call__(self, state):
-> 1835     return self.f(**state)

File /opt/conda/lib/python3.10/site-packages/aesara/compile/function/types.py:964, in Function.__call__(self, *args, **kwargs)
    961 t0_fn = time.time()
    962 try:
    963     outputs = (
--> 964         self.fn()
    965         if output_subset is None
    966         else self.fn(output_subset=output_subset)
    967     )
    968 except Exception:
    969     restore_defaults()

File /opt/conda/lib/python3.10/site-packages/aesara/link/vm.py:313, in LoopGC.__call__(self)
    311             old_s[0] = None
    312 except Exception:
--> 313     raise_with_op(self.fgraph, node, thunk)

File /opt/conda/lib/python3.10/site-packages/aesara/link/utils.py:538, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
    533     warnings.warn(
    534         f"{exc_type} error does not allow us to add an extra error message"
    535     )
    536     # Some exception need extra parameter in inputs. So forget the
    537     # extra long error message in that case.
--> 538 raise exc_value.with_traceback(exc_trace)

File /opt/conda/lib/python3.10/site-packages/aesara/link/vm.py:309, in LoopGC.__call__(self)
    305 try:
    306     for thunk, node, old_storage in zip(
    307         self.thunks, self.nodes, self.post_thunk_clear
    308     ):
--> 309         thunk()
    310         for old_s in old_storage:
    311             old_s[0] = None

File /opt/conda/lib/python3.10/site-packages/aesara/graph/op.py:522, in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
    518 @is_thunk_type
    519 def rval(
    520     p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
    521 ):
--> 522     r = p(n, [x[0] for x in i], o)
    523     for o in node.outputs:
    524         compute_map[o][0] = True

File /opt/conda/lib/python3.10/site-packages/aesara/tensor/elemwise.py:718, in Elemwise.perform(self, node, inputs, output_storage)
    716 for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))):
    717     if len(set(dim_shapes) - {1}) > 1:
--> 718         raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}")
    720 # Determine the shape of outputs
    721 out_shape = []

ValueError: Shapes on dimension 0 do not match: (2, 5, 1, 2, 2)
Apply node that caused the error: Elemwise{Composite{(i0 + Switch(EQ(i1, i2), i3, (i4 * log(i1))))}}[(0, 1)](TensorConstant{[-0.693147...79175947]}, DiffOp{n=1, axis=-1}.0, TensorConstant{(1,) of 0}, TensorConstant{(2,) of -inf}, TensorConstant{[2. 3.]})
Toposort index: 12
Inputs types: [TensorType(float64, (2,)), TensorType(float64, (None,)), TensorType(int8, (1,)), TensorType(float32, (2,)), TensorType(float64, (2,))]
Inputs shapes: [(2,), (5,), (1,), (2,), (2,)]
Inputs strides: [(8,), (8,), (1,), (4,), (8,)]
Inputs values: [array([-0.69314718, -1.79175947]), array([0.22805058, 0.23926506, 0.25214614, 0.17388155, 0.10665667]), array([0], dtype=int8), array([-inf, -inf], dtype=float32), array([2., 3.])]
Outputs clients: [[Sum{acc_dtype=float64}(Elemwise{Composite{(i0 + Switch(EQ(i1, i2), i3, (i4 * log(i1))))}}[(0, 1)].0)]]

HINT: Re-running with most Aesara optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the Aesara flag 'optimizer=fast_compile'. If that does not work, Aesara optimizations can be disabled with 'optimizer=None'.
HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

by contrast if I use N=10000. I get a few warnings - but it does finish and plots something sensible

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
/opt/conda/lib/python3.10/site-packages/pymc/aesaraf.py:1005: UserWarning: The parameter 'updates' of aesara.function() expects an OrderedDict, got <class 'dict'>. Using a standard dictionary here results in non-deterministic behavior. You should use an OrderedDict if you are using Python 2.7 (collections.OrderedDict for older python), or use a list of (shared, update) pairs. Do not just convert your dictionary to this type before the call as the conversion will still be non-deterministic.
  aesara_function = aesara.function(
Sequential sampling (2 chains in 1 job)
NUTS: [mu, sigma]

 100.00% [2000/2000 00:27<00:00 Sampling chain 0, 0 divergences]
/opt/conda/lib/python3.10/site-packages/aesara/scalar/basic.py:3070: RuntimeWarning: divide by zero encountered in log1p
  return np.log1p(x)
/opt/conda/lib/python3.10/site-packages/aesara/scalar/basic.py:2001: RuntimeWarning: divide by zero encountered in divide
  return x / y

 100.00% [2000/2000 00:26<00:00 Sampling chain 1, 0 divergences]
/opt/conda/lib/python3.10/site-packages/aesara/scalar/basic.py:3070: RuntimeWarning: divide by zero encountered in log1p
  return np.log1p(x)
/opt/conda/lib/python3.10/site-packages/aesara/scalar/basic.py:2001: RuntimeWarning: divide by zero encountered in divide
  return x / y
Sampling 2 chains for 1_000 tune and 1_000 draw iterations (2_000 + 2_000 draws total) took 54 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics

 100.00% [2000/2000 00:01<00:00]

You are getting shape issues not stability problems. Is the c1 returning less elements when N is 5? You should have as many c1 values as you have probs1

Yes I think that’s the issue. Empty bins when the counts are low or the standard deviation is small. I am forcing it to have counts = 0 instead of non existent rows in the data_to_bincounts function. Hopefully this fixes it.

For confirmation that has fixed it.

I did a rather hacky add one to count to the raw data prior to binning and then reduced each count by one.

There are a rather lot of divergences and warnings but I think I can live with that for the time being.

def data_to_bincounts(data, cutpoints):
    d2 = np.append(cutpoints,V0)
    d3 = np.append(0.0,d2)
    
    # this is a dummy count for every bin
    xdummy = (d3[:-1] + d3[1:])/2
    
    def take_1(x):
        return x-1
    xf = np.append(data ,xdummy)
    # categorise each datum into correct bin
    bins = np.digitize(xf, bins=cutpoints)
    # bin counts
    counts = pd.DataFrame({"bins": bins}).groupby(by="bins")["bins"].agg("count")
    
    counts = counts.apply(take_1)
    
    return counts
1 Like