1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""TensorSignature class and utilities (deprecated). 17 18This module and all its submodules are deprecated. See 19[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 20for migration instructions. 21""" 22 23from __future__ import absolute_import 24from __future__ import division 25from __future__ import print_function 26 27import collections 28 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import sparse_tensor 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import parsing_ops 35 36 37class TensorSignature(collections.namedtuple( 38 "TensorSignature", ["dtype", "shape", "is_sparse"])): 39 """Signature of the `Tensor` object. 40 41 THIS CLASS IS DEPRECATED. See 42 [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 43 for general migration instructions. 44 45 Useful to check compatibility of tensors. 46 47 Example: 48 49 ```python 50 examples = tf.placeholder(...) 51 inputs = {'a': var_a, 'b': var_b} 52 signatures = tensor_signature.create_signatures(inputs) 53 result = tensor_signature.create_example_parser_from_signatures( 54 signatures, examples) 55 self.assertTrue(tensor_signature.tensors_compatible(result, signatures)) 56 ``` 57 58 Attributes: 59 dtype: `DType` object. 60 shape: `TensorShape` object. 61 """ 62 63 def __new__(cls, tensor): 64 if isinstance(tensor, sparse_tensor.SparseTensor): 65 return super(TensorSignature, cls).__new__( 66 cls, dtype=tensor.values.dtype, shape=None, is_sparse=True) 67 return super(TensorSignature, cls).__new__( 68 cls, dtype=tensor.dtype, shape=tensor.get_shape(), is_sparse=False) 69 70 def is_compatible_with(self, other): 71 """Returns True if signatures are compatible.""" 72 73 def _shape_is_compatible_0dim(this, other): 74 """Checks that shapes are compatible skipping dim 0.""" 75 other = tensor_shape.as_shape(other) 76 # If shapes are None (unknown) they may be compatible. 77 if this.dims is None or other.dims is None: 78 return True 79 if this.ndims != other.ndims: 80 return False 81 for dim, (x_dim, y_dim) in enumerate(zip(this.dims, other.dims)): 82 if dim == 0: 83 continue 84 if not x_dim.is_compatible_with(y_dim): 85 return False 86 return True 87 88 if other.is_sparse: 89 return self.is_sparse and self.dtype.is_compatible_with(other.dtype) 90 return (self.dtype.is_compatible_with(other.dtype) and 91 _shape_is_compatible_0dim(self.shape, other.shape) and 92 not self.is_sparse) 93 94 def get_placeholder(self): 95 if self.is_sparse: 96 return array_ops.sparse_placeholder(dtype=self.dtype) 97 return array_ops.placeholder(dtype=self.dtype, 98 shape=[None] + list(self.shape[1:])) 99 100 def get_feature_spec(self): 101 dtype = self.dtype 102 # Convert, because example parser only supports float32, int64 and string. 103 if dtype == dtypes.int32: 104 dtype = dtypes.int64 105 if dtype == dtypes.float64: 106 dtype = dtypes.float32 107 if self.is_sparse: 108 return parsing_ops.VarLenFeature(dtype=dtype) 109 return parsing_ops.FixedLenFeature(shape=self.shape[1:], dtype=dtype) 110 111 112def tensors_compatible(tensors, signatures): 113 """Check that tensors are compatible with signatures. 114 115 Args: 116 tensors: Dict of `Tensor` objects or single `Tensor` object. 117 signatures: Dict of `TensorSignature` objects or 118 single `TensorSignature` object. 119 120 Returns: 121 True if all tensors are compatible, False otherwise. 122 """ 123 # Dict of Tensors as input. 124 if tensors is None: 125 return signatures is None 126 127 if isinstance(tensors, dict): 128 if not isinstance(signatures, dict): 129 return False 130 for key in signatures: 131 if key not in tensors: 132 return False 133 if not TensorSignature(tensors[key]).is_compatible_with(signatures[key]): 134 return False 135 return True 136 137 # Single tensor as input. 138 if signatures is None or isinstance(signatures, dict): 139 return False 140 return TensorSignature(tensors).is_compatible_with(signatures) 141 142 143def create_signatures(tensors): 144 """Creates TensorSignature objects for given tensors. 145 146 Args: 147 tensors: Dict of `Tensor` objects or single `Tensor`. 148 149 Returns: 150 Dict of `TensorSignature` objects or single `TensorSignature`. 151 """ 152 if isinstance(tensors, dict): 153 return { 154 key: TensorSignature(tensors[key]) for key in tensors} 155 if tensors is None: 156 return None 157 return TensorSignature(tensors) 158 159 160def create_placeholders_from_signatures(signatures): 161 """Creates placeholders from given signatures. 162 163 Args: 164 signatures: Dict of `TensorSignature` objects or single `TensorSignature`, 165 or `None`. 166 167 Returns: 168 Dict of `tf.placeholder` objects or single `tf.placeholder`, or `None`. 169 """ 170 if signatures is None: 171 return None 172 if not isinstance(signatures, dict): 173 return signatures.get_placeholder() 174 return { 175 key: signatures[key].get_placeholder() 176 for key in signatures} 177 178 179def create_example_parser_from_signatures(signatures, examples_batch, 180 single_feature_name="feature"): 181 """Creates example parser from given signatures. 182 183 Args: 184 signatures: Dict of `TensorSignature` objects or single `TensorSignature`. 185 examples_batch: string `Tensor` of serialized `Example` proto. 186 single_feature_name: string, single feature name. 187 188 Returns: 189 features: `Tensor` or `dict` of `Tensor` objects. 190 """ 191 feature_spec = {} 192 if not isinstance(signatures, dict): 193 feature_spec[single_feature_name] = signatures.get_feature_spec() 194 else: 195 feature_spec = {key: signatures[key].get_feature_spec() 196 for key in signatures} 197 features = parsing_ops.parse_example(examples_batch, feature_spec) 198 if not isinstance(signatures, dict): 199 # Returns single feature, casts if needed. 200 features = features[single_feature_name] 201 if not signatures.dtype.is_compatible_with(features.dtype): 202 features = math_ops.cast(features, signatures.dtype) 203 return features 204 # Returns dict of features, casts if needed. 205 for name in features: 206 if not signatures[name].dtype.is_compatible_with(features[name].dtype): 207 features[name] = math_ops.cast(features[name], signatures[name].dtype) 208 return features 209