• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/core/kernels/data/optional_ops.h"
19 
20 #include "tensorflow/core/common_runtime/dma_helper.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/variant_encode_decode.h"
23 #include "tensorflow/core/framework/variant_op_registry.h"
24 
25 namespace tensorflow {
26 namespace data {
27 namespace {
28 
OptionalDeviceCopy(const OptionalVariant & from,OptionalVariant * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy)29 static Status OptionalDeviceCopy(
30     const OptionalVariant& from, OptionalVariant* to,
31     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
32   if (from.has_value()) {
33     const std::vector<Tensor>& from_values = from.get_values();
34     std::vector<Tensor> to_values;
35     to_values.reserve(from_values.size());
36     for (const Tensor& t : from_values) {
37       if (DMAHelper::CanUseDMA(&t) || t.dtype() == DT_VARIANT) {
38         // NOTE(skyewm): we're careful to make sure the lifetime of the 'to'
39         // Tensor passed to `copy` (i.e. to_values.back()) is the same as the
40         // returned 'to' OptionalVariant. This is because `copy` may spawn async
41         // callbacks that don't run until after this function returns and access
42         // the 'to' Tensor (e.g. BaseGPUDevice::MaybeCopyTensorToGPU).
43         to_values.emplace_back(t.dtype());
44         TF_RETURN_IF_ERROR(copy(t, &to_values.back()));
45       } else {
46         to_values.push_back(t);
47       }
48     }
49     *to = OptionalVariant(std::move(to_values));
50   } else {
51     *to = from;
52   }
53   return Status::OK();
54 }
55 
56 #define REGISTER_OPTIONAL_COPY(DIRECTION)               \
57   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
58       OptionalVariant, DIRECTION, OptionalDeviceCopy)
59 
60 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
61 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
62 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
63 
64 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(OptionalVariant,
65                                        kOptionalVariantTypeName);
66 
67 }  // namespace
68 
Compute(OpKernelContext * ctx)69 void OptionalNoneOp::Compute(OpKernelContext* ctx) {
70   OP_REQUIRES_OK(ctx, WriteOptionalNoneToOutput(ctx, 0));
71 }
72 
Compute(OpKernelContext * ctx)73 void OptionalFromValueOp::Compute(OpKernelContext* ctx) {
74   OpInputList components_input;
75   OP_REQUIRES_OK(ctx, ctx->input_list("components", &components_input));
76   std::vector<Tensor> components(components_input.begin(),
77                                  components_input.end());
78   OP_REQUIRES_OK(ctx,
79                  WriteOptionalWithValueToOutput(ctx, 0, std::move(components)));
80 }
81 
Compute(OpKernelContext * ctx)82 void OptionalHasValueOp::Compute(OpKernelContext* ctx) {
83   const Tensor* optional_input;
84   OP_REQUIRES_OK(ctx, ctx->input("optional", &optional_input));
85   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(optional_input->shape()),
86               errors::InvalidArgument(
87                   "Input to OptionalHasValue must be a scalar tensor "
88                   "containing an OptionalVariant object."));
89   const OptionalVariant* optional =
90       optional_input->scalar<Variant>()().get<OptionalVariant>();
91   OP_REQUIRES(
92       ctx, optional != nullptr,
93       errors::InvalidArgument(
94           "Input to OptionalHasValue must be an OptionalVariant object."));
95   Tensor* result;
96   OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &result));
97   result->scalar<bool>()() = optional->has_value();
98 }
99 
Compute(OpKernelContext * ctx)100 void OptionalGetValueOp::Compute(OpKernelContext* ctx) {
101   const Tensor* optional_input;
102   OP_REQUIRES_OK(ctx, ctx->input("optional", &optional_input));
103   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(optional_input->shape()),
104               errors::InvalidArgument(
105                   "Input to OptionalHasValue must be a scalar tensor "
106                   "containing an OptionalVariant object."));
107   const OptionalVariant* optional =
108       optional_input->scalar<Variant>()().get<OptionalVariant>();
109   OP_REQUIRES(
110       ctx, optional != nullptr,
111       errors::InvalidArgument(
112           "Input to OptionalHasValue must be an OptionalVariant object."));
113   OP_REQUIRES(
114       ctx, optional->has_value(),
115       errors::InvalidArgument("The given optional does not have a value."));
116   const auto& components = optional->get_values();
117   OP_REQUIRES(
118       ctx, components.size() == output_types_.size(),
119       errors::InvalidArgument("The given optional has ", components.size(),
120                               " components, expected ", output_types_.size()));
121   for (int i = 0; i < components.size(); ++i) {
122     OP_REQUIRES(ctx, components[i].dtype() == output_types_[i],
123                 errors::InvalidArgument(
124                     "The given optional does not match the expected type for "
125                     "component ",
126                     i, ". Expected: ", DataTypeString(output_types_[i]),
127                     ". Actual: ", DataTypeString(components[i].dtype()), "."));
128     OP_REQUIRES(ctx, output_shapes_[i].IsCompatibleWith(components[i].shape()),
129                 errors::InvalidArgument(
130                     "The given optional does not match the expected shape "
131                     "for component ",
132                     i, ". Expected: ", output_shapes_[i].DebugString(),
133                     ". Actual: ", components[i].shape().DebugString(), "."));
134     ctx->set_output(i, components[i]);
135   }
136 }
137 
WriteOptionalWithValueToOutput(OpKernelContext * ctx,int output_index,std::vector<Tensor> value)138 Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
139                                       std::vector<Tensor> value) {
140   OptionalVariant v(std::move(value));
141   Tensor* variant_t;
142   AllocatorAttributes cpu_alloc;
143   cpu_alloc.set_on_host(true);
144   TF_RETURN_IF_ERROR(ctx->allocate_output(output_index, TensorShape({}),
145                                           &variant_t, cpu_alloc));
146   variant_t->scalar<Variant>()() = v;
147   return Status::OK();
148 }
149 
WriteOptionalNoneToOutput(OpKernelContext * ctx,int output_index)150 Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index) {
151   OptionalVariant v;
152   Tensor* variant_t;
153   AllocatorAttributes cpu_alloc;
154   cpu_alloc.set_on_host(true);
155   TF_RETURN_IF_ERROR(ctx->allocate_output(output_index, TensorShape({}),
156                                           &variant_t, cpu_alloc));
157   variant_t->scalar<Variant>()() = v;
158   return Status::OK();
159 }
160 
161 namespace {
162 
163 REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE_CPU).Priority(2),
164                         OptionalNoneOp);
165 REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE_GPU).Priority(1),
166                         OptionalNoneOp);
167 REGISTER_KERNEL_BUILDER(
168     Name("OptionalFromValue").Device(DEVICE_CPU).Priority(2),
169     OptionalFromValueOp);
170 REGISTER_KERNEL_BUILDER(
171     Name("OptionalFromValue").Device(DEVICE_GPU).Priority(1),
172     OptionalFromValueOp);
173 
174 REGISTER_KERNEL_BUILDER(Name("OptionalHasValue").Device(DEVICE_CPU).Priority(2),
175                         OptionalHasValueOp);
176 REGISTER_KERNEL_BUILDER(Name("OptionalHasValue")
177                             .Device(DEVICE_GPU)
178                             .HostMemory("has_value")
179                             .Priority(1),
180                         OptionalHasValueOp);
181 REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE_CPU).Priority(2),
182                         OptionalGetValueOp);
183 REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE_GPU).Priority(1),
184                         OptionalGetValueOp);
185 
186 }  // namespace
187 
188 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
189                                          DEVICE_CPU, OptionalVariant,
190                                          OptionalZerosLike<CPUDevice>);
191 
192 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
193                                           OptionalVariant,
194                                           OptionalBinaryAdd<CPUDevice>);
195 
196 }  // namespace data
197 }  // namespace tensorflow
198