読者です 読者をやめる 読者になる 読者になる

真・全力失踪

全力で道を見失うブログ

TensorFlowのチュートリアルスクリプトを理解する(続編)

翌々調べてみると
tensorflow/example/tutorial/mnist/fully_connected_feed.py
ちゃんとしたチュートリアルが存在していた。

色々とラップした関数があったけど、
今は一通り処理を追い駆けたいので全部展開してみました。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from six.moves import xrange

import time
import math
import tensorflow as tf

モジュールをインポートしてます。

from tensorflow.examples.tutorials.mnist import input_data
NUM_CLASSES = 10
IMAGE_SIZE = 28
IMAGE_PIXELS = IMAGE_SIZE * IMAGE_SIZE

MNISTのデータを使うので、ここら辺もインポートしてます。

flags = tf.app.flags
flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
flags.DEFINE_integer('max_steps', 2000, 'Number of steps to run trainer.')
flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')
flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')
flags.DEFINE_integer('batch_size', 100, 'Batch size. Must divide evenly into the dataset sizes.')
flags.DEFINE_string('train_dir', 'data', 'Directory to put the training data.')
flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data for unit testing.')
FLAGS = flags.FLAGS

基本的なパラメータをFLAGとしてセットしてます。

def do_eval(sess, eval_correct, images_placeholder, labels_placeholder, data_set):
  true_count = 0  # Counts the number of correct predictions.
  steps_per_epoch = data_set.num_examples // FLAGS.batch_size
  num_examples = steps_per_epoch * FLAGS.batch_size
  for step in xrange(steps_per_epoch):
    images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size, FLAGS.fake_data)
    feed_dict = {
      images_placeholder: images_feed,
      labels_placeholder: labels_feed,
    }
    true_count += sess.run(eval_correct, feed_dict=feed_dict)
  precision = true_count / num_examples
  print('  Num examples: %d  Num correct: %d  Precision @ 1: %0.04f' % (num_examples, true_count, precision))

画像識別率を評価する関数ですね。
演算子「//」は切り捨て除算だそうです。
「FLAGS.batch_size」のような形式で、先ほど設定したパラメータを取り出せるっぽいです。
「data_set.next_batch」は、データセットから指定したサイズのDNNの入力信号と教師信号を取り出してタプルで返してくれる便利メソッドです。
あと本筋とは離れますがxrangeは遅延評価されるレンジオブジェクトとのこと(参考)

def run_training():
  data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)

  with tf.Graph().as_default():
    ### Generate placeholders for the images and labels.
    images_placeholder = tf.placeholder(tf.float32, shape=(FLAGS.batch_size, IMAGE_PIXELS))
    labels_placeholder = tf.placeholder(tf.int32, shape=(FLAGS.batch_size))

訓練データのためのメモリを確保してますね。
imagesはDNNの入力ベクトル、labelsはDNNの教師信号のインデックス、
一度に食わせるデータの数はbatch_sizeです。

    #### Build a Graph that computes predictions from the inference model.
    # Hidden 1
    with tf.name_scope('hidden1'):
      weights = tf.Variable(tf.truncated_normal(
                                                [IMAGE_PIXELS, FLAGS.hidden1],
                                                stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))
                                               ), name='weights')
      biases = tf.Variable(tf.zeros([FLAGS.hidden1]), name='biases')
      hidden1 = tf.nn.relu(tf.matmul(images_placeholder, weights) + biases)

    # Hidden 2
    with tf.name_scope('hidden2'):
      weights = tf.Variable(tf.truncated_normal(
                                                [FLAGS.hidden1, FLAGS.hidden2],
                                                stddev=1.0 / math.sqrt(float(FLAGS.hidden1))
                                               ), name='weights')
      biases = tf.Variable(tf.zeros([FLAGS.hidden2]), name='biases')
      hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)

