• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_
16 #define TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_
17 
18 #include <vector>
19 
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/framework/variant_tensor_data.h"
22 #include "tensorflow/core/util/tensor_ops_util.h"
23 
24 namespace tensorflow {
25 namespace data {
26 
27 const char kOptionalVariantTypeName[] = "tensorflow::data::Optional";
28 
29 // Stores a DT_VARIANT value representing an Optional with the given value
30 // in the `output_index`^th output of the given kernel execution context.
31 Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
32                                       std::vector<Tensor> value);
33 
34 // Stores a DT_VARIANT value representing an Optional with no value
35 // in the `output_index`^th output of the given kernel execution context.
36 Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index);
37 
38 // An `OptionalVariant` can represent either an "actual value" (a tuple of
39 // tensors) or "none", and may be stored in a DT_VARIANT tensor.
40 class OptionalVariant {
41  public:
42   // Create an `OptionalVariant` with no actual value.
OptionalVariant()43   OptionalVariant() : values_(nullptr) {}
44 
45   // Create an `OptionalVariant` with the actual value given by the tuple of
46   // tensors in `values`.
OptionalVariant(std::vector<Tensor> values)47   explicit OptionalVariant(std::vector<Tensor> values) {
48     values_ = std::make_shared<std::vector<Tensor>>(std::move(values));
49   }
50 
OptionalVariant(const OptionalVariant & other)51   OptionalVariant(const OptionalVariant& other) : values_(other.values_) {}
52 
53   // Returns true if `this` represents an actual value.
has_value()54   bool has_value() const { return values_ != nullptr; }
55 
56   // REQUIRES: `this->has_value()` must be true.
get_values()57   const std::vector<Tensor>& get_values() const {
58     DCHECK(values_) << "Tried to get values from an empty OptionalVariant";
59     return *values_;
60   }
61 
62   // Implementations of the necessary methods for using `OptionalVariant`
63   // objects in DT_VARIANT tensors.
TypeName()64   string TypeName() const { return kOptionalVariantTypeName; }
Encode(VariantTensorData * data)65   void Encode(VariantTensorData* data) const {
66     data->set_metadata(values_ != nullptr);
67     if (values_ != nullptr) {
68       for (const auto& t : *values_) {
69         *(data->add_tensors()) = t;
70       }
71     }
72   }
73 
Decode(const VariantTensorData & data)74   bool Decode(const VariantTensorData& data) {
75     if (data.type_name() != TypeName()) {
76       return false;
77     }
78     bool has_value = false;
79     if (!data.get_metadata(&has_value)) {
80       return false;
81     }
82     if (has_value) {
83       values_ = std::make_shared<std::vector<Tensor>>(data.tensors());
84     } else {
85       values_.reset();
86     }
87     return true;
88   }
89 
DebugString()90   string DebugString() const {
91     if (values_) {
92       return strings::StrCat("OptionalVariant<", "values: (",
93                              absl::StrJoin(*values_, ", ",
94                                            [](string* s, const Tensor& elem) {
95                                              *s = elem.DebugString();
96                                            }),
97                              ")>");
98     } else {
99       return strings::StrCat("OptionalVariant<None>");
100     }
101   }
102 
103  private:
104   std::shared_ptr<const std::vector<Tensor>> values_;
105 };
106 
107 template <typename Device>
OptionalZerosLike(OpKernelContext * ctx,const OptionalVariant & x,OptionalVariant * y)108 Status OptionalZerosLike(OpKernelContext* ctx, const OptionalVariant& x,
109                          OptionalVariant* y) {
110   if (!x.has_value()) {
111     *y = x;
112     return Status::OK();
113   }
114   std::vector<Tensor> zero_tensors;
115   for (const Tensor& tensor : x.get_values()) {
116     Tensor zero_t;
117     TF_RETURN_IF_ERROR(ZerosLikeTensor<Device>(ctx, tensor, &zero_t));
118     zero_tensors.push_back(std::move(zero_t));
119   }
120   *y = OptionalVariant(zero_tensors);
121   return Status::OK();
122 }
123 
124 template <typename Device>
OptionalBinaryAdd(OpKernelContext * ctx,const OptionalVariant & a,const OptionalVariant & b,OptionalVariant * out)125 Status OptionalBinaryAdd(OpKernelContext* ctx, const OptionalVariant& a,
126                          const OptionalVariant& b, OptionalVariant* out) {
127   // TODO(skyewm): should adding a value to a non-value be a no-op instead?
128   if (a.has_value() != b.has_value()) {
129     return errors::InvalidArgument(
130         "Cannot add optionals because one has a value and the other doesn't.");
131   }
132   if (!a.has_value()) {
133     *out = a;
134     return Status::OK();
135   }
136   if (a.get_values().size() != b.get_values().size()) {
137     return errors::InvalidArgument(
138         "Cannot add optionals because they have different numbers of "
139         "components (",
140         a.get_values().size(), " vs. ", b.get_values().size(), ").");
141   }
142   std::vector<Tensor> out_tensors;
143   for (int i = 0; i < a.get_values().size(); ++i) {
144     const Tensor& a_tensor = a.get_values()[i];
145     const Tensor& b_tensor = b.get_values()[i];
146     Tensor out_tensor;
147     TF_RETURN_IF_ERROR(
148         BinaryAddTensors<Device>(ctx, a_tensor, b_tensor, &out_tensor));
149     out_tensors.push_back(std::move(out_tensor));
150   }
151   *out = OptionalVariant(out_tensors);
152   return Status::OK();
153 }
154 
155 class OptionalNoneOp : public OpKernel {
156  public:
OptionalNoneOp(OpKernelConstruction * ctx)157   explicit OptionalNoneOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
158 
159   void Compute(OpKernelContext* ctx) override;
160 };
161 
162 class OptionalFromValueOp : public OpKernel {
163  public:
OptionalFromValueOp(OpKernelConstruction * ctx)164   explicit OptionalFromValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
165 
166   void Compute(OpKernelContext* ctx) override;
167 };
168 
169 class OptionalHasValueOp : public OpKernel {
170  public:
OptionalHasValueOp(OpKernelConstruction * ctx)171   explicit OptionalHasValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
172 
173   void Compute(OpKernelContext* ctx) override;
174 };
175 
176 class OptionalGetValueOp : public OpKernel {
177  public:
OptionalGetValueOp(OpKernelConstruction * ctx)178   explicit OptionalGetValueOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
179     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
180     OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
181     OP_REQUIRES(
182         ctx, output_shapes_.size() == output_types_.size(),
183         errors::InvalidArgument(
184             "output_types and output_shapes must be same length, got:\n",
185             "output_types: ", output_types_.size(), "\n",
186             "output_shapes: ", output_shapes_.size()));
187   }
188 
189   void Compute(OpKernelContext* ctx) override;
190 
191  private:
192   DataTypeVector output_types_;
193   std::vector<PartialTensorShape> output_shapes_;
194 };
195 
196 }  // namespace data
197 }  // namespace tensorflow
198 
199 #endif  // TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_
200