1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""AutomaticControlDependencies and related functionality.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.eager import context 22from tensorflow.python.framework import dtypes as dtypes_module 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import sparse_tensor 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import control_flow_ops 27from tensorflow.python.ops import control_flow_util 28from tensorflow.python.ops import tensor_array_ops 29from tensorflow.python.util import nest 30from tensorflow.python.util import tf_decorator 31 32# Op types that should not run in program order, e.g. because they need to run 33# asynchronously to avoid deadlock. 34ASYNC_STATEFUL_OPS = [ 35 "CollectiveGather", 36 "CollectiveReduce", 37 "CollectiveBcastSend", 38 "CollectiveBcastRecv", 39 "NcclAllReduce", 40] 41 42LEGACY_RANDOM_OPS = [ 43 # These may be used in variable initializers -- thus their execution should 44 # not be dependent on other stateful operations. This is because although 45 # according to program order, tf.Variables may be created in sequence, 46 # their initialization happens outside of the program order (specifically, 47 # in graph mode their initialization happens by calling a grouped 48 # initializer operation or in eager mode, where initialization is lifted 49 # out of the tf.function and executed the first time the function is 50 # executed). 51 # 52 # Unless there is a specific dependency between the initializers 53 # themselves (e.g. one initializer depends on a Variable whose value depends 54 # on another initializer), the initialization can happen in any order so 55 # long as it's before the associated Variable read operations. 56 # 57 # Note that in general the randomness of legacy random operations is only 58 # guaranteed by providing a graph-level and op-level seed (and ordering of 59 # the same op across multiple iterations of a while_loop is specifically not 60 # guaranteed; see the discussion below). 61 # 62 # There is a possible race condition inside while_loop where the same 63 # random OpKernel instantiation is reused across multiple steps 64 # of the loop. Since legacy Random OpKernels have an internal rng state, 65 # automatic dependency tracking across loop steps would likely 66 # fix this race; and for that case this blacklist is problematic. 67 # However, since automatic dependency tracking inside while loops is not 68 # currently supported, and there are no other examples of OpKernel reuse 69 # (each OpKernel is associated with a unique op in graph mode), 70 # this blacklist has no effect on the aforementioned behavior. 71 # 72 # TODO(ebrevdo,skyewm): Modify the check against this blacklist to 73 # only occur when the op is inside a "variable initialization scope"; and 74 # add proper autodeps inside while_loops that respects this updated check. 75 "RandomUniform", 76 "RandomUniformInt", 77 "RandomStandardNormal", 78 "ParameterizedTruncatedNormal", 79 "TruncatedNormal", 80 "RandomShuffle", 81 "Multinomial", 82 "RandomGamma", 83 "RandomGammaGrad", 84 "RandomPoisson", 85 "RandomPoissonV2", 86] 87 88_ALL_BLACKLISTED_OPS = set(ASYNC_STATEFUL_OPS) | set(LEGACY_RANDOM_OPS) 89 90 91def op_is_stateful(op_def): 92 return op_def.is_stateful and op_def.name not in _ALL_BLACKLISTED_OPS 93 94 95class AutomaticControlDependencies(object): 96 """Context manager to automatically add control dependencies. 97 98 Code under this context manager will act as if a sensible set of control 99 dependencies were present. More specifically: 100 1. All stateful ops in the scope will execute (with the exception of ops in 101 ASYNC_STATEFUL_OPS and LEGACY_RANDOM_OPS) 102 2. Stateful ops which modify the same resource will execute in program order 103 104 Note: creating variables in an automatic control dependencies context is not 105 supported (the value of the variables will never change as they will keep 106 getting reinitialized). 107 108 NOT THREAD SAFE 109 """ 110 111 def __init__(self): 112 self._returned_tensors = set() 113 self.ops_which_must_run = set() 114 115 def mark_as_return(self, tensor): 116 """Acts like identity but marks the `Tensor` as a return value. 117 118 This will possibly return a copy of the `Tensor`. Usage: 119 120 ``` 121 with AutomaticControlDependencies() as a: 122 ... 123 t = a.mark_as_return(t) 124 _ = ...(t...) # i.e. it's safe to use t here 125 ``` 126 127 Args: 128 tensor: the `Tensor` to be marked 129 130 Returns: 131 a copy of the `Tensor`. 132 """ 133 if isinstance(tensor, ops.IndexedSlices): 134 values = array_ops.identity(tensor.values) 135 indices = array_ops.identity(tensor.indices) 136 self._returned_tensors.add(indices) 137 self._returned_tensors.add(values) 138 return ops.IndexedSlices(values, indices, dense_shape=tensor.dense_shape) 139 elif isinstance(tensor, sparse_tensor.SparseTensor): 140 values = array_ops.identity(tensor.values) 141 indices = array_ops.identity(tensor.indices) 142 self._returned_tensors.add(indices) 143 self._returned_tensors.add(values) 144 return sparse_tensor.SparseTensor( 145 indices, values, dense_shape=tensor.dense_shape) 146 elif isinstance(tensor, tensor_array_ops.TensorArray): 147 flow = array_ops.identity(tensor.flow) 148 self._returned_tensors.add(flow) 149 return tensor_array_ops.build_ta_with_new_flow(tensor, flow) 150 # We want to make the return values depend on the stateful operations, but 151 # we don't want to introduce a cycle, so we make the return value the result 152 # of a new identity operation that the stateful operations definitely don't 153 # depend on. 154 tensor = array_ops.identity(tensor) 155 self._returned_tensors.add(tensor) 156 return tensor 157 158 def __enter__(self): 159 if context.executing_eagerly(): 160 return self 161 # This code assumes no other thread is adding ops to the graph while 162 # we're adding ops to the graph. 163 # TODO(apassos): Fix this by locking the graph or using a temporary 164 # graph (but that would mess up devices and collections at least, 165 # probably other things as well). 166 self._graph = ops.get_default_graph() 167 self._graph._add_control_dependencies = True # pylint: disable=protected-access 168 self._n_operations = len(self._graph.get_operations()) 169 return self 170 171 def _process_switch(self, switch_op, ops_which_must_run, 172 last_op_using_resource_tensor, merge_for_resource): 173 """Processes a switch node for a resource input. 174 175 When tensorflow creates a cond, it creates a control flow context for each 176 branch of the cond. Each external tensor accessed by that branch is routed 177 through a switch op, which gets created in the graph _after_ the op which 178 uses that tensor get created. 179 180 If the resource comes from another switch op we process that one first. 181 182 _process_switch creates a corresponding merge node for the switch node. This 183 merge node is added to the outer control flow context of the switch 184 node. We also ensure that: 185 186 1. The switch node executes after the previous op which used the resource 187 tensor 188 189 2. Any op which uses a resource output of the switch node executes before 190 the merge for the switch node. 191 192 3. The next op which uses the input resource to the switch node (which 193 might be another switch node for the other branch of the conditional) 194 will execute after the merge node is done. 195 196 4. The merge node is marked as must_run so it will run even if no 197 subsequent operation uses the resource. 198 199 Args: 200 switch_op: the switch op to be processed 201 ops_which_must_run: the set of ops which must run 202 last_op_using_resource_tensor: map from resource tensor to last op using 203 it 204 merge_for_resource: map from resource tensor to merge which must follow 205 all usages of it. 206 """ 207 inp = switch_op.inputs[0] 208 if inp.dtype == dtypes_module.resource and inp.op.type == "Switch": 209 self._process_switch(inp.op, ops_which_must_run, 210 last_op_using_resource_tensor, merge_for_resource) 211 if switch_op.outputs[0] in merge_for_resource: 212 return 213 new_merge = control_flow_ops.merge(switch_op.outputs, 214 name="artificial_merge") 215 new_merge[0].op._control_flow_context = ( # pylint: disable=protected-access 216 switch_op._control_flow_context.outer_context) # pylint: disable=protected-access 217 # Ensures the merge always runs 218 ops_which_must_run.add(new_merge[0].op) 219 if inp in last_op_using_resource_tensor: 220 # Ensures the switch executes after the previous op using the resource. 221 switch_op._add_control_input(last_op_using_resource_tensor[inp]) # pylint: disable=protected-access 222 # Ensure the next op outside the cond happens after the merge. 223 last_op_using_resource_tensor[inp] = new_merge[0].op 224 if inp in merge_for_resource: 225 merge_for_resource[inp]._add_control_input(new_merge[0].op) # pylint: disable=protected-access 226 for o in switch_op.outputs: 227 # Ensures the merge will execute after all ops inside the cond 228 merge_for_resource[o] = new_merge[0].op 229 230 def __exit__(self, unused_type, unused_value, unused_traceback): 231 if context.executing_eagerly(): 232 return 233 234 if self._graph is not ops.get_default_graph(): 235 raise RuntimeError( 236 "Graph changed while trying to add control dependencies.") 237 238 # pylint: disable=protected-access 239 if hasattr(self._graph, "outer_graph"): 240 outer_val = self._graph.outer_graph._add_control_dependencies 241 self._graph._add_control_dependencies = outer_val 242 else: 243 self._graph._add_control_dependencies = False 244 # pylint: enable=protected-access 245 246 # map from resource tensor to the last op which used it 247 last_op_using_resource_tensor = {} 248 # set of conditional and loop exits 249 ops_which_must_run = set() 250 # merge which must depend on ops which use this resource 251 merge_for_resource = {} 252 253 new_operations = self._graph.get_operations()[self._n_operations:] 254 255 # Ensures that uses of resource tensors get serialized properly and all 256 # execute. This is done by keeping a map from resource tensor to the last op 257 # in graph-construction order which used it (last_op_using_resource_tensor). 258 # 259 # Conditionals are written in TensorFlow such that every external tensor 260 # accessed in the conditional goes through a switch op and every return 261 # tensor (it's guaranteed that there will be at least one) goes through a 262 # merge op. 263 # 264 # To handle conditionals, switches are handled in a special way (see 265 # comments for _process_switch). Merge nodes created by TF's conditional 266 # logic (as opposed to by _process_switch) are forced to run and also get a 267 # control dependency added to them to ensure all stateful ops inside their 268 # control flow context run. 269 # 270 # We also ensure that if an op is using a resource output by a switch node 271 # (that is, a resource tensor for which there's a value in 272 # merge_for_resource) this op will run before the merge for that resource. 273 # 274 # We try to add control inputs to nodes respecting their control flow 275 # contexts to avoid dead nodes propagating everywhere and leading to 276 # "retval[0] doesn't have value" errors. If a node gets a control dependency 277 # on a dead node (i.e. a note from an untaken control flow branch) that node 278 # will be marked as dead unless it's a merge node. 279 # 280 # TODO(apassos): serialize non-resource-taking stateful ops as well, and 281 # test that it works. Support while loops. Support init_scope escaping from 282 # this. 283 for op in new_operations: 284 # TODO(apassos) make this code safely support while loops. 285 if control_flow_util.IsInWhileLoop(op): 286 continue 287 control_inputs = set() 288 # Ensure stateful ops run 289 if (op.type not in self._graph._registered_ops # pylint: disable=protected-access 290 or op_is_stateful(self._graph._registered_ops[op.type])): # pylint: disable=protected-access 291 ops_which_must_run.add(op) 292 # Ignore switches (they're handled separately) 293 if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource: 294 continue 295 # Make merges trigger all other computation which must run 296 if op.type == "Merge": 297 for o in ops_which_must_run: 298 op._add_control_input(o) # pylint: disable=protected-access 299 for inp in o.inputs: 300 if inp in last_op_using_resource_tensor: 301 last_op_using_resource_tensor[inp] = op 302 ops_which_must_run = set([op]) 303 continue 304 found_resource = False 305 # Check for any resource inputs. If we find any, we update control_inputs 306 # and last_op_using_resource_tensor. Note that we dedup op.inputs in case 307 # op receives the same resource tensor twice as input, which would result 308 # in op getting a control dependency on itself. 309 for inp in set(op.inputs): 310 if inp.dtype != dtypes_module.resource: 311 continue 312 found_resource = True 313 # Deal with switches, finally. 314 if inp.op.type == "Switch": 315 self._process_switch(inp.op, ops_which_must_run, 316 last_op_using_resource_tensor, 317 merge_for_resource) 318 # Ensure uses of resources are serialized 319 if inp in last_op_using_resource_tensor: 320 if (last_op_using_resource_tensor[inp]._control_flow_context # pylint: disable=protected-access 321 is op._control_flow_context): # pylint: disable=protected-access 322 control_inputs.add(last_op_using_resource_tensor[inp]) 323 # Ensure merges happen after the closing of a cond block 324 if inp in merge_for_resource: 325 merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access 326 last_op_using_resource_tensor[inp] = op 327 if (op_is_stateful(op.op_def) and not found_resource 328 and op._control_flow_context is None): # pylint: disable=protected-access 329 if None in last_op_using_resource_tensor: 330 op._add_control_input(last_op_using_resource_tensor[None]) # pylint: disable=protected-access 331 last_op_using_resource_tensor[None] = op 332 control_inputs = [c for c in control_inputs 333 if c._control_flow_context is op._control_flow_context] # pylint: disable=protected-access 334 op._add_control_inputs(control_inputs) # pylint: disable=protected-access 335 336 # Ensure all ops which must run do run 337 self.ops_which_must_run.update(ops_which_must_run) 338 for r in self._returned_tensors: 339 if self.ops_which_must_run: 340 r.op._add_control_inputs( # pylint: disable=protected-access 341 [o for o in self.ops_which_must_run 342 if o._control_flow_context is r.op._control_flow_context]) # pylint: disable=protected-access 343 344 345def automatic_control_dependencies(f): 346 """Wraps f to automatically insert control dependencies. 347 348 The inserted dependencies ensure that: 349 1. All stateful ops in f run when the result of f runs 350 2. Updates to the same resources happen in order. 351 352 Args: 353 f: the function to be wrapped. 354 355 Returns: 356 The wrapped function. 357 """ 358 359 def wrapper(*args, **kwargs): 360 with AutomaticControlDependencies() as a: 361 result = f(*args, **kwargs) 362 result_flat = [a.mark_as_return(t) for t in nest.flatten(result)] 363 return nest.pack_sequence_as(result, result_flat) 364 365 return tf_decorator.make_decorator(f, wrapper) 366