Vector Output GP's

Is it possible to build a GP model in pymc such that a Latent GP produces vector outputs rather than scalars? 'm interested in this because I often have to model multiple variable types (i.e. multiclass, binary, multilabel classification, and regression) and would like a single model for all of them rather than many small ones. My idea is that if a GP can output vectors, these can be turned to probability distributions with softmax (for multiclass), singular probabilities with sigmoids etc, similar to what you can do with a neural net.

After much experimentation, I have a model that works (at least it samples):

with pymc.Model() as categorical_simple_model:
    factors=5
    N,M=X_train.shape
    gps=[ ]
    _f=[ ]
    ϵ = 1e-9
    μ = pymc.gp.mean.Constant(0)
    σ_c=pymc.Normal('σ_c', mu=.7, sigma=.2)
    σ_w = pymc.MvNormal('σ_w', mu=np.ones(M)*3.1, cov=σ_c*np.eye(M), shape=M)
    σ_b = pymc.Normal('σ_b',mu=3.1, sigma=.7)
    _η = pymc.Exponential('_η', lam=2.0)
    ϵ = 1e-9
    η = pymc.Deterministic('η', _η+ϵ)
    κ = MultiLayerPerceptronKernel(M, variance=η**2, bias_variance=σ_b**2, weight_variance=σ_w**2)
    for factor,loc in  list(core_mappings[('origin', 'location', 'loc')].items())[:]:
        gp = pymc.gp.Latent(mean_func=μ, cov_func=κ)
        gps.append(gp)
        f = gp.prior('f_{loc}'.format(loc=loc),
            X=X_train.values)
        _f.append(f)
    f = pymc.Deterministic('f', at.math.sigmoid(at.stack(*_f)) )
    p = pymc.Deterministic('p', f/f.sum(axis=0))
    y = pymc.Categorical('y', p=p ,observed=Y_train.values, shape=Y_train.shape)
with categorical_simple_model:
    fs = []
    for (factor, loc), gp in  zip(list(core_mappings[('origin', 'location', 'loc')].items()), gps):
        f = gp.conditional(f'_f_star_{loc}1',Xnew=X_test.values)
        fs.append(f)
    f_star = pymc.Deterministic('f_star', at.stack(*fs) )
    p_star = pymc.Deterministic('p_star', pymc.math.sigmoid(f_star)/pymc.math.sigmoid(f_star).sum(axis=0) )

This seems to be producing plausible-looking outputs. However I’ve no idea if it’s the model I’m trying to formulate. I’m particularly confused about pymc.Categorical. Without the shape parameter it tends to throw shape errors which are hard to interpret. Furthermore the model seems to work ok both when Y is on-hit-encoded and when it just a single column of integers. Can someone clarify what exactly Categorical expects to see?

Categorical expects integer outputs, 0, 1, 2, … n. When you hot-encode it assumes it is seeing multiple independent observations mostly of zeros and some ones now and then. That’s not what you want. You can use multinomial for hot encoding.

1 Like

Thanks for the reply. After more tinkering, I’ve started to run into sampling errors:

with pymc.Model() as categorical_simple_model:
    factors=5
    N,M=X_train.shape
    gps=[]
    fs=[]
    for factor,loc in  list(core_mappings[('origin', 'location', 'loc')].items()):
        μ = pymc.gp.mean.Constant(0)
        σ_w = pymc.Normal(f'σ_w_{loc}', mu=10.0, sigma=.7, shape=M)
        σ_b= pymc.Normal(f'σ_b_{loc}', mu=.0, sigma=.01)
        η = pymc.Normal(f'η_{loc}', mu=1.0, sigma=0.01)
        κ =  MultiLayerPerceptronKernel(M, variance=η**2, bias_variance=σ_b**2, weight_variance=σ_w**2)
        gp = pymc.gp.Latent(mean_func=μ, cov_func=κ)
        gps.append(gp)
        f = gp.prior('f_{loc}'.format(loc=loc),
            X=X_train.values)
        fs.append(f)
    _f = pymc.Deterministic('_f',at.stack(*fs).T)
    f = pymc.Deterministic('f',  at.math.sigmoid(_f))
    _s = pymc.Deterministic('s_', f.sum(axis=1)[:,None])
    dummy = np.tile(np.array([1.0]+[.0]*(factors-1))[None,:], (N,1))
    p = pymc.Deterministic('p', at.switch(
            at.eq(_s,at.zeros_like(_s)), dummy, f/_s )
        )
    y = pymc.Categorical('y', p=p ,observed=Y_train.values[:,0], shape=Y_train.shape[0])

(Note my priors are all in an incoherent state after endless trial and error). Checking prior outputs of each GP yields somewhat sensible results:

Yet I get hit with generic sampling errors:

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...

---------------------------------------------------------------------------
SamplingError                             Traceback (most recent call last)
Cell In [171], line 2
      1 with categorical_simple_model:
----> 2     idata = pymc.sample(draws= 1000, tune=1500, cores=2, chains=2, random_seed=44)

