Is it possible to build a GP model in pymc such that a Latent GP produces vector outputs rather than scalars? 'm interested in this because I often have to model multiple variable types (i.e. multiclass, binary, multilabel classification, and regression) and would like a single model for all of them rather than many small ones. My idea is that if a GP can output vectors, these can be turned to probability distributions with softmax (for multiclass), singular probabilities with sigmoids etc, similar to what you can do with a neural net.
After much experimentation, I have a model that works (at least it samples):
with pymc.Model() as categorical_simple_model:
factors=5
N,M=X_train.shape
gps=[ ]
_f=[ ]
ϵ = 1e-9
μ = pymc.gp.mean.Constant(0)
σ_c=pymc.Normal('σ_c', mu=.7, sigma=.2)
σ_w = pymc.MvNormal('σ_w', mu=np.ones(M)*3.1, cov=σ_c*np.eye(M), shape=M)
σ_b = pymc.Normal('σ_b',mu=3.1, sigma=.7)
_η = pymc.Exponential('_η', lam=2.0)
ϵ = 1e-9
η = pymc.Deterministic('η', _η+ϵ)
κ = MultiLayerPerceptronKernel(M, variance=η**2, bias_variance=σ_b**2, weight_variance=σ_w**2)
for factor,loc in list(core_mappings[('origin', 'location', 'loc')].items())[:]:
gp = pymc.gp.Latent(mean_func=μ, cov_func=κ)
gps.append(gp)
f = gp.prior('f_{loc}'.format(loc=loc),
X=X_train.values)
_f.append(f)
f = pymc.Deterministic('f', at.math.sigmoid(at.stack(*_f)) )
p = pymc.Deterministic('p', f/f.sum(axis=0))
y = pymc.Categorical('y', p=p ,observed=Y_train.values, shape=Y_train.shape)
with categorical_simple_model:
fs = []
for (factor, loc), gp in zip(list(core_mappings[('origin', 'location', 'loc')].items()), gps):
f = gp.conditional(f'_f_star_{loc}1',Xnew=X_test.values)
fs.append(f)
f_star = pymc.Deterministic('f_star', at.stack(*fs) )
p_star = pymc.Deterministic('p_star', pymc.math.sigmoid(f_star)/pymc.math.sigmoid(f_star).sum(axis=0) )
This seems to be producing plausible-looking outputs. However I’ve no idea if it’s the model I’m trying to formulate. I’m particularly confused about pymc.Categorical
. Without the shape parameter it tends to throw shape errors which are hard to interpret. Furthermore the model seems to work ok both when Y is on-hit-encoded and when it just a single column of integers. Can someone clarify what exactly Categorical expects to see?
Categorical expects integer outputs, 0, 1, 2, … n. When you hot-encode it assumes it is seeing multiple independent observations mostly of zeros and some ones now and then. That’s not what you want. You can use multinomial for hot encoding.
Thanks for the reply. After more tinkering, I’ve started to run into sampling errors:
with pymc.Model() as categorical_simple_model:
factors=5
N,M=X_train.shape
gps=[]
fs=[]
for factor,loc in list(core_mappings[('origin', 'location', 'loc')].items()):
μ = pymc.gp.mean.Constant(0)
σ_w = pymc.Normal(f'σ_w_{loc}', mu=10.0, sigma=.7, shape=M)
σ_b= pymc.Normal(f'σ_b_{loc}', mu=.0, sigma=.01)
η = pymc.Normal(f'η_{loc}', mu=1.0, sigma=0.01)
κ = MultiLayerPerceptronKernel(M, variance=η**2, bias_variance=σ_b**2, weight_variance=σ_w**2)
gp = pymc.gp.Latent(mean_func=μ, cov_func=κ)
gps.append(gp)
f = gp.prior('f_{loc}'.format(loc=loc),
X=X_train.values)
fs.append(f)
_f = pymc.Deterministic('_f',at.stack(*fs).T)
f = pymc.Deterministic('f', at.math.sigmoid(_f))
_s = pymc.Deterministic('s_', f.sum(axis=1)[:,None])
dummy = np.tile(np.array([1.0]+[.0]*(factors-1))[None,:], (N,1))
p = pymc.Deterministic('p', at.switch(
at.eq(_s,at.zeros_like(_s)), dummy, f/_s )
)
y = pymc.Categorical('y', p=p ,observed=Y_train.values[:,0], shape=Y_train.shape[0])
(Note my priors are all in an incoherent state after endless trial and error). Checking prior outputs of each GP yields somewhat sensible results:
Yet I get hit with generic sampling errors:
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
---------------------------------------------------------------------------
SamplingError Traceback (most recent call last)
Cell In [171], line 2
1 with categorical_simple_model:
----> 2 idata = pymc.sample(draws= 1000, tune=1500, cores=2, chains=2, random_seed=44)
File /media/alexander-fyrogenis/Elements/Διδακτορικό/Olive Oil/notebooks/venv/lib/python3.10/site-packages/pymc/sampling.py:561, in sample(draws, step, init, n_init, initvals, trace, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, idata_kwargs, mp_ctx, **kwargs)
559 # One final check that shapes and logps at the starting points are okay.
560 for ip in initial_points:
--> 561 model.check_start_vals(ip)
562 _check_start_shape(model, ip)
564 sample_args = {
565 "draws": draws,
566 "step": step,
(...)
575 "discard_tuned_samples": discard_tuned_samples,
576 }
File /media/alexander-fyrogenis/Elements/Διδακτορικό/Olive Oil/notebooks/venv/lib/python3.10/site-packages/pymc/model.py:1801, in Model.check_start_vals(self, start)
1798 initial_eval = self.point_logps(point=elem)
1800 if not all(np.isfinite(v) for v in initial_eval.values()):
-> 1801 raise SamplingError(
1802 "Initial evaluation of model at starting point failed!\n"
1803 f"Starting values:\n{elem}\n\n"
1804 f"Initial evaluation results:\n{initial_eval}"
1805 )
SamplingError: Initial evaluation of model at starting point failed!
Starting values:
{'σ_w_LESVOS': array([ 9.15389669, 10.85270312, 9.44986538, 10.65616583, 9.86568392,
9.55440129, 10.93130078, 10.940588 , 9.82515486, 9.52135951,
10.74580886, 9.58950133, 10.43082604, 9.17414448, 9.22438013,
10.55692819, 10.01655802, 10.58205922, 9.02345872, 9.74216879,
10.46810645, 10.07352035, 9.28861348, 9.60086126, 10.56862279,
10.70950238, 9.10977794, 9.77351756, 9.3025502 , 10.0548006 ,
10.19049653, 9.08146313, 9.76030827, 9.46656961, 10.12176842,
9.34737303, 10.27021216, 10.10108567, 9.85696949, 9.14323312,
10.07580724, 10.14372062, 10.14633824, 9.07256442, 9.36337599,
10.7717392 , 9.37710522, 10.14230229, 9.75886691, 9.68385258,
10.73053422, 9.87517513, 9.25581399, 9.91026346, 9.27934233,
9.62345242, 10.11493438, 9.08867222, 10.82228144, 9.02693697,
10.50623103, 9.08846799, 9.22064129]), 'σ_b_LESVOS': array(0.70921005), 'η_LESVOS': array(1.62290096), 'f_LESVOS_rotated_': array([ 0.48606202, -0.05750464, 0.39775823, -0.73506174, 0.18924769,
-0.79205822, -0.60330083, -0.91940572, 0.30999119, 0.74735712,
0.56623856, 0.83417569, -0.05784597, -0.21402646, -0.96875847,
-0.92535721, 0.84683007, -0.82805718, 0.19342466, 0.78279287,
-0.06573217, -0.0984265 , -0.03966686, 0.69550449, 0.49175055,
-0.62112312, -0.29253861, 0.51988189, -0.48311947, -0.05685895,
-0.87701023, -0.49639025, 0.47754525, -0.85901368, 0.76777466,
0.60161249, 0.11460279, -0.4222967 , -0.3333136 , -0.96208543,
-0.75479352, -0.67267036, -0.31673458, 0.2832573 , 0.13782709,
0.85818576, -0.33240687, -0.06605224, -0.89846575, -0.41776298,
0.36399779, 0.12166975, -0.55771186, 0.10527268, 0.60839088,
0.50254195, -0.05602044, -0.49496405, 0.40177229, 0.77954297,
0.43582885, -0.16038181, -0.04358258, 0.38788849, -0.4015904 ,
0.96939575, -0.52990232, 0.81265588, 0.06512055, 0.90973364,
-0.46138376, -0.85866802, -0.02531684, 0.30186283, -0.26810522,
-0.44478761, 0.18333775, -0.63726822, 0.6975101 , 0.65135364,
0.9892628 , -0.0553152 , 0.13955964, 0.74278527, -0.34539583,
0.61756714, -0.63967225, 0.5944891 , 0.31933709, 0.22292139,
0.18771531, 0.13105963, 0.34021337, -0.86983251, -0.22363494,
-0.41914142, -0.33309724, -0.24702408, 0.45128761, -0.47457353,
0.38379287, -0.75904167, -0.77078648, 0.50091378, -0.29534457,
-0.385392 , 0.91488685, -0.02172829, -0.57859994, -0.32275252,
0.23209637, -0.4291677 , 0.57120151, 0.90852484, -0.51146913,
0.82481469, -0.79504096, -0.37723923, 0.10826941, -0.12484881,
-0.11216902, -0.58125022, 0.06556108, -0.72667636, 0.83587009,
0.29247122, -0.12196493, -0.71611837, 0.86396981, -0.35988189,
-0.12384023, 0.45293806, -0.77279569, -0.9910372 , 0.20212587,
-0.92086343, -0.22572825, -0.12506967, 0.8051224 , 0.03835518,
0.87467874, -0.59992519, -0.10090282, -0.9048303 , 0.83248862,
0.83882907, 0.40730734, -0.92124882, -0.42170137, -0.08626875,
0.77283201, -0.82084868, -0.14841494, 0.45029237, 0.27661259,
0.44060156, -0.33672568, -0.46138687, -0.88363887, -0.07689968,
0.31085364, -0.98858449, -0.31700121, 0.21144317, 0.06261854,
0.36236879, 0.54185604, 0.23242172, -0.26541738, 0.73528826,
0.36156801, -0.36288286, -0.17884688, -0.41012916, 0.60446187,
0.67412714, 0.47165996, 0.12175196, 0.64446294, -0.07718248,
-0.10220434, -0.57194466, 0.67438814, -0.36951548, 0.97305376,
0.70666903, -0.67239419, -0.21630591, -0.13992048, 0.6942951 ,
0.9492795 , 0.60840455, -0.24632258, -0.77791368, -0.42730229,
-0.1190319 , 0.93997779, 0.40344643, -0.9107019 , 0.12310798,
0.67055378, 0.78422849, 0.72392885, 0.9053756 , 0.20930128,
0.11611487, 0.94173718, 0.31018225, -0.26292217, -0.20211832,
0.68700175, 0.34338105, 0.09102248, 0.16590437, 0.39013556,
-0.60289715, 0.85420581, 0.78423054, -0.77061865, -0.53076217,
0.50791463, -0.74316801, 0.44564758, -0.60896336, -0.50457433,
0.50039073, -0.56013186, -0.41561322, 0.8720142 , -0.50067418,
0.76287962, 0.68348549, -0.86013304, 0.95458716, 0.68242175,
0.09745051, -0.04145826, 0.75769 , -0.19702269, -0.0338294 ,
0.71996494, -0.4837162 , -0.09077571, 0.94018659, 0.3729279 ,
-0.36230459, 0.77445179, -0.72677449, 0.48078482, 0.50581907,
0.99125788, 0.14528036, -0.3164143 , -0.35214186, 0.05464637,
0.33523223, -0.49240213, -0.14493891, 0.58246109, -0.34688059,
-0.94206736, 0.89684004, -0.7421782 , -0.81153334, 0.02151323,
0.9434007 , -0.74797198, -0.98670408, -0.15924378, -0.75799781,
0.14191314, -0.75861898, 0.73981023, -0.02256103, 0.93160883,
0.87281622, 0.63372847, -0.75374709, 0.86407595, -0.83897426,
-0.65116815, -0.13878746, 0.68521175, 0.27816359, 0.64666496,
-0.68166694, 0.19674112, -0.75250003, 0.84833841, -0.05258209,
-0.50343253, 0.18088408, -0.87896426, 0.57941075, -0.78760233,
-0.20115428, -0.12380109, -0.40550318, -0.63408434, 0.79398135,
-0.73878101, 0.03437339, 0.18816869, -0.24104893, -0.91192019,
-0.80004882, -0.87585259, 0.538598 , 0.85966293, 0.53485427,
-0.79257814, -0.54344126, -0.07192188, 0.69661188, -0.99336402,
0.65856514, -0.9859744 , 0.71973767, 0.16406195, -0.64105502,
-0.73872274, -0.89386223, 0.56795237, 0.46226757, -0.79521809,
0.38258295, -0.47987754, -0.91335276, 0.30390317, 0.01575346,
-0.06892708, -0.73739227, -0.94395353, -0.78706455, -0.73927979,
0.21979897, 0.83908211, 0.36305091, 0.36759142, 0.17915262,
-0.75946344, 0.56188368, 0.82582536, -0.86403284, 0.70741271,
0.55239022, -0.0237508 , -0.61855772, 0.91165122, -0.70771664,
-0.78505166, -0.93856211, 0.16596274, 0.08685433, -0.35224752,
0.13862893, 0.1314268 , 0.73904479, 0.15068114, 0.06561571,
-0.67232765, -0.53999699, 0.32825998, -0.89981346, 0.13918157,
-0.02071166, -0.25270416, -0.47887448, -0.80040932, -0.52179009,
-0.93727722, 0.11938124, -0.60214144, 0.92809085, 0.15069965,
0.12023757, -0.4319954 , 0.95014544, 0.17716056, -0.83042521,
0.86325665, 0.75015591, -0.82928655, 0.81665911, 0.58666328,
0.53280148, 0.38688033, -0.27655464, 0.59784245, -0.87440181,
-0.81036219, -0.54573094, -0.31516507, -0.69884707, 0.70818205,
0.51771508, -0.84546186, 0.49186937, 0.1268696 ]), 'σ_w_IKARIA': array([ 9.21494362, 10.69765723, 10.52219506, 10.82648558, 9.62780153,
10.62089311, 10.67443963, 10.28101984, 9.94223576, 9.10331004,
10.51714201, 9.90448965, 10.94807566, 10.67459207, 9.18998324,
10.63577916, 9.1964881 , 10.44334554, 10.69654897, 9.5197523 ,
9.50610035, 10.42415076, 10.77374167, 9.17221711, 10.79809018,
9.35620927, 10.27234331, 10.89739256, 9.84993142, 10.72389463,
9.10432842, 9.84038286, 10.21474642]), 'σ_b_IKARIA': array(-0.47105749), 'η_IKARIA': array(1.50507627), 'f_IKARIA_rotated_': array([-0.27948203, -0.39037051, -0.80441018, -0.2028417 , 0.33666347,
-0.91366135, -0.11786481, 0.9936126 , -0.92285819, 0.44724691,
0.23186736, -0.28812537, 0.29880623, -0.06649264, 0.47457421,
0.93279294, -0.69575604, -0.03617461, 0.50512821, -0.6356363 ,
-0.3298101 , 0.62306306, -0.89120689, 0.68278601, 0.93174378,
0.95517806, -0.27287669, 0.63782505, 0.13242919, -0.88263241,
0.56662343, 0.86178375, -0.11706782, -0.94725375, 0.6196216 ,
0.50170609, 0.33706222, 0.25419795, -0.7354072 , 0.74212949,
0.91508001, -0.77096085, -0.35835153, -0.29613842, 0.93758883,
-0.33040874, -0.70027456, -0.26470037, 0.06104933, -0.97368939,
0.47830373, 0.68366075, 0.25691619, -0.55541296, -0.87332899,
0.21011257, -0.72990397, -0.48704953, -0.50603564, -0.97427106,
-0.63992956, 0.96362526, -0.32897696, -0.65462311, -0.86076531,
-0.21689051, 0.437032 , -0.90258268, -0.18465247, 0.65768317,
0.10505895, -0.53424249, 0.35707991, -0.03910111, -0.5782608 ,
-0.70724433, 0.27216229, -0.73573242, 0.2686824 , -0.19995845,
-0.11793359, -0.52450426, 0.24411859, 0.77099508, -0.63883249,
0.09972287, 0.28418921, -0.78116082, 0.30795318, 0.83271334,
0.11929135, -0.63090027, -0.63922105, 0.32525192, 0.9949957 ,
-0.52158884, -0.11099115, 0.78514752, 0.93714179, 0.65663547,
0.4551711 , -0.89977151, 0.51874976, 0.52851773, 0.63470914,
0.32690264, -0.10442083, -0.22657619, 0.47078002, 0.75417924,
-0.1769126 , -0.87948168, 0.63648271, -0.87532117, -0.78362362,
0.59663556, -0.8766123 , 0.331311 , 0.17902656, 0.35563398,
0.21281197, 0.01168888, 0.57534042, 0.29649029, 0.90763528,
0.89161291, 0.01573708, 0.12611317, 0.12109387, -0.15383683,
0.17758962, 0.65371869, -0.36597453, -0.20420019, 0.97426758,
0.13985259, -0.90608478, 0.67376478, 0.95558517, 0.05439007,
-0.11691677, -0.24941819, -0.84215802, -0.38252799, 0.16286306,
0.90463155, 0.72657282, 0.84332545, -0.34549113, 0.18340231,
-0.64138555, -0.13969498, -0.67951917, 0.26504536, 0.47107957,
0.31325995, -0.91330355, -0.77977864, -0.22834488, -0.82950776,
-0.43478197, 0.4317437 , -0.80898551, 0.58372512, 0.23792121,
0.53531073, -0.24918119, -0.64844689, -0.24058978, -0.47535967,
-0.3388392 , -0.04526641, -0.18171857, -0.20685575, -0.03598036,
-0.53220813, -0.35585681, 0.94456868, 0.11598018, 0.48441732,
0.75093342, 0.68660801, 0.14298974, 0.32119572, 0.47927278,
0.70346316, -0.11036186, -0.58045823, 0.53553899, 0.74019028,
-0.88170812, 0.14287552, -0.19237979, 0.01744791, 0.33012509,
0.31516136, 0.57154341, 0.63324059, 0.55171728, 0.43273322,
0.67060517, -0.91263772, 0.39502389, -0.73515489, 0.17602126,
-0.58788015, -0.46186564, -0.52755813, -0.83745151, 0.50575306,
0.22695921, 0.22509206, -0.97794513, 0.78102943, -0.39314543,
-0.48351958, 0.69888094, 0.35290789, -0.74973345, -0.10720196,
-0.21770319, -0.19948587, -0.4669718 , 0.09693138, 0.21184672,
0.54150039, 0.73848466, -0.33167776, -0.63046675, -0.48865088,
0.65588298, 0.30040299, 0.87153996, 0.35498385, 0.50127758,
0.74501066, -0.76349962, -0.65442908, -0.41294734, 0.44716831,
-0.55596078, 0.83651696, 0.5142481 , 0.46857172, -0.46044239,
0.28445636, 0.10714396, -0.11700016, -0.86218257, -0.04477057,
0.34355535, -0.14767439, -0.79204087, 0.64902172, 0.65239501,
0.14668265, -0.25806695, 0.22708576, 0.80371109, -0.1216073 ,
-0.11337834, 0.92453582, -0.98715673, -0.86063292, -0.75003766,
0.17585209, 0.78084521, -0.85198475, -0.94108158, 0.19619444,
-0.08831064, 0.59278515, -0.0244028 , 0.98702759, 0.93913542,
-0.33784074, -0.4574391 , -0.90798093, -0.70322187, 0.19524623,
-0.99871723, -0.50801984, -0.16170868, 0.85088972, -0.67442472,
0.65773729, -0.40098471, 0.09161588, 0.83786002, 0.54346601,
0.01736629, 0.10784845, 0.32900574, 0.15317762, 0.55516649,
0.61755485, -0.71300043, -0.58162011, -0.91731162, -0.14326466,
-0.94322453, 0.16225505, 0.15473652, -0.49417958, -0.34964539,
-0.21733528, -0.58379056, -0.85234656, 0.47416692, 0.70713682,
-0.03221922, 0.03523189, 0.04571195, 0.17054439, -0.37914808,
-0.9545246 , -0.83439622, -0.23108025, -0.24650903]), 'σ_w_SAMOS': array([10.18692239, 10.15548637, 9.6890952 , 9.85771665, 9.21202937,
10.98213583, 9.63121548, 10.59680639, 9.45200574, 10.65284439,
9.81525158, 10.97862545, 9.04996588, 9.8200355 , 10.20277911,
9.93994377, 9.25621225, 10.12314184, 9.98556261, 10.81837332,
10.14758508, 10.61628615, 9.11069931, 10.87425799, 9.11678893,
10.37474661, 10.42492221, 9.76238371, 9.72092666, 10.05785007,
9.52149489, 9.20260739, 9.37123183, 10.91415715, 10.48881682,
9.49054663, 10.98986968, 10.86928939, 10.50607239, 10.74960645,
9.89163306, 10.94083801, 10.36574912]), 'σ_b_SAMOS': array(0.08874031), 'η_SAMOS': array(1.84375236), 'f_SAMOS_rotated_': array([ 6.72047424e-01, 8.99247179e-04, -8.68739473e-01, 8.77473628e-01,
8.54988823e-01, 3.54413515e-01, -4.29107182e-01, -5.71963173e-02,
-6.59897210e-01, -4.94928671e-01, 3.14148898e-02, -9.20920171e-01,
-5.28165081e-01, 9.48962807e-01, 2.72881317e-01, 8.89889459e-01,
8.41552321e-01, 4.79466518e-01, -5.45694971e-01, 3.59571284e-01,
-8.66889456e-01, -2.96377154e-01, -1.98801726e-01, 1.39233115e-01,
3.42756588e-01, -6.05313283e-01, -9.23766012e-01, -5.39733994e-01,
-6.13123216e-01, -8.02285696e-01, -5.39906572e-01, -5.50472626e-01,
-5.16047242e-01, 3.19877954e-01, 2.97925712e-01, -3.65818533e-01,
5.09248145e-01, -3.63776122e-02, 1.30758010e-01, 1.39254296e-02,
-3.37624658e-01, -5.73184253e-01, 1.58570715e-02, 6.04669290e-01,
1.00045557e-02, 6.05803719e-01, 9.03520587e-01, -8.70227128e-01,
4.41040589e-01, -3.38563828e-01, 8.25278294e-01, 3.95261538e-01,
4.49434962e-01, -4.84723012e-01, -3.16089644e-01, 9.85856074e-01,
9.02263931e-01, 4.08331871e-01, 3.93385682e-01, 8.60688966e-01,
-9.01977495e-01, -9.08208671e-01, 3.71834812e-01, 3.66435823e-01,
-7.85322041e-01, -3.33026307e-01, 2.78392188e-01, -2.86695240e-01,
-1.01505997e-01, -7.23259696e-01, 9.43830496e-01, -2.49606772e-01,
4.37556323e-01, -9.44962281e-02, 8.13106311e-01, 7.67737909e-01,
6.18440795e-02, -9.74608045e-01, 5.92403757e-01, 3.40703552e-01,
5.89602672e-01, -7.64037621e-01, -9.84788180e-01, 1.70642552e-01,
6.17462135e-01, -1.23870282e-01, 8.29921192e-02, 5.96812072e-01,
4.64164234e-01, -5.95147759e-01, -3.43670400e-01, -9.37060512e-01,
7.77130728e-01, 8.26558118e-01, 2.64800319e-01, -3.21290329e-01,
5.18619812e-01, -1.93036409e-01, 8.50219741e-01, -3.74988559e-01,
-8.36921434e-01, 8.66810036e-01, 7.26266157e-01, -1.51680793e-02,
-1.29088066e-01, -1.25795414e-01, -6.70603292e-01, 7.05885602e-01,
5.03493255e-01, -5.97832274e-01, 3.50193186e-02, 7.85649819e-02,
-1.67657861e-01, -4.97760331e-01, -5.75729714e-01, 8.24995861e-01,
-5.17815035e-01, -3.17854552e-01, -4.19093049e-01, 6.63259811e-01,
-3.69317749e-01, -6.51843984e-01, -6.35897037e-01, 9.08977183e-01,
-8.79009691e-01, 9.26260153e-01, -9.34687046e-01, 7.57535804e-01,
4.14710716e-01, 9.24928943e-01, 7.06090118e-01, 8.90402165e-01,
-5.68926483e-01, 6.24782694e-01, 9.59451362e-01, 3.24981571e-01,
-9.93714760e-01, -2.57546343e-01, 2.31703891e-01, 3.43771522e-01,
4.04389444e-01, 7.31885593e-01, 1.69572976e-01, 8.38118073e-01,
4.62325894e-01, 7.05893914e-01, -6.56247505e-01, -9.64997622e-01,
-1.16845123e-01, -4.85002898e-01, 7.26642885e-02, -3.52441192e-01,
1.65671929e-01, 8.21853791e-01, -6.61359365e-01, -1.15244422e-01,
2.42187524e-01, 8.17286695e-01, -1.65606018e-01, 9.37500272e-01,
-2.05241676e-01, -7.61401335e-01, 4.89017896e-01, -7.25363800e-01,
8.16962289e-01, 3.09857373e-01, -7.98204641e-01, 8.45322463e-02,,
9.27873440e-01, -4.01678049e-01, 1.68972192e-02, 2.59915567e-01,
2.41189421e-01, 3.20008792e-01, -3.02608608e-01, 8.98069737e-01,
4.75589063e-01, -4.73205343e-01, 2.67624178e-01, -3.08441235e-01,
-8.53213259e-01, -2.27533001e-01, 8.02616549e-02, 9.16945169e-01,
-8.89499333e-01, 7.54751311e-01, 3.22353879e-01, -5.17115937e-01,
-2.94443829e-01, -6.47316228e-01, -5.19127286e-01]), 'σ_w_FOURNOI': array([ 9.5489369 , 10.77977131, 9.4538371 , 9.33103961, 10.69335592,
9.3333864 , 9.27642861, 10.69094443, 10.21996653, 9.40913956,
10.54219502, 9.05731878, 9.39658146, 9.77360533, 10.94081895,
10.84177394, 10.09743025, 10.63296598, 10.51407624, 10.86785604,
10.50029705, 10.51386372, 10.20350355, 10.57265985, 10.62788594,
10.00872969, 9.19832164, 9.07303835, 9.41240352, 10.36151362,
9.89984589, 10.94599143, 9.0420823 , 10.4690349 , 10.6610043 ,
10.18399225, 9.14191398, 9.39271576, 10.68431681, 10.89876478,
10.27226484, 9.91792894, 10.39164921, 9.69105533, 10.26322218,
10.80251152, 9.60909037, 9.71315524]), 'σ_b_FOURNOI': array(-0.94781615), 'η_FOURNOI': array(0.80100611), 'f_FOURNOI_rotated_': array([-3.97228937e-01, 7.99523421e-02, -4.70200757e-01, 6.24252281e-01,
3.61592850e-01, 6.92856896e-01, -9.58915052e-01, -1.50817891e-01,
-8.23119959e-01, 2.01106346e-01, 5.88014219e-01, -8.78096956e-01,
-2.66865559e-01, -1.97475308e-01, -5.74118623e-01, 7.45968951e-01,
-5.86100627e-01, -3.30842680e-03, -1.73384394e-01, -3.06457049e-01,
-9.67728365e-02, 6.41616905e-02, -9.02874015e-02, -2.83434487e-01,
3.39228918e-01, -5.99231498e-02, 5.72319557e-01, -5.27547541e-01,
-9.78259824e-01, 8.82324783e-01, 3.98367563e-01, -8.23504103e-01,
2.29034833e-01, 5.08419355e-02, -4.68958866e-01, 4.30606611e-01,
-7.82844705e-01, -7.24444332e-01, 4.15036651e-01, 4.88033805e-02,
-9.89303817e-01, -1.95089676e-01, -9.37707220e-01, 3.41806124e-01,
7.29085892e-01, -8.61538823e-01, -7.46261517e-01, -1.29583481e-01,
-5.27773004e-01, -8.99770827e-01, 9.83137845e-01, 9.73918499e-02,
-1.16046095e-01, -5.32323658e-01, -5.81883472e-01, -6.05581983e-01,
-1.88182382e-01, 9.64779217e-01, -9.72327761e-01, -6.54042899e-01,
8.51160442e-01, 5.33781483e-01, -2.40535622e-01, 9.23324376e-01,
-3.69347807e-02, -1.69268385e-01, -8.34101390e-01, -3.43568300e-01,
-2.27356770e-01, -9.09077343e-01, 2.55231860e-01, 3.41537746e-01,
-9.11834903e-02, 1.01098700e-01, -5.17521455e-01, 9.38554298e-01,
-9.94700004e-01, 8.39109963e-01, 5.58946607e-01, 6.02651877e-01,
4.24727475e-03, 1.04099801e-01, -8.91609269e-02, 1.44787449e-01,
6.67546964e-01, -9.07588118e-01, -3.42350003e-01, 5.46606747e-01,
6.71862755e-01, -3.08514111e-01, -1.82284976e-01, -3.67059330e-01,
-5.95956838e-01, -6.37005066e-01, -6.78453699e-01, 6.69278877e-01,
-1.06206787e-01, 4.94111905e-01, -4.61558727e-01, 8.88696078e-01,
2.77218747e-01, 1.67444080e-01, -3.02989870e-01, -8.34514364e-01,
1.96356776e-01, -1.12623121e-01, 8.70064375e-01, -2.58101743e-01,
4.89509938e-01, -3.14376592e-01, -7.02122601e-01, 9.30525996e-01,
-5.35555340e-01, 3.16542107e-02, -6.57610632e-01, -4.88674770e-01,
-6.96602203e-01, -8.95303894e-01, -1.18345528e-01, 7.16415312e-02,
6.22768291e-01, -3.72859682e-01, 1.82658418e-01, -4.97489143e-01,
-5.02119858e-01, 9.72305385e-01, 3.16558983e-01, -3.15023871e-01,
-3.10047409e-01, 7.78602658e-01, -7.30367415e-02, -3.71267674e-01,
-6.24387174e-01, -1.97703595e-01, -8.56258572e-01, -5.82475964e-01,
-6.29914619e-01, -1.97560453e-01, 2.93745836e-01, 6.58612732e-01,
6.95329386e-01, 4.36179228e-01, 8.13879030e-01, -5.32242436e-01,
8.75491174e-01, 8.40942545e-02, 8.35713340e-01, 8.31605981e-01,
-1.28009639e-01, 4.33529926e-01, 8.24683705e-01, 5.76037961e-01,
2.56245525e-01, 9.07778175e-01, -5.61672803e-01, 7.39864576e-01,
3.38142187e-01, -4.56562613e-01, -4.68771588e-01, -9.57035613e-01,
-3.19151709e-01, 7.66949662e-01, 2.46579513e-01, -4.69592162e-01,
-9.77265599e-01, -4.86177570e-01, -1.00060151e-01, 9.87048662e-01,
-6.88439298e-01, -6.55885369e-01, -5.25664120e-01, 5.88978172e-01,
-7.22709171e-01, -7.47925633e-01, -9.72897255e-01, 8.80817801e-01,
-7.98941189e-01, -9.73064781e-01, -7.62033765e-01, -2.65681145e-01,
1.24386781e-01, 4.02619668e-01, 1.91952726e-02, 4.80242473e-01,
-1.37141028e-01, 8.25066721e-01, 3.37083448e-01, 5.64565201e-01,
3.01520036e-02, -6.10088300e-01, -9.50393783e-01, -1.02512354e-01,
-6.75755337e-01, 1.92606977e-01, -3.60347201e-01, -1.87732101e-01,
7.57165960e-01, -1.81846596e-01, -7.74718143e-01, 3.67677040e-01,
2.54096073e-01, 7.61220605e-01, -3.30416686e-01, 1.97579628e-02,
9.54803628e-01, 1.31272760e-01, 6.16910199e-01, -9.95016123e-01,
-6.01827247e-01, -3.47328919e-01, 9.79029919e-01, -3.13571725e-01,
3.92727869e-01, 3.12021804e-01, -7.36271937e-01, -7.10835398e-01,
6.70706507e-01, -4.90674713e-01, -4.62314943e-01, 5.86700173e-01,
-9.14635643e-01, -9.53558568e-02, -6.36862992e-02, -2.47978998e-02,
-1.33616489e-01, 6.32515186e-01, 9.98222804e-01, 8.68327703e-01,
-5.99692454e-01, -2.31811233e-01, 5.75533730e-01, 1.41536067e-01,
-6.05036790e-02, -4.08517779e-01, 4.65523101e-01, 1.05807070e-01,
-1.30159280e-02, -9.70459895e-01, -4.09338250e-01, 7.86467614e-01,
-1.64630750e-01, -3.81857397e-01, 2.49573405e-01, -7.57738539e-01,
-5.69775840e-01, -9.61198205e-01, -2.67736211e-01, 3.20451286e-01,
3.22788042e-02, -7.68116932e-01, -3.63163822e-01, -8.05479590e-01,
2.06979889e-01, -8.75547518e-01, -4.60068120e-01, -3.91805213e-01,
1.11814471e-04, 1.34376033e-01, 7.70014985e-01, -6.95498087e-01,
-8.54759898e-01, 4.23335232e-01, -3.84724603e-01, -6.09505708e-01,
-6.50201798e-02, 1.40525310e-01, -5.53722031e-01, 5.93510202e-01,
6.12454658e-01, -2.24638582e-01, -8.83705876e-01, -3.09590824e-01,
3.64193921e-01, -2.98688047e-01, 4.93076324e-01, 1.29577480e-01,
3.79579314e-01, -4.08191518e-01, 3.56835673e-01, -4.95195152e-01,
-5.02218297e-01, -9.66577645e-01, -4.09654750e-01, 6.49611651e-01,
2.21050389e-01, -1.84653904e-01, 9.91445116e-01, 9.32620868e-01,
-3.38506471e-01, -1.29050688e-01, -4.00248226e-01, 8.68893012e-01,
-3.32464066e-01, -6.20937835e-01, 5.98108100e-01, -8.28208321e-01
9.97516404e-01, -3.02203544e-01, -8.35008880e-01, -5.01918767e-01,
3.23782771e-01, -5.03160902e-01, -1.38063505e-02, -1.94168328e-01,
-1.04008104e-01, 6.07082622e-01, -3.86213935e-01, 7.35792215e-01,
-1.15815896e-01, -2.46873345e-01, 2.84643484e-02, -4.34099003e-01,
-6.39990806e-01, 6.60487970e-01, 7.74316140e-01, 1.21762503e-01,
8.56800711e-01, 2.04393917e-01, 6.16646853e-01, 1.38388212e-01,
3.80496733e-01, -6.03769752e-01, 4.88071554e-01, -9.08363565e-01,
4.34786964e-01, 7.87288634e-01, 3.38746971e-01]), 'σ_w_CHIOS': array([10.17120168, 10.1782227 , 10.73685522, 9.76525574, 10.4665285 ,
10.01731353, 9.94830092, 10.31727308, 10.14967229, 9.65455023,
10.26413 , 9.50327049, 9.31747371, 9.00053912, 10.44134165,
10.20261506, 9.99836577, 9.63783051, 9.11248378, 9.46232361,
10.9423304 , 10.99508382, 10.55153047]), 'σ_b_CHIOS': array(-0.96337696), 'η_CHIOS': array(1.88378994), 'f_CHIOS_rotated_': array([ 9.23616006e-01, 7.77901917e-01, -4.21498370e-01, 9.79736189e-01,
7.42142313e-01, -1.18727426e-01, 4.53994594e-01, -4.20801284e-01,
7.43915009e-03, 3.55987828e-01, -1.84489393e-01, 7.84232474e-01,
-1.46664866e-02, 5.46611866e-01, 5.07858412e-01, -9.47863605e-01,
1.35310451e-01, -4.26369620e-01, 3.19788497e-01, 3.59855091e-01,
2.64946338e-01, 7.84376137e-01, -7.21563615e-01, -5.11396794e-01,
2.65216016e-01, -1.95440060e-01, -4.21435559e-01, -7.25225528e-01,
3.33736999e-01, 3.92834118e-01, 1.17783636e-01, -5.15275119e-01,
9.32732896e-01, -7.42402903e-01, -7.20636743e-01, -8.57224323e-01,
1.80174136e-01, 7.03799368e-01, -8.16897522e-01, 5.96079844e-01,
-3.02083227e-01, -6.08818905e-01, -5.36624200e-01, 1.25554575e-01,
-4.43466349e-01, 4.96767912e-03, 5.46370781e-01, -7.44642655e-01,
7.40036971e-01, -3.88127333e-01, -9.09175491e-02, -6.92525072e-01,
-6.20290752e-02, 1.14593625e-01, -8.28744441e-01, -5.21146361e-01,
-2.19191814e-01, 4.89850783e-01, 4.89657254e-01, -9.82109061e-01,
-7.62067942e-01, -2.32210762e-01, 4.52798298e-01, 4.47310715e-01,
-6.92318686e-01, -3.64211009e-01, -5.52521943e-01, -6.35685293e-01,
-6.06508579e-01, -2.44457398e-01, -9.26360372e-01, 6.98040702e-01,
-8.87170740e-01, 5.06090568e-01, 3.01782351e-02, -5.93163557e-01,
-3.64999150e-01, 5.98768005e-02, 9.42416284e-01, -3.27173157e-01,
-8.61732361e-01, 4.19540875e-01, -7.73493704e-01, 4.30363584e-01,
7.19158268e-01, 9.54380065e-01, -7.74835117e-01, 9.53104576e-01,
-2.02408242e-01, 9.87064157e-01, -6.49260025e-01, 3.61029478e-01,
-3.90492149e-01, -4.58991182e-01, -9.95726837e-01, -6.48892968e-01,
-7.46897004e-01, -2.35958480e-01, 4.26610381e-01, 5.21497252e-01,
5.27145888e-01, -9.46341703e-01, 1.25268829e-01, -1.79295674e-01,
-3.01004685e-01, 3.51001142e-02, 5.99601758e-01, -1.03226080e-01,
-6.46101008e-01, 6.03664602e-02, -6.14892373e-02, 4.69585921e-01,
-9.53703976e-01, 1.10466054e-01, 4.53452267e-01, -2.16886702e-01,
-9.54602415e-01, 2.01389599e-01, -6.45606051e-01, -3.43417297e-02,
-3.24020292e-01, -7.17803823e-01, 5.21611286e-01, 1.29182622e-01,
4.39019845e-01, -4.55945712e-01, 5.18581296e-01, 5.55395318e-01,
-5.84372785e-01, -5.75107275e-01, -8.82776953e-01, -6.31164012e-01,
-5.10222818e-01, 8.39033418e-01, -7.12910091e-01, 3.10943419e-01,
3.32839155e-01, 5.13707675e-01, 3.74061581e-01, -3.73405589e-01,
-2.46819474e-01, 1.41563951e-01, -1.93409010e-02, -4.83414629e-01,
-2.00544916e-01, 1.69270297e-01, -8.18609697e-01, 6.18390635e-01,
-9.90606682e-01, -2.79144405e-01, -4.99728449e-02, -4.49627165e-01,
-8.34815413e-01, -7.61258590e-01, 8.17354111e-01, -1.67018677e-01,
1.19302866e-02, 9.63026417e-01, -1.49358852e-02, -7.41632448e-01,
-6.34982346e-01, 8.17666343e-01, 1.48252156e-01, 8.97417336e-01,
7.34107463e-01, 7.97307374e-01, -9.93167524e-01, -5.86855322e-01,
7.69551696e-01, -7.24265961e-01, -9.66041146e-01, 4.11554797e-01,
9.59580391e-01, 4.40021544e-01, -8.16263896e-01, -9.76384710e-01,
3.04327377e-01, 9.00838395e-01, 3.99404994e-02, 6.53495241e-01,
-9.23477101e-01, 7.44699622e-02, -1.13543422e-01, 3.42359134e-01,
2.16669551e-01, -2.57469320e-01, 4.50825908e-01, 4.12335612e-01,
-9.97725847e-01, -2.85264916e-01, 1.41502106e-01, -9.32997232e-01,
-3.59810035e-01, 9.64132366e-01, 2.62705861e-01, -9.84762331e-01,
-5.95309175e-02, -8.08496452e-01, 7.39996507e-01, -7.88039551e-01,
-9.21900651e-01, 2.63767932e-01, -9.61649896e-01, 8.90901657e-01,
4.31572140e-02, -2.01120261e-02, -9.52584035e-01, -7.84213998e-01,
8.56496365e-01, 8.96517097e-01, 4.57219594e-01, 7.03872473e-01,
4.32698451e-01, -1.10951382e-01, -7.36517657e-01, -8.51071007e-01,
5.00539830e-01, 1.87584275e-02, -5.01886368e-02])}
Initial evaluation results:
{'σ_w_LESVOS': -58.1, 'σ_b_LESVOS': -2511.21, 'η_LESVOS': -1936.34, 'f_LESVOS_rotated_': -435.36, 'σ_w_IKARIA': -61.48, 'σ_b_IKARIA': -1105.79, 'η_IKARIA': -1271.82, 'f_IKARIA_rotated_': -432.46, 'σ_w_SAMOS': -58.62, 'σ_b_SAMOS': -35.69, 'η_SAMOS': -3555.9, 'f_SAMOS_rotated_': -431.25, 'σ_w_FOURNOI': -58.91, 'σ_b_FOURNOI': -4488.09, 'η_FOURNOI': -194.31, 'f_FOURNOI_rotated_': -431.28, 'σ_w_CHIOS': -55.87, 'σ_b_CHIOS': -4636.79, 'η_CHIOS': -3901.74, 'f_CHIOS_rotated_': -434.5, 'y': -inf}
My going theory is all the GP’s are outputing very large negative values which the sigmoid sends to 0, resulting in a divide by zero. I’m not familiar with the internals of MCMC and don’t understand where these gigantic values are coming from or why they seem to disagree with the valuesin the large vectors seen above
P.S. The MultiLayerPerceptronKernel is pretty much copy-pasted from GPy’s implementation. Categorical is encoded as {1,2,…} instead of {0,1,…} not sure if it’s relevant. Outputs arbitarily cropped to meet the character limit
For some inspiration on the GP side of things you might take a look at @DanhPhan’s in-progress example here: Multi-ouput Gaussian Process Coregion models with Hadamard product by danhphan · Pull Request #454 · pymc-devs/pymc-examples · GitHub