1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/parallel_device/parallel_device.h"
17
18 #include <array>
19
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/c_api_experimental.h"
22 #include "tensorflow/c/eager/c_api.h"
23 #include "tensorflow/c/eager/c_api_experimental.h"
24 #include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
25 #include "tensorflow/core/platform/test.h"
26
27 // NOTE(allenl): These tests currently go through TFE_Execute and so are
28 // integration testing rather than purely testing the parallel device. They
29 // correspond fairly well to the implementation, but testing the C++ directly is
30 // another option.
31
32 namespace tensorflow {
33 namespace parallel_device {
34
TEST(PARALLEL_DEVICE,TestBasicCPU)35 TEST(PARALLEL_DEVICE, TestBasicCPU) {
36 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
37 TF_NewStatus(), TF_DeleteStatus);
38 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
39 TFE_NewContextOptions(), TFE_DeleteContextOptions);
40 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
41 TF_CreateConfig(
42 /*xla*/ false,
43 /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
44 2),
45 TF_DeleteBuffer);
46 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
47 status.get());
48 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
49 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
50 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
51 BasicTestsForTwoDevices(context.get(),
52 "/job:localhost/replica:0/task:0/device:CPU:0",
53 "/job:localhost/replica:0/task:0/device:CPU:1");
54 }
55
TEST(PARALLEL_DEVICE,TestBasicCPUAliased)56 TEST(PARALLEL_DEVICE, TestBasicCPUAliased) {
57 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
58 TF_NewStatus(), TF_DeleteStatus);
59 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
60 TFE_NewContextOptions(), TFE_DeleteContextOptions);
61 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
62 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
63 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
64 BasicTestsForTwoDevices(context.get(),
65 "/job:localhost/replica:0/task:0/device:CPU:0",
66 "/job:localhost/replica:0/task:0/device:CPU:0");
67 }
68
TEST(PARALLEL_DEVICE,TestBasicTPUAliased)69 TEST(PARALLEL_DEVICE, TestBasicTPUAliased) {
70 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
71 TF_NewStatus(), TF_DeleteStatus);
72 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
73 TFE_NewContextOptions(), TFE_DeleteContextOptions);
74 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
75 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
76 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
77
78 // Skip the test if no TPU is available.
79 std::unique_ptr<TF_DeviceList, decltype(&TF_DeleteDeviceList)> devices(
80 TFE_ContextListDevices(context.get(), status.get()), TF_DeleteDeviceList);
81 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
82 bool has_tpu = false;
83 for (int device_index = 0; device_index < TF_DeviceListCount(devices.get());
84 ++device_index) {
85 std::string device_type =
86 TF_DeviceListType(devices.get(), device_index, status.get());
87 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
88 if (device_type == "TPU") {
89 has_tpu = true;
90 break;
91 }
92 }
93 if (has_tpu) {
94 BasicTestsForTwoDevices(context.get(),
95 "/job:localhost/replica:0/task:0/device:TPU:0",
96 "/job:localhost/replica:0/task:0/device:TPU:0");
97 }
98 }
99
TEST(PARALLEL_DEVICE,TestExplicitCopies)100 TEST(PARALLEL_DEVICE, TestExplicitCopies) {
101 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
102 TF_NewStatus(), TF_DeleteStatus);
103 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
104 TFE_NewContextOptions(), TFE_DeleteContextOptions);
105 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
106 TF_CreateConfig(
107 /*xla*/ false,
108 /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
109 2),
110 TF_DeleteBuffer);
111 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
112 status.get());
113 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
114 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
115 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
116
117 const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
118 const char* first_device_name =
119 "/job:localhost/replica:0/task:0/device:CPU:0";
120 const char* second_device_name =
121 "/job:localhost/replica:0/task:0/device:CPU:1";
122 std::array<const char*, 2> underlying_devices{first_device_name,
123 second_device_name};
124 RegisterParallelDevice(context.get(), device_name, underlying_devices,
125 status.get());
126 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
127
128 TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get()));
129 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
130
131 // Copying on to a parallel device is OK.
132 TensorHandlePtr device_value(TFE_TensorHandleCopyToDevice(
133 cpu_value.get(), context.get(), device_name, status.get()));
134 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
135 const char* backing_device =
136 TFE_TensorHandleBackingDeviceName(device_value.get(), status.get());
137 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
138 ASSERT_EQ(std::string(device_name), backing_device);
139
140 // Un-pack the parallel tensor to verify that the copy was successful.
141 {
142 std::array<TensorHandlePtr, 2> components;
143 ExtractPerDeviceValues(context.get(), device_value.get(), &components,
144 status.get());
145 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
146
147 // The value of the original tensor is replicated on each device.
148 ExpectScalarEq<float>(components[0].get(), 3.);
149 ExpectScalarEq<float>(components[1].get(), 3.);
150
151 // Verify that the mirrors are placed on the component devices.
152 std::string first_device =
153 TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
154 ASSERT_EQ(underlying_devices[0], first_device);
155 std::string second_device =
156 TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
157 ASSERT_EQ(underlying_devices[1], second_device);
158 }
159
160 // Copies off of parallel devices must be explicit.
161 TensorHandlePtr copy_back(TFE_TensorHandleCopyToDevice(
162 device_value.get(), context.get(), first_device_name, status.get()));
163 ASSERT_EQ(TF_GetCode(status.get()), TF_UNIMPLEMENTED);
164 }
165
TEST(PARALLEL_DEVICE,TestDifferentShapes)166 TEST(PARALLEL_DEVICE, TestDifferentShapes) {
167 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
168 TF_NewStatus(), TF_DeleteStatus);
169 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
170 TFE_NewContextOptions(), TFE_DeleteContextOptions);
171 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
172 TF_CreateConfig(
173 /*xla*/ false,
174 /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
175 2),
176 TF_DeleteBuffer);
177 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
178 status.get());
179 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
180 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
181 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
182
183 const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
184 std::array<const char*, 2> underlying_devices{
185 "/job:localhost/replica:0/task:0/device:CPU:0",
186 "/job:localhost/replica:0/task:0/device:CPU:1"};
187 RegisterParallelDevice(context.get(), device_name, underlying_devices,
188 status.get());
189 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
190
191 // Create two vectors with different lengths
192 std::vector<float> size_two_value{1., 2.};
193 std::vector<float> size_three_value{1., 2., 3.};
194 TensorHandlePtr size_two(
195 VectorFloatTensorHandle(size_two_value, status.get()));
196 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
197 TensorHandlePtr size_three(
198 VectorFloatTensorHandle(size_three_value, status.get()));
199 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
200
201 // Try to combine these values into a single parallel tensor.
202 std::array<TFE_TensorHandle*, 2> components{size_two.get(), size_three.get()};
203 TensorHandlePtr combined_value = CreatePerDeviceValues(
204 context.get(), components, device_name, status.get());
205 // We can create the handle, but fetching the shape is an error at the moment.
206 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
207 TFE_TensorHandleNumDims(combined_value.get(), status.get());
208 ASSERT_TRUE(TF_GetCode(status.get()) == TF_UNIMPLEMENTED);
209 }
210
TEST(PARALLEL_DEVICE,TestNestedParallelDevices)211 TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
212 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
213 TF_NewStatus(), TF_DeleteStatus);
214 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
215 TFE_NewContextOptions(), TFE_DeleteContextOptions);
216 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
217 TF_CreateConfig(
218 /*xla*/ false,
219 /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
220 3),
221 TF_DeleteBuffer);
222 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
223 status.get());
224 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
225 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
226 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
227
228 // Create a parallel device with two CPUs
229 const char* first_device_name =
230 "/job:localhost/replica:0/task:0/device:CUSTOM:0";
231 std::array<const char*, 2> first_underlying_devices{
232 "/job:localhost/replica:0/task:0/device:CPU:0",
233 "/job:localhost/replica:0/task:0/device:CPU:1"};
234 RegisterParallelDevice(context.get(), first_device_name,
235 first_underlying_devices, status.get());
236 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
237
238 // Create a second parallel device with the first parallel device and one
239 // additional CPU.
240 const char* second_device_name =
241 "/job:localhost/replica:0/task:0/device:CUSTOM:1";
242 std::array<const char*, 2> second_underlying_devices{
243 "/job:localhost/replica:0/task:0/device:CUSTOM:0",
244 "/job:localhost/replica:0/task:0/device:CPU:2"};
245 RegisterParallelDevice(context.get(), second_device_name,
246 second_underlying_devices, status.get());
247 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
248
249 // Create a tensor on the first parallel device
250 TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
251 TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
252 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
253 std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
254 TensorHandlePtr first_combined_value = CreatePerDeviceValues(
255 context.get(), components, first_device_name, status.get());
256 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
257
258 // Nest the first parallel tensor into a second
259 TensorHandlePtr value_three(FloatTensorHandle(3., status.get()));
260 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
261 components[0] = first_combined_value.get();
262 components[1] = value_three.get();
263 TensorHandlePtr second_combined_value = CreatePerDeviceValues(
264 context.get(), components, second_device_name, status.get());
265 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
266
267 TensorHandlePtr negative_one(FloatTensorHandle(3., status.get()));
268 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
269 TensorHandlePtr multiply_result(Multiply(context.get(),
270 second_combined_value.get(),
271 negative_one.get(), status.get()));
272 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
273
274 // Un-pack the parallel tensor to verify that the operation was
275 // successful. The resulting structure should be:
276 // second_device{first_device{1. * 3., 2. * 3.}, 3. * 3.}.
277 std::array<TensorHandlePtr, 2> second_components;
278 ExtractPerDeviceValues(context.get(), multiply_result.get(),
279 &second_components, status.get());
280 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
281
282 ExpectScalarEq<float>(second_components[1].get(), 9.);
283
284 // Verify that the mirrors are placed on the component devices.
285 std::string first_device = TFE_TensorHandleBackingDeviceName(
286 second_components[0].get(), status.get());
287 ASSERT_EQ(second_underlying_devices[0], first_device);
288 std::string second_device = TFE_TensorHandleBackingDeviceName(
289 second_components[1].get(), status.get());
290 ASSERT_EQ(second_underlying_devices[1], second_device);
291
292 // Un-pack the first parallel device's tensor too
293 std::array<TensorHandlePtr, 2> first_components;
294 ExtractPerDeviceValues(context.get(), second_components[0].get(),
295 &first_components, status.get());
296 ExpectScalarEq<float>(first_components[0].get(), 3.);
297 ExpectScalarEq<float>(first_components[1].get(), 6.);
298
299 first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
300 status.get());
301 ASSERT_EQ(first_underlying_devices[0], first_device);
302 second_device = TFE_TensorHandleBackingDeviceName(first_components[1].get(),
303 status.get());
304 ASSERT_EQ(first_underlying_devices[1], second_device);
305 }
306
TEST(PARALLEL_DEVICE,TestInvalidPacking)307 TEST(PARALLEL_DEVICE, TestInvalidPacking) {
308 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
309 TF_NewStatus(), TF_DeleteStatus);
310 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
311 TFE_NewContextOptions(), TFE_DeleteContextOptions);
312 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
313 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
314 const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
315 std::array<const char*, 1> underlying_devices{
316 "/job:localhost/replica:0/task:0/device:CPU:0"};
317 RegisterParallelDevice(context.get(), device_name, underlying_devices,
318 status.get());
319 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
320
321 TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
322 TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
323 {
324 // Try to pack two TensorHandles onto a parallel device with a single
325 // component.
326 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
327 std::array<TFE_TensorHandle*, 2> components{value_one.get(),
328 value_two.get()};
329 TensorHandlePtr combined_value = CreatePerDeviceValues(
330 context.get(), components, device_name, status.get());
331 ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
332 << TF_Message(status.get());
333 }
334
335 {
336 // Try to extract the wrong number of components from a parallel tensor
337 std::array<TFE_TensorHandle*, 1> correct_components{value_one.get()};
338 TensorHandlePtr combined_value = CreatePerDeviceValues(
339 context.get(), correct_components, device_name, status.get());
340 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
341
342 std::array<TensorHandlePtr, 2> incorrect_components;
343 ExtractPerDeviceValues(context.get(), combined_value.get(),
344 &incorrect_components, status.get());
345 ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
346 << TF_Message(status.get());
347 }
348
349 {
350 // Try to pass a ParallelTensor to TPUReplicatedInput
351 std::array<TFE_TensorHandle*, 1> correct_components{value_one.get()};
352 TensorHandlePtr combined_value = CreatePerDeviceValues(
353 context.get(), correct_components, device_name, status.get());
354 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
355
356 std::array<TFE_TensorHandle*, 1> incorrect_components{combined_value.get()};
357 TensorHandlePtr recombined_value = CreatePerDeviceValues(
358 context.get(), incorrect_components, device_name, status.get());
359 ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
360 << TF_Message(status.get());
361 }
362
363 {
364 // Try to pass a non-parallel tensor to TPUReplicatedOutput
365 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
366 TFE_NewOp(context.get(), "TPUReplicatedOutput", status.get()),
367 TFE_DeleteOp);
368 if (TF_GetCode(status.get()) != TF_OK) return;
369 TFE_OpSetAttrInt(op.get(), "num_replicas", 1);
370 TFE_OpAddInput(op.get(), value_one.get(), status.get());
371 if (TF_GetCode(status.get()) != TF_OK) return;
372 TFE_OpSetDevice(op.get(), device_name, status.get());
373 if (TF_GetCode(status.get()) != TF_OK) return;
374
375 TFE_TensorHandle* result_handles;
376 int num_retvals = 1;
377 TFE_Execute(op.get(), &result_handles, &num_retvals, status.get());
378 ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
379 << TF_Message(status.get());
380 }
381 }
382
CollectiveSum(TFE_Context * context,TFE_TensorHandle * input,int group_size,TF_Status * status)383 TensorHandlePtr CollectiveSum(TFE_Context* context, TFE_TensorHandle* input,
384 int group_size, TF_Status* status) {
385 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
386 TFE_NewOp(context, "CollectiveReduce", status), TFE_DeleteOp);
387 if (TF_GetCode(status) != TF_OK) return nullptr;
388
389 const char* device = TFE_TensorHandleDeviceName(input, status);
390 if (TF_GetCode(status) != TF_OK) return nullptr;
391 TFE_OpSetDevice(op.get(), device, status);
392 if (TF_GetCode(status) != TF_OK) return nullptr;
393 TFE_OpSetAttrType(op.get(), "T", TFE_TensorHandleDataType(input));
394 TFE_OpSetAttrInt(op.get(), "group_size", group_size);
395 TFE_OpSetAttrInt(op.get(), "group_key", 0);
396 TFE_OpSetAttrInt(op.get(), "instance_key", 0);
397 const std::string merge_op("Add");
398 TFE_OpSetAttrString(op.get(), "merge_op", merge_op.c_str(),
399 merge_op.length());
400 const std::string final_op("Id");
401 TFE_OpSetAttrString(op.get(), "final_op", final_op.c_str(),
402 final_op.length());
403 TFE_OpSetAttrIntList(op.get(), "subdiv_offsets", nullptr, 0);
404
405 TFE_OpAddInput(op.get(), input, status);
406 if (TF_GetCode(status) != TF_OK) return nullptr;
407
408 TFE_TensorHandle* result_handle;
409 int num_retvals = 1;
410 TFE_Execute(op.get(), &result_handle, &num_retvals, status);
411 if (TF_GetCode(status) != TF_OK) return nullptr;
412 return TensorHandlePtr(result_handle);
413 }
414
TestCollective(bool async)415 void TestCollective(bool async) {
416 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
417 TF_NewStatus(), TF_DeleteStatus);
418 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
419 TFE_NewContextOptions(), TFE_DeleteContextOptions);
420 TFE_ContextOptionsSetAsync(opts.get(), async);
421 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
422 TF_CreateConfig(
423 /*xla*/ false,
424 /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
425 2),
426 TF_DeleteBuffer);
427 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
428 status.get());
429 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
430 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
431 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
432
433 const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
434 std::array<const char*, 2> underlying_devices{
435 "/job:localhost/replica:0/task:0/device:CPU:0",
436 "/job:localhost/replica:0/task:0/device:CPU:1"};
437 RegisterParallelDevice(context.get(), device_name, underlying_devices,
438 status.get());
439 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
440
441 // Create a tensor on the parallel device
442 TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
443 TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
444 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
445 std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
446 TensorHandlePtr parallel_value = CreatePerDeviceValues(
447 context.get(), components, device_name, status.get());
448 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
449
450 // Run a collective sum, so each component should now be the same.
451 TensorHandlePtr reduced(
452 CollectiveSum(context.get(), parallel_value.get(), 2, status.get()));
453 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
454
455 std::array<TensorHandlePtr, 2> result_components;
456 ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
457 status.get());
458 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
459 ExpectScalarEq<float>(result_components[0].get(), 3.);
460 ExpectScalarEq<float>(result_components[1].get(), 3.);
461 }
462
TEST(PARALLEL_DEVICE,TestCollectiveSync)463 TEST(PARALLEL_DEVICE, TestCollectiveSync) { TestCollective(/*async=*/false); }
464
465 // Note that ops on the parallel device currently don't execute
466 // asynchronously. The test is just that we don't get deadlocks.
TEST(PARALLEL_DEVICE,TestCollectiveAsync)467 TEST(PARALLEL_DEVICE, TestCollectiveAsync) { TestCollective(/*async=*/true); }
468
RegisterCollectiveMulFunction(TFE_Context * context,const char * function_name,int group_size,TF_Status * status)469 void RegisterCollectiveMulFunction(TFE_Context* context,
470 const char* function_name, int group_size,
471 TF_Status* status) {
472 std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> body(TF_NewGraph(),
473 TF_DeleteGraph);
474 TF_OperationDescription* placeholder_desc =
475 TF_NewOperation(body.get(), "Placeholder", "Placeholder");
476 TF_SetAttrType(placeholder_desc, "dtype", TF_FLOAT);
477 TF_Operation* placeholder_op = TF_FinishOperation(placeholder_desc, status);
478 if (TF_GetCode(status) != TF_OK) return;
479 TF_Output x{placeholder_op, 0};
480
481 TF_OperationDescription* reduce_desc =
482 TF_NewOperation(body.get(), "CollectiveReduce", "CollectiveReduce");
483 TF_SetAttrType(reduce_desc, "T", TF_FLOAT);
484 TF_SetAttrInt(reduce_desc, "group_size", group_size);
485 TF_SetAttrInt(reduce_desc, "group_key", 0);
486 TF_SetAttrInt(reduce_desc, "instance_key", 0);
487
488 const std::string merge_op("Mul");
489 TF_SetAttrString(reduce_desc, "merge_op", merge_op.c_str(),
490 merge_op.length());
491 const std::string final_op("Id");
492 TF_SetAttrString(reduce_desc, "final_op", final_op.c_str(),
493 final_op.length());
494 TF_SetAttrIntList(reduce_desc, "subdiv_offsets", nullptr, 0);
495 TF_AddInput(reduce_desc, x);
496 TF_Operation* reduce_op = TF_FinishOperation(reduce_desc, status);
497 if (TF_GetCode(status) != TF_OK) return;
498 TF_Operation* operations[]{placeholder_op, reduce_op};
499 TF_Output y{reduce_op, 0};
500 const char* output_name = "y";
501 std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)> function(
502 TF_GraphToFunction(
503 /* fn_body */ body.get(), /* fn_name */ function_name,
504 /* append_hash_to_fn_name */ 0, /* num_opers */ 2,
505 /* opers */ operations, /* ninputs */ 1, /* inputs */ &x,
506 /* noutputs */ 1, /* outputs */ &y, /* output_names */ &output_name,
507 /* opts */ nullptr, /* description */ "", /* status */ status),
508 TF_DeleteFunction);
509 if (TF_GetCode(status) != TF_OK) return;
510 TFE_ContextAddFunction(context, function.get(), status);
511 }
512
TEST(PARALLEL_DEVICE,TestFunction)513 TEST(PARALLEL_DEVICE, TestFunction) {
514 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
515 TF_NewStatus(), TF_DeleteStatus);
516 std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
517 TFE_NewContextOptions(), TFE_DeleteContextOptions);
518 std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
519 TF_CreateConfig(
520 /*xla*/ false,
521 /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
522 2),
523 TF_DeleteBuffer);
524 TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
525 status.get());
526 std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
527 TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
528 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
529
530 const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
531 std::array<const char*, 2> underlying_devices{
532 "/job:localhost/replica:0/task:0/device:CPU:0",
533 "/job:localhost/replica:0/task:0/device:CPU:1"};
534 RegisterParallelDevice(context.get(), device_name, underlying_devices,
535 status.get());
536 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
537
538 const char* function_name = "test_reduce_mul";
539 RegisterCollectiveMulFunction(context.get(), function_name, 2, status.get());
540 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
541
542 TensorHandlePtr value_one(FloatTensorHandle(7., status.get()));
543 TensorHandlePtr value_two(FloatTensorHandle(9., status.get()));
544 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
545 std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
546 TensorHandlePtr parallel_value = CreatePerDeviceValues(
547 context.get(), components, device_name, status.get());
548 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
549
550 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
551 TFE_NewOp(context.get(), function_name, status.get()), TFE_DeleteOp);
552 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
553 TFE_OpSetDevice(op.get(), device_name, status.get());
554 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
555 TFE_OpAddInput(op.get(), parallel_value.get(), status.get());
556 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
557
558 TFE_TensorHandle* raw_result_handle;
559 int num_retvals = 1;
560 TFE_Execute(op.get(), &raw_result_handle, &num_retvals, status.get());
561 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
562 TensorHandlePtr reduced(raw_result_handle);
563
564 std::array<TensorHandlePtr, 2> result_components;
565 ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
566 status.get());
567 ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
568 ExpectScalarEq<float>(result_components[0].get(), 7. * 9.);
569 ExpectScalarEq<float>(result_components[1].get(), 7. * 9.);
570
571 std::string first_device = TFE_TensorHandleBackingDeviceName(
572 result_components[0].get(), status.get());
573 ASSERT_EQ(underlying_devices[0], first_device);
574 std::string second_device = TFE_TensorHandleBackingDeviceName(
575 result_components[1].get(), status.get());
576 ASSERT_EQ(underlying_devices[1], second_device);
577 }
578
579 } // namespace parallel_device
580 } // namespace tensorflow
581