1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Create threads to run multiple enqueue ops.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import threading 22import weakref 23 24from tensorflow.core.protobuf import queue_runner_pb2 25from tensorflow.python.client import session 26from tensorflow.python.eager import context 27from tensorflow.python.framework import errors 28from tensorflow.python.framework import ops 29from tensorflow.python.platform import tf_logging as logging 30from tensorflow.python.util import deprecation 31from tensorflow.python.util.tf_export import tf_export 32 33_DEPRECATION_INSTRUCTION = ( 34 "To construct input pipelines, use the `tf.data` module.") 35 36 37@tf_export(v1=["train.queue_runner.QueueRunner", "train.QueueRunner"]) 38class QueueRunner(object): 39 """Holds a list of enqueue operations for a queue, each to be run in a thread. 40 41 Queues are a convenient TensorFlow mechanism to compute tensors 42 asynchronously using multiple threads. For example in the canonical 'Input 43 Reader' setup one set of threads generates filenames in a queue; a second set 44 of threads read records from the files, processes them, and enqueues tensors 45 on a second queue; a third set of threads dequeues these input records to 46 construct batches and runs them through training operations. 47 48 There are several delicate issues when running multiple threads that way: 49 closing the queues in sequence as the input is exhausted, correctly catching 50 and reporting exceptions, etc. 51 52 The `QueueRunner`, combined with the `Coordinator`, helps handle these issues. 53 54 @compatibility(TF2) 55 QueueRunners are not compatible with eager execution. Instead, please 56 use [tf.data](https://www.tensorflow.org/guide/data) to get data into your 57 model. 58 @end_compatibility 59 """ 60 61 @deprecation.deprecated(None, _DEPRECATION_INSTRUCTION) 62 def __init__(self, queue=None, enqueue_ops=None, close_op=None, 63 cancel_op=None, queue_closed_exception_types=None, 64 queue_runner_def=None, import_scope=None): 65 """Create a QueueRunner. 66 67 On construction the `QueueRunner` adds an op to close the queue. That op 68 will be run if the enqueue ops raise exceptions. 69 70 When you later call the `create_threads()` method, the `QueueRunner` will 71 create one thread for each op in `enqueue_ops`. Each thread will run its 72 enqueue op in parallel with the other threads. The enqueue ops do not have 73 to all be the same op, but it is expected that they all enqueue tensors in 74 `queue`. 75 76 Args: 77 queue: A `Queue`. 78 enqueue_ops: List of enqueue ops to run in threads later. 79 close_op: Op to close the queue. Pending enqueue ops are preserved. 80 cancel_op: Op to close the queue and cancel pending enqueue ops. 81 queue_closed_exception_types: Optional tuple of Exception types that 82 indicate that the queue has been closed when raised during an enqueue 83 operation. Defaults to `(tf.errors.OutOfRangeError,)`. Another common 84 case includes `(tf.errors.OutOfRangeError, tf.errors.CancelledError)`, 85 when some of the enqueue ops may dequeue from other Queues. 86 queue_runner_def: Optional `QueueRunnerDef` protocol buffer. If specified, 87 recreates the QueueRunner from its contents. `queue_runner_def` and the 88 other arguments are mutually exclusive. 89 import_scope: Optional `string`. Name scope to add. Only used when 90 initializing from protocol buffer. 91 92 Raises: 93 ValueError: If both `queue_runner_def` and `queue` are both specified. 94 ValueError: If `queue` or `enqueue_ops` are not provided when not 95 restoring from `queue_runner_def`. 96 RuntimeError: If eager execution is enabled. 97 """ 98 if context.executing_eagerly(): 99 raise RuntimeError( 100 "QueueRunners are not supported when eager execution is enabled. " 101 "Instead, please use tf.data to get data into your model.") 102 103 if queue_runner_def: 104 if queue or enqueue_ops: 105 raise ValueError("queue_runner_def and queue are mutually exclusive.") 106 self._init_from_proto(queue_runner_def, 107 import_scope=import_scope) 108 else: 109 self._init_from_args( 110 queue=queue, enqueue_ops=enqueue_ops, 111 close_op=close_op, cancel_op=cancel_op, 112 queue_closed_exception_types=queue_closed_exception_types) 113 # Protect the count of runs to wait for. 114 self._lock = threading.Lock() 115 # A map from a session object to the number of outstanding queue runner 116 # threads for that session. 117 self._runs_per_session = weakref.WeakKeyDictionary() 118 # List of exceptions raised by the running threads. 119 self._exceptions_raised = [] 120 121 def _init_from_args(self, queue=None, enqueue_ops=None, close_op=None, 122 cancel_op=None, queue_closed_exception_types=None): 123 """Create a QueueRunner from arguments. 124 125 Args: 126 queue: A `Queue`. 127 enqueue_ops: List of enqueue ops to run in threads later. 128 close_op: Op to close the queue. Pending enqueue ops are preserved. 129 cancel_op: Op to close the queue and cancel pending enqueue ops. 130 queue_closed_exception_types: Tuple of exception types, which indicate 131 the queue has been safely closed. 132 133 Raises: 134 ValueError: If `queue` or `enqueue_ops` are not provided when not 135 restoring from `queue_runner_def`. 136 TypeError: If `queue_closed_exception_types` is provided, but is not 137 a non-empty tuple of error types (subclasses of `tf.errors.OpError`). 138 """ 139 if not queue or not enqueue_ops: 140 raise ValueError("Must provide queue and enqueue_ops.") 141 self._queue = queue 142 self._enqueue_ops = enqueue_ops 143 self._close_op = close_op 144 self._cancel_op = cancel_op 145 if queue_closed_exception_types is not None: 146 if (not isinstance(queue_closed_exception_types, tuple) 147 or not queue_closed_exception_types 148 or not all(issubclass(t, errors.OpError) 149 for t in queue_closed_exception_types)): 150 raise TypeError( 151 "queue_closed_exception_types, when provided, " 152 "must be a tuple of tf.error types, but saw: %s" 153 % queue_closed_exception_types) 154 self._queue_closed_exception_types = queue_closed_exception_types 155 # Close when no more will be produced, but pending enqueues should be 156 # preserved. 157 if self._close_op is None: 158 self._close_op = self._queue.close() 159 # Close and cancel pending enqueues since there was an error and we want 160 # to unblock everything so we can cleanly exit. 161 if self._cancel_op is None: 162 self._cancel_op = self._queue.close(cancel_pending_enqueues=True) 163 if not self._queue_closed_exception_types: 164 self._queue_closed_exception_types = (errors.OutOfRangeError,) 165 else: 166 self._queue_closed_exception_types = tuple( 167 self._queue_closed_exception_types) 168 169 def _init_from_proto(self, queue_runner_def, import_scope=None): 170 """Create a QueueRunner from `QueueRunnerDef`. 171 172 Args: 173 queue_runner_def: Optional `QueueRunnerDef` protocol buffer. 174 import_scope: Optional `string`. Name scope to add. 175 """ 176 assert isinstance(queue_runner_def, queue_runner_pb2.QueueRunnerDef) 177 g = ops.get_default_graph() 178 self._queue = g.as_graph_element( 179 ops.prepend_name_scope(queue_runner_def.queue_name, import_scope)) 180 self._enqueue_ops = [g.as_graph_element( 181 ops.prepend_name_scope(op, import_scope)) 182 for op in queue_runner_def.enqueue_op_name] 183 self._close_op = g.as_graph_element(ops.prepend_name_scope( 184 queue_runner_def.close_op_name, import_scope)) 185 self._cancel_op = g.as_graph_element(ops.prepend_name_scope( 186 queue_runner_def.cancel_op_name, import_scope)) 187 self._queue_closed_exception_types = tuple( 188 errors.exception_type_from_error_code(code) 189 for code in queue_runner_def.queue_closed_exception_types) 190 # Legacy support for old QueueRunnerDefs created before this field 191 # was added. 192 if not self._queue_closed_exception_types: 193 self._queue_closed_exception_types = (errors.OutOfRangeError,) 194 195 @property 196 def queue(self): 197 return self._queue 198 199 @property 200 def enqueue_ops(self): 201 return self._enqueue_ops 202 203 @property 204 def close_op(self): 205 return self._close_op 206 207 @property 208 def cancel_op(self): 209 return self._cancel_op 210 211 @property 212 def queue_closed_exception_types(self): 213 return self._queue_closed_exception_types 214 215 @property 216 def exceptions_raised(self): 217 """Exceptions raised but not handled by the `QueueRunner` threads. 218 219 Exceptions raised in queue runner threads are handled in one of two ways 220 depending on whether or not a `Coordinator` was passed to 221 `create_threads()`: 222 223 * With a `Coordinator`, exceptions are reported to the coordinator and 224 forgotten by the `QueueRunner`. 225 * Without a `Coordinator`, exceptions are captured by the `QueueRunner` and 226 made available in this `exceptions_raised` property. 227 228 Returns: 229 A list of Python `Exception` objects. The list is empty if no exception 230 was captured. (No exceptions are captured when using a Coordinator.) 231 """ 232 return self._exceptions_raised 233 234 @property 235 def name(self): 236 """The string name of the underlying Queue.""" 237 return self._queue.name 238 239 # pylint: disable=broad-except 240 def _run(self, sess, enqueue_op, coord=None): 241 """Execute the enqueue op in a loop, close the queue in case of error. 242 243 Args: 244 sess: A Session. 245 enqueue_op: The Operation to run. 246 coord: Optional Coordinator object for reporting errors and checking 247 for stop conditions. 248 """ 249 decremented = False 250 try: 251 # Make a cached callable from the `enqueue_op` to decrease the 252 # Python overhead in the queue-runner loop. 253 enqueue_callable = sess.make_callable(enqueue_op) 254 while True: 255 if coord and coord.should_stop(): 256 break 257 try: 258 enqueue_callable() 259 except self._queue_closed_exception_types: # pylint: disable=catching-non-exception 260 # This exception indicates that a queue was closed. 261 with self._lock: 262 self._runs_per_session[sess] -= 1 263 decremented = True 264 if self._runs_per_session[sess] == 0: 265 try: 266 sess.run(self._close_op) 267 except Exception as e: 268 # Intentionally ignore errors from close_op. 269 logging.vlog(1, "Ignored exception: %s", str(e)) 270 return 271 except Exception as e: 272 # This catches all other exceptions. 273 if coord: 274 coord.request_stop(e) 275 else: 276 logging.error("Exception in QueueRunner: %s", str(e)) 277 with self._lock: 278 self._exceptions_raised.append(e) 279 raise 280 finally: 281 # Make sure we account for all terminations: normal or errors. 282 if not decremented: 283 with self._lock: 284 self._runs_per_session[sess] -= 1 285 286 def _close_on_stop(self, sess, cancel_op, coord): 287 """Close the queue when the Coordinator requests stop. 288 289 Args: 290 sess: A Session. 291 cancel_op: The Operation to run. 292 coord: Coordinator. 293 """ 294 coord.wait_for_stop() 295 try: 296 sess.run(cancel_op) 297 except Exception as e: 298 # Intentionally ignore errors from cancel_op. 299 logging.vlog(1, "Ignored exception: %s", str(e)) 300 # pylint: enable=broad-except 301 302 def create_threads(self, sess, coord=None, daemon=False, start=False): 303 """Create threads to run the enqueue ops for the given session. 304 305 This method requires a session in which the graph was launched. It creates 306 a list of threads, optionally starting them. There is one thread for each 307 op passed in `enqueue_ops`. 308 309 The `coord` argument is an optional coordinator that the threads will use 310 to terminate together and report exceptions. If a coordinator is given, 311 this method starts an additional thread to close the queue when the 312 coordinator requests a stop. 313 314 If previously created threads for the given session are still running, no 315 new threads will be created. 316 317 Args: 318 sess: A `Session`. 319 coord: Optional `Coordinator` object for reporting errors and checking 320 stop conditions. 321 daemon: Boolean. If `True` make the threads daemon threads. 322 start: Boolean. If `True` starts the threads. If `False` the 323 caller must call the `start()` method of the returned threads. 324 325 Returns: 326 A list of threads. 327 """ 328 with self._lock: 329 try: 330 if self._runs_per_session[sess] > 0: 331 # Already started: no new threads to return. 332 return [] 333 except KeyError: 334 # We haven't seen this session yet. 335 pass 336 self._runs_per_session[sess] = len(self._enqueue_ops) 337 self._exceptions_raised = [] 338 339 ret_threads = [] 340 for op in self._enqueue_ops: 341 name = "QueueRunnerThread-{}-{}".format(self.name, op.name) 342 ret_threads.append(threading.Thread(target=self._run, 343 args=(sess, op, coord), 344 name=name)) 345 if coord: 346 name = "QueueRunnerThread-{}-close_on_stop".format(self.name) 347 ret_threads.append(threading.Thread(target=self._close_on_stop, 348 args=(sess, self._cancel_op, coord), 349 name=name)) 350 for t in ret_threads: 351 if coord: 352 coord.register_thread(t) 353 if daemon: 354 t.daemon = True 355 if start: 356 t.start() 357 return ret_threads 358 359 def to_proto(self, export_scope=None): 360 """Converts this `QueueRunner` to a `QueueRunnerDef` protocol buffer. 361 362 Args: 363 export_scope: Optional `string`. Name scope to remove. 364 365 Returns: 366 A `QueueRunnerDef` protocol buffer, or `None` if the `Variable` is not in 367 the specified name scope. 368 """ 369 if (export_scope is None or 370 self.queue.name.startswith(export_scope)): 371 queue_runner_def = queue_runner_pb2.QueueRunnerDef() 372 queue_runner_def.queue_name = ops.strip_name_scope( 373 self.queue.name, export_scope) 374 for enqueue_op in self.enqueue_ops: 375 queue_runner_def.enqueue_op_name.append( 376 ops.strip_name_scope(enqueue_op.name, export_scope)) 377 queue_runner_def.close_op_name = ops.strip_name_scope( 378 self.close_op.name, export_scope) 379 queue_runner_def.cancel_op_name = ops.strip_name_scope( 380 self.cancel_op.name, export_scope) 381 queue_runner_def.queue_closed_exception_types.extend([ 382 errors.error_code_from_exception_type(cls) 383 for cls in self._queue_closed_exception_types]) 384 return queue_runner_def 385 else: 386 return None 387 388 @staticmethod 389 def from_proto(queue_runner_def, import_scope=None): 390 """Returns a `QueueRunner` object created from `queue_runner_def`.""" 391 return QueueRunner(queue_runner_def=queue_runner_def, 392 import_scope=import_scope) 393 394 395@tf_export(v1=["train.queue_runner.add_queue_runner", "train.add_queue_runner"]) 396@deprecation.deprecated(None, _DEPRECATION_INSTRUCTION) 397def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS): 398 """Adds a `QueueRunner` to a collection in the graph. 399 400 When building a complex model that uses many queues it is often difficult to 401 gather all the queue runners that need to be run. This convenience function 402 allows you to add a queue runner to a well known collection in the graph. 403 404 The companion method `start_queue_runners()` can be used to start threads for 405 all the collected queue runners. 406 407 @compatibility(TF2) 408 QueueRunners are not compatible with eager execution. Instead, please 409 use [tf.data](https://www.tensorflow.org/guide/data) to get data into your 410 model. 411 @end_compatibility 412 413 Args: 414 qr: A `QueueRunner`. 415 collection: A `GraphKey` specifying the graph collection to add 416 the queue runner to. Defaults to `GraphKeys.QUEUE_RUNNERS`. 417 """ 418 ops.add_to_collection(collection, qr) 419 420 421@tf_export(v1=["train.queue_runner.start_queue_runners", 422 "train.start_queue_runners"]) 423@deprecation.deprecated(None, _DEPRECATION_INSTRUCTION) 424def start_queue_runners(sess=None, coord=None, daemon=True, start=True, 425 collection=ops.GraphKeys.QUEUE_RUNNERS): 426 """Starts all queue runners collected in the graph. 427 428 This is a companion method to `add_queue_runner()`. It just starts 429 threads for all queue runners collected in the graph. It returns 430 the list of all threads. 431 432 @compatibility(TF2) 433 QueueRunners are not compatible with eager execution. Instead, please 434 use [tf.data](https://www.tensorflow.org/guide/data) to get data into your 435 model. 436 @end_compatibility 437 438 Args: 439 sess: `Session` used to run the queue ops. Defaults to the 440 default session. 441 coord: Optional `Coordinator` for coordinating the started threads. 442 daemon: Whether the threads should be marked as `daemons`, meaning 443 they don't block program exit. 444 start: Set to `False` to only create the threads, not start them. 445 collection: A `GraphKey` specifying the graph collection to 446 get the queue runners from. Defaults to `GraphKeys.QUEUE_RUNNERS`. 447 448 Raises: 449 ValueError: if `sess` is None and there isn't any default session. 450 TypeError: if `sess` is not a `tf.compat.v1.Session` object. 451 452 Returns: 453 A list of threads. 454 455 Raises: 456 RuntimeError: If called with eager execution enabled. 457 ValueError: If called without a default `tf.compat.v1.Session` registered. 458 """ 459 if context.executing_eagerly(): 460 raise RuntimeError("Queues are not compatible with eager execution.") 461 if sess is None: 462 sess = ops.get_default_session() 463 if not sess: 464 raise ValueError("Cannot start queue runners: No default session is " 465 "registered. Use `with sess.as_default()` or pass an " 466 "explicit session to tf.start_queue_runners(sess=sess)") 467 468 if not isinstance(sess, session.SessionInterface): 469 # Following check is due to backward compatibility. (b/62061352) 470 if sess.__class__.__name__ in [ 471 "MonitoredSession", "SingularMonitoredSession"]: 472 return [] 473 raise TypeError("sess must be a `tf.Session` object. " 474 "Given class: {}".format(sess.__class__)) 475 476 queue_runners = ops.get_collection(collection) 477 if not queue_runners: 478 logging.warning( 479 "`tf.train.start_queue_runners()` was called when no queue runners " 480 "were defined. You can safely remove the call to this deprecated " 481 "function.") 482 483 with sess.graph.as_default(): 484 threads = [] 485 for qr in ops.get_collection(collection): 486 threads.extend(qr.create_threads(sess, coord=coord, daemon=daemon, 487 start=start)) 488 return threads 489 490 491ops.register_proto_function(ops.GraphKeys.QUEUE_RUNNERS, 492 proto_type=queue_runner_pb2.QueueRunnerDef, 493 to_proto=QueueRunner.to_proto, 494 from_proto=QueueRunner.from_proto) 495