1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""SavedModel utility functions implementation.""" 16 17from tensorflow.core.framework import types_pb2 18from tensorflow.core.protobuf import meta_graph_pb2 19from tensorflow.core.protobuf import struct_pb2 20from tensorflow.python.eager import context 21from tensorflow.python.framework import composite_tensor 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import sparse_tensor 25from tensorflow.python.framework import tensor_shape 26from tensorflow.python.lib.io import file_io 27from tensorflow.python.ops import resource_variable_ops 28from tensorflow.python.saved_model import constants 29from tensorflow.python.saved_model import nested_structure_coder 30from tensorflow.python.util import compat 31from tensorflow.python.util import deprecation 32from tensorflow.python.util import nest 33from tensorflow.python.util.tf_export import tf_export 34 35 36# TensorInfo helpers. 37_DEPRECATION_MSG = ( 38 "This API was designed for TensorFlow v1. See " 39 "https://www.tensorflow.org/guide/migrate for instructions on how to " 40 "migrate your code to TensorFlow v2.") 41 42 43@tf_export( 44 v1=["saved_model.build_tensor_info", "saved_model.utils.build_tensor_info"]) 45@deprecation.deprecated(None, _DEPRECATION_MSG) 46def build_tensor_info(tensor): 47 """Utility function to build TensorInfo proto from a Tensor. 48 49 Args: 50 tensor: Tensor or SparseTensor whose name, dtype and shape are used to 51 build the TensorInfo. For SparseTensors, the names of the three 52 constituent Tensors are used. 53 54 Returns: 55 A TensorInfo protocol buffer constructed based on the supplied argument. 56 57 Raises: 58 RuntimeError: If eager execution is enabled. 59 60 @compatibility(TF2) 61 This API is not compatible with eager execution as `tensor` needs to be a 62 graph tensor, and there is no replacement for it in TensorFlow 2.x. To start 63 writing programs using TensorFlow 2.x, please refer to the [Effective 64 TensorFlow 2](https://www.tensorflow.org/guide/effective_tf2) guide. 65 @end_compatibility 66 """ 67 if context.executing_eagerly(): 68 raise RuntimeError("`build_tensor_info` is not supported in eager " 69 "execution.") 70 return build_tensor_info_internal(tensor) 71 72 73def build_tensor_info_internal(tensor): 74 """Utility function to build TensorInfo proto from a Tensor.""" 75 if (isinstance(tensor, composite_tensor.CompositeTensor) and 76 not isinstance(tensor, sparse_tensor.SparseTensor) and 77 not isinstance(tensor, resource_variable_ops.ResourceVariable)): 78 return _build_composite_tensor_info_internal(tensor) 79 80 tensor_info = meta_graph_pb2.TensorInfo( 81 dtype=dtypes.as_dtype(tensor.dtype).as_datatype_enum, 82 tensor_shape=tensor.get_shape().as_proto()) 83 if isinstance(tensor, sparse_tensor.SparseTensor): 84 tensor_info.coo_sparse.values_tensor_name = tensor.values.name 85 tensor_info.coo_sparse.indices_tensor_name = tensor.indices.name 86 tensor_info.coo_sparse.dense_shape_tensor_name = tensor.dense_shape.name 87 else: 88 tensor_info.name = tensor.name 89 return tensor_info 90 91 92def _build_composite_tensor_info_internal(tensor): 93 """Utility function to build TensorInfo proto from a CompositeTensor.""" 94 spec = tensor._type_spec # pylint: disable=protected-access 95 tensor_info = meta_graph_pb2.TensorInfo() 96 spec_proto = nested_structure_coder.encode_structure(spec) 97 tensor_info.composite_tensor.type_spec.CopyFrom(spec_proto.type_spec_value) 98 for component in nest.flatten(tensor, expand_composites=True): 99 tensor_info.composite_tensor.components.add().CopyFrom( 100 build_tensor_info_internal(component)) 101 return tensor_info 102 103 104def build_tensor_info_from_op(op): 105 """Utility function to build TensorInfo proto from an Op. 106 107 Note that this function should be used with caution. It is strictly restricted 108 to TensorFlow internal use-cases only. Please make sure you do need it before 109 using it. 110 111 This utility function overloads the TensorInfo proto by setting the name to 112 the Op's name, dtype to DT_INVALID and tensor_shape as None. One typical usage 113 is for the Op of the call site for the defunned function: 114 ```python 115 @function.defun 116 def some_variable_initialization_fn(value_a, value_b): 117 a = value_a 118 b = value_b 119 120 value_a = constant_op.constant(1, name="a") 121 value_b = constant_op.constant(2, name="b") 122 op_info = utils.build_op_info( 123 some_variable_initialization_fn(value_a, value_b)) 124 ``` 125 126 Args: 127 op: An Op whose name is used to build the TensorInfo. The name that points 128 to the Op could be fetched at run time in the Loader session. 129 130 Returns: 131 A TensorInfo protocol buffer constructed based on the supplied argument. 132 133 Raises: 134 RuntimeError: If eager execution is enabled. 135 """ 136 if context.executing_eagerly(): 137 raise RuntimeError( 138 "`build_tensor_info_from_op` is not supported in eager execution.") 139 return meta_graph_pb2.TensorInfo( 140 dtype=types_pb2.DT_INVALID, 141 tensor_shape=tensor_shape.unknown_shape().as_proto(), 142 name=op.name) 143 144 145@tf_export(v1=["saved_model.get_tensor_from_tensor_info", 146 "saved_model.utils.get_tensor_from_tensor_info"]) 147@deprecation.deprecated(None, _DEPRECATION_MSG) 148def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None): 149 """Returns the Tensor or CompositeTensor described by a TensorInfo proto. 150 151 Args: 152 tensor_info: A TensorInfo proto describing a Tensor or SparseTensor or 153 CompositeTensor. 154 graph: The tf.Graph in which tensors are looked up. If None, the 155 current default graph is used. 156 import_scope: If not None, names in `tensor_info` are prefixed with this 157 string before lookup. 158 159 Returns: 160 The Tensor or SparseTensor or CompositeTensor in `graph` described by 161 `tensor_info`. 162 163 Raises: 164 KeyError: If `tensor_info` does not correspond to a tensor in `graph`. 165 ValueError: If `tensor_info` is malformed. 166 """ 167 graph = graph or ops.get_default_graph() 168 def _get_tensor(name): 169 return graph.get_tensor_by_name( 170 ops.prepend_name_scope(name, import_scope=import_scope)) 171 encoding = tensor_info.WhichOneof("encoding") 172 if encoding == "name": 173 return _get_tensor(tensor_info.name) 174 elif encoding == "coo_sparse": 175 return sparse_tensor.SparseTensor( 176 _get_tensor(tensor_info.coo_sparse.indices_tensor_name), 177 _get_tensor(tensor_info.coo_sparse.values_tensor_name), 178 _get_tensor(tensor_info.coo_sparse.dense_shape_tensor_name)) 179 elif encoding == "composite_tensor": 180 spec_proto = struct_pb2.StructuredValue( 181 type_spec_value=tensor_info.composite_tensor.type_spec) 182 spec = nested_structure_coder.decode_proto(spec_proto) 183 components = [_get_tensor(component.name) for component in 184 tensor_info.composite_tensor.components] 185 return nest.pack_sequence_as(spec, components, expand_composites=True) 186 else: 187 raise ValueError(f"Invalid TensorInfo.encoding: {encoding}. Expected `" 188 "coo_sparse`, `composite_tensor`, or `name` for a dense " 189 "tensor.") 190 191 192def get_element_from_tensor_info(tensor_info, graph=None, import_scope=None): 193 """Returns the element in the graph described by a TensorInfo proto. 194 195 Args: 196 tensor_info: A TensorInfo proto describing an Op or Tensor by name. 197 graph: The tf.Graph in which tensors are looked up. If None, the current 198 default graph is used. 199 import_scope: If not None, names in `tensor_info` are prefixed with this 200 string before lookup. 201 202 Returns: 203 Op or tensor in `graph` described by `tensor_info`. 204 205 Raises: 206 KeyError: If `tensor_info` does not correspond to an op or tensor in `graph` 207 """ 208 graph = graph or ops.get_default_graph() 209 return graph.as_graph_element( 210 ops.prepend_name_scope(tensor_info.name, import_scope=import_scope)) 211 212 213# Path helpers. 214 215 216def get_or_create_variables_dir(export_dir): 217 """Return variables sub-directory, or create one if it doesn't exist.""" 218 variables_dir = get_variables_dir(export_dir) 219 file_io.recursive_create_dir(variables_dir) 220 return variables_dir 221 222 223def get_variables_dir(export_dir): 224 """Return variables sub-directory in the SavedModel.""" 225 return file_io.join( 226 compat.as_text(export_dir), compat.as_text(constants.VARIABLES_DIRECTORY)) 227 228 229def get_variables_path(export_dir): 230 """Return the variables path, used as the prefix for checkpoint files.""" 231 return file_io.join( 232 compat.as_text(get_variables_dir(export_dir)), 233 compat.as_text(constants.VARIABLES_FILENAME)) 234 235 236def get_or_create_assets_dir(export_dir): 237 """Return assets sub-directory, or create one if it doesn't exist.""" 238 assets_destination_dir = get_assets_dir(export_dir) 239 240 file_io.recursive_create_dir(assets_destination_dir) 241 242 return assets_destination_dir 243 244 245def get_assets_dir(export_dir): 246 """Return path to asset directory in the SavedModel.""" 247 return file_io.join( 248 compat.as_text(export_dir), compat.as_text(constants.ASSETS_DIRECTORY)) 249 250 251def get_or_create_debug_dir(export_dir): 252 """Returns path to the debug sub-directory, creating if it does not exist.""" 253 debug_dir = get_debug_dir(export_dir) 254 255 file_io.recursive_create_dir(debug_dir) 256 257 return debug_dir 258 259 260def get_saved_model_pbtxt_path(export_dir): 261 return file_io.join( 262 compat.as_bytes(compat.path_to_str(export_dir)), 263 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT)) 264 265 266def get_saved_model_pb_path(export_dir): 267 return file_io.join( 268 compat.as_bytes(compat.path_to_str(export_dir)), 269 compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB)) 270 271 272def get_debug_dir(export_dir): 273 """Returns path to the debug sub-directory in the SavedModel.""" 274 return file_io.join( 275 compat.as_text(export_dir), compat.as_text(constants.DEBUG_DIRECTORY)) 276 277# Based on tensor_bundle/byte_swap.cc 278byte_swappable = [ 279 dtypes.float16, dtypes.float32, dtypes.float64, dtypes.bfloat16, 280 dtypes.complex64, dtypes.complex128, dtypes.uint16, dtypes.uint32, 281 dtypes.uint64, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.qint16, 282 dtypes.quint16, dtypes.qint32 283] 284 285 286def swap_function_tensor_content(meta_graph_def, from_endiness, to_endiness): 287 functions = meta_graph_def.graph_def.library.function 288 for function in functions: 289 node_def = function.node_def 290 for node in node_def: 291 if node.op == "Const": 292 tensor = node.attr["value"].tensor 293 byte_swap_tensor_content(tensor, from_endiness, to_endiness) 294 295 296def byte_swap_tensor_content(tensor, from_endiness, to_endiness): 297 """Byte swaps.""" 298 if tensor.dtype in byte_swappable: 299 tshape = tensor.tensor_shape.dim 300 tensor_bytes = tensor.tensor_content 301 if tensor_bytes: 302 tensor_size = 1 303 for sz in tshape: 304 tensor_size = tensor_size * sz.size 305 chunksize = int(len(tensor_bytes) / tensor_size) 306 # Split tensor_data into chunks for byte swapping. 307 to_swap = [ 308 tensor_bytes[i:i + chunksize] 309 for i in range(0, len(tensor_bytes), chunksize) 310 ] 311 # Swap and replace tensor_content. 312 tensor.tensor_content = b"".join([ 313 int.from_bytes(byteswap, 314 from_endiness).to_bytes(chunksize, to_endiness) 315 for byteswap in to_swap 316 ]) 317