• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A TensorSpec class."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import common_shapes
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.util.tf_export import tf_export
28
29
30@tf_export("TensorSpec")
31class TensorSpec(object):
32  """Describes a tf.Tensor.
33
34  Metadata for describing the `tf.Tensor` objects accepted or returned
35  by some TensorFlow APIs.
36  """
37
38  __slots__ = ["_shape", "_shape_tuple", "_dtype", "_name"]
39
40  def __init__(self, shape, dtype=dtypes.float32, name=None):
41    """Creates a TensorSpec.
42
43    Args:
44      shape: Value convertible to `tf.TensorShape`. The shape of the tensor.
45      dtype: Value convertible to `tf.DType`. The type of the tensor values.
46      name: Optional name for the Tensor.
47
48    Raises:
49      TypeError: If shape is not convertible to a `tf.TensorShape`, or dtype is
50        not convertible to a `tf.DType`.
51    """
52    self._shape = tensor_shape.TensorShape(shape)
53    try:
54      self._shape_tuple = tuple(self.shape.as_list())
55    except ValueError:
56      self._shape_tuple = None
57    self._dtype = dtypes.as_dtype(dtype)
58    self._name = name
59
60  @classmethod
61  def from_spec(cls, spec, name=None):
62    return cls(spec.shape, spec.dtype, name or spec.name)
63
64  @classmethod
65  def from_tensor(cls, tensor, name=None):
66    if isinstance(tensor, ops.EagerTensor):
67      return TensorSpec(tensor.shape, tensor.dtype, name)
68    elif isinstance(tensor, ops.Tensor):
69      return TensorSpec(tensor.shape, tensor.dtype, name or tensor.op.name)
70    else:
71      raise ValueError("`tensor` should be a tf.Tensor")
72
73  @property
74  def shape(self):
75    """Returns the `TensorShape` that represents the shape of the tensor."""
76    return self._shape
77
78  @property
79  def dtype(self):
80    """Returns the `dtype` of elements in the tensor."""
81    return self._dtype
82
83  @property
84  def name(self):
85    """Returns the (optionally provided) name of the described tensor."""
86    return self._name
87
88  def is_compatible_with(self, spec_or_tensor):
89    """Returns True if spec_or_tensor is compatible with this TensorSpec.
90
91    Two tensors are considered compatible if they have the same dtype
92    and their shapes are compatible (see `tf.TensorShape.is_compatible_with`).
93
94    Args:
95      spec_or_tensor: A tf.TensorSpec or a tf.Tensor
96
97    Returns:
98      True if spec_or_tensor is compatible with self.
99    """
100    return (self._dtype.is_compatible_with(spec_or_tensor.dtype) and
101            self._shape.is_compatible_with(spec_or_tensor.shape))
102
103  def __repr__(self):
104    return "TensorSpec(shape={}, dtype={}, name={})".format(
105        self.shape, repr(self.dtype), repr(self.name))
106
107  def __hash__(self):
108    return hash((self._shape_tuple, self.dtype))
109
110  def __eq__(self, other):
111    return (self._shape_tuple == other._shape_tuple  # pylint: disable=protected-access
112            and self.dtype == other.dtype
113            and self._name == other._name)  # pylint: disable=protected-access
114
115  def __ne__(self, other):
116    return not self == other
117
118  def __reduce__(self):
119    return TensorSpec, (self._shape, self._dtype, self._name)
120
121
122class BoundedTensorSpec(TensorSpec):
123  """A `TensorSpec` that specifies minimum and maximum values.
124
125  Example usage:
126  ```python
127  spec = tensor_spec.BoundedTensorSpec((1, 2, 3), tf.float32, 0, (5, 5, 5))
128  tf_minimum = tf.convert_to_tensor(spec.minimum, dtype=spec.dtype)
129  tf_maximum = tf.convert_to_tensor(spec.maximum, dtype=spec.dtype)
130  ```
131
132  Bounds are meant to be inclusive. This is especially important for
133  integer types. The following spec will be satisfied by tensors
134  with values in the set {0, 1, 2}:
135  ```python
136  spec = tensor_spec.BoundedTensorSpec((3, 5), tf.int32, 0, 2)
137  ```
138  """
139
140  __slots__ = ("_minimum", "_maximum")
141
142  def __init__(self, shape, dtype, minimum, maximum, name=None):
143    """Initializes a new `BoundedTensorSpec`.
144
145    Args:
146      shape: Value convertible to `tf.TensorShape`. The shape of the tensor.
147      dtype: Value convertible to `tf.DType`. The type of the tensor values.
148      minimum: Number or sequence specifying the minimum element bounds
149        (inclusive). Must be broadcastable to `shape`.
150      maximum: Number or sequence specifying the maximum element bounds
151        (inclusive). Must be broadcastable to `shape`.
152      name: Optional string containing a semantic name for the corresponding
153        array. Defaults to `None`.
154
155    Raises:
156      ValueError: If `minimum` or `maximum` are not provided or not
157        broadcastable to `shape`.
158      TypeError: If the shape is not an iterable or if the `dtype` is an invalid
159        numpy dtype.
160    """
161    super(BoundedTensorSpec, self).__init__(shape, dtype, name)
162
163    if minimum is None or maximum is None:
164      raise ValueError("minimum and maximum must be provided; but saw "
165                       "'%s' and '%s'" % (minimum, maximum))
166
167    try:
168      minimum_shape = np.shape(minimum)
169      common_shapes.broadcast_shape(
170          tensor_shape.TensorShape(minimum_shape), self.shape)
171    except ValueError as exception:
172      raise ValueError("minimum is not compatible with shape. "
173                       "Message: {!r}.".format(exception))
174
175    try:
176      maximum_shape = np.shape(maximum)
177      common_shapes.broadcast_shape(
178          tensor_shape.TensorShape(maximum_shape), self.shape)
179    except ValueError as exception:
180      raise ValueError("maximum is not compatible with shape. "
181                       "Message: {!r}.".format(exception))
182
183    self._minimum = np.array(minimum, dtype=self.dtype.as_numpy_dtype())
184    self._minimum.setflags(write=False)
185
186    self._maximum = np.array(maximum, dtype=self.dtype.as_numpy_dtype())
187    self._maximum.setflags(write=False)
188
189  @classmethod
190  def from_spec(cls, spec):
191    dtype = dtypes.as_dtype(spec.dtype)
192    minimum = getattr(spec, "minimum", dtype.min)
193    maximum = getattr(spec, "maximum", dtype.max)
194    return BoundedTensorSpec(spec.shape, dtype, minimum, maximum, spec.name)
195
196  @property
197  def minimum(self):
198    """Returns a NumPy array specifying the minimum bounds (inclusive)."""
199    return self._minimum
200
201  @property
202  def maximum(self):
203    """Returns a NumPy array specifying the maximum bounds (inclusive)."""
204    return self._maximum
205
206  def __repr__(self):
207    s = "BoundedTensorSpec(shape={}, dtype={}, name={}, minimum={}, maximum={})"
208    return s.format(self.shape, repr(self.dtype), repr(self.name),
209                    repr(self.minimum), repr(self.maximum))
210
211  def __eq__(self, other):
212    tensor_spec_eq = super(BoundedTensorSpec, self).__eq__(other)
213    return (tensor_spec_eq and np.allclose(self.minimum, other.minimum) and
214            np.allclose(self.maximum, other.maximum))
215
216  def __reduce__(self):
217    return BoundedTensorSpec, (self._shape, self._dtype, self._minimum,
218                               self._maximum, self._name)
219