• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Create threads to run multiple enqueue ops."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import threading
22import weakref
23
24from tensorflow.core.protobuf import queue_runner_pb2
25from tensorflow.python.client import session
26from tensorflow.python.eager import context
27from tensorflow.python.framework import errors
28from tensorflow.python.framework import ops
29from tensorflow.python.platform import tf_logging as logging
30from tensorflow.python.util import deprecation
31from tensorflow.python.util.tf_export import tf_export
32
33_DEPRECATION_INSTRUCTION = (
34    "To construct input pipelines, use the `tf.data` module.")
35
36
37@tf_export(v1=["train.queue_runner.QueueRunner", "train.QueueRunner"])
38class QueueRunner(object):
39  """Holds a list of enqueue operations for a queue, each to be run in a thread.
40
41  Queues are a convenient TensorFlow mechanism to compute tensors
42  asynchronously using multiple threads. For example in the canonical 'Input
43  Reader' setup one set of threads generates filenames in a queue; a second set
44  of threads read records from the files, processes them, and enqueues tensors
45  on a second queue; a third set of threads dequeues these input records to
46  construct batches and runs them through training operations.
47
48  There are several delicate issues when running multiple threads that way:
49  closing the queues in sequence as the input is exhausted, correctly catching
50  and reporting exceptions, etc.
51
52  The `QueueRunner`, combined with the `Coordinator`, helps handle these issues.
53
54  @compatibility(TF2)
55  QueueRunners are not compatible with eager execution. Instead, please
56  use [tf.data](https://www.tensorflow.org/guide/data) to get data into your
57  model.
58  @end_compatibility
59  """
60
61  @deprecation.deprecated(None, _DEPRECATION_INSTRUCTION)
62  def __init__(self, queue=None, enqueue_ops=None, close_op=None,
63               cancel_op=None, queue_closed_exception_types=None,
64               queue_runner_def=None, import_scope=None):
65    """Create a QueueRunner.
66
67    On construction the `QueueRunner` adds an op to close the queue.  That op
68    will be run if the enqueue ops raise exceptions.
69
70    When you later call the `create_threads()` method, the `QueueRunner` will
71    create one thread for each op in `enqueue_ops`.  Each thread will run its
72    enqueue op in parallel with the other threads.  The enqueue ops do not have
73    to all be the same op, but it is expected that they all enqueue tensors in
74    `queue`.
75
76    Args:
77      queue: A `Queue`.
78      enqueue_ops: List of enqueue ops to run in threads later.
79      close_op: Op to close the queue. Pending enqueue ops are preserved.
80      cancel_op: Op to close the queue and cancel pending enqueue ops.
81      queue_closed_exception_types: Optional tuple of Exception types that
82        indicate that the queue has been closed when raised during an enqueue
83        operation.  Defaults to `(tf.errors.OutOfRangeError,)`.  Another common
84        case includes `(tf.errors.OutOfRangeError, tf.errors.CancelledError)`,
85        when some of the enqueue ops may dequeue from other Queues.
86      queue_runner_def: Optional `QueueRunnerDef` protocol buffer. If specified,
87        recreates the QueueRunner from its contents. `queue_runner_def` and the
88        other arguments are mutually exclusive.
89      import_scope: Optional `string`. Name scope to add. Only used when
90        initializing from protocol buffer.
91
92    Raises:
93      ValueError: If both `queue_runner_def` and `queue` are both specified.
94      ValueError: If `queue` or `enqueue_ops` are not provided when not
95        restoring from `queue_runner_def`.
96      RuntimeError: If eager execution is enabled.
97    """
98    if context.executing_eagerly():
99      raise RuntimeError(
100          "QueueRunners are not supported when eager execution is enabled. "
101          "Instead, please use tf.data to get data into your model.")
102
103    if queue_runner_def:
104      if queue or enqueue_ops:
105        raise ValueError("queue_runner_def and queue are mutually exclusive.")
106      self._init_from_proto(queue_runner_def,
107                            import_scope=import_scope)
108    else:
109      self._init_from_args(
110          queue=queue, enqueue_ops=enqueue_ops,
111          close_op=close_op, cancel_op=cancel_op,
112          queue_closed_exception_types=queue_closed_exception_types)
113    # Protect the count of runs to wait for.
114    self._lock = threading.Lock()
115    # A map from a session object to the number of outstanding queue runner
116    # threads for that session.
117    self._runs_per_session = weakref.WeakKeyDictionary()
118    # List of exceptions raised by the running threads.
119    self._exceptions_raised = []
120
121  def _init_from_args(self, queue=None, enqueue_ops=None, close_op=None,
122                      cancel_op=None, queue_closed_exception_types=None):
123    """Create a QueueRunner from arguments.
124
125    Args:
126      queue: A `Queue`.
127      enqueue_ops: List of enqueue ops to run in threads later.
128      close_op: Op to close the queue. Pending enqueue ops are preserved.
129      cancel_op: Op to close the queue and cancel pending enqueue ops.
130      queue_closed_exception_types: Tuple of exception types, which indicate
131        the queue has been safely closed.
132
133    Raises:
134      ValueError: If `queue` or `enqueue_ops` are not provided when not
135        restoring from `queue_runner_def`.
136      TypeError: If `queue_closed_exception_types` is provided, but is not
137        a non-empty tuple of error types (subclasses of `tf.errors.OpError`).
138    """
139    if not queue or not enqueue_ops:
140      raise ValueError("Must provide queue and enqueue_ops.")
141    self._queue = queue
142    self._enqueue_ops = enqueue_ops
143    self._close_op = close_op
144    self._cancel_op = cancel_op
145    if queue_closed_exception_types is not None:
146      if (not isinstance(queue_closed_exception_types, tuple)
147          or not queue_closed_exception_types
148          or not all(issubclass(t, errors.OpError)
149                     for t in queue_closed_exception_types)):
150        raise TypeError(
151            "queue_closed_exception_types, when provided, "
152            "must be a tuple of tf.error types, but saw: %s"
153            % queue_closed_exception_types)
154    self._queue_closed_exception_types = queue_closed_exception_types
155    # Close when no more will be produced, but pending enqueues should be
156    # preserved.
157    if self._close_op is None:
158      self._close_op = self._queue.close()
159    # Close and cancel pending enqueues since there was an error and we want
160    # to unblock everything so we can cleanly exit.
161    if self._cancel_op is None:
162      self._cancel_op = self._queue.close(cancel_pending_enqueues=True)
163    if not self._queue_closed_exception_types:
164      self._queue_closed_exception_types = (errors.OutOfRangeError,)
165    else:
166      self._queue_closed_exception_types = tuple(
167          self._queue_closed_exception_types)
168
169  def _init_from_proto(self, queue_runner_def, import_scope=None):
170    """Create a QueueRunner from `QueueRunnerDef`.
171
172    Args:
173      queue_runner_def: Optional `QueueRunnerDef` protocol buffer.
174      import_scope: Optional `string`. Name scope to add.
175    """
176    assert isinstance(queue_runner_def, queue_runner_pb2.QueueRunnerDef)
177    g = ops.get_default_graph()
178    self._queue = g.as_graph_element(
179        ops.prepend_name_scope(queue_runner_def.queue_name, import_scope))
180    self._enqueue_ops = [g.as_graph_element(
181        ops.prepend_name_scope(op, import_scope))
182                         for op in queue_runner_def.enqueue_op_name]
183    self._close_op = g.as_graph_element(ops.prepend_name_scope(
184        queue_runner_def.close_op_name, import_scope))
185    self._cancel_op = g.as_graph_element(ops.prepend_name_scope(
186        queue_runner_def.cancel_op_name, import_scope))
187    self._queue_closed_exception_types = tuple(
188        errors.exception_type_from_error_code(code)
189        for code in queue_runner_def.queue_closed_exception_types)
190    # Legacy support for old QueueRunnerDefs created before this field
191    # was added.
192    if not self._queue_closed_exception_types:
193      self._queue_closed_exception_types = (errors.OutOfRangeError,)
194
195  @property
196  def queue(self):
197    return self._queue
198
199  @property
200  def enqueue_ops(self):
201    return self._enqueue_ops
202
203  @property
204  def close_op(self):
205    return self._close_op
206
207  @property
208  def cancel_op(self):
209    return self._cancel_op
210
211  @property
212  def queue_closed_exception_types(self):
213    return self._queue_closed_exception_types
214
215  @property
216  def exceptions_raised(self):
217    """Exceptions raised but not handled by the `QueueRunner` threads.
218
219    Exceptions raised in queue runner threads are handled in one of two ways
220    depending on whether or not a `Coordinator` was passed to
221    `create_threads()`:
222
223    * With a `Coordinator`, exceptions are reported to the coordinator and
224      forgotten by the `QueueRunner`.
225    * Without a `Coordinator`, exceptions are captured by the `QueueRunner` and
226      made available in this `exceptions_raised` property.
227
228    Returns:
229      A list of Python `Exception` objects.  The list is empty if no exception
230      was captured.  (No exceptions are captured when using a Coordinator.)
231    """
232    return self._exceptions_raised
233
234  @property
235  def name(self):
236    """The string name of the underlying Queue."""
237    return self._queue.name
238
239  # pylint: disable=broad-except
240  def _run(self, sess, enqueue_op, coord=None):
241    """Execute the enqueue op in a loop, close the queue in case of error.
242
243    Args:
244      sess: A Session.
245      enqueue_op: The Operation to run.
246      coord: Optional Coordinator object for reporting errors and checking
247        for stop conditions.
248    """
249    decremented = False
250    try:
251      # Make a cached callable from the `enqueue_op` to decrease the
252      # Python overhead in the queue-runner loop.
253      enqueue_callable = sess.make_callable(enqueue_op)
254      while True:
255        if coord and coord.should_stop():
256          break
257        try:
258          enqueue_callable()
259        except self._queue_closed_exception_types:  # pylint: disable=catching-non-exception
260          # This exception indicates that a queue was closed.
261          with self._lock:
262            self._runs_per_session[sess] -= 1
263            decremented = True
264            if self._runs_per_session[sess] == 0:
265              try:
266                sess.run(self._close_op)
267              except Exception as e:
268                # Intentionally ignore errors from close_op.
269                logging.vlog(1, "Ignored exception: %s", str(e))
270            return
271    except Exception as e:
272      # This catches all other exceptions.
273      if coord:
274        coord.request_stop(e)
275      else:
276        logging.error("Exception in QueueRunner: %s", str(e))
277        with self._lock:
278          self._exceptions_raised.append(e)
279        raise
280    finally:
281      # Make sure we account for all terminations: normal or errors.
282      if not decremented:
283        with self._lock:
284          self._runs_per_session[sess] -= 1
285
286  def _close_on_stop(self, sess, cancel_op, coord):
287    """Close the queue when the Coordinator requests stop.
288
289    Args:
290      sess: A Session.
291      cancel_op: The Operation to run.
292      coord: Coordinator.
293    """
294    coord.wait_for_stop()
295    try:
296      sess.run(cancel_op)
297    except Exception as e:
298      # Intentionally ignore errors from cancel_op.
299      logging.vlog(1, "Ignored exception: %s", str(e))
300  # pylint: enable=broad-except
301
302  def create_threads(self, sess, coord=None, daemon=False, start=False):
303    """Create threads to run the enqueue ops for the given session.
304
305    This method requires a session in which the graph was launched.  It creates
306    a list of threads, optionally starting them.  There is one thread for each
307    op passed in `enqueue_ops`.
308
309    The `coord` argument is an optional coordinator that the threads will use
310    to terminate together and report exceptions.  If a coordinator is given,
311    this method starts an additional thread to close the queue when the
312    coordinator requests a stop.
313
314    If previously created threads for the given session are still running, no
315    new threads will be created.
316
317    Args:
318      sess: A `Session`.
319      coord: Optional `Coordinator` object for reporting errors and checking
320        stop conditions.
321      daemon: Boolean.  If `True` make the threads daemon threads.
322      start: Boolean.  If `True` starts the threads.  If `False` the
323        caller must call the `start()` method of the returned threads.
324
325    Returns:
326      A list of threads.
327    """
328    with self._lock:
329      try:
330        if self._runs_per_session[sess] > 0:
331          # Already started: no new threads to return.
332          return []
333      except KeyError:
334        # We haven't seen this session yet.
335        pass
336      self._runs_per_session[sess] = len(self._enqueue_ops)
337      self._exceptions_raised = []
338
339    ret_threads = []
340    for op in self._enqueue_ops:
341      name = "QueueRunnerThread-{}-{}".format(self.name, op.name)
342      ret_threads.append(threading.Thread(target=self._run,
343                                          args=(sess, op, coord),
344                                          name=name))
345    if coord:
346      name = "QueueRunnerThread-{}-close_on_stop".format(self.name)
347      ret_threads.append(threading.Thread(target=self._close_on_stop,
348                                          args=(sess, self._cancel_op, coord),
349                                          name=name))
350    for t in ret_threads:
351      if coord:
352        coord.register_thread(t)
353      if daemon:
354        t.daemon = True
355      if start:
356        t.start()
357    return ret_threads
358
359  def to_proto(self, export_scope=None):
360    """Converts this `QueueRunner` to a `QueueRunnerDef` protocol buffer.
361
362    Args:
363      export_scope: Optional `string`. Name scope to remove.
364
365    Returns:
366      A `QueueRunnerDef` protocol buffer, or `None` if the `Variable` is not in
367      the specified name scope.
368    """
369    if (export_scope is None or
370        self.queue.name.startswith(export_scope)):
371      queue_runner_def = queue_runner_pb2.QueueRunnerDef()
372      queue_runner_def.queue_name = ops.strip_name_scope(
373          self.queue.name, export_scope)
374      for enqueue_op in self.enqueue_ops:
375        queue_runner_def.enqueue_op_name.append(
376            ops.strip_name_scope(enqueue_op.name, export_scope))
377      queue_runner_def.close_op_name = ops.strip_name_scope(
378          self.close_op.name, export_scope)
379      queue_runner_def.cancel_op_name = ops.strip_name_scope(
380          self.cancel_op.name, export_scope)
381      queue_runner_def.queue_closed_exception_types.extend([
382          errors.error_code_from_exception_type(cls)
383          for cls in self._queue_closed_exception_types])
384      return queue_runner_def
385    else:
386      return None
387
388  @staticmethod
389  def from_proto(queue_runner_def, import_scope=None):
390    """Returns a `QueueRunner` object created from `queue_runner_def`."""
391    return QueueRunner(queue_runner_def=queue_runner_def,
392                       import_scope=import_scope)
393
394
395@tf_export(v1=["train.queue_runner.add_queue_runner", "train.add_queue_runner"])
396@deprecation.deprecated(None, _DEPRECATION_INSTRUCTION)
397def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS):
398  """Adds a `QueueRunner` to a collection in the graph.
399
400  When building a complex model that uses many queues it is often difficult to
401  gather all the queue runners that need to be run.  This convenience function
402  allows you to add a queue runner to a well known collection in the graph.
403
404  The companion method `start_queue_runners()` can be used to start threads for
405  all the collected queue runners.
406
407  @compatibility(TF2)
408  QueueRunners are not compatible with eager execution. Instead, please
409  use [tf.data](https://www.tensorflow.org/guide/data) to get data into your
410  model.
411  @end_compatibility
412
413  Args:
414    qr: A `QueueRunner`.
415    collection: A `GraphKey` specifying the graph collection to add
416      the queue runner to.  Defaults to `GraphKeys.QUEUE_RUNNERS`.
417  """
418  ops.add_to_collection(collection, qr)
419
420
421@tf_export(v1=["train.queue_runner.start_queue_runners",
422               "train.start_queue_runners"])
423@deprecation.deprecated(None, _DEPRECATION_INSTRUCTION)
424def start_queue_runners(sess=None, coord=None, daemon=True, start=True,
425                        collection=ops.GraphKeys.QUEUE_RUNNERS):
426  """Starts all queue runners collected in the graph.
427
428  This is a companion method to `add_queue_runner()`.  It just starts
429  threads for all queue runners collected in the graph.  It returns
430  the list of all threads.
431
432  @compatibility(TF2)
433  QueueRunners are not compatible with eager execution. Instead, please
434  use [tf.data](https://www.tensorflow.org/guide/data) to get data into your
435  model.
436  @end_compatibility
437
438  Args:
439    sess: `Session` used to run the queue ops.  Defaults to the
440      default session.
441    coord: Optional `Coordinator` for coordinating the started threads.
442    daemon: Whether the threads should be marked as `daemons`, meaning
443      they don't block program exit.
444    start: Set to `False` to only create the threads, not start them.
445    collection: A `GraphKey` specifying the graph collection to
446      get the queue runners from.  Defaults to `GraphKeys.QUEUE_RUNNERS`.
447
448  Raises:
449    ValueError: if `sess` is None and there isn't any default session.
450    TypeError: if `sess` is not a `tf.compat.v1.Session` object.
451
452  Returns:
453    A list of threads.
454
455  Raises:
456    RuntimeError: If called with eager execution enabled.
457    ValueError: If called without a default `tf.compat.v1.Session` registered.
458  """
459  if context.executing_eagerly():
460    raise RuntimeError("Queues are not compatible with eager execution.")
461  if sess is None:
462    sess = ops.get_default_session()
463    if not sess:
464      raise ValueError("Cannot start queue runners: No default session is "
465                       "registered. Use `with sess.as_default()` or pass an "
466                       "explicit session to tf.start_queue_runners(sess=sess)")
467
468  if not isinstance(sess, session.SessionInterface):
469    # Following check is due to backward compatibility. (b/62061352)
470    if sess.__class__.__name__ in [
471        "MonitoredSession", "SingularMonitoredSession"]:
472      return []
473    raise TypeError("sess must be a `tf.Session` object. "
474                    "Given class: {}".format(sess.__class__))
475
476  queue_runners = ops.get_collection(collection)
477  if not queue_runners:
478    logging.warning(
479        "`tf.train.start_queue_runners()` was called when no queue runners "
480        "were defined. You can safely remove the call to this deprecated "
481        "function.")
482
483  with sess.graph.as_default():
484    threads = []
485    for qr in ops.get_collection(collection):
486      threads.extend(qr.create_threads(sess, coord=coord, daemon=daemon,
487                                       start=start))
488  return threads
489
490
491ops.register_proto_function(ops.GraphKeys.QUEUE_RUNNERS,
492                            proto_type=queue_runner_pb2.QueueRunnerDef,
493                            to_proto=QueueRunner.to_proto,
494                            from_proto=QueueRunner.from_proto)
495