1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/base_collective_executor.h"
16
17 #include <algorithm>
18 #include <functional>
19 #include <utility>
20
21 #include "tensorflow/core/common_runtime/copy_tensor.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/common_runtime/dma_helper.h"
24 #include "tensorflow/core/common_runtime/process_util.h"
25 #include "tensorflow/core/framework/allocator.h"
26 #include "tensorflow/core/framework/cancellation.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/framework/types.pb.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/notification.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/platform/macros.h"
37 #include "tensorflow/core/platform/refcount.h"
38 #include "tensorflow/core/platform/tracing.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/profiler/lib/connected_traceme.h"
41 #include "tensorflow/core/profiler/lib/traceme.h"
42
43 #define VALUE_IN_DEBUG_STRING false
44
45 namespace tensorflow {
46
47 namespace {
IsCancelled(CancellationManager * cancel_mgr)48 bool IsCancelled(CancellationManager* cancel_mgr) {
49 return cancel_mgr != nullptr &&
50 (cancel_mgr->IsCancelled() || cancel_mgr->IsCancelling());
51 }
52 } // namespace
53
54 /*static*/
AlignedChunkElts(int64_t elt_bytes,int64_t total_elts,int64_t num_chunks)55 int64 CollectiveAdapter::AlignedChunkElts(int64_t elt_bytes, int64_t total_elts,
56 int64_t num_chunks) {
57 DCHECK_GT(num_chunks, 0);
58 int64_t base_chunk_elts = (total_elts + (num_chunks - 1)) / num_chunks;
59 if (EIGEN_MAX_ALIGN_BYTES == 0) return base_chunk_elts;
60 if (EIGEN_MAX_ALIGN_BYTES <= elt_bytes) {
61 // Tolerate weird small values of EIGEN_MAX_ALIGN_BYTES
62 DCHECK_EQ(0, elt_bytes % EIGEN_MAX_ALIGN_BYTES);
63 return base_chunk_elts;
64 }
65 // elt_bytes < EIGEN_MAX_ALIGN_BYTES, which
66 // must be a common multiple of the various atomic data types.
67 DCHECK_EQ(0, EIGEN_MAX_ALIGN_BYTES % elt_bytes)
68 << "total_elts=" << total_elts << " num_chunks=" << num_chunks
69 << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES
70 << " elt_bytes=" << elt_bytes;
71 // Round bytes per chunk up to the next multiple of EIGEN_MAX_ALIGN_BYTES.
72 int64_t chunk_bytes = base_chunk_elts * elt_bytes;
73 int64_t diff =
74 (chunk_bytes < EIGEN_MAX_ALIGN_BYTES)
75 ? (EIGEN_MAX_ALIGN_BYTES - chunk_bytes)
76 : (EIGEN_MAX_ALIGN_BYTES - (chunk_bytes % EIGEN_MAX_ALIGN_BYTES));
77 DCHECK_EQ(0, diff % elt_bytes);
78 base_chunk_elts += (diff / elt_bytes);
79 DCHECK_EQ(0, ((base_chunk_elts * elt_bytes) % EIGEN_MAX_ALIGN_BYTES))
80 << "total_elts=" << total_elts << " num_chunks=" << num_chunks
81 << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES
82 << " base_chunk_elts=" << base_chunk_elts << " elt_bytes=" << elt_bytes;
83 return base_chunk_elts;
84 }
85
86 namespace {
87 template <typename T>
88 class CollectiveAdapterImpl : public CollectiveAdapter {
89 public:
90 // Takes ownership of output and prepares to properly alias its chunks.
91 // Ownership is taken because the shape may temporarily change.
CollectiveAdapterImpl(Tensor * output,int64_t num_chunks,Allocator * allocator,bool align_chunks)92 CollectiveAdapterImpl(Tensor* output, int64_t num_chunks,
93 Allocator* allocator, bool align_chunks)
94 : output_(std::move(*output)),
95 dt_(output_.dtype()),
96 old_shape_(output_.shape()),
97 num_chunks_(num_chunks),
98 allocator_(allocator),
99 total_elts_(output_.NumElements()),
100 chunk_elts_(align_chunks
101 ? AlignedChunkElts(sizeof(T), total_elts_, num_chunks_)
102 : total_elts_ / num_chunks_),
103 data_start_(reinterpret_cast<T*>(DMAHelper::base(&output_))),
104 data_end_(data_start_ + total_elts_) {
105 if (!align_chunks) {
106 DCHECK_EQ(total_elts_, num_chunks_ * chunk_elts_);
107 }
108 DCHECK_GT(chunk_elts_, 0);
109 Flatten();
110 }
111
~CollectiveAdapterImpl()112 ~CollectiveAdapterImpl() override {}
113
Value() const114 const Tensor& Value() const override { return output_; }
115
116 // If necessary, flatten output.
Flatten()117 void Flatten() {
118 if (old_shape_.dims() != 1) {
119 TensorShape new_shape = TensorShape({old_shape_.num_elements()});
120 DMAHelper::UnsafeSetShape(&output_, new_shape);
121 }
122 }
123
ConsumeFinalValue(Tensor * output)124 void ConsumeFinalValue(Tensor* output) override {
125 if (old_shape_ != output_.shape()) {
126 DMAHelper::UnsafeSetShape(&output_, old_shape_);
127 }
128 *output = std::move(output_);
129 }
130
131 // Number of T elements in a particular chunk.
ChunkElts(int i) const132 inline int64 ChunkElts(int i) const {
133 DCHECK_LT(i, num_chunks_);
134 const T* chunk_start = std::min(data_end_, data_start_ + i * chunk_elts_);
135 const T* chunk_end = std::min(data_end_, chunk_start + chunk_elts_);
136 return chunk_end - chunk_start;
137 }
138
ChunkBytes(int i) const139 int64 ChunkBytes(int i) const override { return sizeof(T) * ChunkElts(i); }
140
141 // Returns a new Tensor that aliases the required chunk.
ChunkAlias(int i)142 Tensor ChunkAlias(int i) override {
143 int64_t start = chunk_elts_ * i;
144 int64_t num_elts = ChunkElts(i);
145 // If this chunk is empty the prior chunk might also be short
146 // so always take an empty slice from the front of the tensor
147 // to avoid an illegal offset check failure somewhere.
148 return (num_elts > 0) ? output_.Slice(start, start + num_elts)
149 : output_.Slice(0, 0);
150 }
151
TempChunk(int i) const152 Tensor TempChunk(int i) const override {
153 AllocationAttributes empty;
154 ScopedMemoryDebugAnnotation op_annotation(
155 "CollectiveAdapterImpl::TempChunk");
156 return Tensor(allocator_, dt_, {ChunkElts(i)}, empty);
157 }
158
DebugString() const159 string DebugString() const override {
160 return strings::StrCat(
161 "base addr ", reinterpret_cast<int64>(DMAHelper::base(&output_)),
162 " num_chunks ", num_chunks_, " total_elts ", total_elts_, " chunk_elts",
163 chunk_elts_, " value ",
164 VALUE_IN_DEBUG_STRING ? output_.SummarizeValue(1024) : "<hidden>");
165 }
166
TBounds(const Tensor & t) const167 string TBounds(const Tensor& t) const override {
168 int64_t base_addr = reinterpret_cast<int64>(DMAHelper::base(&t));
169 return strings::StrCat("(", base_addr, ", ", (base_addr + t.TotalBytes()),
170 ")");
171 }
172
Scalar(int v) const173 Tensor Scalar(int v) const override { return Tensor(static_cast<T>(v)); }
174
Scalar(Allocator * a,const AllocationAttributes & attr) const175 Tensor Scalar(Allocator* a, const AllocationAttributes& attr) const override {
176 Tensor t(a, dt_, TensorShape({}), attr);
177 return t;
178 }
179
180 Tensor output_;
181 const DataType dt_;
182 const TensorShape old_shape_;
183 const int64 num_chunks_;
184 Allocator* allocator_;
185 const int64 total_elts_;
186 const int64 chunk_elts_;
187 const T* data_start_;
188 const T* data_end_;
189 };
190
191 } // namespace
192
MakeCollectiveAdapter(Tensor * output,int num_chunks,Allocator * allocator,bool align_chunks)193 CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks,
194 Allocator* allocator,
195 bool align_chunks) {
196 switch (output->dtype()) {
197 case DT_BFLOAT16:
198 return new CollectiveAdapterImpl<Eigen::bfloat16>(
199 output, num_chunks, allocator, align_chunks);
200 break;
201 case DT_HALF:
202 return new CollectiveAdapterImpl<Eigen::half>(output, num_chunks,
203 allocator, align_chunks);
204 break;
205 case DT_FLOAT:
206 return new CollectiveAdapterImpl<float>(output, num_chunks, allocator,
207 align_chunks);
208 break;
209 case DT_DOUBLE:
210 return new CollectiveAdapterImpl<double>(output, num_chunks, allocator,
211 align_chunks);
212 break;
213 case DT_INT32:
214 return new CollectiveAdapterImpl<int32>(output, num_chunks, allocator,
215 align_chunks);
216 break;
217 case DT_INT64:
218 return new CollectiveAdapterImpl<int64>(output, num_chunks, allocator,
219 align_chunks);
220 break;
221 default:
222 LOG(FATAL) << "Unsupported type " << DataTypeString(output->dtype())
223 << " to MakeCollectiveAdapter";
224 return nullptr;
225 }
226 }
227
~BaseCollectiveExecutor()228 BaseCollectiveExecutor::~BaseCollectiveExecutor() {}
229
StartAbort(const Status & s)230 void BaseCollectiveExecutor::StartAbort(const Status& s) {
231 Status status;
232 {
233 mutex_lock l(status_mu_);
234 if (!status_.ok()) {
235 VLOG(2) << "BaseCollectiveExecutor already aborted, ignoring StartAbort: "
236 << s;
237 return;
238 }
239 status_ = StatusGroup::MakeDerived(Status(
240 s.code(),
241 absl::StrCat(
242 "Collective ops is aborted by: ", s.error_message(),
243 "\nThe error could be from a previous operation. Restart your "
244 "program to reset.")));
245 status = status_;
246 }
247 LOG(ERROR) << "BaseCollectiveExecutor::StartAbort " << s;
248 cem_->GetParamResolver()->StartAbort(status);
249 remote_access_->StartAbort(status);
250 if (cem_->GetNcclCommunicator() != nullptr) {
251 cem_->GetNcclCommunicator()->StartAbort(status);
252 }
253 }
254
GetStatus(const Status & s)255 Status BaseCollectiveExecutor::GetStatus(const Status& s) {
256 if (s.ok()) return s;
257 mutex_lock l(status_mu_);
258 // If the collective executor is already aborted, use the aborted status
259 // which is more likely the actual error instead of an artifact of an
260 // abortion.
261 if (!status_.ok()) {
262 VLOG(2) << "Overriding status with collective ops executor status. "
263 "Original status: "
264 << s;
265 return status_;
266 }
267 return s;
268 }
269
ExecuteAsync(OpKernelContext * ctx,const CollectiveParams * col_params,const string & exec_key,StatusCallback done)270 void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
271 const CollectiveParams* col_params,
272 const string& exec_key,
273 StatusCallback done) {
274 // See CompleteParamsAsync() how done() and the timeout callback interacts.
275 const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
276 auto done_safe = [this, done, ctx, is_callback_called](const Status& s) {
277 bool called = is_callback_called->exchange(true);
278 if (!called) {
279 if (!s.ok() && !IsCancelled(ctx->cancellation_manager())) {
280 // This is a collective error. Abort CollectiveExecutor so that this
281 // error can propagate to other workers.
282 StartAbort(s);
283 }
284 done(GetStatus(s));
285 }
286 };
287 auto timeout_microseconds = static_cast<int64>(
288 col_params->instance.impl_details.timeout_seconds * 1'000'000);
289 if (timeout_microseconds > 0) {
290 // TODO(xldrx): Share the timeout watchdog thread among collectives.
291 SchedNonBlockingClosureAfter(
292 timeout_microseconds, [this, is_callback_called, done] {
293 bool called = is_callback_called->exchange(true);
294 if (!called) {
295 Status status(error::DEADLINE_EXCEEDED,
296 "Collective has timed out during execution.");
297 StartAbort(status);
298 done(status);
299 }
300 });
301 }
302
303 Tensor* output = ctx->mutable_output(0);
304 const Tensor* input = (col_params->instance.type == REDUCTION_COLLECTIVE ||
305 col_params->instance.type == GATHER_COLLECTIVE ||
306 col_params->instance.type == PERMUTE_COLLECTIVE ||
307 col_params->instance.type == ALL_TO_ALL_COLLECTIVE ||
308 (col_params->instance.type == BROADCAST_COLLECTIVE &&
309 col_params->is_source))
310 ? &ctx->input(0)
311 : nullptr;
312 CollectiveImplementationInterface* col_impl = nullptr;
313 Status status = CreateCollective(*col_params, &col_impl);
314 if (!status.ok()) {
315 done_safe(status);
316 DCHECK_EQ(nullptr, col_impl);
317 return;
318 }
319 core::ScopedUnref unref(col_impl);
320 auto col_ctx = std::make_shared<CollectiveContext>(
321 this, cem_->GetNcclCommunicator(), dev_mgr_, ctx, CtxParams(ctx),
322 col_params, exec_key, step_id_, input, output);
323 status = col_impl->InitializeCollectiveContext(col_ctx);
324 if (!status.ok()) {
325 done_safe(status);
326 return;
327 }
328 // Run on an unbounded work queue that can handle blocking work so as to not
329 // starve executor threads.
330 col_impl->Ref();
331 profiler::TraceMeProducer producer("BaseCollectiveExecutor::ExecuteAsync");
332 RunClosure([col_impl, col_ctx, done_safe, ctx,
333 context_id = producer.GetContextId()]() {
334 core::ScopedUnref unref(col_impl);
335 profiler::TraceMeConsumer consumer(
336 [ctx] {
337 string op = profiler::TraceMeOp(ctx->op_kernel().name_view(),
338 ctx->op_kernel().type_string_view());
339 return profiler::TraceMeEncode(std::move(op),
340 {{"id", ctx->step_id()}});
341 },
342 context_id);
343 col_impl->Ref();
344 col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) {
345 core::ScopedUnref unref(col_impl);
346 done_safe(s);
347 });
348 });
349 }
350
CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,StatusCallback done)351 void BaseCollectiveExecutor::CompleteParamsAsync(
352 const DeviceAttributes& device, CollectiveParams* cp,
353 CancellationManager* cancel_mgr, StatusCallback done) {
354 // We need to make sure that when the timeout callback executes,
355 // CollectiveExecutor and CollectiveExecutorMgr are both alive. After done()
356 // is called, CollectiveExecutorMgr may be destructed and we don't have a way
357 // to keep it without making the ownerships more complicated. Therefore if the
358 // timeout callback executes, done_safe will become a no-op and the timeout
359 // callback is responsible for invoking done() at the end.
360 const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
361 auto trace_id =
362 profiler::TraceMe::ActivityStart("CollectiveExecutor::CompleteParams");
363 auto done_safe = [this, is_callback_called, cancel_mgr, trace_id,
364 done](const Status& s) {
365 profiler::TraceMe::ActivityEnd(trace_id);
366 bool called = is_callback_called->exchange(true);
367 if (!called) {
368 if (!s.ok() && !IsCancelled(cancel_mgr)) {
369 // This is a collective error. Abort CollectiveExecutor so that this
370 // error can propagate to other workers.
371 StartAbort(s);
372 }
373 done(GetStatus(s));
374 }
375 };
376 auto timeout_microseconds =
377 static_cast<int64>(cp->instance.impl_details.timeout_seconds * 1'000'000);
378 if (timeout_microseconds > 0) {
379 // TODO(xldrx): Share the timeout watchdog thread among collectives.
380 SchedNonBlockingClosureAfter(
381 timeout_microseconds, [this, is_callback_called, done]() {
382 bool called = is_callback_called->exchange(true);
383 if (!called) {
384 Status status(
385 error::DEADLINE_EXCEEDED,
386 "Collective has timed out waiting for other workers.");
387 StartAbort(status);
388 done(status);
389 }
390 });
391 }
392 cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr,
393 done_safe);
394 }
395
CreateCollective(const CollectiveParams & col_params,CollectiveImplementationInterface ** col_impl)396 Status BaseCollectiveExecutor::CreateCollective(
397 const CollectiveParams& col_params,
398 CollectiveImplementationInterface** col_impl) {
399 VLOG(2) << "CreateCollective type "
400 << DataTypeString(col_params.instance.data_type) << " name "
401 << col_params.instance.impl_details.collective_name;
402 *col_impl = nullptr;
403 switch (col_params.instance.data_type) {
404 case DT_BOOL:
405 if (col_params.instance.type == BROADCAST_COLLECTIVE) {
406 return CollectiveRegistry::Lookup(
407 col_params.instance.impl_details.collective_name, col_impl);
408 } else {
409 return errors::Internal(
410 "No collective other than broadcast supports DT_BOOL");
411 }
412 case DT_INT32:
413 if (col_params.group.device_type == DEVICE_GPU &&
414 col_params.instance.type == REDUCTION_COLLECTIVE) {
415 // TODO(b/139421603): enable int32 all-reduce on GPU.
416 return errors::Internal(
417 "Collective all-reduce does not support datatype DT_INT32 on "
418 "DEVICE_GPU");
419 } else {
420 return CollectiveRegistry::Lookup(
421 col_params.instance.impl_details.collective_name, col_impl);
422 }
423 case DT_BFLOAT16:
424 if (col_params.group.device_type == DEVICE_GPU &&
425 col_params.instance.type == REDUCTION_COLLECTIVE) {
426 return errors::Internal(
427 "Collective all-reduce does not support datatype DT_BFLOAT16 on "
428 "DEVICE_GPU");
429 } else {
430 return CollectiveRegistry::Lookup(
431 col_params.instance.impl_details.collective_name, col_impl);
432 }
433 case DT_HALF:
434 case DT_FLOAT:
435 case DT_DOUBLE:
436 case DT_INT64: {
437 return CollectiveRegistry::Lookup(
438 col_params.instance.impl_details.collective_name, col_impl);
439 }
440 default:
441 return errors::Internal(
442 "CollectiveImplementation does not support datatype ",
443 DataTypeString(col_params.instance.data_type));
444 }
445 }
446
CheckDependencies(const CollectiveParams & col_params)447 bool BaseCollectiveExecutor::CheckDependencies(
448 const CollectiveParams& col_params) {
449 for (int32_t instance : col_params.instance.impl_details.dependencies) {
450 auto find_iter = launched_.find(instance);
451 if (find_iter == launched_.end() || find_iter->second != 0) {
452 VLOG(1) << "Collective " << col_params.ToString()
453 << " blocked by instance " << instance;
454 return false;
455 }
456 }
457 return true;
458 }
459
WaitForDependencies(const CollectiveParams & col_params)460 void BaseCollectiveExecutor::WaitForDependencies(
461 const CollectiveParams& col_params) {
462 mutex_lock l(launch_mu_);
463 while (!CheckDependencies(col_params)) {
464 launch_cv_.wait(l);
465 }
466 VLOG(1) << "Unblocking collective " << col_params.ToString();
467 }
468
UnblockDependencies(const CollectiveParams & col_params)469 void BaseCollectiveExecutor::UnblockDependencies(
470 const CollectiveParams& col_params) {
471 mutex_lock l(launch_mu_);
472 if (launched_.find(col_params.instance.instance_key) == launched_.end()) {
473 const string& task_name =
474 col_params.group.task_names[col_params.default_rank];
475 const int32_t num_devices =
476 col_params.group.num_devices_per_task.at(task_name);
477 launched_[col_params.instance.instance_key] = num_devices;
478 }
479 if (--launched_[col_params.instance.instance_key] == 0) {
480 VLOG(1) << "Unblocking dependencies for collective instance "
481 << col_params.instance.instance_key;
482 launch_cv_.notify_all();
483 }
484 }
485
486 } // namespace tensorflow
487