Liu/ne_ZL_TF2.py

41 lines
1.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# 准备初始数据
x_data = np.float32(np.random.rand(100)) # 生成 100 个随机数,作为输入数据
y_data = x_data * 2 + 1.6 # 根据线性方程 y = 2x + 1.6 生成输出数据
# 初始化权重W和偏置b变量
weights = tf.Variable(tf.zeros([1]), name='weights') # 初始化权重为 0
biases = tf.Variable(tf.zeros([1]), name='biases') # 初始化偏置为 0
# 定义模型
def linear_model(x):
return weights * x + biases
# 定义损失函数
def loss_fn(y_true, y_pred):
return tf.reduce_mean(tf.square(y_true - y_pred))
# 定义优化器
optimizer = tf.optimizers.SGD(0.5) # 梯度下降优化器,学习率为 0.5
# 训练过程
def train_step(x, y):
with tf.GradientTape() as tape:
y_pred = linear_model(x)
loss = loss_fn(y, y_pred)
gradients = tape.gradient(loss, [weights, biases])
optimizer.apply_gradients(zip(gradients, [weights, biases]))
# 训练模型
for step in range(201):
train_step(x_data, y_data)
if step % 20 == 0:
print(f"Step {step}: weights = {weights.numpy()}, biases = {biases.numpy()}")
# 绘图(可选)
plt.scatter(x_data, y_data, c='r')
plt.plot(x_data, linear_model(x_data), c='g')
plt.show()