1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for AutomaticControlDependencies.""" 16 17from tensorflow.python.framework import dtypes 18from tensorflow.python.util import object_identity 19 20READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs" 21RESOURCE_READ_OPS = set() 22 23 24COLLECTIVE_MANAGER_IDS = "_collective_manager_ids" 25 26 27def register_read_only_resource_op(op_type): 28 """Declares that `op_type` does not update its touched resource.""" 29 RESOURCE_READ_OPS.add(op_type) 30 31 32def get_read_only_resource_input_indices_graph(func_graph): 33 """Returns sorted list of read-only resource indices in func_graph.inputs.""" 34 result = [] 35 # A cache to store the read only resource inputs of an Op. 36 # Operation -> ObjectIdentitySet of resource handles. 37 op_read_only_resource_inputs = {} 38 for input_index, t in enumerate(func_graph.inputs): 39 if t.dtype != dtypes.resource: 40 continue 41 read_only = True 42 for op in t.consumers(): 43 if op in op_read_only_resource_inputs: 44 if t not in op_read_only_resource_inputs[op]: 45 read_only = False 46 break 47 else: 48 indices = _get_read_only_resource_input_indices_op(op) 49 op_read_only_resource_inputs[op] = object_identity.ObjectIdentitySet( 50 [op.inputs[i] for i in indices]) 51 if t not in op_read_only_resource_inputs[op]: 52 read_only = False 53 break 54 if read_only: 55 result.append(input_index) 56 return result 57 58 59def _get_read_only_resource_input_indices_op(op): 60 """Returns sorted list of read-only resource indices in op.inputs.""" 61 if op.type in RESOURCE_READ_OPS: 62 return [i for i, t in enumerate(op.inputs) if t.dtype == dtypes.resource] 63 64 try: 65 read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR) 66 except ValueError: 67 # Attr was not set. Add all resource inputs to `writes` and return. 68 return [] 69 70 read_only_index = 0 71 result = [] 72 for i, t in enumerate(op.inputs): 73 if read_only_index >= len(read_only_input_indices): 74 break 75 if op.inputs[i].dtype != dtypes.resource: 76 continue 77 if (read_only_index < len(read_only_input_indices) and 78 i == read_only_input_indices[read_only_index]): 79 result.append(i) 80 read_only_index += 1 81 82 return result 83 84 85def get_read_write_resource_inputs(op): 86 """Returns a tuple of resource reads, writes in op.inputs. 87 88 Args: 89 op: Operation 90 91 Returns: 92 A 2-tuple of ObjectIdentitySets, the first entry containing read-only 93 resource handles and the second containing read-write resource handles in 94 `op.inputs`. 95 """ 96 reads = object_identity.ObjectIdentitySet() 97 writes = object_identity.ObjectIdentitySet() 98 99 if op.type in RESOURCE_READ_OPS: 100 # Add all resource inputs to `reads` and return. 101 reads.update(t for t in op.inputs if t.dtype == dtypes.resource) 102 return (reads, writes) 103 104 try: 105 read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR) 106 except ValueError: 107 # Attr was not set. Add all resource inputs to `writes` and return. 108 writes.update(t for t in op.inputs if t.dtype == dtypes.resource) 109 return (reads, writes) 110 111 read_only_index = 0 112 for i, t in enumerate(op.inputs): 113 if op.inputs[i].dtype != dtypes.resource: 114 continue 115 if (read_only_index < len(read_only_input_indices) and 116 i == read_only_input_indices[read_only_index]): 117 reads.add(op.inputs[i]) 118 read_only_index += 1 119 else: 120 writes.add(op.inputs[i]) 121 return (reads, writes) 122 123 124def _op_writes_to_resource(handle, op): 125 """Returns whether op writes to resource handle. 126 127 Args: 128 handle: Resource handle. Must be an input of `op`. 129 op: Operation. 130 131 Returns: 132 Returns False if op is a read-only op registered using 133 `register_read_only_resource_op` or if `handle` is an input at one of 134 the indices in the `READ_ONLY_RESOURCE_INPUTS_ATTR` attr of the op, True 135 otherwise. 136 137 Raises: 138 ValueError: if `handle` is not an input of `op`. 139 """ 140 if op.type in RESOURCE_READ_OPS: 141 return False 142 input_index = _input_index(op, handle) 143 try: 144 read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR) 145 except ValueError: 146 # Attr was not set. Conservatively assume that the resource is written to. 147 return True 148 return input_index not in read_only_input_indices 149 150 151def _input_index(op, handle): 152 """Returns the index of `handle` in `op.inputs`. 153 154 Args: 155 op: Operation. 156 handle: Resource handle. 157 158 Returns: 159 Index in `op.inputs` receiving the resource `handle`. 160 161 Raises: 162 ValueError: If handle and its replicated input are both not found in 163 `op.inputs`. 164 """ 165 for i, t in enumerate(op.inputs): 166 if handle is t: 167 return i 168 raise ValueError(f"{handle!s} not in list of inputs for op: {op!r}") 169