1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""AutomaticControlDependencies and related functionality.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import enum 23 24from tensorflow.python.eager import context 25from tensorflow.python.framework import auto_control_deps_utils as utils 26from tensorflow.python.framework import dtypes as dtypes_module 27from tensorflow.python.framework import op_def_registry 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import registry 30from tensorflow.python.framework import sparse_tensor 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import control_flow_ops 33from tensorflow.python.ops import control_flow_util 34from tensorflow.python.ops import tensor_array_ops 35from tensorflow.python.util import nest 36from tensorflow.python.util import object_identity 37from tensorflow.python.util import tf_decorator 38 39# LINT.IfChange 40# Op types that should not run in program order, e.g. because they need to run 41# asynchronously to avoid deadlock. 42ASYNC_STATEFUL_OPS = [ 43 "CollectiveGather", 44 "CollectiveGatherV2", 45 "CollectiveReduce", 46 "CollectiveReduceV2", 47 "CollectiveBcastSend", 48 "CollectiveBcastSendV2", 49 "CollectiveBcastRecv", 50 "CollectiveBcastRecvV2", 51 "NcclAllReduce", 52 # We do not add "Send" here since we want it to be added as a control output 53 # in order to avoid being pruned. 54 "Recv", 55] 56 57LEGACY_RANDOM_OPS = [ 58 # These may be used in variable initializers -- thus their execution should 59 # not be dependent on other stateful operations. This is because although 60 # according to program order, tf.Variables may be created in sequence, 61 # their initialization happens outside of the program order (specifically, 62 # in graph mode their initialization happens by calling a grouped 63 # initializer operation or in eager mode, where initialization is lifted 64 # out of the tf.function and executed the first time the function is 65 # executed). 66 # 67 # Unless there is a specific dependency between the initializers 68 # themselves (e.g. one initializer depends on a Variable whose value depends 69 # on another initializer), the initialization can happen in any order so 70 # long as it's before the associated Variable read operations. 71 # 72 # Note that in general the randomness of legacy random operations is only 73 # guaranteed by providing a graph-level and op-level seed (and ordering of 74 # the same op across multiple iterations of a while_loop is specifically not 75 # guaranteed; see the discussion below). 76 # 77 # There is a possible race condition inside while_loop where the same 78 # random OpKernel instantiation is reused across multiple steps 79 # of the loop. Since legacy Random OpKernels have an internal rng state, 80 # automatic dependency tracking across loop steps would likely 81 # fix this race; and for that case this denylist is problematic. 82 # However, since automatic dependency tracking inside while loops is not 83 # currently supported, and there are no other examples of OpKernel reuse 84 # (each OpKernel is associated with a unique op in graph mode), 85 # this denylist has no effect on the aforementioned behavior. 86 # 87 # TODO(ebrevdo,skyewm): Modify the check against this denylist to 88 # only occur when the op is inside a "variable initialization scope"; and 89 # add proper autodeps inside while_loops that respects this updated check. 90 "RandomUniform", 91 "RandomUniformInt", 92 "RandomStandardNormal", 93 "ParameterizedTruncatedNormal", 94 "TruncatedNormal", 95 "RandomShuffle", 96 "Multinomial", 97 "RandomGamma", 98 "RandomGammaGrad", 99 "RandomPoisson", 100 "RandomPoissonV2", 101] 102 103_ORDER_INSENSITIVE_STATEFUL_OPS = [ 104 "CudnnRNN", "CudnnRNNBackprop", "CudnnRNNV2", "CudnnRNNV3", 105 "CudnnRNNBackpropV2", "CudnnRNNBackpropV3", 106 "EnqueueTPUEmbeddingSparseBatch", "EnqueueTPUEmbeddingIntegerBatch", 107 "EnqueueTPUEmbeddingSparseTensorBatch", 108 "EnqueueTPUEmbeddingRaggedTensorBatch", "RestoreV2", "SaveV2" 109] 110# LINT.ThenChange(//tensorflow/core/grappler/optimizers/function_optimizer.cc) 111 112_ALL_DENYLISTED_OPS = ( 113 set(ASYNC_STATEFUL_OPS) | set(LEGACY_RANDOM_OPS) 114 | set(_ORDER_INSENSITIVE_STATEFUL_OPS)) 115 116# Op types that are marked as stateless, but should be allowlisted to add auto 117# control dependencies. 118_ALLOWLIST_STATELESS_OPS = [ 119 # As TPU collective ops are blocking, if there are more than one collective 120 # op in the function, we need to make sure different collectives ops are 121 # scheduled in certain orders. Otherwise if at the same time all the 122 # replicas are launching different collective ops/programs, it may cause 123 # deadlock. 124 "AllToAll", 125 "CrossReplicaSum", 126 "CollectivePermute", 127] 128 129 130def op_is_stateful(op): 131 # pylint: disable=protected-access 132 return (op._is_stateful and op.type not in _ALL_DENYLISTED_OPS) or ( 133 op.type in _ALLOWLIST_STATELESS_OPS) 134 135 136class ResourceType(enum.Enum): 137 READ_ONLY = "read-only" 138 READ_WRITE = "read-write" 139 140 141def collective_manager_ids_from_op(op): 142 """Returns CollectiveManager ID from the op if one exists, else None. 143 144 CollectiveManager adds collective and no_op operations tagged with an ID, 145 unique to the manager object. This function extracts that ID, or None, if the 146 node was not generated by a CollectiveManager. 147 148 Args: 149 op: `Operation` to get the collective manager ID from. 150 151 Returns: 152 List of CollectiveManager IDs used by the op. 153 """ 154 if op.type == "CollectiveReduce": 155 try: 156 return [op.get_attr("_collective_manager_id")] 157 except ValueError: 158 pass 159 elif op.type == "StatefulPartitionedCall": 160 try: 161 return op.get_attr(utils.COLLECTIVE_MANAGER_IDS) 162 except ValueError: 163 pass 164 return [] 165 166 167class AutomaticControlDependencies(object): 168 """Context manager to automatically add control dependencies. 169 170 Code under this context manager will act as if a sensible set of control 171 dependencies were present. More specifically: 172 1. All stateful ops in the scope will execute (with the exception of ops in 173 ASYNC_STATEFUL_OPS and LEGACY_RANDOM_OPS) 174 2. Stateful ops which modify the same resource will execute in program order 175 176 Note: creating variables in an automatic control dependencies context is not 177 supported (the value of the variables will never change as they will keep 178 getting reinitialized). 179 180 NOT THREAD SAFE 181 """ 182 183 __slots__ = [ 184 "_returned_tensors", "ops_which_must_run", "_graph", "_n_operations", 185 "collective_manager_ids_used" 186 ] 187 188 def __init__(self): 189 self._returned_tensors = object_identity.ObjectIdentitySet() 190 self.ops_which_must_run = set() 191 192 def mark_as_return(self, tensor): 193 """Acts like identity but marks the `Tensor` as a return value. 194 195 This will possibly return a copy of the `Tensor`. Usage: 196 197 ``` 198 with AutomaticControlDependencies() as a: 199 ... 200 t = a.mark_as_return(t) 201 _ = ...(t...) # i.e. it's safe to use t here 202 ``` 203 204 Args: 205 tensor: the `Tensor` to be marked 206 207 Returns: 208 a copy of the `Tensor`. 209 """ 210 if isinstance(tensor, ops.IndexedSlices): 211 values = array_ops.identity(tensor.values) 212 indices = array_ops.identity(tensor.indices) 213 self._returned_tensors.add(indices) 214 self._returned_tensors.add(values) 215 return ops.IndexedSlices(values, indices, dense_shape=tensor.dense_shape) 216 elif isinstance(tensor, sparse_tensor.SparseTensor): 217 values = array_ops.identity(tensor.values) 218 indices = array_ops.identity(tensor.indices) 219 self._returned_tensors.add(indices) 220 self._returned_tensors.add(values) 221 return sparse_tensor.SparseTensor( 222 indices, values, dense_shape=tensor.dense_shape) 223 elif isinstance(tensor, tensor_array_ops.TensorArray): 224 flow = array_ops.identity(tensor.flow) 225 self._returned_tensors.add(flow) 226 return tensor_array_ops.build_ta_with_new_flow(tensor, flow) 227 # We want to make the return values depend on the stateful operations, but 228 # we don't want to introduce a cycle, so we make the return value the result 229 # of a new identity operation that the stateful operations definitely don't 230 # depend on. 231 tensor = array_ops.identity(tensor) 232 self._returned_tensors.add(tensor) 233 return tensor 234 235 def __enter__(self): 236 if context.executing_eagerly(): 237 return self 238 # This code assumes no other thread is adding ops to the graph while 239 # we're adding ops to the graph. 240 # TODO(apassos): Fix this by locking the graph or using a temporary 241 # graph (but that would mess up devices and collections at least, 242 # probably other things as well). 243 self._graph = ops.get_default_graph() 244 self._graph._add_control_dependencies = True # pylint: disable=protected-access 245 self._n_operations = len(self._graph.get_operations()) 246 return self 247 248 def _process_switch(self, switch_op, ops_which_must_run, 249 last_write_to_resource, merge_for_resource): 250 """Processes a switch node for a resource input. 251 252 When tensorflow creates a cond, it creates a control flow context for each 253 branch of the cond. Each external tensor accessed by that branch is routed 254 through a switch op, which gets created in the graph _after_ the op which 255 uses that tensor get created. 256 257 If the resource comes from another switch op we process that one first. 258 259 _process_switch creates a corresponding merge node for the switch node. This 260 merge node is added to the outer control flow context of the switch 261 node. We also ensure that: 262 263 1. The switch node executes after the previous op which used the resource 264 tensor 265 266 2. Any op which uses a resource output of the switch node executes before 267 the merge for the switch node. 268 269 3. The next op which uses the input resource to the switch node (which 270 might be another switch node for the other branch of the conditional) 271 will execute after the merge node is done. 272 273 4. The merge node is marked as must_run so it will run even if no 274 subsequent operation uses the resource. 275 276 Args: 277 switch_op: the switch op to be processed 278 ops_which_must_run: the set of ops which must run 279 last_write_to_resource: map from resource tensor to last op updating 280 it 281 merge_for_resource: map from resource tensor to merge which must follow 282 all usages of it. 283 """ 284 # pylint: disable=protected-access 285 inp = switch_op.inputs[0] 286 input_id = ops.tensor_id(inp) 287 if inp.dtype == dtypes_module.resource and inp.op.type == "Switch": 288 self._process_switch(inp.op, ops_which_must_run, last_write_to_resource, 289 merge_for_resource) 290 output = switch_op.outputs[0] 291 output_id = ops.tensor_id(output) 292 if output_id in merge_for_resource: 293 return 294 new_merge = control_flow_ops.merge( 295 switch_op.outputs, name="artificial_merge") 296 new_merge[0].op._control_flow_context = ( 297 switch_op._control_flow_context.outer_context) 298 # Ensures the merge always runs 299 ops_which_must_run.add(new_merge[0].op) 300 if input_id in last_write_to_resource: 301 # Ensures the switch executes after the previous op using the resource. 302 switch_op._add_control_input(last_write_to_resource[input_id]) 303 # Ensure the next op outside the cond happens after the merge. 304 last_write_to_resource[input_id] = new_merge[0].op 305 if input_id in merge_for_resource: 306 merge_for_resource[input_id]._add_control_input(new_merge[0].op) 307 for o in switch_op.outputs: 308 # Ensures the merge will execute after all ops inside the cond 309 merge_for_resource[ops.tensor_id(o)] = new_merge[0].op 310 311 def __exit__(self, unused_type, unused_value, unused_traceback): 312 # pylint: disable=protected-access 313 if context.executing_eagerly(): 314 return 315 316 if self._graph is not ops.get_default_graph(): 317 raise RuntimeError( 318 "Graph changed while trying to add control dependencies.") 319 320 if hasattr(self._graph, "outer_graph"): 321 outer_val = self._graph.outer_graph._add_control_dependencies 322 self._graph._add_control_dependencies = outer_val 323 else: 324 self._graph._add_control_dependencies = False 325 326 # map from resource tensor to the last op which wrote to it 327 last_write_to_resource = {} 328 # map from resource tensor to the list of reads from it since the last 329 # write or since the beginning of the function. 330 reads_since_last_write_to_resource = collections.defaultdict(list) 331 # CollectiveManager manager_ids within a particular function call should not 332 # be needed outside of that function call. So we keep them separate (though 333 # the general idea of the maps is the same, in the future, we'll need to 334 # correctly thread the control output outside). 335 # Map from collective manager scope to the last op which used it 336 collective_manager_scopes_opened = {} 337 collective_manager_scopes_used = {} 338 # set of conditional and loop exits 339 ops_which_must_run = set() 340 # merge which must depend on ops which use this resource 341 merge_for_resource = {} 342 343 new_operations = self._graph.get_operations()[self._n_operations:] 344 345 # Ensures that uses of resource tensors get serialized properly and all 346 # execute. This is done by keeping a map from resource tensor to the last op 347 # in graph-construction order which used it (last_write_to_resource). 348 # 349 # Conditionals are written in TensorFlow such that every external tensor 350 # accessed in the conditional goes through a switch op and every return 351 # tensor (it's guaranteed that there will be at least one) goes through a 352 # merge op. 353 # 354 # To handle conditionals, switches are handled in a special way (see 355 # comments for _process_switch). Merge nodes created by TF's conditional 356 # logic (as opposed to by _process_switch) are forced to run and also get a 357 # control dependency added to them to ensure all stateful ops inside their 358 # control flow context run. 359 # 360 # We also ensure that if an op is using a resource output by a switch node 361 # (that is, a resource tensor for which there's a value in 362 # merge_for_resource) this op will run before the merge for that resource. 363 # 364 # We try to add control inputs to nodes respecting their control flow 365 # contexts to avoid dead nodes propagating everywhere and leading to 366 # "retval[0] doesn't have value" errors. If a node gets a control dependency 367 # on a dead node (i.e. a note from an untaken control flow branch) that node 368 # will be marked as dead unless it's a merge node. 369 # 370 # TODO(apassos): serialize non-resource-taking stateful ops as well, and 371 # test that it works. Support while loops. Support init_scope escaping from 372 # this. 373 for op in new_operations: 374 # TODO(apassos) make this code safely support while loops. 375 if control_flow_util.IsInWhileLoop(op): 376 continue 377 control_inputs = set() 378 # Ensure stateful ops run. 379 # Read-only ops are added to control outputs if the read value is 380 # consumed. This covers the case when the read value is returned from 381 # the function since that goes through a tf.identity in mark_as_return. 382 if (op_def_registry.get(op.type) is None or 383 (op_is_stateful(op) and 384 (op.type not in utils.RESOURCE_READ_OPS or 385 any(output.consumers() for output in op.outputs)))): 386 ops_which_must_run.add(op) 387 # Make a note of all opened manager_ids. 388 if op.type == "NoOp": 389 try: 390 collective_manager_scopes_opened[op.get_attr( 391 "_collective_manager_id")] = op 392 except ValueError: 393 pass 394 # Ignore switches (they're handled separately) 395 if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource: 396 continue 397 # Make merges trigger all other computation which must run 398 if op.type == "Merge": 399 for o in ops_which_must_run: 400 op._add_control_input(o) 401 for inp in o.inputs: 402 input_id = ops.tensor_id(inp) 403 if input_id in last_write_to_resource: 404 last_write_to_resource[input_id] = op 405 ops_which_must_run = set([op]) 406 continue 407 408 resource_inputs = set() 409 # Check for any resource inputs. If we find any, we update control_inputs 410 # and last_write_to_resource. 411 for inp, resource_type in _get_resource_inputs(op): 412 is_read = resource_type == ResourceType.READ_ONLY 413 input_id = ops.tensor_id(inp) 414 415 # If the op receives the same resource tensor twice as an input, we skip 416 # to avoid the op getting a control dependency on itself. 417 if input_id in resource_inputs: 418 continue 419 420 resource_inputs.add(input_id) 421 # Deal with switches, finally. 422 if inp.op.type == "Switch": 423 self._process_switch(inp.op, ops_which_must_run, 424 last_write_to_resource, merge_for_resource) 425 is_building_function = op.graph.building_function 426 # Ensure uses of resources are serialized 427 if input_id in last_write_to_resource: 428 if is_building_function or ( 429 last_write_to_resource[input_id]._control_flow_context 430 is op._control_flow_context): 431 control_inputs.add(last_write_to_resource[input_id]) 432 # Ensure merges happen after the closing of a cond block 433 if input_id in merge_for_resource: 434 merge_for_resource[input_id]._add_control_input(op) 435 if is_read: 436 reads_since_last_write_to_resource[input_id].append(op) 437 else: 438 control_inputs.update(reads_since_last_write_to_resource[input_id]) 439 reads_since_last_write_to_resource[input_id] = [] 440 last_write_to_resource[input_id] = op 441 442 if (op_is_stateful(op) and not resource_inputs 443 and op._control_flow_context is None): 444 if None in last_write_to_resource: 445 op._add_control_input(last_write_to_resource[None]) 446 last_write_to_resource[None] = op 447 448 # Ensure ordering of collective ops 449 manager_ids = collective_manager_ids_from_op(op) 450 for manager_id in manager_ids: 451 if manager_id in collective_manager_scopes_opened: 452 # Chain this function call if the scope was opened. 453 op._add_control_input(collective_manager_scopes_opened[manager_id]) 454 collective_manager_scopes_opened[manager_id] = op 455 else: 456 # If this op is in a scope not created here, create a chain starting 457 # at this op. 458 if manager_id in collective_manager_scopes_used: 459 op._add_control_input(collective_manager_scopes_used[manager_id]) 460 collective_manager_scopes_used[manager_id] = op 461 462 if control_inputs and not is_building_function: 463 control_inputs = [ 464 c for c in control_inputs 465 if c._control_flow_context is op._control_flow_context 466 ] 467 468 op._add_control_inputs(control_inputs) 469 470 # Ensure all ops which must run do run 471 self.ops_which_must_run.update(ops_which_must_run) 472 for r in nest.flatten(list(self._returned_tensors), expand_composites=True): 473 if self.ops_which_must_run: 474 updated_ops_which_must_run = [] 475 if r.graph.building_function: 476 updated_ops_which_must_run = self.ops_which_must_run 477 else: 478 updated_ops_which_must_run = [ 479 o for o in self.ops_which_must_run 480 if o._control_flow_context is r.op._control_flow_context 481 ] 482 r.op._add_control_inputs(updated_ops_which_must_run) 483 484 self.collective_manager_ids_used = collective_manager_scopes_used 485 486 487_acd_resource_resolvers_registry = registry.Registry("acd_resource_resolvers") 488 489 490def register_acd_resource_resolver(f): 491 """Register a function for resolving resources touched by an op. 492 493 `f` is called for every Operation added in the ACD context with the op's 494 original resource reads and writes. `f` is expected to update the sets of 495 resource reads and writes in-place and return True if it updated either of the 496 sets, False otherwise. 497 498 Example: 499 @register_acd_resource_resolver 500 def ResolveIdentity(op, resource_reads, resource_writes): 501 # op: The `Operation` being processed by ACD currently. 502 # resource_reads: An `ObjectIdentitySet` of read-only resources. 503 # resource_writes: An `ObjectIdentitySet` of read-write resources. 504 if not resource_reads or resource_writes: 505 return False 506 def update(resource_inputs): 507 to_add = [] 508 to_remove = [] 509 for t in resource_inputs: 510 if t.op.type == "Identity": 511 to_remove.append(t) 512 to_add.append(t.op.inputs[0]) 513 if not to_add and not to_remove: 514 return False 515 for t in to_remove: 516 resource_inputs.discard(t) 517 resource_inputs.update(to_add) 518 return True 519 return update(resource_reads) or update(resource_writes) 520 521 Args: 522 f: Python function with signature 523 (Operation, ObjectIdentitySet, ObjectIdentitySet) -> bool 524 525 Returns: 526 The function `f` after adding it to the registry. 527 """ 528 _acd_resource_resolvers_registry.register(f) 529 return f 530 531 532def _get_resource_inputs(op): 533 """Returns an iterable of resources touched by this `op`.""" 534 reads, writes = utils.get_read_write_resource_inputs(op) 535 saturated = False 536 while not saturated: 537 saturated = True 538 for key in _acd_resource_resolvers_registry.list(): 539 # Resolvers should return true if they are updating the list of 540 # resource_inputs. 541 # TODO(srbs): An alternate would be to just compare the old and new set 542 # but that may not be as fast. 543 updated = _acd_resource_resolvers_registry.lookup(key)(op, reads, writes) 544 if updated: 545 # Conservatively remove any resources from `reads` that are also writes. 546 reads = reads.difference(writes) 547 saturated = saturated and not updated 548 549 # Note: A resource handle that is not written to is treated as read-only. We 550 # don't have a special way of denoting an unused resource. 551 for t in reads: 552 yield (t, ResourceType.READ_ONLY) 553 for t in writes: 554 yield (t, ResourceType.READ_WRITE) 555 556 557def automatic_control_dependencies(f): 558 """Wraps f to automatically insert control dependencies. 559 560 The inserted dependencies ensure that: 561 1. All stateful ops in f run when the result of f runs 562 2. Updates to the same resources happen in order. 563 564 Args: 565 f: the function to be wrapped. 566 567 Returns: 568 The wrapped function. 569 """ 570 571 def wrapper(*args, **kwargs): 572 with AutomaticControlDependencies() as a: 573 result = f(*args, **kwargs) 574 result_flat = [a.mark_as_return(t) for t in nest.flatten(result)] 575 return nest.pack_sequence_as(result, result_flat) 576 577 return tf_decorator.make_decorator(f, wrapper) 578