On my AWS Sagemaker system, when I set N=5 I get this
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
/opt/conda/lib/python3.10/site-packages/pymc/aesaraf.py:1005: UserWarning: The parameter 'updates' of aesara.function() expects an OrderedDict, got <class 'dict'>. Using a standard dictionary here results in non-deterministic behavior. You should use an OrderedDict if you are using Python 2.7 (collections.OrderedDict for older python), or use a list of (shared, update) pairs. Do not just convert your dictionary to this type before the call as the conversion will still be non-deterministic.
aesara_function = aesara.function(
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
File /opt/conda/lib/python3.10/site-packages/aesara/link/vm.py:309, in LoopGC.__call__(self)
306 for thunk, node, old_storage in zip(
307 self.thunks, self.nodes, self.post_thunk_clear
308 ):
--> 309 thunk()
310 for old_s in old_storage:
File /opt/conda/lib/python3.10/site-packages/aesara/graph/op.py:522, in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
518 @is_thunk_type
519 def rval(
520 p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
521 ):
--> 522 r = p(n, [x[0] for x in i], o)
523 for o in node.outputs:
File /opt/conda/lib/python3.10/site-packages/aesara/tensor/elemwise.py:718, in Elemwise.perform(self, node, inputs, output_storage)
717 if len(set(dim_shapes) - {1}) > 1:
--> 718 raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}")
720 # Determine the shape of outputs
ValueError: Shapes on dimension 0 do not match: (2, 5, 1, 2, 2)
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
Cell In[18], line 59
55 pm.model_to_graphviz(model1)
57 with model1:
58 #trace1 = pm.sample(draws=10000,chains=4)
---> 59 trace1 = pm.sample()
62 with model1:
63 ppc = pm.sample_posterior_predictive(trace1)
File /opt/conda/lib/python3.10/site-packages/pymc/sampling.py:481, in sample(draws, step, init, n_init, initvals, trace, chain_idx, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
479 [kwargs.setdefault(k, v) for k, v in nuts_kwargs.items()]
480 _log.info("Auto-assigning NUTS sampler...")
--> 481 initial_points, step = init_nuts(
482 init=init,
483 chains=chains,
484 n_init=n_init,
485 model=model,
486 seeds=random_seed,
487 progressbar=progressbar,
488 jitter_max_retries=jitter_max_retries,
489 tune=tune,
490 initvals=initvals,
491 **kwargs,
492 )
494 if initial_points is None:
495 # Time to draw/evaluate numeric start points for each chain.
496 ipfns = make_initial_point_fns_per_chain(
497 model=model,
498 overrides=initvals,
499 jitter_rvs=filter_rvs_to_jitter(step),
500 chains=chains,
501 )
File /opt/conda/lib/python3.10/site-packages/pymc/sampling.py:2307, in init_nuts(init, chains, n_init, model, seeds, progressbar, jitter_max_retries, tune, initvals, **kwargs)
2300 _log.info(f"Initializing NUTS using {init}...")
2302 cb = [
2303 pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="absolute"),
2304 pm.callbacks.CheckParametersConvergence(tolerance=1e-2, diff="relative"),
2305 ]
-> 2307 initial_points = _init_jitter(
2308 model,
2309 initvals,
2310 seeds=seeds,
2311 jitter="jitter" in init,
2312 jitter_max_retries=jitter_max_retries,
2313 )
2315 apoints = [DictToArrayBijection.map(point) for point in initial_points]
2316 apoints_data = [apoint.data for apoint in apoints]
File /opt/conda/lib/python3.10/site-packages/pymc/sampling.py:2194, in _init_jitter(model, initvals, seeds, jitter, jitter_max_retries)
2192 if i < jitter_max_retries:
2193 try:
-> 2194 model.check_start_vals(point)
2195 except SamplingError:
2196 # Retry with a new seed
2197 seed = rng.randint(2**30, dtype=np.int64)
File /opt/conda/lib/python3.10/site-packages/pymc/model.py:1695, in Model.check_start_vals(self, start)
1689 valid_keys = ", ".join(self.named_vars.keys())
1690 raise KeyError(
1691 "Some start parameters do not appear in the model!\n"
1692 f"Valid keys are: {valid_keys}, but {extra_keys} was supplied"
1693 )
-> 1695 initial_eval = self.point_logps(point=elem)
1697 if not all(np.isfinite(v) for v in initial_eval.values()):
1698 raise SamplingError(
1699 "Initial evaluation of model at starting point failed!\n"
1700 f"Starting values:\n{elem}\n\n"
1701 f"Initial evaluation results:\n{initial_eval}"
1702 )
File /opt/conda/lib/python3.10/site-packages/pymc/model.py:1736, in Model.point_logps(self, point, round_vals)
1730 factors = self.basic_RVs + self.potentials
1731 factor_logps_fn = [at.sum(factor) for factor in self.logpt(factors, sum=False)]
1732 return {
1733 factor.name: np.round(np.asarray(factor_logp), round_vals)
1734 for factor, factor_logp in zip(
1735 factors,
-> 1736 self.compile_fn(factor_logps_fn)(point),
1737 )
1738 }
File /opt/conda/lib/python3.10/site-packages/pymc/model.py:1835, in PointFunc.__call__(self, state)
1834 def __call__(self, state):
-> 1835 return self.f(**state)
File /opt/conda/lib/python3.10/site-packages/aesara/compile/function/types.py:964, in Function.__call__(self, *args, **kwargs)
961 t0_fn = time.time()
962 try:
963 outputs = (
--> 964 self.fn()
965 if output_subset is None
966 else self.fn(output_subset=output_subset)
967 )
968 except Exception:
969 restore_defaults()
File /opt/conda/lib/python3.10/site-packages/aesara/link/vm.py:313, in LoopGC.__call__(self)
311 old_s[0] = None
312 except Exception:
--> 313 raise_with_op(self.fgraph, node, thunk)
File /opt/conda/lib/python3.10/site-packages/aesara/link/utils.py:538, in raise_with_op(fgraph, node, thunk, exc_info, storage_map)
533 warnings.warn(
534 f"{exc_type} error does not allow us to add an extra error message"
535 )
536 # Some exception need extra parameter in inputs. So forget the
537 # extra long error message in that case.
--> 538 raise exc_value.with_traceback(exc_trace)
File /opt/conda/lib/python3.10/site-packages/aesara/link/vm.py:309, in LoopGC.__call__(self)
305 try:
306 for thunk, node, old_storage in zip(
307 self.thunks, self.nodes, self.post_thunk_clear
308 ):
--> 309 thunk()
310 for old_s in old_storage:
311 old_s[0] = None
File /opt/conda/lib/python3.10/site-packages/aesara/graph/op.py:522, in Op.make_py_thunk.<locals>.rval(p, i, o, n, params)
518 @is_thunk_type
519 def rval(
520 p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
521 ):
--> 522 r = p(n, [x[0] for x in i], o)
523 for o in node.outputs:
524 compute_map[o][0] = True
File /opt/conda/lib/python3.10/site-packages/aesara/tensor/elemwise.py:718, in Elemwise.perform(self, node, inputs, output_storage)
716 for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))):
717 if len(set(dim_shapes) - {1}) > 1:
--> 718 raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}")
720 # Determine the shape of outputs
721 out_shape = []
ValueError: Shapes on dimension 0 do not match: (2, 5, 1, 2, 2)
Apply node that caused the error: Elemwise{Composite{(i0 + Switch(EQ(i1, i2), i3, (i4 * log(i1))))}}[(0, 1)](TensorConstant{[-0.693147...79175947]}, DiffOp{n=1, axis=-1}.0, TensorConstant{(1,) of 0}, TensorConstant{(2,) of -inf}, TensorConstant{[2. 3.]})
Toposort index: 12
Inputs types: [TensorType(float64, (2,)), TensorType(float64, (None,)), TensorType(int8, (1,)), TensorType(float32, (2,)), TensorType(float64, (2,))]
Inputs shapes: [(2,), (5,), (1,), (2,), (2,)]
Inputs strides: [(8,), (8,), (1,), (4,), (8,)]
Inputs values: [array([-0.69314718, -1.79175947]), array([0.22805058, 0.23926506, 0.25214614, 0.17388155, 0.10665667]), array([0], dtype=int8), array([-inf, -inf], dtype=float32), array([2., 3.])]
Outputs clients: [[Sum{acc_dtype=float64}(Elemwise{Composite{(i0 + Switch(EQ(i1, i2), i3, (i4 * log(i1))))}}[(0, 1)].0)]]
HINT: Re-running with most Aesara optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the Aesara flag 'optimizer=fast_compile'. If that does not work, Aesara optimizations can be disabled with 'optimizer=None'.
HINT: Use the Aesara flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.