• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Context for building SavedModel."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import contextlib
22import threading
23
24
25class SaveContext(threading.local):
26  """A context for building a graph of SavedModel."""
27
28  def __init__(self):
29    super(SaveContext, self).__init__()
30    self._in_save_context = False
31    self._options = None
32
33  def options(self):
34    if not self.in_save_context():
35      raise ValueError("not in a SaveContext")
36    return self._options
37
38  def enter_save_context(self, options):
39    self._in_save_context = True
40    self._options = options
41
42  def exit_save_context(self):
43    self._in_save_context = False
44    self._options = None
45
46  def in_save_context(self):
47    return self._in_save_context
48
49_save_context = SaveContext()
50
51
52@contextlib.contextmanager
53def save_context(options):
54  if in_save_context():
55    raise ValueError("already in a SaveContext")
56  _save_context.enter_save_context(options)
57  try:
58    yield
59  finally:
60    _save_context.exit_save_context()
61
62
63def in_save_context():
64  """Returns whether under a save context."""
65  return _save_context.in_save_context()
66
67
68def get_save_options():
69  """Returns the save options if under a save context."""
70  return _save_context.options()
71