• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/base_collective_executor.h"
16 
17 #include <algorithm>
18 #include <functional>
19 #include <utility>
20 
21 #include "tensorflow/core/common_runtime/copy_tensor.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/common_runtime/dma_helper.h"
24 #include "tensorflow/core/common_runtime/process_util.h"
25 #include "tensorflow/core/framework/allocator.h"
26 #include "tensorflow/core/framework/cancellation.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/framework/types.pb.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/notification.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/platform/macros.h"
37 #include "tensorflow/core/platform/refcount.h"
38 #include "tensorflow/core/platform/tracing.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/profiler/lib/connected_traceme.h"
41 #include "tensorflow/core/profiler/lib/scoped_memory_debug_annotation.h"
42 #include "tensorflow/core/profiler/lib/traceme.h"
43 
44 #define VALUE_IN_DEBUG_STRING false
45 
46 namespace tensorflow {
47 
48 namespace {
IsCancelled(CancellationManager * cancel_mgr)49 bool IsCancelled(CancellationManager* cancel_mgr) {
50   return cancel_mgr != nullptr &&
51          (cancel_mgr->IsCancelled() || cancel_mgr->IsCancelling());
52 }
53 }  // namespace
54 
55 /*static*/
AlignedChunkElts(int64_t elt_bytes,int64_t total_elts,int64_t num_chunks)56 int64_t CollectiveAdapter::AlignedChunkElts(int64_t elt_bytes,
57                                             int64_t total_elts,
58                                             int64_t num_chunks) {
59   DCHECK_GT(num_chunks, 0);
60   int64_t base_chunk_elts = (total_elts + (num_chunks - 1)) / num_chunks;
61   if (EIGEN_MAX_ALIGN_BYTES == 0) return base_chunk_elts;
62   if (EIGEN_MAX_ALIGN_BYTES <= elt_bytes) {
63     // Tolerate weird small values of EIGEN_MAX_ALIGN_BYTES
64     DCHECK_EQ(0, elt_bytes % EIGEN_MAX_ALIGN_BYTES);
65     return base_chunk_elts;
66   }
67   // elt_bytes < EIGEN_MAX_ALIGN_BYTES, which
68   // must be a common multiple of the various atomic data types.
69   DCHECK_EQ(0, EIGEN_MAX_ALIGN_BYTES % elt_bytes)
70       << "total_elts=" << total_elts << " num_chunks=" << num_chunks
71       << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES
72       << " elt_bytes=" << elt_bytes;
73   // Round bytes per chunk up to the next multiple of EIGEN_MAX_ALIGN_BYTES.
74   int64_t chunk_bytes = base_chunk_elts * elt_bytes;
75   int64_t diff =
76       (chunk_bytes < EIGEN_MAX_ALIGN_BYTES)
77           ? (EIGEN_MAX_ALIGN_BYTES - chunk_bytes)
78           : (EIGEN_MAX_ALIGN_BYTES - (chunk_bytes % EIGEN_MAX_ALIGN_BYTES));
79   DCHECK_EQ(0, diff % elt_bytes);
80   base_chunk_elts += (diff / elt_bytes);
81   DCHECK_EQ(0, ((base_chunk_elts * elt_bytes) % EIGEN_MAX_ALIGN_BYTES))
82       << "total_elts=" << total_elts << " num_chunks=" << num_chunks
83       << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES
84       << " base_chunk_elts=" << base_chunk_elts << " elt_bytes=" << elt_bytes;
85   return base_chunk_elts;
86 }
87 
88 namespace {
89 template <typename T>
90 class CollectiveAdapterImpl : public CollectiveAdapter {
91  public:
92   // Takes ownership of output and prepares to properly alias its chunks.
93   // Ownership is taken because the shape may temporarily change.
CollectiveAdapterImpl(Tensor * output,int64_t num_chunks,Allocator * allocator,bool align_chunks)94   CollectiveAdapterImpl(Tensor* output, int64_t num_chunks,
95                         Allocator* allocator, bool align_chunks)
96       : output_(std::move(*output)),
97         dt_(output_.dtype()),
98         old_shape_(output_.shape()),
99         num_chunks_(num_chunks),
100         allocator_(allocator),
101         total_elts_(output_.NumElements()),
102         chunk_elts_(align_chunks
103                         ? AlignedChunkElts(sizeof(T), total_elts_, num_chunks_)
104                         : total_elts_ / num_chunks_),
105         data_start_(reinterpret_cast<T*>(DMAHelper::base(&output_))),
106         data_end_(data_start_ + total_elts_) {
107     if (!align_chunks) {
108       DCHECK_EQ(total_elts_, num_chunks_ * chunk_elts_);
109     }
110     DCHECK_GT(chunk_elts_, 0);
111     Flatten();
112   }
113 
~CollectiveAdapterImpl()114   ~CollectiveAdapterImpl() override {}
115 
Value() const116   const Tensor& Value() const override { return output_; }
117 
118   // If necessary, flatten output.
Flatten()119   void Flatten() {
120     if (old_shape_.dims() != 1) {
121       TensorShape new_shape = TensorShape({old_shape_.num_elements()});
122       DMAHelper::UnsafeSetShape(&output_, new_shape);
123     }
124   }
125 
ConsumeFinalValue(Tensor * output)126   void ConsumeFinalValue(Tensor* output) override {
127     if (old_shape_ != output_.shape()) {
128       DMAHelper::UnsafeSetShape(&output_, old_shape_);
129     }
130     *output = std::move(output_);
131   }
132 
133   // Number of T elements in a particular chunk.
ChunkElts(int i) const134   inline int64_t ChunkElts(int i) const {
135     DCHECK_LT(i, num_chunks_);
136     const T* chunk_start = std::min(data_end_, data_start_ + i * chunk_elts_);
137     const T* chunk_end = std::min(data_end_, chunk_start + chunk_elts_);
138     return chunk_end - chunk_start;
139   }
140 
ChunkBytes(int i) const141   int64_t ChunkBytes(int i) const override { return sizeof(T) * ChunkElts(i); }
142 
143   // Returns a new Tensor that aliases the required chunk.
ChunkAlias(int i)144   Tensor ChunkAlias(int i) override {
145     int64_t start = chunk_elts_ * i;
146     int64_t num_elts = ChunkElts(i);
147     // If this chunk is empty the prior chunk might also be short
148     // so always take an empty slice from the front of the tensor
149     // to avoid an illegal offset check failure somewhere.
150     return (num_elts > 0) ? output_.Slice(start, start + num_elts)
151                           : output_.Slice(0, 0);
152   }
153 
TempChunk(int i) const154   Tensor TempChunk(int i) const override {
155     AllocationAttributes empty;
156     profiler::ScopedMemoryDebugAnnotation op_annotation(
157         "CollectiveAdapterImpl::TempChunk");
158     return Tensor(allocator_, dt_, {ChunkElts(i)}, empty);
159   }
160 
DebugString() const161   string DebugString() const override {
162     return strings::StrCat(
163         "base addr ", reinterpret_cast<int64_t>(DMAHelper::base(&output_)),
164         " num_chunks ", num_chunks_, " total_elts ", total_elts_, " chunk_elts",
165         chunk_elts_, " value ",
166         VALUE_IN_DEBUG_STRING ? output_.SummarizeValue(1024) : "<hidden>");
167   }
168 
TBounds(const Tensor & t) const169   string TBounds(const Tensor& t) const override {
170     int64_t base_addr = reinterpret_cast<int64_t>(DMAHelper::base(&t));
171     return strings::StrCat("(", base_addr, ", ", (base_addr + t.TotalBytes()),
172                            ")");
173   }
174 
Scalar(int v) const175   Tensor Scalar(int v) const override { return Tensor(static_cast<T>(v)); }
176 
Scalar(Allocator * a,const AllocationAttributes & attr) const177   Tensor Scalar(Allocator* a, const AllocationAttributes& attr) const override {
178     Tensor t(a, dt_, TensorShape({}), attr);
179     return t;
180   }
181 
182   Tensor output_;
183   const DataType dt_;
184   const TensorShape old_shape_;
185   const int64_t num_chunks_;
186   Allocator* allocator_;
187   const int64_t total_elts_;
188   const int64_t chunk_elts_;
189   const T* data_start_;
190   const T* data_end_;
191 };
192 
193 }  // namespace
194 
MakeCollectiveAdapter(Tensor * output,int num_chunks,Allocator * allocator,bool align_chunks)195 CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks,
196                                          Allocator* allocator,
197                                          bool align_chunks) {
198   switch (output->dtype()) {
199     case DT_BFLOAT16:
200       return new CollectiveAdapterImpl<Eigen::bfloat16>(
201           output, num_chunks, allocator, align_chunks);
202       break;
203     case DT_HALF:
204       return new CollectiveAdapterImpl<Eigen::half>(output, num_chunks,
205                                                     allocator, align_chunks);
206       break;
207     case DT_FLOAT:
208       return new CollectiveAdapterImpl<float>(output, num_chunks, allocator,
209                                               align_chunks);
210       break;
211     case DT_DOUBLE:
212       return new CollectiveAdapterImpl<double>(output, num_chunks, allocator,
213                                                align_chunks);
214       break;
215     case DT_INT32:
216       return new CollectiveAdapterImpl<int32>(output, num_chunks, allocator,
217                                               align_chunks);
218       break;
219     case DT_INT64:
220       return new CollectiveAdapterImpl<int64_t>(output, num_chunks, allocator,
221                                                 align_chunks);
222       break;
223     default:
224       LOG(FATAL) << "Unsupported type " << DataTypeString(output->dtype())
225                  << " to MakeCollectiveAdapter";
226       return nullptr;
227   }
228 }
229 
~BaseCollectiveExecutor()230 BaseCollectiveExecutor::~BaseCollectiveExecutor() {}
231 
StartAbort(const Status & s)232 void BaseCollectiveExecutor::StartAbort(const Status& s) {
233   Status status;
234   {
235     mutex_lock l(status_mu_);
236     if (!status_.ok()) {
237       VLOG(2) << "BaseCollectiveExecutor already aborted, ignoring StartAbort: "
238               << s;
239       return;
240     }
241     status_ = StatusGroup::MakeDerived(Status(
242         s.code(),
243         absl::StrCat(
244             "Collective ops is aborted by: ", s.error_message(),
245             "\nThe error could be from a previous operation. Restart your "
246             "program to reset.")));
247     status = status_;
248   }
249   LOG(ERROR) << "BaseCollectiveExecutor::StartAbort " << s;
250   cem_->GetParamResolver()->StartAbort(status);
251   remote_access_->StartAbort(status);
252   if (cem_->GetNcclCommunicator() != nullptr) {
253     cem_->GetNcclCommunicator()->StartAbort(status);
254   }
255 }
256 
GetStatus(const Status & s)257 Status BaseCollectiveExecutor::GetStatus(const Status& s) {
258   if (s.ok()) return s;
259   mutex_lock l(status_mu_);
260   // If the collective executor is already aborted, use the aborted status
261   // which is more likely the actual error instead of an artifact of an
262   // abortion.
263   if (!status_.ok()) {
264     VLOG(2) << "Overriding status with collective ops executor status. "
265                "Original status: "
266             << s;
267     return status_;
268   }
269   return s;
270 }
271 
ExecuteAsync(OpKernelContext * ctx,const CollectiveParams * col_params,const string & exec_key,StatusCallback done)272 void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
273                                           const CollectiveParams* col_params,
274                                           const string& exec_key,
275                                           StatusCallback done) {
276   // See CompleteParamsAsync() how done() and the timeout callback interacts.
277   const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
278   auto done_safe = [this, done, ctx, is_callback_called](const Status& s) {
279     bool called = is_callback_called->exchange(true);
280     if (!called) {
281       if (!s.ok() && !IsCancelled(ctx->cancellation_manager())) {
282         // This is a collective error. Abort CollectiveExecutor so that this
283         // error can propagate to other workers.
284         StartAbort(s);
285       }
286       done(GetStatus(s));
287     }
288   };
289   auto timeout_microseconds = static_cast<int64_t>(
290       col_params->instance.impl_details.timeout_seconds * 1'000'000);
291   if (timeout_microseconds > 0) {
292     // TODO(xldrx): Share the timeout watchdog thread among collectives.
293     SchedNonBlockingClosureAfter(
294         timeout_microseconds, [this, is_callback_called, done] {
295           bool called = is_callback_called->exchange(true);
296           if (!called) {
297             Status status(error::DEADLINE_EXCEEDED,
298                           "Collective has timed out during execution.");
299             StartAbort(status);
300             done(status);
301           }
302         });
303   }
304 
305   Tensor* output = ctx->mutable_output(0);
306   const Tensor* input = (col_params->instance.type == REDUCTION_COLLECTIVE ||
307                          col_params->instance.type == GATHER_COLLECTIVE ||
308                          col_params->instance.type == PERMUTE_COLLECTIVE ||
309                          col_params->instance.type == ALL_TO_ALL_COLLECTIVE ||
310                          (col_params->instance.type == BROADCAST_COLLECTIVE &&
311                           col_params->is_source))
312                             ? &ctx->input(0)
313                             : nullptr;
314   CollectiveImplementationInterface* col_impl = nullptr;
315   Status status = CreateCollective(*col_params, &col_impl);
316   if (!status.ok()) {
317     done_safe(status);
318     DCHECK_EQ(nullptr, col_impl);
319     return;
320   }
321   core::ScopedUnref unref(col_impl);
322   auto col_ctx = std::make_shared<CollectiveContext>(
323       this, cem_->GetNcclCommunicator(), dev_mgr_, ctx, CtxParams(ctx),
324       col_params, exec_key, step_id_, input, output);
325   status = col_impl->InitializeCollectiveContext(col_ctx);
326   if (!status.ok()) {
327     done_safe(status);
328     return;
329   }
330   // Run on an unbounded work queue that can handle blocking work so as to not
331   // starve executor threads.
332   col_impl->Ref();
333   profiler::TraceMeProducer producer("BaseCollectiveExecutor::ExecuteAsync");
334   RunClosure([col_impl, col_ctx, done_safe, ctx,
335               context_id = producer.GetContextId()]() {
336     core::ScopedUnref unref(col_impl);
337     profiler::TraceMeConsumer consumer(
338         [ctx, col_ctx] {
339           string op = profiler::TraceMeOp(ctx->op_kernel().name_view(),
340                                           ctx->op_kernel().type_string_view());
341           return profiler::TraceMeEncode(
342               std::move(op),
343               {{"id", ctx->step_id()},
344                {"instance_key", col_ctx->col_params->instance.instance_key},
345                {"collective", col_ctx->col_params->instance.type}});
346         },
347         context_id);
348     col_impl->Ref();
349     col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) {
350       core::ScopedUnref unref(col_impl);
351       done_safe(s);
352     });
353   });
354 }
355 
CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,StatusCallback done)356 void BaseCollectiveExecutor::CompleteParamsAsync(
357     const DeviceAttributes& device, CollectiveParams* cp,
358     CancellationManager* cancel_mgr, StatusCallback done) {
359   // We need to make sure that when the timeout callback executes,
360   // CollectiveExecutor and CollectiveExecutorMgr are both alive. After done()
361   // is called, CollectiveExecutorMgr may be destructed and we don't have a way
362   // to keep it without making the ownerships more complicated. Therefore if the
363   // timeout callback executes, done_safe will become a no-op and the timeout
364   // callback is responsible for invoking done() at the end.
365   const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
366   int64_t trace_id = profiler::TraceMe::ActivityStart([cp]() {
367     return profiler::TraceMeEncode("CollectiveExecutor::CompleteParams",
368                                    {{"group_key", cp->group.group_key},
369                                     {"group_size", cp->group.group_size}});
370   });
371 
372   auto done_safe = [this, is_callback_called, cancel_mgr, trace_id,
373                     done](const Status& s) {
374     profiler::TraceMe::ActivityEnd(trace_id);
375     bool called = is_callback_called->exchange(true);
376     if (!called) {
377       if (!s.ok() && !IsCancelled(cancel_mgr)) {
378         // This is a collective error. Abort CollectiveExecutor so that this
379         // error can propagate to other workers.
380         StartAbort(s);
381       }
382       done(GetStatus(s));
383     }
384   };
385   auto timeout_microseconds = static_cast<int64_t>(
386       cp->instance.impl_details.timeout_seconds * 1'000'000);
387   if (timeout_microseconds > 0) {
388     // TODO(xldrx): Share the timeout watchdog thread among collectives.
389     SchedNonBlockingClosureAfter(
390         timeout_microseconds, [this, is_callback_called, done]() {
391           bool called = is_callback_called->exchange(true);
392           if (!called) {
393             Status status(
394                 error::DEADLINE_EXCEEDED,
395                 "Collective has timed out waiting for other workers.");
396             StartAbort(status);
397             done(status);
398           }
399         });
400   }
401   cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr,
402                                                 done_safe);
403 }
404 
CreateCollective(const CollectiveParams & col_params,CollectiveImplementationInterface ** col_impl)405 Status BaseCollectiveExecutor::CreateCollective(
406     const CollectiveParams& col_params,
407     CollectiveImplementationInterface** col_impl) {
408   VLOG(2) << "CreateCollective type "
409           << DataTypeString(col_params.instance.data_type) << " name "
410           << col_params.instance.impl_details.collective_name;
411   *col_impl = nullptr;
412   switch (col_params.instance.data_type) {
413     case DT_BOOL:
414       if (col_params.instance.type == BROADCAST_COLLECTIVE) {
415         return CollectiveRegistry::Lookup(
416             col_params.instance.impl_details.collective_name, col_impl);
417       } else {
418         return errors::Internal(
419             "No collective other than broadcast supports DT_BOOL");
420       }
421     case DT_INT32:
422       if (col_params.group.device_type == DEVICE_GPU &&
423           col_params.instance.type == REDUCTION_COLLECTIVE) {
424         // TODO(b/139421603): enable int32 all-reduce on GPU.
425         return errors::Internal(
426             "Collective all-reduce does not support datatype DT_INT32 on "
427             "DEVICE_GPU");
428       } else {
429         return CollectiveRegistry::Lookup(
430             col_params.instance.impl_details.collective_name, col_impl);
431       }
432     case DT_BFLOAT16:
433       if (col_params.group.device_type == DEVICE_GPU &&
434           col_params.instance.type == REDUCTION_COLLECTIVE) {
435         return errors::Internal(
436             "Collective all-reduce does not support datatype DT_BFLOAT16 on "
437             "DEVICE_GPU");
438       } else {
439         return CollectiveRegistry::Lookup(
440             col_params.instance.impl_details.collective_name, col_impl);
441       }
442     case DT_HALF:
443     case DT_FLOAT:
444     case DT_DOUBLE:
445     case DT_INT64: {
446       return CollectiveRegistry::Lookup(
447           col_params.instance.impl_details.collective_name, col_impl);
448     }
449     default:
450       return errors::Internal(
451           "CollectiveImplementation does not support datatype ",
452           DataTypeString(col_params.instance.data_type));
453   }
454 }
455 
CheckDependencies(const CollectiveParams & col_params)456 bool BaseCollectiveExecutor::CheckDependencies(
457     const CollectiveParams& col_params) {
458   for (int32_t instance : col_params.instance.impl_details.dependencies) {
459     auto find_iter = launched_.find(instance);
460     if (find_iter == launched_.end() || find_iter->second != 0) {
461       VLOG(1) << "Collective " << col_params.ToString()
462               << " blocked by instance " << instance;
463       return false;
464     }
465   }
466   return true;
467 }
468 
WaitForDependencies(const CollectiveParams & col_params)469 void BaseCollectiveExecutor::WaitForDependencies(
470     const CollectiveParams& col_params) {
471   mutex_lock l(launch_mu_);
472   while (!CheckDependencies(col_params)) {
473     launch_cv_.wait(l);
474   }
475   VLOG(1) << "Unblocking collective " << col_params.ToString();
476 }
477 
UnblockDependencies(const CollectiveParams & col_params)478 void BaseCollectiveExecutor::UnblockDependencies(
479     const CollectiveParams& col_params) {
480   mutex_lock l(launch_mu_);
481   if (launched_.find(col_params.instance.instance_key) == launched_.end()) {
482     const string& task_name =
483         col_params.group.members[col_params.default_rank].task;
484     const int32_t num_devices =
485         col_params.group.num_devices_per_task.at(task_name);
486     launched_[col_params.instance.instance_key] = num_devices;
487   }
488   if (--launched_[col_params.instance.instance_key] == 0) {
489     VLOG(1) << "Unblocking dependencies for collective instance "
490             << col_params.instance.instance_key;
491     launch_cv_.notify_all();
492   }
493 }
494 
495 }  // namespace tensorflow
496