1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include "tensorflow/core/lib/strings/str_util.h"
18
19 #define EIGEN_USE_THREADS
20
21 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
22 #define EIGEN_USE_GPU
23 #endif
24
25 #include "tensorflow/core/framework/variant_op_registry.h"
26
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/lib/core/status_test_util.h"
31 #include "tensorflow/core/platform/test.h"
32
33 namespace tensorflow {
34
35 typedef Eigen::ThreadPoolDevice CPUDevice;
36 typedef Eigen::GpuDevice GPUDevice;
37
38 namespace {
39
40 struct VariantValue {
TypeNametensorflow::__anon8c19e2d40111::VariantValue41 string TypeName() const { return "TEST VariantValue"; }
CPUZerosLikeFntensorflow::__anon8c19e2d40111::VariantValue42 static Status CPUZerosLikeFn(OpKernelContext* ctx, const VariantValue& v,
43 VariantValue* v_out) {
44 if (v.early_exit) {
45 return errors::InvalidArgument("early exit zeros_like!");
46 }
47 v_out->value = 1; // CPU
48 return Status::OK();
49 }
GPUZerosLikeFntensorflow::__anon8c19e2d40111::VariantValue50 static Status GPUZerosLikeFn(OpKernelContext* ctx, const VariantValue& v,
51 VariantValue* v_out) {
52 if (v.early_exit) {
53 return errors::InvalidArgument("early exit zeros_like!");
54 }
55 v_out->value = 2; // GPU
56 return Status::OK();
57 }
CPUAddFntensorflow::__anon8c19e2d40111::VariantValue58 static Status CPUAddFn(OpKernelContext* ctx, const VariantValue& a,
59 const VariantValue& b, VariantValue* out) {
60 if (a.early_exit) {
61 return errors::InvalidArgument("early exit add!");
62 }
63 out->value = a.value + b.value; // CPU
64 return Status::OK();
65 }
GPUAddFntensorflow::__anon8c19e2d40111::VariantValue66 static Status GPUAddFn(OpKernelContext* ctx, const VariantValue& a,
67 const VariantValue& b, VariantValue* out) {
68 if (a.early_exit) {
69 return errors::InvalidArgument("early exit add!");
70 }
71 out->value = -(a.value + b.value); // GPU
72 return Status::OK();
73 }
CPUToGPUCopyFntensorflow::__anon8c19e2d40111::VariantValue74 static Status CPUToGPUCopyFn(
75 const VariantValue& from, VariantValue* to,
76 const std::function<Status(const Tensor&, Tensor*)>& copier) {
77 TF_RETURN_IF_ERROR(copier(Tensor(), nullptr));
78 to->value = 0xdeadbeef;
79 return Status::OK();
80 }
81 bool early_exit;
82 int value;
83 };
84
85 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantValue, "TEST VariantValue");
86
87 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
88 VariantValue, VariantDeviceCopyDirection::HOST_TO_DEVICE,
89 VariantValue::CPUToGPUCopyFn);
90
91 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
92 DEVICE_CPU, VariantValue,
93 VariantValue::CPUZerosLikeFn);
94
95 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
96 DEVICE_GPU, VariantValue,
97 VariantValue::GPUZerosLikeFn);
98
99 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
100 VariantValue, VariantValue::CPUAddFn);
101
102 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
103 VariantValue, VariantValue::GPUAddFn);
104
105 } // namespace
106
TEST(VariantOpDecodeRegistryTest,TestBasic)107 TEST(VariantOpDecodeRegistryTest, TestBasic) {
108 EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetDecodeFn("YOU SHALL NOT PASS"),
109 nullptr);
110
111 auto* decode_fn =
112 UnaryVariantOpRegistry::Global()->GetDecodeFn("TEST VariantValue");
113 EXPECT_NE(decode_fn, nullptr);
114
115 VariantValue vv{true /* early_exit */};
116 Variant v = vv;
117 VariantTensorData data;
118 v.Encode(&data);
119 VariantTensorDataProto proto;
120 data.ToProto(&proto);
121 Variant encoded = std::move(proto);
122 EXPECT_TRUE((*decode_fn)(&encoded));
123 VariantValue* decoded = encoded.get<VariantValue>();
124 EXPECT_NE(decoded, nullptr);
125 EXPECT_EQ(decoded->early_exit, true);
126 }
127
TEST(VariantOpDecodeRegistryTest,TestEmpty)128 TEST(VariantOpDecodeRegistryTest, TestEmpty) {
129 VariantTensorDataProto empty_proto;
130 Variant empty_encoded = std::move(empty_proto);
131 EXPECT_TRUE(DecodeUnaryVariant(&empty_encoded));
132 EXPECT_TRUE(empty_encoded.is_empty());
133
134 VariantTensorData data;
135 Variant number = 3.0f;
136 number.Encode(&data);
137 VariantTensorDataProto proto;
138 data.ToProto(&proto);
139 proto.set_type_name("");
140 Variant encoded = std::move(proto);
141 // Failure when type name is empty but there's data in the proto.
142 EXPECT_FALSE(DecodeUnaryVariant(&encoded));
143 }
144
TEST(VariantOpDecodeRegistryTest,TestDuplicate)145 TEST(VariantOpDecodeRegistryTest, TestDuplicate) {
146 UnaryVariantOpRegistry registry;
147 UnaryVariantOpRegistry::VariantDecodeFn f;
148 string kTypeName = "fjfjfj";
149 registry.RegisterDecodeFn(kTypeName, f);
150 EXPECT_DEATH(registry.RegisterDecodeFn(kTypeName, f),
151 "fjfjfj already registered");
152 }
153
TEST(VariantOpCopyToGPURegistryTest,TestBasic)154 TEST(VariantOpCopyToGPURegistryTest, TestBasic) {
155 // No registered copy fn for GPU<->GPU.
156 EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
157 VariantDeviceCopyDirection::DEVICE_TO_DEVICE,
158 TypeIndex::Make<VariantValue>()),
159 nullptr);
160
161 auto* copy_to_gpu_fn = UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
162 VariantDeviceCopyDirection::HOST_TO_DEVICE,
163 TypeIndex::Make<VariantValue>());
164 EXPECT_NE(copy_to_gpu_fn, nullptr);
165
166 VariantValue vv{true /* early_exit */};
167 Variant v = vv;
168 Variant v_out;
169 bool dummy_executed = false;
170 auto dummy_copy_fn = [&dummy_executed](const Tensor& from,
171 Tensor* to) -> Status {
172 dummy_executed = true;
173 return Status::OK();
174 };
175 TF_EXPECT_OK((*copy_to_gpu_fn)(v, &v_out, dummy_copy_fn));
176 EXPECT_TRUE(dummy_executed);
177 VariantValue* copied_value = v_out.get<VariantValue>();
178 EXPECT_NE(copied_value, nullptr);
179 EXPECT_EQ(copied_value->value, 0xdeadbeef);
180 }
181
TEST(VariantOpCopyToGPURegistryTest,TestDuplicate)182 TEST(VariantOpCopyToGPURegistryTest, TestDuplicate) {
183 UnaryVariantOpRegistry registry;
184 UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn f;
185 class FjFjFj {};
186 const auto kTypeIndex = TypeIndex::Make<FjFjFj>();
187 registry.RegisterDeviceCopyFn(VariantDeviceCopyDirection::HOST_TO_DEVICE,
188 kTypeIndex, f);
189 EXPECT_DEATH(registry.RegisterDeviceCopyFn(
190 VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeIndex, f),
191 "FjFjFj already registered");
192 }
193
TEST(VariantOpZerosLikeRegistryTest,TestBasicCPU)194 TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
195 class Blah {};
196 EXPECT_EQ(
197 UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
198 ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, TypeIndex::Make<Blah>()),
199 nullptr);
200
201 VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
202 Variant v = vv_early_exit;
203 Variant v_out = VariantValue();
204
205 OpKernelContext* null_context_pointer = nullptr;
206 Status s0 = UnaryOpVariant<CPUDevice>(null_context_pointer,
207 ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out);
208 EXPECT_FALSE(s0.ok());
209 EXPECT_TRUE(absl::StrContains(s0.error_message(), "early exit zeros_like"));
210
211 VariantValue vv_ok{false /* early_exit */, 0 /* value */};
212 v = vv_ok;
213 TF_EXPECT_OK(UnaryOpVariant<CPUDevice>(
214 null_context_pointer, ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out));
215 VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>());
216 EXPECT_EQ(vv_out->value, 1); // CPU
217 }
218
219 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TEST(VariantOpUnaryOpRegistryTest,TestBasicGPU)220 TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
221 class Blah {};
222 EXPECT_EQ(
223 UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
224 ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, TypeIndex::Make<Blah>()),
225 nullptr);
226
227 VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
228 Variant v = vv_early_exit;
229 Variant v_out = VariantValue();
230
231 OpKernelContext* null_context_pointer = nullptr;
232 Status s0 = UnaryOpVariant<GPUDevice>(null_context_pointer,
233 ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out);
234 EXPECT_FALSE(s0.ok());
235 EXPECT_TRUE(absl::StrContains(s0.error_message(), "early exit zeros_like"));
236
237 VariantValue vv_ok{false /* early_exit */, 0 /* value */};
238 v = vv_ok;
239 TF_EXPECT_OK(UnaryOpVariant<GPUDevice>(
240 null_context_pointer, ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out));
241 VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>());
242 EXPECT_EQ(vv_out->value, 2); // GPU
243 }
244 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
245
TEST(VariantOpUnaryOpRegistryTest,TestDuplicate)246 TEST(VariantOpUnaryOpRegistryTest, TestDuplicate) {
247 UnaryVariantOpRegistry registry;
248 UnaryVariantOpRegistry::VariantUnaryOpFn f;
249 class FjFjFj {};
250 const auto kTypeIndex = TypeIndex::Make<FjFjFj>();
251
252 registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU,
253 kTypeIndex, f);
254 EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP,
255 DEVICE_CPU, kTypeIndex, f),
256 "FjFjFj already registered");
257
258 registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU,
259 kTypeIndex, f);
260 EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP,
261 DEVICE_GPU, kTypeIndex, f),
262 "FjFjFj already registered");
263 }
264
TEST(VariantOpAddRegistryTest,TestBasicCPU)265 TEST(VariantOpAddRegistryTest, TestBasicCPU) {
266 class Blah {};
267 EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
268 ADD_VARIANT_BINARY_OP, DEVICE_CPU, TypeIndex::Make<Blah>()),
269 nullptr);
270
271 VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
272 VariantValue vv_other{true /* early_exit */, 4 /* value */};
273 Variant v_a = vv_early_exit;
274 Variant v_b = vv_other;
275 Variant v_out = VariantValue();
276
277 OpKernelContext* null_context_pointer = nullptr;
278 Status s0 = BinaryOpVariants<CPUDevice>(
279 null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out);
280 EXPECT_FALSE(s0.ok());
281 EXPECT_TRUE(absl::StrContains(s0.error_message(), "early exit add"));
282
283 VariantValue vv_ok{false /* early_exit */, 3 /* value */};
284 v_a = vv_ok;
285 TF_EXPECT_OK(BinaryOpVariants<CPUDevice>(
286 null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out));
287 VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>());
288 EXPECT_EQ(vv_out->value, 7); // CPU
289 }
290
291 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
TEST(VariantOpAddRegistryTest,TestBasicGPU)292 TEST(VariantOpAddRegistryTest, TestBasicGPU) {
293 class Blah {};
294 EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
295 ADD_VARIANT_BINARY_OP, DEVICE_GPU, TypeIndex::Make<Blah>()),
296 nullptr);
297
298 VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
299 VariantValue vv_other{true /* early_exit */, 4 /* value */};
300 Variant v_a = vv_early_exit;
301 Variant v_b = vv_other;
302 Variant v_out = VariantValue();
303
304 OpKernelContext* null_context_pointer = nullptr;
305 Status s0 = BinaryOpVariants<GPUDevice>(
306 null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out);
307 EXPECT_FALSE(s0.ok());
308 EXPECT_TRUE(absl::StrContains(s0.error_message(), "early exit add"));
309
310 VariantValue vv_ok{false /* early_exit */, 3 /* value */};
311 v_a = vv_ok;
312 TF_EXPECT_OK(BinaryOpVariants<GPUDevice>(
313 null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out));
314 VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>());
315 EXPECT_EQ(vv_out->value, -7); // GPU
316 }
317 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
318
TEST(VariantOpAddRegistryTest,TestDuplicate)319 TEST(VariantOpAddRegistryTest, TestDuplicate) {
320 UnaryVariantOpRegistry registry;
321 UnaryVariantOpRegistry::VariantBinaryOpFn f;
322 class FjFjFj {};
323 const auto kTypeIndex = TypeIndex::Make<FjFjFj>();
324
325 registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeIndex, f);
326 EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
327 kTypeIndex, f),
328 "FjFjFj already registered");
329
330 registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeIndex, f);
331 EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
332 kTypeIndex, f),
333 "FjFjFj already registered");
334 }
335
336 } // namespace tensorflow
337