File /media/alexander-fyrogenis/Elements/Διδακτορικό/Olive Oil/notebooks/venv/lib/python3.10/site-packages/pymc/sampling.py:561, in sample(draws, step, init, n_init, initvals, trace, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
    559 # One final check that shapes and logps at the starting points are okay.
    560 for ip in initial_points:
--> 561     model.check_start_vals(ip)
    562     _check_start_shape(model, ip)
    564 sample_args = {
    565     "draws": draws,
    566     "step": step,
   (...)
    575     "discard_tuned_samples": discard_tuned_samples,
    576 }

File /media/alexander-fyrogenis/Elements/Διδακτορικό/Olive Oil/notebooks/venv/lib/python3.10/site-packages/pymc/model.py:1801, in Model.check_start_vals(self, start)
   1798 initial_eval = self.point_logps(point=elem)
   1800 if not all(np.isfinite(v) for v in initial_eval.values()):
-> 1801     raise SamplingError(
   1802         "Initial evaluation of model at starting point failed!\n"
   1803         f"Starting values:\n{elem}\n\n"
   1804         f"Initial evaluation results:\n{initial_eval}"
   1805     )

SamplingError: Initial evaluation of model at starting point failed!
Starting values:
{'σ_w_LESVOS': array([ 9.15389669, 10.85270312,  9.44986538, 10.65616583,  9.86568392,
        9.55440129, 10.93130078, 10.940588  ,  9.82515486,  9.52135951,
       10.74580886,  9.58950133, 10.43082604,  9.17414448,  9.22438013,
       10.55692819, 10.01655802, 10.58205922,  9.02345872,  9.74216879,
       10.46810645, 10.07352035,  9.28861348,  9.60086126, 10.56862279,
       10.70950238,  9.10977794,  9.77351756,  9.3025502 , 10.0548006 ,
       10.19049653,  9.08146313,  9.76030827,  9.46656961, 10.12176842,
        9.34737303, 10.27021216, 10.10108567,  9.85696949,  9.14323312,
       10.07580724, 10.14372062, 10.14633824,  9.07256442,  9.36337599,
       10.7717392 ,  9.37710522, 10.14230229,  9.75886691,  9.68385258,
       10.73053422,  9.87517513,  9.25581399,  9.91026346,  9.27934233,
        9.62345242, 10.11493438,  9.08867222, 10.82228144,  9.02693697,
       10.50623103,  9.08846799,  9.22064129]), 'σ_b_LESVOS': array(0.70921005), 'η_LESVOS': array(1.62290096), 'f_LESVOS_rotated_': array([ 0.48606202, -0.05750464,  0.39775823, -0.73506174,  0.18924769,
       -0.79205822, -0.60330083, -0.91940572,  0.30999119,  0.74735712,
        0.56623856,  0.83417569, -0.05784597, -0.21402646, -0.96875847,
       -0.92535721,  0.84683007, -0.82805718,  0.19342466,  0.78279287,
       -0.06573217, -0.0984265 , -0.03966686,  0.69550449,  0.49175055,
       -0.62112312, -0.29253861,  0.51988189, -0.48311947, -0.05685895,
       -0.87701023, -0.49639025,  0.47754525, -0.85901368,  0.76777466,
        0.60161249,  0.11460279, -0.4222967 , -0.3333136 , -0.96208543,
       -0.75479352, -0.67267036, -0.31673458,  0.2832573 ,  0.13782709,
        0.85818576, -0.33240687, -0.06605224, -0.89846575, -0.41776298,
        0.36399779,  0.12166975, -0.55771186,  0.10527268,  0.60839088,
        0.50254195, -0.05602044, -0.49496405,  0.40177229,  0.77954297,
        0.43582885, -0.16038181, -0.04358258,  0.38788849, -0.4015904 ,
        0.96939575, -0.52990232,  0.81265588,  0.06512055,  0.90973364,
       -0.46138376, -0.85866802, -0.02531684,  0.30186283, -0.26810522,
       -0.44478761,  0.18333775, -0.63726822,  0.6975101 ,  0.65135364,
        0.9892628 , -0.0553152 ,  0.13955964,  0.74278527, -0.34539583,
        0.61756714, -0.63967225,  0.5944891 ,  0.31933709,  0.22292139,
        0.18771531,  0.13105963,  0.34021337, -0.86983251, -0.22363494,
       -0.41914142, -0.33309724, -0.24702408,  0.45128761, -0.47457353,
        0.38379287, -0.75904167, -0.77078648,  0.50091378, -0.29534457,
       -0.385392  ,  0.91488685, -0.02172829, -0.57859994, -0.32275252,
        0.23209637, -0.4291677 ,  0.57120151,  0.90852484, -0.51146913,
        0.82481469, -0.79504096, -0.37723923,  0.10826941, -0.12484881,
       -0.11216902, -0.58125022,  0.06556108, -0.72667636,  0.83587009,
        0.29247122, -0.12196493, -0.71611837,  0.86396981, -0.35988189,
       -0.12384023,  0.45293806, -0.77279569, -0.9910372 ,  0.20212587,
       -0.92086343, -0.22572825, -0.12506967,  0.8051224 ,  0.03835518,
        0.87467874, -0.59992519, -0.10090282, -0.9048303 ,  0.83248862,
        0.83882907,  0.40730734, -0.92124882, -0.42170137, -0.08626875,
        0.77283201, -0.82084868, -0.14841494,  0.45029237,  0.27661259,
        0.44060156, -0.33672568, -0.46138687, -0.88363887, -0.07689968,
        0.31085364, -0.98858449, -0.31700121,  0.21144317,  0.06261854,
        0.36236879,  0.54185604,  0.23242172, -0.26541738,  0.73528826,
        0.36156801, -0.36288286, -0.17884688, -0.41012916,  0.60446187,
        0.67412714,  0.47165996,  0.12175196,  0.64446294, -0.07718248,
       -0.10220434, -0.57194466,  0.67438814, -0.36951548,  0.97305376,
        0.70666903, -0.67239419, -0.21630591, -0.13992048,  0.6942951 ,
        0.9492795 ,  0.60840455, -0.24632258, -0.77791368, -0.42730229,
       -0.1190319 ,  0.93997779,  0.40344643, -0.9107019 ,  0.12310798,
        0.67055378,  0.78422849,  0.72392885,  0.9053756 ,  0.20930128,
        0.11611487,  0.94173718,  0.31018225, -0.26292217, -0.20211832,
        0.68700175,  0.34338105,  0.09102248,  0.16590437,  0.39013556,
       -0.60289715,  0.85420581,  0.78423054, -0.77061865, -0.53076217,
        0.50791463, -0.74316801,  0.44564758, -0.60896336, -0.50457433,
        0.50039073, -0.56013186, -0.41561322,  0.8720142 , -0.50067418,
        0.76287962,  0.68348549, -0.86013304,  0.95458716,  0.68242175,
        0.09745051, -0.04145826,  0.75769   , -0.19702269, -0.0338294 ,
        0.71996494, -0.4837162 , -0.09077571,  0.94018659,  0.3729279 ,
       -0.36230459,  0.77445179, -0.72677449,  0.48078482,  0.50581907,
        0.99125788,  0.14528036, -0.3164143 , -0.35214186,  0.05464637,
        0.33523223, -0.49240213, -0.14493891,  0.58246109, -0.34688059,
       -0.94206736,  0.89684004, -0.7421782 , -0.81153334,  0.02151323,
        0.9434007 , -0.74797198, -0.98670408, -0.15924378, -0.75799781,
        0.14191314, -0.75861898,  0.73981023, -0.02256103,  0.93160883,
        0.87281622,  0.63372847, -0.75374709,  0.86407595, -0.83897426,
       -0.65116815, -0.13878746,  0.68521175,  0.27816359,  0.64666496,
       -0.68166694,  0.19674112, -0.75250003,  0.84833841, -0.05258209,
       -0.50343253,  0.18088408, -0.87896426,  0.57941075, -0.78760233,
       -0.20115428, -0.12380109, -0.40550318, -0.63408434,  0.79398135,
       -0.73878101,  0.03437339,  0.18816869, -0.24104893, -0.91192019,
       -0.80004882, -0.87585259,  0.538598  ,  0.85966293,  0.53485427,
       -0.79257814, -0.54344126, -0.07192188,  0.69661188, -0.99336402,
        0.65856514, -0.9859744 ,  0.71973767,  0.16406195, -0.64105502,
       -0.73872274, -0.89386223,  0.56795237,  0.46226757, -0.79521809,
        0.38258295, -0.47987754, -0.91335276,  0.30390317,  0.01575346,
       -0.06892708, -0.73739227, -0.94395353, -0.78706455, -0.73927979,
        0.21979897,  0.83908211,  0.36305091,  0.36759142,  0.17915262,
       -0.75946344,  0.56188368,  0.82582536, -0.86403284,  0.70741271,
        0.55239022, -0.0237508 , -0.61855772,  0.91165122, -0.70771664,
       -0.78505166, -0.93856211,  0.16596274,  0.08685433, -0.35224752,
        0.13862893,  0.1314268 ,  0.73904479,  0.15068114,  0.06561571,
       -0.67232765, -0.53999699,  0.32825998, -0.89981346,  0.13918157,
       -0.02071166, -0.25270416, -0.47887448, -0.80040932, -0.52179009,
       -0.93727722,  0.11938124, -0.60214144,  0.92809085,  0.15069965,
        0.12023757, -0.4319954 ,  0.95014544,  0.17716056, -0.83042521,
        0.86325665,  0.75015591, -0.82928655,  0.81665911,  0.58666328,
        0.53280148,  0.38688033, -0.27655464,  0.59784245, -0.87440181,
       -0.81036219, -0.54573094, -0.31516507, -0.69884707,  0.70818205,
        0.51771508, -0.84546186,  0.49186937,  0.1268696 ]), 'σ_w_IKARIA': array([ 9.21494362, 10.69765723, 10.52219506, 10.82648558,  9.62780153,
       10.62089311, 10.67443963, 10.28101984,  9.94223576,  9.10331004,
       10.51714201,  9.90448965, 10.94807566, 10.67459207,  9.18998324,
       10.63577916,  9.1964881 , 10.44334554, 10.69654897,  9.5197523 ,
        9.50610035, 10.42415076, 10.77374167,  9.17221711, 10.79809018,
        9.35620927, 10.27234331, 10.89739256,  9.84993142, 10.72389463,
        9.10432842,  9.84038286, 10.21474642]), 'σ_b_IKARIA': array(-0.47105749), 'η_IKARIA': array(1.50507627), 'f_IKARIA_rotated_': array([-0.27948203, -0.39037051, -0.80441018, -0.2028417 ,  0.33666347,
       -0.91366135, -0.11786481,  0.9936126 , -0.92285819,  0.44724691,
        0.23186736, -0.28812537,  0.29880623, -0.06649264,  0.47457421,
        0.93279294, -0.69575604, -0.03617461,  0.50512821, -0.6356363 ,
       -0.3298101 ,  0.62306306, -0.89120689,  0.68278601,  0.93174378,
        0.95517806, -0.27287669,  0.63782505,  0.13242919, -0.88263241,
        0.56662343,  0.86178375, -0.11706782, -0.94725375,  0.6196216 ,
        0.50170609,  0.33706222,  0.25419795, -0.7354072 ,  0.74212949,
        0.91508001, -0.77096085, -0.35835153, -0.29613842,  0.93758883,
       -0.33040874, -0.70027456, -0.26470037,  0.06104933, -0.97368939,
        0.47830373,  0.68366075,  0.25691619, -0.55541296, -0.87332899,
        0.21011257, -0.72990397, -0.48704953, -0.50603564, -0.97427106,
       -0.63992956,  0.96362526, -0.32897696, -0.65462311, -0.86076531,
       -0.21689051,  0.437032  , -0.90258268, -0.18465247,  0.65768317,
        0.10505895, -0.53424249,  0.35707991, -0.03910111, -0.5782608 ,
       -0.70724433,  0.27216229, -0.73573242,  0.2686824 , -0.19995845,
       -0.11793359, -0.52450426,  0.24411859,  0.77099508, -0.63883249,
        0.09972287,  0.28418921, -0.78116082,  0.30795318,  0.83271334,
        0.11929135, -0.63090027, -0.63922105,  0.32525192,  0.9949957 ,
       -0.52158884, -0.11099115,  0.78514752,  0.93714179,  0.65663547,
        0.4551711 , -0.89977151,  0.51874976,  0.52851773,  0.63470914,
        0.32690264, -0.10442083, -0.22657619,  0.47078002,  0.75417924,
       -0.1769126 , -0.87948168,  0.63648271, -0.87532117, -0.78362362,
        0.59663556, -0.8766123 ,  0.331311  ,  0.17902656,  0.35563398,
        0.21281197,  0.01168888,  0.57534042,  0.29649029,  0.90763528,
        0.89161291,  0.01573708,  0.12611317,  0.12109387, -0.15383683,
        0.17758962,  0.65371869, -0.36597453, -0.20420019,  0.97426758,
        0.13985259, -0.90608478,  0.67376478,  0.95558517,  0.05439007,
       -0.11691677, -0.24941819, -0.84215802, -0.38252799,  0.16286306,
        0.90463155,  0.72657282,  0.84332545, -0.34549113,  0.18340231,
       -0.64138555, -0.13969498, -0.67951917,  0.26504536,  0.47107957,
        0.31325995, -0.91330355, -0.77977864, -0.22834488, -0.82950776,
       -0.43478197,  0.4317437 , -0.80898551,  0.58372512,  0.23792121,
        0.53531073, -0.24918119, -0.64844689, -0.24058978, -0.47535967,
       -0.3388392 , -0.04526641, -0.18171857, -0.20685575, -0.03598036,
       -0.53220813, -0.35585681,  0.94456868,  0.11598018,  0.48441732,
        0.75093342,  0.68660801,  0.14298974,  0.32119572,  0.47927278,
        0.70346316, -0.11036186, -0.58045823,  0.53553899,  0.74019028,
       -0.88170812,  0.14287552, -0.19237979,  0.01744791,  0.33012509,
        0.31516136,  0.57154341,  0.63324059,  0.55171728,  0.43273322,
        0.67060517, -0.91263772,  0.39502389, -0.73515489,  0.17602126,
       -0.58788015, -0.46186564, -0.52755813, -0.83745151,  0.50575306,
        0.22695921,  0.22509206, -0.97794513,  0.78102943, -0.39314543,
       -0.48351958,  0.69888094,  0.35290789, -0.74973345, -0.10720196,
       -0.21770319, -0.19948587, -0.4669718 ,  0.09693138,  0.21184672,
        0.54150039,  0.73848466, -0.33167776, -0.63046675, -0.48865088,
        0.65588298,  0.30040299,  0.87153996,  0.35498385,  0.50127758,
        0.74501066, -0.76349962, -0.65442908, -0.41294734,  0.44716831,
       -0.55596078,  0.83651696,  0.5142481 ,  0.46857172, -0.46044239,
        0.28445636,  0.10714396, -0.11700016, -0.86218257, -0.04477057,
        0.34355535, -0.14767439, -0.79204087,  0.64902172,  0.65239501,
        0.14668265, -0.25806695,  0.22708576,  0.80371109, -0.1216073 ,
       -0.11337834,  0.92453582, -0.98715673, -0.86063292, -0.75003766,
        0.17585209,  0.78084521, -0.85198475, -0.94108158,  0.19619444,
       -0.08831064,  0.59278515, -0.0244028 ,  0.98702759,  0.93913542,
       -0.33784074, -0.4574391 , -0.90798093, -0.70322187,  0.19524623,
       -0.99871723, -0.50801984, -0.16170868,  0.85088972, -0.67442472,
        0.65773729, -0.40098471,  0.09161588,  0.83786002,  0.54346601,
        0.01736629,  0.10784845,  0.32900574,  0.15317762,  0.55516649,
        0.61755485, -0.71300043, -0.58162011, -0.91731162, -0.14326466,
       -0.94322453,  0.16225505,  0.15473652, -0.49417958, -0.34964539,
       -0.21733528, -0.58379056, -0.85234656,  0.47416692,  0.70713682,
       -0.03221922,  0.03523189,  0.04571195,  0.17054439, -0.37914808,
       -0.9545246 , -0.83439622, -0.23108025, -0.24650903]), 'σ_w_SAMOS': array([10.18692239, 10.15548637,  9.6890952 ,  9.85771665,  9.21202937,
       10.98213583,  9.63121548, 10.59680639,  9.45200574, 10.65284439,
        9.81525158, 10.97862545,  9.04996588,  9.8200355 , 10.20277911,
        9.93994377,  9.25621225, 10.12314184,  9.98556261, 10.81837332,
       10.14758508, 10.61628615,  9.11069931, 10.87425799,  9.11678893,
       10.37474661, 10.42492221,  9.76238371,  9.72092666, 10.05785007,
        9.52149489,  9.20260739,  9.37123183, 10.91415715, 10.48881682,
        9.49054663, 10.98986968, 10.86928939, 10.50607239, 10.74960645,
        9.89163306, 10.94083801, 10.36574912]), 'σ_b_SAMOS': array(0.08874031), 'η_SAMOS': array(1.84375236), 'f_SAMOS_rotated_': array([ 6.72047424e-01,  8.99247179e-04, -8.68739473e-01,  8.77473628e-01,
        8.54988823e-01,  3.54413515e-01, -4.29107182e-01, -5.71963173e-02,
       -6.59897210e-01, -4.94928671e-01,  3.14148898e-02, -9.20920171e-01,
       -5.28165081e-01,  9.48962807e-01,  2.72881317e-01,  8.89889459e-01,
        8.41552321e-01,  4.79466518e-01, -5.45694971e-01,  3.59571284e-01,
       -8.66889456e-01, -2.96377154e-01, -1.98801726e-01,  1.39233115e-01,
        3.42756588e-01, -6.05313283e-01, -9.23766012e-01, -5.39733994e-01,
       -6.13123216e-01, -8.02285696e-01, -5.39906572e-01, -5.50472626e-01,
       -5.16047242e-01,  3.19877954e-01,  2.97925712e-01, -3.65818533e-01,
        5.09248145e-01, -3.63776122e-02,  1.30758010e-01,  1.39254296e-02,
       -3.37624658e-01, -5.73184253e-01,  1.58570715e-02,  6.04669290e-01,
        1.00045557e-02,  6.05803719e-01,  9.03520587e-01, -8.70227128e-01,
        4.41040589e-01, -3.38563828e-01,  8.25278294e-01,  3.95261538e-01,
        4.49434962e-01, -4.84723012e-01, -3.16089644e-01,  9.85856074e-01,
        9.02263931e-01,  4.08331871e-01,  3.93385682e-01,  8.60688966e-01,
       -9.01977495e-01, -9.08208671e-01,  3.71834812e-01,  3.66435823e-01,
       -7.85322041e-01, -3.33026307e-01,  2.78392188e-01, -2.86695240e-01,
       -1.01505997e-01, -7.23259696e-01,  9.43830496e-01, -2.49606772e-01,
        4.37556323e-01, -9.44962281e-02,  8.13106311e-01,  7.67737909e-01,
        6.18440795e-02, -9.74608045e-01,  5.92403757e-01,  3.40703552e-01,
        5.89602672e-01, -7.64037621e-01, -9.84788180e-01,  1.70642552e-01,
        6.17462135e-01, -1.23870282e-01,  8.29921192e-02,  5.96812072e-01,
        4.64164234e-01, -5.95147759e-01, -3.43670400e-01, -9.37060512e-01,
        7.77130728e-01,  8.26558118e-01,  2.64800319e-01, -3.21290329e-01,
        5.18619812e-01, -1.93036409e-01,  8.50219741e-01, -3.74988559e-01,
       -8.36921434e-01,  8.66810036e-01,  7.26266157e-01, -1.51680793e-02,
       -1.29088066e-01, -1.25795414e-01, -6.70603292e-01,  7.05885602e-01,
        5.03493255e-01, -5.97832274e-01,  3.50193186e-02,  7.85649819e-02,
       -1.67657861e-01, -4.97760331e-01, -5.75729714e-01,  8.24995861e-01,
       -5.17815035e-01, -3.17854552e-01, -4.19093049e-01,  6.63259811e-01,
       -3.69317749e-01, -6.51843984e-01, -6.35897037e-01,  9.08977183e-01,
       -8.79009691e-01,  9.26260153e-01, -9.34687046e-01,  7.57535804e-01,
        4.14710716e-01,  9.24928943e-01,  7.06090118e-01,  8.90402165e-01,
       -5.68926483e-01,  6.24782694e-01,  9.59451362e-01,  3.24981571e-01,
       -9.93714760e-01, -2.57546343e-01,  2.31703891e-01,  3.43771522e-01,
        4.04389444e-01,  7.31885593e-01,  1.69572976e-01,  8.38118073e-01,
        4.62325894e-01,  7.05893914e-01, -6.56247505e-01, -9.64997622e-01,
       -1.16845123e-01, -4.85002898e-01,  7.26642885e-02, -3.52441192e-01,
        1.65671929e-01,  8.21853791e-01, -6.61359365e-01, -1.15244422e-01,
        2.42187524e-01,  8.17286695e-01, -1.65606018e-01,  9.37500272e-01,
       -2.05241676e-01, -7.61401335e-01,  4.89017896e-01, -7.25363800e-01,
        8.16962289e-01,  3.09857373e-01, -7.98204641e-01,  8.45322463e-02,,
        9.27873440e-01, -4.01678049e-01,  1.68972192e-02,  2.59915567e-01,
        2.41189421e-01,  3.20008792e-01, -3.02608608e-01,  8.98069737e-01,
        4.75589063e-01, -4.73205343e-01,  2.67624178e-01, -3.08441235e-01,
       -8.53213259e-01, -2.27533001e-01,  8.02616549e-02,  9.16945169e-01,
       -8.89499333e-01,  7.54751311e-01,  3.22353879e-01, -5.17115937e-01,
       -2.94443829e-01, -6.47316228e-01, -5.19127286e-01]), 'σ_w_FOURNOI': array([ 9.5489369 , 10.77977131,  9.4538371 ,  9.33103961, 10.69335592,
        9.3333864 ,  9.27642861, 10.69094443, 10.21996653,  9.40913956,
       10.54219502,  9.05731878,  9.39658146,  9.77360533, 10.94081895,
       10.84177394, 10.09743025, 10.63296598, 10.51407624, 10.86785604,
       10.50029705, 10.51386372, 10.20350355, 10.57265985, 10.62788594,
       10.00872969,  9.19832164,  9.07303835,  9.41240352, 10.36151362,
        9.89984589, 10.94599143,  9.0420823 , 10.4690349 , 10.6610043 ,
       10.18399225,  9.14191398,  9.39271576, 10.68431681, 10.89876478,
       10.27226484,  9.91792894, 10.39164921,  9.69105533, 10.26322218,

       10.80251152,  9.60909037,  9.71315524]), 'σ_b_FOURNOI': array(-0.94781615), 'η_FOURNOI': array(0.80100611), 'f_FOURNOI_rotated_': array([-3.97228937e-01,  7.99523421e-02, -4.70200757e-01,  6.24252281e-01,
        3.61592850e-01,  6.92856896e-01, -9.58915052e-01, -1.50817891e-01,
       -8.23119959e-01,  2.01106346e-01,  5.88014219e-01, -8.78096956e-01,
       -2.66865559e-01, -1.97475308e-01, -5.74118623e-01,  7.45968951e-01,
       -5.86100627e-01, -3.30842680e-03, -1.73384394e-01, -3.06457049e-01,
       -9.67728365e-02,  6.41616905e-02, -9.02874015e-02, -2.83434487e-01,
        3.39228918e-01, -5.99231498e-02,  5.72319557e-01, -5.27547541e-01,
       -9.78259824e-01,  8.82324783e-01,  3.98367563e-01, -8.23504103e-01,
        2.29034833e-01,  5.08419355e-02, -4.68958866e-01,  4.30606611e-01,
       -7.82844705e-01, -7.24444332e-01,  4.15036651e-01,  4.88033805e-02,
       -9.89303817e-01, -1.95089676e-01, -9.37707220e-01,  3.41806124e-01,
        7.29085892e-01, -8.61538823e-01, -7.46261517e-01, -1.29583481e-01,
       -5.27773004e-01, -8.99770827e-01,  9.83137845e-01,  9.73918499e-02,
       -1.16046095e-01, -5.32323658e-01, -5.81883472e-01, -6.05581983e-01,
       -1.88182382e-01,  9.64779217e-01, -9.72327761e-01, -6.54042899e-01,
        8.51160442e-01,  5.33781483e-01, -2.40535622e-01,  9.23324376e-01,
       -3.69347807e-02, -1.69268385e-01, -8.34101390e-01, -3.43568300e-01,
       -2.27356770e-01, -9.09077343e-01,  2.55231860e-01,  3.41537746e-01,
       -9.11834903e-02,  1.01098700e-01, -5.17521455e-01,  9.38554298e-01,
       -9.94700004e-01,  8.39109963e-01,  5.58946607e-01,  6.02651877e-01,
        4.24727475e-03,  1.04099801e-01, -8.91609269e-02,  1.44787449e-01,
        6.67546964e-01, -9.07588118e-01, -3.42350003e-01,  5.46606747e-01,
        6.71862755e-01, -3.08514111e-01, -1.82284976e-01, -3.67059330e-01,
       -5.95956838e-01, -6.37005066e-01, -6.78453699e-01,  6.69278877e-01,
       -1.06206787e-01,  4.94111905e-01, -4.61558727e-01,  8.88696078e-01,
        2.77218747e-01,  1.67444080e-01, -3.02989870e-01, -8.34514364e-01,
        1.96356776e-01, -1.12623121e-01,  8.70064375e-01, -2.58101743e-01,
        4.89509938e-01, -3.14376592e-01, -7.02122601e-01,  9.30525996e-01,
       -5.35555340e-01,  3.16542107e-02, -6.57610632e-01, -4.88674770e-01,
       -6.96602203e-01, -8.95303894e-01, -1.18345528e-01,  7.16415312e-02,
        6.22768291e-01, -3.72859682e-01,  1.82658418e-01, -4.97489143e-01,
       -5.02119858e-01,  9.72305385e-01,  3.16558983e-01, -3.15023871e-01,
       -3.10047409e-01,  7.78602658e-01, -7.30367415e-02, -3.71267674e-01,
       -6.24387174e-01, -1.97703595e-01, -8.56258572e-01, -5.82475964e-01,
       -6.29914619e-01, -1.97560453e-01,  2.93745836e-01,  6.58612732e-01,
        6.95329386e-01,  4.36179228e-01,  8.13879030e-01, -5.32242436e-01,
        8.75491174e-01,  8.40942545e-02,  8.35713340e-01,  8.31605981e-01,
       -1.28009639e-01,  4.33529926e-01,  8.24683705e-01,  5.76037961e-01,
        2.56245525e-01,  9.07778175e-01, -5.61672803e-01,  7.39864576e-01,
        3.38142187e-01, -4.56562613e-01, -4.68771588e-01, -9.57035613e-01,
       -3.19151709e-01,  7.66949662e-01,  2.46579513e-01, -4.69592162e-01,
       -9.77265599e-01, -4.86177570e-01, -1.00060151e-01,  9.87048662e-01,
       -6.88439298e-01, -6.55885369e-01, -5.25664120e-01,  5.88978172e-01,
       -7.22709171e-01, -7.47925633e-01, -9.72897255e-01,  8.80817801e-01,
       -7.98941189e-01, -9.73064781e-01, -7.62033765e-01, -2.65681145e-01,
        1.24386781e-01,  4.02619668e-01,  1.91952726e-02,  4.80242473e-01,
       -1.37141028e-01,  8.25066721e-01,  3.37083448e-01,  5.64565201e-01,
        3.01520036e-02, -6.10088300e-01, -9.50393783e-01, -1.02512354e-01,
       -6.75755337e-01,  1.92606977e-01, -3.60347201e-01, -1.87732101e-01,
        7.57165960e-01, -1.81846596e-01, -7.74718143e-01,  3.67677040e-01,
        2.54096073e-01,  7.61220605e-01, -3.30416686e-01,  1.97579628e-02,
        9.54803628e-01,  1.31272760e-01,  6.16910199e-01, -9.95016123e-01,
       -6.01827247e-01, -3.47328919e-01,  9.79029919e-01, -3.13571725e-01,
        3.92727869e-01,  3.12021804e-01, -7.36271937e-01, -7.10835398e-01,
        6.70706507e-01, -4.90674713e-01, -4.62314943e-01,  5.86700173e-01,
       -9.14635643e-01, -9.53558568e-02, -6.36862992e-02, -2.47978998e-02,
       -1.33616489e-01,  6.32515186e-01,  9.98222804e-01,  8.68327703e-01,
       -5.99692454e-01, -2.31811233e-01,  5.75533730e-01,  1.41536067e-01,
       -6.05036790e-02, -4.08517779e-01,  4.65523101e-01,  1.05807070e-01,
       -1.30159280e-02, -9.70459895e-01, -4.09338250e-01,  7.86467614e-01,
       -1.64630750e-01, -3.81857397e-01,  2.49573405e-01, -7.57738539e-01,
       -5.69775840e-01, -9.61198205e-01, -2.67736211e-01,  3.20451286e-01,
        3.22788042e-02, -7.68116932e-01, -3.63163822e-01, -8.05479590e-01,
        2.06979889e-01, -8.75547518e-01, -4.60068120e-01, -3.91805213e-01,
        1.11814471e-04,  1.34376033e-01,  7.70014985e-01, -6.95498087e-01,
       -8.54759898e-01,  4.23335232e-01, -3.84724603e-01, -6.09505708e-01,
       -6.50201798e-02,  1.40525310e-01, -5.53722031e-01,  5.93510202e-01,
        6.12454658e-01, -2.24638582e-01, -8.83705876e-01, -3.09590824e-01,
        3.64193921e-01, -2.98688047e-01,  4.93076324e-01,  1.29577480e-01,
        3.79579314e-01, -4.08191518e-01,  3.56835673e-01, -4.95195152e-01,
       -5.02218297e-01, -9.66577645e-01, -4.09654750e-01,  6.49611651e-01,
        2.21050389e-01, -1.84653904e-01,  9.91445116e-01,  9.32620868e-01,
       -3.38506471e-01, -1.29050688e-01, -4.00248226e-01,  8.68893012e-01,
       -3.32464066e-01, -6.20937835e-01,  5.98108100e-01, -8.28208321e-01
        9.97516404e-01, -3.02203544e-01, -8.35008880e-01, -5.01918767e-01,
        3.23782771e-01, -5.03160902e-01, -1.38063505e-02, -1.94168328e-01,
       -1.04008104e-01,  6.07082622e-01, -3.86213935e-01,  7.35792215e-01,
       -1.15815896e-01, -2.46873345e-01,  2.84643484e-02, -4.34099003e-01,
       -6.39990806e-01,  6.60487970e-01,  7.74316140e-01,  1.21762503e-01,
        8.56800711e-01,  2.04393917e-01,  6.16646853e-01,  1.38388212e-01,
        3.80496733e-01, -6.03769752e-01,  4.88071554e-01, -9.08363565e-01,
        4.34786964e-01,  7.87288634e-01,  3.38746971e-01]), 'σ_w_CHIOS': array([10.17120168, 10.1782227 , 10.73685522,  9.76525574, 10.4665285 ,
       10.01731353,  9.94830092, 10.31727308, 10.14967229,  9.65455023,
       10.26413   ,  9.50327049,  9.31747371,  9.00053912, 10.44134165,
       10.20261506,  9.99836577,  9.63783051,  9.11248378,  9.46232361,
       10.9423304 , 10.99508382, 10.55153047]), 'σ_b_CHIOS': array(-0.96337696), 'η_CHIOS': array(1.88378994), 'f_CHIOS_rotated_': array([ 9.23616006e-01,  7.77901917e-01, -4.21498370e-01,  9.79736189e-01,
        7.42142313e-01, -1.18727426e-01,  4.53994594e-01, -4.20801284e-01,
        7.43915009e-03,  3.55987828e-01, -1.84489393e-01,  7.84232474e-01,
       -1.46664866e-02,  5.46611866e-01,  5.07858412e-01, -9.47863605e-01,
        1.35310451e-01, -4.26369620e-01,  3.19788497e-01,  3.59855091e-01,
        2.64946338e-01,  7.84376137e-01, -7.21563615e-01, -5.11396794e-01,
        2.65216016e-01, -1.95440060e-01, -4.21435559e-01, -7.25225528e-01,
        3.33736999e-01,  3.92834118e-01,  1.17783636e-01, -5.15275119e-01,
        9.32732896e-01, -7.42402903e-01, -7.20636743e-01, -8.57224323e-01,
        1.80174136e-01,  7.03799368e-01, -8.16897522e-01,  5.96079844e-01,
       -3.02083227e-01, -6.08818905e-01, -5.36624200e-01,  1.25554575e-01,
       -4.43466349e-01,  4.96767912e-03,  5.46370781e-01, -7.44642655e-01,
        7.40036971e-01, -3.88127333e-01, -9.09175491e-02, -6.92525072e-01,
       -6.20290752e-02,  1.14593625e-01, -8.28744441e-01, -5.21146361e-01,

       -2.19191814e-01,  4.89850783e-01,  4.89657254e-01, -9.82109061e-01,
       -7.62067942e-01, -2.32210762e-01,  4.52798298e-01,  4.47310715e-01,
       -6.92318686e-01, -3.64211009e-01, -5.52521943e-01, -6.35685293e-01,
       -6.06508579e-01, -2.44457398e-01, -9.26360372e-01,  6.98040702e-01,
       -8.87170740e-01,  5.06090568e-01,  3.01782351e-02, -5.93163557e-01,
       -3.64999150e-01,  5.98768005e-02,  9.42416284e-01, -3.27173157e-01,
       -8.61732361e-01,  4.19540875e-01, -7.73493704e-01,  4.30363584e-01,
        7.19158268e-01,  9.54380065e-01, -7.74835117e-01,  9.53104576e-01,
       -2.02408242e-01,  9.87064157e-01, -6.49260025e-01,  3.61029478e-01,
       -3.90492149e-01, -4.58991182e-01, -9.95726837e-01, -6.48892968e-01,
       -7.46897004e-01, -2.35958480e-01,  4.26610381e-01,  5.21497252e-01,
        5.27145888e-01, -9.46341703e-01,  1.25268829e-01, -1.79295674e-01,
       -3.01004685e-01,  3.51001142e-02,  5.99601758e-01, -1.03226080e-01,
       -6.46101008e-01,  6.03664602e-02, -6.14892373e-02,  4.69585921e-01,
       -9.53703976e-01,  1.10466054e-01,  4.53452267e-01, -2.16886702e-01,
       -9.54602415e-01,  2.01389599e-01, -6.45606051e-01, -3.43417297e-02,
       -3.24020292e-01, -7.17803823e-01,  5.21611286e-01,  1.29182622e-01,
        4.39019845e-01, -4.55945712e-01,  5.18581296e-01,  5.55395318e-01,
       -5.84372785e-01, -5.75107275e-01, -8.82776953e-01, -6.31164012e-01,
       -5.10222818e-01,  8.39033418e-01, -7.12910091e-01,  3.10943419e-01,
        3.32839155e-01,  5.13707675e-01,  3.74061581e-01, -3.73405589e-01,
       -2.46819474e-01,  1.41563951e-01, -1.93409010e-02, -4.83414629e-01,
       -2.00544916e-01,  1.69270297e-01, -8.18609697e-01,  6.18390635e-01,
       -9.90606682e-01, -2.79144405e-01, -4.99728449e-02, -4.49627165e-01,
       -8.34815413e-01, -7.61258590e-01,  8.17354111e-01, -1.67018677e-01,
        1.19302866e-02,  9.63026417e-01, -1.49358852e-02, -7.41632448e-01,
       -6.34982346e-01,  8.17666343e-01,  1.48252156e-01,  8.97417336e-01,
        7.34107463e-01,  7.97307374e-01, -9.93167524e-01, -5.86855322e-01,
        7.69551696e-01, -7.24265961e-01, -9.66041146e-01,  4.11554797e-01,
        9.59580391e-01,  4.40021544e-01, -8.16263896e-01, -9.76384710e-01,
        3.04327377e-01,  9.00838395e-01,  3.99404994e-02,  6.53495241e-01,
       -9.23477101e-01,  7.44699622e-02, -1.13543422e-01,  3.42359134e-01,
        2.16669551e-01, -2.57469320e-01,  4.50825908e-01,  4.12335612e-01,
       -9.97725847e-01, -2.85264916e-01,  1.41502106e-01, -9.32997232e-01,
       -3.59810035e-01,  9.64132366e-01,  2.62705861e-01, -9.84762331e-01,
       -5.95309175e-02, -8.08496452e-01,  7.39996507e-01, -7.88039551e-01,
       -9.21900651e-01,  2.63767932e-01, -9.61649896e-01,  8.90901657e-01,
        4.31572140e-02, -2.01120261e-02, -9.52584035e-01, -7.84213998e-01,
        8.56496365e-01,  8.96517097e-01,  4.57219594e-01,  7.03872473e-01,
        4.32698451e-01, -1.10951382e-01, -7.36517657e-01, -8.51071007e-01,
        5.00539830e-01,  1.87584275e-02, -5.01886368e-02])}

Initial evaluation results:
{'σ_w_LESVOS': -58.1, 'σ_b_LESVOS': -2511.21, 'η_LESVOS': -1936.34, 'f_LESVOS_rotated_': -435.36, 'σ_w_IKARIA': -61.48, 'σ_b_IKARIA': -1105.79, 'η_IKARIA': -1271.82, 'f_IKARIA_rotated_': -432.46, 'σ_w_SAMOS': -58.62, 'σ_b_SAMOS': -35.69, 'η_SAMOS': -3555.9, 'f_SAMOS_rotated_': -431.25, 'σ_w_FOURNOI': -58.91, 'σ_b_FOURNOI': -4488.09, 'η_FOURNOI': -194.31, 'f_FOURNOI_rotated_': -431.28, 'σ_w_CHIOS': -55.87, 'σ_b_CHIOS': -4636.79, 'η_CHIOS': -3901.74, 'f_CHIOS_rotated_': -434.5, 'y': -inf}

My going theory is all the GP’s are outputing very large negative values which the sigmoid sends to 0, resulting in a divide by zero. I’m not familiar with the internals of MCMC and don’t understand where these gigantic values are coming from or why they seem to disagree with the valuesin the large vectors seen above

P.S. The MultiLayerPerceptronKernel is pretty much copy-pasted from GPy’s implementation. Categorical is encoded as {1,2,…} instead of {0,1,…} not sure if it’s relevant. Outputs arbitarily cropped to meet the character limit

For some inspiration on the GP side of things you might take a look at @DanhPhan’s in-progress example here: Multi-ouput Gaussian Process Coregion models with Hadamard product by danhphan · Pull Request #454 · pymc-devs/pymc-examples · GitHub

2 Likes