How to set up minibatches on one specific dimension when variables have multiple and different dimensions

Dear community,
In trying to use VI to speed up a model that I have previously made work (first code block below) in the MCMC framework (took several days to run and the results were good), I run into this error of “ValueError: Multiple update expressions found for the variable RandomGeneratorSharedVariable”, which I suspect to have something to do with the way I set up minibatches (second code block below).

I have set up the total_size according to this thread https://discourse.pymc.io/t/possible-extensions-of-total-size/36.
When it comes to setting up pm.minibatch on multi-dimensional variables, it was inconclusive from fearrine’s commends in this thread https://discourse.pymc.io/t/how-to-make-minibatch-for-multi-dimensional-data/5033/4. I couldn’t find any other examples or discussion on this matter.

Details on the model and variables: the observed samples are non-negative integers (hence the ZIP likelihood) under 2 dimensions, one of which (dimA) is in size of hundreds of thousands (target of minibatch) and the other (dimB) is 40. Also 3 categorical covariates are involved (c1 has 12 possible values, c2 has 2 and c3 has 2). 40 samples in dimB share the same set of covariates, while samples in dimA are latently grouped within a finite number of sets of covariates (12 X 2 X 2=48 combos) but also allow to vary around these finite (48) centers with variations also modeled.
MCMC code that works:

with pm.Model() as m:

    m.add_coord('nC1', dataDict['dataDF']['nC1'].unique(), mutable=True)
    m.add_coord('nC2', dataDict['dataDF']['nC2'].unique(), mutable=True)
    m.add_coord('nC3', dataDict['dataDF']['nC3'].unique(), mutable=True)
    m.add_coord('dimA', ..., mutable=True)
    m.add_coord('dimB', ...), mutable=True)

    TD_obs = pm.MutableData('TD_obs', dataDict['TD'], dims=('dimB', 'dimA'))
    AD_obs = pm.MutableData('AD_obs', dataDict['AD'], dims=('dimB', 'dimA'))
    c1_idx = pm.MutableData('c1_idx', dataDict['dataDF']['context'].values, dims='dimA')
    c2_idx = pm.MutableData('c2_idx', dataDict['dataDF']['segDup'].values, dims='dimA')
    c3_idx = pm.MutableData('c3_idx', dataDict['dataDF']['mapa'].values, dims='dimA')

    mu_bc = pm.TruncatedNormal('mu_bc', mu=8, sigma=3, lower=4, shape=1)
    std_bc = pm.HalfNormal('std_bc', sigma=1, shape=1)
    mu_c1 = pm.TruncatedNormal('mu_c1', mu=mu_bc, sigma=std_bc, lower=4, dims='nC1')
    mu_c2 = pm.Normal('mu_c2', mu=0, sigma=1, dims='nC2')
    mu_c3 = pm.Normal('mu_c3', mu=0, sigma=1, dims='nC3')

    mu_dimA = pm.Deterministic('mu_dimA', mu_c1[c1_idx] + mu_c2[c2_idx] + mu_c3[c3_idx], dims='dimA')
    std_shared = pm.HalfNormal('std_dimA', sigma=2, shape=1)
    ER_dimA = pm.Gamma('ER_dimA', alpha=mu_dimA ** 2 / std_shared ** 2, beta=mu_dimA / std_shared ** 2, dims='dimA')

    psi_dimA = pm.Beta('psi_dimA', alpha=2, beta=5, dims='dimA')
    AD_predicted = pm.ZeroInflatedBinomial('AD_predicted', psi=psi_dimA, n=TD_obs,
                                           p=pm.invlogit(-ER_dimA), observed=AD_obs)

VI with minibatches:

