• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Coordinator."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import sys
22import threading
23import time
24
25from tensorflow.python.framework import errors_impl
26from tensorflow.python.platform import test
27from tensorflow.python.training import coordinator
28
29
30def StopOnEvent(coord, wait_for_stop, set_when_stopped):
31  wait_for_stop.wait()
32  coord.request_stop()
33  set_when_stopped.set()
34
35
36def RaiseOnEvent(coord, wait_for_stop, set_when_stopped, ex, report_exception):
37  try:
38    wait_for_stop.wait()
39    raise ex
40  except RuntimeError as e:
41    if report_exception:
42      coord.request_stop(e)
43    else:
44      coord.request_stop(sys.exc_info())
45  finally:
46    if set_when_stopped:
47      set_when_stopped.set()
48
49
50def RaiseOnEventUsingContextHandler(coord, wait_for_stop, set_when_stopped, ex):
51  with coord.stop_on_exception():
52    wait_for_stop.wait()
53    raise ex
54  if set_when_stopped:
55    set_when_stopped.set()
56
57
58def SleepABit(n_secs, coord=None):
59  if coord:
60    coord.register_thread(threading.current_thread())
61  time.sleep(n_secs)
62
63
64def WaitForThreadsToRegister(coord, num_threads):
65  while True:
66    with coord._lock:
67      if len(coord._registered_threads) == num_threads:
68        break
69    time.sleep(0.001)
70
71
72class CoordinatorTest(test.TestCase):
73
74  def testStopAPI(self):
75    coord = coordinator.Coordinator()
76    self.assertFalse(coord.should_stop())
77    self.assertFalse(coord.wait_for_stop(0.01))
78    coord.request_stop()
79    self.assertTrue(coord.should_stop())
80    self.assertTrue(coord.wait_for_stop(0.01))
81
82  def testStopAsync(self):
83    coord = coordinator.Coordinator()
84    self.assertFalse(coord.should_stop())
85    self.assertFalse(coord.wait_for_stop(0.1))
86    wait_for_stop_ev = threading.Event()
87    has_stopped_ev = threading.Event()
88    t = threading.Thread(
89        target=StopOnEvent, args=(coord, wait_for_stop_ev, has_stopped_ev))
90    t.start()
91    self.assertFalse(coord.should_stop())
92    self.assertFalse(coord.wait_for_stop(0.01))
93    wait_for_stop_ev.set()
94    has_stopped_ev.wait()
95    self.assertTrue(coord.wait_for_stop(0.05))
96    self.assertTrue(coord.should_stop())
97
98  def testJoin(self):
99    coord = coordinator.Coordinator()
100    threads = [
101        threading.Thread(target=SleepABit, args=(0.01,)),
102        threading.Thread(target=SleepABit, args=(0.02,)),
103        threading.Thread(target=SleepABit, args=(0.01,))
104    ]
105    for t in threads:
106      t.start()
107    coord.join(threads)
108    for t in threads:
109      self.assertFalse(t.is_alive())
110
111  def testJoinAllRegistered(self):
112    coord = coordinator.Coordinator()
113    threads = [
114        threading.Thread(target=SleepABit, args=(0.01, coord)),
115        threading.Thread(target=SleepABit, args=(0.02, coord)),
116        threading.Thread(target=SleepABit, args=(0.01, coord))
117    ]
118    for t in threads:
119      t.start()
120    WaitForThreadsToRegister(coord, 3)
121    coord.join()
122    for t in threads:
123      self.assertFalse(t.is_alive())
124
125  def testJoinSomeRegistered(self):
126    coord = coordinator.Coordinator()
127    threads = [
128        threading.Thread(target=SleepABit, args=(0.01, coord)),
129        threading.Thread(target=SleepABit, args=(0.02,)),
130        threading.Thread(target=SleepABit, args=(0.01, coord))
131    ]
132    for t in threads:
133      t.start()
134    WaitForThreadsToRegister(coord, 2)
135    # threads[1] is not registered we must pass it in.
136    coord.join([threads[1]])
137    for t in threads:
138      self.assertFalse(t.is_alive())
139
140  def testJoinGraceExpires(self):
141
142    def TestWithGracePeriod(stop_grace_period):
143      coord = coordinator.Coordinator()
144      wait_for_stop_ev = threading.Event()
145      has_stopped_ev = threading.Event()
146      threads = [
147          threading.Thread(
148              target=StopOnEvent,
149              args=(coord, wait_for_stop_ev, has_stopped_ev)),
150          threading.Thread(target=SleepABit, args=(10.0,))
151      ]
152      for t in threads:
153        t.daemon = True
154        t.start()
155      wait_for_stop_ev.set()
156      has_stopped_ev.wait()
157      with self.assertRaisesRegexp(RuntimeError, "threads still running"):
158        coord.join(threads, stop_grace_period_secs=stop_grace_period)
159
160    TestWithGracePeriod(1e-10)
161    TestWithGracePeriod(0.002)
162    TestWithGracePeriod(1.0)
163
164  def testJoinWithoutGraceExpires(self):
165    coord = coordinator.Coordinator()
166    wait_for_stop_ev = threading.Event()
167    has_stopped_ev = threading.Event()
168    threads = [
169        threading.Thread(
170            target=StopOnEvent, args=(coord, wait_for_stop_ev, has_stopped_ev)),
171        threading.Thread(target=SleepABit, args=(10.0,))
172    ]
173    for t in threads:
174      t.daemon = True
175      t.start()
176    wait_for_stop_ev.set()
177    has_stopped_ev.wait()
178    coord.join(threads, stop_grace_period_secs=1., ignore_live_threads=True)
179
180  def testJoinRaiseReportExcInfo(self):
181    coord = coordinator.Coordinator()
182    ev_1 = threading.Event()
183    ev_2 = threading.Event()
184    threads = [
185        threading.Thread(
186            target=RaiseOnEvent,
187            args=(coord, ev_1, ev_2, RuntimeError("First"), False)),
188        threading.Thread(
189            target=RaiseOnEvent,
190            args=(coord, ev_2, None, RuntimeError("Too late"), False))
191    ]
192    for t in threads:
193      t.start()
194
195    ev_1.set()
196
197    with self.assertRaisesRegexp(RuntimeError, "First"):
198      coord.join(threads)
199
200  def testJoinRaiseReportException(self):
201    coord = coordinator.Coordinator()
202    ev_1 = threading.Event()
203    ev_2 = threading.Event()
204    threads = [
205        threading.Thread(
206            target=RaiseOnEvent,
207            args=(coord, ev_1, ev_2, RuntimeError("First"), True)),
208        threading.Thread(
209            target=RaiseOnEvent,
210            args=(coord, ev_2, None, RuntimeError("Too late"), True))
211    ]
212    for t in threads:
213      t.start()
214
215    ev_1.set()
216    with self.assertRaisesRegexp(RuntimeError, "First"):
217      coord.join(threads)
218
219  def testJoinIgnoresOutOfRange(self):
220    coord = coordinator.Coordinator()
221    ev_1 = threading.Event()
222    threads = [
223        threading.Thread(
224            target=RaiseOnEvent,
225            args=(coord, ev_1, None,
226                  errors_impl.OutOfRangeError(None, None, "First"), True))
227    ]
228    for t in threads:
229      t.start()
230
231    ev_1.set()
232    coord.join(threads)
233
234  def testJoinIgnoresMyExceptionType(self):
235    coord = coordinator.Coordinator(clean_stop_exception_types=(ValueError,))
236    ev_1 = threading.Event()
237    threads = [
238        threading.Thread(
239            target=RaiseOnEvent,
240            args=(coord, ev_1, None, ValueError("Clean stop"), True))
241    ]
242    for t in threads:
243      t.start()
244
245    ev_1.set()
246    coord.join(threads)
247
248  def testJoinRaiseReportExceptionUsingHandler(self):
249    coord = coordinator.Coordinator()
250    ev_1 = threading.Event()
251    ev_2 = threading.Event()
252    threads = [
253        threading.Thread(
254            target=RaiseOnEventUsingContextHandler,
255            args=(coord, ev_1, ev_2, RuntimeError("First"))),
256        threading.Thread(
257            target=RaiseOnEventUsingContextHandler,
258            args=(coord, ev_2, None, RuntimeError("Too late")))
259    ]
260    for t in threads:
261      t.start()
262
263    ev_1.set()
264    with self.assertRaisesRegexp(RuntimeError, "First"):
265      coord.join(threads)
266
267  def testClearStopClearsExceptionToo(self):
268    coord = coordinator.Coordinator()
269    ev_1 = threading.Event()
270    threads = [
271        threading.Thread(
272            target=RaiseOnEvent,
273            args=(coord, ev_1, None, RuntimeError("First"), True)),
274    ]
275    for t in threads:
276      t.start()
277
278    with self.assertRaisesRegexp(RuntimeError, "First"):
279      ev_1.set()
280      coord.join(threads)
281    coord.clear_stop()
282    threads = [
283        threading.Thread(
284            target=RaiseOnEvent,
285            args=(coord, ev_1, None, RuntimeError("Second"), True)),
286    ]
287    for t in threads:
288      t.start()
289    with self.assertRaisesRegexp(RuntimeError, "Second"):
290      ev_1.set()
291      coord.join(threads)
292
293  def testRequestStopRaisesIfJoined(self):
294    coord = coordinator.Coordinator()
295    # Join the coordinator right away.
296    coord.join([])
297    reported = False
298    with self.assertRaisesRegexp(RuntimeError, "Too late"):
299      try:
300        raise RuntimeError("Too late")
301      except RuntimeError as e:
302        reported = True
303        coord.request_stop(e)
304    self.assertTrue(reported)
305    # If we clear_stop the exceptions are handled normally.
306    coord.clear_stop()
307    try:
308      raise RuntimeError("After clear")
309    except RuntimeError as e:
310      coord.request_stop(e)
311    with self.assertRaisesRegexp(RuntimeError, "After clear"):
312      coord.join([])
313
314  def testRequestStopRaisesIfJoined_ExcInfo(self):
315    # Same as testRequestStopRaisesIfJoined but using syc.exc_info().
316    coord = coordinator.Coordinator()
317    # Join the coordinator right away.
318    coord.join([])
319    reported = False
320    with self.assertRaisesRegexp(RuntimeError, "Too late"):
321      try:
322        raise RuntimeError("Too late")
323      except RuntimeError:
324        reported = True
325        coord.request_stop(sys.exc_info())
326    self.assertTrue(reported)
327    # If we clear_stop the exceptions are handled normally.
328    coord.clear_stop()
329    try:
330      raise RuntimeError("After clear")
331    except RuntimeError:
332      coord.request_stop(sys.exc_info())
333    with self.assertRaisesRegexp(RuntimeError, "After clear"):
334      coord.join([])
335
336
337def _StopAt0(coord, n):
338  if n[0] == 0:
339    coord.request_stop()
340  else:
341    n[0] -= 1
342
343
344class LooperTest(test.TestCase):
345
346  def testTargetArgs(self):
347    n = [3]
348    coord = coordinator.Coordinator()
349    thread = coordinator.LooperThread.loop(
350        coord, 0, target=_StopAt0, args=(coord, n))
351    coord.join([thread])
352    self.assertEqual(0, n[0])
353
354  def testTargetKwargs(self):
355    n = [3]
356    coord = coordinator.Coordinator()
357    thread = coordinator.LooperThread.loop(
358        coord, 0, target=_StopAt0, kwargs={
359            "coord": coord,
360            "n": n
361        })
362    coord.join([thread])
363    self.assertEqual(0, n[0])
364
365  def testTargetMixedArgs(self):
366    n = [3]
367    coord = coordinator.Coordinator()
368    thread = coordinator.LooperThread.loop(
369        coord, 0, target=_StopAt0, args=(coord,), kwargs={
370            "n": n
371        })
372    coord.join([thread])
373    self.assertEqual(0, n[0])
374
375
376if __name__ == "__main__":
377  test.main()
378