Bayesian neural network regression

我的因变量是一个连续型变量(英语不好,多多见谅),解释变量有15个,也是连续型的。因变量和解释变量是非线性的关系,所以我想用bayesian 神经网络来建立模型。Y的长度是24,X的长度也是24.数据标准化后是这样的:

Y:
[  1.94646436e+00   1.69789286e+00   1.09761075e+00   1.27649423e-01
   3.30047496e-01   7.90309870e-02  -6.80997874e-01  -6.96567735e-01
   1.93587499e-01   1.81847484e-03  -6.70038130e-01  -6.14039642e-01
  -1.07661282e-01   8.83198175e-01   1.68850271e-01   2.64080960e+00
   8.27846809e-04  -9.28953761e-01  -1.10603110e+00  -1.25667752e+00
  -1.11093131e+00  -3.79662376e-01  -6.74221778e-01  -9.42005225e-01]
X:
[[  2.01524961e+00   1.72653251e+00   2.21105318e+00  -7.77997843e-01
   -1.51118262e+00   3.53464590e-01   1.84499313e+00  -9.35420755e-01
   -2.07155504e+00  -1.78400932e-01   1.48169886e+00   3.59069027e+00
   -1.16651064e+00   4.58630703e-01   1.97122894e+00]
 [  1.58426553e+00   3.28455545e+00   2.45517176e+00  -3.98765034e-01
   -3.55398207e+00   7.18519679e-01   7.26404969e-01  -1.60276749e+00
   -6.27768408e-01   3.78372942e+00   1.14361027e+00  -6.92733671e-02
   -1.25667971e+00  -1.68201296e+00  -3.04623158e-01]
 [  6.26523146e-01   2.12432560e+00   1.69230119e+00  -6.45174363e-01
   -1.73936767e+00   8.64541715e-01   1.58191591e+00  -1.58900731e+00
   -1.03463795e+00   1.25857963e-01   6.82580375e-01  -4.03170126e-01
   -1.33083066e+00  -1.73486836e+00   4.40489810e-01]
 [  9.97648322e-02   6.65750937e-01   1.05148992e+00  -8.67724970e-01
   -4.21055737e-02   1.03976816e+00   6.76047337e-01  -1.40829592e+00
   -8.13378256e-01  -7.43704401e-01   1.90815152e-01  -6.34317187e-01
   -1.02091709e+00  -2.36748142e-02   7.41400816e-01]
 [ -2.35445004e-01  -2.45858230e-01  -1.69102974e-01  -8.74890305e-01
    7.42416341e-01   1.22959680e+00   9.25567972e-01  -8.31006132e-01
    2.62280245e-01   1.76228667e-01   5.90374395e-01  -6.73405106e-01
   -7.01641029e-01   3.59526829e-01   4.08093593e-01]
 [ -9.05864676e-01  -8.42547866e-01  -5.50538255e-01  -5.03100069e-01
    1.18574729e+00   1.81368495e+00  -2.39155440e-01  -3.34583883e-01
   -1.88320565e-02  -7.71291425e-01   1.69684615e+00  -7.57541531e-01
   -6.66021288e-01  -2.28489486e-01  -7.41109452e-01]
 [ -9.53751796e-01  -9.41996139e-01  -5.65795666e-01   3.87400929e-01
    1.04883626e+00   2.10572902e+00  -1.43167227e+00  -2.07525472e-01
   -4.22653463e-01  -5.01382592e-01   1.78905213e+00  -9.37298372e-01
   -4.22645771e-01   1.21677533e-01  -1.42843365e+00]
 [ -5.70654840e-01  -6.43651321e-01  -9.31973536e-01   9.78132182e-01
    6.57661896e-01   5.28691033e-01  -1.81774029e+00  -4.96039454e-01
   -5.07186427e-01   5.12332359e-01   1.17434560e+00  -1.09751003e+00
    2.99099223e-01   1.33735171e+00  -1.37087644e+00]
 [ -6.66429079e-01  -5.93927184e-01  -6.11567900e-01   7.51630970e-01
    5.92466169e-01   1.06897256e+00  -6.50649282e-01  -3.50961532e-01
   -6.48382790e-01  -4.73784559e-01   6.51845048e-01   1.54573931e+00
   -3.61546928e-02   2.67029881e-01  -3.82483664e-01]
 [ -2.83332123e-01  -8.42547866e-01  -2.45390031e-01   1.08722868e+00
    8.01092496e-01   7.91530697e-01  -6.06510819e-03   7.84397027e-02
   -2.51392295e-01   2.61223033e-01   7.74786354e-01   7.85901997e-01
   -3.35502928e-01   8.20359840e-02  -7.67090778e-01]
 [  1.47651952e-01  -5.11053624e-01   1.81817484e-01   6.09602915e-01
    3.66454316e-01   3.82668997e-01   6.93192814e-01   5.91183624e-01
   -1.85420340e-01   8.40321599e-01   4.67433090e-01   3.74235430e-01
   -3.58410527e-01   2.91805849e-02  -6.51387269e-01]
 [  1.95539071e-01  -6.35363965e-02  -5.65795666e-01  -2.82081369e-01
   -1.61631073e-01  -1.43010331e-01  -8.08573263e-01   2.25059393e-01
    1.48455982e-01   8.71214476e-01  -6.69773989e-01   1.44494352e-01
   -6.66704926e-01  -7.37222702e-01   9.32928329e-01]
 [  8.18071624e-01  -4.11605351e-01  -1.23330741e-01  -3.60223175e-01
    2.94739016e-01  -4.78861013e-01  -1.27819233e+00  -1.79036779e-01
    9.59116022e-01   5.07344553e-01  -7.00509315e-01  -4.44952763e-03
   -8.81890173e-01  -1.27899054e+00   9.01031967e-01]
 [  1.44060418e+00  -4.61329487e-01  -2.30132619e-01  -8.46359526e-01
    4.49035570e-01  -5.88377540e-01  -1.97776786e-01  -2.60394287e-01
    1.90163120e+00   3.32226524e-01  -1.00786258e+00  -1.95013310e-01
    2.10573294e-02  -2.02061786e-01  -6.26068773e-01]
 [  2.15891097e+00   8.56360126e-02  -2.60647442e-01  -2.24923475e+00
   -6.38374827e-02  -6.54087456e-01  -1.34292656e-01  -6.43633165e-01
    2.44508967e+00   1.81332588e-01  -1.03859791e+00  -4.67953870e-01
    3.11642330e+00   2.24250042e+00   2.34837016e-01]
 [  3.99059329e-03   2.67957846e-01  -9.28159183e-02  -1.22945609e+00
   -4.31106745e-01  -9.31529324e-01  -1.66181683e-01  -7.67025533e-01
    2.16009568e+00  -4.50623384e-01  -1.10006856e+00  -7.85513195e-01
   -8.52186012e-02  -1.25916977e+00  -1.66795535e+00]
 [ -1.39670765e-01  -8.25973154e-01   1.53972708e+00   7.08541712e-01
    7.79360587e-01  -9.89938138e-01   2.46383538e+00   5.12414154e-01
    5.90347428e-01  -1.65169713e+00  -7.61979968e-01  -8.54261355e-01
    1.14402338e+00   6.76659224e-01   2.48223348e+00]
 [  4.34974668e-01  -1.62984669e-01   1.97074895e-01   8.57261638e-02
   -7.33451929e-03  -1.02425332e+00   4.92447109e-01   6.12148737e-01
    2.81704623e-01   6.43079799e-01  -1.16153921e+00  -6.54369784e-01
    9.27511650e-01   5.90769200e-01  -2.92727304e-01]
 [ -9.17836456e-02   1.85084285e-01  -7.33627190e-01  -4.86526787e-01
   -3.28966773e-01  -1.03666519e+00  -4.30670644e-01   3.48046980e-01
   -4.70235584e-01  -7.76599100e-02  -1.13080389e+00  -4.40866363e-01
    1.42405892e-01  -8.36326575e-01  -2.79735781e-01]
 [ -8.57977557e-01  -2.79007654e-01  -7.03112367e-01   5.42036654e-02
    2.46928816e-01  -1.03739530e+00  -3.28974355e-01   6.48234391e-01
   -2.12939474e-01  -5.57152563e-01  -1.16153921e+00  -3.91608939e-01
    9.38079382e-02  -9.09002749e-01   1.28333184e-01]
 [ -9.05864676e-01  -2.79007654e-01  -7.33627190e-01   3.06192336e-01
    2.12157762e-01  -1.03374475e+00  -4.65617289e-01   9.24111734e-01
   -5.61356606e-01  -5.08049774e-01  -1.16153921e+00  -2.76762385e-01
   -1.20322964e-03  -5.25801106e-01   1.48029295e-01]
 [ -1.14530027e+00  -3.78455927e-01  -9.47230947e-01   2.26178559e-01
    2.55621580e-01  -1.01549199e+00  -5.54346834e-01   1.32012443e+00
   -3.07518247e-01  -5.59075380e-01  -8.54185947e-01   1.06847935e+00
    6.80431624e-01   4.05775303e-01  -1.10999684e+00]
 [ -1.33684875e+00  -4.77904199e-01  -7.79399423e-01   1.97341116e+00
    2.23023716e-01  -9.84827367e-01  -3.16827422e-01   1.94341118e+00
   -1.29591057e-01  -8.79668919e-01  -1.16538113e-01   4.71373690e-01
    1.42340427e+00   1.76019491e+00   2.89307468e-01]
 [ -1.43262299e+00  -3.78455927e-01  -1.08454765e+00   2.35328501e+00
   -1.60272829e-02  -9.78986485e-01  -5.77668960e-01   2.40252338e+00
   -4.85872849e-01  -8.82400014e-01   2.21550478e-01   6.62400048e-01
    1.08216666e+00   1.08628857e+00   9.44574557e-01]]

