1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/base_collective_executor.h"
16
17 #include <algorithm>
18 #include <functional>
19 #include <utility>
20
21 #include "tensorflow/core/common_runtime/copy_tensor.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/common_runtime/dma_helper.h"
24 #include "tensorflow/core/common_runtime/process_util.h"
25 #include "tensorflow/core/framework/allocator.h"
26 #include "tensorflow/core/framework/cancellation.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/framework/types.pb.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/notification.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/platform/macros.h"
37 #include "tensorflow/core/platform/refcount.h"
38 #include "tensorflow/core/platform/tracing.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/profiler/lib/connected_traceme.h"
41 #include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h"
42 #include "tensorflow/core/profiler/lib/traceme.h"
43
44 #define VALUE_IN_DEBUG_STRING false
45
46 namespace tensorflow {
47
48 namespace {
IsCancelled(CancellationManager * cancel_mgr)49 bool IsCancelled(CancellationManager* cancel_mgr) {
50 return cancel_mgr != nullptr &&
51 (cancel_mgr->IsCancelled() || cancel_mgr->IsCancelling());
52 }
53 } // namespace
54
55 /*static*/
AlignedChunkElts(int64_t elt_bytes,int64_t total_elts,int64_t num_chunks)56 int64_t CollectiveAdapter::AlignedChunkElts(int64_t elt_bytes,
57 int64_t total_elts,
58 int64_t num_chunks) {
59 DCHECK_GT(num_chunks, 0);
60 int64_t base_chunk_elts = (total_elts + (num_chunks - 1)) / num_chunks;
61 if (EIGEN_MAX_ALIGN_BYTES == 0) return base_chunk_elts;
62 if (EIGEN_MAX_ALIGN_BYTES <= elt_bytes) {
63 // Tolerate weird small values of EIGEN_MAX_ALIGN_BYTES
64 DCHECK_EQ(0, elt_bytes % EIGEN_MAX_ALIGN_BYTES);
65 return base_chunk_elts;
66 }
67 // elt_bytes < EIGEN_MAX_ALIGN_BYTES, which
68 // must be a common multiple of the various atomic data types.
69 DCHECK_EQ(0, EIGEN_MAX_ALIGN_BYTES % elt_bytes)
70 << "total_elts=" << total_elts << " num_chunks=" << num_chunks
71 << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES
72 << " elt_bytes=" << elt_bytes;
73 // Round bytes per chunk up to the next multiple of EIGEN_MAX_ALIGN_BYTES.
74 int64_t chunk_bytes = base_chunk_elts * elt_bytes;
75 int64_t diff =
76 (chunk_bytes < EIGEN_MAX_ALIGN_BYTES)
77 ? (EIGEN_MAX_ALIGN_BYTES - chunk_bytes)
78 : (EIGEN_MAX_ALIGN_BYTES - (chunk_bytes % EIGEN_MAX_ALIGN_BYTES));
79 DCHECK_EQ(0, diff % elt_bytes);
80 base_chunk_elts += (diff / elt_bytes);
81 DCHECK_EQ(0, ((base_chunk_elts * elt_bytes) % EIGEN_MAX_ALIGN_BYTES))
82 << "total_elts=" << total_elts << " num_chunks=" << num_chunks
83 << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES
84 << " base_chunk_elts=" << base_chunk_elts << " elt_bytes=" << elt_bytes;
85 return base_chunk_elts;
86 }
87
88 namespace {
89 template <typename T>
90 class CollectiveAdapterImpl : public CollectiveAdapter {
91 public:
92 // Takes ownership of output and prepares to properly alias its chunks.
93 // Ownership is taken because the shape may temporarily change.
CollectiveAdapterImpl(Tensor * output,int64_t num_chunks,Allocator * allocator,bool align_chunks)94 CollectiveAdapterImpl(Tensor* output, int64_t num_chunks,
95 Allocator* allocator, bool align_chunks)
96 : output_(std::move(*output)),
97 dt_(output_.dtype()),
98 old_shape_(output_.shape()),
99 num_chunks_(num_chunks),
100 allocator_(allocator),
101 total_elts_(output_.NumElements()),
102 chunk_elts_(align_chunks
103 ? AlignedChunkElts(sizeof(T), total_elts_, num_chunks_)
104 : total_elts_ / num_chunks_),
105 data_start_(reinterpret_cast<T*>(DMAHelper::base(&output_))),
106 data_end_(data_start_ + total_elts_) {
107 if (!align_chunks) {
108 DCHECK_EQ(total_elts_, num_chunks_ * chunk_elts_);
109 }
110 DCHECK_GT(chunk_elts_, 0);
111 Flatten();
112 }
113
~CollectiveAdapterImpl()114 ~CollectiveAdapterImpl() override {}
115
Value() const116 const Tensor& Value() const override { return output_; }
117
118 // If necessary, flatten output.
Flatten()119 void Flatten() {
120 if (old_shape_.dims() != 1) {
121 TensorShape new_shape = TensorShape({old_shape_.num_elements()});
122 DMAHelper::UnsafeSetShape(&output_, new_shape);
123 }
124 }
125
ConsumeFinalValue(Tensor * output)126 void ConsumeFinalValue(Tensor* output) override {
127 if (old_shape_ != output_.shape()) {
128 DMAHelper::UnsafeSetShape(&output_, old_shape_);
129 }
130 *output = std::move(output_);
131 }
132
133 // Number of T elements in a particular chunk.
ChunkElts(int i) const134 inline int64_t ChunkElts(int i) const {
135 DCHECK_LT(i, num_chunks_);
136 const T* chunk_start = std::min(data_end_, data_start_ + i * chunk_elts_);
137 const T* chunk_end = std::min(data_end_, chunk_start + chunk_elts_);
138 return chunk_end - chunk_start;
139 }
140
ChunkBytes(int i) const141 int64_t ChunkBytes(int i) const override { return sizeof(T) * ChunkElts(i); }
142
143 // Returns a new Tensor that aliases the required chunk.
ChunkAlias(int i)144 Tensor ChunkAlias(int i) override {
145 int64_t start = chunk_elts_ * i;
146 int64_t num_elts = ChunkElts(i);
147 // If this chunk is empty the prior chunk might also be short
148 // so always take an empty slice from the front of the tensor
149 // to avoid an illegal offset check failure somewhere.
150 return (num_elts > 0) ? output_.Slice(start, start + num_elts)
151 : output_.Slice(0, 0);
152 }
153
TempChunk(int i) const154 Tensor TempChunk(int i) const override {
155 AllocationAttributes empty;
156 profiler::ScopedMemoryDebugAnnotation op_annotation(
157 "CollectiveAdapterImpl::TempChunk");
158 return Tensor(allocator_, dt_, {ChunkElts(i)}, empty);
159 }
160
DebugString() const161 string DebugString() const override {
162 return strings::StrCat(
163 "base addr ", reinterpret_cast<int64_t>(DMAHelper::base(&output_)),
164 " num_chunks ", num_chunks_, " total_elts ", total_elts_, " chunk_elts",
165 chunk_elts_, " value ",
166 VALUE_IN_DEBUG_STRING ? output_.SummarizeValue(1024) : "<hidden>");
167 }
168
TBounds(const Tensor & t) const169 string TBounds(const Tensor& t) const override {
170 int64_t base_addr = reinterpret_cast<int64_t>(DMAHelper::base(&t));
171 return strings::StrCat("(", base_addr, ", ", (base_addr + t.TotalBytes()),
172 ")");
173 }
174
Scalar(int v) const175 Tensor Scalar(int v) const override { return Tensor(static_cast<T>(v)); }
176
Scalar(Allocator * a,const AllocationAttributes & attr) const177 Tensor Scalar(Allocator* a, const AllocationAttributes& attr) const override {
178 Tensor t(a, dt_, TensorShape({}), attr);
179 return t;
180 }
181
182 Tensor output_;
183 const DataType dt_;
184 const TensorShape old_shape_;
185 const int64_t num_chunks_;
186 Allocator* allocator_;
187 const int64_t total_elts_;
188 const int64_t chunk_elts_;
189 const T* data_start_;
190 const T* data_end_;
191 };
192
193 } // namespace
194
MakeCollectiveAdapter(Tensor * output,int num_chunks,Allocator * allocator,bool align_chunks)195 CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks,
196 Allocator* allocator,
197 bool align_chunks) {
198 switch (output->dtype()) {
199 case DT_BFLOAT16:
200 return new CollectiveAdapterImpl<Eigen::bfloat16>(
201 output, num_chunks, allocator, align_chunks);
202 break;
203 case DT_HALF:
204 return new CollectiveAdapterImpl<Eigen::half>(output, num_chunks,
205 allocator, align_chunks);
206 break;
207 case DT_FLOAT:
208 return new CollectiveAdapterImpl<float>(output, num_chunks, allocator,
209 align_chunks);
210 break;
211 case DT_DOUBLE:
212 return new CollectiveAdapterImpl<double>(output, num_chunks, allocator,
213 align_chunks);
214 break;
215 case DT_INT32:
216 return new CollectiveAdapterImpl<int32>(output, num_chunks, allocator,
217 align_chunks);
218 break;
219 case DT_INT64:
220 return new CollectiveAdapterImpl<int64_t>(output, num_chunks, allocator,
221 align_chunks);
222 break;
223 default:
224 LOG(FATAL) << "Unsupported type " << DataTypeString(output->dtype())
225 << " to MakeCollectiveAdapter";
226 return nullptr;
227 }
228 }
229
~BaseCollectiveExecutor()230 BaseCollectiveExecutor::~BaseCollectiveExecutor() {}
231
StartAbort(const Status & s)232 void BaseCollectiveExecutor::StartAbort(const Status& s) {
233 Status status;
234 {
235 mutex_lock l(status_mu_);
236 if (!status_.ok()) {
237 VLOG(2) << "BaseCollectiveExecutor already aborted, ignoring StartAbort: "
238 << s;
239 return;
240 }
241 status_ = StatusGroup::MakeDerived(Status(
242 s.code(),
243 absl::StrCat(
244 "Collective ops is aborted by: ", s.error_message(),
245 "\nThe error could be from a previous operation. Restart your "
246 "program to reset.")));
247 status = status_;
248 }
249 LOG(ERROR) << "BaseCollectiveExecutor::StartAbort " << s;
250 cem_->GetParamResolver()->StartAbort(status);
251 remote_access_->StartAbort(status);
252 if (cem_->GetNcclCommunicator() != nullptr) {
253 cem_->GetNcclCommunicator()->StartAbort(status);
254 }
255 }
256
GetStatus(const Status & s)257 Status BaseCollectiveExecutor::GetStatus(const Status& s) {
258 if (s.ok()) return s;
259 mutex_lock l(status_mu_);
260 // If the collective executor is already aborted, use the aborted status
261 // which is more likely the actual error instead of an artifact of an
262 // abortion.
263 if (!status_.ok()) {
264 VLOG(2) << "Overriding status with collective ops executor status. "
265 "Original status: "
266 << s;
267 return status_;
268 }
269 return s;
270 }
271
ExecuteAsync(OpKernelContext * ctx,const CollectiveParams * col_params,const string & exec_key,StatusCallback done)272 void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
273 const CollectiveParams* col_params,
274 const string& exec_key,
275 StatusCallback done) {
276 // See CompleteParamsAsync() how done() and the timeout callback interacts.
277 const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
278 auto done_safe = [this, done, ctx, is_callback_called](const Status& s) {
279 bool called = is_callback_called->exchange(true);
280 if (!called) {
281 if (!s.ok() && !IsCancelled(ctx->cancellation_manager())) {
282 // This is a collective error. Abort CollectiveExecutor so that this
283 // error can propagate to other workers.
284 StartAbort(s);
285 }
286 done(GetStatus(s));
287 }
288 };
289 auto timeout_microseconds = static_cast<int64_t>(
290 col_params->instance.impl_details.timeout_seconds * 1'000'000);
291 if (timeout_microseconds > 0) {
292 // TODO(xldrx): Share the timeout watchdog thread among collectives.
293 SchedNonBlockingClosureAfter(
294 timeout_microseconds, [this, is_callback_called, done] {
295 bool called = is_callback_called->exchange(true);
296 if (!called) {
297 Status status(error::DEADLINE_EXCEEDED,
298 "Collective has timed out during execution.");
299 StartAbort(status);
300 done(status);
301 }
302 });
303 }
304
305 Tensor* output = ctx->mutable_output(0);
306 const Tensor* input = (col_params->instance.type == REDUCTION_COLLECTIVE ||
307 col_params->instance.type == GATHER_COLLECTIVE ||
308 col_params->instance.type == PERMUTE_COLLECTIVE ||
309 col_params->instance.type == ALL_TO_ALL_COLLECTIVE ||
310 (col_params->instance.type == BROADCAST_COLLECTIVE &&
311 col_params->is_source))
312 ? &ctx->input(0)
313 : nullptr;
314 CollectiveImplementationInterface* col_impl = nullptr;
315 Status status = CreateCollective(*col_params, &col_impl);
316 if (!status.ok()) {
317 done_safe(status);
318 DCHECK_EQ(nullptr, col_impl);
319 return;
320 }
321 core::ScopedUnref unref(col_impl);
322 auto col_ctx = std::make_shared<CollectiveContext>(
323 this, cem_->GetNcclCommunicator(), dev_mgr_, ctx, CtxParams(ctx),
324 col_params, exec_key, step_id_, input, output);
325 status = col_impl->InitializeCollectiveContext(col_ctx);
326 if (!status.ok()) {
327 done_safe(status);
328 return;
329 }
330 // Run on an unbounded work queue that can handle blocking work so as to not
331 // starve executor threads.
332 col_impl->Ref();
333 profiler::TraceMeProducer producer("BaseCollectiveExecutor::ExecuteAsync");
334 RunClosure([col_impl, col_ctx, done_safe, ctx,
335 context_id = producer.GetContextId()]() {
336 core::ScopedUnref unref(col_impl);
337 profiler::TraceMeConsumer consumer(
338 [ctx, col_ctx] {
339 string op = profiler::TraceMeOp(ctx->op_kernel().name_view(),
340 ctx->op_kernel().type_string_view());
341 return profiler::TraceMeEncode(
342 std::move(op),
343 {{"id", ctx->step_id()},
344 {"instance_key", col_ctx->col_params->instance.instance_key},
345 {"collective", col_ctx->col_params->instance.type}});
346 },
347 context_id);
348 col_impl->Ref();
349 col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) {
350 core::ScopedUnref unref(col_impl);
351 done_safe(s);
352 });
353 });
354 }
355
CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,StatusCallback done)356 void BaseCollectiveExecutor::CompleteParamsAsync(
357 const DeviceAttributes& device, CollectiveParams* cp,
358 CancellationManager* cancel_mgr, StatusCallback done) {
359 // We need to make sure that when the timeout callback executes,
360 // CollectiveExecutor and CollectiveExecutorMgr are both alive. After done()
361 // is called, CollectiveExecutorMgr may be destructed and we don't have a way
362 // to keep it without making the ownerships more complicated. Therefore if the
363 // timeout callback executes, done_safe will become a no-op and the timeout
364 // callback is responsible for invoking done() at the end.
365 const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
366 int64_t trace_id = profiler::TraceMe::ActivityStart([cp]() {
367 return profiler::TraceMeEncode("CollectiveExecutor::CompleteParams",
368 {{"group_key", cp->group.group_key},
369 {"group_size", cp->group.group_size}});
370 });
371
372 auto done_safe = [this, is_callback_called, cancel_mgr, trace_id,
373 done](const Status& s) {
374 profiler::TraceMe::ActivityEnd(trace_id);
375 bool called = is_callback_called->exchange(true);
376 if (!called) {
377 if (!s.ok() && !IsCancelled(cancel_mgr)) {
378 // This is a collective error. Abort CollectiveExecutor so that this
379 // error can propagate to other workers.
380 StartAbort(s);
381 }
382 done(GetStatus(s));
383 }
384 };
385 auto timeout_microseconds = static_cast<int64_t>(
386 cp->instance.impl_details.timeout_seconds * 1'000'000);
387 if (timeout_microseconds > 0) {
388 // TODO(xldrx): Share the timeout watchdog thread among collectives.
389 SchedNonBlockingClosureAfter(
390 timeout_microseconds, [this, is_callback_called, done]() {
391 bool called = is_callback_called->exchange(true);
392 if (!called) {
393 Status status(
394 error::DEADLINE_EXCEEDED,
395 "Collective has timed out waiting for other workers.");
396 StartAbort(status);
397 done(status);
398 }
399 });
400 }
401 cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr,
402 done_safe);
403 }
404
CreateCollective(const CollectiveParams & col_params,CollectiveImplementationInterface ** col_impl)405 Status BaseCollectiveExecutor::CreateCollective(
406 const CollectiveParams& col_params,
407 CollectiveImplementationInterface** col_impl) {
408 VLOG(2) << "CreateCollective type "
409 << DataTypeString(col_params.instance.data_type) << " name "
410 << col_params.instance.impl_details.collective_name;
411 *col_impl = nullptr;
412 switch (col_params.instance.data_type) {
413 case DT_BOOL:
414 if (col_params.instance.type == BROADCAST_COLLECTIVE) {
415 return CollectiveRegistry::Lookup(
416 col_params.instance.impl_details.collective_name, col_impl);
417 } else {
418 return errors::Internal(
419 "No collective other than broadcast supports DT_BOOL");
420 }
421 case DT_INT32:
422 if (col_params.group.device_type == DEVICE_GPU &&
423 col_params.instance.type == REDUCTION_COLLECTIVE) {
424 // TODO(b/139421603): enable int32 all-reduce on GPU.
425 return errors::Internal(
426 "Collective all-reduce does not support datatype DT_INT32 on "
427 "DEVICE_GPU");
428 } else {
429 return CollectiveRegistry::Lookup(
430 col_params.instance.impl_details.collective_name, col_impl);
431 }
432 case DT_BFLOAT16:
433 if (col_params.group.device_type == DEVICE_GPU &&
434 col_params.instance.type == REDUCTION_COLLECTIVE) {
435 return errors::Internal(
436 "Collective all-reduce does not support datatype DT_BFLOAT16 on "
437 "DEVICE_GPU");
438 } else {
439 return CollectiveRegistry::Lookup(
440 col_params.instance.impl_details.collective_name, col_impl);
441 }
442 case DT_HALF:
443 case DT_FLOAT:
444 case DT_DOUBLE:
445 case DT_INT64: {
446 return CollectiveRegistry::Lookup(
447 col_params.instance.impl_details.collective_name, col_impl);
448 }
449 default:
450 return errors::Internal(
451 "CollectiveImplementation does not support datatype ",
452 DataTypeString(col_params.instance.data_type));
453 }
454 }
455
CheckDependencies(const CollectiveParams & col_params)456 bool BaseCollectiveExecutor::CheckDependencies(
457 const CollectiveParams& col_params) {
458 for (int32_t instance : col_params.instance.impl_details.dependencies) {
459 auto find_iter = launched_.find(instance);
460 if (find_iter == launched_.end() || find_iter->second != 0) {
461 VLOG(1) << "Collective " << col_params.ToString()
462 << " blocked by instance " << instance;
463 return false;
464 }
465 }
466 return true;
467 }
468
WaitForDependencies(const CollectiveParams & col_params)469 void BaseCollectiveExecutor::WaitForDependencies(
470 const CollectiveParams& col_params) {
471 mutex_lock l(launch_mu_);
472 while (!CheckDependencies(col_params)) {
473 launch_cv_.wait(l);
474 }
475 VLOG(1) << "Unblocking collective " << col_params.ToString();
476 }
477
UnblockDependencies(const CollectiveParams & col_params)478 void BaseCollectiveExecutor::UnblockDependencies(
479 const CollectiveParams& col_params) {
480 mutex_lock l(launch_mu_);
481 if (launched_.find(col_params.instance.instance_key) == launched_.end()) {
482 const string& task_name =
483 col_params.group.members[col_params.default_rank].task;
484 const int32_t num_devices =
485 col_params.group.num_devices_per_task.at(task_name);
486 launched_[col_params.instance.instance_key] = num_devices;
487 }
488 if (--launched_[col_params.instance.instance_key] == 0) {
489 VLOG(1) << "Unblocking dependencies for collective instance "
490 << col_params.instance.instance_key;
491 launch_cv_.notify_all();
492 }
493 }
494
495 } // namespace tensorflow
496