1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""SavedModel utility functions implementation.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23from tensorflow.core.framework import types_pb2 24from tensorflow.core.protobuf import meta_graph_pb2 25from tensorflow.python.eager import context 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import sparse_tensor 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.lib.io import file_io 31from tensorflow.python.saved_model import constants 32from tensorflow.python.util import compat 33from tensorflow.python.util import deprecation 34from tensorflow.python.util.tf_export import tf_export 35 36 37# TensorInfo helpers. 38 39 40@tf_export(v1=["saved_model.build_tensor_info", 41 "saved_model.utils.build_tensor_info"]) 42@deprecation.deprecated( 43 None, 44 "This function will only be available through the v1 compatibility " 45 "library as tf.compat.v1.saved_model.utils.build_tensor_info or " 46 "tf.compat.v1.saved_model.build_tensor_info.") 47def build_tensor_info(tensor): 48 """Utility function to build TensorInfo proto from a Tensor. 49 50 Args: 51 tensor: Tensor or SparseTensor whose name, dtype and shape are used to 52 build the TensorInfo. For SparseTensors, the names of the three 53 constituent Tensors are used. 54 55 Returns: 56 A TensorInfo protocol buffer constructed based on the supplied argument. 57 58 Raises: 59 RuntimeError: If eager execution is enabled. 60 """ 61 if context.executing_eagerly(): 62 raise RuntimeError("build_tensor_info is not supported in Eager mode.") 63 return build_tensor_info_internal(tensor) 64 65 66def build_tensor_info_internal(tensor): 67 """Utility function to build TensorInfo proto from a Tensor.""" 68 tensor_info = meta_graph_pb2.TensorInfo( 69 dtype=dtypes.as_dtype(tensor.dtype).as_datatype_enum, 70 tensor_shape=tensor.get_shape().as_proto()) 71 if isinstance(tensor, sparse_tensor.SparseTensor): 72 tensor_info.coo_sparse.values_tensor_name = tensor.values.name 73 tensor_info.coo_sparse.indices_tensor_name = tensor.indices.name 74 tensor_info.coo_sparse.dense_shape_tensor_name = tensor.dense_shape.name 75 else: 76 tensor_info.name = tensor.name 77 return tensor_info 78 79 80def build_tensor_info_from_op(op): 81 """Utility function to build TensorInfo proto from an Op. 82 83 Note that this function should be used with caution. It is strictly restricted 84 to TensorFlow internal use-cases only. Please make sure you do need it before 85 using it. 86 87 This utility function overloads the TensorInfo proto by setting the name to 88 the Op's name, dtype to DT_INVALID and tensor_shape as None. One typical usage 89 is for the Op of the call site for the defunned function: 90 ```python 91 @function.defun 92 def some_vairable_initialiation_fn(value_a, value_b): 93 a = value_a 94 b = value_b 95 96 value_a = constant_op.constant(1, name="a") 97 value_b = constant_op.constant(2, name="b") 98 op_info = utils.build_op_info( 99 some_vairable_initialiation_fn(value_a, value_b)) 100 ``` 101 102 Args: 103 op: An Op whose name is used to build the TensorInfo. The name that points 104 to the Op could be fetched at run time in the Loader session. 105 106 Returns: 107 A TensorInfo protocol buffer constructed based on the supplied argument. 108 """ 109 return meta_graph_pb2.TensorInfo( 110 dtype=types_pb2.DT_INVALID, 111 tensor_shape=tensor_shape.unknown_shape().as_proto(), 112 name=op.name) 113 114 115@tf_export(v1=["saved_model.get_tensor_from_tensor_info", 116 "saved_model.utils.get_tensor_from_tensor_info"]) 117@deprecation.deprecated( 118 None, 119 "This function will only be available through the v1 compatibility " 120 "library as tf.compat.v1.saved_model.utils.get_tensor_from_tensor_info or " 121 "tf.compat.v1.saved_model.get_tensor_from_tensor_info.") 122def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None): 123 """Returns the Tensor or SparseTensor described by a TensorInfo proto. 124 125 Args: 126 tensor_info: A TensorInfo proto describing a Tensor or SparseTensor. 127 graph: The tf.Graph in which tensors are looked up. If None, the 128 current default graph is used. 129 import_scope: If not None, names in `tensor_info` are prefixed with this 130 string before lookup. 131 132 Returns: 133 The Tensor or SparseTensor in `graph` described by `tensor_info`. 134 135 Raises: 136 KeyError: If `tensor_info` does not correspond to a tensor in `graph`. 137 ValueError: If `tensor_info` is malformed. 138 """ 139 graph = graph or ops.get_default_graph() 140 def _get_tensor(name): 141 return graph.get_tensor_by_name( 142 ops.prepend_name_scope(name, import_scope=import_scope)) 143 encoding = tensor_info.WhichOneof("encoding") 144 if encoding == "name": 145 return _get_tensor(tensor_info.name) 146 elif encoding == "coo_sparse": 147 return sparse_tensor.SparseTensor( 148 _get_tensor(tensor_info.coo_sparse.indices_tensor_name), 149 _get_tensor(tensor_info.coo_sparse.values_tensor_name), 150 _get_tensor(tensor_info.coo_sparse.dense_shape_tensor_name)) 151 else: 152 raise ValueError("Invalid TensorInfo.encoding: %s" % encoding) 153 154 155def get_element_from_tensor_info(tensor_info, graph=None, import_scope=None): 156 """Returns the element in the graph described by a TensorInfo proto. 157 158 Args: 159 tensor_info: A TensorInfo proto describing an Op or Tensor by name. 160 graph: The tf.Graph in which tensors are looked up. If None, the current 161 default graph is used. 162 import_scope: If not None, names in `tensor_info` are prefixed with this 163 string before lookup. 164 165 Returns: 166 Op or tensor in `graph` described by `tensor_info`. 167 168 Raises: 169 KeyError: If `tensor_info` does not correspond to an op or tensor in `graph` 170 """ 171 graph = graph or ops.get_default_graph() 172 return graph.as_graph_element( 173 ops.prepend_name_scope(tensor_info.name, import_scope=import_scope)) 174 175 176# Path helpers. 177 178 179def get_or_create_variables_dir(export_dir): 180 """Return variables sub-directory, or create one if it doesn't exist.""" 181 variables_dir = get_variables_dir(export_dir) 182 if not file_io.file_exists(variables_dir): 183 file_io.recursive_create_dir(variables_dir) 184 return variables_dir 185 186 187def get_variables_dir(export_dir): 188 """Return variables sub-directory in the SavedModel.""" 189 return os.path.join( 190 compat.as_text(export_dir), 191 compat.as_text(constants.VARIABLES_DIRECTORY)) 192 193 194def get_variables_path(export_dir): 195 """Return the variables path, used as the prefix for checkpoint files.""" 196 return os.path.join( 197 compat.as_text(get_variables_dir(export_dir)), 198 compat.as_text(constants.VARIABLES_FILENAME)) 199 200 201def get_or_create_assets_dir(export_dir): 202 """Return assets sub-directory, or create one if it doesn't exist.""" 203 assets_destination_dir = get_assets_dir(export_dir) 204 205 if not file_io.file_exists(assets_destination_dir): 206 file_io.recursive_create_dir(assets_destination_dir) 207 208 return assets_destination_dir 209 210 211def get_assets_dir(export_dir): 212 """Return path to asset directory in the SavedModel.""" 213 return os.path.join( 214 compat.as_text(export_dir), 215 compat.as_text(constants.ASSETS_DIRECTORY)) 216