1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""SavedModel utility functions implementation.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23from tensorflow.core.framework import types_pb2 24from tensorflow.core.protobuf import meta_graph_pb2 25from tensorflow.core.protobuf import struct_pb2 26from tensorflow.python.eager import context 27from tensorflow.python.framework import composite_tensor 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import sparse_tensor 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.lib.io import file_io 33from tensorflow.python.saved_model import constants 34from tensorflow.python.saved_model import nested_structure_coder 35from tensorflow.python.util import compat 36from tensorflow.python.util import deprecation 37from tensorflow.python.util import nest 38from tensorflow.python.util.tf_export import tf_export 39 40 41# TensorInfo helpers. 42 43 44@tf_export(v1=["saved_model.build_tensor_info", 45 "saved_model.utils.build_tensor_info"]) 46@deprecation.deprecated( 47 None, 48 "This function will only be available through the v1 compatibility " 49 "library as tf.compat.v1.saved_model.utils.build_tensor_info or " 50 "tf.compat.v1.saved_model.build_tensor_info.") 51def build_tensor_info(tensor): 52 """Utility function to build TensorInfo proto from a Tensor. 53 54 Args: 55 tensor: Tensor or SparseTensor whose name, dtype and shape are used to 56 build the TensorInfo. For SparseTensors, the names of the three 57 constituent Tensors are used. 58 59 Returns: 60 A TensorInfo protocol buffer constructed based on the supplied argument. 61 62 Raises: 63 RuntimeError: If eager execution is enabled. 64 """ 65 if context.executing_eagerly(): 66 raise RuntimeError("build_tensor_info is not supported in Eager mode.") 67 return build_tensor_info_internal(tensor) 68 69 70def build_tensor_info_internal(tensor): 71 """Utility function to build TensorInfo proto from a Tensor.""" 72 if (isinstance(tensor, composite_tensor.CompositeTensor) and 73 not isinstance(tensor, sparse_tensor.SparseTensor)): 74 return _build_composite_tensor_info_internal(tensor) 75 76 tensor_info = meta_graph_pb2.TensorInfo( 77 dtype=dtypes.as_dtype(tensor.dtype).as_datatype_enum, 78 tensor_shape=tensor.get_shape().as_proto()) 79 if isinstance(tensor, sparse_tensor.SparseTensor): 80 tensor_info.coo_sparse.values_tensor_name = tensor.values.name 81 tensor_info.coo_sparse.indices_tensor_name = tensor.indices.name 82 tensor_info.coo_sparse.dense_shape_tensor_name = tensor.dense_shape.name 83 else: 84 tensor_info.name = tensor.name 85 return tensor_info 86 87 88def _build_composite_tensor_info_internal(tensor): 89 """Utility function to build TensorInfo proto from a CompositeTensor.""" 90 spec = tensor._type_spec # pylint: disable=protected-access 91 tensor_info = meta_graph_pb2.TensorInfo() 92 struct_coder = nested_structure_coder.StructureCoder() 93 spec_proto = struct_coder.encode_structure(spec) 94 tensor_info.composite_tensor.type_spec.CopyFrom(spec_proto.type_spec_value) 95 for component in nest.flatten(tensor, expand_composites=True): 96 tensor_info.composite_tensor.components.add().CopyFrom( 97 build_tensor_info_internal(component)) 98 return tensor_info 99 100 101def build_tensor_info_from_op(op): 102 """Utility function to build TensorInfo proto from an Op. 103 104 Note that this function should be used with caution. It is strictly restricted 105 to TensorFlow internal use-cases only. Please make sure you do need it before 106 using it. 107 108 This utility function overloads the TensorInfo proto by setting the name to 109 the Op's name, dtype to DT_INVALID and tensor_shape as None. One typical usage 110 is for the Op of the call site for the defunned function: 111 ```python 112 @function.defun 113 def some_variable_initialization_fn(value_a, value_b): 114 a = value_a 115 b = value_b 116 117 value_a = constant_op.constant(1, name="a") 118 value_b = constant_op.constant(2, name="b") 119 op_info = utils.build_op_info( 120 some_variable_initialization_fn(value_a, value_b)) 121 ``` 122 123 Args: 124 op: An Op whose name is used to build the TensorInfo. The name that points 125 to the Op could be fetched at run time in the Loader session. 126 127 Returns: 128 A TensorInfo protocol buffer constructed based on the supplied argument. 129 130 Raises: 131 RuntimeError: If eager execution is enabled. 132 """ 133 if context.executing_eagerly(): 134 raise RuntimeError( 135 "build_tensor_info_from_op is not supported in Eager mode.") 136 return meta_graph_pb2.TensorInfo( 137 dtype=types_pb2.DT_INVALID, 138 tensor_shape=tensor_shape.unknown_shape().as_proto(), 139 name=op.name) 140 141 142@tf_export(v1=["saved_model.get_tensor_from_tensor_info", 143 "saved_model.utils.get_tensor_from_tensor_info"]) 144@deprecation.deprecated( 145 None, 146 "This function will only be available through the v1 compatibility " 147 "library as tf.compat.v1.saved_model.utils.get_tensor_from_tensor_info or " 148 "tf.compat.v1.saved_model.get_tensor_from_tensor_info.") 149def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None): 150 """Returns the Tensor or CompositeTensor described by a TensorInfo proto. 151 152 Args: 153 tensor_info: A TensorInfo proto describing a Tensor or SparseTensor or 154 CompositeTensor. 155 graph: The tf.Graph in which tensors are looked up. If None, the 156 current default graph is used. 157 import_scope: If not None, names in `tensor_info` are prefixed with this 158 string before lookup. 159 160 Returns: 161 The Tensor or SparseTensor or CompositeTensor in `graph` described by 162 `tensor_info`. 163 164 Raises: 165 KeyError: If `tensor_info` does not correspond to a tensor in `graph`. 166 ValueError: If `tensor_info` is malformed. 167 """ 168 graph = graph or ops.get_default_graph() 169 def _get_tensor(name): 170 return graph.get_tensor_by_name( 171 ops.prepend_name_scope(name, import_scope=import_scope)) 172 encoding = tensor_info.WhichOneof("encoding") 173 if encoding == "name": 174 return _get_tensor(tensor_info.name) 175 elif encoding == "coo_sparse": 176 return sparse_tensor.SparseTensor( 177 _get_tensor(tensor_info.coo_sparse.indices_tensor_name), 178 _get_tensor(tensor_info.coo_sparse.values_tensor_name), 179 _get_tensor(tensor_info.coo_sparse.dense_shape_tensor_name)) 180 elif encoding == "composite_tensor": 181 struct_coder = nested_structure_coder.StructureCoder() 182 spec_proto = struct_pb2.StructuredValue( 183 type_spec_value=tensor_info.composite_tensor.type_spec) 184 spec = struct_coder.decode_proto(spec_proto) 185 components = [_get_tensor(component.name) for component in 186 tensor_info.composite_tensor.components] 187 return nest.pack_sequence_as(spec, components, expand_composites=True) 188 else: 189 raise ValueError("Invalid TensorInfo.encoding: %s" % encoding) 190 191 192def get_element_from_tensor_info(tensor_info, graph=None, import_scope=None): 193 """Returns the element in the graph described by a TensorInfo proto. 194 195 Args: 196 tensor_info: A TensorInfo proto describing an Op or Tensor by name. 197 graph: The tf.Graph in which tensors are looked up. If None, the current 198 default graph is used. 199 import_scope: If not None, names in `tensor_info` are prefixed with this 200 string before lookup. 201 202 Returns: 203 Op or tensor in `graph` described by `tensor_info`. 204 205 Raises: 206 KeyError: If `tensor_info` does not correspond to an op or tensor in `graph` 207 """ 208 graph = graph or ops.get_default_graph() 209 return graph.as_graph_element( 210 ops.prepend_name_scope(tensor_info.name, import_scope=import_scope)) 211 212 213# Path helpers. 214 215 216def get_or_create_variables_dir(export_dir): 217 """Return variables sub-directory, or create one if it doesn't exist.""" 218 variables_dir = get_variables_dir(export_dir) 219 if not file_io.file_exists(variables_dir): 220 file_io.recursive_create_dir(variables_dir) 221 return variables_dir 222 223 224def get_variables_dir(export_dir): 225 """Return variables sub-directory in the SavedModel.""" 226 return os.path.join( 227 compat.as_text(export_dir), 228 compat.as_text(constants.VARIABLES_DIRECTORY)) 229 230 231def get_variables_path(export_dir): 232 """Return the variables path, used as the prefix for checkpoint files.""" 233 return os.path.join( 234 compat.as_text(get_variables_dir(export_dir)), 235 compat.as_text(constants.VARIABLES_FILENAME)) 236 237 238def get_or_create_assets_dir(export_dir): 239 """Return assets sub-directory, or create one if it doesn't exist.""" 240 assets_destination_dir = get_assets_dir(export_dir) 241 242 if not file_io.file_exists(assets_destination_dir): 243 file_io.recursive_create_dir(assets_destination_dir) 244 245 return assets_destination_dir 246 247 248def get_assets_dir(export_dir): 249 """Return path to asset directory in the SavedModel.""" 250 return os.path.join( 251 compat.as_text(export_dir), 252 compat.as_text(constants.ASSETS_DIRECTORY)) 253 254 255def get_or_create_debug_dir(export_dir): 256 """Returns path to the debug sub-directory, creating if it does not exist.""" 257 debug_dir = get_debug_dir(export_dir) 258 259 if not file_io.file_exists(debug_dir): 260 file_io.recursive_create_dir(debug_dir) 261 262 return debug_dir 263 264 265def get_saved_model_pbtxt_path(export_dir): 266 return os.path.join( 267 compat.as_bytes(compat.path_to_str(export_dir)), 268 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)) 269 270 271def get_saved_model_pb_path(export_dir): 272 return os.path.join( 273 compat.as_bytes(compat.path_to_str(export_dir)), 274 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) 275 276 277def get_debug_dir(export_dir): 278 """Returns path to the debug sub-directory in the SavedModel.""" 279 return os.path.join( 280 compat.as_text(export_dir), compat.as_text(constants.DEBUG_DIRECTORY)) 281