1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for basic_loops.py.""" 16 17import os 18import shutil 19 20from tensorflow.python.framework import errors_impl 21from tensorflow.python.framework import ops 22from tensorflow.python.platform import test 23from tensorflow.python.training import basic_loops 24from tensorflow.python.training import supervisor 25 26 27def _test_dir(test_name): 28 test_dir = os.path.join(test.get_temp_dir(), test_name) 29 if os.path.exists(test_dir): 30 shutil.rmtree(test_dir) 31 return test_dir 32 33 34class BasicTrainLoopTest(test.TestCase): 35 36 def testBasicTrainLoop(self): 37 logdir = _test_dir("basic_train_loop") 38 # Counts the number of calls. 39 num_calls = [0] 40 41 def train_fn(unused_sess, sv, y, a): 42 num_calls[0] += 1 43 self.assertEqual("y", y) 44 self.assertEqual("A", a) 45 if num_calls[0] == 3: 46 sv.request_stop() 47 48 with ops.Graph().as_default(): 49 sv = supervisor.Supervisor(logdir=logdir) 50 basic_loops.basic_train_loop( 51 sv, train_fn, args=(sv, "y"), kwargs={"a": "A"}) 52 self.assertEqual(3, num_calls[0]) 53 54 def testBasicTrainLoopExceptionAborts(self): 55 logdir = _test_dir("basic_train_loop_exception_aborts") 56 57 def train_fn(unused_sess): 58 train_fn.counter += 1 59 if train_fn.counter == 3: 60 raise RuntimeError("Failed") 61 62 # Function attribute use to count the number of calls. 63 train_fn.counter = 0 64 65 with ops.Graph().as_default(): 66 sv = supervisor.Supervisor(logdir=logdir) 67 with self.assertRaisesRegex(RuntimeError, "Failed"): 68 basic_loops.basic_train_loop(sv, train_fn) 69 70 def testBasicTrainLoopRetryOnAborted(self): 71 logdir = _test_dir("basic_train_loop_exception_aborts") 72 73 class AbortAndRetry: 74 75 def __init__(self): 76 self.num_calls = 0 77 self.retries_left = 2 78 79 def train_fn(self, unused_sess): 80 self.num_calls += 1 81 if self.num_calls % 3 == 2: 82 self.retries_left -= 1 83 if self.retries_left > 0: 84 raise errors_impl.AbortedError(None, None, "Aborted here") 85 else: 86 raise RuntimeError("Failed Again") 87 88 with ops.Graph().as_default(): 89 sv = supervisor.Supervisor(logdir=logdir) 90 aar = AbortAndRetry() 91 with self.assertRaisesRegex(RuntimeError, "Failed Again"): 92 basic_loops.basic_train_loop(sv, aar.train_fn) 93 self.assertEqual(0, aar.retries_left) 94 95 96if __name__ == "__main__": 97 test.main() 98