with pm.Model() as m:
    posBatchSize = 100
    lenDimB = 40
    m.add_coord('nC1', dataDict['dataDF']['nC1'].unique(), mutable=True)
    m.add_coord('nC2', dataDict['dataDF']['nC2'].unique(), mutable=True)
    m.add_coord('nC3', dataDict['dataDF']['nC3'].unique(), mutable=True)
    
    # these two observed variables have 2 dimensions (lenDimB, lenPos)
    TD_obs_m = pm.Minibatch(dataDict['TD'], batch_size=[(lenDimB, posBatchSize)])
    AD_obs_m = pm.Minibatch(dataDict['AD'], batch_size=[(lenDimB, posBatchSize)])
    # these three covariates have 1 dimension (lenPos,)
    c1_idx_m = pm.Minibatch(dataDict['dataDF']['c1'].values, batch_size=posBatchSize)
    c2_idx_m = pm.Minibatch(dataDict['dataDF']['c2'].values, batch_size=posBatchSize)
    c3_idx_m = pm.Minibatch(dataDict['dataDF']['c3'].values, batch_size=posBatchSize) (nPos,)
    mu_bc = pm.TruncatedNormal('mu_bc', mu=8, sigma=3, lower=4, shape=1)
    std_bc = pm.HalfNormal('std_bc', sigma=1, shape=1)
    mu_c1 = pm.TruncatedNormal('mu_c1', mu=mu_bc, sigma=std_bc, lower=4, dims='nC1')
    mu_c2 = pm.Normal('mu_c2', mu=0, sigma=1, dims='nC2')
    mu_c3 = pm.Normal('mu_c3', mu=0, sigma=1, dims='nC3')
    
    mu_dimA = pm.Deterministic('mu_dimA', mu_c1[c1_idx] + mu_c2[c2_idx] + mu_c3[c3_idx], dims='dimA')
    std_shared = pm.HalfNormal('std_dimA', sigma=2, shape=1)
    ER_dimA = pm.Gamma('ER_dimA', alpha=mu_dimA ** 2 / std_shared ** 2, beta=mu_dimA / std_shared ** 2, dims='dimA')
    psi_dimA = pm.Beta('psi_dimA', alpha=2, beta=5, dims='dimA')
    AD_predicted = pm.ZeroInflatedBinomial(
      'AD_predicted', psi=psi_dimA, n=TD_obs_m, p=pm.invlogit(-ER_dimA), observed=AD_obs_m, total_size=dataDict['AD'].shape)
    
approx = pm.fit(1_000, callbacks=[pm.callbacks.CheckParametersConvergence(tolerance=1e-4)])

it gave an error:

