1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Core conversion logic, serves as main point of access.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import functools 22import imp 23import unittest 24 25import gast 26 27from tensorflow.python.autograph import operators 28from tensorflow.python.autograph import utils 29from tensorflow.python.autograph.converters import arg_defaults 30from tensorflow.python.autograph.converters import asserts 31from tensorflow.python.autograph.converters import break_statements 32from tensorflow.python.autograph.converters import builtin_functions 33from tensorflow.python.autograph.converters import call_trees 34from tensorflow.python.autograph.converters import conditional_expressions 35from tensorflow.python.autograph.converters import continue_statements 36from tensorflow.python.autograph.converters import control_flow 37from tensorflow.python.autograph.converters import directives 38from tensorflow.python.autograph.converters import error_handlers 39from tensorflow.python.autograph.converters import function_scopes 40from tensorflow.python.autograph.converters import lists 41from tensorflow.python.autograph.converters import logical_expressions 42from tensorflow.python.autograph.converters import return_statements 43from tensorflow.python.autograph.converters import side_effect_guards 44from tensorflow.python.autograph.converters import slices 45from tensorflow.python.autograph.core import config 46from tensorflow.python.autograph.core import converter 47from tensorflow.python.autograph.core import errors as ag_errors 48from tensorflow.python.autograph.core import function_wrapping 49from tensorflow.python.autograph.core import naming 50from tensorflow.python.autograph.core import unsupported_features_checker 51from tensorflow.python.autograph.lang import special_functions 52from tensorflow.python.autograph.pyct import ast_util 53from tensorflow.python.autograph.pyct import compiler 54from tensorflow.python.autograph.pyct import errors 55from tensorflow.python.autograph.pyct import inspect_utils 56from tensorflow.python.autograph.pyct import origin_info 57from tensorflow.python.autograph.pyct import parser 58from tensorflow.python.autograph.pyct import pretty_printer 59from tensorflow.python.autograph.pyct import qual_names 60from tensorflow.python.autograph.pyct import templates 61from tensorflow.python.autograph.pyct import transformer 62from tensorflow.python.autograph.utils import ag_logging as logging 63from tensorflow.python.util import tf_inspect 64 65 66# TODO(mdan): Might we not need any renaming at all? 67 68 69def is_whitelisted_for_graph(o): 70 """Check whether an entity is whitelisted for use in graph mode. 71 72 Examples of whitelisted entities include all members of the tensorflow 73 package. 74 75 Args: 76 o: A Python entity. 77 Returns: 78 Boolean 79 """ 80 # TODO(b/120224672): Fix this. 81 if isinstance(o, functools.partial): 82 # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since 83 # functools.partial objects do not have a __module__ attribute. 84 m = functools 85 else: 86 m = tf_inspect.getmodule(o) 87 88 if hasattr(m, '__name__'): 89 # Builtins typically have unnamed modules. 90 for prefix, in config.DEFAULT_UNCOMPILED_MODULES: 91 if m.__name__.startswith(prefix): 92 logging.log(2, 'Whitelisted: %s: name starts with "%s"', o, prefix) 93 return True 94 95 # Temporary -- whitelist tensorboard modules. 96 # TODO(b/122731813): Remove. 97 if m.__name__ == 'tensorboard' or '.tensorboard' in m.__name__: 98 logging.log(2, 'Whitelisted: %s: name contains "tensorboard"', o) 99 return True 100 101 if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'): 102 logging.log(2, 'Whitelisted: %s: already converted', o) 103 return True 104 105 if hasattr(o, '__call__'): 106 # Callable objects: whitelisted if their __call__ method is. 107 # The type check avoids infinite recursion around the __call__ method 108 # of function objects. 109 if (type(o) != type(o.__call__)) and is_whitelisted_for_graph(o.__call__): # pylint: disable=unidiomatic-typecheck 110 logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o) 111 return True 112 113 owner_class = None 114 if tf_inspect.ismethod(o): 115 # Methods of whitelisted classes are also whitelisted, even if they are 116 # bound via user subclasses. 117 # 118 # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is 119 # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also 120 # whitelisted. 121 # 122 # class Custom(tf.Foo): 123 # pass 124 # 125 # baz = Custom() 126 # 127 # For the example above, if `Custom` did overload `bar`, then it would no 128 # longer be whitelisted. 129 130 owner_class = inspect_utils.getmethodclass(o) 131 if owner_class is not None: 132 if issubclass(owner_class, unittest.TestCase): 133 logging.log(2, 'Whitelisted: %s: method of TestCase subclass', o) 134 return True 135 136 owner_class = inspect_utils.getdefiningclass(o, owner_class) 137 if is_whitelisted_for_graph(owner_class): 138 logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o, 139 owner_class) 140 return True 141 142 if inspect_utils.isnamedtuple(o): 143 # Due to the way they're constructed, namedtuple types cannot be converted 144 # because they don't expose source code. But we assume they are safe for 145 # graph mode since they are just containers. 146 if tf_inspect.isclass(o) and len(o.__bases__) > 1: 147 logging.warn( 148 'Entity {} looks like a namedtuple subclass. Its constructor will' 149 ' not be converted by AutoGraph, but if it has any custom methods,' 150 ' those will be.'.format(o), 1) 151 logging.log(2, 'Whitelisted: %s: named tuple', o) 152 return True 153 154 logging.log(2, 'Not whitelisted: %s: default rule', o) 155 return False 156 157 158def entity_to_graph(o, program_ctx, arg_values, arg_types): 159 """Compile a Python entity into equivalent TensorFlow. 160 161 The function will also recursively compile all the entities that `o` 162 references, updating `dependency_cache`. 163 164 This function is reentrant, and relies on dependency_cache to avoid 165 generating duplicate code. 166 167 Args: 168 o: A Python entity. 169 program_ctx: A ProgramContext object. 170 arg_values: A dict containing value hints for symbols like function 171 parameters. 172 arg_types: A dict containing type hints for symbols like function 173 parameters. 174 175 Returns: 176 A tuple (ast, new_name, namespace): 177 * ast: An AST representing an entity with interface equivalent to `o`, 178 but which when executed it creates TF a graph. 179 * new_name: The symbol name under which the new entity can be found. 180 * namespace: A dict mapping all symbols visible to the converted entity, 181 keyed by their symbol name. 182 183 Raises: 184 ValueError: if the entity type is not supported. 185 """ 186 logging.log(1, 'Converting %s', o) 187 188 if tf_inspect.isclass(o): 189 nodes, name, ns = class_to_graph(o, program_ctx) 190 elif tf_inspect.isfunction(o): 191 nodes, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) 192 elif tf_inspect.ismethod(o): 193 nodes, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types) 194 # TODO(mdan,yashkatariya): Remove when object conversion is implemented. 195 elif hasattr(o, '__class__'): 196 raise NotImplementedError( 197 'Object conversion is not yet supported. If you are ' 198 'trying to convert code that uses an existing object, ' 199 'try including the creation of that object in the ' 200 'conversion. For example, instead of converting the method ' 201 'of a class, try converting the entire class instead. ' 202 'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/' 203 'python/autograph/README.md#using-the-functional-api ' 204 'for more information.') 205 else: 206 raise ValueError( 207 'Entity "%s" has unsupported type "%s". Only functions and classes are ' 208 'supported for now.' % (o, type(o))) 209 210 # TODO(mdan): This is temporary. it should be created using a converter. 211 # TODO(mdan): The attribute should be added with a helper, not directly. 212 # The helper can ensure there are no collisions. 213 template = ''' 214 entity.autograph_info__ = {} 215 ''' 216 nodes.extend(templates.replace(template, entity=name)) 217 218 if logging.has_verbosity(2): 219 logging.log(2, 'Compiled output of %s:\n\n%s\n', o, 220 compiler.ast_to_source(nodes)) 221 if logging.has_verbosity(4): 222 for n in nodes: 223 logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o, 224 pretty_printer.fmt(n, color=False)) 225 226 return nodes, name, ns 227 228 229def class_to_graph(c, program_ctx): 230 """Specialization of `entity_to_graph` for classes.""" 231 # TODO(mdan): Revisit this altogether. Not sure we still need it. 232 converted_members = {} 233 method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m) 234 members = tf_inspect.getmembers(c, predicate=method_filter) 235 if not members: 236 raise ValueError('Cannot convert %s: it has no member methods.' % c) 237 238 class_namespace = {} 239 for _, m in members: 240 # Only convert the members that are directly defined by the class. 241 if inspect_utils.getdefiningclass(m, c) is not c: 242 continue 243 nodes, _, namespace = function_to_graph( 244 m, 245 program_ctx=program_ctx, 246 arg_values={}, 247 arg_types={'self': (c.__name__, c)}, 248 do_rename=False) 249 if class_namespace is None: 250 class_namespace = namespace 251 else: 252 class_namespace.update(namespace) 253 converted_members[m] = nodes[0] 254 namer = naming.Namer(class_namespace) 255 class_name = namer.class_name(c.__name__) 256 257 # Process any base classes: if the superclass if of a whitelisted type, an 258 # absolute import line is generated. 259 output_nodes = [] 260 renames = {} 261 base_names = [] 262 for base in c.__bases__: 263 if isinstance(object, base): 264 base_names.append('object') 265 continue 266 if is_whitelisted_for_graph(base): 267 alias = namer.new_symbol(base.__name__, ()) 268 output_nodes.append( 269 gast.ImportFrom( 270 module=base.__module__, 271 names=[gast.alias(name=base.__name__, asname=alias)], 272 level=0)) 273 else: 274 raise NotImplementedError( 275 'Conversion of classes that do not directly extend classes from' 276 ' whitelisted modules is temporarily suspended. If this breaks' 277 ' existing code please notify the AutoGraph team immediately.') 278 base_names.append(alias) 279 renames[qual_names.QN(base.__name__)] = qual_names.QN(alias) 280 281 # Generate the definition of the converted class. 282 bases = [gast.Name(n, gast.Load(), None) for n in base_names] 283 class_def = gast.ClassDef( 284 class_name, 285 bases=bases, 286 keywords=[], 287 body=list(converted_members.values()), 288 decorator_list=[]) 289 # Make a final pass to replace references to the class or its base classes. 290 # Most commonly, this occurs when making super().__init__() calls. 291 # TODO(mdan): Making direct references to superclass' superclass will fail. 292 class_def = qual_names.resolve(class_def) 293 renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name) 294 class_def = ast_util.rename_symbols(class_def, renames) 295 296 output_nodes.append(class_def) 297 298 return output_nodes, class_name, class_namespace 299 300 301def _add_reserved_symbol(namespace, name, entity): 302 if name not in namespace: 303 namespace[name] = entity 304 elif namespace[name] != entity: 305 raise ValueError('The name "%s" is reserved and may not be used.' % name) 306 307 308ag_internal = None 309 310 311# TODO(mdan): Move into core or replace with an actual importable module. 312def _add_self_references(namespace, autograph_module): 313 """Adds namespace references to the module that exposes the api itself.""" 314 global ag_internal 315 if ag_internal is None: 316 # Craft a module that exposes parts of the external API as well as certain 317 # internal modules. 318 ag_internal = imp.new_module('autograph') 319 ag_internal.__dict__.update(autograph_module.__dict__) 320 ag_internal.ConversionOptions = converter.ConversionOptions 321 ag_internal.Feature = converter.Feature 322 ag_internal.utils = utils 323 ag_internal.function_scope = function_wrapping.function_scope 324 ag_internal.rewrite_graph_construction_error = ( 325 ag_errors.rewrite_graph_construction_error) 326 # TODO(mdan): Add safeguards against name clashes. 327 # We don't want to create a submodule because we want the operators to be 328 # accessible as ag__.<operator> 329 ag_internal.__dict__.update(special_functions.__dict__) 330 ag_internal.__dict__.update(operators.__dict__) 331 332 _add_reserved_symbol(namespace, 'ag__', ag_internal) 333 334 335def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True): 336 """Specialization of `entity_to_graph` for callable functions.""" 337 338 node, source, _ = parser.parse_entity(f) 339 logging.log(3, 'Source code of %s:\n\n%s\n', f, source) 340 341 # In general, the output of inspect.getsource is inexact for lambdas because 342 # it uses regex matching to adjust the exact location around the line number 343 # that CPython records. Then, the entire containing line is returned, which 344 # we may have trouble disambiguating. For example: 345 # x, y = lambda: 1, lambda: 2 346 if f.__name__ == '<lambda>': 347 nodes = ast_util.find_matching_definitions(node, f) 348 if len(nodes) != 1: 349 raise ValueError( 350 'Unable to identify source code of lambda function {}. It was' 351 ' defined on this line: {}, which must contain a single lambda with' 352 ' matching signature. To avoid ambiguity, define each lambda' 353 ' in a separate expression.'.format(f, source)) 354 node, = nodes 355 356 # TODO(znado): Place inside standard_analysis. 357 origin_info.resolve(node, source, f) 358 namespace = inspect_utils.getnamespace(f) 359 _add_self_references(namespace, program_ctx.autograph_module) 360 namer = naming.Namer(namespace) 361 362 entity_info = transformer.EntityInfo( 363 source_code=source, 364 source_file='<fragment>', 365 namespace=namespace, 366 arg_values=arg_values, 367 arg_types=arg_types) 368 context = converter.EntityContext(namer, entity_info, program_ctx) 369 try: 370 node = node_to_graph(node, context) 371 except (ValueError, AttributeError, KeyError, NotImplementedError) as e: 372 logging.error(1, 'Error converting %s', f, exc_info=True) 373 raise errors.InternalError('conversion', e) 374 # TODO(mdan): Catch and rethrow syntax errors. 375 376 if isinstance(node, gast.Lambda): 377 new_name = namer.new_symbol('tf__lambda', ()) 378 node = gast.Assign( 379 targets=[gast.Name(new_name, gast.Store(), None)], value=node) 380 381 elif do_rename: 382 # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py 383 new_name = namer.function_name(f.__name__) 384 node.name = new_name 385 else: 386 new_name = f.__name__ 387 assert node.name == new_name 388 389 return [node], new_name, namespace 390 391 392def node_to_graph(node, context): 393 """Convert Python code to equivalent TF graph mode code. 394 395 Args: 396 node: AST, the code to convert. 397 context: converter.EntityContext 398 399 Returns: 400 A tuple (node, deps): 401 * node: A Python ast node, representing the converted code. 402 * deps: A set of strings, the fully qualified names of entity 403 dependencies that this node has. 404 """ 405 # TODO(mdan): Insert list_comprehensions somewhere. 406 unsupported_features_checker.verify(node) 407 408 node = converter.standard_analysis(node, context, is_initial=True) 409 # Past this point, line numbers are no longer accurate so we ignore the 410 # source. 411 # TODO(mdan): Is it feasible to reconstruct intermediate source code? 412 context.info.source_code = None 413 node = converter.apply_(node, context, arg_defaults) 414 node = converter.apply_(node, context, directives) 415 node = converter.apply_(node, context, break_statements) 416 if context.program.options.uses(converter.Feature.ASSERT_STATEMENTS): 417 node = converter.apply_(node, context, asserts) 418 # Note: sequencing continue canonicalization before for loop one avoids 419 # dealing with the extra loop increment operation that the for 420 # canonicalization creates. 421 node = converter.apply_(node, context, continue_statements) 422 node = converter.apply_(node, context, return_statements) 423 if context.program.options.uses(converter.Feature.LISTS): 424 node = converter.apply_(node, context, lists) 425 node = converter.apply_(node, context, slices) 426 if context.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS): 427 node = converter.apply_(node, context, builtin_functions) 428 node = converter.apply_(node, context, call_trees) 429 node = converter.apply_(node, context, control_flow) 430 node = converter.apply_(node, context, conditional_expressions) 431 if context.program.options.uses(converter.Feature.LOGICAL_EXPRESSIONS): 432 node = converter.apply_(node, context, logical_expressions) 433 if context.program.options.uses(converter.Feature.AUTO_CONTROL_DEPS): 434 node = converter.apply_(node, context, side_effect_guards) 435 # TODO(mdan): If function scopes ever does more, the toggle will need moving. 436 if context.program.options.uses(converter.Feature.NAME_SCOPES): 437 node = converter.apply_(node, context, function_scopes) 438 if context.program.options.uses(converter.Feature.ERROR_REWRITING): 439 node = converter.apply_(node, context, error_handlers) 440 return node 441