はじめに
Tensorflowを使う際にコードによって若干の違いが見られたのでその点を理解しておきたいと思います。
- run() と eval()
- InteractiveSession() と Session()
この2点に違いについて説明します。
run() vs eval()
例えば、以下のような簡単なMLPの実装の一部を見て下さい。
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=t, logits=h_fc)) train_step = tf.train.GradientDescentOptimizer(0.1).minimize(cost) correct_prediction = tf.equal(tf.argmax(h_fc, 1), tf.argmax(t, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) sess = tf.InteractiveSession() init = tf.global_variables_initializer() sess.run(init) n_epochs = 10 batch_size = 100 n_batches = train_X.shape[0] // batch_size train_X, train_y = shuffle(train_X, train_y) for epoch in range(n_epochs): for i in range(n_batches): start = i * batch_size end = start + batch_size train_step.run(feed_dict={x: train_X[start:end], t: train_y[start:end]}) train_accuracy = accuracy.eval(feed_dict={x: valid_X, t: valid_y}) print("EPOCH::%i, training_accuracy %g" % (epoch+1, train_accuracy)) print("test accuracy %g" % accuracy.eval(feed_dict={x: mnist.test.images, t: mnist.test.labels})) sess.close()
上のコードの中で、例えば、
train_step.run(feed_dict={x: train_X[start:end], t: train_y[start:end]})
の部分では run
が使われているのに
accuracy.eval(feed_dict={x: valid_X, t: valid_y})
の部分では eval
が使われているじゃないですか。
eval
と run
って何が違うのでしょうか?
上記のAnswerとして以下のようにあります。
op.run() is a shortcut for calling tf.get_default_session().run(op) t.eval() is a shortcut for calling tf.get_default_session().run(t)
ここでいう tf.get_default_session().run()
は
sess = tf.Session() sess.run()
の sess = tf.get_default_session
と考えれば分かりやすいと思います。
じゃ、「結局どっちも同じじゃん」って感じですけど、run
は Operation
クラスで eval
は Tensor
クラスに属するのでオブジェクトに応じてメソッドを変える必要があるということです。これが結論です。
ここで、「あれれ、じゃあ sess.run
ってどういうやつだっけ?」ともなっているかもしれません。次で説明します。
InteractiveSession() vs Session()
InteractiveSession() がTensorflowの公式に載っていました。Session()に対してInteractiveSession() は何が違うのでしょうか?
A TensorFlow Session for use in interactive contexts, such as a shell. The only difference with a regular Session is that an InteractiveSession installs itself as the default session on construction. The methods Tensor.eval() and Operation.run() will use that session to run ops. This is convenient in interactive shells and IPython notebooks, as it avoids having to pass an explicit Session object to run ops.
つまり、 InteractiveSessonを使うとsess = Session()
のようにして指定したsessを明示的に指定しなくてもよくなるよ、ということです。IPython notebookで使うときとかに便利だということですね。
examples/faq.md at master · tensorflow/examples · GitHub
以下は上記リンク先のコード例です。わざわざ sess.run()
のような記述はいらなくなります。
sess = tf.InteractiveSession() a = tf.constant(5.0) b = tf.constant(6.0) c = a * b # We can just use 'c.eval()' without passing 'sess' print(c.eval()) sess.close()
ちなみに with
公文を使えば、tf.Session()
を使っても同様の記述が出来るようです。こちらも上記リンク先のコードです。
a = tf.constant(5.0) b = tf.constant(6.0) c = a * b with tf.Session(): # We can also use 'c.eval()' here. print(c.eval())
sess.run
を使って書いてみます。
a = tf.constant(5.0) b = tf.constant(6.0) c = a * b sess = tf.Session() sess.run(c) sess.close()
以上で違いが理解出来たのではないでしょうか?
おわりに
IPython notebookを使うときはInteractiveSession
が便利のような気もしますが(ちょっと楽)、Session
は run
のメソッドで(eval
とごちゃごちゃにならず)統一的に書けるので良いなと思ったりもしました。