1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_
16 #define TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_
17
18 #include <vector>
19
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/variant_tensor_data.h"
22 #include "tensorflow/core/util/tensor_ops_util.h"
23
24 namespace tensorflow {
25 namespace data {
26
27 const char kOptionalVariantTypeName[] = "tensorflow::data::Optional";
28
29 // Stores a DT_VARIANT value representing an Optional with the given value
30 // in the `output_index`^th output of the given kernel execution context.
31 Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
32 std::vector<Tensor> value);
33
34 // Stores a DT_VARIANT value representing an Optional with no value
35 // in the `output_index`^th output of the given kernel execution context.
36 Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index);
37
38 // An `OptionalVariant` can represent either an "actual value" (a tuple of
39 // tensors) or "none", and may be stored in a DT_VARIANT tensor.
40 class OptionalVariant {
41 public:
42 // Create an `OptionalVariant` with no actual value.
OptionalVariant()43 OptionalVariant() : values_(nullptr) {}
44
45 // Create an `OptionalVariant` with the actual value given by the tuple of
46 // tensors in `values`.
OptionalVariant(std::vector<Tensor> values)47 explicit OptionalVariant(std::vector<Tensor> values) {
48 values_ = std::make_shared<std::vector<Tensor>>(std::move(values));
49 }
50
OptionalVariant(const OptionalVariant & other)51 OptionalVariant(const OptionalVariant& other) : values_(other.values_) {}
52
53 // Returns true if `this` represents an actual value.
has_value()54 bool has_value() const { return values_ != nullptr; }
55
56 // REQUIRES: `this->has_value()` must be true.
get_values()57 const std::vector<Tensor>& get_values() const {
58 DCHECK(values_) << "Tried to get values from an empty OptionalVariant";
59 return *values_;
60 }
61
62 // Implementations of the necessary methods for using `OptionalVariant`
63 // objects in DT_VARIANT tensors.
TypeName()64 string TypeName() const { return kOptionalVariantTypeName; }
Encode(VariantTensorData * data)65 void Encode(VariantTensorData* data) const {
66 data->set_metadata(values_ != nullptr);
67 if (values_ != nullptr) {
68 for (const auto& t : *values_) {
69 *(data->add_tensors()) = t;
70 }
71 }
72 }
73
Decode(const VariantTensorData & data)74 bool Decode(const VariantTensorData& data) {
75 if (data.type_name() != TypeName()) {
76 return false;
77 }
78 bool has_value = false;
79 if (!data.get_metadata(&has_value)) {
80 return false;
81 }
82 if (has_value) {
83 values_ = std::make_shared<std::vector<Tensor>>(data.tensors());
84 } else {
85 values_.reset();
86 }
87 return true;
88 }
89
DebugString()90 string DebugString() const {
91 if (values_) {
92 return strings::StrCat("OptionalVariant<", "values: (",
93 absl::StrJoin(*values_, ", ",
94 [](string* s, const Tensor& elem) {
95 *s = elem.DebugString();
96 }),
97 ")>");
98 } else {
99 return strings::StrCat("OptionalVariant<None>");
100 }
101 }
102
103 private:
104 std::shared_ptr<const std::vector<Tensor>> values_;
105 };
106
107 template <typename Device>
OptionalZerosLike(OpKernelContext * ctx,const OptionalVariant & x,OptionalVariant * y)108 Status OptionalZerosLike(OpKernelContext* ctx, const OptionalVariant& x,
109 OptionalVariant* y) {
110 if (!x.has_value()) {
111 *y = x;
112 return Status::OK();
113 }
114 std::vector<Tensor> zero_tensors;
115 for (const Tensor& tensor : x.get_values()) {
116 Tensor zero_t;
117 TF_RETURN_IF_ERROR(ZerosLikeTensor<Device>(ctx, tensor, &zero_t));
118 zero_tensors.push_back(std::move(zero_t));
119 }
120 *y = OptionalVariant(zero_tensors);
121 return Status::OK();
122 }
123
124 template <typename Device>
OptionalBinaryAdd(OpKernelContext * ctx,const OptionalVariant & a,const OptionalVariant & b,OptionalVariant * out)125 Status OptionalBinaryAdd(OpKernelContext* ctx, const OptionalVariant& a,
126 const OptionalVariant& b, OptionalVariant* out) {
127 // TODO(skyewm): should adding a value to a non-value be a no-op instead?
128 if (a.has_value() != b.has_value()) {
129 return errors::InvalidArgument(
130 "Cannot add optionals because one has a value and the other doesn't.");
131 }
132 if (!a.has_value()) {
133 *out = a;
134 return Status::OK();
135 }
136 if (a.get_values().size() != b.get_values().size()) {
137 return errors::InvalidArgument(
138 "Cannot add optionals because they have different numbers of "
139 "components (",
140 a.get_values().size(), " vs. ", b.get_values().size(), ").");
141 }
142 std::vector<Tensor> out_tensors;
143 for (int i = 0; i < a.get_values().size(); ++i) {
144 const Tensor& a_tensor = a.get_values()[i];
145 const Tensor& b_tensor = b.get_values()[i];
146 Tensor out_tensor;
147 TF_RETURN_IF_ERROR(
148 BinaryAddTensors<Device>(ctx, a_tensor, b_tensor, &out_tensor));
149 out_tensors.push_back(std::move(out_tensor));
150 }
151 *out = OptionalVariant(out_tensors);
152 return Status::OK();
153 }
154
155 class OptionalNoneOp : public OpKernel {
156 public:
OptionalNoneOp(OpKernelConstruction * ctx)157 explicit OptionalNoneOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
158
159 void Compute(OpKernelContext* ctx) override;
160 };
161
162 class OptionalFromValueOp : public OpKernel {
163 public:
OptionalFromValueOp(OpKernelConstruction * ctx)164 explicit OptionalFromValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
165
166 void Compute(OpKernelContext* ctx) override;
167 };
168
169 class OptionalHasValueOp : public OpKernel {
170 public:
OptionalHasValueOp(OpKernelConstruction * ctx)171 explicit OptionalHasValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
172
173 void Compute(OpKernelContext* ctx) override;
174 };
175
176 class OptionalGetValueOp : public OpKernel {
177 public:
OptionalGetValueOp(OpKernelConstruction * ctx)178 explicit OptionalGetValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
179 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
180 OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
181 OP_REQUIRES(
182 ctx, output_shapes_.size() == output_types_.size(),
183 errors::InvalidArgument(
184 "output_types and output_shapes must be same length, got:\n",
185 "output_types: ", output_types_.size(), "\n",
186 "output_shapes: ", output_shapes_.size()));
187 }
188
189 void Compute(OpKernelContext* ctx) override;
190
191 private:
192 DataTypeVector output_types_;
193 std::vector<PartialTensorShape> output_shapes_;
194 };
195
196 } // namespace data
197 } // namespace tensorflow
198
199 #endif // TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_
200