我的因变量是一个连续型变量(英语不好,多多见谅),解释变量有15个,也是连续型的。因变量和解释变量是非线性的关系,所以我想用bayesian 神经网络来建立模型。Y的长度是24,X的长度也是24.数据标准化后是这样的:
Y:
[ 1.94646436e+00 1.69789286e+00 1.09761075e+00 1.27649423e-01
3.30047496e-01 7.90309870e-02 -6.80997874e-01 -6.96567735e-01
1.93587499e-01 1.81847484e-03 -6.70038130e-01 -6.14039642e-01
-1.07661282e-01 8.83198175e-01 1.68850271e-01 2.64080960e+00
8.27846809e-04 -9.28953761e-01 -1.10603110e+00 -1.25667752e+00
-1.11093131e+00 -3.79662376e-01 -6.74221778e-01 -9.42005225e-01]
X:
[[ 2.01524961e+00 1.72653251e+00 2.21105318e+00 -7.77997843e-01
-1.51118262e+00 3.53464590e-01 1.84499313e+00 -9.35420755e-01
-2.07155504e+00 -1.78400932e-01 1.48169886e+00 3.59069027e+00
-1.16651064e+00 4.58630703e-01 1.97122894e+00]
[ 1.58426553e+00 3.28455545e+00 2.45517176e+00 -3.98765034e-01
-3.55398207e+00 7.18519679e-01 7.26404969e-01 -1.60276749e+00
-6.27768408e-01 3.78372942e+00 1.14361027e+00 -6.92733671e-02
-1.25667971e+00 -1.68201296e+00 -3.04623158e-01]
[ 6.26523146e-01 2.12432560e+00 1.69230119e+00 -6.45174363e-01
-1.73936767e+00 8.64541715e-01 1.58191591e+00 -1.58900731e+00
-1.03463795e+00 1.25857963e-01 6.82580375e-01 -4.03170126e-01
-1.33083066e+00 -1.73486836e+00 4.40489810e-01]
[ 9.97648322e-02 6.65750937e-01 1.05148992e+00 -8.67724970e-01
-4.21055737e-02 1.03976816e+00 6.76047337e-01 -1.40829592e+00
-8.13378256e-01 -7.43704401e-01 1.90815152e-01 -6.34317187e-01
-1.02091709e+00 -2.36748142e-02 7.41400816e-01]
[ -2.35445004e-01 -2.45858230e-01 -1.69102974e-01 -8.74890305e-01
7.42416341e-01 1.22959680e+00 9.25567972e-01 -8.31006132e-01
2.62280245e-01 1.76228667e-01 5.90374395e-01 -6.73405106e-01
-7.01641029e-01 3.59526829e-01 4.08093593e-01]
[ -9.05864676e-01 -8.42547866e-01 -5.50538255e-01 -5.03100069e-01
1.18574729e+00 1.81368495e+00 -2.39155440e-01 -3.34583883e-01
-1.88320565e-02 -7.71291425e-01 1.69684615e+00 -7.57541531e-01
-6.66021288e-01 -2.28489486e-01 -7.41109452e-01]
[ -9.53751796e-01 -9.41996139e-01 -5.65795666e-01 3.87400929e-01
1.04883626e+00 2.10572902e+00 -1.43167227e+00 -2.07525472e-01
-4.22653463e-01 -5.01382592e-01 1.78905213e+00 -9.37298372e-01
-4.22645771e-01 1.21677533e-01 -1.42843365e+00]
[ -5.70654840e-01 -6.43651321e-01 -9.31973536e-01 9.78132182e-01
6.57661896e-01 5.28691033e-01 -1.81774029e+00 -4.96039454e-01
-5.07186427e-01 5.12332359e-01 1.17434560e+00 -1.09751003e+00
2.99099223e-01 1.33735171e+00 -1.37087644e+00]
[ -6.66429079e-01 -5.93927184e-01 -6.11567900e-01 7.51630970e-01
5.92466169e-01 1.06897256e+00 -6.50649282e-01 -3.50961532e-01
-6.48382790e-01 -4.73784559e-01 6.51845048e-01 1.54573931e+00
-3.61546928e-02 2.67029881e-01 -3.82483664e-01]
[ -2.83332123e-01 -8.42547866e-01 -2.45390031e-01 1.08722868e+00
8.01092496e-01 7.91530697e-01 -6.06510819e-03 7.84397027e-02
-2.51392295e-01 2.61223033e-01 7.74786354e-01 7.85901997e-01
-3.35502928e-01 8.20359840e-02 -7.67090778e-01]
[ 1.47651952e-01 -5.11053624e-01 1.81817484e-01 6.09602915e-01
3.66454316e-01 3.82668997e-01 6.93192814e-01 5.91183624e-01
-1.85420340e-01 8.40321599e-01 4.67433090e-01 3.74235430e-01
-3.58410527e-01 2.91805849e-02 -6.51387269e-01]
[ 1.95539071e-01 -6.35363965e-02 -5.65795666e-01 -2.82081369e-01
-1.61631073e-01 -1.43010331e-01 -8.08573263e-01 2.25059393e-01
1.48455982e-01 8.71214476e-01 -6.69773989e-01 1.44494352e-01
-6.66704926e-01 -7.37222702e-01 9.32928329e-01]
[ 8.18071624e-01 -4.11605351e-01 -1.23330741e-01 -3.60223175e-01
2.94739016e-01 -4.78861013e-01 -1.27819233e+00 -1.79036779e-01
9.59116022e-01 5.07344553e-01 -7.00509315e-01 -4.44952763e-03
-8.81890173e-01 -1.27899054e+00 9.01031967e-01]
[ 1.44060418e+00 -4.61329487e-01 -2.30132619e-01 -8.46359526e-01
4.49035570e-01 -5.88377540e-01 -1.97776786e-01 -2.60394287e-01
1.90163120e+00 3.32226524e-01 -1.00786258e+00 -1.95013310e-01
2.10573294e-02 -2.02061786e-01 -6.26068773e-01]
[ 2.15891097e+00 8.56360126e-02 -2.60647442e-01 -2.24923475e+00
-6.38374827e-02 -6.54087456e-01 -1.34292656e-01 -6.43633165e-01
2.44508967e+00 1.81332588e-01 -1.03859791e+00 -4.67953870e-01
3.11642330e+00 2.24250042e+00 2.34837016e-01]
[ 3.99059329e-03 2.67957846e-01 -9.28159183e-02 -1.22945609e+00
-4.31106745e-01 -9.31529324e-01 -1.66181683e-01 -7.67025533e-01
2.16009568e+00 -4.50623384e-01 -1.10006856e+00 -7.85513195e-01
-8.52186012e-02 -1.25916977e+00 -1.66795535e+00]
[ -1.39670765e-01 -8.25973154e-01 1.53972708e+00 7.08541712e-01
7.79360587e-01 -9.89938138e-01 2.46383538e+00 5.12414154e-01
5.90347428e-01 -1.65169713e+00 -7.61979968e-01 -8.54261355e-01
1.14402338e+00 6.76659224e-01 2.48223348e+00]
[ 4.34974668e-01 -1.62984669e-01 1.97074895e-01 8.57261638e-02
-7.33451929e-03 -1.02425332e+00 4.92447109e-01 6.12148737e-01
2.81704623e-01 6.43079799e-01 -1.16153921e+00 -6.54369784e-01
9.27511650e-01 5.90769200e-01 -2.92727304e-01]
[ -9.17836456e-02 1.85084285e-01 -7.33627190e-01 -4.86526787e-01
-3.28966773e-01 -1.03666519e+00 -4.30670644e-01 3.48046980e-01
-4.70235584e-01 -7.76599100e-02 -1.13080389e+00 -4.40866363e-01
1.42405892e-01 -8.36326575e-01 -2.79735781e-01]
[ -8.57977557e-01 -2.79007654e-01 -7.03112367e-01 5.42036654e-02
2.46928816e-01 -1.03739530e+00 -3.28974355e-01 6.48234391e-01
-2.12939474e-01 -5.57152563e-01 -1.16153921e+00 -3.91608939e-01
9.38079382e-02 -9.09002749e-01 1.28333184e-01]
[ -9.05864676e-01 -2.79007654e-01 -7.33627190e-01 3.06192336e-01
2.12157762e-01 -1.03374475e+00 -4.65617289e-01 9.24111734e-01
-5.61356606e-01 -5.08049774e-01 -1.16153921e+00 -2.76762385e-01
-1.20322964e-03 -5.25801106e-01 1.48029295e-01]
[ -1.14530027e+00 -3.78455927e-01 -9.47230947e-01 2.26178559e-01
2.55621580e-01 -1.01549199e+00 -5.54346834e-01 1.32012443e+00
-3.07518247e-01 -5.59075380e-01 -8.54185947e-01 1.06847935e+00
6.80431624e-01 4.05775303e-01 -1.10999684e+00]
[ -1.33684875e+00 -4.77904199e-01 -7.79399423e-01 1.97341116e+00
2.23023716e-01 -9.84827367e-01 -3.16827422e-01 1.94341118e+00
-1.29591057e-01 -8.79668919e-01 -1.16538113e-01 4.71373690e-01
1.42340427e+00 1.76019491e+00 2.89307468e-01]
[ -1.43262299e+00 -3.78455927e-01 -1.08454765e+00 2.35328501e+00
-1.60272829e-02 -9.78986485e-01 -5.77668960e-01 2.40252338e+00
-4.85872849e-01 -8.82400014e-01 2.21550478e-01 6.62400048e-01
1.08216666e+00 1.08628857e+00 9.44574557e-01]]
建立的模型是这样的:
X = X.astype(floatX)
Y = Y.astype(floatX)
def construct_nn(ann_input, ann_output):
n_hidden = 3
# Initialize random weights between each layer
init_1 = np.random.randn(X.shape[1]+1, n_hidden).astype(floatX)
init_out = np.random.randn(n_hidden+1).astype(floatX)
with pm.Model() as neural_network:
# Weights from input to hidden layer
weights_in_1 = pm.Normal('w_in_1', 0, sd=2,
shape=(X.shape[1]+1, n_hidden),
testval=init_1)
weights_1_out = pm.Normal('w_1_out', 0, sd=20,
shape=(n_hidden+1,),
testval=init_out)
# Weights from hidden layer to output
sigma = pm.HalfNormal('sigma', sd = 0.1)
# Build neural-network using tanh activation function
one = theano.shared(np.ones((24,1)).astype(floatX))
act_1 = pm.math.sigmoid(pm.math.dot(T.concatenate([one, ann_input], axis = 1),
weights_in_1))
one = theano.shared(np.ones((24,1)).astype(floatX))
act_out = pm.math.dot(T.concatenate([one, act_1], axis = 1),weights_1_out)
# Binary classification -> Bernoulli likelihood
out = pm.Normal('out',
mu=act_out, sd = sigma,
observed=ann_output,
total_size=int(Y.shape[0]) # IMPORTANT for minibatches
)
return neural_network
ann_input = theano.shared(X)
ann_output = theano.shared(Y)
neural_network = construct_nn(ann_input, ann_output)
运行模型:
with neural_network:
inference = pm.ADVI()
approx = pm.fit(n=100000, method=inference)
最后我看了一下在X上的预测值与实际值对比:
x = T.matrix('X')
n = T.iscalar('n')
x.tag.test_value = np.empty_like(X[:10])
n.tag.test_value = 100
_sample_proba = approx.sample_node(neural_network.out.distribution.mu,
size=n,
more_replacements={ann_input: x})
sample_proba = theano.function([x, n], _sample_proba)
pred = sample_proba(X, 50000).mean(0)
print pred
print Y
fig, ax = plt.subplots()
ax.scatter(Y, pred)
ax.set_xlim(-1.5, 2)
ax.set_ylim(-1.5, 2)
感觉预测效果不好,特别是与R软件中nnet函数预测结果比较:
大家可不可以帮我看看模型,哪里有问题,怎么才能提高预测准确度?