1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Thread-local context managers for AutoGraph.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import threading 22 23import enum 24 25 26stacks = threading.local() 27 28 29def _control_ctx(): 30 if not hasattr(stacks, 'control_status'): 31 stacks.control_status = [_default_control_status_ctx()] 32 return stacks.control_status 33 34 35def control_status_ctx(): 36 ret = _control_ctx()[-1] 37 return ret 38 39 40class Status(enum.Enum): 41 UNSPECIFIED = 0 42 ENABLED = 1 43 DISABLED = 2 44 45 46class ControlStatusCtx(object): 47 """A context that tracks whether autograph is enabled by the user.""" 48 49 def __init__(self, status, options=None): 50 self.status = status 51 self.options = options 52 53 def __enter__(self): 54 _control_ctx().append(self) 55 return self 56 57 def __repr__(self): 58 return '{}[status={}, options={}]'.format( 59 self.__class__.__name__, self.status, self.options) 60 61 def __exit__(self, unused_type, unused_value, unused_traceback): 62 assert _control_ctx()[-1] is self 63 _control_ctx().pop() 64 65 66class NullCtx(object): 67 """Helper substitute for contextlib.nullcontext.""" 68 69 def __enter__(self): 70 pass 71 72 def __exit__(self, unused_type, unused_value, unused_traceback): 73 pass 74 75 76def _default_control_status_ctx(): 77 return ControlStatusCtx(status=Status.UNSPECIFIED) 78