1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""SavedModel utility functions implementation.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23from tensorflow.core.framework import types_pb2 24from tensorflow.core.protobuf import meta_graph_pb2 25from tensorflow.core.protobuf import struct_pb2 26from tensorflow.python.eager import context 27from tensorflow.python.framework import composite_tensor 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import sparse_tensor 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.lib.io import file_io 33from tensorflow.python.saved_model import constants 34from tensorflow.python.saved_model import nested_structure_coder 35from tensorflow.python.util import compat 36from tensorflow.python.util import deprecation 37from tensorflow.python.util import nest 38from tensorflow.python.util.tf_export import tf_export 39 40 41# TensorInfo helpers. 42 43 44@tf_export(v1=["saved_model.build_tensor_info", 45 "saved_model.utils.build_tensor_info"]) 46@deprecation.deprecated( 47 None, 48 "This function will only be available through the v1 compatibility " 49 "library as tf.compat.v1.saved_model.utils.build_tensor_info or " 50 "tf.compat.v1.saved_model.build_tensor_info.") 51def build_tensor_info(tensor): 52 """Utility function to build TensorInfo proto from a Tensor. 53 54 Args: 55 tensor: Tensor or SparseTensor whose name, dtype and shape are used to 56 build the TensorInfo. For SparseTensors, the names of the three 57 constituent Tensors are used. 58 59 Returns: 60 A TensorInfo protocol buffer constructed based on the supplied argument. 61 62 Raises: 63 RuntimeError: If eager execution is enabled. 64 65 @compatibility(TF2) 66 This API is not compatible with eager execution as `tensor` needs to be a 67 graph tensor, and there is no replacement for it in TensorFlow 2.x. To start 68 writing programs using TensorFlow 2.x, please refer to the [Effective 69 TensorFlow 2](https://www.tensorflow.org/guide/effective_tf2) guide. 70 @end_compatibility 71 """ 72 if context.executing_eagerly(): 73 raise RuntimeError("build_tensor_info is not supported in Eager mode.") 74 return build_tensor_info_internal(tensor) 75 76 77def build_tensor_info_internal(tensor): 78 """Utility function to build TensorInfo proto from a Tensor.""" 79 if (isinstance(tensor, composite_tensor.CompositeTensor) and 80 not isinstance(tensor, sparse_tensor.SparseTensor)): 81 return _build_composite_tensor_info_internal(tensor) 82 83 tensor_info = meta_graph_pb2.TensorInfo( 84 dtype=dtypes.as_dtype(tensor.dtype).as_datatype_enum, 85 tensor_shape=tensor.get_shape().as_proto()) 86 if isinstance(tensor, sparse_tensor.SparseTensor): 87 tensor_info.coo_sparse.values_tensor_name = tensor.values.name 88 tensor_info.coo_sparse.indices_tensor_name = tensor.indices.name 89 tensor_info.coo_sparse.dense_shape_tensor_name = tensor.dense_shape.name 90 else: 91 tensor_info.name = tensor.name 92 return tensor_info 93 94 95def _build_composite_tensor_info_internal(tensor): 96 """Utility function to build TensorInfo proto from a CompositeTensor.""" 97 spec = tensor._type_spec # pylint: disable=protected-access 98 tensor_info = meta_graph_pb2.TensorInfo() 99 struct_coder = nested_structure_coder.StructureCoder() 100 spec_proto = struct_coder.encode_structure(spec) 101 tensor_info.composite_tensor.type_spec.CopyFrom(spec_proto.type_spec_value) 102 for component in nest.flatten(tensor, expand_composites=True): 103 tensor_info.composite_tensor.components.add().CopyFrom( 104 build_tensor_info_internal(component)) 105 return tensor_info 106 107 108def build_tensor_info_from_op(op): 109 """Utility function to build TensorInfo proto from an Op. 110 111 Note that this function should be used with caution. It is strictly restricted 112 to TensorFlow internal use-cases only. Please make sure you do need it before 113 using it. 114 115 This utility function overloads the TensorInfo proto by setting the name to 116 the Op's name, dtype to DT_INVALID and tensor_shape as None. One typical usage 117 is for the Op of the call site for the defunned function: 118 ```python 119 @function.defun 120 def some_variable_initialization_fn(value_a, value_b): 121 a = value_a 122 b = value_b 123 124 value_a = constant_op.constant(1, name="a") 125 value_b = constant_op.constant(2, name="b") 126 op_info = utils.build_op_info( 127 some_variable_initialization_fn(value_a, value_b)) 128 ``` 129 130 Args: 131 op: An Op whose name is used to build the TensorInfo. The name that points 132 to the Op could be fetched at run time in the Loader session. 133 134 Returns: 135 A TensorInfo protocol buffer constructed based on the supplied argument. 136 137 Raises: 138 RuntimeError: If eager execution is enabled. 139 """ 140 if context.executing_eagerly(): 141 raise RuntimeError( 142 "build_tensor_info_from_op is not supported in Eager mode.") 143 return meta_graph_pb2.TensorInfo( 144 dtype=types_pb2.DT_INVALID, 145 tensor_shape=tensor_shape.unknown_shape().as_proto(), 146 name=op.name) 147 148 149@tf_export(v1=["saved_model.get_tensor_from_tensor_info", 150 "saved_model.utils.get_tensor_from_tensor_info"]) 151@deprecation.deprecated( 152 None, 153 "This function will only be available through the v1 compatibility " 154 "library as tf.compat.v1.saved_model.utils.get_tensor_from_tensor_info or " 155 "tf.compat.v1.saved_model.get_tensor_from_tensor_info.") 156def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None): 157 """Returns the Tensor or CompositeTensor described by a TensorInfo proto. 158 159 Args: 160 tensor_info: A TensorInfo proto describing a Tensor or SparseTensor or 161 CompositeTensor. 162 graph: The tf.Graph in which tensors are looked up. If None, the 163 current default graph is used. 164 import_scope: If not None, names in `tensor_info` are prefixed with this 165 string before lookup. 166 167 Returns: 168 The Tensor or SparseTensor or CompositeTensor in `graph` described by 169 `tensor_info`. 170 171 Raises: 172 KeyError: If `tensor_info` does not correspond to a tensor in `graph`. 173 ValueError: If `tensor_info` is malformed. 174 """ 175 graph = graph or ops.get_default_graph() 176 def _get_tensor(name): 177 return graph.get_tensor_by_name( 178 ops.prepend_name_scope(name, import_scope=import_scope)) 179 encoding = tensor_info.WhichOneof("encoding") 180 if encoding == "name": 181 return _get_tensor(tensor_info.name) 182 elif encoding == "coo_sparse": 183 return sparse_tensor.SparseTensor( 184 _get_tensor(tensor_info.coo_sparse.indices_tensor_name), 185 _get_tensor(tensor_info.coo_sparse.values_tensor_name), 186 _get_tensor(tensor_info.coo_sparse.dense_shape_tensor_name)) 187 elif encoding == "composite_tensor": 188 struct_coder = nested_structure_coder.StructureCoder() 189 spec_proto = struct_pb2.StructuredValue( 190 type_spec_value=tensor_info.composite_tensor.type_spec) 191 spec = struct_coder.decode_proto(spec_proto) 192 components = [_get_tensor(component.name) for component in 193 tensor_info.composite_tensor.components] 194 return nest.pack_sequence_as(spec, components, expand_composites=True) 195 else: 196 raise ValueError("Invalid TensorInfo.encoding: %s" % encoding) 197 198 199def get_element_from_tensor_info(tensor_info, graph=None, import_scope=None): 200 """Returns the element in the graph described by a TensorInfo proto. 201 202 Args: 203 tensor_info: A TensorInfo proto describing an Op or Tensor by name. 204 graph: The tf.Graph in which tensors are looked up. If None, the current 205 default graph is used. 206 import_scope: If not None, names in `tensor_info` are prefixed with this 207 string before lookup. 208 209 Returns: 210 Op or tensor in `graph` described by `tensor_info`. 211 212 Raises: 213 KeyError: If `tensor_info` does not correspond to an op or tensor in `graph` 214 """ 215 graph = graph or ops.get_default_graph() 216 return graph.as_graph_element( 217 ops.prepend_name_scope(tensor_info.name, import_scope=import_scope)) 218 219 220# Path helpers. 221 222 223def get_or_create_variables_dir(export_dir): 224 """Return variables sub-directory, or create one if it doesn't exist.""" 225 variables_dir = get_variables_dir(export_dir) 226 file_io.recursive_create_dir(variables_dir) 227 return variables_dir 228 229 230def get_variables_dir(export_dir): 231 """Return variables sub-directory in the SavedModel.""" 232 return os.path.join( 233 compat.as_text(export_dir), 234 compat.as_text(constants.VARIABLES_DIRECTORY)) 235 236 237def get_variables_path(export_dir): 238 """Return the variables path, used as the prefix for checkpoint files.""" 239 return os.path.join( 240 compat.as_text(get_variables_dir(export_dir)), 241 compat.as_text(constants.VARIABLES_FILENAME)) 242 243 244def get_or_create_assets_dir(export_dir): 245 """Return assets sub-directory, or create one if it doesn't exist.""" 246 assets_destination_dir = get_assets_dir(export_dir) 247 248 file_io.recursive_create_dir(assets_destination_dir) 249 250 return assets_destination_dir 251 252 253def get_assets_dir(export_dir): 254 """Return path to asset directory in the SavedModel.""" 255 return os.path.join( 256 compat.as_text(export_dir), 257 compat.as_text(constants.ASSETS_DIRECTORY)) 258 259 260def get_or_create_debug_dir(export_dir): 261 """Returns path to the debug sub-directory, creating if it does not exist.""" 262 debug_dir = get_debug_dir(export_dir) 263 264 file_io.recursive_create_dir(debug_dir) 265 266 return debug_dir 267 268 269def get_saved_model_pbtxt_path(export_dir): 270 return os.path.join( 271 compat.as_bytes(compat.path_to_str(export_dir)), 272 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)) 273 274 275def get_saved_model_pb_path(export_dir): 276 return os.path.join( 277 compat.as_bytes(compat.path_to_str(export_dir)), 278 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) 279 280 281def get_debug_dir(export_dir): 282 """Returns path to the debug sub-directory in the SavedModel.""" 283 return os.path.join( 284 compat.as_text(export_dir), compat.as_text(constants.DEBUG_DIRECTORY)) 285 286# Based on tensor_bundle/byte_swap.cc 287byte_swappable = [ 288 dtypes.float16, dtypes.float32, dtypes.float64, dtypes.bfloat16, 289 dtypes.complex64, dtypes.complex128, dtypes.uint16, dtypes.uint32, 290 dtypes.uint64, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.qint16, 291 dtypes.quint16, dtypes.qint32 292] 293 294 295def swap_function_tensor_content(meta_graph_def, from_endiness, to_endiness): 296 functions = meta_graph_def.graph_def.library.function 297 for function in functions: 298 node_def = function.node_def 299 for node in node_def: 300 if node.op == "Const": 301 tensor = node.attr["value"].tensor 302 byte_swap_tensor_content(tensor, from_endiness, to_endiness) 303 304 305def byte_swap_tensor_content(tensor, from_endiness, to_endiness): 306 """Byte swaps.""" 307 if tensor.dtype in byte_swappable: 308 tshape = tensor.tensor_shape.dim 309 tensor_bytes = tensor.tensor_content 310 if tensor_bytes: 311 tensor_size = 1 312 for sz in tshape: 313 tensor_size = tensor_size * sz.size 314 chunksize = int(len(tensor_bytes) / tensor_size) 315 # Split tensor_data into chunks for byte swapping. 316 to_swap = [ 317 tensor_bytes[i:i + chunksize] 318 for i in range(0, len(tensor_bytes), chunksize) 319 ] 320 # Swap and replace tensor_content. 321 tensor.tensor_content = b"".join([ 322 int.from_bytes(byteswap, 323 from_endiness).to_bytes(chunksize, to_endiness) 324 for byteswap in to_swap 325 ]) 326