Liu/ne_PM.py

83 lines
2.3 KiB
Python
Raw Normal View History

2023-11-27 15:56:15 +08:00
#====================================================================
# 使用 Python API , 拟合一个平面,训练、测试,并画图显示
# 麻雀虽小五脏俱全!
#====================================================================
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# 训练数据
x_train = np.float32(np.random.rand(2, 100)) # 随机输入
y_train = np.dot([5.000, 7.000], x_train) + 2.000
# 构造一个线性模型
b = tf.Variable(tf.zeros([1]))
W = tf.Variable(tf.zeros([1,2]))
y = tf.matmul(W, x_train) + b
# 最小化方差
loss = tf.reduce_mean(tf.square(y - y_train))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
# 启动图 (graph)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# 训练:拟合平面
m = 101
n = 0
W_temp = np.zeros([m // 20 + 1, 2])
b_temp = np.zeros([1, m // 20 + 1])
for step in range(0, m):
sess.run(train)
if step % 20 == 0:
temp = sess.run(W)
W_temp[n] = temp[0] # 注意:列表和数组属于不同类型数据,否则赋值报错!!
b_temp[0][n] = sess.run(b)
print(step, sess.run(W), sess.run(b))
n = n + 1
W_temp = np.transpose(W_temp)
# 参数画图
fig = plt.figure()
step=np.arange(0,m,20)
plt.plot(step,W_temp[0], marker='o', mec='b', mfc='b',label=u'W1')
plt.plot(step,W_temp[1], marker='*', ms=10,label=u'W2')
plt.plot(step,b_temp[0], marker='^', ms=10,label=u'b')
plt.legend()
plt.xlabel('step')
plt.ylabel('value')
plt.title('Parameter')
# 测试数据.
x_test = np.float32(np.random.rand(2, 100))
y_test = np.dot([5.000, 7.000], x_test) + 2.000
# 测试Test trained model
y_=sess.run(tf.matmul(W, x_test) + b)
y = abs(y_-y_test)
y=np.where(y>0.01,y,1) # 数组y_中如果元素大于一个给定的数0.001则值保留否则元素被1替代。
accuracy=np.sum(y==1)/len(y[0,:])
print('accuracy:',accuracy)
sess.close()
# 测试结果数据画图对比
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
xs = x_test[0]
ys = x_test[1]
zs1 = y_test
zs2 = y_
ax.scatter(xs, ys, zs1, c='r', marker='v')
ax.scatter(xs, ys, zs2, c='b', marker='^')
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')
# 显示以上画的两个图
plt.show()