• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A TensorSpec class."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.framework import common_shapes
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import tensor_shape
27from tensorflow.python.framework import type_spec
28from tensorflow.python.util import _pywrap_utils
29from tensorflow.python.util.tf_export import tf_export
30
31
32class DenseSpec(type_spec.TypeSpec):
33  """Describes a dense object with shape, dtype, and name."""
34
35  __slots__ = ["_shape", "_shape_tuple", "_dtype", "_name"]
36
37  _component_specs = property(lambda self: self)
38
39  def __init__(self, shape, dtype=dtypes.float32, name=None):
40    """Creates a TensorSpec.
41
42    Args:
43      shape: Value convertible to `tf.TensorShape`. The shape of the tensor.
44      dtype: Value convertible to `tf.DType`. The type of the tensor values.
45      name: Optional name for the Tensor.
46
47    Raises:
48      TypeError: If shape is not convertible to a `tf.TensorShape`, or dtype is
49        not convertible to a `tf.DType`.
50    """
51    self._shape = tensor_shape.TensorShape(shape)
52    try:
53      self._shape_tuple = tuple(self.shape.as_list())
54    except ValueError:
55      self._shape_tuple = None
56    self._dtype = dtypes.as_dtype(dtype)
57    self._name = name
58
59  @property
60  def shape(self):
61    """Returns the `TensorShape` that represents the shape of the tensor."""
62    return self._shape
63
64  @property
65  def dtype(self):
66    """Returns the `dtype` of elements in the tensor."""
67    return self._dtype
68
69  @property
70  def name(self):
71    """Returns the (optionally provided) name of the described tensor."""
72    return self._name
73
74  def is_compatible_with(self, spec_or_value):
75    return (isinstance(spec_or_value, (DenseSpec, self.value_type)) and
76            self._dtype.is_compatible_with(spec_or_value.dtype) and
77            self._shape.is_compatible_with(spec_or_value.shape))
78
79  def __repr__(self):
80    return "{}(shape={}, dtype={}, name={})".format(
81        type(self).__name__, self.shape, repr(self.dtype), repr(self.name))
82
83  def __hash__(self):
84    return hash((self._shape_tuple, self.dtype))
85
86  def __eq__(self, other):
87    # pylint: disable=protected-access
88    return (type(self) is type(other) and
89            self._shape_tuple == other._shape_tuple
90            and self._dtype == other._dtype
91            and self._name == other._name)
92
93  def __ne__(self, other):
94    return not self == other
95
96  def most_specific_compatible_type(self, other):
97    if (type(self) is not type(other)) or (self._dtype != other.dtype):
98      raise ValueError("Types are not compatible: %r vs %r" % (self, other))
99    shape = self._shape.most_specific_compatible_shape(other.shape)
100    name = self._name if self._name == other.name else None
101    return type(self)(shape, self._dtype, name)
102
103  def _serialize(self):
104    return (self._shape, self._dtype, self._name)
105
106  def _to_legacy_output_types(self):
107    return self._dtype
108
109  def _to_legacy_output_shapes(self):
110    return self._shape
111
112  def _to_legacy_output_classes(self):
113    return self.value_type
114
115
116@tf_export("TensorSpec")
117@type_spec.register("tf.TensorSpec")
118class TensorSpec(DenseSpec, type_spec.BatchableTypeSpec):
119  """Describes a tf.Tensor.
120
121  Metadata for describing the `tf.Tensor` objects accepted or returned
122  by some TensorFlow APIs.
123  """
124
125  __slots__ = []
126
127  def is_compatible_with(self, spec_or_tensor):  # pylint:disable=useless-super-delegation
128    """Returns True if spec_or_tensor is compatible with this TensorSpec.
129
130    Two tensors are considered compatible if they have the same dtype
131    and their shapes are compatible (see `tf.TensorShape.is_compatible_with`).
132
133    Args:
134      spec_or_tensor: A tf.TensorSpec or a tf.Tensor
135
136    Returns:
137      True if spec_or_tensor is compatible with self.
138    """
139    return super(TensorSpec, self).is_compatible_with(spec_or_tensor)
140
141  @classmethod
142  def from_spec(cls, spec, name=None):
143    """Returns a `TensorSpec` with the same shape and dtype as `spec`.
144
145    >>> spec = tf.TensorSpec(shape=[8, 3], dtype=tf.int32, name="OriginalName")
146    >>> tf.TensorSpec.from_spec(spec, "NewName")
147    TensorSpec(shape=(8, 3), dtype=tf.int32, name='NewName')
148
149    Args:
150      spec: The `TypeSpec` used to create the new `TensorSpec`.
151      name: The name for the new `TensorSpec`.  Defaults to `spec.name`.
152    """
153    return cls(spec.shape, spec.dtype, name or spec.name)
154
155  @classmethod
156  def from_tensor(cls, tensor, name=None):
157    """Returns a `TensorSpec` that describes `tensor`.
158
159    >>> tf.TensorSpec.from_tensor(tf.constant([1, 2, 3]))
160    TensorSpec(shape=(3,), dtype=tf.int32, name=None)
161
162    Args:
163      tensor: The `tf.Tensor` that should be described.
164      name: A name for the `TensorSpec`.  Defaults to `tensor.op.name`.
165
166    Returns:
167      A `TensorSpec` that describes `tensor`.
168    """
169    if isinstance(tensor, ops.EagerTensor):
170      return TensorSpec(tensor.shape, tensor.dtype, name)
171    elif isinstance(tensor, ops.Tensor):
172      return TensorSpec(tensor.shape, tensor.dtype, name or tensor.op.name)
173    else:
174      raise ValueError("`tensor` should be a tf.Tensor")
175
176  @property
177  def value_type(self):
178    """The Python type for values that are compatible with this TypeSpec."""
179    return ops.Tensor
180
181  def _to_components(self, value):
182    try:
183      value = ops.convert_to_tensor(value, self._dtype)
184    except (TypeError, ValueError):
185      raise ValueError("Value %r is not convertible to a tensor with dtype %s "
186                       "and shape %s." % (value, self._dtype, self._shape))
187    if not value.shape.is_compatible_with(self._shape):
188      raise ValueError("Value %r is not convertible to a tensor with dtype %s "
189                       "and shape %s." % (value, self._dtype, self._shape))
190    return value
191
192  def _from_components(self, components):
193    return components
194
195  def _from_compatible_tensor_list(self, tensor_list):
196    # TODO(b/112266545): It would be cleaner to create a new `ensure_shape()`
197    # op here and return that, instead of mutating the input's shape using
198    # `Tensor.set_shape()`. However, that would add extra ops, which could
199    # impact performance. When this bug is resolved, we should be able to add
200    # the `ensure_shape()` ops and optimize them away using contextual shape
201    # information.
202    assert len(tensor_list) == 1
203    tensor_list[0].set_shape(self._shape)
204    return tensor_list[0]
205
206  def _to_batchable_tensor_list(self, value, batched=False):
207    if batched and self._shape.merge_with(value.shape).ndims == 0:
208      raise ValueError("Unbatching a tensor is only supported for rank >= 1")
209    return self._to_components(value)
210
211  def _batch(self, batch_size):
212    return TensorSpec(
213        tensor_shape.TensorShape([batch_size]).concatenate(self._shape),
214        self._dtype)
215
216  def _unbatch(self):
217    if self._shape.ndims == 0:
218      raise ValueError("Unbatching a tensor is only supported for rank >= 1")
219    return TensorSpec(self._shape[1:], self._dtype)
220
221
222# TODO(b/133606651): Should is_compatible_with should check min/max bounds?
223@type_spec.register("tf.BoundedTensorSpec")
224class BoundedTensorSpec(TensorSpec):
225  """A `TensorSpec` that specifies minimum and maximum values.
226
227  Example usage:
228  ```python
229  spec = tensor_spec.BoundedTensorSpec((1, 2, 3), tf.float32, 0, (5, 5, 5))
230  tf_minimum = tf.convert_to_tensor(spec.minimum, dtype=spec.dtype)
231  tf_maximum = tf.convert_to_tensor(spec.maximum, dtype=spec.dtype)
232  ```
233
234  Bounds are meant to be inclusive. This is especially important for
235  integer types. The following spec will be satisfied by tensors
236  with values in the set {0, 1, 2}:
237  ```python
238  spec = tensor_spec.BoundedTensorSpec((3, 5), tf.int32, 0, 2)
239  ```
240  """
241
242  __slots__ = ("_minimum", "_maximum")
243
244  def __init__(self, shape, dtype, minimum, maximum, name=None):
245    """Initializes a new `BoundedTensorSpec`.
246
247    Args:
248      shape: Value convertible to `tf.TensorShape`. The shape of the tensor.
249      dtype: Value convertible to `tf.DType`. The type of the tensor values.
250      minimum: Number or sequence specifying the minimum element bounds
251        (inclusive). Must be broadcastable to `shape`.
252      maximum: Number or sequence specifying the maximum element bounds
253        (inclusive). Must be broadcastable to `shape`.
254      name: Optional string containing a semantic name for the corresponding
255        array. Defaults to `None`.
256
257    Raises:
258      ValueError: If `minimum` or `maximum` are not provided or not
259        broadcastable to `shape`.
260      TypeError: If the shape is not an iterable or if the `dtype` is an invalid
261        numpy dtype.
262    """
263    super(BoundedTensorSpec, self).__init__(shape, dtype, name)
264
265    if minimum is None or maximum is None:
266      raise ValueError("minimum and maximum must be provided; but saw "
267                       "'%s' and '%s'" % (minimum, maximum))
268
269    try:
270      minimum_shape = np.shape(minimum)
271      common_shapes.broadcast_shape(
272          tensor_shape.TensorShape(minimum_shape), self.shape)
273    except ValueError as exception:
274      raise ValueError("minimum is not compatible with shape. "
275                       "Message: {!r}.".format(exception))
276
277    try:
278      maximum_shape = np.shape(maximum)
279      common_shapes.broadcast_shape(
280          tensor_shape.TensorShape(maximum_shape), self.shape)
281    except ValueError as exception:
282      raise ValueError("maximum is not compatible with shape. "
283                       "Message: {!r}.".format(exception))
284
285    self._minimum = np.array(minimum, dtype=self.dtype.as_numpy_dtype)
286    self._minimum.setflags(write=False)
287
288    self._maximum = np.array(maximum, dtype=self.dtype.as_numpy_dtype)
289    self._maximum.setflags(write=False)
290
291  @classmethod
292  def from_spec(cls, spec):
293    """Returns a `TensorSpec` with the same shape and dtype as `spec`.
294
295    If `spec` is a `BoundedTensorSpec`, then the new spec's bounds are set to
296    `spec.minimum` and `spec.maximum`; otherwise, the bounds are set to
297    `spec.dtype.min` and `spec.dtype.max`.
298
299    >>> spec = tf.TensorSpec(shape=[8, 3], dtype=tf.int32, name="x")
300    >>> BoundedTensorSpec.from_spec(spec)
301    BoundedTensorSpec(shape=(8, 3), dtype=tf.int32, name='x',
302        minimum=array(-2147483648, dtype=int32),
303        maximum=array(2147483647, dtype=int32))
304
305    Args:
306      spec: The `TypeSpec` used to create the new `BoundedTensorSpec`.
307    """
308    dtype = dtypes.as_dtype(spec.dtype)
309    minimum = getattr(spec, "minimum", dtype.min)
310    maximum = getattr(spec, "maximum", dtype.max)
311    return BoundedTensorSpec(spec.shape, dtype, minimum, maximum, spec.name)
312
313  @property
314  def minimum(self):
315    """Returns a NumPy array specifying the minimum bounds (inclusive)."""
316    return self._minimum
317
318  @property
319  def maximum(self):
320    """Returns a NumPy array specifying the maximum bounds (inclusive)."""
321    return self._maximum
322
323  def __repr__(self):
324    s = "BoundedTensorSpec(shape={}, dtype={}, name={}, minimum={}, maximum={})"
325    return s.format(self.shape, repr(self.dtype), repr(self.name),
326                    repr(self.minimum), repr(self.maximum))
327
328  def __eq__(self, other):
329    tensor_spec_eq = super(BoundedTensorSpec, self).__eq__(other)
330    return (tensor_spec_eq and np.allclose(self.minimum, other.minimum) and
331            np.allclose(self.maximum, other.maximum))
332
333  def __hash__(self):
334    return hash((self._shape_tuple, self.dtype))
335
336  def __reduce__(self):
337    return BoundedTensorSpec, (self._shape, self._dtype, self._minimum,
338                               self._maximum, self._name)
339
340  def _serialize(self):
341    return (self._shape, self._dtype, self._minimum, self._maximum, self._name)
342
343
344_pywrap_utils.RegisterType("TensorSpec", TensorSpec)
345
346
347# Note: we do not include Tensor names when constructing TypeSpecs.
348type_spec.register_type_spec_from_value_converter(
349    ops.Tensor,
350    lambda tensor: TensorSpec(tensor.shape, tensor.dtype))
351
352type_spec.register_type_spec_from_value_converter(
353    np.ndarray,
354    lambda array: TensorSpec(array.shape, array.dtype))
355