1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Coordinator.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import sys 22import threading 23import time 24 25from tensorflow.python.framework import errors_impl 26from tensorflow.python.platform import test 27from tensorflow.python.training import coordinator 28 29 30def StopOnEvent(coord, wait_for_stop, set_when_stopped): 31 wait_for_stop.wait() 32 coord.request_stop() 33 set_when_stopped.set() 34 35 36def RaiseOnEvent(coord, wait_for_stop, set_when_stopped, ex, report_exception): 37 try: 38 wait_for_stop.wait() 39 raise ex 40 except RuntimeError as e: 41 if report_exception: 42 coord.request_stop(e) 43 else: 44 coord.request_stop(sys.exc_info()) 45 finally: 46 if set_when_stopped: 47 set_when_stopped.set() 48 49 50def RaiseOnEventUsingContextHandler(coord, wait_for_stop, set_when_stopped, ex): 51 with coord.stop_on_exception(): 52 wait_for_stop.wait() 53 raise ex 54 if set_when_stopped: 55 set_when_stopped.set() 56 57 58def SleepABit(n_secs, coord=None): 59 if coord: 60 coord.register_thread(threading.current_thread()) 61 time.sleep(n_secs) 62 63 64def WaitForThreadsToRegister(coord, num_threads): 65 while True: 66 with coord._lock: 67 if len(coord._registered_threads) == num_threads: 68 break 69 time.sleep(0.001) 70 71 72class CoordinatorTest(test.TestCase): 73 74 def testStopAPI(self): 75 coord = coordinator.Coordinator() 76 self.assertFalse(coord.should_stop()) 77 self.assertFalse(coord.wait_for_stop(0.01)) 78 coord.request_stop() 79 self.assertTrue(coord.should_stop()) 80 self.assertTrue(coord.wait_for_stop(0.01)) 81 82 def testStopAsync(self): 83 coord = coordinator.Coordinator() 84 self.assertFalse(coord.should_stop()) 85 self.assertFalse(coord.wait_for_stop(0.1)) 86 wait_for_stop_ev = threading.Event() 87 has_stopped_ev = threading.Event() 88 t = threading.Thread( 89 target=StopOnEvent, args=(coord, wait_for_stop_ev, has_stopped_ev)) 90 t.start() 91 self.assertFalse(coord.should_stop()) 92 self.assertFalse(coord.wait_for_stop(0.01)) 93 wait_for_stop_ev.set() 94 has_stopped_ev.wait() 95 self.assertTrue(coord.wait_for_stop(0.05)) 96 self.assertTrue(coord.should_stop()) 97 98 def testJoin(self): 99 coord = coordinator.Coordinator() 100 threads = [ 101 threading.Thread(target=SleepABit, args=(0.01,)), 102 threading.Thread(target=SleepABit, args=(0.02,)), 103 threading.Thread(target=SleepABit, args=(0.01,)) 104 ] 105 for t in threads: 106 t.start() 107 coord.join(threads) 108 for t in threads: 109 self.assertFalse(t.is_alive()) 110 111 def testJoinAllRegistered(self): 112 coord = coordinator.Coordinator() 113 threads = [ 114 threading.Thread(target=SleepABit, args=(0.01, coord)), 115 threading.Thread(target=SleepABit, args=(0.02, coord)), 116 threading.Thread(target=SleepABit, args=(0.01, coord)) 117 ] 118 for t in threads: 119 t.start() 120 WaitForThreadsToRegister(coord, 3) 121 coord.join() 122 for t in threads: 123 self.assertFalse(t.is_alive()) 124 125 def testJoinSomeRegistered(self): 126 coord = coordinator.Coordinator() 127 threads = [ 128 threading.Thread(target=SleepABit, args=(0.01, coord)), 129 threading.Thread(target=SleepABit, args=(0.02,)), 130 threading.Thread(target=SleepABit, args=(0.01, coord)) 131 ] 132 for t in threads: 133 t.start() 134 WaitForThreadsToRegister(coord, 2) 135 # threads[1] is not registered we must pass it in. 136 coord.join([threads[1]]) 137 for t in threads: 138 self.assertFalse(t.is_alive()) 139 140 def testJoinGraceExpires(self): 141 142 def TestWithGracePeriod(stop_grace_period): 143 coord = coordinator.Coordinator() 144 wait_for_stop_ev = threading.Event() 145 has_stopped_ev = threading.Event() 146 threads = [ 147 threading.Thread( 148 target=StopOnEvent, 149 args=(coord, wait_for_stop_ev, has_stopped_ev)), 150 threading.Thread(target=SleepABit, args=(10.0,)) 151 ] 152 for t in threads: 153 t.daemon = True 154 t.start() 155 wait_for_stop_ev.set() 156 has_stopped_ev.wait() 157 with self.assertRaisesRegexp(RuntimeError, "threads still running"): 158 coord.join(threads, stop_grace_period_secs=stop_grace_period) 159 160 TestWithGracePeriod(1e-10) 161 TestWithGracePeriod(0.002) 162 TestWithGracePeriod(1.0) 163 164 def testJoinWithoutGraceExpires(self): 165 coord = coordinator.Coordinator() 166 wait_for_stop_ev = threading.Event() 167 has_stopped_ev = threading.Event() 168 threads = [ 169 threading.Thread( 170 target=StopOnEvent, args=(coord, wait_for_stop_ev, has_stopped_ev)), 171 threading.Thread(target=SleepABit, args=(10.0,)) 172 ] 173 for t in threads: 174 t.daemon = True 175 t.start() 176 wait_for_stop_ev.set() 177 has_stopped_ev.wait() 178 coord.join(threads, stop_grace_period_secs=1., ignore_live_threads=True) 179 180 def testJoinRaiseReportExcInfo(self): 181 coord = coordinator.Coordinator() 182 ev_1 = threading.Event() 183 ev_2 = threading.Event() 184 threads = [ 185 threading.Thread( 186 target=RaiseOnEvent, 187 args=(coord, ev_1, ev_2, RuntimeError("First"), False)), 188 threading.Thread( 189 target=RaiseOnEvent, 190 args=(coord, ev_2, None, RuntimeError("Too late"), False)) 191 ] 192 for t in threads: 193 t.start() 194 195 ev_1.set() 196 197 with self.assertRaisesRegexp(RuntimeError, "First"): 198 coord.join(threads) 199 200 def testJoinRaiseReportException(self): 201 coord = coordinator.Coordinator() 202 ev_1 = threading.Event() 203 ev_2 = threading.Event() 204 threads = [ 205 threading.Thread( 206 target=RaiseOnEvent, 207 args=(coord, ev_1, ev_2, RuntimeError("First"), True)), 208 threading.Thread( 209 target=RaiseOnEvent, 210 args=(coord, ev_2, None, RuntimeError("Too late"), True)) 211 ] 212 for t in threads: 213 t.start() 214 215 ev_1.set() 216 with self.assertRaisesRegexp(RuntimeError, "First"): 217 coord.join(threads) 218 219 def testJoinIgnoresOutOfRange(self): 220 coord = coordinator.Coordinator() 221 ev_1 = threading.Event() 222 threads = [ 223 threading.Thread( 224 target=RaiseOnEvent, 225 args=(coord, ev_1, None, 226 errors_impl.OutOfRangeError(None, None, "First"), True)) 227 ] 228 for t in threads: 229 t.start() 230 231 ev_1.set() 232 coord.join(threads) 233 234 def testJoinIgnoresMyExceptionType(self): 235 coord = coordinator.Coordinator(clean_stop_exception_types=(ValueError,)) 236 ev_1 = threading.Event() 237 threads = [ 238 threading.Thread( 239 target=RaiseOnEvent, 240 args=(coord, ev_1, None, ValueError("Clean stop"), True)) 241 ] 242 for t in threads: 243 t.start() 244 245 ev_1.set() 246 coord.join(threads) 247 248 def testJoinRaiseReportExceptionUsingHandler(self): 249 coord = coordinator.Coordinator() 250 ev_1 = threading.Event() 251 ev_2 = threading.Event() 252 threads = [ 253 threading.Thread( 254 target=RaiseOnEventUsingContextHandler, 255 args=(coord, ev_1, ev_2, RuntimeError("First"))), 256 threading.Thread( 257 target=RaiseOnEventUsingContextHandler, 258 args=(coord, ev_2, None, RuntimeError("Too late"))) 259 ] 260 for t in threads: 261 t.start() 262 263 ev_1.set() 264 with self.assertRaisesRegexp(RuntimeError, "First"): 265 coord.join(threads) 266 267 def testClearStopClearsExceptionToo(self): 268 coord = coordinator.Coordinator() 269 ev_1 = threading.Event() 270 threads = [ 271 threading.Thread( 272 target=RaiseOnEvent, 273 args=(coord, ev_1, None, RuntimeError("First"), True)), 274 ] 275 for t in threads: 276 t.start() 277 278 with self.assertRaisesRegexp(RuntimeError, "First"): 279 ev_1.set() 280 coord.join(threads) 281 coord.clear_stop() 282 threads = [ 283 threading.Thread( 284 target=RaiseOnEvent, 285 args=(coord, ev_1, None, RuntimeError("Second"), True)), 286 ] 287 for t in threads: 288 t.start() 289 with self.assertRaisesRegexp(RuntimeError, "Second"): 290 ev_1.set() 291 coord.join(threads) 292 293 def testRequestStopRaisesIfJoined(self): 294 coord = coordinator.Coordinator() 295 # Join the coordinator right away. 296 coord.join([]) 297 reported = False 298 with self.assertRaisesRegexp(RuntimeError, "Too late"): 299 try: 300 raise RuntimeError("Too late") 301 except RuntimeError as e: 302 reported = True 303 coord.request_stop(e) 304 self.assertTrue(reported) 305 # If we clear_stop the exceptions are handled normally. 306 coord.clear_stop() 307 try: 308 raise RuntimeError("After clear") 309 except RuntimeError as e: 310 coord.request_stop(e) 311 with self.assertRaisesRegexp(RuntimeError, "After clear"): 312 coord.join([]) 313 314 def testRequestStopRaisesIfJoined_ExcInfo(self): 315 # Same as testRequestStopRaisesIfJoined but using syc.exc_info(). 316 coord = coordinator.Coordinator() 317 # Join the coordinator right away. 318 coord.join([]) 319 reported = False 320 with self.assertRaisesRegexp(RuntimeError, "Too late"): 321 try: 322 raise RuntimeError("Too late") 323 except RuntimeError: 324 reported = True 325 coord.request_stop(sys.exc_info()) 326 self.assertTrue(reported) 327 # If we clear_stop the exceptions are handled normally. 328 coord.clear_stop() 329 try: 330 raise RuntimeError("After clear") 331 except RuntimeError: 332 coord.request_stop(sys.exc_info()) 333 with self.assertRaisesRegexp(RuntimeError, "After clear"): 334 coord.join([]) 335 336 337def _StopAt0(coord, n): 338 if n[0] == 0: 339 coord.request_stop() 340 else: 341 n[0] -= 1 342 343 344class LooperTest(test.TestCase): 345 346 def testTargetArgs(self): 347 n = [3] 348 coord = coordinator.Coordinator() 349 thread = coordinator.LooperThread.loop( 350 coord, 0, target=_StopAt0, args=(coord, n)) 351 coord.join([thread]) 352 self.assertEqual(0, n[0]) 353 354 def testTargetKwargs(self): 355 n = [3] 356 coord = coordinator.Coordinator() 357 thread = coordinator.LooperThread.loop( 358 coord, 0, target=_StopAt0, kwargs={ 359 "coord": coord, 360 "n": n 361 }) 362 coord.join([thread]) 363 self.assertEqual(0, n[0]) 364 365 def testTargetMixedArgs(self): 366 n = [3] 367 coord = coordinator.Coordinator() 368 thread = coordinator.LooperThread.loop( 369 coord, 0, target=_StopAt0, args=(coord,), kwargs={ 370 "n": n 371 }) 372 coord.join([thread]) 373 self.assertEqual(0, n[0]) 374 375 376if __name__ == "__main__": 377 test.main() 378