• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Code to generate inputs/outputs exclusion lists for GradientTape."""
15
16import sys
17
18import gast
19
20from tensorflow.python.autograph.pyct import anno
21from tensorflow.python.autograph.pyct import cfg
22from tensorflow.python.autograph.pyct import parser
23from tensorflow.python.autograph.pyct import qual_names
24from tensorflow.python.autograph.pyct import transformer
25from tensorflow.python.autograph.pyct.static_analysis import activity
26from tensorflow.python.autograph.pyct.static_analysis import liveness
27from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs
28from tensorflow.python.framework import op_def_registry
29from tensorflow.python.framework import ops
30
31_GENERATED_FILE_HEADER = """/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
32
33Licensed under the Apache License, Version 2.0 (the "License");
34you may not use this file except in compliance with the License.
35You may obtain a copy of the License at
36
37    http://www.apache.org/licenses/LICENSE-2.0
38
39Unless required by applicable law or agreed to in writing, software
40distributed under the License is distributed on an "AS IS" BASIS,
41WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
42See the License for the specific language governing permissions and
43limitations under the License.
44==============================================================================*/
45
46// Inputs/Outputs exclusion lists for GradientTape.
47//
48// This file is MACHINE GENERATED! Do not edit.
49// Generated by: tensorflow/python/eager/gen_gradient_input_output_exclusions.py
50"""
51
52_INCLUDES = """
53#include "tensorflow/python/eager/pywrap_gradient_exclusions.h"
54
55#include "absl/types/optional.h"
56#include "tensorflow/core/lib/gtl/flatmap.h"
57#include "tensorflow/core/lib/gtl/flatset.h"
58
59using tensorflow::string;
60
61namespace {
62// Keep static data in a format that's easy to init statically.
63struct OpIndexInfo {
64  const char *op_name;
65  int num_indices;
66  std::array<int, 4> unused_indices;
67};
68
69// Helper function to initialize FlatMap<string,FlatSet> from OpIndexInfo.
70template <typename T>
71auto OpGradientInfoInit(const T &a) {
72  auto *m = new tensorflow::gtl::FlatMap<string, tensorflow::gtl::FlatSet<int>>;
73  for (const auto &item : a) {
74    m->emplace(string(item.op_name),
75               tensorflow::gtl::FlatSet<int>(
76                   item.unused_indices.begin(),
77                   item.unused_indices.begin() + item.num_indices));
78  }
79  return m;
80}
81}  // namespace
82"""
83
84_EXCLUDED_OPS = [
85    # Composite ops with custom gradient functions.
86    "If",
87    "StatelessIf",
88    "While",
89    "StatelessWhile",
90    "Case",
91
92    # TF Lite. These ops only appear in OSS.
93    # TODO(srbs): Find a better way to filter these out.
94    "AudioMicrofrontend",
95]
96
97
98class _SubscriptUseTracker(transformer.Base):
99  """Track uses of composite names, excluding certain names when subscripted."""
100
101  def __init__(self, ctx, exclude_when_subscripted):
102    super(_SubscriptUseTracker, self).__init__(ctx)
103    self.exclude = exclude_when_subscripted
104    self.reads = set()
105    self.complex_reads = set()
106
107  def visit_Attribute(self, node):
108    """Visits attribute nodes in the AST."""
109    if anno.hasanno(node, anno.Basic.QN):
110      qn = anno.getanno(node, anno.Basic.QN)
111      if isinstance(node.ctx, gast.Load):
112        self.reads.add(qn)
113    node = self.generic_visit(node)
114    return node
115
116  def visit_Subscript(self, node):
117    """Visits nodes with subscript in the AST."""
118    s = node.slice
119    if anno.hasanno(node, anno.Basic.QN):
120      qn = anno.getanno(node, anno.Basic.QN)
121      if isinstance(node.ctx, gast.Load):
122        self.reads.add(qn)
123    elif isinstance(s, (gast.Tuple, gast.Slice)):
124      if anno.hasanno(node.value, anno.Basic.QN):
125        self.complex_reads.add(anno.getanno(node.value, anno.Basic.QN))
126    value_qn = anno.getanno(node.value, anno.Basic.QN, None)
127    if value_qn in self.exclude:
128      node.value = self.generic_visit(node.value)
129    else:
130      node.value = self.visit(node.value)
131    node.slice = self.visit(s)
132    return node
133
134
135class _FunctionCallsTracker(transformer.Base):
136  """Tracks any function calls made with a given first argument name."""
137
138  def __init__(self, ctx, first_argument_name):
139    super(_FunctionCallsTracker, self).__init__(ctx)
140    self.first_argument_name = first_argument_name
141    self.calls = set()
142
143  def visit_Name(self, node):
144    node = self.generic_visit(node)
145    if isinstance(node.ctx, gast.Load) and node.id in self.ctx.info.namespace:
146      anno.setanno(node, "static_value", self.ctx.info.namespace[node.id])
147    return node
148
149  def visit_Attribute(self, node):
150    node = self.generic_visit(node)
151    parent_val = anno.getanno(node.value, "static_value", default=None)
152    if parent_val is not None:
153      if hasattr(parent_val, node.attr):
154        anno.setanno(node, "static_value", getattr(parent_val, node.attr))
155    return node
156
157  def visit_Call(self, node):
158    node = self.generic_visit(node)
159    if (node.args and anno.getanno(node.args[0], anno.Basic.QN,
160                                   None) == self.first_argument_name):
161      fn_object = anno.getanno(node.func, "static_value", None)
162      if fn_object is not None:
163        self.calls.add(fn_object)
164    return node
165
166
167_ALL = object()
168
169
170def _live_tensors(f, attr_name="inputs"):
171  """Returns the indices of the used inputs.
172
173  Note: This currently only handles direct index accesses e.g. op.inputs[1].
174  If the function has slicing or list comprehension on attr_name then returns
175  _ALL. This ensure that this is correct even if inefficient.
176
177  Args:
178    f: A grad function, taking the op as first argument.
179    attr_name: op attr to track. "inputs" or "outputs".
180
181  Returns:
182    Either one of:
183      * set of integers representing individual indices of inputs used
184      * the value _ALL, if indices are used but cannot be determined which
185      * empty set, if no inputs are used
186  """
187  node, _ = parser.parse_entity(f, ())
188  entity_info = transformer.EntityInfo(
189      name=f.__name__,
190      source_code=None,
191      source_file=None,
192      future_features=(),
193      namespace=sys.modules[f.__module__].__dict__)
194  ctx = transformer.Context(entity_info, None, None)
195
196  graphs = cfg.build(node)
197  node = qual_names.resolve(node)
198  node = activity.resolve(node, ctx, None)
199  node = reaching_fndefs.resolve(node, ctx, graphs)
200  node = liveness.resolve(node, ctx, graphs)
201
202  op_arg_name = anno.getanno(node.args.args[0], anno.Basic.QN)
203  op_inputs_outputs_name = qual_names.QN(op_arg_name, attr=attr_name)
204
205  special_tracker = _SubscriptUseTracker(ctx, (op_inputs_outputs_name,))
206  node = special_tracker.visit(node)
207
208  live_vars_in = anno.getanno(node.body[0], anno.Static.LIVE_VARS_IN)
209  inputs_outputs_used_qns = set()
210  for v in special_tracker.complex_reads:
211    # Complicated patterns like op.inputs[:3]. Could be smarter about them
212    # if they matter much.
213    if v == op_inputs_outputs_name:
214      return _ALL
215  for v in live_vars_in:
216    if v in special_tracker.reads:
217      if (v.has_subscript() and v.parent == op_inputs_outputs_name):
218        inputs_outputs_used_qns.add(v)
219      elif v == op_inputs_outputs_name:
220        # When op.{attr_name} is used directly, assume all tensors are
221        # used for now. In that case, no point digging further.
222        # TODO(mdan): We can descend into tuple expansions.
223        return _ALL
224
225  function_calls_tracker = _FunctionCallsTracker(ctx, op_arg_name)
226  node = function_calls_tracker.visit(node)
227
228  input_output_indices = set()
229
230  for called_f in function_calls_tracker.calls:
231    child_indices = _live_tensors(called_f, attr_name=attr_name)
232    if child_indices is _ALL:
233      return _ALL
234    input_output_indices |= child_indices
235
236  for v in inputs_outputs_used_qns:
237    assert v.has_subscript()
238    _, subscript = v.qn
239    if not subscript.is_simple():
240      # Not a number, assuming it can be anything.
241      return _ALL
242    subscript_val, = subscript.qn
243    if (not isinstance(subscript_val, qual_names.Literal) and
244        not isinstance(subscript_val.value, int)):
245      # Not a number, assuming it can be anything.
246      return _ALL
247    input_output_indices.add(subscript_val.value)
248  return input_output_indices
249
250
251def _get_num_inputs_outputs(op_type):
252  """Returns (num_inputs, num_outputs).
253
254  Args:
255    op_type: String. The type of the Operation. Used to lookup the op in the
256      registry.
257
258  Returns:
259    (num_inputs, num_outputs), for either num_inputs or num_outputs if the value
260    can't be statically inferred from the OpDef alone or of the OpDef lookup
261    fails, -1 is returned.
262  """
263
264  def _is_list_arg(arg):
265    return arg.number_attr or arg.type_list_attr
266
267  def _count_args(arg_defs):
268    for arg in arg_defs:
269      if _is_list_arg(arg):
270        # Op has list type args which could be variable.
271        return -1
272    return len(arg_defs)
273
274  op_def = op_def_registry.get(op_type)
275  if not op_def:
276    return -1, -1
277  return _count_args(op_def.input_arg), _count_args(op_def.output_arg)
278
279
280def get_entries(attr_name):
281  """Returns the dict of entries.
282
283  Each entry is of the form {op_name, {true|false, indices}}
284
285  true: All values are unused.
286  false: `indices` are the only unused indices.
287
288  Note: ops for which all values are used are not printed.
289
290  Args:
291    attr_name: inputs or outputs.
292
293  Returns:
294    A dict from op_type to formatted entry in the dict.
295  """
296  assert attr_name in ["inputs", "outputs"]
297  entries = {}
298  for op_type in ops._gradient_registry.list():  # pylint: disable=protected-access
299    if op_type in _EXCLUDED_OPS:
300      continue
301    num_values = _get_num_inputs_outputs(op_type)[0 if attr_name ==
302                                                  "inputs" else 1]
303    gradient_fn = ops._gradient_registry.lookup(op_type)  # pylint: disable=protected-access
304    if gradient_fn is None:
305      # NotDifferentiable
306      if num_values != -1:
307        entries[op_type] = "{\"%s\"}," % op_type
308      continue
309    used_tensors = _live_tensors(gradient_fn, attr_name=attr_name)
310    if used_tensors is _ALL:
311      continue
312    elif not used_tensors:
313      entries[op_type] = "{\"%s\"}," % op_type
314    else:
315      all_tensors = set(range(num_values))
316      unused_tensors = all_tensors - used_tensors
317      if unused_tensors:
318        unused_tensor_list = sorted(list(unused_tensors))
319        entries[op_type] = "{\"%s\", %d, {%s}}," % (
320            op_type, len(unused_tensor_list), ", ".join(
321                str(i) for i in unused_tensor_list))
322  return entries
323
324
325def get_function(name, entries):
326  """Generates lookup function with given name and lookup table entries."""
327  contents = """
328absl::optional<tensorflow::gtl::FlatSet<int>> {name}(
329    const tensorflow::string &op_name) {{
330  static std::array<OpIndexInfo, {count}> a = {{{{
331""".format(
332    name=name, count=len(entries) + 1)
333  contents += "      "
334  contents += "\n      ".join(entries[op_type] for op_type in sorted(entries))
335  contents += "\n      {\"VarHandleOp\"},"
336  contents += """
337  }};
338  static const auto &m = *OpGradientInfoInit(a);
339
340  auto it = m.find(op_name);
341  if (it != m.end()) {
342    return it->second;
343  }
344  return absl::nullopt;
345}
346"""
347  return contents
348
349
350def get_contents():
351  """Returns contents for the generated file."""
352  contents = ""
353  contents += _GENERATED_FILE_HEADER + _INCLUDES
354  contents += get_function("OpGradientUnusedInputIndices",
355                           get_entries("inputs"))
356  contents += get_function("OpGradientUnusedOutputIndices",
357                           get_entries("outputs"))
358  return contents
359