/root/anaconda3/envs/pm5/lib/python3.11/site-packages/pymc/logprob/joint_logprob.py:167: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: [integers_rv{0, (0, 0), int64, False}.0, integers_rv{0, (0, 0), int64, False}.out]
  warnings.warn(
/root/anaconda3/envs/pm5/lib/python3.11/site-packages/pymc/logprob/joint_logprob.py:167: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: [integers_rv{0, (0, 0), int64, False}.0, integers_rv{0, (0, 0), int64, False}.out]
  warnings.warn(
/root/anaconda3/envs/pm5/lib/python3.11/site-packages/pymc/logprob/joint_logprob.py:167: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: [integers_rv{0, (0, 0), int64, False}.0, integers_rv{0, (0, 0), int64, False}.out]
  warnings.warn(
/root/anaconda3/envs/pm5/lib/python3.11/site-packages/pymc/logprob/joint_logprob.py:167: UserWarning: Found a random variable that was neither among the observations nor the conditioned variables: [integers_rv{0, (0, 0), int64, False}.0, integers_rv{0, (0, 0), int64, False}.out]
  warnings.warn(
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[87], line 2
      1 with m:
----> 2     approx = pm.fit(1_000, callbacks=[pm.callbacks.CheckParametersConvergence(tolerance=1e-4)])
      4 # idata_advi = approx.sample(500)

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pymc/variational/inference.py:747, in fit(n, method, model, random_seed, start, start_sigma, inf_kwargs, **kwargs)
    745 else:
    746     raise TypeError(f"method should be one of {set(_select.keys())} or Inference instance")
--> 747 return inference.fit(n, **kwargs)

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pymc/variational/inference.py:138, in Inference.fit(self, n, score, callbacks, progressbar, **kwargs)
    136     callbacks = []
    137 score = self._maybe_score(score)
--> 138 step_func = self.objective.step_function(score=score, **kwargs)
    139 if progressbar:
    140     progress = progress_bar(range(n), display=progressbar)

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/configparser.py:47, in _ChangeFlagsDecorator.__call__.<locals>.res(*args, **kwargs)
     44 @wraps(f)
     45 def res(*args, **kwargs):
     46     with self:
---> 47         return f(*args, **kwargs)

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pymc/variational/opvi.py:387, in ObjectiveFunction.step_function(self, obj_n_mc, tf_n_mc, obj_optimizer, test_optimizer, more_obj_params, more_tf_params, more_updates, more_replacements, total_grad_norm_constraint, score, fn_kwargs)
    385 seed = self.approx.rng.randint(2**30, dtype=np.int64)
    386 if score:
--> 387     step_fn = compile_pymc([], updates.loss, updates=updates, random_seed=seed, **fn_kwargs)
    388 else:
    389     step_fn = compile_pymc([], [], updates=updates, random_seed=seed, **fn_kwargs)

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pymc/pytensorf.py:1083, in compile_pymc(inputs, outputs, random_seed, mode, **kwargs)
   1047 """Use ``pytensor.function`` with specialized pymc rewrites always enabled.
   1048 
   1049 This function also ensures shared RandomState/Generator used by RandomVariables
   (...)
   1079     is set to False.
   1080 """
   1081 # Create an update mapping of RandomVariable's RNG so that it is automatically
   1082 # updated after every function call
-> 1083 rng_updates = collect_default_updates(inputs, outputs)
   1085 # We always reseed random variables as this provides RNGs with no chances of collision
   1086 if rng_updates:

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pymc/pytensorf.py:1036, in collect_default_updates(inputs, outputs)
   1032         # When a variable has multiple outputs, it will be called twice with the same
   1033         # update expression. We don't want to raise in that case, only if the update
   1034         # expression in different from the one already registered
   1035         elif rng_updates[rng] is not update:
-> 1036             raise ValueError(f"Multiple update expressions found for the variable {rng}")
   1037 return rng_updates

ValueError: Multiple update expressions found for the variable RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F0B95FC5000>)

I can provide further information if necessary. Thanks in advance for any advices.

The Minibatch API changed a bit in the latest release, and may be simpler to use: pymc.Minibatch — PyMC 5.0.2 documentation

You can pass multiple variables and the leading dimension will be batched consistently (so you only need and should only do one call to Minibatch). You can then transpose the returned variables if you need the batched dims at the end.

Do you mean that now multiple calls to the minibatch function will cause errors for sure?
Since 2 of my observed variables have dimension (A, B) and the other 3 covariates have dimension (B, ), my understanding is that I need to transpose the observed to (B, A) first, and minibatch and then transpose, am I correct? But it comes back to another question, what if I want to batch both dimensions, if by default it batches the leading dimension only? My guess is that internally it first determines whether the leading dimension (dim1) is present within all variables, and if it does it will determine dim1 as valid. This logic goes to the next dimension and until it finds one dimension that is not present in all variables, and it will determine this one and remaining dimensions as invalid to be batched. Am I corrent? For example:
for 3 variables var1: (A,B,C), var2: (A,B,), var3: (A,), internally it will batch dimension A only.

Do you mean that now multiple calls to the minibatch function will cause errors for sure?

They will cause the minibatch slices to not match across variables, so you might pair c1_idx_m[0:10] with c2_idx_m[20:30], which you don’t want. That’s why we now accept multiple variables so that the slices are the same. Before this was ensured with some complicated logic, now it’s just by using the same random slice across multiple variables.

Since 2 of my observed variables have dimension (A, B) and the other 3 covariates have dimension (B, ), my understanding is that I need to transpose the observed to (B, A) first, and minibatch and then transpose, am I correct?

Yes, although you might not need the second transpose if you organize your code to match on the leading dimension.

But it comes back to another question, what if I want to batch both dimensions, if by default it batches the leading dimension only? My guess is that internally it first determines whether the leading dimension (dim1) is present within all variables, and if it does it will determine dim1 as valid. This logic goes to the next dimension and until it finds one dimension that is not present in all variables, and it will determine this one and remaining dimensions as invalid to be batched. Am I corrent?

Actually I think the code only works for one batched leading dimension, as it checks that the first dimension of the batched dims matches, but not the following ones. I will open an issue to ensure that’s the case.

It will fail if you try to use a 2-value batch_size with a vector input.


I think you have two options for you specific use case:

  1. Use two calls to Minibatch:
ab = np.ones((1000, 500)
c = np.ones((1000,))
with pm.Model() as m:
  mab, mc = pm.Minibatch(ab, c, batch_size=(10,)  # mab.shape == (10, 500), c.shape == (10,)
  mmab = pm.Minibatch(mab.T, batch_size=(5,)).T  # mmab.shape == (10, 5)
  1. Use the lower level minibatch_index directly: pymc.data — PyMC 5.10.0 documentation
from pymc.data import minibatch_index

with pm.Model() as m:
  mb_idx_dim1 = minibatch_index(0, 1000, size=(10,))
  mb_idx_dim2 = minibatch_index(0, 500, size=(5,))
  mc = c[mb_idx_dim1]
  mmab = ab[mb_idx_dim1, mb_idx_dim2]

This way you don’t need to bother with transposes or multiple calls to Minibatch. The only thing to pay attention is to use the same minibatch_index for dims that should be paired.

1 Like

Having said that, I am still surprised by your original error… it doesn’t make much sense to me. Let me know if you still see it after you fix the use of minibatch

Gratitudes. I am giving it a try now.

I opened an issue here to think about batching in multiple dimensions, thanks for bringing the case to our attention: Assert that `Minibatch` `batch_size` is an integer · Issue #6554 · pymc-devs/pymc · GitHub

Not sure if I have done it correctly, but both options yielded some kind of errors:

  1. Use two calls to Minibatch:
ab = np.ones((1000, 500))
c = np.ones((1000,))
with pm.Model() as m:
    mab, mc = pm.Minibatch(ab, c, batch_size=(10,))  # mab.shape == (10, 500), c.shape == (10,)
    mmab = pm.Minibatch(mab.T, batch_size=(5,)).T  # mmab.shape == (10, 5)

ValueError                                Traceback (most recent call last)
Cell In[16], line 5
      3 with pm.Model() as m:
      4     mab, mc = pm.Minibatch(ab, c, batch_size=(10,))  # mab.shape == (10, 500), c.shape == (10,)
----> 5     mmab = pm.Minibatch(mab.T, batch_size=(5,)).T  # mmab.shape == (10, 5)

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pymc/data.py:197, in Minibatch(variable, batch_size, *variables)
    195 for i, v in enumerate((tensor, *tensors)):
    196     if not valid_for_minibatch(v):
--> 197         raise ValueError(
    198             f"{i}: {v} is not valid for Minibatch, only constants or constants.astype(dtype) are allowed"
    199         )
    200 result = tuple([v[slc] for v in (tensor, *tensors)])
    201 for i, r in enumerate(result):

ValueError: 0: minibatch.0.T is not valid for Minibatch, only constants or constants.astype(dtype) are allowed
  1. Use the lower level minibatch_index directly
with pm.Model() as m:
    mb_idx_dim1 = minibatch_index(0, 1000, size=(10,))
    mb_idx_dim2 = minibatch_index(0, 500, size=(5,))
    mc = c[mb_idx_dim1]
    mmab = ab[mb_idx_dim1, mb_idx_dim2]

yielded

IndexError                                Traceback (most recent call last)
Cell In[14], line 4
      2 mb_idx_dim1 = minibatch_index(0, 1000, size=(10,))
      3 mb_idx_dim2 = minibatch_index(0, 500, size=(5,))
----> 4 mc = c[mb_idx_dim1]
      5 mmab = ab[mb_idx_dim1, mb_idx_dim2]

IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

@HongyuXie Can you perhaps provide a smaller completely reproducible example so that I can try to replicate locally?

That will be great! Here is a small dataset and the mcmc version code, miniBatch_test. I have just tested it and it was ok. First time sharing data so please let me know if there are better ways for sharing these.

I am afraid a pickle file with pickle is a bit dangerous for a third person to open, as it can contain arbitrary code. Can you save your data as CSV (or just generate fake data with numpy routines on the fly).

Also can you reduce your model to the bare minimum that still shows the issue? Probably there’s no need for 5 MutableData. and all those variables,

Finally, can you also include the use of Minibatch that you tried without success?

Apologies if my request sounds annoying, but having safe and small reproducible examples really makes our lives easier.

Totally understood. Here is a simplified MCMC version with mock data made on the fly as input. Please let me know if you want it to be more simplified or need any information.

# mock data
rng = np.random.default_rng(123)
samplePerPos = 10
nPos = 100
dataDict = {
        'TD': rng.normal(3800, 500, size=(samplePerPos, nPos)).astype(int),
        'AD': rng.poisson(1, size=(samplePerPos, nPos)),
        'covariate1': rng.integers(low=0, high=16, size=nPos)}

with pm.Model() as m:

    m.add_coord('nC1', np.unique(dataDict['covariate1']), mutable=True)
    m.add_coord('dimA', np.arange(dataDict['AD'].shape[1]), mutable=True)
    m.add_coord('dimB', np.arange(dataDict['AD'].shape[0]), mutable=True)

    TD_obs = pm.MutableData('TD_obs', dataDict['TD'], dims=('dimB', 'dimA'))
    AD_obs = pm.MutableData('AD_obs', dataDict['AD'], dims=('dimB', 'dimA'))
    c1_idx = pm.MutableData('c1_idx', dataDict['covariate1'], dims='dimA')

    mu_c1 = pm.TruncatedNormal('mu_c1', mu=6, sigma=2, lower=3, dims='nC1')
    mu_p = pm.Deterministic('mu_p', mu_c1[c1_idx], dims='dimA')

    std_p = pm.HalfNormal('std_p', sigma=2, shape=1)
    ER_p = pm.Gamma('ER_p', alpha=mu_p ** 2 / std_p ** 2, beta=mu_p / std_p ** 2, dims='dimA')

    psi_p = pm.Beta('psi_p', alpha=2, beta=5, dims='dimA')
    AD_predicted = pm.ZeroInflatedBinomial('AD_predicted', psi=psi_p, n=TD_obs, p=pm.invlogit(-ER_p), observed=AD_obs)
    idata = pm.sample(300, init="adapt_diag", tune=300, cores=2, chains=2, target_accept=.9, return_inferencedata=True)

Regarding errors with using minibatch, I was kinda hoping I can overhaul my model after making your two suggested methods work because they seem promising and that way I don’t have to change dimensions of all my lengthy data-generating codes not shown.
The errors associated with the two methods shown above have nothing to do with my data or model. I just copied your codes and ran so there is nothing more I can give beside that. I figured they might just need some minor revisions to work. That being said, I still tried a minibatch version anyway and here is the code:

with pm.Model() as m:
    m.add_coord('nC1', np.unique(dataDict['covariate1']), mutable=True)
    m.add_coord('dimA', np.arange(dataDict['AD'].shape[1]), mutable=True)
    m.add_coord('dimB', np.arange(dataDict['AD'].shape[0]), mutable=True)

    TD_obs_m, AD_obs_m, c1_idx_m = pm.Minibatch(
            dataDict['TD'].T,
            dataDict['AD'].T,
            dataDict['covariate1'], batch_size=posBatchSize)
    TD_obs_mt = TD_obs_m.T
    AD_obs_mt = AD_obs_m.T

    mu_c1 = pm.TruncatedNormal('mu_c1', mu=6, sigma=2, lower=3, dims='nC1')
    mu_p = pm.Deterministic('mu_p', mu_c1[c1_idx_m], dims='dimA')

    std_p = pm.HalfNormal('std_p', sigma=2, shape=1)
    ER_p = pm.Gamma('ER_p', alpha=mu_p ** 2 / std_p ** 2, beta=mu_p / std_p ** 2, dims='dimA')

    psi_p = pm.Beta('psi_p', alpha=2, beta=5, dims='dimA')
    AD_predicted = pm.ZeroInflatedBinomial('AD_predicted', psi=psi_p, n=TD_obs_mt, p=pm.invlogit(-ER_p), observed=AD_obs_mt,
                                               total_size=dataDict['AD'].shape)

   approx = pm.fit(1_000, callbacks=[pm.callbacks.CheckParametersConvergence(tolerance=1e-4)])

The error was:

ERROR (pytensor.graph.rewriting.basic): Rewrite failure due to: constant_folding
ERROR (pytensor.graph.rewriting.basic): node: Assert{msg=All variables shape[0] in Minibatch should be equal, check your Minibatch(data1, data2, ...) code}(TensorConstant{100}, TensorConstant{False})
ERROR (pytensor.graph.rewriting.basic): TRACEBACK:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File "/root/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1925, in process_node
    replacements = node_rewriter.transform(fgraph, node)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/root/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py", line 1084, in transform
    return self.fn(fgraph, node)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/root/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/tensor/rewriting/basic.py", line 1142, in constant_folding
    required = thunk()
               ^^^^^^^
  File "/root/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/link/c/op.py", line 103, in rval
    thunk()
  File "/root/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/link/c/basic.py", line 1786, in __call__
    raise exc_value.with_traceback(exc_trace)
AssertionError: All variables shape[0] in Minibatch should be equal, check your Minibatch(data1, data2, ...) code

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[61], line 3
      1 # m = SNV_depth_GLM_VI_1bc_pm5_v11(trainDict, 8, 4, 4)
      2 # m = SNV_depth_GLM_1bc_v45(trainDict, 6, 3)
----> 3 m = minibatchVersion_simplified()

File ~/Codes/Bagger/development/../../Bagger/bagger/models_pm5.py:179, in minibatchVersion_simplified()
    174     psi_p = pm.Beta('psi_p', alpha=2, beta=5, dims='dimA')
    175     AD_predicted = pm.ZeroInflatedBinomial('AD_predicted', psi=psi_p, n=TD_obs_mt,
    176                                            p=pm.invlogit(-ER_p), observed=AD_obs_mt,
    177                                            total_size=dataDict['AD'].shape)
--> 179     approx = pm.fit(1_000, callbacks=[pm.callbacks.CheckParametersConvergence(tolerance=1e-4)])
    181 return m

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pymc/variational/inference.py:740, in fit(n, method, model, random_seed, start, start_sigma, inf_kwargs, **kwargs)
    738 method = method.lower()
    739 if method in _select:
--> 740     inference = _select[method](model=model, **inf_kwargs)
    741 else:
    742     raise KeyError(f"method should be one of {set(_select.keys())} or Inference instance")

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pymc/variational/inference.py:456, in ADVI.__init__(self, *args, **kwargs)
    455 def __init__(self, *args, **kwargs):
--> 456     super().__init__(MeanField(*args, **kwargs))

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pymc/variational/approximations.py:338, in SingleGroupApproximation.__init__(self, *args, **kwargs)
    336 def __init__(self, *args, **kwargs):
    337     groups = [self._group_class(None, *args, **kwargs)]
--> 338     super().__init__(groups, model=kwargs.get("model"))

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pymc/variational/opvi.py:1175, in Approximation.__init__(self, groups, model)
   1173         raise GroupError("No approximation is specified for the rest variables")
   1174     else:
-> 1175         rest.__init_group__(unseen_free_RVs)
   1176         self.groups.append(rest)
   1177 self.model = model

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/configparser.py:47, in _ChangeFlagsDecorator.__call__.<locals>.res(*args, **kwargs)
     44 @wraps(f)
     45 def res(*args, **kwargs):
     46     with self:
---> 47         return f(*args, **kwargs)

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pymc/variational/approximations.py:71, in MeanFieldGroup.__init_group__(self, group)
     69 @pytensor.config.change_flags(compute_test_value="off")
     70 def __init_group__(self, group):
---> 71     super().__init_group__(group)
     72     if not self._check_user_params():
     73         self.shared_params = self.create_shared_params(
     74             self._kwargs.get("start", None), self._kwargs.get("start_sigma", None)
     75         )

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/configparser.py:47, in _ChangeFlagsDecorator.__call__.<locals>.res(*args, **kwargs)
     44 @wraps(f)
     45 def res(*args, **kwargs):
     46     with self:
---> 47         return f(*args, **kwargs)

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pymc/variational/opvi.py:849, in Group.__init_group__(self, group)
    844 self.input = self._input_type(self.__class__.__name__ + "_symbolic_input")
    845 # I do some staff that is not supported by standard __init__
    846 # so I have to to it by myself
    847 
    848 # 1) we need initial point (transformed space)
--> 849 model_initial_point = self.model.initial_point(0)
    850 # 2) we'll work with a single group, a subset of the model
    851 # here we need to create a mapping to replace value_vars with slices from the approximation
    852 start_idx = 0

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pymc/model.py:1126, in Model.initial_point(self, random_seed)
   1113 def initial_point(self, random_seed: SeedSequenceSeed = None) -> Dict[str, np.ndarray]:
   1114     """Computes the initial point of the model.
   1115 
   1116     Parameters
   (...)
   1124         Maps names of transformed variables to numeric initial values in the transformed space.
   1125     """
-> 1126     fn = make_initial_point_fn(model=self, return_transformed=True)
   1127     return Point(fn(random_seed), model=self)

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pymc/initial_point.py:152, in make_initial_point_fn(model, overrides, jitter_rvs, default_strategy, return_transformed)
    149 # Replace original rng shared variables so that we don't mess with them
    150 # when calling the final seeded function
    151 initial_values = replace_rng_nodes(initial_values)
--> 152 func = compile_pymc(inputs=[], outputs=initial_values, mode=pytensor.compile.mode.FAST_COMPILE)
    154 varnames = []
    155 for var in model.free_RVs:

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pymc/pytensorf.py:1104, in compile_pymc(inputs, outputs, random_seed, mode, **kwargs)
   1102 opt_qry = mode.provided_optimizer.including("random_make_inplace", check_parameter_opt)
   1103 mode = Mode(linker=mode.linker, optimizer=opt_qry)
-> 1104 pytensor_function = pytensor.function(
   1105     inputs,
   1106     outputs,
   1107     updates={**rng_updates, **kwargs.pop("updates", {})},
   1108     mode=mode,
   1109     **kwargs,
   1110 )
   1111 return pytensor_function

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/compile/function/__init__.py:315, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input)
    309     fn = orig_function(
    310         inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name
    311     )
    312 else:
    313     # note: pfunc will also call orig_function -- orig_function is
    314     #      a choke point that all compilation must pass through
--> 315     fn = pfunc(
    316         params=inputs,
    317         outputs=outputs,
    318         mode=mode,
    319         updates=updates,
    320         givens=givens,
    321         no_default_updates=no_default_updates,
    322         accept_inplace=accept_inplace,
    323         name=name,
    324         rebuild_strict=rebuild_strict,
    325         allow_input_downcast=allow_input_downcast,
    326         on_unused_input=on_unused_input,
    327         profile=profile,
    328         output_keys=output_keys,
    329     )
    330 return fn

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py:367, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph)
    353     profile = ProfileStats(message=profile)
    355 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
    356     params,
    357     outputs,
   (...)
    364     fgraph=fgraph,
    365 )
--> 367 return orig_function(
    368     inputs,
    369     cloned_outputs,
    370     mode,
    371     accept_inplace=accept_inplace,
    372     name=name,
    373     profile=profile,
    374     on_unused_input=on_unused_input,
    375     output_keys=output_keys,
    376     fgraph=fgraph,
    377 )

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/compile/function/types.py:1751, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph)
   1749 try:
   1750     Maker = getattr(mode, "function_maker", FunctionMaker)
-> 1751     m = Maker(
   1752         inputs,
   1753         outputs,
   1754         mode,
   1755         accept_inplace=accept_inplace,
   1756         profile=profile,
   1757         on_unused_input=on_unused_input,
   1758         output_keys=output_keys,
   1759         name=name,
   1760         fgraph=fgraph,
   1761     )
   1762     with config.change_flags(compute_test_value="off"):
   1763         fn = m.create(defaults)

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/compile/function/types.py:1524, in FunctionMaker.__init__(self, inputs, outputs, mode, accept_inplace, function_builder, profile, on_unused_input, fgraph, output_keys, name, no_fgraph_prep)
   1521 self.fgraph = fgraph
   1523 if not no_fgraph_prep:
-> 1524     self.prepare_fgraph(inputs, outputs, found_updates, fgraph, mode, profile)
   1526 assert len(fgraph.outputs) == len(outputs + found_updates)
   1528 # The 'no_borrow' outputs are the ones for which that we can't
   1529 # return the internal storage pointer.

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/compile/function/types.py:1416, in FunctionMaker.prepare_fgraph(inputs, outputs, additional_outputs, fgraph, mode, profile)
   1409 rewrite_time = None
   1411 with config.change_flags(
   1412     mode=mode,
   1413     compute_test_value=config.compute_test_value_opt,
   1414     traceback__limit=config.traceback__compile_limit,
   1415 ):
-> 1416     rewriter_profile = rewriter(fgraph)
   1418     end_rewriter = time.perf_counter()
   1419     rewrite_time = end_rewriter - start_rewriter

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py:127, in GraphRewriter.__call__(self, fgraph)
    125 def __call__(self, fgraph):
    126     """Rewrite a `FunctionGraph`."""
--> 127     return self.rewrite(fgraph)

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py:123, in GraphRewriter.rewrite(self, fgraph, *args, **kwargs)
    114 """
    115 
    116 This is meant as a shortcut for the following::
   (...)
    120 
    121 """
    122 self.add_requirements(fgraph)
--> 123 return self.apply(fgraph, *args, **kwargs)

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py:294, in SequentialGraphRewriter.apply(self, fgraph)
    292 nb_nodes_before = len(fgraph.apply_nodes)
    293 t0 = time.perf_counter()
--> 294 sub_prof = rewriter.apply(fgraph)
    295 l.append(float(time.perf_counter() - t0))
    296 sub_profs.append(sub_prof)

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py:2461, in EquilibriumGraphRewriter.apply(self, fgraph, start_from)
   2459 nb = change_tracker.nb_imported
   2460 t_rewrite = time.perf_counter()
-> 2461 sub_prof = grewrite.apply(fgraph)
   2462 time_rewriters[grewrite] += time.perf_counter() - t_rewrite
   2463 sub_profs.append(sub_prof)

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py:2043, in WalkingGraphRewriter.apply(self, fgraph, start_from)
   2041             continue
   2042         current_node = node
-> 2043         nb += self.process_node(fgraph, node)
   2044     loop_t = time.perf_counter() - t0
   2045 finally:

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py:1928, in NodeProcessingGraphRewriter.process_node(self, fgraph, node, node_rewriter)
   1926 except Exception as e:
   1927     if self.failure_callback is not None:
-> 1928         self.failure_callback(
   1929             e, self, [(x, None) for x in node.outputs], node_rewriter, node
   1930         )
   1931         return False
   1932     else:

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py:1781, in NodeProcessingGraphRewriter.warn_inplace(cls, exc, nav, repl_pairs, node_rewriter, node)
   1779 if isinstance(exc, InconsistencyError):
   1780     return
-> 1781 return cls.warn(exc, nav, repl_pairs, node_rewriter, node)

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py:1769, in NodeProcessingGraphRewriter.warn(cls, exc, nav, repl_pairs, node_rewriter, node)
   1765     pdb.post_mortem(sys.exc_info()[2])
   1766 elif isinstance(exc, AssertionError) or config.on_opt_error == "raise":
   1767     # We always crash on AssertionError because something may be
   1768     # seriously wrong if such an exception is raised.
-> 1769     raise exc

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py:1925, in NodeProcessingGraphRewriter.process_node(self, fgraph, node, node_rewriter)
   1923 assert node_rewriter is not None
   1924 try:
-> 1925     replacements = node_rewriter.transform(fgraph, node)
   1926 except Exception as e:
   1927     if self.failure_callback is not None:

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/graph/rewriting/basic.py:1084, in FromFunctionNodeRewriter.transform(self, fgraph, node)
   1079     if not (
   1080         node.op in self._tracks or isinstance(node.op, self._tracked_types)
   1081     ):
   1082         return False
-> 1084 return self.fn(fgraph, node)

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/tensor/rewriting/basic.py:1142, in constant_folding(fgraph, node)
   1139     compute_map[o] = [False]
   1141 thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[])
