• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for basic_loops.py."""
16
17import os
18import shutil
19
20from tensorflow.python.framework import errors_impl
21from tensorflow.python.framework import ops
22from tensorflow.python.platform import test
23from tensorflow.python.training import basic_loops
24from tensorflow.python.training import supervisor
25
26
27def _test_dir(test_name):
28  test_dir = os.path.join(test.get_temp_dir(), test_name)
29  if os.path.exists(test_dir):
30    shutil.rmtree(test_dir)
31  return test_dir
32
33
34class BasicTrainLoopTest(test.TestCase):
35
36  def testBasicTrainLoop(self):
37    logdir = _test_dir("basic_train_loop")
38    # Counts the number of calls.
39    num_calls = [0]
40
41    def train_fn(unused_sess, sv, y, a):
42      num_calls[0] += 1
43      self.assertEqual("y", y)
44      self.assertEqual("A", a)
45      if num_calls[0] == 3:
46        sv.request_stop()
47
48    with ops.Graph().as_default():
49      sv = supervisor.Supervisor(logdir=logdir)
50      basic_loops.basic_train_loop(
51          sv, train_fn, args=(sv, "y"), kwargs={"a": "A"})
52      self.assertEqual(3, num_calls[0])
53
54  def testBasicTrainLoopExceptionAborts(self):
55    logdir = _test_dir("basic_train_loop_exception_aborts")
56
57    def train_fn(unused_sess):
58      train_fn.counter += 1
59      if train_fn.counter == 3:
60        raise RuntimeError("Failed")
61
62    # Function attribute use to count the number of calls.
63    train_fn.counter = 0
64
65    with ops.Graph().as_default():
66      sv = supervisor.Supervisor(logdir=logdir)
67      with self.assertRaisesRegex(RuntimeError, "Failed"):
68        basic_loops.basic_train_loop(sv, train_fn)
69
70  def testBasicTrainLoopRetryOnAborted(self):
71    logdir = _test_dir("basic_train_loop_exception_aborts")
72
73    class AbortAndRetry:
74
75      def __init__(self):
76        self.num_calls = 0
77        self.retries_left = 2
78
79      def train_fn(self, unused_sess):
80        self.num_calls += 1
81        if self.num_calls % 3 == 2:
82          self.retries_left -= 1
83        if self.retries_left > 0:
84          raise errors_impl.AbortedError(None, None, "Aborted here")
85        else:
86          raise RuntimeError("Failed Again")
87
88    with ops.Graph().as_default():
89      sv = supervisor.Supervisor(logdir=logdir)
90      aar = AbortAndRetry()
91      with self.assertRaisesRegex(RuntimeError, "Failed Again"):
92        basic_loops.basic_train_loop(sv, aar.train_fn)
93      self.assertEqual(0, aar.retries_left)
94
95
96if __name__ == "__main__":
97  test.main()
98