1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A type for representing values that may or may not exist.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import abc 21 22import six 23 24from tensorflow.python.data.util import structure 25from tensorflow.python.framework import composite_tensor 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import ops 28from tensorflow.python.framework import tensor_spec 29from tensorflow.python.framework import type_spec 30from tensorflow.python.ops import gen_dataset_ops 31from tensorflow.python.util import deprecation 32from tensorflow.python.util.tf_export import tf_export 33 34 35@tf_export("experimental.Optional", "data.experimental.Optional") 36@deprecation.deprecated_endpoints("data.experimental.Optional") 37@six.add_metaclass(abc.ABCMeta) 38class Optional(composite_tensor.CompositeTensor): 39 """Represents a value that may or may not be present. 40 41 A `tf.experimental.Optional` can represent the result of an operation that may 42 fail as a value, rather than raising an exception and halting execution. For 43 example, `tf.data.Iterator.get_next_as_optional()` returns a 44 `tf.experimental.Optional` that either contains the next element of an 45 iterator if one exists, or an "empty" value that indicates the end of the 46 sequence has been reached. 47 48 `tf.experimental.Optional` can only be used with values that are convertible 49 to `tf.Tensor` or `tf.CompositeTensor`. 50 51 One can create a `tf.experimental.Optional` from a value using the 52 `from_value()` method: 53 54 >>> optional = tf.experimental.Optional.from_value(42) 55 >>> print(optional.has_value()) 56 tf.Tensor(True, shape=(), dtype=bool) 57 >>> print(optional.get_value()) 58 tf.Tensor(42, shape=(), dtype=int32) 59 60 or without a value using the `empty()` method: 61 62 >>> optional = tf.experimental.Optional.empty( 63 ... tf.TensorSpec(shape=(), dtype=tf.int32, name=None)) 64 >>> print(optional.has_value()) 65 tf.Tensor(False, shape=(), dtype=bool) 66 """ 67 68 @abc.abstractmethod 69 def has_value(self, name=None): 70 """Returns a tensor that evaluates to `True` if this optional has a value. 71 72 >>> optional = tf.experimental.Optional.from_value(42) 73 >>> print(optional.has_value()) 74 tf.Tensor(True, shape=(), dtype=bool) 75 76 Args: 77 name: (Optional.) A name for the created operation. 78 79 Returns: 80 A scalar `tf.Tensor` of type `tf.bool`. 81 """ 82 raise NotImplementedError("Optional.has_value()") 83 84 @abc.abstractmethod 85 def get_value(self, name=None): 86 """Returns the value wrapped by this optional. 87 88 If this optional does not have a value (i.e. `self.has_value()` evaluates to 89 `False`), this operation will raise `tf.errors.InvalidArgumentError` at 90 runtime. 91 92 >>> optional = tf.experimental.Optional.from_value(42) 93 >>> print(optional.get_value()) 94 tf.Tensor(42, shape=(), dtype=int32) 95 96 Args: 97 name: (Optional.) A name for the created operation. 98 99 Returns: 100 The wrapped value. 101 """ 102 raise NotImplementedError("Optional.get_value()") 103 104 @abc.abstractproperty 105 def element_spec(self): 106 """The type specification of an element of this optional. 107 108 >>> optional = tf.experimental.Optional.from_value(42) 109 >>> print(optional.element_spec) 110 tf.TensorSpec(shape=(), dtype=tf.int32, name=None) 111 112 Returns: 113 A nested structure of `tf.TypeSpec` objects matching the structure of an 114 element of this optional, specifying the type of individual components. 115 """ 116 raise NotImplementedError("Optional.element_spec") 117 118 @staticmethod 119 def empty(element_spec): 120 """Returns an `Optional` that has no value. 121 122 NOTE: This method takes an argument that defines the structure of the value 123 that would be contained in the returned `Optional` if it had a value. 124 125 >>> optional = tf.experimental.Optional.empty( 126 ... tf.TensorSpec(shape=(), dtype=tf.int32, name=None)) 127 >>> print(optional.has_value()) 128 tf.Tensor(False, shape=(), dtype=bool) 129 130 Args: 131 element_spec: A nested structure of `tf.TypeSpec` objects matching the 132 structure of an element of this optional. 133 134 Returns: 135 A `tf.experimental.Optional` with no value. 136 """ 137 return _OptionalImpl(gen_dataset_ops.optional_none(), element_spec) 138 139 @staticmethod 140 def from_value(value): 141 """Returns a `tf.experimental.Optional` that wraps the given value. 142 143 >>> optional = tf.experimental.Optional.from_value(42) 144 >>> print(optional.has_value()) 145 tf.Tensor(True, shape=(), dtype=bool) 146 >>> print(optional.get_value()) 147 tf.Tensor(42, shape=(), dtype=int32) 148 149 Args: 150 value: A value to wrap. The value must be convertible to `tf.Tensor` or 151 `tf.CompositeTensor`. 152 153 Returns: 154 A `tf.experimental.Optional` that wraps `value`. 155 """ 156 with ops.name_scope("optional") as scope: 157 with ops.name_scope("value"): 158 element_spec = structure.type_spec_from_value(value) 159 encoded_value = structure.to_tensor_list(element_spec, value) 160 161 return _OptionalImpl( 162 gen_dataset_ops.optional_from_value(encoded_value, name=scope), 163 element_spec) 164 165 166class _OptionalImpl(Optional): 167 """Concrete implementation of `tf.experimental.Optional`. 168 169 NOTE(mrry): This implementation is kept private, to avoid defining 170 `Optional.__init__()` in the public API. 171 """ 172 173 def __init__(self, variant_tensor, element_spec): 174 self._variant_tensor = variant_tensor 175 self._element_spec = element_spec 176 177 def has_value(self, name=None): 178 with ops.colocate_with(self._variant_tensor): 179 return gen_dataset_ops.optional_has_value(self._variant_tensor, name=name) 180 181 def get_value(self, name=None): 182 # TODO(b/110122868): Consolidate the restructuring logic with similar logic 183 # in `Iterator.get_next()` and `StructuredFunctionWrapper`. 184 with ops.name_scope(name, "OptionalGetValue", 185 [self._variant_tensor]) as scope: 186 with ops.colocate_with(self._variant_tensor): 187 result = gen_dataset_ops.optional_get_value( 188 self._variant_tensor, 189 name=scope, 190 output_types=structure.get_flat_tensor_types(self._element_spec), 191 output_shapes=structure.get_flat_tensor_shapes(self._element_spec)) 192 # NOTE: We do not colocate the deserialization of composite tensors 193 # because not all ops are guaranteed to have non-GPU kernels. 194 return structure.from_tensor_list(self._element_spec, result) 195 196 @property 197 def element_spec(self): 198 return self._element_spec 199 200 @property 201 def _type_spec(self): 202 return OptionalSpec.from_value(self) 203 204 205@tf_export( 206 "OptionalSpec", v1=["OptionalSpec", "data.experimental.OptionalStructure"]) 207class OptionalSpec(type_spec.TypeSpec): 208 """Type specification for `tf.experimental.Optional`. 209 210 For instance, `tf.OptionalSpec` can be used to define a tf.function that takes 211 `tf.experimental.Optional` as an input argument: 212 213 >>> @tf.function(input_signature=[tf.OptionalSpec( 214 ... tf.TensorSpec(shape=(), dtype=tf.int32, name=None))]) 215 ... def maybe_square(optional): 216 ... if optional.has_value(): 217 ... x = optional.get_value() 218 ... return x * x 219 ... return -1 220 >>> optional = tf.experimental.Optional.from_value(5) 221 >>> print(maybe_square(optional)) 222 tf.Tensor(25, shape=(), dtype=int32) 223 224 Attributes: 225 element_spec: A nested structure of `TypeSpec` objects that represents the 226 type specification of the optional element. 227 """ 228 229 __slots__ = ["_element_spec"] 230 231 def __init__(self, element_spec): 232 self._element_spec = element_spec 233 234 @property 235 def value_type(self): 236 return _OptionalImpl 237 238 def _serialize(self): 239 return (self._element_spec,) 240 241 @property 242 def _component_specs(self): 243 return [tensor_spec.TensorSpec((), dtypes.variant)] 244 245 def _to_components(self, value): 246 return [value._variant_tensor] # pylint: disable=protected-access 247 248 def _from_components(self, flat_value): 249 # pylint: disable=protected-access 250 return _OptionalImpl(flat_value[0], self._element_spec) 251 252 @staticmethod 253 def from_value(value): 254 return OptionalSpec(value.element_spec) 255 256 def _to_legacy_output_types(self): 257 return self 258 259 def _to_legacy_output_shapes(self): 260 return self 261 262 def _to_legacy_output_classes(self): 263 return self 264