Liu/ANN_mnist.py

51 lines
2.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#===============================================
# 训练简单的神经网络,并显示运行时间
# 数据集mnnist
#===============================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import datetime
starttime = datetime.datetime.now()
import tensorflow as tf
import numpy as np
# Import data
from tensorflow.examples.tutorials.mnist import input_data
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_string('data_dir', '/learn/tensorflow/python/data/', 'Directory for storing data') # 把数据放在/data文件夹中
mnist_data = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) # 读取数据集
# 建立抽象模型
x = tf.placeholder(tf.float32, [None, 784]) # 占位符
y = tf.placeholder(tf.float32, [None, 10])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
a = tf.nn.softmax(tf.matmul(x, W) + b)
# 定义损失函数和训练方法
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(a), reduction_indices=[1])) # 损失函数为交叉熵学习速率要设为0.3量级
#cross_entropy = -tf.reduce_sum(y * tf.log(a)) # 损失函数为交叉熵
optimizer = tf.train.GradientDescentOptimizer(0.3) # 梯度下降法学习速率要设为0.001量级
train_next = optimizer.minimize(cross_entropy) # 训练目标:最小化损失函数
# Train
sess = tf.InteractiveSession() # 建立交互式会话
# tf.global_variables_initializer().run()
sess.run(tf.global_variables_initializer())
for i in range(1000):
batch_xs, batch_ys = mnist_data.train.next_batch(100) # 随机抓取100个数据
# train_next.run({x: batch_xs, y: batch_ys})
sess.run(train_next, feed_dict={x: batch_xs, y: batch_ys})
#测试
correct_prediction = tf.equal(tf.argmax(a, 1), tf.argmax(y, 1))
# tf.cast先将数据转换成float防止求平均不准确:比如 tf.float32就是正确写成tf.float16导致不准确超出范围。
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
print(sess.run(accuracy,feed_dict={x:mnist_data.test.images,y:mnist_data.test.labels}))
endtime=datetime.datetime.now()
print('total time endtime-starttime:', endtime-starttime)