83 lines
2.3 KiB
Python
83 lines
2.3 KiB
Python
#====================================================================
|
||
# 使用 Python API , 拟合一个平面,训练、测试,并画图显示
|
||
# 麻雀虽小五脏俱全!
|
||
#====================================================================
|
||
import tensorflow as tf
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
from mpl_toolkits.mplot3d import Axes3D
|
||
|
||
# 训练数据
|
||
x_train = np.float32(np.random.rand(2, 100)) # 随机输入
|
||
y_train = np.dot([5.000, 7.000], x_train) + 2.000
|
||
|
||
# 构造一个线性模型
|
||
b = tf.Variable(tf.zeros([1]))
|
||
W = tf.Variable(tf.zeros([1,2]))
|
||
y = tf.matmul(W, x_train) + b
|
||
|
||
# 最小化方差
|
||
loss = tf.reduce_mean(tf.square(y - y_train))
|
||
optimizer = tf.train.GradientDescentOptimizer(0.5)
|
||
train = optimizer.minimize(loss)
|
||
|
||
# 启动图 (graph)
|
||
sess = tf.Session()
|
||
sess.run(tf.global_variables_initializer())
|
||
|
||
# 训练:拟合平面
|
||
m = 101
|
||
n = 0
|
||
W_temp = np.zeros([m // 20 + 1, 2])
|
||
b_temp = np.zeros([1, m // 20 + 1])
|
||
for step in range(0, m):
|
||
sess.run(train)
|
||
if step % 20 == 0:
|
||
temp = sess.run(W)
|
||
W_temp[n] = temp[0] # 注意:列表和数组属于不同类型数据,否则赋值报错!!
|
||
b_temp[0][n] = sess.run(b)
|
||
print(step, sess.run(W), sess.run(b))
|
||
n = n + 1
|
||
|
||
W_temp = np.transpose(W_temp)
|
||
|
||
# 参数画图
|
||
fig = plt.figure()
|
||
step=np.arange(0,m,20)
|
||
plt.plot(step,W_temp[0], marker='o', mec='b', mfc='b',label=u'W1')
|
||
plt.plot(step,W_temp[1], marker='*', ms=10,label=u'W2')
|
||
plt.plot(step,b_temp[0], marker='^', ms=10,label=u'b')
|
||
plt.legend()
|
||
plt.xlabel('step')
|
||
plt.ylabel('value')
|
||
plt.title('Parameter')
|
||
|
||
|
||
# 测试数据.
|
||
x_test = np.float32(np.random.rand(2, 100))
|
||
y_test = np.dot([5.000, 7.000], x_test) + 2.000
|
||
|
||
# 测试:Test trained model
|
||
y_=sess.run(tf.matmul(W, x_test) + b)
|
||
y = abs(y_-y_test)
|
||
y=np.where(y>0.01,y,1) # 数组y_中,如果元素大于一个给定的数0.001,则值保留,否则元素被1替代。
|
||
accuracy=np.sum(y==1)/len(y[0,:])
|
||
print('accuracy:',accuracy)
|
||
|
||
sess.close()
|
||
|
||
# 测试结果数据画图对比
|
||
fig = plt.figure()
|
||
ax = fig.add_subplot(111, projection='3d')
|
||
xs = x_test[0]
|
||
ys = x_test[1]
|
||
zs1 = y_test
|
||
zs2 = y_
|
||
ax.scatter(xs, ys, zs1, c='r', marker='v')
|
||
ax.scatter(xs, ys, zs2, c='b', marker='^')
|
||
ax.set_xlabel('x')
|
||
ax.set_ylabel('y')
|
||
ax.set_zlabel('z')
|
||
|
||
# 显示以上画的两个图
|
||
plt.show() |