1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #include "tensorflow/core/kernels/data/optional_ops.h"
19
20 #include "tensorflow/core/common_runtime/dma_helper.h"
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/variant_encode_decode.h"
23 #include "tensorflow/core/framework/variant_op_registry.h"
24
25 namespace tensorflow {
26 namespace data {
27 namespace {
28
OptionalDeviceCopy(const OptionalVariant & from,OptionalVariant * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy)29 static Status OptionalDeviceCopy(
30 const OptionalVariant& from, OptionalVariant* to,
31 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
32 if (from.has_value()) {
33 const std::vector<Tensor>& from_values = from.get_values();
34 std::vector<Tensor> to_values;
35 to_values.reserve(from_values.size());
36 for (const Tensor& t : from_values) {
37 if (DMAHelper::CanUseDMA(&t) || t.dtype() == DT_VARIANT) {
38 // NOTE(skyewm): we're careful to make sure the lifetime of the 'to'
39 // Tensor passed to `copy` (i.e. to_values.back()) is the same as the
40 // returned 'to' OptionalVariant. This is because `copy` may spawn async
41 // callbacks that don't run until after this function returns and access
42 // the 'to' Tensor (e.g. BaseGPUDevice::MaybeCopyTensorToGPU).
43 to_values.emplace_back(t.dtype());
44 TF_RETURN_IF_ERROR(copy(t, &to_values.back()));
45 } else {
46 to_values.push_back(t);
47 }
48 }
49 *to = OptionalVariant(std::move(to_values));
50 } else {
51 *to = from;
52 }
53 return Status::OK();
54 }
55
56 #define REGISTER_OPTIONAL_COPY(DIRECTION) \
57 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
58 OptionalVariant, DIRECTION, OptionalDeviceCopy)
59
60 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
61 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
62 REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_DEVICE);
63
64 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(OptionalVariant,
65 kOptionalVariantTypeName);
66
67 } // namespace
68
Compute(OpKernelContext * ctx)69 void OptionalNoneOp::Compute(OpKernelContext* ctx) {
70 OP_REQUIRES_OK(ctx, WriteOptionalNoneToOutput(ctx, 0));
71 }
72
Compute(OpKernelContext * ctx)73 void OptionalFromValueOp::Compute(OpKernelContext* ctx) {
74 OpInputList components_input;
75 OP_REQUIRES_OK(ctx, ctx->input_list("components", &components_input));
76 std::vector<Tensor> components(components_input.begin(),
77 components_input.end());
78 OP_REQUIRES_OK(ctx,
79 WriteOptionalWithValueToOutput(ctx, 0, std::move(components)));
80 }
81
Compute(OpKernelContext * ctx)82 void OptionalHasValueOp::Compute(OpKernelContext* ctx) {
83 const Tensor* optional_input;
84 OP_REQUIRES_OK(ctx, ctx->input("optional", &optional_input));
85 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(optional_input->shape()),
86 errors::InvalidArgument(
87 "Input to OptionalHasValue must be a scalar tensor "
88 "containing an OptionalVariant object."));
89 const OptionalVariant* optional =
90 optional_input->scalar<Variant>()().get<OptionalVariant>();
91 OP_REQUIRES(
92 ctx, optional != nullptr,
93 errors::InvalidArgument(
94 "Input to OptionalHasValue must be an OptionalVariant object."));
95 Tensor* result;
96 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {}, &result));
97 result->scalar<bool>()() = optional->has_value();
98 }
99
Compute(OpKernelContext * ctx)100 void OptionalGetValueOp::Compute(OpKernelContext* ctx) {
101 const Tensor* optional_input;
102 OP_REQUIRES_OK(ctx, ctx->input("optional", &optional_input));
103 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(optional_input->shape()),
104 errors::InvalidArgument(
105 "Input to OptionalHasValue must be a scalar tensor "
106 "containing an OptionalVariant object."));
107 const OptionalVariant* optional =
108 optional_input->scalar<Variant>()().get<OptionalVariant>();
109 OP_REQUIRES(
110 ctx, optional != nullptr,
111 errors::InvalidArgument(
112 "Input to OptionalHasValue must be an OptionalVariant object."));
113 OP_REQUIRES(
114 ctx, optional->has_value(),
115 errors::InvalidArgument("The given optional does not have a value."));
116 const auto& components = optional->get_values();
117 OP_REQUIRES(
118 ctx, components.size() == output_types_.size(),
119 errors::InvalidArgument("The given optional has ", components.size(),
120 " components, expected ", output_types_.size()));
121 for (int i = 0; i < components.size(); ++i) {
122 OP_REQUIRES(ctx, components[i].dtype() == output_types_[i],
123 errors::InvalidArgument(
124 "The given optional does not match the expected type for "
125 "component ",
126 i, ". Expected: ", DataTypeString(output_types_[i]),
127 ". Actual: ", DataTypeString(components[i].dtype()), "."));
128 OP_REQUIRES(ctx, output_shapes_[i].IsCompatibleWith(components[i].shape()),
129 errors::InvalidArgument(
130 "The given optional does not match the expected shape "
131 "for component ",
132 i, ". Expected: ", output_shapes_[i].DebugString(),
133 ". Actual: ", components[i].shape().DebugString(), "."));
134 ctx->set_output(i, components[i]);
135 }
136 }
137
WriteOptionalWithValueToOutput(OpKernelContext * ctx,int output_index,std::vector<Tensor> value)138 Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
139 std::vector<Tensor> value) {
140 OptionalVariant v(std::move(value));
141 Tensor* variant_t;
142 AllocatorAttributes cpu_alloc;
143 cpu_alloc.set_on_host(true);
144 TF_RETURN_IF_ERROR(ctx->allocate_output(output_index, TensorShape({}),
145 &variant_t, cpu_alloc));
146 variant_t->scalar<Variant>()() = v;
147 return Status::OK();
148 }
149
WriteOptionalNoneToOutput(OpKernelContext * ctx,int output_index)150 Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index) {
151 OptionalVariant v;
152 Tensor* variant_t;
153 AllocatorAttributes cpu_alloc;
154 cpu_alloc.set_on_host(true);
155 TF_RETURN_IF_ERROR(ctx->allocate_output(output_index, TensorShape({}),
156 &variant_t, cpu_alloc));
157 variant_t->scalar<Variant>()() = v;
158 return Status::OK();
159 }
160
161 namespace {
162
163 REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE_CPU).Priority(2),
164 OptionalNoneOp);
165 REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE_GPU).Priority(1),
166 OptionalNoneOp);
167 REGISTER_KERNEL_BUILDER(
168 Name("OptionalFromValue").Device(DEVICE_CPU).Priority(2),
169 OptionalFromValueOp);
170 REGISTER_KERNEL_BUILDER(
171 Name("OptionalFromValue").Device(DEVICE_GPU).Priority(1),
172 OptionalFromValueOp);
173
174 REGISTER_KERNEL_BUILDER(Name("OptionalHasValue").Device(DEVICE_CPU).Priority(2),
175 OptionalHasValueOp);
176 REGISTER_KERNEL_BUILDER(Name("OptionalHasValue")
177 .Device(DEVICE_GPU)
178 .HostMemory("has_value")
179 .Priority(1),
180 OptionalHasValueOp);
181 REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE_CPU).Priority(2),
182 OptionalGetValueOp);
183 REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE_GPU).Priority(1),
184 OptionalGetValueOp);
185
186 } // namespace
187
188 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
189 DEVICE_CPU, OptionalVariant,
190 OptionalZerosLike<CPUDevice>);
191
192 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
193 OptionalVariant,
194 OptionalBinaryAdd<CPUDevice>);
195
196 } // namespace data
197 } // namespace tensorflow
198