• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""SavedModel utility functions implementation."""
16
17from tensorflow.core.framework import types_pb2
18from tensorflow.core.protobuf import meta_graph_pb2
19from tensorflow.core.protobuf import struct_pb2
20from tensorflow.python.eager import context
21from tensorflow.python.framework import composite_tensor
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops
24from tensorflow.python.framework import sparse_tensor
25from tensorflow.python.framework import tensor_shape
26from tensorflow.python.lib.io import file_io
27from tensorflow.python.ops import resource_variable_ops
28from tensorflow.python.saved_model import constants
29from tensorflow.python.saved_model import nested_structure_coder
30from tensorflow.python.util import compat
31from tensorflow.python.util import deprecation
32from tensorflow.python.util import nest
33from tensorflow.python.util.tf_export import tf_export
34
35
36# TensorInfo helpers.
37_DEPRECATION_MSG = (
38    "This API was designed for TensorFlow v1. See "
39    "https://www.tensorflow.org/guide/migrate for instructions on how to "
40    "migrate your code to TensorFlow v2.")
41
42
43@tf_export(
44    v1=["saved_model.build_tensor_info", "saved_model.utils.build_tensor_info"])
45@deprecation.deprecated(None, _DEPRECATION_MSG)
46def build_tensor_info(tensor):
47  """Utility function to build TensorInfo proto from a Tensor.
48
49  Args:
50    tensor: Tensor or SparseTensor whose name, dtype and shape are used to
51        build the TensorInfo. For SparseTensors, the names of the three
52        constituent Tensors are used.
53
54  Returns:
55    A TensorInfo protocol buffer constructed based on the supplied argument.
56
57  Raises:
58    RuntimeError: If eager execution is enabled.
59
60  @compatibility(TF2)
61  This API is not compatible with eager execution as `tensor` needs to be a
62  graph tensor, and there is no replacement for it in TensorFlow 2.x. To start
63  writing programs using TensorFlow 2.x, please refer to the [Effective
64  TensorFlow 2](https://www.tensorflow.org/guide/effective_tf2) guide.
65  @end_compatibility
66  """
67  if context.executing_eagerly():
68    raise RuntimeError("`build_tensor_info` is not supported in eager "
69                       "execution.")
70  return build_tensor_info_internal(tensor)
71
72
73def build_tensor_info_internal(tensor):
74  """Utility function to build TensorInfo proto from a Tensor."""
75  if (isinstance(tensor, composite_tensor.CompositeTensor) and
76      not isinstance(tensor, sparse_tensor.SparseTensor) and
77      not isinstance(tensor, resource_variable_ops.ResourceVariable)):
78    return _build_composite_tensor_info_internal(tensor)
79
80  tensor_info = meta_graph_pb2.TensorInfo(
81      dtype=dtypes.as_dtype(tensor.dtype).as_datatype_enum,
82      tensor_shape=tensor.get_shape().as_proto())
83  if isinstance(tensor, sparse_tensor.SparseTensor):
84    tensor_info.coo_sparse.values_tensor_name = tensor.values.name
85    tensor_info.coo_sparse.indices_tensor_name = tensor.indices.name
86    tensor_info.coo_sparse.dense_shape_tensor_name = tensor.dense_shape.name
87  else:
88    tensor_info.name = tensor.name
89  return tensor_info
90
91
92def _build_composite_tensor_info_internal(tensor):
93  """Utility function to build TensorInfo proto from a CompositeTensor."""
94  spec = tensor._type_spec  # pylint: disable=protected-access
95  tensor_info = meta_graph_pb2.TensorInfo()
96  spec_proto = nested_structure_coder.encode_structure(spec)
97  tensor_info.composite_tensor.type_spec.CopyFrom(spec_proto.type_spec_value)
98  for component in nest.flatten(tensor, expand_composites=True):
99    tensor_info.composite_tensor.components.add().CopyFrom(
100        build_tensor_info_internal(component))
101  return tensor_info
102
103
104def build_tensor_info_from_op(op):
105  """Utility function to build TensorInfo proto from an Op.
106
107  Note that this function should be used with caution. It is strictly restricted
108  to TensorFlow internal use-cases only. Please make sure you do need it before
109  using it.
110
111  This utility function overloads the TensorInfo proto by setting the name to
112  the Op's name, dtype to DT_INVALID and tensor_shape as None. One typical usage
113  is for the Op of the call site for the defunned function:
114  ```python
115    @function.defun
116    def some_variable_initialization_fn(value_a, value_b):
117      a = value_a
118      b = value_b
119
120    value_a = constant_op.constant(1, name="a")
121    value_b = constant_op.constant(2, name="b")
122    op_info = utils.build_op_info(
123        some_variable_initialization_fn(value_a, value_b))
124  ```
125
126  Args:
127    op: An Op whose name is used to build the TensorInfo. The name that points
128        to the Op could be fetched at run time in the Loader session.
129
130  Returns:
131    A TensorInfo protocol buffer constructed based on the supplied argument.
132
133  Raises:
134    RuntimeError: If eager execution is enabled.
135  """
136  if context.executing_eagerly():
137    raise RuntimeError(
138        "`build_tensor_info_from_op` is not supported in eager execution.")
139  return meta_graph_pb2.TensorInfo(
140      dtype=types_pb2.DT_INVALID,
141      tensor_shape=tensor_shape.unknown_shape().as_proto(),
142      name=op.name)
143
144
145@tf_export(v1=["saved_model.get_tensor_from_tensor_info",
146               "saved_model.utils.get_tensor_from_tensor_info"])
147@deprecation.deprecated(None, _DEPRECATION_MSG)
148def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None):
149  """Returns the Tensor or CompositeTensor described by a TensorInfo proto.
150
151  Args:
152    tensor_info: A TensorInfo proto describing a Tensor or SparseTensor or
153      CompositeTensor.
154    graph: The tf.Graph in which tensors are looked up. If None, the
155        current default graph is used.
156    import_scope: If not None, names in `tensor_info` are prefixed with this
157        string before lookup.
158
159  Returns:
160    The Tensor or SparseTensor or CompositeTensor in `graph` described by
161    `tensor_info`.
162
163  Raises:
164    KeyError: If `tensor_info` does not correspond to a tensor in `graph`.
165    ValueError: If `tensor_info` is malformed.
166  """
167  graph = graph or ops.get_default_graph()
168  def _get_tensor(name):
169    return graph.get_tensor_by_name(
170        ops.prepend_name_scope(name, import_scope=import_scope))
171  encoding = tensor_info.WhichOneof("encoding")
172  if encoding == "name":
173    return _get_tensor(tensor_info.name)
174  elif encoding == "coo_sparse":
175    return sparse_tensor.SparseTensor(
176        _get_tensor(tensor_info.coo_sparse.indices_tensor_name),
177        _get_tensor(tensor_info.coo_sparse.values_tensor_name),
178        _get_tensor(tensor_info.coo_sparse.dense_shape_tensor_name))
179  elif encoding == "composite_tensor":
180    spec_proto = struct_pb2.StructuredValue(
181        type_spec_value=tensor_info.composite_tensor.type_spec)
182    spec = nested_structure_coder.decode_proto(spec_proto)
183    components = [_get_tensor(component.name) for component in
184                  tensor_info.composite_tensor.components]
185    return nest.pack_sequence_as(spec, components, expand_composites=True)
186  else:
187    raise ValueError(f"Invalid TensorInfo.encoding: {encoding}. Expected `"
188                     "coo_sparse`, `composite_tensor`, or `name` for a dense "
189                     "tensor.")
190
191
192def get_element_from_tensor_info(tensor_info, graph=None, import_scope=None):
193  """Returns the element in the graph described by a TensorInfo proto.
194
195  Args:
196    tensor_info: A TensorInfo proto describing an Op or Tensor by name.
197    graph: The tf.Graph in which tensors are looked up. If None, the current
198      default graph is used.
199    import_scope: If not None, names in `tensor_info` are prefixed with this
200      string before lookup.
201
202  Returns:
203    Op or tensor in `graph` described by `tensor_info`.
204
205  Raises:
206    KeyError: If `tensor_info` does not correspond to an op or tensor in `graph`
207  """
208  graph = graph or ops.get_default_graph()
209  return graph.as_graph_element(
210      ops.prepend_name_scope(tensor_info.name, import_scope=import_scope))
211
212
213# Path helpers.
214
215
216def get_or_create_variables_dir(export_dir):
217  """Return variables sub-directory, or create one if it doesn't exist."""
218  variables_dir = get_variables_dir(export_dir)
219  file_io.recursive_create_dir(variables_dir)
220  return variables_dir
221
222
223def get_variables_dir(export_dir):
224  """Return variables sub-directory in the SavedModel."""
225  return file_io.join(
226      compat.as_text(export_dir), compat.as_text(constants.VARIABLES_DIRECTORY))
227
228
229def get_variables_path(export_dir):
230  """Return the variables path, used as the prefix for checkpoint files."""
231  return file_io.join(
232      compat.as_text(get_variables_dir(export_dir)),
233      compat.as_text(constants.VARIABLES_FILENAME))
234
235
236def get_or_create_assets_dir(export_dir):
237  """Return assets sub-directory, or create one if it doesn't exist."""
238  assets_destination_dir = get_assets_dir(export_dir)
239
240  file_io.recursive_create_dir(assets_destination_dir)
241
242  return assets_destination_dir
243
244
245def get_assets_dir(export_dir):
246  """Return path to asset directory in the SavedModel."""
247  return file_io.join(
248      compat.as_text(export_dir), compat.as_text(constants.ASSETS_DIRECTORY))
249
250
251def get_or_create_debug_dir(export_dir):
252  """Returns path to the debug sub-directory, creating if it does not exist."""
253  debug_dir = get_debug_dir(export_dir)
254
255  file_io.recursive_create_dir(debug_dir)
256
257  return debug_dir
258
259
260def get_saved_model_pbtxt_path(export_dir):
261  return file_io.join(
262      compat.as_bytes(compat.path_to_str(export_dir)),
263      compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
264
265
266def get_saved_model_pb_path(export_dir):
267  return file_io.join(
268      compat.as_bytes(compat.path_to_str(export_dir)),
269      compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
270
271
272def get_debug_dir(export_dir):
273  """Returns path to the debug sub-directory in the SavedModel."""
274  return file_io.join(
275      compat.as_text(export_dir), compat.as_text(constants.DEBUG_DIRECTORY))
276
277# Based on tensor_bundle/byte_swap.cc
278byte_swappable = [
279    dtypes.float16, dtypes.float32, dtypes.float64, dtypes.bfloat16,
280    dtypes.complex64, dtypes.complex128, dtypes.uint16, dtypes.uint32,
281    dtypes.uint64, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.qint16,
282    dtypes.quint16, dtypes.qint32
283]
284
285
286def swap_function_tensor_content(meta_graph_def, from_endiness, to_endiness):
287  functions = meta_graph_def.graph_def.library.function
288  for function in functions:
289    node_def = function.node_def
290    for node in node_def:
291      if node.op == "Const":
292        tensor = node.attr["value"].tensor
293        byte_swap_tensor_content(tensor, from_endiness, to_endiness)
294
295
296def byte_swap_tensor_content(tensor, from_endiness, to_endiness):
297  """Byte swaps."""
298  if tensor.dtype in byte_swappable:
299    tshape = tensor.tensor_shape.dim
300    tensor_bytes = tensor.tensor_content
301    if tensor_bytes:
302      tensor_size = 1
303      for sz in tshape:
304        tensor_size = tensor_size * sz.size
305      chunksize = int(len(tensor_bytes) / tensor_size)
306      # Split tensor_data into chunks for byte swapping.
307      to_swap = [
308          tensor_bytes[i:i + chunksize]
309          for i in range(0, len(tensor_bytes), chunksize)
310      ]
311      # Swap and replace tensor_content.
312      tensor.tensor_content = b"".join([
313          int.from_bytes(byteswap,
314                         from_endiness).to_bytes(chunksize, to_endiness)
315          for byteswap in to_swap
316      ])
317