第一个TF程序:MNIST案例简介

来自CloudWiki
(重定向自MNIST案例简介
跳转至: 导航搜索

这个教程的目标读者是对机器学习和TensorFlow都不太了解的新手。

当我们开始学习编程的时候,第一件事往往是学习打印"Hello World"。就好比编程入门有Hello World,机器学习入门有MNIST。

MNIST是一个入门级的计算机视觉数据集,它包含各种手写数字图片:

  • MNIST.png

它也包含每一张图片对应的标签,告诉我们这个是数字几。比如,上面这四张图片的标签分别是5,0,4,1。

在此教程中,我们将训练一个机器学习模型用于预测图片里面的数字。我们的目的不是要设计一个世界一流的复杂模型 -- 尽管我们会在之后给你源代码去实现一流的预测模型 -- 而是要介绍下如何使用TensorFlow。所以,我们这里会从一个很简单的数学模型开始,它叫做Softmax Regression。

对应这个教程的实现代码很短,而且真正有意思的内容只包含在三行代码里面。但是,去理解包含在这些代码里面的设计思想是非常重要的:TensorFlow工作流程和机器学习的基本概念。因此,这个教程会很详细地介绍这些代码的实现原理。

准备工作

  • 在下载数据之前,需要在你的Linux系统中安装python-numpy、python-six、python-pip、python-dev、tensorflow等包。
  28  sudo apt-get install -y python-numpy
  34  sudo apt-get install -y python-six
  37  sudo apt-get install python-pip python-dev
  41  sudo pip install --upgrade https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-0.8.0-cp27-none-linux_x86_64.whl

参考文档:[ubuntu/linux系统下tensorflow的安装 https://jingyan.baidu.com/article/455a99504fb489a1662778b8.html]

下载数据

  • MNIST数据集的官网是Yann LeCun's website。在这里,我们提供了一份python源代码用于自动下载和安装这个数据集。
  • 源代码在GitHub上,下面也贴出源码:
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for downloading and reading MNIST data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip
import os
import tempfile
import numpy
from six.moves import urllib
from six.moves import xrange  # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets

使用方法

  • 将以上代码存为input_data.py文件
  • 新建test.py文件,与input_data.py在同一工程目录下
  • test.py内容如下:
#下载用于训练和测试的mnist数据集的源码
import input_data # 调用input_data
mnist = input_data.read_data_sets('data/', one_hot=True)
  • 运行test.py文件:
python test.py
  • 这时可以看到在同一目录下多了一个data目录,这就是下载的数据集:
(tensorflow) cloud@ubuntu:~/ai$ ls
data  hello.py  input_data.py  input_data.pyc  test.py
(tensorflow) cloud@ubuntu:~/ai$ ls data
t10k-images-idx3-ubyte.gz  train-images-idx3-ubyte.gz
t10k-labels-idx1-ubyte.gz  train-labels-idx1-ubyte.gz


MNIST数据集

下载下来的数据集被分成两部分:60000行的训练数据集(mnist.train)和10000行的测试数据集(mnist.test)。这样的切分很重要,在机器学习模型设计时必须有一个单独的测试数据集不用于训练而是用来评估这个模型的性能,从而更加容易把设计的模型推广到其他数据集上(泛化)。

正如前面提到的一样,每一个MNIST数据单元有两部分组成:一张包含手写数字的图片和一个对应的标签。我们把这些图片设为“xs”,把这些标签设为“ys”。训练数据集和测试数据集都包含xs和ys,比如训练数据集的图片是 mnist.train.images ,训练数据集的标签是 mnist.train.labels。

训练集中的每一张图片包含28像素X28像素。我们可以用一个数字数组来表示这张图片:

T1-5.png

跟我们用手机拍摄的图片类似,这些小图片拥有28x28 = 784个像素,或784个向量空间。从这个角度来看,MNIST数据集的图片就是在784维向量空间里面的点。

因此,在MNIST训练数据集中,mnist.train.images 是一个形状为 [60000, 784] 的张量,第一个维度数字用来索引图片,第二个维度数字用来索引每张图片中的像素点。在此张量里的每一个元素,都表示某张图片里的某个像素的强度值,值介于0和1之间。

T1-6.png

相对应的MNIST数据集的标签是数字,用来描述给定图片里的数字是几?可选的值是0~9。比如,标签0将表示成([1,0,0,0,0,0,0,0,0,0,0]),标签9将表示成([0,0,0,0,0,0,0,0,0,0,1])。因此, mnist.train.labels 是一个 [60000, 10] 的数字矩阵。

T1-7.png

好了,我们准备好可以开始构建我们的模型啦!

下一节 Softmax回归介绍


参考文档: