• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""SavedModel utils."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23from google.protobuf import message
24from google.protobuf import text_format
25from tensorflow.core.protobuf import saved_model_pb2
26from tensorflow.python.lib.io import file_io
27from tensorflow.python.saved_model import constants
28from tensorflow.python.util import compat
29
30
31def read_saved_model(saved_model_dir):
32  """Reads the savedmodel.pb or savedmodel.pbtxt file containing `SavedModel`.
33
34  Args:
35    saved_model_dir: Directory containing the SavedModel file.
36
37  Returns:
38    A `SavedModel` protocol buffer.
39
40  Raises:
41    IOError: If the file does not exist, or cannot be successfully parsed.
42  """
43  # Build the path to the SavedModel in pbtxt format.
44  path_to_pbtxt = os.path.join(
45      compat.as_bytes(saved_model_dir),
46      compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
47  # Build the path to the SavedModel in pb format.
48  path_to_pb = os.path.join(
49      compat.as_bytes(saved_model_dir),
50      compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
51
52  # Ensure that the SavedModel exists at either path.
53  if not file_io.file_exists(path_to_pbtxt) and not file_io.file_exists(
54      path_to_pb):
55    raise IOError("SavedModel file does not exist at: %s" % saved_model_dir)
56
57  # Parse the SavedModel protocol buffer.
58  saved_model = saved_model_pb2.SavedModel()
59  if file_io.file_exists(path_to_pb):
60    try:
61      file_content = file_io.FileIO(path_to_pb, "rb").read()
62      saved_model.ParseFromString(file_content)
63      return saved_model
64    except message.DecodeError as e:
65      raise IOError("Cannot parse file %s: %s." % (path_to_pb, str(e)))
66  elif file_io.file_exists(path_to_pbtxt):
67    try:
68      file_content = file_io.FileIO(path_to_pbtxt, "rb").read()
69      text_format.Merge(file_content.decode("utf-8"), saved_model)
70      return saved_model
71    except text_format.ParseError as e:
72      raise IOError("Cannot parse file %s: %s." % (path_to_pbtxt, str(e)))
73  else:
74    raise IOError("SavedModel file does not exist at: %s/{%s|%s}" %
75                  (saved_model_dir, constants.SAVED_MODEL_FILENAME_PBTXT,
76                   constants.SAVED_MODEL_FILENAME_PB))
77
78
79def get_saved_model_tag_sets(saved_model_dir):
80  """Retrieves all the tag-sets available in the SavedModel.
81
82  Args:
83    saved_model_dir: Directory containing the SavedModel.
84
85  Returns:
86    String representation of all tag-sets in the SavedModel.
87  """
88  saved_model = read_saved_model(saved_model_dir)
89  all_tags = []
90  for meta_graph_def in saved_model.meta_graphs:
91    all_tags.append(list(meta_graph_def.meta_info_def.tags))
92  return all_tags
93
94
95def get_meta_graph_def(saved_model_dir, tag_set):
96  """Gets MetaGraphDef from SavedModel.
97
98  Returns the MetaGraphDef for the given tag-set and SavedModel directory.
99
100  Args:
101    saved_model_dir: Directory containing the SavedModel to inspect or execute.
102    tag_set: Group of tag(s) of the MetaGraphDef to load, in string format,
103        separated by ','. For tag-set contains multiple tags, all tags must be
104        passed in.
105
106  Raises:
107    RuntimeError: An error when the given tag-set does not exist in the
108        SavedModel.
109
110  Returns:
111    A MetaGraphDef corresponding to the tag-set.
112  """
113  saved_model = read_saved_model(saved_model_dir)
114  set_of_tags = set(tag_set.split(','))
115  for meta_graph_def in saved_model.meta_graphs:
116    if set(meta_graph_def.meta_info_def.tags) == set_of_tags:
117      return meta_graph_def
118
119  raise RuntimeError('MetaGraphDef associated with tag-set ' + tag_set +
120                     ' could not be found in SavedModel')
121