To complete the topic, here is the evolution of P(w_1) and P(w_2) wrt y.
ynew = np.asarray([0.51])
def prob_weights(model_mixed, trace_mixed, ynew):
complogp = y_obs.distribution._comp_logp(theano.shared(ynew))
f_complogp = model_mixed.model.fastfn(complogp)
weight_ynew = []
for ichain in range(trace_mixed.nchains):
for point_idx in range(len(trace_mixed)):
point = trace_mixed._straces[ichain].point(point_idx)
point.pop('sigma')
point.pop('w')
prob = np.exp(f_complogp(point))
prob /= prob.sum()
weight_ynew.append(prob)
weight_ynew = np.asarray(weight_ynew).squeeze()
return(weight_ynew.mean(axis=0))
n=1000
min_x = -0.5
max_x = 1.5
delta_x = (max_x - min_x )/n
probability_weights = np.array([[],[],[]])
for i in range(n):
xx = min_x + i * delta_x
res = prob_weights(model_mixed, trace_mixed, xx)
res = np.append(xx, res).reshape(3,1)
probability_weights = np.concatenate([probability_weights, res], axis=1)
print(probability_weights.shape)
probability_weights = pd.DataFrame(probability_weights.T, columns=['x', 'P(w_1)', 'P(w_2)'])
probability_weights
(3, 1000)
.dataframe tbody tr th:only-of-type {
vertical-align: middle;
}
.dataframe tbody tr th {
vertical-align: top;
}
.dataframe thead th {
text-align: right;
}
| x | P(w_1) | P(w_2) | |
|---|---|---|---|
| 0 | -0.500 | 1.000000e+00 | 4.424003e-42 |
| 1 | -0.498 | 1.000000e+00 | 5.257151e-42 |
| 2 | -0.496 | 1.000000e+00 | 6.247663e-42 |
| 3 | -0.494 | 1.000000e+00 | 7.425351e-42 |
| 4 | -0.492 | 1.000000e+00 | 8.825689e-42 |
| 5 | -0.490 | 1.000000e+00 | 1.049089e-41 |
| 6 | -0.488 | 1.000000e+00 | 1.247121e-41 |
| 7 | -0.486 | 1.000000e+00 | 1.482645e-41 |
| 8 | -0.484 | 1.000000e+00 | 1.762779e-41 |
| 9 | -0.482 | 1.000000e+00 | 2.095998e-41 |
| 10 | -0.480 | 1.000000e+00 | 2.492392e-41 |
| 11 | -0.478 | 1.000000e+00 | 2.963974e-41 |
| 12 | -0.476 | 1.000000e+00 | 3.525045e-41 |
| 13 | -0.474 | 1.000000e+00 | 4.192638e-41 |
| 14 | -0.472 | 1.000000e+00 | 4.987038e-41 |
| 15 | -0.470 | 1.000000e+00 | 5.932400e-41 |
| 16 | -0.468 | 1.000000e+00 | 7.057496e-41 |
| 17 | -0.466 | 1.000000e+00 | 8.396600e-41 |
| 18 | -0.464 | 1.000000e+00 | 9.990538e-41 |
| 19 | -0.462 | 1.000000e+00 | 1.188795e-40 |
| 20 | -0.460 | 1.000000e+00 | 1.414678e-40 |
| 21 | -0.458 | 1.000000e+00 | 1.683607e-40 |
| 22 | -0.456 | 1.000000e+00 | 2.003811e-40 |
| 23 | -0.454 | 1.000000e+00 | 2.385093e-40 |
| 24 | -0.452 | 1.000000e+00 | 2.839139e-40 |
| 25 | -0.450 | 1.000000e+00 | 3.379875e-40 |
| 26 | -0.448 | 1.000000e+00 | 4.023903e-40 |
| 27 | -0.446 | 1.000000e+00 | 4.791011e-40 |
| 28 | -0.444 | 1.000000e+00 | 5.704790e-40 |
| 29 | -0.442 | 1.000000e+00 | 6.793364e-40 |
| ... | ... | ... | ... |
| 970 | 1.440 | 2.567641e-38 | 1.000000e+00 |
| 971 | 1.442 | 2.194209e-38 | 1.000000e+00 |
| 972 | 1.444 | 1.875275e-38 | 1.000000e+00 |
| 973 | 1.446 | 1.602860e-38 | 1.000000e+00 |
| 974 | 1.448 | 1.370153e-38 | 1.000000e+00 |
| 975 | 1.450 | 1.171348e-38 | 1.000000e+00 |
| 976 | 1.452 | 1.001487e-38 | 1.000000e+00 |
| 977 | 1.454 | 8.563431e-39 | 1.000000e+00 |
| 978 | 1.456 | 7.323064e-39 | 1.000000e+00 |
| 979 | 1.458 | 6.262970e-39 | 1.000000e+00 |
| 980 | 1.460 | 5.356858e-39 | 1.000000e+00 |
| 981 | 1.462 | 4.582286e-39 | 1.000000e+00 |
| 982 | 1.464 | 3.920093e-39 | 1.000000e+00 |
| 983 | 1.466 | 3.353918e-39 | 1.000000e+00 |
| 984 | 1.468 | 2.869791e-39 | 1.000000e+00 |
| 985 | 1.470 | 2.455782e-39 | 1.000000e+00 |
| 986 | 1.472 | 2.101701e-39 | 1.000000e+00 |
| 987 | 1.474 | 1.798843e-39 | 1.000000e+00 |
| 988 | 1.476 | 1.539775e-39 | 1.000000e+00 |
| 989 | 1.478 | 1.318142e-39 | 1.000000e+00 |
| 990 | 1.480 | 1.128517e-39 | 1.000000e+00 |
| 991 | 1.482 | 9.662622e-40 | 1.000000e+00 |
| 992 | 1.484 | 8.274136e-40 | 1.000000e+00 |
| 993 | 1.486 | 7.085835e-40 | 1.000000e+00 |
| 994 | 1.488 | 6.068759e-40 | 1.000000e+00 |
| 995 | 1.490 | 5.198155e-40 | 1.000000e+00 |
| 996 | 1.492 | 4.452858e-40 | 1.000000e+00 |
| 997 | 1.494 | 3.814773e-40 | 1.000000e+00 |
| 998 | 1.496 | 3.268425e-40 | 1.000000e+00 |
| 999 | 1.498 | 2.800583e-40 | 1.000000e+00 |
1000 rows × 3 columns
probability_weights.set_index('x', inplace=True)
sns.scatterplot(data = probability_weights)
