1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A TensorSpec class.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import common_shapes 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.framework import type_spec 28from tensorflow.python.util import _pywrap_utils 29from tensorflow.python.util.tf_export import tf_export 30 31 32class DenseSpec(type_spec.TypeSpec): 33 """Describes a dense object with shape, dtype, and name.""" 34 35 __slots__ = ["_shape", "_shape_tuple", "_dtype", "_name"] 36 37 _component_specs = property(lambda self: self) 38 39 def __init__(self, shape, dtype=dtypes.float32, name=None): 40 """Creates a TensorSpec. 41 42 Args: 43 shape: Value convertible to `tf.TensorShape`. The shape of the tensor. 44 dtype: Value convertible to `tf.DType`. The type of the tensor values. 45 name: Optional name for the Tensor. 46 47 Raises: 48 TypeError: If shape is not convertible to a `tf.TensorShape`, or dtype is 49 not convertible to a `tf.DType`. 50 """ 51 self._shape = tensor_shape.TensorShape(shape) 52 try: 53 self._shape_tuple = tuple(self.shape.as_list()) 54 except ValueError: 55 self._shape_tuple = None 56 self._dtype = dtypes.as_dtype(dtype) 57 self._name = name 58 59 @property 60 def shape(self): 61 """Returns the `TensorShape` that represents the shape of the tensor.""" 62 return self._shape 63 64 @property 65 def dtype(self): 66 """Returns the `dtype` of elements in the tensor.""" 67 return self._dtype 68 69 @property 70 def name(self): 71 """Returns the (optionally provided) name of the described tensor.""" 72 return self._name 73 74 def is_compatible_with(self, spec_or_value): 75 return (isinstance(spec_or_value, (DenseSpec, self.value_type)) and 76 self._dtype.is_compatible_with(spec_or_value.dtype) and 77 self._shape.is_compatible_with(spec_or_value.shape)) 78 79 def __repr__(self): 80 return "{}(shape={}, dtype={}, name={})".format( 81 type(self).__name__, self.shape, repr(self.dtype), repr(self.name)) 82 83 def __hash__(self): 84 return hash((self._shape_tuple, self.dtype)) 85 86 def __eq__(self, other): 87 # pylint: disable=protected-access 88 return (type(self) is type(other) and 89 self._shape_tuple == other._shape_tuple 90 and self._dtype == other._dtype 91 and self._name == other._name) 92 93 def __ne__(self, other): 94 return not self == other 95 96 def most_specific_compatible_type(self, other): 97 if (type(self) is not type(other)) or (self._dtype != other.dtype): 98 raise ValueError("Types are not compatible: %r vs %r" % (self, other)) 99 shape = self._shape.most_specific_compatible_shape(other.shape) 100 name = self._name if self._name == other.name else None 101 return type(self)(shape, self._dtype, name) 102 103 def _serialize(self): 104 return (self._shape, self._dtype, self._name) 105 106 def _to_legacy_output_types(self): 107 return self._dtype 108 109 def _to_legacy_output_shapes(self): 110 return self._shape 111 112 def _to_legacy_output_classes(self): 113 return self.value_type 114 115 116@tf_export("TensorSpec") 117@type_spec.register("tf.TensorSpec") 118class TensorSpec(DenseSpec, type_spec.BatchableTypeSpec): 119 """Describes a tf.Tensor. 120 121 Metadata for describing the `tf.Tensor` objects accepted or returned 122 by some TensorFlow APIs. 123 """ 124 125 __slots__ = [] 126 127 def is_compatible_with(self, spec_or_tensor): # pylint:disable=useless-super-delegation 128 """Returns True if spec_or_tensor is compatible with this TensorSpec. 129 130 Two tensors are considered compatible if they have the same dtype 131 and their shapes are compatible (see `tf.TensorShape.is_compatible_with`). 132 133 Args: 134 spec_or_tensor: A tf.TensorSpec or a tf.Tensor 135 136 Returns: 137 True if spec_or_tensor is compatible with self. 138 """ 139 return super(TensorSpec, self).is_compatible_with(spec_or_tensor) 140 141 @classmethod 142 def from_spec(cls, spec, name=None): 143 """Returns a `TensorSpec` with the same shape and dtype as `spec`. 144 145 >>> spec = tf.TensorSpec(shape=[8, 3], dtype=tf.int32, name="OriginalName") 146 >>> tf.TensorSpec.from_spec(spec, "NewName") 147 TensorSpec(shape=(8, 3), dtype=tf.int32, name='NewName') 148 149 Args: 150 spec: The `TypeSpec` used to create the new `TensorSpec`. 151 name: The name for the new `TensorSpec`. Defaults to `spec.name`. 152 """ 153 return cls(spec.shape, spec.dtype, name or spec.name) 154 155 @classmethod 156 def from_tensor(cls, tensor, name=None): 157 """Returns a `TensorSpec` that describes `tensor`. 158 159 >>> tf.TensorSpec.from_tensor(tf.constant([1, 2, 3])) 160 TensorSpec(shape=(3,), dtype=tf.int32, name=None) 161 162 Args: 163 tensor: The `tf.Tensor` that should be described. 164 name: A name for the `TensorSpec`. Defaults to `tensor.op.name`. 165 166 Returns: 167 A `TensorSpec` that describes `tensor`. 168 """ 169 if isinstance(tensor, ops.EagerTensor): 170 return TensorSpec(tensor.shape, tensor.dtype, name) 171 elif isinstance(tensor, ops.Tensor): 172 return TensorSpec(tensor.shape, tensor.dtype, name or tensor.op.name) 173 else: 174 raise ValueError("`tensor` should be a tf.Tensor") 175 176 @property 177 def value_type(self): 178 """The Python type for values that are compatible with this TypeSpec.""" 179 return ops.Tensor 180 181 def _to_components(self, value): 182 try: 183 value = ops.convert_to_tensor(value, self._dtype) 184 except (TypeError, ValueError): 185 raise ValueError("Value %r is not convertible to a tensor with dtype %s " 186 "and shape %s." % (value, self._dtype, self._shape)) 187 if not value.shape.is_compatible_with(self._shape): 188 raise ValueError("Value %r is not convertible to a tensor with dtype %s " 189 "and shape %s." % (value, self._dtype, self._shape)) 190 return value 191 192 def _from_components(self, components): 193 return components 194 195 def _from_compatible_tensor_list(self, tensor_list): 196 # TODO(b/112266545): It would be cleaner to create a new `ensure_shape()` 197 # op here and return that, instead of mutating the input's shape using 198 # `Tensor.set_shape()`. However, that would add extra ops, which could 199 # impact performance. When this bug is resolved, we should be able to add 200 # the `ensure_shape()` ops and optimize them away using contextual shape 201 # information. 202 assert len(tensor_list) == 1 203 tensor_list[0].set_shape(self._shape) 204 return tensor_list[0] 205 206 def _to_batchable_tensor_list(self, value, batched=False): 207 if batched and self._shape.merge_with(value.shape).ndims == 0: 208 raise ValueError("Unbatching a tensor is only supported for rank >= 1") 209 return self._to_components(value) 210 211 def _batch(self, batch_size): 212 return TensorSpec( 213 tensor_shape.TensorShape([batch_size]).concatenate(self._shape), 214 self._dtype) 215 216 def _unbatch(self): 217 if self._shape.ndims == 0: 218 raise ValueError("Unbatching a tensor is only supported for rank >= 1") 219 return TensorSpec(self._shape[1:], self._dtype) 220 221 222# TODO(b/133606651): Should is_compatible_with should check min/max bounds? 223@type_spec.register("tf.BoundedTensorSpec") 224class BoundedTensorSpec(TensorSpec): 225 """A `TensorSpec` that specifies minimum and maximum values. 226 227 Example usage: 228 ```python 229 spec = tensor_spec.BoundedTensorSpec((1, 2, 3), tf.float32, 0, (5, 5, 5)) 230 tf_minimum = tf.convert_to_tensor(spec.minimum, dtype=spec.dtype) 231 tf_maximum = tf.convert_to_tensor(spec.maximum, dtype=spec.dtype) 232 ``` 233 234 Bounds are meant to be inclusive. This is especially important for 235 integer types. The following spec will be satisfied by tensors 236 with values in the set {0, 1, 2}: 237 ```python 238 spec = tensor_spec.BoundedTensorSpec((3, 5), tf.int32, 0, 2) 239 ``` 240 """ 241 242 __slots__ = ("_minimum", "_maximum") 243 244 def __init__(self, shape, dtype, minimum, maximum, name=None): 245 """Initializes a new `BoundedTensorSpec`. 246 247 Args: 248 shape: Value convertible to `tf.TensorShape`. The shape of the tensor. 249 dtype: Value convertible to `tf.DType`. The type of the tensor values. 250 minimum: Number or sequence specifying the minimum element bounds 251 (inclusive). Must be broadcastable to `shape`. 252 maximum: Number or sequence specifying the maximum element bounds 253 (inclusive). Must be broadcastable to `shape`. 254 name: Optional string containing a semantic name for the corresponding 255 array. Defaults to `None`. 256 257 Raises: 258 ValueError: If `minimum` or `maximum` are not provided or not 259 broadcastable to `shape`. 260 TypeError: If the shape is not an iterable or if the `dtype` is an invalid 261 numpy dtype. 262 """ 263 super(BoundedTensorSpec, self).__init__(shape, dtype, name) 264 265 if minimum is None or maximum is None: 266 raise ValueError("minimum and maximum must be provided; but saw " 267 "'%s' and '%s'" % (minimum, maximum)) 268 269 try: 270 minimum_shape = np.shape(minimum) 271 common_shapes.broadcast_shape( 272 tensor_shape.TensorShape(minimum_shape), self.shape) 273 except ValueError as exception: 274 raise ValueError("minimum is not compatible with shape. " 275 "Message: {!r}.".format(exception)) 276 277 try: 278 maximum_shape = np.shape(maximum) 279 common_shapes.broadcast_shape( 280 tensor_shape.TensorShape(maximum_shape), self.shape) 281 except ValueError as exception: 282 raise ValueError("maximum is not compatible with shape. " 283 "Message: {!r}.".format(exception)) 284 285 self._minimum = np.array(minimum, dtype=self.dtype.as_numpy_dtype) 286 self._minimum.setflags(write=False) 287 288 self._maximum = np.array(maximum, dtype=self.dtype.as_numpy_dtype) 289 self._maximum.setflags(write=False) 290 291 @classmethod 292 def from_spec(cls, spec): 293 """Returns a `TensorSpec` with the same shape and dtype as `spec`. 294 295 If `spec` is a `BoundedTensorSpec`, then the new spec's bounds are set to 296 `spec.minimum` and `spec.maximum`; otherwise, the bounds are set to 297 `spec.dtype.min` and `spec.dtype.max`. 298 299 >>> spec = tf.TensorSpec(shape=[8, 3], dtype=tf.int32, name="x") 300 >>> BoundedTensorSpec.from_spec(spec) 301 BoundedTensorSpec(shape=(8, 3), dtype=tf.int32, name='x', 302 minimum=array(-2147483648, dtype=int32), 303 maximum=array(2147483647, dtype=int32)) 304 305 Args: 306 spec: The `TypeSpec` used to create the new `BoundedTensorSpec`. 307 """ 308 dtype = dtypes.as_dtype(spec.dtype) 309 minimum = getattr(spec, "minimum", dtype.min) 310 maximum = getattr(spec, "maximum", dtype.max) 311 return BoundedTensorSpec(spec.shape, dtype, minimum, maximum, spec.name) 312 313 @property 314 def minimum(self): 315 """Returns a NumPy array specifying the minimum bounds (inclusive).""" 316 return self._minimum 317 318 @property 319 def maximum(self): 320 """Returns a NumPy array specifying the maximum bounds (inclusive).""" 321 return self._maximum 322 323 def __repr__(self): 324 s = "BoundedTensorSpec(shape={}, dtype={}, name={}, minimum={}, maximum={})" 325 return s.format(self.shape, repr(self.dtype), repr(self.name), 326 repr(self.minimum), repr(self.maximum)) 327 328 def __eq__(self, other): 329 tensor_spec_eq = super(BoundedTensorSpec, self).__eq__(other) 330 return (tensor_spec_eq and np.allclose(self.minimum, other.minimum) and 331 np.allclose(self.maximum, other.maximum)) 332 333 def __hash__(self): 334 return hash((self._shape_tuple, self.dtype)) 335 336 def __reduce__(self): 337 return BoundedTensorSpec, (self._shape, self._dtype, self._minimum, 338 self._maximum, self._name) 339 340 def _serialize(self): 341 return (self._shape, self._dtype, self._minimum, self._maximum, self._name) 342 343 344_pywrap_utils.RegisterType("TensorSpec", TensorSpec) 345 346 347# Note: we do not include Tensor names when constructing TypeSpecs. 348type_spec.register_type_spec_from_value_converter( 349 ops.Tensor, 350 lambda tensor: TensorSpec(tensor.shape, tensor.dtype)) 351 352type_spec.register_type_spec_from_value_converter( 353 np.ndarray, 354 lambda array: TensorSpec(array.shape, array.dtype)) 355