Liu/ne_PM.py

83 lines
2.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#====================================================================
# 使用 Python API , 拟合一个平面,训练、测试,并画图显示
# 麻雀虽小五脏俱全!
#====================================================================
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# 训练数据
x_train = np.float32(np.random.rand(2, 100)) # 随机输入
y_train = np.dot([5.000, 7.000], x_train) + 2.000
# 构造一个线性模型
b = tf.Variable(tf.zeros([1]))
W = tf.Variable(tf.zeros([1,2]))
y = tf.matmul(W, x_train) + b
# 最小化方差
loss = tf.reduce_mean(tf.square(y - y_train))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
# 启动图 (graph)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# 训练:拟合平面
m = 101
n = 0
W_temp = np.zeros([m // 20 + 1, 2])
b_temp = np.zeros([1, m // 20 + 1])
for step in range(0, m):
sess.run(train)
if step % 20 == 0:
temp = sess.run(W)
W_temp[n] = temp[0] # 注意:列表和数组属于不同类型数据,否则赋值报错!!
b_temp[0][n] = sess.run(b)
print(step, sess.run(W), sess.run(b))
n = n + 1
W_temp = np.transpose(W_temp)
# 参数画图
fig = plt.figure()
step=np.arange(0,m,20)
plt.plot(step,W_temp[0], marker='o', mec='b', mfc='b',label=u'W1')
plt.plot(step,W_temp[1], marker='*', ms=10,label=u'W2')
plt.plot(step,b_temp[0], marker='^', ms=10,label=u'b')
plt.legend()
plt.xlabel('step')
plt.ylabel('value')
plt.title('Parameter')
# 测试数据.
x_test = np.float32(np.random.rand(2, 100))
y_test = np.dot([5.000, 7.000], x_test) + 2.000
# 测试Test trained model
y_=sess.run(tf.matmul(W, x_test) + b)
y = abs(y_-y_test)
y=np.where(y>0.01,y,1) # 数组y_中如果元素大于一个给定的数0.001则值保留否则元素被1替代。
accuracy=np.sum(y==1)/len(y[0,:])
print('accuracy:',accuracy)
sess.close()
# 测试结果数据画图对比
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
xs = x_test[0]
ys = x_test[1]
zs1 = y_test
zs2 = y_
ax.scatter(xs, ys, zs1, c='r', marker='v')
ax.scatter(xs, ys, zs2, c='b', marker='^')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
# 显示以上画的两个图
plt.show()