建立的模型是这样的:

X = X.astype(floatX)
Y = Y.astype(floatX)
def construct_nn(ann_input, ann_output):
    n_hidden = 3

    # Initialize random weights between each layer
    init_1 = np.random.randn(X.shape[1]+1, n_hidden).astype(floatX)
    init_out = np.random.randn(n_hidden+1).astype(floatX)
    with pm.Model() as neural_network:
        # Weights from input to hidden layer
        weights_in_1 = pm.Normal('w_in_1', 0, sd=2,
                                 shape=(X.shape[1]+1, n_hidden),
                                 testval=init_1)
        weights_1_out = pm.Normal('w_1_out', 0, sd=20,
                                  shape=(n_hidden+1,),
                                  testval=init_out)
        # Weights from hidden layer to output
        
        sigma = pm.HalfNormal('sigma', sd = 0.1)

        # Build neural-network using tanh activation function
        one = theano.shared(np.ones((24,1)).astype(floatX))
        act_1 = pm.math.sigmoid(pm.math.dot(T.concatenate([one, ann_input], axis = 1),
                                         weights_in_1))
        one = theano.shared(np.ones((24,1)).astype(floatX))

        act_out = pm.math.dot(T.concatenate([one, act_1], axis = 1),weights_1_out)

        # Binary classification -> Bernoulli likelihood
        out = pm.Normal('out',
                           mu=act_out, sd = sigma,
                           observed=ann_output,
                           total_size=int(Y.shape[0]) # IMPORTANT for minibatches
                          )
    return neural_network

ann_input = theano.shared(X)
ann_output = theano.shared(Y)
neural_network = construct_nn(ann_input, ann_output)

运行模型:

with neural_network:
    inference = pm.ADVI()
    approx = pm.fit(n=100000, method=inference)

最后我看了一下在X上的预测值与实际值对比:

x = T.matrix('X')
n = T.iscalar('n')
x.tag.test_value = np.empty_like(X[:10])
n.tag.test_value = 100
_sample_proba = approx.sample_node(neural_network.out.distribution.mu,
                                   size=n,
                                   more_replacements={ann_input: x})
sample_proba = theano.function([x, n], _sample_proba)
pred = sample_proba(X, 50000).mean(0) 
print pred
print Y
fig, ax = plt.subplots()
ax.scatter(Y, pred)
ax.set_xlim(-1.5, 2)
ax.set_ylim(-1.5, 2)

untitled

感觉预测效果不好,特别是与R软件中nnet函数预测结果比较:

Rplot

大家可不可以帮我看看模型,哪里有问题,怎么才能提高预测准确度?