• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/ring_reducer.h"
16 
17 #include <algorithm>
18 
19 #include "absl/memory/memory.h"
20 #include "tensorflow/core/common_runtime/base_collective_executor.h"
21 #include "tensorflow/core/common_runtime/collective_rma_local.h"
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/device_resolver_local.h"
25 #include "tensorflow/core/common_runtime/process_util.h"
26 #include "tensorflow/core/common_runtime/test_collective_executor_mgr.h"
27 #include "tensorflow/core/common_runtime/threadpool_device.h"
28 #include "tensorflow/core/framework/cancellation.h"
29 #include "tensorflow/core/framework/collective.h"
30 #include "tensorflow/core/framework/fake_input.h"
31 #include "tensorflow/core/framework/node_def.pb.h"
32 #include "tensorflow/core/framework/node_def_builder.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/lib/core/notification.h"
36 #include "tensorflow/core/lib/core/status_test_util.h"
37 #include "tensorflow/core/platform/test.h"
38 #include "tensorflow/core/platform/unbounded_work_queue.h"
39 #include "tensorflow/core/public/session_options.h"
40 #include "tensorflow/core/public/version.h"
41 
42 namespace tensorflow {
43 
44 // Wraps CollectiveRemoteAccessLocal with the ability to return an
45 // error status to the N'th action.
46 class FailTestRMA : public CollectiveRemoteAccessLocal {
47  public:
FailTestRMA(const DeviceMgr * dev_mgr,DeviceResolverInterface * dev_resolver,int64 step_id,int fail_after)48   FailTestRMA(const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
49               int64 step_id, int fail_after)
50       : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id),
51         fail_after_(fail_after) {}
52 
MaybeFail(const StatusCallback & done)53   bool MaybeFail(const StatusCallback& done) {
54     bool fail_now = false;
55     {
56       mutex_lock l(mu_);
57       if (fail_after_ > 0) {
58         fail_now = (--fail_after_ == 0);
59       }
60     }
61     if (fail_now) {
62       done(errors::Internal("Deliberate failure"));
63       return true;
64     }
65     return false;
66   }
67 
RecvFromPeer(const string & peer_device,const string & peer_task,bool peer_is_local,const string & key,Device * to_device,DeviceContext * to_device_ctx,const AllocatorAttributes & to_alloc_attr,Tensor * to_tensor,const DeviceLocality & client_locality,int dev_to_dev_stream_index,CancellationManager * cancellation_manager,const StatusCallback & done)68   void RecvFromPeer(const string& peer_device, const string& peer_task,
69                     bool peer_is_local, const string& key, Device* to_device,
70                     DeviceContext* to_device_ctx,
71                     const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
72                     const DeviceLocality& client_locality,
73                     int dev_to_dev_stream_index,
74                     CancellationManager* cancellation_manager,
75                     const StatusCallback& done) override {
76     if (MaybeFail(done)) return;
77     CollectiveRemoteAccessLocal::RecvFromPeer(
78         peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx,
79         to_alloc_attr, to_tensor, client_locality, dev_to_dev_stream_index,
80         cancellation_manager, done);
81   }
82 
PostToPeer(const string & peer_device,const string & peer_task,const string & key,Device * from_device,DeviceContext * from_device_ctx,const AllocatorAttributes & from_alloc_attr,const Tensor * from_tensor,const DeviceLocality & client_locality,CancellationManager * cancellation_manager,const StatusCallback & done)83   void PostToPeer(const string& peer_device, const string& peer_task,
84                   const string& key, Device* from_device,
85                   DeviceContext* from_device_ctx,
86                   const AllocatorAttributes& from_alloc_attr,
87                   const Tensor* from_tensor,
88                   const DeviceLocality& client_locality,
89                   CancellationManager* cancellation_manager,
90                   const StatusCallback& done) override {
91     if (MaybeFail(done)) return;
92     CollectiveRemoteAccessLocal::PostToPeer(
93         peer_device, peer_task, key, from_device, from_device_ctx,
94         from_alloc_attr, from_tensor, client_locality, cancellation_manager,
95         done);
96   }
97 
98   mutex mu_;
99   int fail_after_ TF_GUARDED_BY(mu_);
100 };
101 
GetKernel(const NodeDef & node,const DeviceType & device_type,DeviceBase * device)102 std::unique_ptr<OpKernel> GetKernel(const NodeDef& node,
103                                     const DeviceType& device_type,
104                                     DeviceBase* device) {
105   Status status;
106   std::unique_ptr<OpKernel> k = CreateOpKernel(
107       device_type, device, device->GetAllocator(AllocatorAttributes()), node,
108       TF_GRAPH_DEF_VERSION, &status);
109   if (!status.ok()) {
110     LOG(FATAL) << status;
111   }
112   return k;
113 }
114 
GetAdd(DataType dtype,const DeviceType & device_type,DeviceBase * device)115 std::unique_ptr<OpKernel> GetAdd(DataType dtype, const DeviceType& device_type,
116                                  DeviceBase* device) {
117   NodeDef node_def;
118   NodeDefBuilder builder("add_node", "Add");
119   TF_CHECK_OK(builder.Attr("T", dtype)
120                   .Input(FakeInput(dtype))
121                   .Input(FakeInput(dtype))
122                   .Finalize(&node_def));
123   return GetKernel(node_def, device_type, device);
124 }
125 
GetDiv(DataType dtype,const DeviceType & device_type,DeviceBase * device)126 std::unique_ptr<OpKernel> GetDiv(DataType dtype, const DeviceType& device_type,
127                                  DeviceBase* device) {
128   NodeDef node_def;
129   NodeDefBuilder builder("add_node", "Div");
130   TF_CHECK_OK(builder.Attr("T", dtype)
131                   .Input(FakeInput(dtype))
132                   .Input(FakeInput(dtype))
133                   .Finalize(&node_def));
134   return GetKernel(node_def, device_type, device);
135 }
136 
137 static int64 kStepId = 123;
138 
139 class RingReducerTest : public ::testing::Test {
140  protected:
RingReducerTest()141   RingReducerTest()
142       : device_type_(DEVICE_CPU), col_exec_(nullptr), col_params_(nullptr) {}
143 
144 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
InitGPUDevices()145   void InitGPUDevices() {
146     auto device_factory = DeviceFactory::GetFactory("GPU");
147     CHECK(device_factory);
148     SessionOptions options;
149     Status s = device_factory->CreateDevices(
150         options, "/job:worker/replica:0/task:0", &gpu_devices_);
151     CHECK(s.ok());
152   }
153 #endif
154 
~RingReducerTest()155   ~RingReducerTest() override {
156     stop_ = true;
157     for (auto i : instances_) delete i;
158     if (col_exec_) col_exec_->Unref();
159     if (col_params_) col_params_->Unref();
160   }
161 
Init(int num_workers,int num_devices,DataType dtype,const DeviceType & device_type,int num_subdivs,int fail_after)162   void Init(int num_workers, int num_devices, DataType dtype,
163             const DeviceType& device_type, int num_subdivs, int fail_after) {
164 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
165     InitGPUDevices();
166 #endif
167     device_type_ = device_type;
168     std::vector<std::unique_ptr<Device>> local_devices;
169     SessionOptions sess_opts;
170     sess_opts.env = Env::Default();
171     Bytes mem_limit(4 << 20);
172     DeviceLocality dev_locality;
173     for (int wi = 0; wi < num_workers; ++wi) {
174       for (int di = 0; di < num_devices; ++di) {
175         if (device_type == DEVICE_CPU) {
176           string dev_name =
177               strings::StrCat("/job:worker/replica:0/task:", wi, "/cpu:", di);
178           local_devices.push_back(absl::make_unique<ThreadPoolDevice>(
179               sess_opts, dev_name, mem_limit, dev_locality, cpu_allocator()));
180         } else if (device_type == DEVICE_GPU && !gpu_devices_.empty()) {
181           int dev_idx = (wi * num_devices) + di;
182           if (dev_idx >= static_cast<int>(gpu_devices_.size())) {
183             LOG(INFO) << "dev_mgr has access to limited GPUs, reusing for more "
184                          "than one ring node.";
185           } else {
186             local_devices.push_back(std::move(gpu_devices_[dev_idx]));
187           }
188         } else {
189           LOG(FATAL) << "Unsupported device_type " << device_type;
190         }
191       }
192     }
193     if (!dev_mgr_ || device_type == DEVICE_CPU) {
194       LOG(INFO) << "resetting dev_mgr for " << local_devices.size()
195                 << " devices: ";
196       dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(local_devices));
197     }
198     if (!gpu_ring_order_) {
199       gpu_ring_order_ = absl::make_unique<string>();
200     }
201     dev_resolver_ = absl::make_unique<DeviceResolverLocal>(dev_mgr_.get());
202     work_queue_ = std::make_shared<UnboundedWorkQueue>(Env::Default(), "test");
203     rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(), kStepId,
204                            fail_after);
205     col_exec_ = new BaseCollectiveExecutor(&col_exec_mgr_, rma_, kStepId,
206                                            dev_mgr_.get(),
207                                            gpu_ring_order_.get(), work_queue_);
208     col_params_ = new CollectiveParams();
209     col_params_->name = "test_collective";
210     static const int kGroupKey = 5;
211     col_params_->group.group_key = kGroupKey;
212     col_params_->group.device_type = device_type;
213     col_params_->group.group_size = num_workers * num_devices;
214     static const int kInstanceKey = 17;
215     col_params_->instance.instance_key = kInstanceKey;
216     col_params_->instance.impl_details.subdiv_offsets.clear();
217     col_params_->instance.type = REDUCTION_COLLECTIVE;
218     col_params_->instance.impl_details.collective_name = "RingReduce";
219     col_params_->instance.data_type = dtype;
220     col_params_->instance.impl_details.subdiv_permutations.resize(num_subdivs);
221     col_params_->subdiv_rank.resize(num_subdivs);
222     int subdiv_stride = num_devices / num_subdivs;
223     for (int sdi = 0; sdi < num_subdivs; ++sdi) {
224       col_params_->instance.impl_details.subdiv_offsets.push_back(
225           sdi * subdiv_stride);
226       col_params_->subdiv_rank[sdi] = sdi * subdiv_stride;
227     }
228 
229     // Set up a local device ring order that's not just 0,1,2...
230     std::vector<int> local_ring_order;
231     for (int di = 0; di < num_devices; ++di) {
232       local_ring_order.push_back(di);
233     }
234     for (int di = 0; di < num_devices; ++di) {
235       bool is_odd = ((di % 2) == 1);
236       int other = (di + (is_odd ? 7 : 3)) % num_devices;
237       if (di == other) continue;
238       iter_swap(local_ring_order.begin() + di,
239                 local_ring_order.begin() + other);
240     }
241     string lro_buf;
242     for (auto d : local_ring_order) strings::StrAppend(&lro_buf, d, ", ");
243     VLOG(1) << "local_ring_order " << lro_buf;
244 
245     // Set up all of the fake device contexts.
246     for (int wi = 0; wi < num_workers; ++wi) {
247       string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
248       col_params_->group.num_devices_per_task[task_name] = num_devices;
249       for (int di = 0; di < num_devices; ++di) {
250         string dev_name = strings::StrCat(task_name, "/cpu:", di);
251         if (device_type == DEVICE_GPU) {
252           dev_name =
253               strings::StrCat(task_name, "/gpu:", di % gpu_devices_.size());
254         }
255         col_params_->group.device_names.push_back(dev_name);
256         col_params_->group.task_names.push_back(task_name);
257         // Normally each device would set is_local to its own perspective but
258         // this test runs in a single process so is_local is always true.
259         col_params_->task.is_local.push_back(true);
260         for (int sdi = 0; sdi < num_subdivs; ++sdi) {
261           int rotated_di =
262               (di + col_params_->instance.impl_details.subdiv_offsets[sdi]) %
263               num_devices;
264           col_params_->instance.impl_details.subdiv_permutations[sdi].push_back(
265               wi * num_devices + local_ring_order[rotated_di]);
266         }
267       }
268     }
269     for (int wi = 0; wi < num_workers; ++wi) {
270       for (int di = 0; di < num_devices; ++di) {
271         int rank = wi * num_devices + di;
272         instances_.push_back(new DeviceInstance(
273             rank, col_params_->group.device_names[rank], device_type_, this));
274       }
275     }
276   }
277 
Reduce(int fail_after)278   void Reduce(int fail_after) {
279     std::atomic<int> done(0);
280     for (auto di : instances_) {
281       SchedClosure([di, &done] {
282         di->DoReduce();
283         ++done;
284       });
285       if (fail_after > 0) {
286         // Stagger the op execution starts.
287         Env::Default()->SleepForMicroseconds(100);
288       }
289     }
290     while (done < static_cast<int>(instances_.size())) {
291       if (stop_) break;
292       Env::Default()->SleepForMicroseconds(1000);
293     }
294   }
295 
296   template <typename T>
RunTest(DataType dtype,const DeviceType & device_type,int num_workers,int num_devices,int num_subdivs,int tensor_len,int fail_after)297   void RunTest(DataType dtype, const DeviceType& device_type, int num_workers,
298                int num_devices, int num_subdivs, int tensor_len,
299                int fail_after) {
300     Init(num_workers, num_devices, dtype, device_type, num_subdivs, fail_after);
301     std::vector<T> expected(tensor_len, 0.0);
302     for (int di = 0; di < static_cast<int>(instances_.size()); ++di) {
303       DeviceInstance* instance = instances_[di];
304       instance->InitTensor(
305           dtype, TensorShape({tensor_len}), [&expected, dtype, di](Tensor* t) {
306             for (size_t i = 0; i < t->NumElements(); ++i) {
307               // The cast is necessary to prevent clang-tidy from insisting
308               // that a faster non-open source function be substituted.
309               float value = pow(10, static_cast<double>(di)) * i;
310               if (dtype == DT_INT32 || dtype == DT_INT64) {
311                 value = di * 10 + i;
312               }
313               t->flat<T>()(i) = static_cast<T>(value);
314               expected[i] += value;
315             }
316           });
317     }
318     Reduce(fail_after);
319     if (fail_after > 0) {
320       // Confirm that every device terminated with the expected error status.
321       for (int di = 0; di < static_cast<int>(instances_.size()); ++di) {
322         EXPECT_NE(
323             instances_[di]->status_.error_message().find("Deliberate failure"),
324             string::npos);
325       }
326     } else {
327       // Confirm that every device computed the same correct reduction value.
328       for (int i = 0; i < tensor_len; ++i) {
329         expected[i] /= (num_workers * num_devices);
330       }
331       for (int di = 0; di < static_cast<int>(instances_.size()); ++di) {
332         TF_EXPECT_OK(instances_[di]->status_);
333         Tensor* inst = &instances_[di]->tensor_;
334         CHECK(inst);
335         Tensor actual(dtype, TensorShape({tensor_len}));
336         if (device_type_ == DEVICE_CPU) {
337           CHECK(actual.CopyFrom(*inst, inst->shape()));
338           VLOG(1) << "actual " << actual.SummarizeValue(100);
339         } else if (device_type_ == DEVICE_GPU) {
340           Device* dev = instances_[di]->device_;
341           auto* dev_info = dev->tensorflow_gpu_device_info();
342           CHECK(dev_info);
343           CHECK(dev_info->default_context
344                     ->CopyDeviceTensorToCPUSync(inst, "" /*tensor_name*/, dev,
345                                                 &actual)
346                     .ok());
347         }
348 
349         auto alias = actual.template unaligned_flat<T>();
350         for (int i = 0; i < tensor_len; ++i) {
351           switch (dtype) {
352             case DT_FLOAT:
353               EXPECT_FLOAT_EQ(expected[i], alias(i))
354                   << "Mismatch at device " << di << " index " << i;
355               break;
356             case DT_DOUBLE:
357               EXPECT_DOUBLE_EQ(expected[i], alias(i))
358                   << "Mismatch at device " << di << " index " << i;
359               break;
360             case DT_INT32:
361             case DT_INT64:
362               EXPECT_EQ(expected[i], alias(i))
363                   << "Mismatch at device " << di << " index " << i;
364               break;
365             default:
366               LOG(FATAL) << "unimplemented";
367           }
368         }
369       }
370     }
371   }
372 
GetCollectiveReduce(const CollectiveParams & params,Tensor * input,const DeviceType & device_type,DeviceBase * device)373   std::unique_ptr<OpKernel> GetCollectiveReduce(const CollectiveParams& params,
374                                                 Tensor* input,
375                                                 const DeviceType& device_type,
376                                                 DeviceBase* device) {
377     mutex_lock l(mu_);
378     NodeDef node_def;
379     NodeDefBuilder builder(
380         strings::StrCat("collective_reduce_", reduce_counter_++),
381         "CollectiveReduce");
382     TF_CHECK_OK(
383         builder.Attr("T", params.instance.data_type)
384             .Attr("merge_op", "Add")
385             .Attr("final_op", "Id")
386             .Attr("group_size", params.group.group_size)
387             .Attr("group_key", params.group.group_key)
388             .Attr("instance_key", params.instance.instance_key)
389             .Attr("subdiv_offsets", params.instance.impl_details.subdiv_offsets)
390             .Input(FakeInput(params.instance.data_type))
391             .Finalize(&node_def));
392     return GetKernel(node_def, device_type, device);
393   }
394 
RunSubdivPermsTest(CollectiveParams * cp,const std::vector<std::vector<int>> & expected_subdiv_perms,const std::vector<int> & expected_subdiv_rank)395   void RunSubdivPermsTest(
396       CollectiveParams* cp,
397       const std::vector<std::vector<int>>& expected_subdiv_perms,
398       const std::vector<int>& expected_subdiv_rank) {
399     col_exec_ = nullptr;
400     cp->instance.impl_details.subdiv_permutations.clear();
401     cp->subdiv_rank.clear();
402     // Create a stub ring reducer only for testing param initialization.
403     RingReducer* reducer = new RingReducer;
404     core::ScopedUnref unref(reducer);
405     TF_CHECK_OK(reducer->InitializeCollectiveParams(cp));
406     EXPECT_EQ(expected_subdiv_perms,
407               cp->instance.impl_details.subdiv_permutations);
408     EXPECT_EQ(expected_subdiv_rank, cp->subdiv_rank);
409     reducer->group_size_tensor_ready_.Notify();  // To unblock destructor.
410   }
411 
412   class DeviceInstance {
413    public:
DeviceInstance(int rank,const string & dev_name,const DeviceType & device_type,RingReducerTest * parent)414     DeviceInstance(int rank, const string& dev_name,
415                    const DeviceType& device_type, RingReducerTest* parent)
416         : parent_(parent),
417           dev_name_(dev_name),
418           device_type_(device_type),
419           rank_(rank),
420           col_params_(new CollectiveParams()) {
421       TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(dev_name, &device_))
422           << "Couldn't find device " << dev_name
423           << " existing devices: " << parent_->dev_mgr_->DebugString();
424       col_params_->name = parent_->col_params_->name;
425       col_params_->group = parent_->col_params_->group;
426       col_params_->instance = parent->col_params_->instance;
427       col_params_->task.is_local = parent_->col_params_->task.is_local;
428       col_params_->subdiv_rank = parent_->col_params_->subdiv_rank;
429 
430       int num_subdivs = static_cast<int>(col_params_->subdiv_rank.size());
431       int group_size = col_params_->group.group_size;
432       CHECK_EQ(group_size,
433                static_cast<int>(col_params_->group.device_names.size()));
434       // Id of this device is at rank position in first subdiv perm.
435       int my_device_id =
436           col_params_->instance.impl_details.subdiv_permutations[0][rank];
437       col_params_->default_rank = my_device_id;
438       // Set rank for all other subdivs by finding that device_id.
439       for (int sdi = 0; sdi < num_subdivs; ++sdi) {
440         for (int r = 0; r < static_cast<int>(col_params_->instance.impl_details
441                                                  .subdiv_permutations[sdi]
442                                                  .size());
443              ++r) {
444           if (my_device_id ==
445               col_params_->instance.impl_details.subdiv_permutations[sdi][r]) {
446             col_params_->subdiv_rank[sdi] = r;
447             break;
448           }
449         }
450       }
451     }
452 
~DeviceInstance()453     ~DeviceInstance() { col_params_->Unref(); }
454 
InitTensor(DataType dtype,const TensorShape & shape,const std::function<void (Tensor *)> & init_f)455     void InitTensor(DataType dtype, const TensorShape& shape,
456                     const std::function<void(Tensor*)>& init_f) {
457       tensor_ =
458           Tensor(device_->GetAllocator(AllocatorAttributes()), dtype, shape);
459       if (device_type_ == DEVICE_CPU) {
460         init_f(&tensor_);
461       } else if (device_type_ == DEVICE_GPU) {
462         Tensor cpu_tensor(dtype, shape);
463         init_f(&cpu_tensor);
464         auto* dev_info = device_->tensorflow_gpu_device_info();
465         CHECK(dev_info);
466         CHECK(dev_info->default_context
467                   ->CopyCPUTensorToDeviceSync(&cpu_tensor, device_, &tensor_)
468                   .ok());
469       } else {
470         LOG(FATAL) << "Unsupported device_type " << device_type_;
471       }
472     }
473 
DoReduce()474     void DoReduce() {
475       merge_op_ =
476           GetAdd(col_params_->instance.data_type, device_type_, device_);
477       final_op_ =
478           GetDiv(col_params_->instance.data_type, device_type_, device_);
479       col_params_->merge_op = merge_op_.get();
480       col_params_->final_op = final_op_.get();
481 
482       // Prepare an OpKernelContext.
483       OpKernelContext::Params op_params;
484       op_params.step_id = kStepId;
485       op_params.device = device_;
486       op_params.cancellation_manager = &parent_->cancellation_manager_;
487       gtl::InlinedVector<TensorValue, 4> inputs;
488       inputs.push_back(TensorValue(&tensor_));
489       op_params.inputs = &inputs;
490       gtl::InlinedVector<AllocatorAttributes, 4> input_aa(
491           {AllocatorAttributes()});
492       op_params.input_alloc_attrs = &input_aa;
493       DeviceContext* dev_ctx = nullptr;
494       auto* dev_info = device_->tensorflow_gpu_device_info();
495       if (dev_info) {
496         dev_ctx = dev_info->default_context;
497         dev_ctx->Ref();
498       } else {
499         dev_ctx = new DeviceContext;
500       }
501       op_params.op_device_context = dev_ctx;
502       int forward_from = 0;
503       op_params.forward_from_array = &forward_from;
504       AllocatorAttributes generic_alloc_attr;
505       op_params.output_attr_array = &generic_alloc_attr;
506       std::unique_ptr<OpKernel> op = parent_->GetCollectiveReduce(
507           *col_params_, &tensor_, DEVICE_CPU, device_);
508       op_params.op_kernel = op.get();
509       OpKernelContext ctx(&op_params, 1);
510 
511       // We never actually execute the kernel, so we need to do the output
512       // allocation it would do, ourselves.
513       Tensor* output_tensor_ptr = nullptr;
514       TF_CHECK_OK(ctx.forward_input_or_allocate_output({0}, 0, tensor_.shape(),
515                                                        &output_tensor_ptr));
516       CHECK_EQ(output_tensor_ptr, ctx.mutable_output(0));
517 
518       // Prepare a RingReducer instance.
519       string exec_key =
520           strings::StrCat(col_params_->instance.instance_key, ":0:0");
521       RingReducer* reducer = new RingReducer;
522       core::ScopedUnref unref(reducer);
523       auto col_ctx = std::make_shared<CollectiveContext>(
524           parent_->col_exec_, /*nccl_communicator*/ nullptr,
525           parent_->dev_mgr_.get(), &ctx, &op_params, col_params_, exec_key,
526           kStepId, &tensor_, &tensor_);
527       TF_CHECK_OK(reducer->InitializeCollectiveContext(col_ctx));
528 
529       // Run the all-reduce.
530       reducer->Run([this](Status s) { status_ = s; });
531       if (status_.ok()) {
532         CHECK(tensor_.CopyFrom(*ctx.mutable_output(0), tensor_.shape()));
533       }
534 
535       dev_ctx->Unref();
536     }
537 
tensor()538     const Tensor& tensor() { return tensor_; }
539 
540     RingReducerTest* parent_;
541     string dev_name_;
542     DeviceType device_type_;
543     int rank_;
544     Tensor tensor_;
545     Device* device_;
546     CollectiveParams* col_params_;
547     std::unique_ptr<OpKernel> merge_op_;
548     std::unique_ptr<OpKernel> final_op_;
549     std::unique_ptr<CollectiveAdapter> ca_;
550     std::unique_ptr<OpKernelContext> ctx_;
551     Status status_;
552   };
553 
554   bool stop_ = false;
555   DeviceType device_type_;
556   TestCollectiveExecutorMgr col_exec_mgr_;
557   CollectiveExecutor* col_exec_;
558   CollectiveRemoteAccessLocal* rma_;
559   std::unique_ptr<DeviceResolverLocal> dev_resolver_;
560   std::shared_ptr<UnboundedWorkQueue> work_queue_;
561   std::vector<DeviceInstance*> instances_;
562   CollectiveParams* col_params_;
563   std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
564   std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
565   std::unique_ptr<string> gpu_ring_order_;
566   mutex mu_;
567   int32 reduce_counter_ TF_GUARDED_BY(mu_) = 0;
568   CancellationManager cancellation_manager_;
569 };
570 
SetUpCollectiveParams(const int num_devs_per_task,const int num_tasks)571 CollectiveParams* SetUpCollectiveParams(const int num_devs_per_task,
572                                         const int num_tasks) {
573   auto cp = new CollectiveParams();
574   const int kNumDevs = num_devs_per_task * num_tasks;
575   cp->group.group_key = 1;
576   cp->group.group_size = kNumDevs;
577   cp->group.device_type = DeviceType("GPU");
578   cp->group.num_tasks = num_tasks;
579   cp->instance.instance_key = 3;
580   cp->instance.type = REDUCTION_COLLECTIVE;
581   cp->instance.data_type = DataType(DT_FLOAT);
582   cp->instance.shape = TensorShape({kNumDevs});
583   cp->instance.impl_details.collective_name = "RingReduce";
584   cp->instance.impl_details.subdiv_offsets.push_back(0);
585   cp->is_source = false;
586   for (int i = 0; i < kNumDevs; ++i) {
587     int task_id = i / num_devs_per_task;
588     int dev_id = i % num_devs_per_task;
589     string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id);
590     string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id);
591     cp->group.task_names.push_back(task_name);
592     cp->group.device_names.push_back(device_name);
593   }
594   return cp;
595 }
596 
TEST_F(RingReducerTest,InitializeParams)597 TEST_F(RingReducerTest, InitializeParams) {
598   const int kNumDevsPerTask = 8;
599   const int kNumTasks = 3;
600   CollectiveParams* cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
601   core::ScopedUnref unref(cp);
602 
603   cp->default_rank = 0;
604   cp->instance.impl_details.subdiv_offsets = {0, 4};
605   RunSubdivPermsTest(cp,
606                      {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
607                        12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
608                       {4, 5, 6,  7,  0,  1,  2,  3,  12, 13, 14, 15,
609                        8, 9, 10, 11, 20, 21, 22, 23, 16, 17, 18, 19}},
610                      {0, 4});
611 
612   cp->instance.impl_details.subdiv_offsets = {0, -4};
613   RunSubdivPermsTest(cp,
614                      {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
615                        12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
616                       {3,  2,  1,  0,  7,  6,  5,  4,  11, 10, 9,  8,
617                        15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20}},
618                      {0, 3});
619 
620   cp->default_rank = 3;
621   cp->instance.impl_details.subdiv_offsets = {3, -3};
622   RunSubdivPermsTest(cp,
623                      {{3,  4, 5, 6,  7,  0,  1,  2,  11, 12, 13, 14,
624                        15, 8, 9, 10, 19, 20, 21, 22, 23, 16, 17, 18},
625                       {4, 3,  2,  1,  0,  7,  6,  5,  12, 11, 10, 9,
626                        8, 15, 14, 13, 20, 19, 18, 17, 16, 23, 22, 21}},
627                      {0, 1});
628 }
629 
TEST_F(RingReducerTest,AutomaticSubdivs)630 TEST_F(RingReducerTest, AutomaticSubdivs) {
631   const int kNumDevsPerTask = 8;
632   const int kNumTasks = 3;
633   const int kNumDevs = kNumDevsPerTask * kNumTasks;
634   CollectiveParams* cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
635   core::ScopedUnref unref(cp);
636 
637   // Test automatic generation of subdiv offsets.
638   cp->default_rank = 0;
639   cp->instance.impl_details.subdiv_offsets.clear();
640   RunSubdivPermsTest(cp, {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
641                            12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
642                      {0});
643 
644   // Set shape so that with 2 subdivs chunk_size is 3 MiB.  This should cause 2
645   // offsets, {0, -4}, to be generated.
646   {
647     int num_subdivs = 2;
648     int num_chunks = kNumDevs * num_subdivs;
649     size_t chunk_size = 3 * 1048576;  // 3 MB
650     size_t tensor_size = chunk_size * num_chunks;
651     cp->instance.shape =
652         TensorShape({static_cast<int64>(tensor_size / DataTypeSize(DT_FLOAT))});
653   }
654   cp->instance.impl_details.subdiv_offsets.clear();
655   RunSubdivPermsTest(cp,
656                      {{0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11,
657                        12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
658                       {3,  2,  1,  0,  7,  6,  5,  4,  11, 10, 9,  8,
659                        15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20}},
660                      {0, 3});
661 }
662 
TEST_F(RingReducerTest,AutomaticSubdivUpperBound)663 TEST_F(RingReducerTest, AutomaticSubdivUpperBound) {
664   const int kNumDevsPerTask = 1;
665   const int kNumTasks = 4;
666   CollectiveParams* cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
667   core::ScopedUnref unref(cp);
668 
669   cp->default_rank = 0;
670   cp->instance.impl_details.subdiv_offsets.clear();
671   cp->instance.shape = TensorShape({104857600 / DataTypeSize(DT_FLOAT)});
672   RunSubdivPermsTest(cp, {{0, 1, 2, 3}, {0, 1, 2, 3}}, {0, 0});
673 }
674 
675 // TODO(b/113171733): change to use TEST_P.
676 #define DEF_TEST(B, T, W, D, S, L, A)                                         \
677   TEST_F(RingReducerTest,                                                     \
678          DaTy##B##_DevTy##T##_Wkr##W##_Dev##D##_Sdiv##S##_Len##L##_Abrt##A) { \
679     DataType dtype = DT_##B;                                                  \
680     switch (dtype) {                                                          \
681       case DT_FLOAT: {                                                        \
682         RunTest<float>(dtype, DEVICE_##T, W, D, S, L, A);                     \
683       } break;                                                                \
684       case DT_DOUBLE: {                                                       \
685         RunTest<double>(dtype, DEVICE_##T, W, D, S, L, A);                    \
686       } break;                                                                \
687       case DT_INT32: {                                                        \
688         RunTest<int32>(dtype, DEVICE_##T, W, D, S, L, A);                     \
689       } break;                                                                \
690       case DT_INT64: {                                                        \
691         RunTest<int64>(dtype, DEVICE_##T, W, D, S, L, A);                     \
692       } break;                                                                \
693       default:                                                                \
694         LOG(FATAL) << "Unimplemented";                                        \
695     }                                                                         \
696   }
697 
698 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
699 // Success tests
700 DEF_TEST(FLOAT, CPU, 1, 2, 1, 1, 0)
701 DEF_TEST(FLOAT, CPU, 1, 2, 1, 2, 0)
702 DEF_TEST(FLOAT, CPU, 1, 2, 1, 8, 0)
703 DEF_TEST(FLOAT, CPU, 1, 2, 1, 16, 0)
704 DEF_TEST(FLOAT, CPU, 1, 2, 1, 1001, 0)
705 DEF_TEST(FLOAT, CPU, 2, 4, 1, 128, 0)
706 DEF_TEST(FLOAT, CPU, 2, 8, 1, 1001, 0)
707 DEF_TEST(FLOAT, CPU, 2, 8, 1, 4096, 0)
708 DEF_TEST(FLOAT, CPU, 2, 8, 1, 9408, 0)
709 DEF_TEST(FLOAT, CPU, 2, 8, 3, 4095, 0)
710 DEF_TEST(FLOAT, CPU, 2, 8, 3, 1045991, 0)
711 DEF_TEST(FLOAT, CPU, 4, 4, 4, 1045991, 0)
712 DEF_TEST(DOUBLE, CPU, 1, 2, 1, 1001, 0)
713 DEF_TEST(DOUBLE, CPU, 2, 8, 3, 4095, 0)
714 DEF_TEST(INT32, CPU, 1, 2, 1, 1001, 0)
715 DEF_TEST(INT32, CPU, 2, 8, 3, 4095, 0)
716 DEF_TEST(INT64, CPU, 1, 2, 1, 1001, 0)
717 DEF_TEST(INT64, CPU, 2, 8, 3, 4095, 0)
718 
719 // Failure tests
720 DEF_TEST(FLOAT, CPU, 2, 8, 1, 9408, 1)
721 DEF_TEST(FLOAT, CPU, 2, 8, 1, 9408, 7)
722 DEF_TEST(FLOAT, CPU, 2, 8, 2, 9408, 11)
723 #endif
724 
725 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
726 // GPU tests.  So long as the device names are all in a single tasks we
727 // bypass inter-worker routing code and can fake multiple GPUs with a single
728 // GPU, from the perspective of the RingReducer logic.  So these tests
729 // are all single-worker.
730 DEF_TEST(FLOAT, GPU, 1, 2, 1, 1, 0)
731 DEF_TEST(FLOAT, GPU, 1, 2, 1, 2, 0)
732 DEF_TEST(FLOAT, GPU, 1, 2, 1, 8, 0)
733 DEF_TEST(FLOAT, GPU, 1, 2, 1, 16, 0)
734 DEF_TEST(FLOAT, GPU, 1, 2, 1, 1001, 0)
735 DEF_TEST(FLOAT, GPU, 1, 8, 1, 1001, 0)
736 DEF_TEST(FLOAT, GPU, 1, 8, 1, 4096, 0)
737 DEF_TEST(FLOAT, GPU, 1, 8, 3, 4095, 0)
738 DEF_TEST(FLOAT, GPU, 1, 8, 3, 1045991, 0)
739 DEF_TEST(FLOAT, GPU, 1, 4, 4, 1045991, 0)
740 DEF_TEST(DOUBLE, GPU, 1, 2, 1, 1001, 0)
741 // INT32 values are never on the GPU.
742 // DEF_TEST(INT32, GPU, 1, 2, 1, 1001, 0)
743 DEF_TEST(INT64, GPU, 1, 2, 1, 1001, 0)
744 
745 // Failure tests
746 DEF_TEST(FLOAT, GPU, 1, 8, 1, 9408, 2)
747 DEF_TEST(FLOAT, GPU, 1, 8, 2, 9408, 5)
748 #endif
749 
750 }  // namespace tensorflow
751