1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A TensorSpec class.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import common_shapes 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import tensor_shape 27from tensorflow.python.util.tf_export import tf_export 28 29 30@tf_export("TensorSpec") 31class TensorSpec(object): 32 """Describes a tf.Tensor. 33 34 Metadata for describing the `tf.Tensor` objects accepted or returned 35 by some TensorFlow APIs. 36 """ 37 38 __slots__ = ["_shape", "_shape_tuple", "_dtype", "_name"] 39 40 def __init__(self, shape, dtype=dtypes.float32, name=None): 41 """Creates a TensorSpec. 42 43 Args: 44 shape: Value convertible to `tf.TensorShape`. The shape of the tensor. 45 dtype: Value convertible to `tf.DType`. The type of the tensor values. 46 name: Optional name for the Tensor. 47 48 Raises: 49 TypeError: If shape is not convertible to a `tf.TensorShape`, or dtype is 50 not convertible to a `tf.DType`. 51 """ 52 self._shape = tensor_shape.TensorShape(shape) 53 try: 54 self._shape_tuple = tuple(self.shape.as_list()) 55 except ValueError: 56 self._shape_tuple = None 57 self._dtype = dtypes.as_dtype(dtype) 58 self._name = name 59 60 @classmethod 61 def from_spec(cls, spec, name=None): 62 return cls(spec.shape, spec.dtype, name or spec.name) 63 64 @classmethod 65 def from_tensor(cls, tensor, name=None): 66 if isinstance(tensor, ops.EagerTensor): 67 return TensorSpec(tensor.shape, tensor.dtype, name) 68 elif isinstance(tensor, ops.Tensor): 69 return TensorSpec(tensor.shape, tensor.dtype, name or tensor.op.name) 70 else: 71 raise ValueError("`tensor` should be a tf.Tensor") 72 73 @property 74 def shape(self): 75 """Returns the `TensorShape` that represents the shape of the tensor.""" 76 return self._shape 77 78 @property 79 def dtype(self): 80 """Returns the `dtype` of elements in the tensor.""" 81 return self._dtype 82 83 @property 84 def name(self): 85 """Returns the (optionally provided) name of the described tensor.""" 86 return self._name 87 88 def is_compatible_with(self, spec_or_tensor): 89 """Returns True if spec_or_tensor is compatible with this TensorSpec. 90 91 Two tensors are considered compatible if they have the same dtype 92 and their shapes are compatible (see `tf.TensorShape.is_compatible_with`). 93 94 Args: 95 spec_or_tensor: A tf.TensorSpec or a tf.Tensor 96 97 Returns: 98 True if spec_or_tensor is compatible with self. 99 """ 100 return (self._dtype.is_compatible_with(spec_or_tensor.dtype) and 101 self._shape.is_compatible_with(spec_or_tensor.shape)) 102 103 def __repr__(self): 104 return "TensorSpec(shape={}, dtype={}, name={})".format( 105 self.shape, repr(self.dtype), repr(self.name)) 106 107 def __hash__(self): 108 return hash((self._shape_tuple, self.dtype)) 109 110 def __eq__(self, other): 111 return (self._shape_tuple == other._shape_tuple # pylint: disable=protected-access 112 and self.dtype == other.dtype 113 and self._name == other._name) # pylint: disable=protected-access 114 115 def __ne__(self, other): 116 return not self == other 117 118 def __reduce__(self): 119 return TensorSpec, (self._shape, self._dtype, self._name) 120 121 122class BoundedTensorSpec(TensorSpec): 123 """A `TensorSpec` that specifies minimum and maximum values. 124 125 Example usage: 126 ```python 127 spec = tensor_spec.BoundedTensorSpec((1, 2, 3), tf.float32, 0, (5, 5, 5)) 128 tf_minimum = tf.convert_to_tensor(spec.minimum, dtype=spec.dtype) 129 tf_maximum = tf.convert_to_tensor(spec.maximum, dtype=spec.dtype) 130 ``` 131 132 Bounds are meant to be inclusive. This is especially important for 133 integer types. The following spec will be satisfied by tensors 134 with values in the set {0, 1, 2}: 135 ```python 136 spec = tensor_spec.BoundedTensorSpec((3, 5), tf.int32, 0, 2) 137 ``` 138 """ 139 140 __slots__ = ("_minimum", "_maximum") 141 142 def __init__(self, shape, dtype, minimum, maximum, name=None): 143 """Initializes a new `BoundedTensorSpec`. 144 145 Args: 146 shape: Value convertible to `tf.TensorShape`. The shape of the tensor. 147 dtype: Value convertible to `tf.DType`. The type of the tensor values. 148 minimum: Number or sequence specifying the minimum element bounds 149 (inclusive). Must be broadcastable to `shape`. 150 maximum: Number or sequence specifying the maximum element bounds 151 (inclusive). Must be broadcastable to `shape`. 152 name: Optional string containing a semantic name for the corresponding 153 array. Defaults to `None`. 154 155 Raises: 156 ValueError: If `minimum` or `maximum` are not provided or not 157 broadcastable to `shape`. 158 TypeError: If the shape is not an iterable or if the `dtype` is an invalid 159 numpy dtype. 160 """ 161 super(BoundedTensorSpec, self).__init__(shape, dtype, name) 162 163 if minimum is None or maximum is None: 164 raise ValueError("minimum and maximum must be provided; but saw " 165 "'%s' and '%s'" % (minimum, maximum)) 166 167 try: 168 minimum_shape = np.shape(minimum) 169 common_shapes.broadcast_shape( 170 tensor_shape.TensorShape(minimum_shape), self.shape) 171 except ValueError as exception: 172 raise ValueError("minimum is not compatible with shape. " 173 "Message: {!r}.".format(exception)) 174 175 try: 176 maximum_shape = np.shape(maximum) 177 common_shapes.broadcast_shape( 178 tensor_shape.TensorShape(maximum_shape), self.shape) 179 except ValueError as exception: 180 raise ValueError("maximum is not compatible with shape. " 181 "Message: {!r}.".format(exception)) 182 183 self._minimum = np.array(minimum, dtype=self.dtype.as_numpy_dtype()) 184 self._minimum.setflags(write=False) 185 186 self._maximum = np.array(maximum, dtype=self.dtype.as_numpy_dtype()) 187 self._maximum.setflags(write=False) 188 189 @classmethod 190 def from_spec(cls, spec): 191 dtype = dtypes.as_dtype(spec.dtype) 192 minimum = getattr(spec, "minimum", dtype.min) 193 maximum = getattr(spec, "maximum", dtype.max) 194 return BoundedTensorSpec(spec.shape, dtype, minimum, maximum, spec.name) 195 196 @property 197 def minimum(self): 198 """Returns a NumPy array specifying the minimum bounds (inclusive).""" 199 return self._minimum 200 201 @property 202 def maximum(self): 203 """Returns a NumPy array specifying the maximum bounds (inclusive).""" 204 return self._maximum 205 206 def __repr__(self): 207 s = "BoundedTensorSpec(shape={}, dtype={}, name={}, minimum={}, maximum={})" 208 return s.format(self.shape, repr(self.dtype), repr(self.name), 209 repr(self.minimum), repr(self.maximum)) 210 211 def __eq__(self, other): 212 tensor_spec_eq = super(BoundedTensorSpec, self).__eq__(other) 213 return (tensor_spec_eq and np.allclose(self.minimum, other.minimum) and 214 np.allclose(self.maximum, other.maximum)) 215 216 def __reduce__(self): 217 return BoundedTensorSpec, (self._shape, self._dtype, self._minimum, 218 self._maximum, self._name) 219