• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/base_collective_executor.h"
16 
17 #include <algorithm>
18 #include <functional>
19 #include <utility>
20 
21 #include "tensorflow/core/common_runtime/copy_tensor.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/common_runtime/dma_helper.h"
24 #include "tensorflow/core/common_runtime/process_util.h"
25 #include "tensorflow/core/framework/allocator.h"
26 #include "tensorflow/core/framework/cancellation.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/tensor_shape.h"
30 #include "tensorflow/core/framework/types.h"
31 #include "tensorflow/core/framework/types.pb.h"
32 #include "tensorflow/core/lib/core/errors.h"
33 #include "tensorflow/core/lib/core/notification.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/lib/strings/strcat.h"
36 #include "tensorflow/core/platform/macros.h"
37 #include "tensorflow/core/platform/refcount.h"
38 #include "tensorflow/core/platform/tracing.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/profiler/lib/connected_traceme.h"
41 #include "tensorflow/core/profiler/lib/traceme.h"
42 
43 #define VALUE_IN_DEBUG_STRING false
44 
45 namespace tensorflow {
46 
47 namespace {
IsCancelled(CancellationManager * cancel_mgr)48 bool IsCancelled(CancellationManager* cancel_mgr) {
49   return cancel_mgr != nullptr &&
50          (cancel_mgr->IsCancelled() || cancel_mgr->IsCancelling());
51 }
52 }  // namespace
53 
54 /*static*/
AlignedChunkElts(int64_t elt_bytes,int64_t total_elts,int64_t num_chunks)55 int64 CollectiveAdapter::AlignedChunkElts(int64_t elt_bytes, int64_t total_elts,
56                                           int64_t num_chunks) {
57   DCHECK_GT(num_chunks, 0);
58   int64_t base_chunk_elts = (total_elts + (num_chunks - 1)) / num_chunks;
59   if (EIGEN_MAX_ALIGN_BYTES == 0) return base_chunk_elts;
60   if (EIGEN_MAX_ALIGN_BYTES <= elt_bytes) {
61     // Tolerate weird small values of EIGEN_MAX_ALIGN_BYTES
62     DCHECK_EQ(0, elt_bytes % EIGEN_MAX_ALIGN_BYTES);
63     return base_chunk_elts;
64   }
65   // elt_bytes < EIGEN_MAX_ALIGN_BYTES, which
66   // must be a common multiple of the various atomic data types.
67   DCHECK_EQ(0, EIGEN_MAX_ALIGN_BYTES % elt_bytes)
68       << "total_elts=" << total_elts << " num_chunks=" << num_chunks
69       << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES
70       << " elt_bytes=" << elt_bytes;
71   // Round bytes per chunk up to the next multiple of EIGEN_MAX_ALIGN_BYTES.
72   int64_t chunk_bytes = base_chunk_elts * elt_bytes;
73   int64_t diff =
74       (chunk_bytes < EIGEN_MAX_ALIGN_BYTES)
75           ? (EIGEN_MAX_ALIGN_BYTES - chunk_bytes)
76           : (EIGEN_MAX_ALIGN_BYTES - (chunk_bytes % EIGEN_MAX_ALIGN_BYTES));
77   DCHECK_EQ(0, diff % elt_bytes);
78   base_chunk_elts += (diff / elt_bytes);
79   DCHECK_EQ(0, ((base_chunk_elts * elt_bytes) % EIGEN_MAX_ALIGN_BYTES))
80       << "total_elts=" << total_elts << " num_chunks=" << num_chunks
81       << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES
82       << " base_chunk_elts=" << base_chunk_elts << " elt_bytes=" << elt_bytes;
83   return base_chunk_elts;
84 }
85 
86 namespace {
87 template <typename T>
88 class CollectiveAdapterImpl : public CollectiveAdapter {
89  public:
90   // Takes ownership of output and prepares to properly alias its chunks.
91   // Ownership is taken because the shape may temporarily change.
CollectiveAdapterImpl(Tensor * output,int64_t num_chunks,Allocator * allocator,bool align_chunks)92   CollectiveAdapterImpl(Tensor* output, int64_t num_chunks,
93                         Allocator* allocator, bool align_chunks)
94       : output_(std::move(*output)),
95         dt_(output_.dtype()),
96         old_shape_(output_.shape()),
97         num_chunks_(num_chunks),
98         allocator_(allocator),
99         total_elts_(output_.NumElements()),
100         chunk_elts_(align_chunks
101                         ? AlignedChunkElts(sizeof(T), total_elts_, num_chunks_)
102                         : total_elts_ / num_chunks_),
103         data_start_(reinterpret_cast<T*>(DMAHelper::base(&output_))),
104         data_end_(data_start_ + total_elts_) {
105     if (!align_chunks) {
106       DCHECK_EQ(total_elts_, num_chunks_ * chunk_elts_);
107     }
108     DCHECK_GT(chunk_elts_, 0);
109     Flatten();
110   }
111 
~CollectiveAdapterImpl()112   ~CollectiveAdapterImpl() override {}
113 
Value() const114   const Tensor& Value() const override { return output_; }
115 
116   // If necessary, flatten output.
Flatten()117   void Flatten() {
118     if (old_shape_.dims() != 1) {
119       TensorShape new_shape = TensorShape({old_shape_.num_elements()});
120       DMAHelper::UnsafeSetShape(&output_, new_shape);
121     }
122   }
123 
ConsumeFinalValue(Tensor * output)124   void ConsumeFinalValue(Tensor* output) override {
125     if (old_shape_ != output_.shape()) {
126       DMAHelper::UnsafeSetShape(&output_, old_shape_);
127     }
128     *output = std::move(output_);
129   }
130 
131   // Number of T elements in a particular chunk.
ChunkElts(int i) const132   inline int64 ChunkElts(int i) const {
133     DCHECK_LT(i, num_chunks_);
134     const T* chunk_start = std::min(data_end_, data_start_ + i * chunk_elts_);
135     const T* chunk_end = std::min(data_end_, chunk_start + chunk_elts_);
136     return chunk_end - chunk_start;
137   }
138 
ChunkBytes(int i) const139   int64 ChunkBytes(int i) const override { return sizeof(T) * ChunkElts(i); }
140 
141   // Returns a new Tensor that aliases the required chunk.
ChunkAlias(int i)142   Tensor ChunkAlias(int i) override {
143     int64_t start = chunk_elts_ * i;
144     int64_t num_elts = ChunkElts(i);
145     // If this chunk is empty the prior chunk might also be short
146     // so always take an empty slice from the front of the tensor
147     // to avoid an illegal offset check failure somewhere.
148     return (num_elts > 0) ? output_.Slice(start, start + num_elts)
149                           : output_.Slice(0, 0);
150   }
151 
TempChunk(int i) const152   Tensor TempChunk(int i) const override {
153     AllocationAttributes empty;
154     ScopedMemoryDebugAnnotation op_annotation(
155         "CollectiveAdapterImpl::TempChunk");
156     return Tensor(allocator_, dt_, {ChunkElts(i)}, empty);
157   }
158 
DebugString() const159   string DebugString() const override {
160     return strings::StrCat(
161         "base addr ", reinterpret_cast<int64>(DMAHelper::base(&output_)),
162         " num_chunks ", num_chunks_, " total_elts ", total_elts_, " chunk_elts",
163         chunk_elts_, " value ",
164         VALUE_IN_DEBUG_STRING ? output_.SummarizeValue(1024) : "<hidden>");
165   }
166 
TBounds(const Tensor & t) const167   string TBounds(const Tensor& t) const override {
168     int64_t base_addr = reinterpret_cast<int64>(DMAHelper::base(&t));
169     return strings::StrCat("(", base_addr, ", ", (base_addr + t.TotalBytes()),
170                            ")");
171   }
172 
Scalar(int v) const173   Tensor Scalar(int v) const override { return Tensor(static_cast<T>(v)); }
174 
Scalar(Allocator * a,const AllocationAttributes & attr) const175   Tensor Scalar(Allocator* a, const AllocationAttributes& attr) const override {
176     Tensor t(a, dt_, TensorShape({}), attr);
177     return t;
178   }
179 
180   Tensor output_;
181   const DataType dt_;
182   const TensorShape old_shape_;
183   const int64 num_chunks_;
184   Allocator* allocator_;
185   const int64 total_elts_;
186   const int64 chunk_elts_;
187   const T* data_start_;
188   const T* data_end_;
189 };
190 
191 }  // namespace
192 
MakeCollectiveAdapter(Tensor * output,int num_chunks,Allocator * allocator,bool align_chunks)193 CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks,
194                                          Allocator* allocator,
195                                          bool align_chunks) {
196   switch (output->dtype()) {
197     case DT_BFLOAT16:
198       return new CollectiveAdapterImpl<Eigen::bfloat16>(
199           output, num_chunks, allocator, align_chunks);
200       break;
201     case DT_HALF:
202       return new CollectiveAdapterImpl<Eigen::half>(output, num_chunks,
203                                                     allocator, align_chunks);
204       break;
205     case DT_FLOAT:
206       return new CollectiveAdapterImpl<float>(output, num_chunks, allocator,
207                                               align_chunks);
208       break;
209     case DT_DOUBLE:
210       return new CollectiveAdapterImpl<double>(output, num_chunks, allocator,
211                                                align_chunks);
212       break;
213     case DT_INT32:
214       return new CollectiveAdapterImpl<int32>(output, num_chunks, allocator,
215                                               align_chunks);
216       break;
217     case DT_INT64:
218       return new CollectiveAdapterImpl<int64>(output, num_chunks, allocator,
219                                               align_chunks);
220       break;
221     default:
222       LOG(FATAL) << "Unsupported type " << DataTypeString(output->dtype())
223                  << " to MakeCollectiveAdapter";
224       return nullptr;
225   }
226 }
227 
~BaseCollectiveExecutor()228 BaseCollectiveExecutor::~BaseCollectiveExecutor() {}
229 
StartAbort(const Status & s)230 void BaseCollectiveExecutor::StartAbort(const Status& s) {
231   Status status;
232   {
233     mutex_lock l(status_mu_);
234     if (!status_.ok()) {
235       VLOG(2) << "BaseCollectiveExecutor already aborted, ignoring StartAbort: "
236               << s;
237       return;
238     }
239     status_ = StatusGroup::MakeDerived(Status(
240         s.code(),
241         absl::StrCat(
242             "Collective ops is aborted by: ", s.error_message(),
243             "\nThe error could be from a previous operation. Restart your "
244             "program to reset.")));
245     status = status_;
246   }
247   LOG(ERROR) << "BaseCollectiveExecutor::StartAbort " << s;
248   cem_->GetParamResolver()->StartAbort(status);
249   remote_access_->StartAbort(status);
250   if (cem_->GetNcclCommunicator() != nullptr) {
251     cem_->GetNcclCommunicator()->StartAbort(status);
252   }
253 }
254 
GetStatus(const Status & s)255 Status BaseCollectiveExecutor::GetStatus(const Status& s) {
256   if (s.ok()) return s;
257   mutex_lock l(status_mu_);
258   // If the collective executor is already aborted, use the aborted status
259   // which is more likely the actual error instead of an artifact of an
260   // abortion.
261   if (!status_.ok()) {
262     VLOG(2) << "Overriding status with collective ops executor status. "
263                "Original status: "
264             << s;
265     return status_;
266   }
267   return s;
268 }
269 
ExecuteAsync(OpKernelContext * ctx,const CollectiveParams * col_params,const string & exec_key,StatusCallback done)270 void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
271                                           const CollectiveParams* col_params,
272                                           const string& exec_key,
273                                           StatusCallback done) {
274   // See CompleteParamsAsync() how done() and the timeout callback interacts.
275   const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
276   auto done_safe = [this, done, ctx, is_callback_called](const Status& s) {
277     bool called = is_callback_called->exchange(true);
278     if (!called) {
279       if (!s.ok() && !IsCancelled(ctx->cancellation_manager())) {
280         // This is a collective error. Abort CollectiveExecutor so that this
281         // error can propagate to other workers.
282         StartAbort(s);
283       }
284       done(GetStatus(s));
285     }
286   };
287   auto timeout_microseconds = static_cast<int64>(
288       col_params->instance.impl_details.timeout_seconds * 1'000'000);
289   if (timeout_microseconds > 0) {
290     // TODO(xldrx): Share the timeout watchdog thread among collectives.
291     SchedNonBlockingClosureAfter(
292         timeout_microseconds, [this, is_callback_called, done] {
293           bool called = is_callback_called->exchange(true);
294           if (!called) {
295             Status status(error::DEADLINE_EXCEEDED,
296                           "Collective has timed out during execution.");
297             StartAbort(status);
298             done(status);
299           }
300         });
301   }
302 
303   Tensor* output = ctx->mutable_output(0);
304   const Tensor* input = (col_params->instance.type == REDUCTION_COLLECTIVE ||
305                          col_params->instance.type == GATHER_COLLECTIVE ||
306                          col_params->instance.type == PERMUTE_COLLECTIVE ||
307                          col_params->instance.type == ALL_TO_ALL_COLLECTIVE ||
308                          (col_params->instance.type == BROADCAST_COLLECTIVE &&
309                           col_params->is_source))
310                             ? &ctx->input(0)
311                             : nullptr;
312   CollectiveImplementationInterface* col_impl = nullptr;
313   Status status = CreateCollective(*col_params, &col_impl);
314   if (!status.ok()) {
315     done_safe(status);
316     DCHECK_EQ(nullptr, col_impl);
317     return;
318   }
319   core::ScopedUnref unref(col_impl);
320   auto col_ctx = std::make_shared<CollectiveContext>(
321       this, cem_->GetNcclCommunicator(), dev_mgr_, ctx, CtxParams(ctx),
322       col_params, exec_key, step_id_, input, output);
323   status = col_impl->InitializeCollectiveContext(col_ctx);
324   if (!status.ok()) {
325     done_safe(status);
326     return;
327   }
328   // Run on an unbounded work queue that can handle blocking work so as to not
329   // starve executor threads.
330   col_impl->Ref();
331   profiler::TraceMeProducer producer("BaseCollectiveExecutor::ExecuteAsync");
332   RunClosure([col_impl, col_ctx, done_safe, ctx,
333               context_id = producer.GetContextId()]() {
334     core::ScopedUnref unref(col_impl);
335     profiler::TraceMeConsumer consumer(
336         [ctx] {
337           string op = profiler::TraceMeOp(ctx->op_kernel().name_view(),
338                                           ctx->op_kernel().type_string_view());
339           return profiler::TraceMeEncode(std::move(op),
340                                          {{"id", ctx->step_id()}});
341         },
342         context_id);
343     col_impl->Ref();
344     col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) {
345       core::ScopedUnref unref(col_impl);
346       done_safe(s);
347     });
348   });
349 }
350 
CompleteParamsAsync(const DeviceAttributes & device,CollectiveParams * cp,CancellationManager * cancel_mgr,StatusCallback done)351 void BaseCollectiveExecutor::CompleteParamsAsync(
352     const DeviceAttributes& device, CollectiveParams* cp,
353     CancellationManager* cancel_mgr, StatusCallback done) {
354   // We need to make sure that when the timeout callback executes,
355   // CollectiveExecutor and CollectiveExecutorMgr are both alive. After done()
356   // is called, CollectiveExecutorMgr may be destructed and we don't have a way
357   // to keep it without making the ownerships more complicated. Therefore if the
358   // timeout callback executes, done_safe will become a no-op and the timeout
359   // callback is responsible for invoking done() at the end.
360   const auto is_callback_called = std::make_shared<std::atomic<bool>>(false);
361   auto trace_id =
362       profiler::TraceMe::ActivityStart("CollectiveExecutor::CompleteParams");
363   auto done_safe = [this, is_callback_called, cancel_mgr, trace_id,
364                     done](const Status& s) {
365     profiler::TraceMe::ActivityEnd(trace_id);
366     bool called = is_callback_called->exchange(true);
367     if (!called) {
368       if (!s.ok() && !IsCancelled(cancel_mgr)) {
369         // This is a collective error. Abort CollectiveExecutor so that this
370         // error can propagate to other workers.
371         StartAbort(s);
372       }
373       done(GetStatus(s));
374     }
375   };
376   auto timeout_microseconds =
377       static_cast<int64>(cp->instance.impl_details.timeout_seconds * 1'000'000);
378   if (timeout_microseconds > 0) {
379     // TODO(xldrx): Share the timeout watchdog thread among collectives.
380     SchedNonBlockingClosureAfter(
381         timeout_microseconds, [this, is_callback_called, done]() {
382           bool called = is_callback_called->exchange(true);
383           if (!called) {
384             Status status(
385                 error::DEADLINE_EXCEEDED,
386                 "Collective has timed out waiting for other workers.");
387             StartAbort(status);
388             done(status);
389           }
390         });
391   }
392   cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr,
393                                                 done_safe);
394 }
395 
CreateCollective(const CollectiveParams & col_params,CollectiveImplementationInterface ** col_impl)396 Status BaseCollectiveExecutor::CreateCollective(
397     const CollectiveParams& col_params,
398     CollectiveImplementationInterface** col_impl) {
399   VLOG(2) << "CreateCollective type "
400           << DataTypeString(col_params.instance.data_type) << " name "
401           << col_params.instance.impl_details.collective_name;
402   *col_impl = nullptr;
403   switch (col_params.instance.data_type) {
404     case DT_BOOL:
405       if (col_params.instance.type == BROADCAST_COLLECTIVE) {
406         return CollectiveRegistry::Lookup(
407             col_params.instance.impl_details.collective_name, col_impl);
408       } else {
409         return errors::Internal(
410             "No collective other than broadcast supports DT_BOOL");
411       }
412     case DT_INT32:
413       if (col_params.group.device_type == DEVICE_GPU &&
414           col_params.instance.type == REDUCTION_COLLECTIVE) {
415         // TODO(b/139421603): enable int32 all-reduce on GPU.
416         return errors::Internal(
417             "Collective all-reduce does not support datatype DT_INT32 on "
418             "DEVICE_GPU");
419       } else {
420         return CollectiveRegistry::Lookup(
421             col_params.instance.impl_details.collective_name, col_impl);
422       }
423     case DT_BFLOAT16:
424       if (col_params.group.device_type == DEVICE_GPU &&
425           col_params.instance.type == REDUCTION_COLLECTIVE) {
426         return errors::Internal(
427             "Collective all-reduce does not support datatype DT_BFLOAT16 on "
428             "DEVICE_GPU");
429       } else {
430         return CollectiveRegistry::Lookup(
431             col_params.instance.impl_details.collective_name, col_impl);
432       }
433     case DT_HALF:
434     case DT_FLOAT:
435     case DT_DOUBLE:
436     case DT_INT64: {
437       return CollectiveRegistry::Lookup(
438           col_params.instance.impl_details.collective_name, col_impl);
439     }
440     default:
441       return errors::Internal(
442           "CollectiveImplementation does not support datatype ",
443           DataTypeString(col_params.instance.data_type));
444   }
445 }
446 
CheckDependencies(const CollectiveParams & col_params)447 bool BaseCollectiveExecutor::CheckDependencies(
448     const CollectiveParams& col_params) {
449   for (int32_t instance : col_params.instance.impl_details.dependencies) {
450     auto find_iter = launched_.find(instance);
451     if (find_iter == launched_.end() || find_iter->second != 0) {
452       VLOG(1) << "Collective " << col_params.ToString()
453               << " blocked by instance " << instance;
454       return false;
455     }
456   }
457   return true;
458 }
459 
WaitForDependencies(const CollectiveParams & col_params)460 void BaseCollectiveExecutor::WaitForDependencies(
461     const CollectiveParams& col_params) {
462   mutex_lock l(launch_mu_);
463   while (!CheckDependencies(col_params)) {
464     launch_cv_.wait(l);
465   }
466   VLOG(1) << "Unblocking collective " << col_params.ToString();
467 }
468 
UnblockDependencies(const CollectiveParams & col_params)469 void BaseCollectiveExecutor::UnblockDependencies(
470     const CollectiveParams& col_params) {
471   mutex_lock l(launch_mu_);
472   if (launched_.find(col_params.instance.instance_key) == launched_.end()) {
473     const string& task_name =
474         col_params.group.task_names[col_params.default_rank];
475     const int32_t num_devices =
476         col_params.group.num_devices_per_task.at(task_name);
477     launched_[col_params.instance.instance_key] = num_devices;
478   }
479   if (--launched_[col_params.instance.instance_key] == 0) {
480     VLOG(1) << "Unblocking dependencies for collective instance "
481             << col_params.instance.instance_key;
482     launch_cv_.notify_all();
483   }
484 }
485 
486 }  // namespace tensorflow
487