Weird posterior predictive shape when only one value is in test set?

So, basically when i use posterior predictive distribution for logistic regression, when number of points in test set is greater then 1, posterior predictive shape is as acting as we would expect. But, when it is exactly 1, posterior predictive shape becomes awkward? Would you check code example below, it should be easily runnable, just trying changing EXAMPLES_IN_TEST_SET from 25 to 1.

import numpy as np 
import pandas as pd 
import pymc3 as pm

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

EXAMPLES_IN_TEST_SET = 1 #try changing this from 25 to 1 

X1= np.random.normal(0,10,100)
X2= np.random.normal(2,10,100)


df = pd.DataFrame({"X1":X1,"X2":X2,"y":y})

X = scaler.fit_transform(df[["X1","X2"]])
y = 1*(df["y"].values>0.5)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=EXAMPLES_IN_TEST_SET/100)

def train(X,y):
    model = pm.Model()
    with model:
        alpha = pm.Normal("alpha",mu=0.0,sigma=np.sqrt(5))
        betas = pm.Normal("betas", mu=0.0, sigma=np.ones(X.shape[1])*np.sqrt(5), shape=X.shape[1])

        # set predictors as shared variable to change them for PPCs:
        data = pm.Data("data", X)
        p = pm.Deterministic("p", pm.math.invlogit(alpha +,betas)))

        outcome = pm.Bernoulli("outcome", p=p, observed=y)
        trace = pm.sample(return_inferencedata=True) # draw 3000 posterior samples using NUTS sampling
    return model, trace

def predict(model, trace, X):
    with model:
        pm.set_data({"data": X})
        post_pred = pm.sample_posterior_predictive(trace)
    return post_pred["outcome"]

model, trace = train(X_train,y_train)

predicted = predict(model,trace,X_test)
#when EXAMPLES_IN_TEST_SET is set to i.e. 25, sample_posterior_predictive["outcome"] shape is (25,2000)
#if you change EXAMPLES_IN_TEST_SET to 1 (only one example), sample_posterior_predictive["outcome"] shape becomes (2000,99)
1 Like

Yes, this is a know issue (ticket 3640) that should be fixed with the new major version and use of Aesara