1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Converter construction support. 16 17This module contains a base class for all converters, as well as supporting 18structures. These structures are referred to as contexts. 19 20The class hierarchy is as follows: 21 22 <your converter> 23 [extends] converter.Base 24 [extends] transformer.Base 25 [extends] gast.nodeTransformer 26 [uses] transfomer.SourceInfo 27 [uses] converter.EntityContext 28 [uses] converter.ProgramContext 29 [uses] transfomer.SourceInfo 30 31converter.Base is a specialization of transformer.Base for AutoGraph. It's a 32very lightweight subclass that adds a `ctx` attribute holding the corresponding 33EntityContext object (see below). Note that converters are not reusable, and 34`visit` will raise an error if called more than once. 35 36converter.EntityContext contains mutable state associated with an entity that 37the converter processes. 38 39converter.ProgramContext contains mutable state across related entities. For 40example, when converting several functions that call one another, the 41ProgramContext should be shared across these entities. 42 43Below is the overall flow at conversion: 44 45 program_ctx = ProgramContext(<entities to convert>, <global settings>, ...) 46 while <program_ctx has more entities to convert>: 47 entity, source_info = <get next entity from program_ctx> 48 entity_ctx = EntityContext(program_ctx, source_info) 49 for <each ConverterClass>: 50 converter = ConverterClass(entity_ctx) 51 52 # May update entity_ctx and program_ctx 53 entity = converter.visit(entity) 54 55 <add entity's dependencies to program_ctx> 56 57Note that pyct contains a small number of transformers used for static analysis. 58These implement transformer.Base, rather than converter.Base, to avoid a 59dependency on AutoGraph. 60""" 61 62from __future__ import absolute_import 63from __future__ import division 64from __future__ import print_function 65 66import enum 67 68from tensorflow.python.autograph.core import config 69from tensorflow.python.autograph.pyct import anno 70from tensorflow.python.autograph.pyct import ast_util 71from tensorflow.python.autograph.pyct import cfg 72from tensorflow.python.autograph.pyct import compiler 73from tensorflow.python.autograph.pyct import parser 74from tensorflow.python.autograph.pyct import qual_names 75from tensorflow.python.autograph.pyct import templates 76from tensorflow.python.autograph.pyct import transformer 77from tensorflow.python.autograph.pyct.static_analysis import activity 78from tensorflow.python.autograph.pyct.static_analysis import live_values 79from tensorflow.python.autograph.pyct.static_analysis import liveness 80from tensorflow.python.autograph.pyct.static_analysis import reaching_definitions 81from tensorflow.python.autograph.pyct.static_analysis import type_info 82from tensorflow.python.util.tf_export import tf_export 83 84# TODO(mdan): These contexts can be refactored into first class objects. 85# For example, we could define Program and Entity abstractions that hold on 86# to the actual entity and have conversion methods. 87 88# TODO(mdan): Add a test specific to this converter. 89 90 91@tf_export('autograph.experimental.Feature') 92class Feature(enum.Enum): 93 """Represents conversion options that can be toggled on or off. 94 95 Attributes: 96 ALL: Enable all features. 97 AUTO_CONTROL_DEPS: Insert of control dependencies in the generated code. 98 ASSERT_STATEMENTS: Convert Tensor-dependent assert statements to tf.Assert. 99 BUILTIN_FUNCTIONS: Convert builtin functions applied to Tensors to 100 their TF counterparts. 101 ERROR_REWRITING: Rewrite errors that occur in the generated code to 102 indicate the source code to which the failing code corresponds. 103 LISTS: Convert list idioms, like initializers, slices, append, etc. 104 LOGICAL_EXPRESSIONS: Convert data-dependent logical expressions applied to 105 Tensors to their TF counterparts. 106 NAME_SCOPES: Insert name scopes that name ops according to context, like the 107 function they were defined in. 108 """ 109 110 ALL = 'ALL' 111 112 AUTO_CONTROL_DEPS = 'AUTO_CONTROL_DEPS' 113 ASSERT_STATEMENTS = 'ASSERT_STATEMENTS' 114 BUILTIN_FUNCTIONS = 'BUILTIN_FUNCTIONS' 115 ERROR_REWRITING = 'ERROR_REWRITING' 116 LISTS = 'LISTS' 117 LOGICAL_EXPRESSIONS = 'LOGICAL_EXPRESSIONS' 118 NAME_SCOPES = 'NAME_SCOPES' 119 120 @classmethod 121 def all(cls): 122 """Returns a tuple that enables all options.""" 123 return tuple(cls.__members__.values()) 124 125 @classmethod 126 def all_but(cls, exclude): 127 """Returns a tuple that enables all but the excluded options.""" 128 if not isinstance(exclude, (list, tuple, set)): 129 exclude = (exclude,) 130 return tuple(set(cls.all()) - set(exclude) - {cls.ALL}) 131 132 133class ConversionOptions(object): 134 """Immutable container for global conversion flags. 135 136 Attributes: 137 recursive: bool, whether to recursively convert any user functions or 138 classes that the converted function may use. 139 force_conversion: bool, whether to force convertinng the target entity. When 140 force_conversion is turned off, the converter may decide to return the 141 function as-is. 142 optional_features: Union[Feature, Set[Feature]], controls the use of 143 optional features in the conversion process. See Feature for available 144 options. 145 """ 146 147 def __init__(self, 148 recursive=False, 149 force_conversion=False, 150 internal_convert_user_code=True, 151 optional_features=Feature.ALL): 152 self.recursive = recursive 153 self.force_conversion = force_conversion 154 # TODO(mdan): Rename to conversion_recursion_depth? 155 self.internal_convert_user_code = internal_convert_user_code 156 157 if optional_features is None: 158 optional_features = () 159 elif isinstance(optional_features, Feature): 160 optional_features = (optional_features,) 161 optional_features = frozenset(optional_features) 162 self.optional_features = optional_features 163 164 def uses(self, feature): 165 return (Feature.ALL in self.optional_features or 166 feature in self.optional_features) 167 168 def to_ast(self, internal_convert_user_code=None): 169 """Returns a representation of this object as an AST node. 170 171 The AST node encodes a constructor that would create an object with the 172 same contents. 173 174 Args: 175 internal_convert_user_code: Optional[bool], allows ovrriding the 176 corresponding value. 177 178 Returns: 179 ast.Node 180 """ 181 template = """ 182 ag__.ConversionOptions( 183 recursive=recursive_val, 184 force_conversion=force_conversion_val, 185 optional_features=optional_features_val, 186 internal_convert_user_code=internal_convert_user_code_val) 187 """ 188 189 def list_of_features(values): 190 return parser.parse_expression('({})'.format(', '.join( 191 'ag__.{}'.format(str(v)) for v in values))) 192 193 if internal_convert_user_code is None: 194 internal_convert_user_code = self.internal_convert_user_code 195 196 expr_ast = templates.replace( 197 template, 198 recursive_val=parser.parse_expression(str(self.recursive)), 199 force_conversion_val=parser.parse_expression( 200 str(self.force_conversion)), 201 internal_convert_user_code_val=parser.parse_expression( 202 str(internal_convert_user_code)), 203 optional_features_val=list_of_features(self.optional_features)) 204 return expr_ast[0].value 205 206 207class ProgramContext(object): 208 """ProgramContext keeps track of converting function hierarchies. 209 210 This object is mutable, and is updated during conversion. Not thread safe. 211 212 Attributes: 213 options: ConversionOptions 214 autograph_module: Module, a reference to the autograph module. This needs to 215 be specified by the caller to avoid circular dependencies. 216 required_imports: str, containing an import statement on each line. These 217 are all the imports necessary for the compiled code to run, in addition to 218 the closures of each entity, which are attached dynamically. 219 """ 220 221 def __init__( 222 self, 223 options, 224 autograph_module, 225 ): 226 self.options = options 227 self.autograph_module = autograph_module 228 229 @property 230 def required_imports(self): 231 """Returns a block containing all imports required by the converted code.""" 232 # TODO(mdan): Check that these don't clobber one another. 233 return '\n'.join(config.COMPILED_IMPORT_STATEMENTS) 234 235 236class EntityContext(transformer.Context): 237 """Tracks the conversion of a single entity. 238 239 This object is mutable, and is updated during conversion. Not thread safe. 240 241 Attributes: 242 namer: Namer 243 info: transformer.EntityInfo 244 program: ProgramContext 245 """ 246 247 def __init__(self, namer, entity_info, program_ctx): 248 super(EntityContext, self).__init__(entity_info) 249 self.namer = namer 250 self.program = program_ctx 251 252 253class Base(transformer.Base): 254 """All converters should inherit from this class. 255 256 Attributes: 257 ctx: EntityContext 258 """ 259 260 def __init__(self, ctx): 261 super(Base, self).__init__(ctx) 262 263 self._used = False 264 self._ast_depth = 0 265 266 def get_definition_directive(self, node, directive, arg, default): 267 """Returns the unique directive argument for a symbol. 268 269 See lang/directives.py for details on directives. 270 271 Example: 272 # Given a directive in the code: 273 ag.foo_directive(bar, baz=1) 274 275 # One can write for an AST node Name(id='bar'): 276 get_definition_directive(node, ag.foo_directive, 'baz') 277 278 Args: 279 node: ast.AST, the node representing the symbol for which the directive 280 argument is needed. 281 directive: Callable[..., Any], the directive to search. 282 arg: str, the directive argument to return. 283 default: Any 284 285 Raises: 286 ValueError: if conflicting annotations have been found 287 """ 288 defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ()) 289 if not defs: 290 return default 291 292 arg_values_found = [] 293 for def_ in defs: 294 if (directive in def_.directives and arg in def_.directives[directive]): 295 arg_values_found.append(def_.directives[directive][arg]) 296 297 if not arg_values_found: 298 return default 299 300 if len(arg_values_found) == 1: 301 return arg_values_found[0] 302 303 # If multiple annotations reach the symbol, they must all match. If they do, 304 # return any of them. 305 first_value = arg_values_found[0] 306 for other_value in arg_values_found[1:]: 307 if not ast_util.matches(first_value, other_value): 308 qn = anno.getanno(node, anno.Basic.QN) 309 raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' % 310 (qn, directive.__name__, arg, 311 compiler.ast_to_source(other_value).strip(), 312 compiler.ast_to_source(first_value).strip())) 313 return first_value 314 315 def visit(self, node): 316 if not self._ast_depth: 317 if self._used: 318 raise ValueError('converter objects cannot be reused') 319 self._used = True 320 321 self._ast_depth += 1 322 try: 323 return super(Base, self).visit(node) 324 finally: 325 self._ast_depth -= 1 326 327 328class AnnotatedDef(reaching_definitions.Definition): 329 330 def __init__(self): 331 super(AnnotatedDef, self).__init__() 332 self.directives = {} 333 334 335class AgAnno(enum.Enum): 336 """Annotation labels specific to AutoGraph. See anno.py.""" 337 338 DIRECTIVES = 'User directives associated with the annotated statement.' 339 340 def __repr__(self): 341 return self.name 342 343 344def standard_analysis(node, context, is_initial=False): 345 """Performs a complete static analysis of the given code. 346 347 Args: 348 node: ast.AST 349 context: converter.EntityContext 350 is_initial: bool, whether this is the initial analysis done on the input 351 source code 352 353 Returns: 354 ast.AST, same as node, with the static analysis annotations added 355 """ 356 # TODO(mdan): Clear static analysis here. 357 # TODO(mdan): Consider not running all analyses every time. 358 # TODO(mdan): Don't return a node because it's modified by reference. 359 graphs = cfg.build(node) 360 node = qual_names.resolve(node) 361 node = activity.resolve(node, context, None) 362 node = reaching_definitions.resolve(node, context, graphs, AnnotatedDef) 363 node = liveness.resolve(node, context, graphs) 364 node = live_values.resolve(node, context, config.PYTHON_LITERALS) 365 node = type_info.resolve(node, context) 366 # This second call allows resolving first-order class attributes. 367 node = live_values.resolve(node, context, config.PYTHON_LITERALS) 368 if is_initial: 369 anno.dup( 370 node, 371 { 372 anno.Static.DEFINITIONS: anno.Static.ORIG_DEFINITIONS, 373 }, 374 ) 375 return node 376 377 378def apply_(node, context, converter_module): 379 """Applies a converter to an AST. 380 381 Args: 382 node: ast.AST 383 context: converter.EntityContext 384 converter_module: converter.Base 385 386 Returns: 387 ast.AST, the result of applying converter to node 388 """ 389 node = standard_analysis(node, context) 390 node = converter_module.transform(node, context) 391 return node 392