1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Create threads to run multiple enqueue ops.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import threading 22import weakref 23 24from tensorflow.core.protobuf import queue_runner_pb2 25from tensorflow.python.client import session 26from tensorflow.python.eager import context 27from tensorflow.python.framework import errors 28from tensorflow.python.framework import ops 29from tensorflow.python.platform import tf_logging as logging 30from tensorflow.python.util import deprecation 31from tensorflow.python.util.tf_export import tf_export 32 33_DEPRECATION_INSTRUCTION = ( 34 "To construct input pipelines, use the `tf.data` module.") 35 36 37@tf_export(v1=["train.queue_runner.QueueRunner", "train.QueueRunner"]) 38class QueueRunner(object): 39 """Holds a list of enqueue operations for a queue, each to be run in a thread. 40 41 Queues are a convenient TensorFlow mechanism to compute tensors 42 asynchronously using multiple threads. For example in the canonical 'Input 43 Reader' setup one set of threads generates filenames in a queue; a second set 44 of threads read records from the files, processes them, and enqueues tensors 45 on a second queue; a third set of threads dequeues these input records to 46 construct batches and runs them through training operations. 47 48 There are several delicate issues when running multiple threads that way: 49 closing the queues in sequence as the input is exhausted, correctly catching 50 and reporting exceptions, etc. 51 52 The `QueueRunner`, combined with the `Coordinator`, helps handle these issues. 53 54 @compatibility(eager) 55 QueueRunners are not compatible with eager execution. Instead, please 56 use `tf.data` to get data into your model. 57 @end_compatibility 58 """ 59 60 @deprecation.deprecated(None, _DEPRECATION_INSTRUCTION) 61 def __init__(self, queue=None, enqueue_ops=None, close_op=None, 62 cancel_op=None, queue_closed_exception_types=None, 63 queue_runner_def=None, import_scope=None): 64 """Create a QueueRunner. 65 66 On construction the `QueueRunner` adds an op to close the queue. That op 67 will be run if the enqueue ops raise exceptions. 68 69 When you later call the `create_threads()` method, the `QueueRunner` will 70 create one thread for each op in `enqueue_ops`. Each thread will run its 71 enqueue op in parallel with the other threads. The enqueue ops do not have 72 to all be the same op, but it is expected that they all enqueue tensors in 73 `queue`. 74 75 Args: 76 queue: A `Queue`. 77 enqueue_ops: List of enqueue ops to run in threads later. 78 close_op: Op to close the queue. Pending enqueue ops are preserved. 79 cancel_op: Op to close the queue and cancel pending enqueue ops. 80 queue_closed_exception_types: Optional tuple of Exception types that 81 indicate that the queue has been closed when raised during an enqueue 82 operation. Defaults to `(tf.errors.OutOfRangeError,)`. Another common 83 case includes `(tf.errors.OutOfRangeError, tf.errors.CancelledError)`, 84 when some of the enqueue ops may dequeue from other Queues. 85 queue_runner_def: Optional `QueueRunnerDef` protocol buffer. If specified, 86 recreates the QueueRunner from its contents. `queue_runner_def` and the 87 other arguments are mutually exclusive. 88 import_scope: Optional `string`. Name scope to add. Only used when 89 initializing from protocol buffer. 90 91 Raises: 92 ValueError: If both `queue_runner_def` and `queue` are both specified. 93 ValueError: If `queue` or `enqueue_ops` are not provided when not 94 restoring from `queue_runner_def`. 95 RuntimeError: If eager execution is enabled. 96 """ 97 if context.executing_eagerly(): 98 raise RuntimeError( 99 "QueueRunners are not supported when eager execution is enabled. " 100 "Instead, please use tf.data to get data into your model.") 101 102 if queue_runner_def: 103 if queue or enqueue_ops: 104 raise ValueError("queue_runner_def and queue are mutually exclusive.") 105 self._init_from_proto(queue_runner_def, 106 import_scope=import_scope) 107 else: 108 self._init_from_args( 109 queue=queue, enqueue_ops=enqueue_ops, 110 close_op=close_op, cancel_op=cancel_op, 111 queue_closed_exception_types=queue_closed_exception_types) 112 # Protect the count of runs to wait for. 113 self._lock = threading.Lock() 114 # A map from a session object to the number of outstanding queue runner 115 # threads for that session. 116 self._runs_per_session = weakref.WeakKeyDictionary() 117 # List of exceptions raised by the running threads. 118 self._exceptions_raised = [] 119 120 def _init_from_args(self, queue=None, enqueue_ops=None, close_op=None, 121 cancel_op=None, queue_closed_exception_types=None): 122 """Create a QueueRunner from arguments. 123 124 Args: 125 queue: A `Queue`. 126 enqueue_ops: List of enqueue ops to run in threads later. 127 close_op: Op to close the queue. Pending enqueue ops are preserved. 128 cancel_op: Op to close the queue and cancel pending enqueue ops. 129 queue_closed_exception_types: Tuple of exception types, which indicate 130 the queue has been safely closed. 131 132 Raises: 133 ValueError: If `queue` or `enqueue_ops` are not provided when not 134 restoring from `queue_runner_def`. 135 TypeError: If `queue_closed_exception_types` is provided, but is not 136 a non-empty tuple of error types (subclasses of `tf.errors.OpError`). 137 """ 138 if not queue or not enqueue_ops: 139 raise ValueError("Must provide queue and enqueue_ops.") 140 self._queue = queue 141 self._enqueue_ops = enqueue_ops 142 self._close_op = close_op 143 self._cancel_op = cancel_op 144 if queue_closed_exception_types is not None: 145 if (not isinstance(queue_closed_exception_types, tuple) 146 or not queue_closed_exception_types 147 or not all(issubclass(t, errors.OpError) 148 for t in queue_closed_exception_types)): 149 raise TypeError( 150 "queue_closed_exception_types, when provided, " 151 "must be a tuple of tf.error types, but saw: %s" 152 % queue_closed_exception_types) 153 self._queue_closed_exception_types = queue_closed_exception_types 154 # Close when no more will be produced, but pending enqueues should be 155 # preserved. 156 if self._close_op is None: 157 self._close_op = self._queue.close() 158 # Close and cancel pending enqueues since there was an error and we want 159 # to unblock everything so we can cleanly exit. 160 if self._cancel_op is None: 161 self._cancel_op = self._queue.close(cancel_pending_enqueues=True) 162 if not self._queue_closed_exception_types: 163 self._queue_closed_exception_types = (errors.OutOfRangeError,) 164 else: 165 self._queue_closed_exception_types = tuple( 166 self._queue_closed_exception_types) 167 168 def _init_from_proto(self, queue_runner_def, import_scope=None): 169 """Create a QueueRunner from `QueueRunnerDef`. 170 171 Args: 172 queue_runner_def: Optional `QueueRunnerDef` protocol buffer. 173 import_scope: Optional `string`. Name scope to add. 174 """ 175 assert isinstance(queue_runner_def, queue_runner_pb2.QueueRunnerDef) 176 g = ops.get_default_graph() 177 self._queue = g.as_graph_element( 178 ops.prepend_name_scope(queue_runner_def.queue_name, import_scope)) 179 self._enqueue_ops = [g.as_graph_element( 180 ops.prepend_name_scope(op, import_scope)) 181 for op in queue_runner_def.enqueue_op_name] 182 self._close_op = g.as_graph_element(ops.prepend_name_scope( 183 queue_runner_def.close_op_name, import_scope)) 184 self._cancel_op = g.as_graph_element(ops.prepend_name_scope( 185 queue_runner_def.cancel_op_name, import_scope)) 186 self._queue_closed_exception_types = tuple( 187 errors.exception_type_from_error_code(code) 188 for code in queue_runner_def.queue_closed_exception_types) 189 # Legacy support for old QueueRunnerDefs created before this field 190 # was added. 191 if not self._queue_closed_exception_types: 192 self._queue_closed_exception_types = (errors.OutOfRangeError,) 193 194 @property 195 def queue(self): 196 return self._queue 197 198 @property 199 def enqueue_ops(self): 200 return self._enqueue_ops 201 202 @property 203 def close_op(self): 204 return self._close_op 205 206 @property 207 def cancel_op(self): 208 return self._cancel_op 209 210 @property 211 def queue_closed_exception_types(self): 212 return self._queue_closed_exception_types 213 214 @property 215 def exceptions_raised(self): 216 """Exceptions raised but not handled by the `QueueRunner` threads. 217 218 Exceptions raised in queue runner threads are handled in one of two ways 219 depending on whether or not a `Coordinator` was passed to 220 `create_threads()`: 221 222 * With a `Coordinator`, exceptions are reported to the coordinator and 223 forgotten by the `QueueRunner`. 224 * Without a `Coordinator`, exceptions are captured by the `QueueRunner` and 225 made available in this `exceptions_raised` property. 226 227 Returns: 228 A list of Python `Exception` objects. The list is empty if no exception 229 was captured. (No exceptions are captured when using a Coordinator.) 230 """ 231 return self._exceptions_raised 232 233 @property 234 def name(self): 235 """The string name of the underlying Queue.""" 236 return self._queue.name 237 238 # pylint: disable=broad-except 239 def _run(self, sess, enqueue_op, coord=None): 240 """Execute the enqueue op in a loop, close the queue in case of error. 241 242 Args: 243 sess: A Session. 244 enqueue_op: The Operation to run. 245 coord: Optional Coordinator object for reporting errors and checking 246 for stop conditions. 247 """ 248 decremented = False 249 try: 250 # Make a cached callable from the `enqueue_op` to decrease the 251 # Python overhead in the queue-runner loop. 252 enqueue_callable = sess.make_callable(enqueue_op) 253 while True: 254 if coord and coord.should_stop(): 255 break 256 try: 257 enqueue_callable() 258 except self._queue_closed_exception_types: # pylint: disable=catching-non-exception 259 # This exception indicates that a queue was closed. 260 with self._lock: 261 self._runs_per_session[sess] -= 1 262 decremented = True 263 if self._runs_per_session[sess] == 0: 264 try: 265 sess.run(self._close_op) 266 except Exception as e: 267 # Intentionally ignore errors from close_op. 268 logging.vlog(1, "Ignored exception: %s", str(e)) 269 return 270 except Exception as e: 271 # This catches all other exceptions. 272 if coord: 273 coord.request_stop(e) 274 else: 275 logging.error("Exception in QueueRunner: %s", str(e)) 276 with self._lock: 277 self._exceptions_raised.append(e) 278 raise 279 finally: 280 # Make sure we account for all terminations: normal or errors. 281 if not decremented: 282 with self._lock: 283 self._runs_per_session[sess] -= 1 284 285 def _close_on_stop(self, sess, cancel_op, coord): 286 """Close the queue when the Coordinator requests stop. 287 288 Args: 289 sess: A Session. 290 cancel_op: The Operation to run. 291 coord: Coordinator. 292 """ 293 coord.wait_for_stop() 294 try: 295 sess.run(cancel_op) 296 except Exception as e: 297 # Intentionally ignore errors from cancel_op. 298 logging.vlog(1, "Ignored exception: %s", str(e)) 299 # pylint: enable=broad-except 300 301 def create_threads(self, sess, coord=None, daemon=False, start=False): 302 """Create threads to run the enqueue ops for the given session. 303 304 This method requires a session in which the graph was launched. It creates 305 a list of threads, optionally starting them. There is one thread for each 306 op passed in `enqueue_ops`. 307 308 The `coord` argument is an optional coordinator that the threads will use 309 to terminate together and report exceptions. If a coordinator is given, 310 this method starts an additional thread to close the queue when the 311 coordinator requests a stop. 312 313 If previously created threads for the given session are still running, no 314 new threads will be created. 315 316 Args: 317 sess: A `Session`. 318 coord: Optional `Coordinator` object for reporting errors and checking 319 stop conditions. 320 daemon: Boolean. If `True` make the threads daemon threads. 321 start: Boolean. If `True` starts the threads. If `False` the 322 caller must call the `start()` method of the returned threads. 323 324 Returns: 325 A list of threads. 326 """ 327 with self._lock: 328 try: 329 if self._runs_per_session[sess] > 0: 330 # Already started: no new threads to return. 331 return [] 332 except KeyError: 333 # We haven't seen this session yet. 334 pass 335 self._runs_per_session[sess] = len(self._enqueue_ops) 336 self._exceptions_raised = [] 337 338 ret_threads = [] 339 for op in self._enqueue_ops: 340 name = "QueueRunnerThread-{}-{}".format(self.name, op.name) 341 ret_threads.append(threading.Thread(target=self._run, 342 args=(sess, op, coord), 343 name=name)) 344 if coord: 345 name = "QueueRunnerThread-{}-close_on_stop".format(self.name) 346 ret_threads.append(threading.Thread(target=self._close_on_stop, 347 args=(sess, self._cancel_op, coord), 348 name=name)) 349 for t in ret_threads: 350 if coord: 351 coord.register_thread(t) 352 if daemon: 353 t.daemon = True 354 if start: 355 t.start() 356 return ret_threads 357 358 def to_proto(self, export_scope=None): 359 """Converts this `QueueRunner` to a `QueueRunnerDef` protocol buffer. 360 361 Args: 362 export_scope: Optional `string`. Name scope to remove. 363 364 Returns: 365 A `QueueRunnerDef` protocol buffer, or `None` if the `Variable` is not in 366 the specified name scope. 367 """ 368 if (export_scope is None or 369 self.queue.name.startswith(export_scope)): 370 queue_runner_def = queue_runner_pb2.QueueRunnerDef() 371 queue_runner_def.queue_name = ops.strip_name_scope( 372 self.queue.name, export_scope) 373 for enqueue_op in self.enqueue_ops: 374 queue_runner_def.enqueue_op_name.append( 375 ops.strip_name_scope(enqueue_op.name, export_scope)) 376 queue_runner_def.close_op_name = ops.strip_name_scope( 377 self.close_op.name, export_scope) 378 queue_runner_def.cancel_op_name = ops.strip_name_scope( 379 self.cancel_op.name, export_scope) 380 queue_runner_def.queue_closed_exception_types.extend([ 381 errors.error_code_from_exception_type(cls) 382 for cls in self._queue_closed_exception_types]) 383 return queue_runner_def 384 else: 385 return None 386 387 @staticmethod 388 def from_proto(queue_runner_def, import_scope=None): 389 """Returns a `QueueRunner` object created from `queue_runner_def`.""" 390 return QueueRunner(queue_runner_def=queue_runner_def, 391 import_scope=import_scope) 392 393 394@tf_export(v1=["train.queue_runner.add_queue_runner", "train.add_queue_runner"]) 395@deprecation.deprecated(None, _DEPRECATION_INSTRUCTION) 396def add_queue_runner(qr, collection=ops.GraphKeys.QUEUE_RUNNERS): 397 """Adds a `QueueRunner` to a collection in the graph. 398 399 When building a complex model that uses many queues it is often difficult to 400 gather all the queue runners that need to be run. This convenience function 401 allows you to add a queue runner to a well known collection in the graph. 402 403 The companion method `start_queue_runners()` can be used to start threads for 404 all the collected queue runners. 405 406 Args: 407 qr: A `QueueRunner`. 408 collection: A `GraphKey` specifying the graph collection to add 409 the queue runner to. Defaults to `GraphKeys.QUEUE_RUNNERS`. 410 """ 411 ops.add_to_collection(collection, qr) 412 413 414@tf_export(v1=["train.queue_runner.start_queue_runners", 415 "train.start_queue_runners"]) 416@deprecation.deprecated(None, _DEPRECATION_INSTRUCTION) 417def start_queue_runners(sess=None, coord=None, daemon=True, start=True, 418 collection=ops.GraphKeys.QUEUE_RUNNERS): 419 """Starts all queue runners collected in the graph. 420 421 This is a companion method to `add_queue_runner()`. It just starts 422 threads for all queue runners collected in the graph. It returns 423 the list of all threads. 424 425 Args: 426 sess: `Session` used to run the queue ops. Defaults to the 427 default session. 428 coord: Optional `Coordinator` for coordinating the started threads. 429 daemon: Whether the threads should be marked as `daemons`, meaning 430 they don't block program exit. 431 start: Set to `False` to only create the threads, not start them. 432 collection: A `GraphKey` specifying the graph collection to 433 get the queue runners from. Defaults to `GraphKeys.QUEUE_RUNNERS`. 434 435 Raises: 436 ValueError: if `sess` is None and there isn't any default session. 437 TypeError: if `sess` is not a `tf.compat.v1.Session` object. 438 439 Returns: 440 A list of threads. 441 442 Raises: 443 RuntimeError: If called with eager execution enabled. 444 ValueError: If called without a default `tf.compat.v1.Session` registered. 445 446 @compatibility(eager) 447 Not compatible with eager execution. To ingest data under eager execution, 448 use the `tf.data` API instead. 449 @end_compatibility 450 """ 451 if context.executing_eagerly(): 452 raise RuntimeError("Queues are not compatible with eager execution.") 453 if sess is None: 454 sess = ops.get_default_session() 455 if not sess: 456 raise ValueError("Cannot start queue runners: No default session is " 457 "registered. Use `with sess.as_default()` or pass an " 458 "explicit session to tf.start_queue_runners(sess=sess)") 459 460 if not isinstance(sess, session.SessionInterface): 461 # Following check is due to backward compatibility. (b/62061352) 462 if sess.__class__.__name__ in [ 463 "MonitoredSession", "SingularMonitoredSession"]: 464 return [] 465 raise TypeError("sess must be a `tf.Session` object. " 466 "Given class: {}".format(sess.__class__)) 467 468 queue_runners = ops.get_collection(collection) 469 if not queue_runners: 470 logging.warning( 471 "`tf.train.start_queue_runners()` was called when no queue runners " 472 "were defined. You can safely remove the call to this deprecated " 473 "function.") 474 475 with sess.graph.as_default(): 476 threads = [] 477 for qr in ops.get_collection(collection): 478 threads.extend(qr.create_threads(sess, coord=coord, daemon=daemon, 479 start=start)) 480 return threads 481 482 483ops.register_proto_function(ops.GraphKeys.QUEUE_RUNNERS, 484 proto_type=queue_runner_pb2.QueueRunnerDef, 485 to_proto=QueueRunner.to_proto, 486 from_proto=QueueRunner.from_proto) 487