1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Thread-local context managers for AutoGraph.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import enum 22import threading 23 24from tensorflow.python.util.tf_export import tf_export 25 26 27stacks = threading.local() 28 29 30def _control_ctx(): 31 if not hasattr(stacks, 'control_status'): 32 stacks.control_status = [_default_control_status_ctx()] 33 return stacks.control_status 34 35 36@tf_export('__internal__.autograph.control_status_ctx', v1=[]) 37def control_status_ctx(): 38 """Returns the current control context for autograph. 39 40 This method is useful when calling `tf.__internal__.autograph.tf_convert`, 41 The context will be used by tf_convert to determine whether it should convert 42 the input function. See the sample usage like below: 43 44 ``` 45 def foo(func): 46 return tf.__internal__.autograph.tf_convert( 47 input_fn, ctx=tf.__internal__.autograph.control_status_ctx())() 48 ``` 49 50 Returns: 51 The current control context of autograph. 52 """ 53 ret = _control_ctx()[-1] 54 return ret 55 56 57class Status(enum.Enum): 58 UNSPECIFIED = 0 59 ENABLED = 1 60 DISABLED = 2 61 62 63class ControlStatusCtx(object): 64 """A context that tracks whether autograph is enabled by the user.""" 65 66 def __init__(self, status, options=None): 67 self.status = status 68 self.options = options 69 70 def __enter__(self): 71 _control_ctx().append(self) 72 return self 73 74 def __repr__(self): 75 return '{}[status={}, options={}]'.format( 76 self.__class__.__name__, self.status, self.options) 77 78 def __exit__(self, unused_type, unused_value, unused_traceback): 79 assert _control_ctx()[-1] is self 80 _control_ctx().pop() 81 82 83class NullCtx(object): 84 """Helper substitute for contextlib.nullcontext.""" 85 86 def __enter__(self): 87 pass 88 89 def __exit__(self, unused_type, unused_value, unused_traceback): 90 pass 91 92 93def _default_control_status_ctx(): 94 return ControlStatusCtx(status=Status.UNSPECIFIED) 95