1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Converter construction support. 16 17This module contains a base class for all converters, as well as supporting 18structures. These structures are referred to as contexts. 19 20The class hierarchy is as follows: 21 22 <your converter> 23 [extends] converter.Base 24 [extends] transformer.Base 25 [extends] gast.nodeTransformer 26 [uses] transformer.SourceInfo 27 [uses] converter.EntityContext 28 [uses] converter.ProgramContext 29 [uses] transformer.SourceInfo 30 31converter.Base is a specialization of transformer.Base for AutoGraph. It's a 32very lightweight subclass that adds a `ctx` attribute holding the corresponding 33EntityContext object (see below). Note that converters are not reusable, and 34`visit` will raise an error if called more than once. 35 36converter.EntityContext contains mutable state associated with an entity that 37the converter processes. 38 39converter.ProgramContext contains mutable state across related entities. For 40example, when converting several functions that call one another, the 41ProgramContext should be shared across these entities. 42 43Below is the overall flow at conversion: 44 45 program_ctx = ProgramContext(<entities to convert>, <global settings>, ...) 46 while <program_ctx has more entities to convert>: 47 entity, source_info = <get next entity from program_ctx> 48 entity_ctx = EntityContext(program_ctx, source_info) 49 for <each ConverterClass>: 50 converter = ConverterClass(entity_ctx) 51 52 # May update entity_ctx and program_ctx 53 entity = converter.visit(entity) 54 55 <add entity's dependencies to program_ctx> 56 57Note that pyct contains a small number of transformers used for static analysis. 58These implement transformer.Base, rather than converter.Base, to avoid a 59dependency on AutoGraph. 60""" 61 62from __future__ import absolute_import 63from __future__ import division 64from __future__ import print_function 65 66import enum 67 68from tensorflow.python.autograph.pyct import anno 69from tensorflow.python.autograph.pyct import ast_util 70from tensorflow.python.autograph.pyct import parser 71from tensorflow.python.autograph.pyct import templates 72from tensorflow.python.autograph.pyct import transformer 73from tensorflow.python.util.tf_export import tf_export 74 75# TODO(mdan): These contexts can be refactored into first class objects. 76# For example, we could define Program and Entity abstractions that hold on 77# to the actual entity and have conversion methods. 78 79# TODO(mdan): Add a test specific to this converter. 80 81 82@tf_export('autograph.experimental.Feature') 83class Feature(enum.Enum): 84 """This enumeration represents optional conversion options. 85 86 These conversion options are experimental. They are subject to change without 87 notice and offer no guarantees. 88 89 _Example Usage_ 90 91 ```python 92 optionals= tf.autograph.experimental.Feature.EQUALITY_OPERATORS 93 @tf.function(experimental_autograph_options=optionals) 94 def f(i): 95 if i == 0: # EQUALITY_OPERATORS allows the use of == here. 96 tf.print('i is zero') 97 ``` 98 99 Attributes: 100 ALL: Enable all features. 101 AUTO_CONTROL_DEPS: Insert of control dependencies in the generated code. 102 ASSERT_STATEMENTS: Convert Tensor-dependent assert statements to tf.Assert. 103 BUILTIN_FUNCTIONS: Convert builtin functions applied to Tensors to 104 their TF counterparts. 105 EQUALITY_OPERATORS: Whether to convert the comparison operators, like 106 equality. This is soon to be deprecated as support is being added to the 107 Tensor class. 108 LISTS: Convert list idioms, like initializers, slices, append, etc. 109 NAME_SCOPES: Insert name scopes that name ops according to context, like the 110 function they were defined in. 111 """ 112 113 ALL = 'ALL' 114 115 AUTO_CONTROL_DEPS = 'AUTO_CONTROL_DEPS' 116 ASSERT_STATEMENTS = 'ASSERT_STATEMENTS' 117 BUILTIN_FUNCTIONS = 'BUILTIN_FUNCTIONS' 118 EQUALITY_OPERATORS = 'EQUALITY_OPERATORS' 119 LISTS = 'LISTS' 120 NAME_SCOPES = 'NAME_SCOPES' 121 122 @classmethod 123 def all(cls): 124 """Returns a tuple that enables all options.""" 125 return tuple(cls.__members__.values()) 126 127 @classmethod 128 def all_but(cls, exclude): 129 """Returns a tuple that enables all but the excluded options.""" 130 if not isinstance(exclude, (list, tuple, set)): 131 exclude = (exclude,) 132 return tuple(set(cls.all()) - set(exclude) - {cls.ALL}) 133 134 135STANDARD_OPTIONS = None # Forward definition. 136 137 138class ConversionOptions(object): 139 """Immutable container for global conversion flags. 140 141 Attributes: 142 recursive: bool, whether to recursively convert any user functions or 143 classes that the converted function may use. 144 user_requested: bool, whether the conversion was explicitly requested by 145 the user, as opposed to being performed as a result of other logic. This 146 value always auto-resets to False in child conversions. 147 optional_features: Union[Feature, Set[Feature]], controls the use of 148 optional features in the conversion process. See Feature for available 149 options. 150 """ 151 152 def __init__(self, 153 recursive=False, 154 user_requested=False, 155 internal_convert_user_code=True, 156 optional_features=Feature.ALL): 157 self.recursive = recursive 158 self.user_requested = user_requested 159 # TODO(mdan): Rename to conversion_recursion_depth? 160 self.internal_convert_user_code = internal_convert_user_code 161 162 if optional_features is None: 163 optional_features = () 164 elif isinstance(optional_features, Feature): 165 optional_features = (optional_features,) 166 optional_features = frozenset(optional_features) 167 self.optional_features = optional_features 168 169 def as_tuple(self): 170 return (self.recursive, self.user_requested, 171 self.internal_convert_user_code, self.optional_features) 172 173 def __hash__(self): 174 return hash(self.as_tuple()) 175 176 def __eq__(self, other): 177 assert isinstance(other, ConversionOptions) 178 return self.as_tuple() == other.as_tuple() 179 180 def __str__(self): 181 return 'ConversionOptions[{}]' 182 183 def uses(self, feature): 184 return (Feature.ALL in self.optional_features or 185 feature in self.optional_features) 186 187 def call_options(self): 188 """Returns the corresponding options to be used for recursive conversion.""" 189 return ConversionOptions( 190 recursive=self.recursive, 191 user_requested=False, 192 internal_convert_user_code=self.recursive, 193 optional_features=self.optional_features) 194 195 def to_ast(self): 196 """Returns a representation of this object as an AST node. 197 198 The AST node encodes a constructor that would create an object with the 199 same contents. 200 201 Returns: 202 ast.Node 203 """ 204 if self == STANDARD_OPTIONS: 205 return parser.parse_expression('ag__.STD') 206 207 template = """ 208 ag__.ConversionOptions( 209 recursive=recursive_val, 210 user_requested=user_requested_val, 211 optional_features=optional_features_val, 212 internal_convert_user_code=internal_convert_user_code_val) 213 """ 214 215 def list_of_features(values): 216 return parser.parse_expression('({})'.format(', '.join( 217 'ag__.{}'.format(str(v)) for v in values))) 218 219 expr_ast = templates.replace( 220 template, 221 recursive_val=parser.parse_expression(str(self.recursive)), 222 user_requested_val=parser.parse_expression(str(self.user_requested)), 223 internal_convert_user_code_val=parser.parse_expression( 224 str(self.internal_convert_user_code)), 225 optional_features_val=list_of_features(self.optional_features)) 226 return expr_ast[0].value 227 228 229STANDARD_OPTIONS = ConversionOptions( 230 recursive=True, 231 user_requested=False, 232 internal_convert_user_code=True, 233 optional_features=None) 234 235 236class ProgramContext(object): 237 """ProgramContext keeps track of converting function hierarchies. 238 239 Attributes: 240 options: ConversionOptions 241 autograph_module: Deprecated. Do not use. 242 """ 243 244 def __init__(self, options, autograph_module=None): 245 self.options = options 246 self.autograph_module = autograph_module 247 248 249class Base(transformer.Base): 250 """All converters should inherit from this class. 251 252 Attributes: 253 ctx: EntityContext 254 """ 255 256 def __init__(self, ctx): 257 super(Base, self).__init__(ctx) 258 259 self._used = False 260 self._ast_depth = 0 261 262 def get_definition_directive(self, node, directive, arg, default): 263 """Returns the unique directive argument for a symbol. 264 265 See lang/directives.py for details on directives. 266 267 Example: 268 # Given a directive in the code: 269 ag.foo_directive(bar, baz=1) 270 271 # One can write for an AST node Name(id='bar'): 272 get_definition_directive(node, ag.foo_directive, 'baz') 273 274 Args: 275 node: ast.AST, the node representing the symbol for which the directive 276 argument is needed. 277 directive: Callable[..., Any], the directive to search. 278 arg: str, the directive argument to return. 279 default: Any 280 281 Raises: 282 ValueError: if conflicting annotations have been found 283 """ 284 defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ()) 285 if not defs: 286 return default 287 288 arg_values_found = [] 289 for def_ in defs: 290 if (directive in def_.directives and arg in def_.directives[directive]): 291 arg_values_found.append(def_.directives[directive][arg]) 292 293 if not arg_values_found: 294 return default 295 296 if len(arg_values_found) == 1: 297 return arg_values_found[0] 298 299 # If multiple annotations reach the symbol, they must all match. If they do, 300 # return any of them. 301 first_value = arg_values_found[0] 302 for other_value in arg_values_found[1:]: 303 if not ast_util.matches(first_value, other_value): 304 qn = anno.getanno(node, anno.Basic.QN) 305 raise ValueError( 306 '%s has ambiguous annotations for %s(%s): %s, %s' % 307 (qn, directive.__name__, arg, parser.unparse(other_value).strip(), 308 parser.unparse(first_value).strip())) 309 return first_value 310 311 def visit(self, node): 312 if not self._ast_depth: 313 if self._used: 314 raise ValueError('converter objects cannot be reused') 315 self._used = True 316 317 self._ast_depth += 1 318 try: 319 return super(Base, self).visit(node) 320 finally: 321 self._ast_depth -= 1 322