1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""SavedModel utils.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23from google.protobuf import message 24from google.protobuf import text_format 25from tensorflow.core.protobuf import saved_model_pb2 26from tensorflow.python.lib.io import file_io 27from tensorflow.python.saved_model import constants 28from tensorflow.python.util import compat 29 30 31def read_saved_model(saved_model_dir): 32 """Reads the savedmodel.pb or savedmodel.pbtxt file containing `SavedModel`. 33 34 Args: 35 saved_model_dir: Directory containing the SavedModel file. 36 37 Returns: 38 A `SavedModel` protocol buffer. 39 40 Raises: 41 IOError: If the file does not exist, or cannot be successfully parsed. 42 """ 43 # Build the path to the SavedModel in pbtxt format. 44 path_to_pbtxt = os.path.join( 45 compat.as_bytes(saved_model_dir), 46 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)) 47 # Build the path to the SavedModel in pb format. 48 path_to_pb = os.path.join( 49 compat.as_bytes(saved_model_dir), 50 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) 51 52 # Ensure that the SavedModel exists at either path. 53 if not file_io.file_exists(path_to_pbtxt) and not file_io.file_exists( 54 path_to_pb): 55 raise IOError("SavedModel file does not exist at: %s" % saved_model_dir) 56 57 # Parse the SavedModel protocol buffer. 58 saved_model = saved_model_pb2.SavedModel() 59 if file_io.file_exists(path_to_pb): 60 try: 61 file_content = file_io.FileIO(path_to_pb, "rb").read() 62 saved_model.ParseFromString(file_content) 63 return saved_model 64 except message.DecodeError as e: 65 raise IOError("Cannot parse file %s: %s." % (path_to_pb, str(e))) 66 elif file_io.file_exists(path_to_pbtxt): 67 try: 68 file_content = file_io.FileIO(path_to_pbtxt, "rb").read() 69 text_format.Merge(file_content.decode("utf-8"), saved_model) 70 return saved_model 71 except text_format.ParseError as e: 72 raise IOError("Cannot parse file %s: %s." % (path_to_pbtxt, str(e))) 73 else: 74 raise IOError("SavedModel file does not exist at: %s/{%s|%s}" % 75 (saved_model_dir, constants.SAVED_MODEL_FILENAME_PBTXT, 76 constants.SAVED_MODEL_FILENAME_PB)) 77 78 79def get_saved_model_tag_sets(saved_model_dir): 80 """Retrieves all the tag-sets available in the SavedModel. 81 82 Args: 83 saved_model_dir: Directory containing the SavedModel. 84 85 Returns: 86 String representation of all tag-sets in the SavedModel. 87 """ 88 saved_model = read_saved_model(saved_model_dir) 89 all_tags = [] 90 for meta_graph_def in saved_model.meta_graphs: 91 all_tags.append(list(meta_graph_def.meta_info_def.tags)) 92 return all_tags 93 94 95def get_meta_graph_def(saved_model_dir, tag_set): 96 """Gets MetaGraphDef from SavedModel. 97 98 Returns the MetaGraphDef for the given tag-set and SavedModel directory. 99 100 Args: 101 saved_model_dir: Directory containing the SavedModel to inspect or execute. 102 tag_set: Group of tag(s) of the MetaGraphDef to load, in string format, 103 separated by ','. For tag-set contains multiple tags, all tags must be 104 passed in. 105 106 Raises: 107 RuntimeError: An error when the given tag-set does not exist in the 108 SavedModel. 109 110 Returns: 111 A MetaGraphDef corresponding to the tag-set. 112 """ 113 saved_model = read_saved_model(saved_model_dir) 114 set_of_tags = set(tag_set.split(',')) 115 for meta_graph_def in saved_model.meta_graphs: 116 if set(meta_graph_def.meta_info_def.tags) == set_of_tags: 117 return meta_graph_def 118 119 raise RuntimeError('MetaGraphDef associated with tag-set ' + tag_set + 120 ' could not be found in SavedModel') 121