83 lines
2.3 KiB
Python
83 lines
2.3 KiB
Python
|
#====================================================================
|
|||
|
# 使用 Python API , 拟合一个平面,训练、测试,并画图显示
|
|||
|
# 麻雀虽小五脏俱全!
|
|||
|
#====================================================================
|
|||
|
import tensorflow as tf
|
|||
|
import numpy as np
|
|||
|
import matplotlib.pyplot as plt
|
|||
|
from mpl_toolkits.mplot3d import Axes3D
|
|||
|
|
|||
|
# 训练数据
|
|||
|
x_train = np.float32(np.random.rand(2, 100)) # 随机输入
|
|||
|
y_train = np.dot([5.000, 7.000], x_train) + 2.000
|
|||
|
|
|||
|
# 构造一个线性模型
|
|||
|
b = tf.Variable(tf.zeros([1]))
|
|||
|
W = tf.Variable(tf.zeros([1,2]))
|
|||
|
y = tf.matmul(W, x_train) + b
|
|||
|
|
|||
|
# 最小化方差
|
|||
|
loss = tf.reduce_mean(tf.square(y - y_train))
|
|||
|
optimizer = tf.train.GradientDescentOptimizer(0.5)
|
|||
|
train = optimizer.minimize(loss)
|
|||
|
|
|||
|
# 启动图 (graph)
|
|||
|
sess = tf.Session()
|
|||
|
sess.run(tf.global_variables_initializer())
|
|||
|
|
|||
|
# 训练:拟合平面
|
|||
|
m = 101
|
|||
|
n = 0
|
|||
|
W_temp = np.zeros([m // 20 + 1, 2])
|
|||
|
b_temp = np.zeros([1, m // 20 + 1])
|
|||
|
for step in range(0, m):
|
|||
|
sess.run(train)
|
|||
|
if step % 20 == 0:
|
|||
|
temp = sess.run(W)
|
|||
|
W_temp[n] = temp[0] # 注意:列表和数组属于不同类型数据,否则赋值报错!!
|
|||
|
b_temp[0][n] = sess.run(b)
|
|||
|
print(step, sess.run(W), sess.run(b))
|
|||
|
n = n + 1
|
|||
|
|
|||
|
W_temp = np.transpose(W_temp)
|
|||
|
|
|||
|
# 参数画图
|
|||
|
fig = plt.figure()
|
|||
|
step=np.arange(0,m,20)
|
|||
|
plt.plot(step,W_temp[0], marker='o', mec='b', mfc='b',label=u'W1')
|
|||
|
plt.plot(step,W_temp[1], marker='*', ms=10,label=u'W2')
|
|||
|
plt.plot(step,b_temp[0], marker='^', ms=10,label=u'b')
|
|||
|
plt.legend()
|
|||
|
plt.xlabel('step')
|
|||
|
plt.ylabel('value')
|
|||
|
plt.title('Parameter')
|
|||
|
|
|||
|
|
|||
|
# 测试数据.
|
|||
|
x_test = np.float32(np.random.rand(2, 100))
|
|||
|
y_test = np.dot([5.000, 7.000], x_test) + 2.000
|
|||
|
|
|||
|
# 测试:Test trained model
|
|||
|
y_=sess.run(tf.matmul(W, x_test) + b)
|
|||
|
y = abs(y_-y_test)
|
|||
|
y=np.where(y>0.01,y,1) # 数组y_中,如果元素大于一个给定的数0.001,则值保留,否则元素被1替代。
|
|||
|
accuracy=np.sum(y==1)/len(y[0,:])
|
|||
|
print('accuracy:',accuracy)
|
|||
|
|
|||
|
sess.close()
|
|||
|
|
|||
|
# 测试结果数据画图对比
|
|||
|
fig = plt.figure()
|
|||
|
ax = fig.add_subplot(111, projection='3d')
|
|||
|
xs = x_test[0]
|
|||
|
ys = x_test[1]
|
|||
|
zs1 = y_test
|
|||
|
zs2 = y_
|
|||
|
ax.scatter(xs, ys, zs1, c='r', marker='v')
|
|||
|
ax.scatter(xs, ys, zs2, c='b', marker='^')
|
|||
|
ax.set_xlabel('x')
|
|||
|
ax.set_ylabel('y')
|
|||
|
ax.set_zlabel('z')
|
|||
|
|
|||
|
# 显示以上画的两个图
|
|||
|
plt.show()
|