• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for AutomaticControlDependencies."""
16
17from tensorflow.python.framework import dtypes
18from tensorflow.python.util import object_identity
19
20READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs"
21RESOURCE_READ_OPS = set()
22
23
24COLLECTIVE_MANAGER_IDS = "_collective_manager_ids"
25
26
27def register_read_only_resource_op(op_type):
28  """Declares that `op_type` does not update its touched resource."""
29  RESOURCE_READ_OPS.add(op_type)
30
31
32def get_read_only_resource_input_indices_graph(func_graph):
33  """Returns sorted list of read-only resource indices in func_graph.inputs."""
34  result = []
35  # A cache to store the read only resource inputs of an Op.
36  # Operation -> ObjectIdentitySet of resource handles.
37  op_read_only_resource_inputs = {}
38  for input_index, t in enumerate(func_graph.inputs):
39    if t.dtype != dtypes.resource:
40      continue
41    read_only = True
42    for op in t.consumers():
43      if op in op_read_only_resource_inputs:
44        if t not in op_read_only_resource_inputs[op]:
45          read_only = False
46          break
47      else:
48        indices = _get_read_only_resource_input_indices_op(op)
49        op_read_only_resource_inputs[op] = object_identity.ObjectIdentitySet(
50            [op.inputs[i] for i in indices])
51        if t not in op_read_only_resource_inputs[op]:
52          read_only = False
53          break
54    if read_only:
55      result.append(input_index)
56  return result
57
58
59def _get_read_only_resource_input_indices_op(op):
60  """Returns sorted list of read-only resource indices in op.inputs."""
61  if op.type in RESOURCE_READ_OPS:
62    return [i for i, t in enumerate(op.inputs) if t.dtype == dtypes.resource]
63
64  try:
65    read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR)
66  except ValueError:
67    # Attr was not set. Add all resource inputs to `writes` and return.
68    return []
69
70  read_only_index = 0
71  result = []
72  for i, t in enumerate(op.inputs):
73    if read_only_index >= len(read_only_input_indices):
74      break
75    if op.inputs[i].dtype != dtypes.resource:
76      continue
77    if (read_only_index < len(read_only_input_indices) and
78        i == read_only_input_indices[read_only_index]):
79      result.append(i)
80      read_only_index += 1
81
82  return result
83
84
85def get_read_write_resource_inputs(op):
86  """Returns a tuple of resource reads, writes in op.inputs.
87
88  Args:
89    op: Operation
90
91  Returns:
92    A 2-tuple of ObjectIdentitySets, the first entry containing read-only
93    resource handles and the second containing read-write resource handles in
94    `op.inputs`.
95  """
96  reads = object_identity.ObjectIdentitySet()
97  writes = object_identity.ObjectIdentitySet()
98
99  if op.type in RESOURCE_READ_OPS:
100    # Add all resource inputs to `reads` and return.
101    reads.update(t for t in op.inputs if t.dtype == dtypes.resource)
102    return (reads, writes)
103
104  try:
105    read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR)
106  except ValueError:
107    # Attr was not set. Add all resource inputs to `writes` and return.
108    writes.update(t for t in op.inputs if t.dtype == dtypes.resource)
109    return (reads, writes)
110
111  read_only_index = 0
112  for i, t in enumerate(op.inputs):
113    if op.inputs[i].dtype != dtypes.resource:
114      continue
115    if (read_only_index < len(read_only_input_indices) and
116        i == read_only_input_indices[read_only_index]):
117      reads.add(op.inputs[i])
118      read_only_index += 1
119    else:
120      writes.add(op.inputs[i])
121  return (reads, writes)
122
123
124def _op_writes_to_resource(handle, op):
125  """Returns whether op writes to resource handle.
126
127  Args:
128    handle: Resource handle. Must be an input of `op`.
129    op: Operation.
130
131  Returns:
132    Returns False if op is a read-only op registered using
133    `register_read_only_resource_op` or if `handle` is an input at one of
134    the indices in the `READ_ONLY_RESOURCE_INPUTS_ATTR` attr of the op, True
135    otherwise.
136
137  Raises:
138    ValueError: if `handle` is not an input of `op`.
139  """
140  if op.type in RESOURCE_READ_OPS:
141    return False
142  input_index = _input_index(op, handle)
143  try:
144    read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR)
145  except ValueError:
146    # Attr was not set. Conservatively assume that the resource is written to.
147    return True
148  return input_index not in read_only_input_indices
149
150
151def _input_index(op, handle):
152  """Returns the index of `handle` in `op.inputs`.
153
154  Args:
155    op: Operation.
156    handle: Resource handle.
157
158  Returns:
159    Index in `op.inputs` receiving the resource `handle`.
160
161  Raises:
162    ValueError: If handle and its replicated input are both not found in
163    `op.inputs`.
164  """
165  for i, t in enumerate(op.inputs):
166    if handle is t:
167      return i
168  raise ValueError(f"{handle!s} not in list of inputs for op: {op!r}")
169