• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/parallel_device/parallel_device.h"
17 
18 #include <array>
19 
20 #include "tensorflow/c/c_api.h"
21 #include "tensorflow/c/c_api_experimental.h"
22 #include "tensorflow/c/eager/c_api.h"
23 #include "tensorflow/c/eager/c_api_experimental.h"
24 #include "tensorflow/c/eager/parallel_device/parallel_device_testlib.h"
25 #include "tensorflow/core/platform/test.h"
26 
27 // NOTE(allenl): These tests currently go through TFE_Execute and so are
28 // integration testing rather than purely testing the parallel device. They
29 // correspond fairly well to the implementation, but testing the C++ directly is
30 // another option.
31 
32 namespace tensorflow {
33 namespace parallel_device {
34 
TEST(PARALLEL_DEVICE,TestBasicCPU)35 TEST(PARALLEL_DEVICE, TestBasicCPU) {
36   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
37       TF_NewStatus(), TF_DeleteStatus);
38   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
39       TFE_NewContextOptions(), TFE_DeleteContextOptions);
40   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
41       TF_CreateConfig(
42           /*xla*/ false,
43           /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
44           2),
45       TF_DeleteBuffer);
46   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
47                               status.get());
48   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
49       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
50   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
51   BasicTestsForTwoDevices(context.get(),
52                           "/job:localhost/replica:0/task:0/device:CPU:0",
53                           "/job:localhost/replica:0/task:0/device:CPU:1");
54 }
55 
TEST(PARALLEL_DEVICE,TestBasicCPUAliased)56 TEST(PARALLEL_DEVICE, TestBasicCPUAliased) {
57   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
58       TF_NewStatus(), TF_DeleteStatus);
59   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
60       TFE_NewContextOptions(), TFE_DeleteContextOptions);
61   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
62       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
63   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
64   BasicTestsForTwoDevices(context.get(),
65                           "/job:localhost/replica:0/task:0/device:CPU:0",
66                           "/job:localhost/replica:0/task:0/device:CPU:0");
67 }
68 
TEST(PARALLEL_DEVICE,TestBasicTPUAliased)69 TEST(PARALLEL_DEVICE, TestBasicTPUAliased) {
70   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
71       TF_NewStatus(), TF_DeleteStatus);
72   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
73       TFE_NewContextOptions(), TFE_DeleteContextOptions);
74   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
75       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
76   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
77 
78   // Skip the test if no TPU is available.
79   std::unique_ptr<TF_DeviceList, decltype(&TF_DeleteDeviceList)> devices(
80       TFE_ContextListDevices(context.get(), status.get()), TF_DeleteDeviceList);
81   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
82   bool has_tpu = false;
83   for (int device_index = 0; device_index < TF_DeviceListCount(devices.get());
84        ++device_index) {
85     std::string device_type =
86         TF_DeviceListType(devices.get(), device_index, status.get());
87     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
88     if (device_type == "TPU") {
89       has_tpu = true;
90       break;
91     }
92   }
93   if (has_tpu) {
94     BasicTestsForTwoDevices(context.get(),
95                             "/job:localhost/replica:0/task:0/device:TPU:0",
96                             "/job:localhost/replica:0/task:0/device:TPU:0");
97   }
98 }
99 
TEST(PARALLEL_DEVICE,TestExplicitCopies)100 TEST(PARALLEL_DEVICE, TestExplicitCopies) {
101   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
102       TF_NewStatus(), TF_DeleteStatus);
103   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
104       TFE_NewContextOptions(), TFE_DeleteContextOptions);
105   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
106       TF_CreateConfig(
107           /*xla*/ false,
108           /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
109           2),
110       TF_DeleteBuffer);
111   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
112                               status.get());
113   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
114       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
115   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
116 
117   const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
118   const char* first_device_name =
119       "/job:localhost/replica:0/task:0/device:CPU:0";
120   const char* second_device_name =
121       "/job:localhost/replica:0/task:0/device:CPU:1";
122   std::array<const char*, 2> underlying_devices{first_device_name,
123                                                 second_device_name};
124   RegisterParallelDevice(context.get(), device_name, underlying_devices,
125                          status.get());
126   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
127 
128   TensorHandlePtr cpu_value(FloatTensorHandle(3., status.get()));
129   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
130 
131   // Copying on to a parallel device is OK.
132   TensorHandlePtr device_value(TFE_TensorHandleCopyToDevice(
133       cpu_value.get(), context.get(), device_name, status.get()));
134   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
135   const char* backing_device =
136       TFE_TensorHandleBackingDeviceName(device_value.get(), status.get());
137   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
138   ASSERT_EQ(std::string(device_name), backing_device);
139 
140   // Un-pack the parallel tensor to verify that the copy was successful.
141   {
142     std::array<TensorHandlePtr, 2> components;
143     ExtractPerDeviceValues(context.get(), device_value.get(), &components,
144                            status.get());
145     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
146 
147     // The value of the original tensor is replicated on each device.
148     ExpectScalarEq<float>(components[0].get(), 3.);
149     ExpectScalarEq<float>(components[1].get(), 3.);
150 
151     // Verify that the mirrors are placed on the component devices.
152     std::string first_device =
153         TFE_TensorHandleBackingDeviceName(components[0].get(), status.get());
154     ASSERT_EQ(underlying_devices[0], first_device);
155     std::string second_device =
156         TFE_TensorHandleBackingDeviceName(components[1].get(), status.get());
157     ASSERT_EQ(underlying_devices[1], second_device);
158   }
159 
160   // Copies off of parallel devices must be explicit.
161   TensorHandlePtr copy_back(TFE_TensorHandleCopyToDevice(
162       device_value.get(), context.get(), first_device_name, status.get()));
163   ASSERT_EQ(TF_GetCode(status.get()), TF_UNIMPLEMENTED);
164 }
165 
TEST(PARALLEL_DEVICE,TestDifferentShapes)166 TEST(PARALLEL_DEVICE, TestDifferentShapes) {
167   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
168       TF_NewStatus(), TF_DeleteStatus);
169   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
170       TFE_NewContextOptions(), TFE_DeleteContextOptions);
171   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
172       TF_CreateConfig(
173           /*xla*/ false,
174           /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
175           2),
176       TF_DeleteBuffer);
177   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
178                               status.get());
179   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
180       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
181   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
182 
183   const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
184   std::array<const char*, 2> underlying_devices{
185       "/job:localhost/replica:0/task:0/device:CPU:0",
186       "/job:localhost/replica:0/task:0/device:CPU:1"};
187   RegisterParallelDevice(context.get(), device_name, underlying_devices,
188                          status.get());
189   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
190 
191   // Create two vectors with different lengths
192   std::vector<float> size_two_value{1., 2.};
193   std::vector<float> size_three_value{1., 2., 3.};
194   TensorHandlePtr size_two(
195       VectorFloatTensorHandle(size_two_value, status.get()));
196   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
197   TensorHandlePtr size_three(
198       VectorFloatTensorHandle(size_three_value, status.get()));
199   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
200 
201   // Try to combine these values into a single parallel tensor.
202   std::array<TFE_TensorHandle*, 2> components{size_two.get(), size_three.get()};
203   TensorHandlePtr combined_value = CreatePerDeviceValues(
204       context.get(), components, device_name, status.get());
205   // We can create the handle, but fetching the shape is an error at the moment.
206   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
207   TFE_TensorHandleNumDims(combined_value.get(), status.get());
208   ASSERT_TRUE(TF_GetCode(status.get()) == TF_UNIMPLEMENTED);
209 }
210 
TEST(PARALLEL_DEVICE,TestNestedParallelDevices)211 TEST(PARALLEL_DEVICE, TestNestedParallelDevices) {
212   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
213       TF_NewStatus(), TF_DeleteStatus);
214   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
215       TFE_NewContextOptions(), TFE_DeleteContextOptions);
216   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
217       TF_CreateConfig(
218           /*xla*/ false,
219           /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
220           3),
221       TF_DeleteBuffer);
222   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
223                               status.get());
224   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
225       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
226   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
227 
228   // Create a parallel device with two CPUs
229   const char* first_device_name =
230       "/job:localhost/replica:0/task:0/device:CUSTOM:0";
231   std::array<const char*, 2> first_underlying_devices{
232       "/job:localhost/replica:0/task:0/device:CPU:0",
233       "/job:localhost/replica:0/task:0/device:CPU:1"};
234   RegisterParallelDevice(context.get(), first_device_name,
235                          first_underlying_devices, status.get());
236   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
237 
238   // Create a second parallel device with the first parallel device and one
239   // additional CPU.
240   const char* second_device_name =
241       "/job:localhost/replica:0/task:0/device:CUSTOM:1";
242   std::array<const char*, 2> second_underlying_devices{
243       "/job:localhost/replica:0/task:0/device:CUSTOM:0",
244       "/job:localhost/replica:0/task:0/device:CPU:2"};
245   RegisterParallelDevice(context.get(), second_device_name,
246                          second_underlying_devices, status.get());
247   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
248 
249   // Create a tensor on the first parallel device
250   TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
251   TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
252   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
253   std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
254   TensorHandlePtr first_combined_value = CreatePerDeviceValues(
255       context.get(), components, first_device_name, status.get());
256   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
257 
258   // Nest the first parallel tensor into a second
259   TensorHandlePtr value_three(FloatTensorHandle(3., status.get()));
260   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
261   components[0] = first_combined_value.get();
262   components[1] = value_three.get();
263   TensorHandlePtr second_combined_value = CreatePerDeviceValues(
264       context.get(), components, second_device_name, status.get());
265   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
266 
267   TensorHandlePtr negative_one(FloatTensorHandle(3., status.get()));
268   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
269   TensorHandlePtr multiply_result(Multiply(context.get(),
270                                            second_combined_value.get(),
271                                            negative_one.get(), status.get()));
272   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
273 
274   // Un-pack the parallel tensor to verify that the operation was
275   // successful. The resulting structure should be:
276   //   second_device{first_device{1. * 3., 2. * 3.}, 3. * 3.}.
277   std::array<TensorHandlePtr, 2> second_components;
278   ExtractPerDeviceValues(context.get(), multiply_result.get(),
279                          &second_components, status.get());
280   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
281 
282   ExpectScalarEq<float>(second_components[1].get(), 9.);
283 
284   // Verify that the mirrors are placed on the component devices.
285   std::string first_device = TFE_TensorHandleBackingDeviceName(
286       second_components[0].get(), status.get());
287   ASSERT_EQ(second_underlying_devices[0], first_device);
288   std::string second_device = TFE_TensorHandleBackingDeviceName(
289       second_components[1].get(), status.get());
290   ASSERT_EQ(second_underlying_devices[1], second_device);
291 
292   // Un-pack the first parallel device's tensor too
293   std::array<TensorHandlePtr, 2> first_components;
294   ExtractPerDeviceValues(context.get(), second_components[0].get(),
295                          &first_components, status.get());
296   ExpectScalarEq<float>(first_components[0].get(), 3.);
297   ExpectScalarEq<float>(first_components[1].get(), 6.);
298 
299   first_device = TFE_TensorHandleBackingDeviceName(first_components[0].get(),
300                                                    status.get());
301   ASSERT_EQ(first_underlying_devices[0], first_device);
302   second_device = TFE_TensorHandleBackingDeviceName(first_components[1].get(),
303                                                     status.get());
304   ASSERT_EQ(first_underlying_devices[1], second_device);
305 }
306 
TEST(PARALLEL_DEVICE,TestInvalidPacking)307 TEST(PARALLEL_DEVICE, TestInvalidPacking) {
308   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
309       TF_NewStatus(), TF_DeleteStatus);
310   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
311       TFE_NewContextOptions(), TFE_DeleteContextOptions);
312   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
313       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
314   const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
315   std::array<const char*, 1> underlying_devices{
316       "/job:localhost/replica:0/task:0/device:CPU:0"};
317   RegisterParallelDevice(context.get(), device_name, underlying_devices,
318                          status.get());
319   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
320 
321   TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
322   TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
323   {
324     // Try to pack two TensorHandles onto a parallel device with a single
325     // component.
326     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
327     std::array<TFE_TensorHandle*, 2> components{value_one.get(),
328                                                 value_two.get()};
329     TensorHandlePtr combined_value = CreatePerDeviceValues(
330         context.get(), components, device_name, status.get());
331     ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
332         << TF_Message(status.get());
333   }
334 
335   {
336     // Try to extract the wrong number of components from a parallel tensor
337     std::array<TFE_TensorHandle*, 1> correct_components{value_one.get()};
338     TensorHandlePtr combined_value = CreatePerDeviceValues(
339         context.get(), correct_components, device_name, status.get());
340     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
341 
342     std::array<TensorHandlePtr, 2> incorrect_components;
343     ExtractPerDeviceValues(context.get(), combined_value.get(),
344                            &incorrect_components, status.get());
345     ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
346         << TF_Message(status.get());
347   }
348 
349   {
350     // Try to pass a ParallelTensor to TPUReplicatedInput
351     std::array<TFE_TensorHandle*, 1> correct_components{value_one.get()};
352     TensorHandlePtr combined_value = CreatePerDeviceValues(
353         context.get(), correct_components, device_name, status.get());
354     ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
355 
356     std::array<TFE_TensorHandle*, 1> incorrect_components{combined_value.get()};
357     TensorHandlePtr recombined_value = CreatePerDeviceValues(
358         context.get(), incorrect_components, device_name, status.get());
359     ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
360         << TF_Message(status.get());
361   }
362 
363   {
364     // Try to pass a non-parallel tensor to TPUReplicatedOutput
365     std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
366         TFE_NewOp(context.get(), "TPUReplicatedOutput", status.get()),
367         TFE_DeleteOp);
368     if (TF_GetCode(status.get()) != TF_OK) return;
369     TFE_OpSetAttrInt(op.get(), "num_replicas", 1);
370     TFE_OpAddInput(op.get(), value_one.get(), status.get());
371     if (TF_GetCode(status.get()) != TF_OK) return;
372     TFE_OpSetDevice(op.get(), device_name, status.get());
373     if (TF_GetCode(status.get()) != TF_OK) return;
374 
375     TFE_TensorHandle* result_handles;
376     int num_retvals = 1;
377     TFE_Execute(op.get(), &result_handles, &num_retvals, status.get());
378     ASSERT_TRUE(TF_GetCode(status.get()) == TF_INVALID_ARGUMENT)
379         << TF_Message(status.get());
380   }
381 }
382 
CollectiveSum(TFE_Context * context,TFE_TensorHandle * input,int group_size,TF_Status * status)383 TensorHandlePtr CollectiveSum(TFE_Context* context, TFE_TensorHandle* input,
384                               int group_size, TF_Status* status) {
385   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
386       TFE_NewOp(context, "CollectiveReduce", status), TFE_DeleteOp);
387   if (TF_GetCode(status) != TF_OK) return nullptr;
388 
389   const char* device = TFE_TensorHandleDeviceName(input, status);
390   if (TF_GetCode(status) != TF_OK) return nullptr;
391   TFE_OpSetDevice(op.get(), device, status);
392   if (TF_GetCode(status) != TF_OK) return nullptr;
393   TFE_OpSetAttrType(op.get(), "T", TFE_TensorHandleDataType(input));
394   TFE_OpSetAttrInt(op.get(), "group_size", group_size);
395   TFE_OpSetAttrInt(op.get(), "group_key", 0);
396   TFE_OpSetAttrInt(op.get(), "instance_key", 0);
397   const std::string merge_op("Add");
398   TFE_OpSetAttrString(op.get(), "merge_op", merge_op.c_str(),
399                       merge_op.length());
400   const std::string final_op("Id");
401   TFE_OpSetAttrString(op.get(), "final_op", final_op.c_str(),
402                       final_op.length());
403   TFE_OpSetAttrIntList(op.get(), "subdiv_offsets", nullptr, 0);
404 
405   TFE_OpAddInput(op.get(), input, status);
406   if (TF_GetCode(status) != TF_OK) return nullptr;
407 
408   TFE_TensorHandle* result_handle;
409   int num_retvals = 1;
410   TFE_Execute(op.get(), &result_handle, &num_retvals, status);
411   if (TF_GetCode(status) != TF_OK) return nullptr;
412   return TensorHandlePtr(result_handle);
413 }
414 
TestCollective(bool async)415 void TestCollective(bool async) {
416   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
417       TF_NewStatus(), TF_DeleteStatus);
418   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
419       TFE_NewContextOptions(), TFE_DeleteContextOptions);
420   TFE_ContextOptionsSetAsync(opts.get(), async);
421   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
422       TF_CreateConfig(
423           /*xla*/ false,
424           /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
425           2),
426       TF_DeleteBuffer);
427   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
428                               status.get());
429   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
430       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
431   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
432 
433   const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
434   std::array<const char*, 2> underlying_devices{
435       "/job:localhost/replica:0/task:0/device:CPU:0",
436       "/job:localhost/replica:0/task:0/device:CPU:1"};
437   RegisterParallelDevice(context.get(), device_name, underlying_devices,
438                          status.get());
439   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
440 
441   // Create a tensor on the parallel device
442   TensorHandlePtr value_one(FloatTensorHandle(1., status.get()));
443   TensorHandlePtr value_two(FloatTensorHandle(2., status.get()));
444   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
445   std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
446   TensorHandlePtr parallel_value = CreatePerDeviceValues(
447       context.get(), components, device_name, status.get());
448   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
449 
450   // Run a collective sum, so each component should now be the same.
451   TensorHandlePtr reduced(
452       CollectiveSum(context.get(), parallel_value.get(), 2, status.get()));
453   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
454 
455   std::array<TensorHandlePtr, 2> result_components;
456   ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
457                          status.get());
458   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
459   ExpectScalarEq<float>(result_components[0].get(), 3.);
460   ExpectScalarEq<float>(result_components[1].get(), 3.);
461 }
462 
TEST(PARALLEL_DEVICE,TestCollectiveSync)463 TEST(PARALLEL_DEVICE, TestCollectiveSync) { TestCollective(/*async=*/false); }
464 
465 // Note that ops on the parallel device currently don't execute
466 // asynchronously. The test is just that we don't get deadlocks.
TEST(PARALLEL_DEVICE,TestCollectiveAsync)467 TEST(PARALLEL_DEVICE, TestCollectiveAsync) { TestCollective(/*async=*/true); }
468 
RegisterCollectiveMulFunction(TFE_Context * context,const char * function_name,int group_size,TF_Status * status)469 void RegisterCollectiveMulFunction(TFE_Context* context,
470                                    const char* function_name, int group_size,
471                                    TF_Status* status) {
472   std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> body(TF_NewGraph(),
473                                                             TF_DeleteGraph);
474   TF_OperationDescription* placeholder_desc =
475       TF_NewOperation(body.get(), "Placeholder", "Placeholder");
476   TF_SetAttrType(placeholder_desc, "dtype", TF_FLOAT);
477   TF_Operation* placeholder_op = TF_FinishOperation(placeholder_desc, status);
478   if (TF_GetCode(status) != TF_OK) return;
479   TF_Output x{placeholder_op, 0};
480 
481   TF_OperationDescription* reduce_desc =
482       TF_NewOperation(body.get(), "CollectiveReduce", "CollectiveReduce");
483   TF_SetAttrType(reduce_desc, "T", TF_FLOAT);
484   TF_SetAttrInt(reduce_desc, "group_size", group_size);
485   TF_SetAttrInt(reduce_desc, "group_key", 0);
486   TF_SetAttrInt(reduce_desc, "instance_key", 0);
487 
488   const std::string merge_op("Mul");
489   TF_SetAttrString(reduce_desc, "merge_op", merge_op.c_str(),
490                    merge_op.length());
491   const std::string final_op("Id");
492   TF_SetAttrString(reduce_desc, "final_op", final_op.c_str(),
493                    final_op.length());
494   TF_SetAttrIntList(reduce_desc, "subdiv_offsets", nullptr, 0);
495   TF_AddInput(reduce_desc, x);
496   TF_Operation* reduce_op = TF_FinishOperation(reduce_desc, status);
497   if (TF_GetCode(status) != TF_OK) return;
498   TF_Operation* operations[]{placeholder_op, reduce_op};
499   TF_Output y{reduce_op, 0};
500   const char* output_name = "y";
501   std::unique_ptr<TF_Function, decltype(&TF_DeleteFunction)> function(
502       TF_GraphToFunction(
503           /* fn_body */ body.get(), /* fn_name */ function_name,
504           /* append_hash_to_fn_name */ 0, /* num_opers */ 2,
505           /* opers */ operations, /* ninputs */ 1, /* inputs */ &x,
506           /* noutputs */ 1, /* outputs */ &y, /* output_names */ &output_name,
507           /* opts */ nullptr, /* description */ "", /* status */ status),
508       TF_DeleteFunction);
509   if (TF_GetCode(status) != TF_OK) return;
510   TFE_ContextAddFunction(context, function.get(), status);
511 }
512 
TEST(PARALLEL_DEVICE,TestFunction)513 TEST(PARALLEL_DEVICE, TestFunction) {
514   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
515       TF_NewStatus(), TF_DeleteStatus);
516   std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
517       TFE_NewContextOptions(), TFE_DeleteContextOptions);
518   std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
519       TF_CreateConfig(
520           /*xla*/ false,
521           /* gpu_memory_allow_growth */ true, /* num_cpu_devices */
522           2),
523       TF_DeleteBuffer);
524   TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
525                               status.get());
526   std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
527       TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
528   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
529 
530   const char* device_name = "/job:localhost/replica:0/task:0/device:CUSTOM:0";
531   std::array<const char*, 2> underlying_devices{
532       "/job:localhost/replica:0/task:0/device:CPU:0",
533       "/job:localhost/replica:0/task:0/device:CPU:1"};
534   RegisterParallelDevice(context.get(), device_name, underlying_devices,
535                          status.get());
536   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
537 
538   const char* function_name = "test_reduce_mul";
539   RegisterCollectiveMulFunction(context.get(), function_name, 2, status.get());
540   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
541 
542   TensorHandlePtr value_one(FloatTensorHandle(7., status.get()));
543   TensorHandlePtr value_two(FloatTensorHandle(9., status.get()));
544   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
545   std::array<TFE_TensorHandle*, 2> components{value_one.get(), value_two.get()};
546   TensorHandlePtr parallel_value = CreatePerDeviceValues(
547       context.get(), components, device_name, status.get());
548   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
549 
550   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
551       TFE_NewOp(context.get(), function_name, status.get()), TFE_DeleteOp);
552   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
553   TFE_OpSetDevice(op.get(), device_name, status.get());
554   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
555   TFE_OpAddInput(op.get(), parallel_value.get(), status.get());
556   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
557 
558   TFE_TensorHandle* raw_result_handle;
559   int num_retvals = 1;
560   TFE_Execute(op.get(), &raw_result_handle, &num_retvals, status.get());
561   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
562   TensorHandlePtr reduced(raw_result_handle);
563 
564   std::array<TensorHandlePtr, 2> result_components;
565   ExtractPerDeviceValues(context.get(), reduced.get(), &result_components,
566                          status.get());
567   ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
568   ExpectScalarEq<float>(result_components[0].get(), 7. * 9.);
569   ExpectScalarEq<float>(result_components[1].get(), 7. * 9.);
570 
571   std::string first_device = TFE_TensorHandleBackingDeviceName(
572       result_components[0].get(), status.get());
573   ASSERT_EQ(underlying_devices[0], first_device);
574   std::string second_device = TFE_TensorHandleBackingDeviceName(
575       result_components[1].get(), status.get());
576   ASSERT_EQ(underlying_devices[1], second_device);
577 }
578 
579 }  // namespace parallel_device
580 }  // namespace tensorflow
581