ここで出力層以外のネットワークを構築してます。
tf.name_scope('hidden1')のwithブロックの中で定義したtf.Variableは、
'hidden1'というコンテキストの中でのみ有効ということみたいです。
あと、チュートリアルのレベルで活性化関数に「ReLU」が使われているとは驚きでした。

    # Linear
    with tf.name_scope('softmax_linear'):
      weights = tf.Variable(tf.truncated_normal(
                                                [FLAGS.hidden2, NUM_CLASSES],
                                                stddev=1.0 / math.sqrt(float(FLAGS.hidden2))
                                               ), name='weights')
      biases = tf.Variable(tf.zeros([NUM_CLASSES]), name='biases')
      logits = tf.matmul(hidden2, weights) + biases

    labels_placeholder = tf.to_int64(labels_placeholder)
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
                                                                   logits,
                                                                   labels_placeholder,
                                                                   name='xentropy'
                                                                  )
    loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')

なぜここでint64にキャストしてるのかはちょっと謎ですが、
グラフに損失を計算するOperationを追加しています。
ところでlogitsってなんだろうと思って調べてみると、
ロジスティック関数の逆関数だそうです。
活性化関数がReLUの場合でもロジットになるのだろうか…?
tf.nn.sparse_softmax_cross_entropy_with_logitsは、
logitsとlabelベクトル間のソフトマックス・クロスエントロピーを計算する関数ですね。

続きはまた今度。

    ### Add to the Graph the Ops that calculate and apply gradients.
    # Add a scalar summary for the snapshot loss.
    tf.scalar_summary(loss.op.name, loss)

    # Create the gradient descent optimizer with the given learning rate.
    optimizer = tf.train.GradientDescentOptimizer(FLAGS.learning_rate)

    # Create a variable to track the global step.
    global_step = tf.Variable(0, name='global_step', trainable=False)

    # Use the optimizer to apply the gradients that minimize the loss
    # (and also increment the global step counter) as a single training step.
    train_op = optimizer.minimize(loss, global_step=global_step)


    # Add the Op to compare the logits to the labels during evaluation.
    correct = tf.nn.in_top_k(logits, labels_placeholder, 1)
    eval_correct = tf.reduce_sum(tf.cast(correct, tf.int32))

    # Build the summary operation based on the TF collection of Summaries.
    summary_op = tf.merge_all_summaries()

    # Add the variable initializer Op.
    init = tf.initialize_all_variables()

    # Create a saver for writing training checkpoints.
    saver = tf.train.Saver()

    # Create a session for running Ops on the Graph.
    sess = tf.Session()

    # Instantiate a SummaryWriter to output summaries and the Graph.
    summary_writer = tf.train.SummaryWriter(FLAGS.train_dir, sess.graph)

    # And then after everything is built:

    # Run the Op to initialize the variables.
    sess.run(init)

    # Start the training loop.
    for step in xrange(FLAGS.max_steps):
      start_time = time.time()

      # Fill a feed dictionary with the actual set of images and labels
      # for this particular training step.
      images_feed, labels_feed = data_sets.train.next_batch(FLAGS.batch_size, FLAGS.fake_data)
      feed_dict = {
        images_placeholder: images_feed,
        labels_placeholder: labels_feed,
      }

      # Run one step of the model.  The return values are the activations
      # from the `train_op` (which is discarded) and the `loss` Op.  To
      # inspect the values of your Ops or variables, you may include them
      # in the list passed to sess.run() and the value tensors will be
      # returned in the tuple from the call.
      _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)

      duration = time.time() - start_time

      # Write the summaries and print an overview fairly often.
      if step % 100 == 0:
        # Print status to stdout.
        print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))
        # Update the events file.
        summary_str = sess.run(summary_op, feed_dict=feed_dict)
        summary_writer.add_summary(summary_str, step)
        summary_writer.flush()

      # Save a checkpoint and evaluate the model periodically.
      if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:
        saver.save(sess, FLAGS.train_dir, global_step=step)
        # Evaluate against the training set.
        print('Training Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.train)
        # Evaluate against the validation set.
        print('Validation Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.validation)
        # Evaluate against the test set.
        print('Test Data Eval:')
        do_eval(sess,
                eval_correct,
                images_placeholder,
                labels_placeholder,
                data_sets.test)


def main(_):
  run_training()


if __name__ == '__main__':
  tf.app.run()