import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import Axes3D # TensorFlow 2.x 使用 Keras API from keras.optimizers import SGD # 确保我们正在使用 TensorFlow 2.x print("TensorFlow version:", tf.__version__) # 生成训练数据 x_train = np.float32(np.random.rand(2, 100)) y_train = np.dot([5.000, 7.000], x_train) + 2.000 # 构建线性模型 class LinearModel(tf.keras.Model): def __init__(self): super(LinearModel, self).__init__() self.W = tf.Variable(tf.zeros([1, 2]), name='weight') self.b = tf.Variable(tf.zeros([1]), name='bias') def call(self, inputs): return tf.matmul(self.W, inputs) + self.b model = LinearModel() # 定义损失函数和优化器 loss_object = tf.keras.losses.MeanSquaredError() optimizer = SGD(learning_rate=0.5) # 训练步骤封装 @tf.function def train_step(inputs, outputs): with tf.GradientTape() as tape: predictions = model(inputs) loss = loss_object(outputs, predictions) gradients = tape.gradient(loss, model.trainable_variables) optimizer.apply_gradients(zip(gradients, model.trainable_variables)) return loss # 训练模型 m = 101 W_temp, b_temp = [], [] for step in range(m): loss = train_step(x_train, y_train) if step % 20 == 0: print("Step {:02d}: Loss {:.3f}, W {}, b {}".format( step, loss, model.W.numpy(), model.b.numpy())) W_temp.append(model.W.numpy()[0]) b_temp.append(model.b.numpy()[0]) # 参数可视化 fig, ax = plt.subplots() step = np.arange(0, m, 20) ax.plot(step, [w[0] for w in W_temp], marker='o', label='W1') ax.plot(step, [w[1] for w in W_temp], marker='*', label='W2') ax.plot(step, b_temp, marker='^', label='b') ax.legend() ax.set_xlabel('step') ax.set_ylabel('value') ax.set_title('Parameter') # 生成测试数据 x_test = np.float32(np.random.rand(2, 100)) y_test = np.dot([5.000, 7.000], x_test) + 2.000 # 测试模型 y_pred = model(x_test) y_diff = abs(y_pred - y_test) y_diff = np.where(y_diff > 0.01, y_diff, 1) accuracy = np.sum(y_diff == 1) / len(y_diff[0]) print('accuracy:', accuracy) # 测试结果可视化 fig = plt.figure() ax = fig.add_subplot(111, projection='3d') xs, ys = x_test[0], x_test[1] zs1, zs2 = y_test, y_pred.numpy()[0] ax.scatter(xs, ys, zs1, c='r', marker='v') # 真实值 ax.scatter(xs, ys, zs2, c='b', marker='^') # 预测值 ax.set_xlabel('X') ax.set_ylabel('Y') ax.set_zlabel('Z') # 显示图表 plt.show()