1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities for AutomaticControlDependencies.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import dtypes 22from tensorflow.python.util import object_identity 23 24READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs" 25RESOURCE_READ_OPS = set() 26 27 28COLLECTIVE_MANAGER_IDS = "_collective_manager_ids" 29 30 31def register_read_only_resource_op(op_type): 32 """Declares that `op_type` does not update its touched resource.""" 33 RESOURCE_READ_OPS.add(op_type) 34 35 36def get_read_only_resource_input_indices_graph(func_graph): 37 """Returns sorted list of read-only resource indices in func_graph.inputs.""" 38 result = [] 39 # A cache to store the read only resource inputs of an Op. 40 # Operation -> ObjectIdentitySet of resource handles. 41 op_read_only_resource_inputs = {} 42 for input_index, t in enumerate(func_graph.inputs): 43 if t.dtype != dtypes.resource: 44 continue 45 read_only = True 46 for op in t.consumers(): 47 if op in op_read_only_resource_inputs: 48 if t not in op_read_only_resource_inputs[op]: 49 read_only = False 50 break 51 else: 52 indices = _get_read_only_resource_input_indices_op(op) 53 op_read_only_resource_inputs[op] = object_identity.ObjectIdentitySet( 54 [op.inputs[i] for i in indices]) 55 if t not in op_read_only_resource_inputs[op]: 56 read_only = False 57 break 58 if read_only: 59 result.append(input_index) 60 return result 61 62 63def _get_read_only_resource_input_indices_op(op): 64 """Returns sorted list of read-only resource indices in op.inputs.""" 65 if op.type in RESOURCE_READ_OPS: 66 return [i for i, t in enumerate(op.inputs) if t.dtype == dtypes.resource] 67 68 try: 69 read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR) 70 except ValueError: 71 # Attr was not set. Add all resource inputs to `writes` and return. 72 return [] 73 74 read_only_index = 0 75 result = [] 76 for i, t in enumerate(op.inputs): 77 if read_only_index >= len(read_only_input_indices): 78 break 79 if op.inputs[i].dtype != dtypes.resource: 80 continue 81 if (read_only_index < len(read_only_input_indices) and 82 i == read_only_input_indices[read_only_index]): 83 result.append(i) 84 read_only_index += 1 85 86 return result 87 88 89def get_read_write_resource_inputs(op): 90 """Returns a tuple of resource reads, writes in op.inputs. 91 92 Args: 93 op: Operation 94 95 Returns: 96 A 2-tuple of ObjectIdentitySets, the first entry containing read-only 97 resource handles and the second containing read-write resource handles in 98 `op.inputs`. 99 """ 100 reads = object_identity.ObjectIdentitySet() 101 writes = object_identity.ObjectIdentitySet() 102 103 if op.type in RESOURCE_READ_OPS: 104 # Add all resource inputs to `reads` and return. 105 reads.update(t for t in op.inputs if t.dtype == dtypes.resource) 106 return (reads, writes) 107 108 try: 109 read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR) 110 except ValueError: 111 # Attr was not set. Add all resource inputs to `writes` and return. 112 writes.update(t for t in op.inputs if t.dtype == dtypes.resource) 113 return (reads, writes) 114 115 read_only_index = 0 116 for i, t in enumerate(op.inputs): 117 if op.inputs[i].dtype != dtypes.resource: 118 continue 119 if (read_only_index < len(read_only_input_indices) and 120 i == read_only_input_indices[read_only_index]): 121 reads.add(op.inputs[i]) 122 read_only_index += 1 123 else: 124 writes.add(op.inputs[i]) 125 return (reads, writes) 126 127 128def _op_writes_to_resource(handle, op): 129 """Returns whether op writes to resource handle. 130 131 Args: 132 handle: Resource handle. Must be an input of `op`. 133 op: Operation. 134 135 Returns: 136 Returns False if op is a read-only op registered using 137 `register_read_only_resource_op` or if `handle` is an input at one of 138 the indices in the `READ_ONLY_RESOURCE_INPUTS_ATTR` attr of the op, True 139 otherwise. 140 141 Raises: 142 ValueError: if `handle` is not an input of `op`. 143 """ 144 if op.type in RESOURCE_READ_OPS: 145 return False 146 input_index = _input_index(op, handle) 147 try: 148 read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR) 149 except ValueError: 150 # Attr was not set. Conservatively assume that the resource is written to. 151 return True 152 return input_index not in read_only_input_indices 153 154 155def _input_index(op, handle): 156 """Returns the index of `handle` in `op.inputs`. 157 158 Args: 159 op: Operation. 160 handle: Resource handle. 161 162 Returns: 163 Index in `op.inputs` receiving the resource `handle`. 164 165 Raises: 166 ValueError: If handle and its replicated input are both not found in 167 `op.inputs`. 168 """ 169 for i, t in enumerate(op.inputs): 170 if handle is t: 171 return i 172 raise ValueError("%s not in list" % str(handle)) 173