• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A type for representing values that may or may not exist."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21
22import six
23
24from tensorflow.python.data.util import structure
25from tensorflow.python.framework import composite_tensor
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import ops
28from tensorflow.python.framework import tensor_spec
29from tensorflow.python.framework import type_spec
30from tensorflow.python.ops import gen_dataset_ops
31from tensorflow.python.util import deprecation
32from tensorflow.python.util.tf_export import tf_export
33
34
35@tf_export("experimental.Optional", "data.experimental.Optional")
36@deprecation.deprecated_endpoints("data.experimental.Optional")
37@six.add_metaclass(abc.ABCMeta)
38class Optional(composite_tensor.CompositeTensor):
39  """Represents a value that may or may not be present.
40
41  A `tf.experimental.Optional` can represent the result of an operation that may
42  fail as a value, rather than raising an exception and halting execution. For
43  example, `tf.data.Iterator.get_next_as_optional()` returns a
44  `tf.experimental.Optional` that either contains the next element of an
45  iterator if one exists, or an "empty" value that indicates the end of the
46  sequence has been reached.
47
48  `tf.experimental.Optional` can only be used with values that are convertible
49  to `tf.Tensor` or `tf.CompositeTensor`.
50
51  One can create a `tf.experimental.Optional` from a value using the
52  `from_value()` method:
53
54  >>> optional = tf.experimental.Optional.from_value(42)
55  >>> print(optional.has_value())
56  tf.Tensor(True, shape=(), dtype=bool)
57  >>> print(optional.get_value())
58  tf.Tensor(42, shape=(), dtype=int32)
59
60  or without a value using the `empty()` method:
61
62  >>> optional = tf.experimental.Optional.empty(
63  ...   tf.TensorSpec(shape=(), dtype=tf.int32, name=None))
64  >>> print(optional.has_value())
65  tf.Tensor(False, shape=(), dtype=bool)
66  """
67
68  @abc.abstractmethod
69  def has_value(self, name=None):
70    """Returns a tensor that evaluates to `True` if this optional has a value.
71
72    >>> optional = tf.experimental.Optional.from_value(42)
73    >>> print(optional.has_value())
74    tf.Tensor(True, shape=(), dtype=bool)
75
76    Args:
77      name: (Optional.) A name for the created operation.
78
79    Returns:
80      A scalar `tf.Tensor` of type `tf.bool`.
81    """
82    raise NotImplementedError("Optional.has_value()")
83
84  @abc.abstractmethod
85  def get_value(self, name=None):
86    """Returns the value wrapped by this optional.
87
88    If this optional does not have a value (i.e. `self.has_value()` evaluates to
89    `False`), this operation will raise `tf.errors.InvalidArgumentError` at
90    runtime.
91
92    >>> optional = tf.experimental.Optional.from_value(42)
93    >>> print(optional.get_value())
94    tf.Tensor(42, shape=(), dtype=int32)
95
96    Args:
97      name: (Optional.) A name for the created operation.
98
99    Returns:
100      The wrapped value.
101    """
102    raise NotImplementedError("Optional.get_value()")
103
104  @abc.abstractproperty
105  def element_spec(self):
106    """The type specification of an element of this optional.
107
108    >>> optional = tf.experimental.Optional.from_value(42)
109    >>> print(optional.element_spec)
110    tf.TensorSpec(shape=(), dtype=tf.int32, name=None)
111
112    Returns:
113      A nested structure of `tf.TypeSpec` objects matching the structure of an
114      element of this optional, specifying the type of individual components.
115    """
116    raise NotImplementedError("Optional.element_spec")
117
118  @staticmethod
119  def empty(element_spec):
120    """Returns an `Optional` that has no value.
121
122    NOTE: This method takes an argument that defines the structure of the value
123    that would be contained in the returned `Optional` if it had a value.
124
125    >>> optional = tf.experimental.Optional.empty(
126    ...   tf.TensorSpec(shape=(), dtype=tf.int32, name=None))
127    >>> print(optional.has_value())
128    tf.Tensor(False, shape=(), dtype=bool)
129
130    Args:
131      element_spec: A nested structure of `tf.TypeSpec` objects matching the
132        structure of an element of this optional.
133
134    Returns:
135      A `tf.experimental.Optional` with no value.
136    """
137    return _OptionalImpl(gen_dataset_ops.optional_none(), element_spec)
138
139  @staticmethod
140  def from_value(value):
141    """Returns a `tf.experimental.Optional` that wraps the given value.
142
143    >>> optional = tf.experimental.Optional.from_value(42)
144    >>> print(optional.has_value())
145    tf.Tensor(True, shape=(), dtype=bool)
146    >>> print(optional.get_value())
147    tf.Tensor(42, shape=(), dtype=int32)
148
149    Args:
150      value: A value to wrap. The value must be convertible to `tf.Tensor` or
151        `tf.CompositeTensor`.
152
153    Returns:
154      A `tf.experimental.Optional` that wraps `value`.
155    """
156    with ops.name_scope("optional") as scope:
157      with ops.name_scope("value"):
158        element_spec = structure.type_spec_from_value(value)
159        encoded_value = structure.to_tensor_list(element_spec, value)
160
161    return _OptionalImpl(
162        gen_dataset_ops.optional_from_value(encoded_value, name=scope),
163        element_spec)
164
165
166class _OptionalImpl(Optional):
167  """Concrete implementation of `tf.experimental.Optional`.
168
169  NOTE(mrry): This implementation is kept private, to avoid defining
170  `Optional.__init__()` in the public API.
171  """
172
173  def __init__(self, variant_tensor, element_spec):
174    self._variant_tensor = variant_tensor
175    self._element_spec = element_spec
176
177  def has_value(self, name=None):
178    with ops.colocate_with(self._variant_tensor):
179      return gen_dataset_ops.optional_has_value(self._variant_tensor, name=name)
180
181  def get_value(self, name=None):
182    # TODO(b/110122868): Consolidate the restructuring logic with similar logic
183    # in `Iterator.get_next()` and `StructuredFunctionWrapper`.
184    with ops.name_scope(name, "OptionalGetValue",
185                        [self._variant_tensor]) as scope:
186      with ops.colocate_with(self._variant_tensor):
187        result = gen_dataset_ops.optional_get_value(
188            self._variant_tensor,
189            name=scope,
190            output_types=structure.get_flat_tensor_types(self._element_spec),
191            output_shapes=structure.get_flat_tensor_shapes(self._element_spec))
192      # NOTE: We do not colocate the deserialization of composite tensors
193      # because not all ops are guaranteed to have non-GPU kernels.
194      return structure.from_tensor_list(self._element_spec, result)
195
196  @property
197  def element_spec(self):
198    return self._element_spec
199
200  @property
201  def _type_spec(self):
202    return OptionalSpec.from_value(self)
203
204
205@tf_export(
206    "OptionalSpec", v1=["OptionalSpec", "data.experimental.OptionalStructure"])
207class OptionalSpec(type_spec.TypeSpec):
208  """Type specification for `tf.experimental.Optional`.
209
210  For instance, `tf.OptionalSpec` can be used to define a tf.function that takes
211  `tf.experimental.Optional` as an input argument:
212
213  >>> @tf.function(input_signature=[tf.OptionalSpec(
214  ...   tf.TensorSpec(shape=(), dtype=tf.int32, name=None))])
215  ... def maybe_square(optional):
216  ...   if optional.has_value():
217  ...     x = optional.get_value()
218  ...     return x * x
219  ...   return -1
220  >>> optional = tf.experimental.Optional.from_value(5)
221  >>> print(maybe_square(optional))
222  tf.Tensor(25, shape=(), dtype=int32)
223
224  Attributes:
225    element_spec: A nested structure of `TypeSpec` objects that represents the
226      type specification of the optional element.
227  """
228
229  __slots__ = ["_element_spec"]
230
231  def __init__(self, element_spec):
232    self._element_spec = element_spec
233
234  @property
235  def value_type(self):
236    return _OptionalImpl
237
238  def _serialize(self):
239    return (self._element_spec,)
240
241  @property
242  def _component_specs(self):
243    return [tensor_spec.TensorSpec((), dtypes.variant)]
244
245  def _to_components(self, value):
246    return [value._variant_tensor]  # pylint: disable=protected-access
247
248  def _from_components(self, flat_value):
249    # pylint: disable=protected-access
250    return _OptionalImpl(flat_value[0], self._element_spec)
251
252  @staticmethod
253  def from_value(value):
254    return OptionalSpec(value.element_spec)
255
256  def _to_legacy_output_types(self):
257    return self
258
259  def _to_legacy_output_shapes(self):
260    return self
261
262  def _to_legacy_output_classes(self):
263    return self
264