1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Support for wrapping converted functions bodies with auxiliary logic.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.autograph.core import ag_ctx 22from tensorflow.python.autograph.core import converter 23from tensorflow.python.autograph.operators import variables 24from tensorflow.python.framework import auto_control_deps 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_util 27from tensorflow.python.util import nest 28 29 30# TODO(mdan): Move this into operators - it represents a function definition. 31 32 33class FunctionScope(object): 34 """Context manager that wraps the body of a converted function. 35 36 This context manager handles various operations related to the scope of a 37 function: 38 * optional TF name scopes - these name scopes match the name of the 39 function, for easy visualization in tensorBoard; 40 * optional automatic control dependencies - this adds the same mechanism 41 for control dependencies that is used by `@tf.function`; it can be 42 optionally enabled when using `tf.autograph.to_graph`; 43 * tracking of autograph conversion state (whether it's enabled by the user, 44 conversion options; 45 """ 46 47 def __init__(self, function_name, scope_name, options): 48 self.name = scope_name 49 self.options = options 50 51 if options.user_requested: 52 self.autograph_ctx = ag_ctx.ControlStatusCtx(ag_ctx.Status.ENABLED, 53 options) 54 self.callopts = options.call_options() 55 56 use_name_scope = options.uses(converter.Feature.NAME_SCOPES) 57 self.use_name_scope = use_name_scope 58 if use_name_scope: 59 self.name_scope = ops.name_scope(self._sanitize(function_name)) 60 61 use_auto_deps = self.options.uses(converter.Feature.AUTO_CONTROL_DEPS) 62 self.use_auto_deps = use_auto_deps 63 if use_auto_deps: 64 self.autodeps_scope = auto_control_deps.AutomaticControlDependencies() 65 self._return_value_marked = False 66 67 def _sanitize(self, name): 68 """See https://www.tensorflow.org/api_docs/python/tf/Graph#name_scope.""" 69 # TensorFlow doesn't like leading underscores at the top level. 70 if name and name.startswith('_'): 71 name = 'fn' + name 72 return name 73 74 def __enter__(self): 75 if self.options.user_requested: 76 self.autograph_ctx.__enter__() 77 if self.use_name_scope: 78 self.name_scope.__enter__() 79 if self.use_auto_deps: 80 self.autodeps_scope.__enter__() 81 return self 82 83 def __exit__(self, exc_type, exc_val, exc_tb): 84 if self.options.user_requested: 85 self.autograph_ctx.__exit__(exc_type, exc_val, exc_tb) 86 if self.use_name_scope: 87 self.name_scope.__exit__(exc_type, exc_val, exc_tb) 88 if self.use_auto_deps: 89 self.autodeps_scope.__exit__(exc_type, exc_val, exc_tb) 90 91 def ret(self, value, did_return): 92 """Marks a value as returned from the function guarded by the scope.""" 93 del did_return 94 95 if isinstance(value, variables.UndefinedReturnValue): 96 return None 97 98 if self.use_auto_deps: 99 self._return_value_marked = True 100 if value is None: 101 # We don't create dummy returns, to preserve Python semantics. The user 102 # is responsible for adding a return value to the top-level function. 103 return None 104 105 def _mark_return_if_tensor(t): 106 if tensor_util.is_tf_type(t): 107 return self.autodeps_scope.mark_as_return(t) 108 return t 109 110 value = nest.map_structure(_mark_return_if_tensor, value) 111 return value 112 113 114def with_function_scope(thunk, scope_name, options): 115 """Inline version of the FunctionScope context manager.""" 116 with FunctionScope('lambda_', scope_name, options) as scope: 117 return thunk(scope) 118