51 lines
2.2 KiB
Python
51 lines
2.2 KiB
Python
|
#===============================================
|
|||
|
# 训练简单的神经网络,并显示运行时间
|
|||
|
# 数据集:mnnist
|
|||
|
#===============================================
|
|||
|
from __future__ import absolute_import
|
|||
|
from __future__ import division
|
|||
|
from __future__ import print_function
|
|||
|
|
|||
|
import datetime
|
|||
|
starttime = datetime.datetime.now()
|
|||
|
|
|||
|
import tensorflow as tf
|
|||
|
import numpy as np
|
|||
|
|
|||
|
# Import data
|
|||
|
from tensorflow.examples.tutorials.mnist import input_data
|
|||
|
flags = tf.app.flags
|
|||
|
FLAGS = flags.FLAGS
|
|||
|
flags.DEFINE_string('data_dir', '/learn/tensorflow/python/data/', 'Directory for storing data') # 把数据放在/data文件夹中
|
|||
|
mnist_data = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) # 读取数据集
|
|||
|
|
|||
|
# 建立抽象模型
|
|||
|
x = tf.placeholder(tf.float32, [None, 784]) # 占位符
|
|||
|
y = tf.placeholder(tf.float32, [None, 10])
|
|||
|
W = tf.Variable(tf.zeros([784, 10]))
|
|||
|
b = tf.Variable(tf.zeros([10]))
|
|||
|
a = tf.nn.softmax(tf.matmul(x, W) + b)
|
|||
|
|
|||
|
# 定义损失函数和训练方法
|
|||
|
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y * tf.log(a), reduction_indices=[1])) # 损失函数为交叉熵,学习速率要设为0.3量级
|
|||
|
#cross_entropy = -tf.reduce_sum(y * tf.log(a)) # 损失函数为交叉熵
|
|||
|
optimizer = tf.train.GradientDescentOptimizer(0.3) # 梯度下降法,学习速率要设为0.001量级
|
|||
|
train_next = optimizer.minimize(cross_entropy) # 训练目标:最小化损失函数
|
|||
|
|
|||
|
# Train
|
|||
|
sess = tf.InteractiveSession() # 建立交互式会话
|
|||
|
# tf.global_variables_initializer().run()
|
|||
|
sess.run(tf.global_variables_initializer())
|
|||
|
for i in range(1000):
|
|||
|
batch_xs, batch_ys = mnist_data.train.next_batch(100) # 随机抓取100个数据
|
|||
|
# train_next.run({x: batch_xs, y: batch_ys})
|
|||
|
sess.run(train_next, feed_dict={x: batch_xs, y: batch_ys})
|
|||
|
|
|||
|
#测试
|
|||
|
correct_prediction = tf.equal(tf.argmax(a, 1), tf.argmax(y, 1))
|
|||
|
# tf.cast先将数据转换成float,防止求平均不准确:比如 tf.float32就是正确,写成tf.float16导致不准确,超出范围。
|
|||
|
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
|
|||
|
print(sess.run(accuracy,feed_dict={x:mnist_data.test.images,y:mnist_data.test.labels}))
|
|||
|
|
|||
|
endtime=datetime.datetime.now()
|
|||
|
print('total time endtime-starttime:', endtime-starttime)
|