65 lines
1.9 KiB
Python
65 lines
1.9 KiB
Python
# try your best to do it!
|
|
# @Time : 2022/9/24 21:55
|
|
# @Author : LianghengZhang
|
|
# @File : MINIST_Classsification.py
|
|
import tensorflow as tf
|
|
from tensorflow.keras import datasets, layers, models
|
|
import matplotlib.pyplot as plt
|
|
|
|
import matplotlib
|
|
matplotlib.use('TkAgg') # 使用 TkAgg 后端
|
|
import matplotlib.pyplot as plt
|
|
|
|
if __name__ == '__main__':
|
|
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()
|
|
|
|
train_images, test_images = train_images / 255.0, test_images / 255.0
|
|
# 查看数据维数信息
|
|
print("data_shpe:",train_images.shape,test_images.shape,train_labels.shape,test_labels.shape)
|
|
"""
|
|
输出:((60000, 28, 28), (10000, 28, 28), (60000,), (10000,))
|
|
"""
|
|
|
|
plt.figure(figsize=(20, 10))
|
|
for i in range(20):
|
|
plt.subplot(2, 10, i + 1)
|
|
plt.xticks([])
|
|
plt.yticks([])
|
|
plt.grid(False)
|
|
plt.imshow(train_images[i], cmap=plt.cm.binary)
|
|
plt.xlabel(train_labels[i])
|
|
plt.show()
|
|
|
|
# 调整数据到我们需要的格式
|
|
train_images = train_images.reshape((60000, 28, 28, 1))
|
|
test_images = test_images.reshape((10000, 28, 28, 1))
|
|
|
|
print("data_shpe:",train_images.shape, test_images.shape, train_labels.shape, test_labels.shape)
|
|
"""
|
|
输出:((60000, 28, 28, 1), (10000, 28, 28, 1), (60000,), (10000,))
|
|
"""
|
|
|
|
model = models.Sequential([
|
|
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
|
|
layers.MaxPooling2D((2, 2)),
|
|
layers.Conv2D(64, (3, 3), activation='relu'),
|
|
layers.MaxPooling2D((2, 2)),
|
|
|
|
layers.Flatten(),
|
|
layers.Dense(64, activation='relu'),
|
|
layers.Dense(10)
|
|
])
|
|
model.summary()
|
|
|
|
|
|
model.compile(
|
|
optimizer='adam',
|
|
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
|
|
metrics=['accuracy'])
|
|
|
|
history = model.fit(
|
|
train_images,
|
|
train_labels,
|
|
epochs=10,
|
|
)
|