1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Code to generate inputs/outputs exclusion lists for GradientTape.""" 15 16import sys 17 18import gast 19 20from tensorflow.python.autograph.pyct import anno 21from tensorflow.python.autograph.pyct import cfg 22from tensorflow.python.autograph.pyct import parser 23from tensorflow.python.autograph.pyct import qual_names 24from tensorflow.python.autograph.pyct import transformer 25from tensorflow.python.autograph.pyct.static_analysis import activity 26from tensorflow.python.autograph.pyct.static_analysis import liveness 27from tensorflow.python.autograph.pyct.static_analysis import reaching_fndefs 28from tensorflow.python.framework import op_def_registry 29from tensorflow.python.framework import ops 30 31_GENERATED_FILE_HEADER = """/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 32 33Licensed under the Apache License, Version 2.0 (the "License"); 34you may not use this file except in compliance with the License. 35You may obtain a copy of the License at 36 37 http://www.apache.org/licenses/LICENSE-2.0 38 39Unless required by applicable law or agreed to in writing, software 40distributed under the License is distributed on an "AS IS" BASIS, 41WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 42See the License for the specific language governing permissions and 43limitations under the License. 44==============================================================================*/ 45 46// Inputs/Outputs exclusion lists for GradientTape. 47// 48// This file is MACHINE GENERATED! Do not edit. 49// Generated by: tensorflow/python/eager/gen_gradient_input_output_exclusions.py 50""" 51 52_INCLUDES = """ 53#include "tensorflow/python/eager/pywrap_gradient_exclusions.h" 54 55#include "absl/types/optional.h" 56#include "tensorflow/core/lib/gtl/flatmap.h" 57#include "tensorflow/core/lib/gtl/flatset.h" 58 59using tensorflow::string; 60 61namespace { 62// Keep static data in a format that's easy to init statically. 63struct OpIndexInfo { 64 const char *op_name; 65 int num_indices; 66 std::array<int, 4> unused_indices; 67}; 68 69// Helper function to initialize FlatMap<string,FlatSet> from OpIndexInfo. 70template <typename T> 71auto OpGradientInfoInit(const T &a) { 72 auto *m = new tensorflow::gtl::FlatMap<string, tensorflow::gtl::FlatSet<int>>; 73 for (const auto &item : a) { 74 m->emplace(string(item.op_name), 75 tensorflow::gtl::FlatSet<int>( 76 item.unused_indices.begin(), 77 item.unused_indices.begin() + item.num_indices)); 78 } 79 return m; 80} 81} // namespace 82""" 83 84_EXCLUDED_OPS = [ 85 # Composite ops with custom gradient functions. 86 "If", 87 "StatelessIf", 88 "While", 89 "StatelessWhile", 90 "Case", 91 92 # TF Lite. These ops only appear in OSS. 93 # TODO(srbs): Find a better way to filter these out. 94 "AudioMicrofrontend", 95] 96 97 98class _SubscriptUseTracker(transformer.Base): 99 """Track uses of composite names, excluding certain names when subscripted.""" 100 101 def __init__(self, ctx, exclude_when_subscripted): 102 super(_SubscriptUseTracker, self).__init__(ctx) 103 self.exclude = exclude_when_subscripted 104 self.reads = set() 105 self.complex_reads = set() 106 107 def visit_Attribute(self, node): 108 """Visits attribute nodes in the AST.""" 109 if anno.hasanno(node, anno.Basic.QN): 110 qn = anno.getanno(node, anno.Basic.QN) 111 if isinstance(node.ctx, gast.Load): 112 self.reads.add(qn) 113 node = self.generic_visit(node) 114 return node 115 116 def visit_Subscript(self, node): 117 """Visits nodes with subscript in the AST.""" 118 s = node.slice 119 if anno.hasanno(node, anno.Basic.QN): 120 qn = anno.getanno(node, anno.Basic.QN) 121 if isinstance(node.ctx, gast.Load): 122 self.reads.add(qn) 123 elif isinstance(s, (gast.Tuple, gast.Slice)): 124 if anno.hasanno(node.value, anno.Basic.QN): 125 self.complex_reads.add(anno.getanno(node.value, anno.Basic.QN)) 126 value_qn = anno.getanno(node.value, anno.Basic.QN, None) 127 if value_qn in self.exclude: 128 node.value = self.generic_visit(node.value) 129 else: 130 node.value = self.visit(node.value) 131 node.slice = self.visit(s) 132 return node 133 134 135class _FunctionCallsTracker(transformer.Base): 136 """Tracks any function calls made with a given first argument name.""" 137 138 def __init__(self, ctx, first_argument_name): 139 super(_FunctionCallsTracker, self).__init__(ctx) 140 self.first_argument_name = first_argument_name 141 self.calls = set() 142 143 def visit_Name(self, node): 144 node = self.generic_visit(node) 145 if isinstance(node.ctx, gast.Load) and node.id in self.ctx.info.namespace: 146 anno.setanno(node, "static_value", self.ctx.info.namespace[node.id]) 147 return node 148 149 def visit_Attribute(self, node): 150 node = self.generic_visit(node) 151 parent_val = anno.getanno(node.value, "static_value", default=None) 152 if parent_val is not None: 153 if hasattr(parent_val, node.attr): 154 anno.setanno(node, "static_value", getattr(parent_val, node.attr)) 155 return node 156 157 def visit_Call(self, node): 158 node = self.generic_visit(node) 159 if (node.args and anno.getanno(node.args[0], anno.Basic.QN, 160 None) == self.first_argument_name): 161 fn_object = anno.getanno(node.func, "static_value", None) 162 if fn_object is not None: 163 self.calls.add(fn_object) 164 return node 165 166 167_ALL = object() 168 169 170def _live_tensors(f, attr_name="inputs"): 171 """Returns the indices of the used inputs. 172 173 Note: This currently only handles direct index accesses e.g. op.inputs[1]. 174 If the function has slicing or list comprehension on attr_name then returns 175 _ALL. This ensure that this is correct even if inefficient. 176 177 Args: 178 f: A grad function, taking the op as first argument. 179 attr_name: op attr to track. "inputs" or "outputs". 180 181 Returns: 182 Either one of: 183 * set of integers representing individual indices of inputs used 184 * the value _ALL, if indices are used but cannot be determined which 185 * empty set, if no inputs are used 186 """ 187 node, _ = parser.parse_entity(f, ()) 188 entity_info = transformer.EntityInfo( 189 name=f.__name__, 190 source_code=None, 191 source_file=None, 192 future_features=(), 193 namespace=sys.modules[f.__module__].__dict__) 194 ctx = transformer.Context(entity_info, None, None) 195 196 graphs = cfg.build(node) 197 node = qual_names.resolve(node) 198 node = activity.resolve(node, ctx, None) 199 node = reaching_fndefs.resolve(node, ctx, graphs) 200 node = liveness.resolve(node, ctx, graphs) 201 202 op_arg_name = anno.getanno(node.args.args[0], anno.Basic.QN) 203 op_inputs_outputs_name = qual_names.QN(op_arg_name, attr=attr_name) 204 205 special_tracker = _SubscriptUseTracker(ctx, (op_inputs_outputs_name,)) 206 node = special_tracker.visit(node) 207 208 live_vars_in = anno.getanno(node.body[0], anno.Static.LIVE_VARS_IN) 209 inputs_outputs_used_qns = set() 210 for v in special_tracker.complex_reads: 211 # Complicated patterns like op.inputs[:3]. Could be smarter about them 212 # if they matter much. 213 if v == op_inputs_outputs_name: 214 return _ALL 215 for v in live_vars_in: 216 if v in special_tracker.reads: 217 if (v.has_subscript() and v.parent == op_inputs_outputs_name): 218 inputs_outputs_used_qns.add(v) 219 elif v == op_inputs_outputs_name: 220 # When op.{attr_name} is used directly, assume all tensors are 221 # used for now. In that case, no point digging further. 222 # TODO(mdan): We can descend into tuple expansions. 223 return _ALL 224 225 function_calls_tracker = _FunctionCallsTracker(ctx, op_arg_name) 226 node = function_calls_tracker.visit(node) 227 228 input_output_indices = set() 229 230 for called_f in function_calls_tracker.calls: 231 child_indices = _live_tensors(called_f, attr_name=attr_name) 232 if child_indices is _ALL: 233 return _ALL 234 input_output_indices |= child_indices 235 236 for v in inputs_outputs_used_qns: 237 assert v.has_subscript() 238 _, subscript = v.qn 239 if not subscript.is_simple(): 240 # Not a number, assuming it can be anything. 241 return _ALL 242 subscript_val, = subscript.qn 243 if (not isinstance(subscript_val, qual_names.Literal) and 244 not isinstance(subscript_val.value, int)): 245 # Not a number, assuming it can be anything. 246 return _ALL 247 input_output_indices.add(subscript_val.value) 248 return input_output_indices 249 250 251def _get_num_inputs_outputs(op_type): 252 """Returns (num_inputs, num_outputs). 253 254 Args: 255 op_type: String. The type of the Operation. Used to lookup the op in the 256 registry. 257 258 Returns: 259 (num_inputs, num_outputs), for either num_inputs or num_outputs if the value 260 can't be statically inferred from the OpDef alone or of the OpDef lookup 261 fails, -1 is returned. 262 """ 263 264 def _is_list_arg(arg): 265 return arg.number_attr or arg.type_list_attr 266 267 def _count_args(arg_defs): 268 for arg in arg_defs: 269 if _is_list_arg(arg): 270 # Op has list type args which could be variable. 271 return -1 272 return len(arg_defs) 273 274 op_def = op_def_registry.get(op_type) 275 if not op_def: 276 return -1, -1 277 return _count_args(op_def.input_arg), _count_args(op_def.output_arg) 278 279 280def get_entries(attr_name): 281 """Returns the dict of entries. 282 283 Each entry is of the form {op_name, {true|false, indices}} 284 285 true: All values are unused. 286 false: `indices` are the only unused indices. 287 288 Note: ops for which all values are used are not printed. 289 290 Args: 291 attr_name: inputs or outputs. 292 293 Returns: 294 A dict from op_type to formatted entry in the dict. 295 """ 296 assert attr_name in ["inputs", "outputs"] 297 entries = {} 298 for op_type in ops._gradient_registry.list(): # pylint: disable=protected-access 299 if op_type in _EXCLUDED_OPS: 300 continue 301 num_values = _get_num_inputs_outputs(op_type)[0 if attr_name == 302 "inputs" else 1] 303 gradient_fn = ops._gradient_registry.lookup(op_type) # pylint: disable=protected-access 304 if gradient_fn is None: 305 # NotDifferentiable 306 if num_values != -1: 307 entries[op_type] = "{\"%s\"}," % op_type 308 continue 309 used_tensors = _live_tensors(gradient_fn, attr_name=attr_name) 310 if used_tensors is _ALL: 311 continue 312 elif not used_tensors: 313 entries[op_type] = "{\"%s\"}," % op_type 314 else: 315 all_tensors = set(range(num_values)) 316 unused_tensors = all_tensors - used_tensors 317 if unused_tensors: 318 unused_tensor_list = sorted(list(unused_tensors)) 319 entries[op_type] = "{\"%s\", %d, {%s}}," % ( 320 op_type, len(unused_tensor_list), ", ".join( 321 str(i) for i in unused_tensor_list)) 322 return entries 323 324 325def get_function(name, entries): 326 """Generates lookup function with given name and lookup table entries.""" 327 contents = """ 328absl::optional<tensorflow::gtl::FlatSet<int>> {name}( 329 const tensorflow::string &op_name) {{ 330 static std::array<OpIndexInfo, {count}> a = {{{{ 331""".format( 332 name=name, count=len(entries) + 1) 333 contents += " " 334 contents += "\n ".join(entries[op_type] for op_type in sorted(entries)) 335 contents += "\n {\"VarHandleOp\"}," 336 contents += """ 337 }}; 338 static const auto &m = *OpGradientInfoInit(a); 339 340 auto it = m.find(op_name); 341 if (it != m.end()) { 342 return it->second; 343 } 344 return absl::nullopt; 345} 346""" 347 return contents 348 349 350def get_contents(): 351 """Returns contents for the generated file.""" 352 contents = "" 353 contents += _GENERATED_FILE_HEADER + _INCLUDES 354 contents += get_function("OpGradientUnusedInputIndices", 355 get_entries("inputs")) 356 contents += get_function("OpGradientUnusedOutputIndices", 357 get_entries("outputs")) 358 return contents 359