• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""SavedModel utility functions implementation."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23from tensorflow.core.framework import types_pb2
24from tensorflow.core.protobuf import meta_graph_pb2
25from tensorflow.core.protobuf import struct_pb2
26from tensorflow.python.eager import context
27from tensorflow.python.framework import composite_tensor
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import sparse_tensor
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.lib.io import file_io
33from tensorflow.python.saved_model import constants
34from tensorflow.python.saved_model import nested_structure_coder
35from tensorflow.python.util import compat
36from tensorflow.python.util import deprecation
37from tensorflow.python.util import nest
38from tensorflow.python.util.tf_export import tf_export
39
40
41# TensorInfo helpers.
42
43
44@tf_export(v1=["saved_model.build_tensor_info",
45               "saved_model.utils.build_tensor_info"])
46@deprecation.deprecated(
47    None,
48    "This function will only be available through the v1 compatibility "
49    "library as tf.compat.v1.saved_model.utils.build_tensor_info or "
50    "tf.compat.v1.saved_model.build_tensor_info.")
51def build_tensor_info(tensor):
52  """Utility function to build TensorInfo proto from a Tensor.
53
54  Args:
55    tensor: Tensor or SparseTensor whose name, dtype and shape are used to
56        build the TensorInfo. For SparseTensors, the names of the three
57        constituent Tensors are used.
58
59  Returns:
60    A TensorInfo protocol buffer constructed based on the supplied argument.
61
62  Raises:
63    RuntimeError: If eager execution is enabled.
64
65  @compatibility(TF2)
66  This API is not compatible with eager execution as `tensor` needs to be a
67  graph tensor, and there is no replacement for it in TensorFlow 2.x. To start
68  writing programs using TensorFlow 2.x, please refer to the [Effective
69  TensorFlow 2](https://www.tensorflow.org/guide/effective_tf2) guide.
70  @end_compatibility
71  """
72  if context.executing_eagerly():
73    raise RuntimeError("build_tensor_info is not supported in Eager mode.")
74  return build_tensor_info_internal(tensor)
75
76
77def build_tensor_info_internal(tensor):
78  """Utility function to build TensorInfo proto from a Tensor."""
79  if (isinstance(tensor, composite_tensor.CompositeTensor) and
80      not isinstance(tensor, sparse_tensor.SparseTensor)):
81    return _build_composite_tensor_info_internal(tensor)
82
83  tensor_info = meta_graph_pb2.TensorInfo(
84      dtype=dtypes.as_dtype(tensor.dtype).as_datatype_enum,
85      tensor_shape=tensor.get_shape().as_proto())
86  if isinstance(tensor, sparse_tensor.SparseTensor):
87    tensor_info.coo_sparse.values_tensor_name = tensor.values.name
88    tensor_info.coo_sparse.indices_tensor_name = tensor.indices.name
89    tensor_info.coo_sparse.dense_shape_tensor_name = tensor.dense_shape.name
90  else:
91    tensor_info.name = tensor.name
92  return tensor_info
93
94
95def _build_composite_tensor_info_internal(tensor):
96  """Utility function to build TensorInfo proto from a CompositeTensor."""
97  spec = tensor._type_spec  # pylint: disable=protected-access
98  tensor_info = meta_graph_pb2.TensorInfo()
99  struct_coder = nested_structure_coder.StructureCoder()
100  spec_proto = struct_coder.encode_structure(spec)
101  tensor_info.composite_tensor.type_spec.CopyFrom(spec_proto.type_spec_value)
102  for component in nest.flatten(tensor, expand_composites=True):
103    tensor_info.composite_tensor.components.add().CopyFrom(
104        build_tensor_info_internal(component))
105  return tensor_info
106
107
108def build_tensor_info_from_op(op):
109  """Utility function to build TensorInfo proto from an Op.
110
111  Note that this function should be used with caution. It is strictly restricted
112  to TensorFlow internal use-cases only. Please make sure you do need it before
113  using it.
114
115  This utility function overloads the TensorInfo proto by setting the name to
116  the Op's name, dtype to DT_INVALID and tensor_shape as None. One typical usage
117  is for the Op of the call site for the defunned function:
118  ```python
119    @function.defun
120    def some_variable_initialization_fn(value_a, value_b):
121      a = value_a
122      b = value_b
123
124    value_a = constant_op.constant(1, name="a")
125    value_b = constant_op.constant(2, name="b")
126    op_info = utils.build_op_info(
127        some_variable_initialization_fn(value_a, value_b))
128  ```
129
130  Args:
131    op: An Op whose name is used to build the TensorInfo. The name that points
132        to the Op could be fetched at run time in the Loader session.
133
134  Returns:
135    A TensorInfo protocol buffer constructed based on the supplied argument.
136
137  Raises:
138    RuntimeError: If eager execution is enabled.
139  """
140  if context.executing_eagerly():
141    raise RuntimeError(
142        "build_tensor_info_from_op is not supported in Eager mode.")
143  return meta_graph_pb2.TensorInfo(
144      dtype=types_pb2.DT_INVALID,
145      tensor_shape=tensor_shape.unknown_shape().as_proto(),
146      name=op.name)
147
148
149@tf_export(v1=["saved_model.get_tensor_from_tensor_info",
150               "saved_model.utils.get_tensor_from_tensor_info"])
151@deprecation.deprecated(
152    None,
153    "This function will only be available through the v1 compatibility "
154    "library as tf.compat.v1.saved_model.utils.get_tensor_from_tensor_info or "
155    "tf.compat.v1.saved_model.get_tensor_from_tensor_info.")
156def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None):
157  """Returns the Tensor or CompositeTensor described by a TensorInfo proto.
158
159  Args:
160    tensor_info: A TensorInfo proto describing a Tensor or SparseTensor or
161      CompositeTensor.
162    graph: The tf.Graph in which tensors are looked up. If None, the
163        current default graph is used.
164    import_scope: If not None, names in `tensor_info` are prefixed with this
165        string before lookup.
166
167  Returns:
168    The Tensor or SparseTensor or CompositeTensor in `graph` described by
169    `tensor_info`.
170
171  Raises:
172    KeyError: If `tensor_info` does not correspond to a tensor in `graph`.
173    ValueError: If `tensor_info` is malformed.
174  """
175  graph = graph or ops.get_default_graph()
176  def _get_tensor(name):
177    return graph.get_tensor_by_name(
178        ops.prepend_name_scope(name, import_scope=import_scope))
179  encoding = tensor_info.WhichOneof("encoding")
180  if encoding == "name":
181    return _get_tensor(tensor_info.name)
182  elif encoding == "coo_sparse":
183    return sparse_tensor.SparseTensor(
184        _get_tensor(tensor_info.coo_sparse.indices_tensor_name),
185        _get_tensor(tensor_info.coo_sparse.values_tensor_name),
186        _get_tensor(tensor_info.coo_sparse.dense_shape_tensor_name))
187  elif encoding == "composite_tensor":
188    struct_coder = nested_structure_coder.StructureCoder()
189    spec_proto = struct_pb2.StructuredValue(
190        type_spec_value=tensor_info.composite_tensor.type_spec)
191    spec = struct_coder.decode_proto(spec_proto)
192    components = [_get_tensor(component.name) for component in
193                  tensor_info.composite_tensor.components]
194    return nest.pack_sequence_as(spec, components, expand_composites=True)
195  else:
196    raise ValueError("Invalid TensorInfo.encoding: %s" % encoding)
197
198
199def get_element_from_tensor_info(tensor_info, graph=None, import_scope=None):
200  """Returns the element in the graph described by a TensorInfo proto.
201
202  Args:
203    tensor_info: A TensorInfo proto describing an Op or Tensor by name.
204    graph: The tf.Graph in which tensors are looked up. If None, the current
205      default graph is used.
206    import_scope: If not None, names in `tensor_info` are prefixed with this
207      string before lookup.
208
209  Returns:
210    Op or tensor in `graph` described by `tensor_info`.
211
212  Raises:
213    KeyError: If `tensor_info` does not correspond to an op or tensor in `graph`
214  """
215  graph = graph or ops.get_default_graph()
216  return graph.as_graph_element(
217      ops.prepend_name_scope(tensor_info.name, import_scope=import_scope))
218
219
220# Path helpers.
221
222
223def get_or_create_variables_dir(export_dir):
224  """Return variables sub-directory, or create one if it doesn't exist."""
225  variables_dir = get_variables_dir(export_dir)
226  file_io.recursive_create_dir(variables_dir)
227  return variables_dir
228
229
230def get_variables_dir(export_dir):
231  """Return variables sub-directory in the SavedModel."""
232  return os.path.join(
233      compat.as_text(export_dir),
234      compat.as_text(constants.VARIABLES_DIRECTORY))
235
236
237def get_variables_path(export_dir):
238  """Return the variables path, used as the prefix for checkpoint files."""
239  return os.path.join(
240      compat.as_text(get_variables_dir(export_dir)),
241      compat.as_text(constants.VARIABLES_FILENAME))
242
243
244def get_or_create_assets_dir(export_dir):
245  """Return assets sub-directory, or create one if it doesn't exist."""
246  assets_destination_dir = get_assets_dir(export_dir)
247
248  file_io.recursive_create_dir(assets_destination_dir)
249
250  return assets_destination_dir
251
252
253def get_assets_dir(export_dir):
254  """Return path to asset directory in the SavedModel."""
255  return os.path.join(
256      compat.as_text(export_dir),
257      compat.as_text(constants.ASSETS_DIRECTORY))
258
259
260def get_or_create_debug_dir(export_dir):
261  """Returns path to the debug sub-directory, creating if it does not exist."""
262  debug_dir = get_debug_dir(export_dir)
263
264  file_io.recursive_create_dir(debug_dir)
265
266  return debug_dir
267
268
269def get_saved_model_pbtxt_path(export_dir):
270  return os.path.join(
271      compat.as_bytes(compat.path_to_str(export_dir)),
272      compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
273
274
275def get_saved_model_pb_path(export_dir):
276  return os.path.join(
277      compat.as_bytes(compat.path_to_str(export_dir)),
278      compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
279
280
281def get_debug_dir(export_dir):
282  """Returns path to the debug sub-directory in the SavedModel."""
283  return os.path.join(
284      compat.as_text(export_dir), compat.as_text(constants.DEBUG_DIRECTORY))
285
286# Based on tensor_bundle/byte_swap.cc
287byte_swappable = [
288    dtypes.float16, dtypes.float32, dtypes.float64, dtypes.bfloat16,
289    dtypes.complex64, dtypes.complex128, dtypes.uint16, dtypes.uint32,
290    dtypes.uint64, dtypes.int16, dtypes.int32, dtypes.int64, dtypes.qint16,
291    dtypes.quint16, dtypes.qint32
292]
293
294
295def swap_function_tensor_content(meta_graph_def, from_endiness, to_endiness):
296  functions = meta_graph_def.graph_def.library.function
297  for function in functions:
298    node_def = function.node_def
299    for node in node_def:
300      if node.op == "Const":
301        tensor = node.attr["value"].tensor
302        byte_swap_tensor_content(tensor, from_endiness, to_endiness)
303
304
305def byte_swap_tensor_content(tensor, from_endiness, to_endiness):
306  """Byte swaps."""
307  if tensor.dtype in byte_swappable:
308    tshape = tensor.tensor_shape.dim
309    tensor_bytes = tensor.tensor_content
310    if tensor_bytes:
311      tensor_size = 1
312      for sz in tshape:
313        tensor_size = tensor_size * sz.size
314      chunksize = int(len(tensor_bytes) / tensor_size)
315      # Split tensor_data into chunks for byte swapping.
316      to_swap = [
317          tensor_bytes[i:i + chunksize]
318          for i in range(0, len(tensor_bytes), chunksize)
319      ]
320      # Swap and replace tensor_content.
321      tensor.tensor_content = b"".join([
322          int.from_bytes(byteswap,
323                         from_endiness).to_bytes(chunksize, to_endiness)
324          for byteswap in to_swap
325      ])
326