-> 1142 required = thunk()
   1144 # A node whose inputs are all provided should always return successfully
   1145 assert not required

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/link/c/op.py:103, in COp.make_c_thunk.<locals>.rval()
    101 @is_cthunk_wrapper_type
    102 def rval():
--> 103     thunk()
    104     for o in node.outputs:
    105         compute_map[o][0] = True

File ~/anaconda3/envs/pm5/lib/python3.11/site-packages/pytensor/link/c/basic.py:1786, in _CThunk.__call__(self)
   1784     print(self.error_storage, file=sys.stderr)
   1785     raise
-> 1786 raise exc_value.with_traceback(exc_trace)

AssertionError: All variables shape[0] in Minibatch should be equal, check your Minibatch(data1, data2, ...) code

I don’t understand this error because simply running the following did not raise such error:

rng = np.random.default_rng(123)
samplePerPos = 20
nPos = 100
dataDict = {
    'TD': rng.normal(3800, 500, size=(samplePerPos, nPos)).astype(int),
    'AD': rng.poisson(1, size=(samplePerPos, nPos)),
    'covariate1': rng.integers(low=0, high=16, size=nPos)}
posBatchSize = 4
TD_obs_m, AD_obs_m, c1_idx_m = pm.Minibatch(
            dataDict['TD'].T,
            dataDict['AD'].T,
            dataDict['covariate1'], batch_size=posBatchSize)

Thanks, that was perfect (adding imports is also helpful but not a problem). I could narrow it down to this failing case:

import pymc as pm
import numpy as np

A = rng.normal(size=(1000, 100))
B = rng.normal(size=(1000, 100))

mA, mB = pm.Minibatch(A, B, batch_size=10)
mA.eval().shape
# AssertionError: All variables shape[0] in Minibatch should be equal, check your Minibatch(data1, data2, ...) code

I’ll open an issue on our repo!

I also replicated your issue when using minibatch_index directly… Will try to sort it out in the next days.