1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/ring_reducer.h"
16
17 #include <algorithm>
18
19 #include "absl/memory/memory.h"
20 #include "tensorflow/core/common_runtime/base_collective_executor.h"
21 #include "tensorflow/core/common_runtime/collective_rma_local.h"
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/device_resolver_local.h"
25 #include "tensorflow/core/common_runtime/process_util.h"
26 #include "tensorflow/core/common_runtime/test_collective_executor_mgr.h"
27 #include "tensorflow/core/common_runtime/threadpool_device.h"
28 #include "tensorflow/core/framework/cancellation.h"
29 #include "tensorflow/core/framework/collective.h"
30 #include "tensorflow/core/framework/fake_input.h"
31 #include "tensorflow/core/framework/node_def.pb.h"
32 #include "tensorflow/core/framework/node_def_builder.h"
33 #include "tensorflow/core/framework/op_kernel.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/lib/core/notification.h"
36 #include "tensorflow/core/lib/core/status_test_util.h"
37 #include "tensorflow/core/platform/test.h"
38 #include "tensorflow/core/platform/unbounded_work_queue.h"
39 #include "tensorflow/core/public/session_options.h"
40 #include "tensorflow/core/public/version.h"
41
42 namespace tensorflow {
43
44 // Wraps CollectiveRemoteAccessLocal with the ability to return an
45 // error status to the N'th action.
46 class FailTestRMA : public CollectiveRemoteAccessLocal {
47 public:
FailTestRMA(const DeviceMgr * dev_mgr,DeviceResolverInterface * dev_resolver,int64 step_id,int fail_after)48 FailTestRMA(const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
49 int64 step_id, int fail_after)
50 : CollectiveRemoteAccessLocal(dev_mgr, dev_resolver, step_id),
51 fail_after_(fail_after) {}
52
MaybeFail(const StatusCallback & done)53 bool MaybeFail(const StatusCallback& done) {
54 bool fail_now = false;
55 {
56 mutex_lock l(mu_);
57 if (fail_after_ > 0) {
58 fail_now = (--fail_after_ == 0);
59 }
60 }
61 if (fail_now) {
62 done(errors::Internal("Deliberate failure"));
63 return true;
64 }
65 return false;
66 }
67
RecvFromPeer(const string & peer_device,const string & peer_task,bool peer_is_local,const string & key,Device * to_device,DeviceContext * to_device_ctx,const AllocatorAttributes & to_alloc_attr,Tensor * to_tensor,const DeviceLocality & client_locality,int dev_to_dev_stream_index,CancellationManager * cancellation_manager,const StatusCallback & done)68 void RecvFromPeer(const string& peer_device, const string& peer_task,
69 bool peer_is_local, const string& key, Device* to_device,
70 DeviceContext* to_device_ctx,
71 const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
72 const DeviceLocality& client_locality,
73 int dev_to_dev_stream_index,
74 CancellationManager* cancellation_manager,
75 const StatusCallback& done) override {
76 if (MaybeFail(done)) return;
77 CollectiveRemoteAccessLocal::RecvFromPeer(
78 peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx,
79 to_alloc_attr, to_tensor, client_locality, dev_to_dev_stream_index,
80 cancellation_manager, done);
81 }
82
PostToPeer(const string & peer_device,const string & peer_task,const string & key,Device * from_device,DeviceContext * from_device_ctx,const AllocatorAttributes & from_alloc_attr,const Tensor * from_tensor,const DeviceLocality & client_locality,CancellationManager * cancellation_manager,const StatusCallback & done)83 void PostToPeer(const string& peer_device, const string& peer_task,
84 const string& key, Device* from_device,
85 DeviceContext* from_device_ctx,
86 const AllocatorAttributes& from_alloc_attr,
87 const Tensor* from_tensor,
88 const DeviceLocality& client_locality,
89 CancellationManager* cancellation_manager,
90 const StatusCallback& done) override {
91 if (MaybeFail(done)) return;
92 CollectiveRemoteAccessLocal::PostToPeer(
93 peer_device, peer_task, key, from_device, from_device_ctx,
94 from_alloc_attr, from_tensor, client_locality, cancellation_manager,
95 done);
96 }
97
98 mutex mu_;
99 int fail_after_ TF_GUARDED_BY(mu_);
100 };
101
GetKernel(const NodeDef & node,const DeviceType & device_type,DeviceBase * device)102 std::unique_ptr<OpKernel> GetKernel(const NodeDef& node,
103 const DeviceType& device_type,
104 DeviceBase* device) {
105 Status status;
106 std::unique_ptr<OpKernel> k = CreateOpKernel(
107 device_type, device, device->GetAllocator(AllocatorAttributes()), node,
108 TF_GRAPH_DEF_VERSION, &status);
109 if (!status.ok()) {
110 LOG(FATAL) << status;
111 }
112 return k;
113 }
114
GetAdd(DataType dtype,const DeviceType & device_type,DeviceBase * device)115 std::unique_ptr<OpKernel> GetAdd(DataType dtype, const DeviceType& device_type,
116 DeviceBase* device) {
117 NodeDef node_def;
118 NodeDefBuilder builder("add_node", "Add");
119 TF_CHECK_OK(builder.Attr("T", dtype)
120 .Input(FakeInput(dtype))
121 .Input(FakeInput(dtype))
122 .Finalize(&node_def));
123 return GetKernel(node_def, device_type, device);
124 }
125
GetDiv(DataType dtype,const DeviceType & device_type,DeviceBase * device)126 std::unique_ptr<OpKernel> GetDiv(DataType dtype, const DeviceType& device_type,
127 DeviceBase* device) {
128 NodeDef node_def;
129 NodeDefBuilder builder("add_node", "Div");
130 TF_CHECK_OK(builder.Attr("T", dtype)
131 .Input(FakeInput(dtype))
132 .Input(FakeInput(dtype))
133 .Finalize(&node_def));
134 return GetKernel(node_def, device_type, device);
135 }
136
137 static int64 kStepId = 123;
138
139 class RingReducerTest : public ::testing::Test {
140 protected:
RingReducerTest()141 RingReducerTest()
142 : device_type_(DEVICE_CPU), col_exec_(nullptr), col_params_(nullptr) {}
143
144 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
InitGPUDevices()145 void InitGPUDevices() {
146 auto device_factory = DeviceFactory::GetFactory("GPU");
147 CHECK(device_factory);
148 SessionOptions options;
149 Status s = device_factory->CreateDevices(
150 options, "/job:worker/replica:0/task:0", &gpu_devices_);
151 CHECK(s.ok());
152 }
153 #endif
154
~RingReducerTest()155 ~RingReducerTest() override {
156 stop_ = true;
157 for (auto i : instances_) delete i;
158 if (col_exec_) col_exec_->Unref();
159 if (col_params_) col_params_->Unref();
160 }
161
Init(int num_workers,int num_devices,DataType dtype,const DeviceType & device_type,int num_subdivs,int fail_after)162 void Init(int num_workers, int num_devices, DataType dtype,
163 const DeviceType& device_type, int num_subdivs, int fail_after) {
164 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
165 InitGPUDevices();
166 #endif
167 device_type_ = device_type;
168 std::vector<std::unique_ptr<Device>> local_devices;
169 SessionOptions sess_opts;
170 sess_opts.env = Env::Default();
171 Bytes mem_limit(4 << 20);
172 DeviceLocality dev_locality;
173 for (int wi = 0; wi < num_workers; ++wi) {
174 for (int di = 0; di < num_devices; ++di) {
175 if (device_type == DEVICE_CPU) {
176 string dev_name =
177 strings::StrCat("/job:worker/replica:0/task:", wi, "/cpu:", di);
178 local_devices.push_back(absl::make_unique<ThreadPoolDevice>(
179 sess_opts, dev_name, mem_limit, dev_locality, cpu_allocator()));
180 } else if (device_type == DEVICE_GPU && !gpu_devices_.empty()) {
181 int dev_idx = (wi * num_devices) + di;
182 if (dev_idx >= static_cast<int>(gpu_devices_.size())) {
183 LOG(INFO) << "dev_mgr has access to limited GPUs, reusing for more "
184 "than one ring node.";
185 } else {
186 local_devices.push_back(std::move(gpu_devices_[dev_idx]));
187 }
188 } else {
189 LOG(FATAL) << "Unsupported device_type " << device_type;
190 }
191 }
192 }
193 if (!dev_mgr_ || device_type == DEVICE_CPU) {
194 LOG(INFO) << "resetting dev_mgr for " << local_devices.size()
195 << " devices: ";
196 dev_mgr_ = absl::make_unique<StaticDeviceMgr>(std::move(local_devices));
197 }
198 if (!gpu_ring_order_) {
199 gpu_ring_order_ = absl::make_unique<string>();
200 }
201 dev_resolver_ = absl::make_unique<DeviceResolverLocal>(dev_mgr_.get());
202 work_queue_ = std::make_shared<UnboundedWorkQueue>(Env::Default(), "test");
203 rma_ = new FailTestRMA(dev_mgr_.get(), dev_resolver_.get(), kStepId,
204 fail_after);
205 col_exec_ = new BaseCollectiveExecutor(&col_exec_mgr_, rma_, kStepId,
206 dev_mgr_.get(),
207 gpu_ring_order_.get(), work_queue_);
208 col_params_ = new CollectiveParams();
209 col_params_->name = "test_collective";
210 static const int kGroupKey = 5;
211 col_params_->group.group_key = kGroupKey;
212 col_params_->group.device_type = device_type;
213 col_params_->group.group_size = num_workers * num_devices;
214 static const int kInstanceKey = 17;
215 col_params_->instance.instance_key = kInstanceKey;
216 col_params_->instance.impl_details.subdiv_offsets.clear();
217 col_params_->instance.type = REDUCTION_COLLECTIVE;
218 col_params_->instance.impl_details.collective_name = "RingReduce";
219 col_params_->instance.data_type = dtype;
220 col_params_->instance.impl_details.subdiv_permutations.resize(num_subdivs);
221 col_params_->subdiv_rank.resize(num_subdivs);
222 int subdiv_stride = num_devices / num_subdivs;
223 for (int sdi = 0; sdi < num_subdivs; ++sdi) {
224 col_params_->instance.impl_details.subdiv_offsets.push_back(
225 sdi * subdiv_stride);
226 col_params_->subdiv_rank[sdi] = sdi * subdiv_stride;
227 }
228
229 // Set up a local device ring order that's not just 0,1,2...
230 std::vector<int> local_ring_order;
231 for (int di = 0; di < num_devices; ++di) {
232 local_ring_order.push_back(di);
233 }
234 for (int di = 0; di < num_devices; ++di) {
235 bool is_odd = ((di % 2) == 1);
236 int other = (di + (is_odd ? 7 : 3)) % num_devices;
237 if (di == other) continue;
238 iter_swap(local_ring_order.begin() + di,
239 local_ring_order.begin() + other);
240 }
241 string lro_buf;
242 for (auto d : local_ring_order) strings::StrAppend(&lro_buf, d, ", ");
243 VLOG(1) << "local_ring_order " << lro_buf;
244
245 // Set up all of the fake device contexts.
246 for (int wi = 0; wi < num_workers; ++wi) {
247 string task_name = strings::StrCat("/job:worker/replica:0/task:", wi);
248 col_params_->group.num_devices_per_task[task_name] = num_devices;
249 for (int di = 0; di < num_devices; ++di) {
250 string dev_name = strings::StrCat(task_name, "/cpu:", di);
251 if (device_type == DEVICE_GPU) {
252 dev_name =
253 strings::StrCat(task_name, "/gpu:", di % gpu_devices_.size());
254 }
255 col_params_->group.device_names.push_back(dev_name);
256 col_params_->group.task_names.push_back(task_name);
257 // Normally each device would set is_local to its own perspective but
258 // this test runs in a single process so is_local is always true.
259 col_params_->task.is_local.push_back(true);
260 for (int sdi = 0; sdi < num_subdivs; ++sdi) {
261 int rotated_di =
262 (di + col_params_->instance.impl_details.subdiv_offsets[sdi]) %
263 num_devices;
264 col_params_->instance.impl_details.subdiv_permutations[sdi].push_back(
265 wi * num_devices + local_ring_order[rotated_di]);
266 }
267 }
268 }
269 for (int wi = 0; wi < num_workers; ++wi) {
270 for (int di = 0; di < num_devices; ++di) {
271 int rank = wi * num_devices + di;
272 instances_.push_back(new DeviceInstance(
273 rank, col_params_->group.device_names[rank], device_type_, this));
274 }
275 }
276 }
277
Reduce(int fail_after)278 void Reduce(int fail_after) {
279 std::atomic<int> done(0);
280 for (auto di : instances_) {
281 SchedClosure([di, &done] {
282 di->DoReduce();
283 ++done;
284 });
285 if (fail_after > 0) {
286 // Stagger the op execution starts.
287 Env::Default()->SleepForMicroseconds(100);
288 }
289 }
290 while (done < static_cast<int>(instances_.size())) {
291 if (stop_) break;
292 Env::Default()->SleepForMicroseconds(1000);
293 }
294 }
295
296 template <typename T>
RunTest(DataType dtype,const DeviceType & device_type,int num_workers,int num_devices,int num_subdivs,int tensor_len,int fail_after)297 void RunTest(DataType dtype, const DeviceType& device_type, int num_workers,
298 int num_devices, int num_subdivs, int tensor_len,
299 int fail_after) {
300 Init(num_workers, num_devices, dtype, device_type, num_subdivs, fail_after);
301 std::vector<T> expected(tensor_len, 0.0);
302 for (int di = 0; di < static_cast<int>(instances_.size()); ++di) {
303 DeviceInstance* instance = instances_[di];
304 instance->InitTensor(
305 dtype, TensorShape({tensor_len}), [&expected, dtype, di](Tensor* t) {
306 for (size_t i = 0; i < t->NumElements(); ++i) {
307 // The cast is necessary to prevent clang-tidy from insisting
308 // that a faster non-open source function be substituted.
309 float value = pow(10, static_cast<double>(di)) * i;
310 if (dtype == DT_INT32 || dtype == DT_INT64) {
311 value = di * 10 + i;
312 }
313 t->flat<T>()(i) = static_cast<T>(value);
314 expected[i] += value;
315 }
316 });
317 }
318 Reduce(fail_after);
319 if (fail_after > 0) {
320 // Confirm that every device terminated with the expected error status.
321 for (int di = 0; di < static_cast<int>(instances_.size()); ++di) {
322 EXPECT_NE(
323 instances_[di]->status_.error_message().find("Deliberate failure"),
324 string::npos);
325 }
326 } else {
327 // Confirm that every device computed the same correct reduction value.
328 for (int i = 0; i < tensor_len; ++i) {
329 expected[i] /= (num_workers * num_devices);
330 }
331 for (int di = 0; di < static_cast<int>(instances_.size()); ++di) {
332 TF_EXPECT_OK(instances_[di]->status_);
333 Tensor* inst = &instances_[di]->tensor_;
334 CHECK(inst);
335 Tensor actual(dtype, TensorShape({tensor_len}));
336 if (device_type_ == DEVICE_CPU) {
337 CHECK(actual.CopyFrom(*inst, inst->shape()));
338 VLOG(1) << "actual " << actual.SummarizeValue(100);
339 } else if (device_type_ == DEVICE_GPU) {
340 Device* dev = instances_[di]->device_;
341 auto* dev_info = dev->tensorflow_gpu_device_info();
342 CHECK(dev_info);
343 CHECK(dev_info->default_context
344 ->CopyDeviceTensorToCPUSync(inst, "" /*tensor_name*/, dev,
345 &actual)
346 .ok());
347 }
348
349 auto alias = actual.template unaligned_flat<T>();
350 for (int i = 0; i < tensor_len; ++i) {
351 switch (dtype) {
352 case DT_FLOAT:
353 EXPECT_FLOAT_EQ(expected[i], alias(i))
354 << "Mismatch at device " << di << " index " << i;
355 break;
356 case DT_DOUBLE:
357 EXPECT_DOUBLE_EQ(expected[i], alias(i))
358 << "Mismatch at device " << di << " index " << i;
359 break;
360 case DT_INT32:
361 case DT_INT64:
362 EXPECT_EQ(expected[i], alias(i))
363 << "Mismatch at device " << di << " index " << i;
364 break;
365 default:
366 LOG(FATAL) << "unimplemented";
367 }
368 }
369 }
370 }
371 }
372
GetCollectiveReduce(const CollectiveParams & params,Tensor * input,const DeviceType & device_type,DeviceBase * device)373 std::unique_ptr<OpKernel> GetCollectiveReduce(const CollectiveParams& params,
374 Tensor* input,
375 const DeviceType& device_type,
376 DeviceBase* device) {
377 mutex_lock l(mu_);
378 NodeDef node_def;
379 NodeDefBuilder builder(
380 strings::StrCat("collective_reduce_", reduce_counter_++),
381 "CollectiveReduce");
382 TF_CHECK_OK(
383 builder.Attr("T", params.instance.data_type)
384 .Attr("merge_op", "Add")
385 .Attr("final_op", "Id")
386 .Attr("group_size", params.group.group_size)
387 .Attr("group_key", params.group.group_key)
388 .Attr("instance_key", params.instance.instance_key)
389 .Attr("subdiv_offsets", params.instance.impl_details.subdiv_offsets)
390 .Input(FakeInput(params.instance.data_type))
391 .Finalize(&node_def));
392 return GetKernel(node_def, device_type, device);
393 }
394
RunSubdivPermsTest(CollectiveParams * cp,const std::vector<std::vector<int>> & expected_subdiv_perms,const std::vector<int> & expected_subdiv_rank)395 void RunSubdivPermsTest(
396 CollectiveParams* cp,
397 const std::vector<std::vector<int>>& expected_subdiv_perms,
398 const std::vector<int>& expected_subdiv_rank) {
399 col_exec_ = nullptr;
400 cp->instance.impl_details.subdiv_permutations.clear();
401 cp->subdiv_rank.clear();
402 // Create a stub ring reducer only for testing param initialization.
403 RingReducer* reducer = new RingReducer;
404 core::ScopedUnref unref(reducer);
405 TF_CHECK_OK(reducer->InitializeCollectiveParams(cp));
406 EXPECT_EQ(expected_subdiv_perms,
407 cp->instance.impl_details.subdiv_permutations);
408 EXPECT_EQ(expected_subdiv_rank, cp->subdiv_rank);
409 reducer->group_size_tensor_ready_.Notify(); // To unblock destructor.
410 }
411
412 class DeviceInstance {
413 public:
DeviceInstance(int rank,const string & dev_name,const DeviceType & device_type,RingReducerTest * parent)414 DeviceInstance(int rank, const string& dev_name,
415 const DeviceType& device_type, RingReducerTest* parent)
416 : parent_(parent),
417 dev_name_(dev_name),
418 device_type_(device_type),
419 rank_(rank),
420 col_params_(new CollectiveParams()) {
421 TF_CHECK_OK(parent_->dev_mgr_->LookupDevice(dev_name, &device_))
422 << "Couldn't find device " << dev_name
423 << " existing devices: " << parent_->dev_mgr_->DebugString();
424 col_params_->name = parent_->col_params_->name;
425 col_params_->group = parent_->col_params_->group;
426 col_params_->instance = parent->col_params_->instance;
427 col_params_->task.is_local = parent_->col_params_->task.is_local;
428 col_params_->subdiv_rank = parent_->col_params_->subdiv_rank;
429
430 int num_subdivs = static_cast<int>(col_params_->subdiv_rank.size());
431 int group_size = col_params_->group.group_size;
432 CHECK_EQ(group_size,
433 static_cast<int>(col_params_->group.device_names.size()));
434 // Id of this device is at rank position in first subdiv perm.
435 int my_device_id =
436 col_params_->instance.impl_details.subdiv_permutations[0][rank];
437 col_params_->default_rank = my_device_id;
438 // Set rank for all other subdivs by finding that device_id.
439 for (int sdi = 0; sdi < num_subdivs; ++sdi) {
440 for (int r = 0; r < static_cast<int>(col_params_->instance.impl_details
441 .subdiv_permutations[sdi]
442 .size());
443 ++r) {
444 if (my_device_id ==
445 col_params_->instance.impl_details.subdiv_permutations[sdi][r]) {
446 col_params_->subdiv_rank[sdi] = r;
447 break;
448 }
449 }
450 }
451 }
452
~DeviceInstance()453 ~DeviceInstance() { col_params_->Unref(); }
454
InitTensor(DataType dtype,const TensorShape & shape,const std::function<void (Tensor *)> & init_f)455 void InitTensor(DataType dtype, const TensorShape& shape,
456 const std::function<void(Tensor*)>& init_f) {
457 tensor_ =
458 Tensor(device_->GetAllocator(AllocatorAttributes()), dtype, shape);
459 if (device_type_ == DEVICE_CPU) {
460 init_f(&tensor_);
461 } else if (device_type_ == DEVICE_GPU) {
462 Tensor cpu_tensor(dtype, shape);
463 init_f(&cpu_tensor);
464 auto* dev_info = device_->tensorflow_gpu_device_info();
465 CHECK(dev_info);
466 CHECK(dev_info->default_context
467 ->CopyCPUTensorToDeviceSync(&cpu_tensor, device_, &tensor_)
468 .ok());
469 } else {
470 LOG(FATAL) << "Unsupported device_type " << device_type_;
471 }
472 }
473
DoReduce()474 void DoReduce() {
475 merge_op_ =
476 GetAdd(col_params_->instance.data_type, device_type_, device_);
477 final_op_ =
478 GetDiv(col_params_->instance.data_type, device_type_, device_);
479 col_params_->merge_op = merge_op_.get();
480 col_params_->final_op = final_op_.get();
481
482 // Prepare an OpKernelContext.
483 OpKernelContext::Params op_params;
484 op_params.step_id = kStepId;
485 op_params.device = device_;
486 op_params.cancellation_manager = &parent_->cancellation_manager_;
487 gtl::InlinedVector<TensorValue, 4> inputs;
488 inputs.push_back(TensorValue(&tensor_));
489 op_params.inputs = &inputs;
490 gtl::InlinedVector<AllocatorAttributes, 4> input_aa(
491 {AllocatorAttributes()});
492 op_params.input_alloc_attrs = &input_aa;
493 DeviceContext* dev_ctx = nullptr;
494 auto* dev_info = device_->tensorflow_gpu_device_info();
495 if (dev_info) {
496 dev_ctx = dev_info->default_context;
497 dev_ctx->Ref();
498 } else {
499 dev_ctx = new DeviceContext;
500 }
501 op_params.op_device_context = dev_ctx;
502 int forward_from = 0;
503 op_params.forward_from_array = &forward_from;
504 AllocatorAttributes generic_alloc_attr;
505 op_params.output_attr_array = &generic_alloc_attr;
506 std::unique_ptr<OpKernel> op = parent_->GetCollectiveReduce(
507 *col_params_, &tensor_, DEVICE_CPU, device_);
508 op_params.op_kernel = op.get();
509 OpKernelContext ctx(&op_params, 1);
510
511 // We never actually execute the kernel, so we need to do the output
512 // allocation it would do, ourselves.
513 Tensor* output_tensor_ptr = nullptr;
514 TF_CHECK_OK(ctx.forward_input_or_allocate_output({0}, 0, tensor_.shape(),
515 &output_tensor_ptr));
516 CHECK_EQ(output_tensor_ptr, ctx.mutable_output(0));
517
518 // Prepare a RingReducer instance.
519 string exec_key =
520 strings::StrCat(col_params_->instance.instance_key, ":0:0");
521 RingReducer* reducer = new RingReducer;
522 core::ScopedUnref unref(reducer);
523 auto col_ctx = std::make_shared<CollectiveContext>(
524 parent_->col_exec_, /*nccl_communicator*/ nullptr,
525 parent_->dev_mgr_.get(), &ctx, &op_params, col_params_, exec_key,
526 kStepId, &tensor_, &tensor_);
527 TF_CHECK_OK(reducer->InitializeCollectiveContext(col_ctx));
528
529 // Run the all-reduce.
530 reducer->Run([this](Status s) { status_ = s; });
531 if (status_.ok()) {
532 CHECK(tensor_.CopyFrom(*ctx.mutable_output(0), tensor_.shape()));
533 }
534
535 dev_ctx->Unref();
536 }
537
tensor()538 const Tensor& tensor() { return tensor_; }
539
540 RingReducerTest* parent_;
541 string dev_name_;
542 DeviceType device_type_;
543 int rank_;
544 Tensor tensor_;
545 Device* device_;
546 CollectiveParams* col_params_;
547 std::unique_ptr<OpKernel> merge_op_;
548 std::unique_ptr<OpKernel> final_op_;
549 std::unique_ptr<CollectiveAdapter> ca_;
550 std::unique_ptr<OpKernelContext> ctx_;
551 Status status_;
552 };
553
554 bool stop_ = false;
555 DeviceType device_type_;
556 TestCollectiveExecutorMgr col_exec_mgr_;
557 CollectiveExecutor* col_exec_;
558 CollectiveRemoteAccessLocal* rma_;
559 std::unique_ptr<DeviceResolverLocal> dev_resolver_;
560 std::shared_ptr<UnboundedWorkQueue> work_queue_;
561 std::vector<DeviceInstance*> instances_;
562 CollectiveParams* col_params_;
563 std::vector<std::unique_ptr<tensorflow::Device>> gpu_devices_;
564 std::unique_ptr<tensorflow::DeviceMgr> dev_mgr_;
565 std::unique_ptr<string> gpu_ring_order_;
566 mutex mu_;
567 int32 reduce_counter_ TF_GUARDED_BY(mu_) = 0;
568 CancellationManager cancellation_manager_;
569 };
570
SetUpCollectiveParams(const int num_devs_per_task,const int num_tasks)571 CollectiveParams* SetUpCollectiveParams(const int num_devs_per_task,
572 const int num_tasks) {
573 auto cp = new CollectiveParams();
574 const int kNumDevs = num_devs_per_task * num_tasks;
575 cp->group.group_key = 1;
576 cp->group.group_size = kNumDevs;
577 cp->group.device_type = DeviceType("GPU");
578 cp->group.num_tasks = num_tasks;
579 cp->instance.instance_key = 3;
580 cp->instance.type = REDUCTION_COLLECTIVE;
581 cp->instance.data_type = DataType(DT_FLOAT);
582 cp->instance.shape = TensorShape({kNumDevs});
583 cp->instance.impl_details.collective_name = "RingReduce";
584 cp->instance.impl_details.subdiv_offsets.push_back(0);
585 cp->is_source = false;
586 for (int i = 0; i < kNumDevs; ++i) {
587 int task_id = i / num_devs_per_task;
588 int dev_id = i % num_devs_per_task;
589 string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id);
590 string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id);
591 cp->group.task_names.push_back(task_name);
592 cp->group.device_names.push_back(device_name);
593 }
594 return cp;
595 }
596
TEST_F(RingReducerTest,InitializeParams)597 TEST_F(RingReducerTest, InitializeParams) {
598 const int kNumDevsPerTask = 8;
599 const int kNumTasks = 3;
600 CollectiveParams* cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
601 core::ScopedUnref unref(cp);
602
603 cp->default_rank = 0;
604 cp->instance.impl_details.subdiv_offsets = {0, 4};
605 RunSubdivPermsTest(cp,
606 {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
607 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
608 {4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15,
609 8, 9, 10, 11, 20, 21, 22, 23, 16, 17, 18, 19}},
610 {0, 4});
611
612 cp->instance.impl_details.subdiv_offsets = {0, -4};
613 RunSubdivPermsTest(cp,
614 {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
615 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
616 {3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8,
617 15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20}},
618 {0, 3});
619
620 cp->default_rank = 3;
621 cp->instance.impl_details.subdiv_offsets = {3, -3};
622 RunSubdivPermsTest(cp,
623 {{3, 4, 5, 6, 7, 0, 1, 2, 11, 12, 13, 14,
624 15, 8, 9, 10, 19, 20, 21, 22, 23, 16, 17, 18},
625 {4, 3, 2, 1, 0, 7, 6, 5, 12, 11, 10, 9,
626 8, 15, 14, 13, 20, 19, 18, 17, 16, 23, 22, 21}},
627 {0, 1});
628 }
629
TEST_F(RingReducerTest,AutomaticSubdivs)630 TEST_F(RingReducerTest, AutomaticSubdivs) {
631 const int kNumDevsPerTask = 8;
632 const int kNumTasks = 3;
633 const int kNumDevs = kNumDevsPerTask * kNumTasks;
634 CollectiveParams* cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
635 core::ScopedUnref unref(cp);
636
637 // Test automatic generation of subdiv offsets.
638 cp->default_rank = 0;
639 cp->instance.impl_details.subdiv_offsets.clear();
640 RunSubdivPermsTest(cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
641 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
642 {0});
643
644 // Set shape so that with 2 subdivs chunk_size is 3 MiB. This should cause 2
645 // offsets, {0, -4}, to be generated.
646 {
647 int num_subdivs = 2;
648 int num_chunks = kNumDevs * num_subdivs;
649 size_t chunk_size = 3 * 1048576; // 3 MB
650 size_t tensor_size = chunk_size * num_chunks;
651 cp->instance.shape =
652 TensorShape({static_cast<int64>(tensor_size / DataTypeSize(DT_FLOAT))});
653 }
654 cp->instance.impl_details.subdiv_offsets.clear();
655 RunSubdivPermsTest(cp,
656 {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
657 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
658 {3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8,
659 15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20}},
660 {0, 3});
661 }
662
TEST_F(RingReducerTest,AutomaticSubdivUpperBound)663 TEST_F(RingReducerTest, AutomaticSubdivUpperBound) {
664 const int kNumDevsPerTask = 1;
665 const int kNumTasks = 4;
666 CollectiveParams* cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
667 core::ScopedUnref unref(cp);
668
669 cp->default_rank = 0;
670 cp->instance.impl_details.subdiv_offsets.clear();
671 cp->instance.shape = TensorShape({104857600 / DataTypeSize(DT_FLOAT)});
672 RunSubdivPermsTest(cp, {{0, 1, 2, 3}, {0, 1, 2, 3}}, {0, 0});
673 }
674
675 // TODO(b/113171733): change to use TEST_P.
676 #define DEF_TEST(B, T, W, D, S, L, A) \
677 TEST_F(RingReducerTest, \
678 DaTy##B##_DevTy##T##_Wkr##W##_Dev##D##_Sdiv##S##_Len##L##_Abrt##A) { \
679 DataType dtype = DT_##B; \
680 switch (dtype) { \
681 case DT_FLOAT: { \
682 RunTest<float>(dtype, DEVICE_##T, W, D, S, L, A); \
683 } break; \
684 case DT_DOUBLE: { \
685 RunTest<double>(dtype, DEVICE_##T, W, D, S, L, A); \
686 } break; \
687 case DT_INT32: { \
688 RunTest<int32>(dtype, DEVICE_##T, W, D, S, L, A); \
689 } break; \
690 case DT_INT64: { \
691 RunTest<int64>(dtype, DEVICE_##T, W, D, S, L, A); \
692 } break; \
693 default: \
694 LOG(FATAL) << "Unimplemented"; \
695 } \
696 }
697
698 #if !(GOOGLE_CUDA || TENSORFLOW_USE_ROCM)
699 // Success tests
700 DEF_TEST(FLOAT, CPU, 1, 2, 1, 1, 0)
701 DEF_TEST(FLOAT, CPU, 1, 2, 1, 2, 0)
702 DEF_TEST(FLOAT, CPU, 1, 2, 1, 8, 0)
703 DEF_TEST(FLOAT, CPU, 1, 2, 1, 16, 0)
704 DEF_TEST(FLOAT, CPU, 1, 2, 1, 1001, 0)
705 DEF_TEST(FLOAT, CPU, 2, 4, 1, 128, 0)
706 DEF_TEST(FLOAT, CPU, 2, 8, 1, 1001, 0)
707 DEF_TEST(FLOAT, CPU, 2, 8, 1, 4096, 0)
708 DEF_TEST(FLOAT, CPU, 2, 8, 1, 9408, 0)
709 DEF_TEST(FLOAT, CPU, 2, 8, 3, 4095, 0)
710 DEF_TEST(FLOAT, CPU, 2, 8, 3, 1045991, 0)
711 DEF_TEST(FLOAT, CPU, 4, 4, 4, 1045991, 0)
712 DEF_TEST(DOUBLE, CPU, 1, 2, 1, 1001, 0)
713 DEF_TEST(DOUBLE, CPU, 2, 8, 3, 4095, 0)
714 DEF_TEST(INT32, CPU, 1, 2, 1, 1001, 0)
715 DEF_TEST(INT32, CPU, 2, 8, 3, 4095, 0)
716 DEF_TEST(INT64, CPU, 1, 2, 1, 1001, 0)
717 DEF_TEST(INT64, CPU, 2, 8, 3, 4095, 0)
718
719 // Failure tests
720 DEF_TEST(FLOAT, CPU, 2, 8, 1, 9408, 1)
721 DEF_TEST(FLOAT, CPU, 2, 8, 1, 9408, 7)
722 DEF_TEST(FLOAT, CPU, 2, 8, 2, 9408, 11)
723 #endif
724
725 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
726 // GPU tests. So long as the device names are all in a single tasks we
727 // bypass inter-worker routing code and can fake multiple GPUs with a single
728 // GPU, from the perspective of the RingReducer logic. So these tests
729 // are all single-worker.
730 DEF_TEST(FLOAT, GPU, 1, 2, 1, 1, 0)
731 DEF_TEST(FLOAT, GPU, 1, 2, 1, 2, 0)
732 DEF_TEST(FLOAT, GPU, 1, 2, 1, 8, 0)
733 DEF_TEST(FLOAT, GPU, 1, 2, 1, 16, 0)
734 DEF_TEST(FLOAT, GPU, 1, 2, 1, 1001, 0)
735 DEF_TEST(FLOAT, GPU, 1, 8, 1, 1001, 0)
736 DEF_TEST(FLOAT, GPU, 1, 8, 1, 4096, 0)
737 DEF_TEST(FLOAT, GPU, 1, 8, 3, 4095, 0)
738 DEF_TEST(FLOAT, GPU, 1, 8, 3, 1045991, 0)
739 DEF_TEST(FLOAT, GPU, 1, 4, 4, 1045991, 0)
740 DEF_TEST(DOUBLE, GPU, 1, 2, 1, 1001, 0)
741 // INT32 values are never on the GPU.
742 // DEF_TEST(INT32, GPU, 1, 2, 1, 1001, 0)
743 DEF_TEST(INT64, GPU, 1, 2, 1, 1001, 0)
744
745 // Failure tests
746 DEF_TEST(FLOAT, GPU, 1, 8, 1, 9408, 2)
747 DEF_TEST(FLOAT, GPU, 1, 8, 2, 9408, 5)
748 #endif
749
750 } // namespace tensorflow
751