第一个TF程序:评估模型
来自CloudWiki
- 那么我们的模型性能如何呢?
- 首先让我们找出那些预测正确的标签。tf.argmax 是一个非常有用的函数,它能给出某个tensor对象在某一维上的其数据最大值所在的索引值。由于标签向量是由0,1组成,因此最大值1所在的索引位置就是类别标签,比如tf.argmax(y,1)返回的是模型对于任一输入x预测到的标签值,而 tf.argmax(y_,1) 代表正确的标签,我们可以用 tf.equal 来检测我们的预测是否真实标签匹配(索引位置一样表示匹配)。
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
- 这行代码会给我们一组布尔值。为了确定正确预测项的比例,我们可以把布尔值转换成浮点数,然后取平均值。例如,[True, False, True, True] 会变成 [1,0,1,1] ,取平均值后得到 0.75.
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
- 最后,我们计算所学习到的模型在测试数据集上面的正确率。
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})
- 这个最终结果值应该大约是91%。
- 这个结果好吗?嗯,并不太好。事实上,这个结果是很差的。这是因为我们仅仅使用了一个非常简单的模型。不过,做一些小小的改进,我们就可以得到97%的正确率。最好的模型甚至可以获得超过99.7%的准确率!(想了解更多信息,可以看看这个关于各种模型的性能对比列表。)
- 比结果更重要的是,我们从这个模型中学习到的设计思想。不过,如果你仍然对这里的结果有点失望,可以查看下一个教程,在那里你可以学习如何用FensorFlow构建更加复杂的模型以获得更好的性能!
源代码
- 以下是我们第一个tensorflow程序的源代码:
#inport Minst dataset import input_data mnist = input_data.read_data_sets("data0",one_hot=True) import tensorflow as tf x = tf.placeholder("float", [None, 784]) W = tf.Variable(tf.zeros([784,10])) b = tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(x,W) + b)
# the correct value: y_ = tf.placeholder("float",[None,10]) #the cross_entropy: cross_entropy = -tf.reduce_sum(y_*tf.log(y))
#backpropagation algorithm train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
#initialize variables init = tf.initialize_all_variables()
#start our model sess = tf.Session() sess.run(init)
#train our model for i in range(1000): batch_xs,batch_ys = mnist.train.next_batch(100) sess.run(train_step,feed_dict={x:batch_xs,y_: batch_ys})
#evaluate our model correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) print sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels})
返回 人工智能
参考文档:
- [1] MNIST机器学习入门 http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html
- [2] TensorFlow下MNIST数据集下载脚本input_data.py http://blog.csdn.net/lwplwf/article/details/54896959