• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/base_collective_executor.h"
16 
17 #include <algorithm>
18 #include <functional>
19 #include <utility>
20 
21 #include "tensorflow/core/common_runtime/copy_tensor.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/common_runtime/dma_helper.h"
24 #include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h"
25 #include "tensorflow/core/common_runtime/process_util.h"
26 #include "tensorflow/core/common_runtime/ring_reducer.h"
27 #include "tensorflow/core/framework/allocator.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/notification.h"
35 #include "tensorflow/core/lib/core/status.h"
36 #include "tensorflow/core/lib/strings/strcat.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/platform/tracing.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/profiler/lib/traceme.h"
41 
42 #define VALUE_IN_DEBUG_STRING false
43 
44 namespace tensorflow {
45 /*static*/
AlignedChunkElts(int64 elt_bytes,int64 total_elts,int64 num_chunks)46 int64 CollectiveAdapter::AlignedChunkElts(int64 elt_bytes, int64 total_elts,
47                                           int64 num_chunks) {
48   DCHECK_GT(num_chunks, 0);
49   int64 base_chunk_elts = (total_elts + (num_chunks - 1)) / num_chunks;
50   if (EIGEN_MAX_ALIGN_BYTES == 0) return base_chunk_elts;
51   if (EIGEN_MAX_ALIGN_BYTES <= elt_bytes) {
52     // Tolerate weird small values of EIGEN_MAX_ALIGN_BYTES
53     DCHECK_EQ(0, elt_bytes % EIGEN_MAX_ALIGN_BYTES);
54     return base_chunk_elts;
55   }
56   // elt_bytes < EIGEN_MAX_ALIGN_BYTES, which
57   // must be a common multiple of the various atomic data types.
58   DCHECK_EQ(0, EIGEN_MAX_ALIGN_BYTES % elt_bytes)
59       << "total_elts=" << total_elts << " num_chunks=" << num_chunks
60       << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES
61       << " elt_bytes=" << elt_bytes;
62   // Round bytes per chunk up to the next multiple of EIGEN_MAX_ALIGN_BYTES.
63   int64 chunk_bytes = base_chunk_elts * elt_bytes;
64   int64 diff =
65       (chunk_bytes < EIGEN_MAX_ALIGN_BYTES)
66           ? (EIGEN_MAX_ALIGN_BYTES - chunk_bytes)
67           : (EIGEN_MAX_ALIGN_BYTES - (chunk_bytes % EIGEN_MAX_ALIGN_BYTES));
68   DCHECK_EQ(0, diff % elt_bytes);
69   base_chunk_elts += (diff / elt_bytes);
70   DCHECK_EQ(0, ((base_chunk_elts * elt_bytes) % EIGEN_MAX_ALIGN_BYTES))
71       << "total_elts=" << total_elts << " num_chunks=" << num_chunks
72       << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES
73       << " base_chunk_elts=" << base_chunk_elts << " elt_bytes=" << elt_bytes;
74   return base_chunk_elts;
75 }
76 
77 namespace {
78 template <typename T>
79 class CollectiveAdapterImpl : public CollectiveAdapter {
80  public:
81   // Takes ownership of output and prepares to properly alias its chunks.
82   // Ownership is taken because the shape may temporarily change.
CollectiveAdapterImpl(Tensor * output,int64 num_chunks,Allocator * allocator,bool align_chunks)83   CollectiveAdapterImpl(Tensor* output, int64 num_chunks, Allocator* allocator,
84                         bool align_chunks)
85       : output_(std::move(*output)),
86         dt_(output_.dtype()),
87         old_shape_(output_.shape()),
88         num_chunks_(num_chunks),
89         allocator_(allocator),
90         total_elts_(output_.NumElements()),
91         chunk_elts_(align_chunks
92                         ? AlignedChunkElts(sizeof(T), total_elts_, num_chunks_)
93                         : total_elts_ / num_chunks_),
94         data_start_(reinterpret_cast<T*>(DMAHelper::base(&output_))),
95         data_end_(data_start_ + total_elts_) {
96     if (!align_chunks) {
97       DCHECK_EQ(total_elts_, num_chunks_ * chunk_elts_);
98     }
99     DCHECK_GT(chunk_elts_, 0);
100     Flatten();
101   }
102 
~CollectiveAdapterImpl()103   ~CollectiveAdapterImpl() override {}
104 
Value() const105   const Tensor& Value() const override { return output_; }
106 
107   // If necessary, flatten output.
Flatten()108   void Flatten() {
109     if (old_shape_.dims() != 1) {
110       TensorShape new_shape = TensorShape({old_shape_.num_elements()});
111       DMAHelper::UnsafeSetShape(&output_, new_shape);
112     }
113   }
114 
ConsumeFinalValue(Tensor * output)115   void ConsumeFinalValue(Tensor* output) override {
116     if (old_shape_ != output_.shape()) {
117       DMAHelper::UnsafeSetShape(&output_, old_shape_);
118     }
119     *output = std::move(output_);
120   }
121 
122   // Number of T elements in a particular chunk.
ChunkElts(int i) const123   inline int64 ChunkElts(int i) const {
124     DCHECK_LT(i, num_chunks_);
125     const T* chunk_start = std::min(data_end_, data_start_ + i * chunk_elts_);
126     const T* chunk_end = std::min(data_end_, chunk_start + chunk_elts_);
127     return chunk_end - chunk_start;
128   }
129 
ChunkBytes(int i) const130   int64 ChunkBytes(int i) const override { return sizeof(T) * ChunkElts(i); }
131 
132   // Returns a new Tensor that aliases the required chunk.
ChunkAlias(int i)133   Tensor ChunkAlias(int i) override {
134     int64 start = chunk_elts_ * i;
135     int64 num_elts = ChunkElts(i);
136     // If this chunk is empty the prior chunk might also be short
137     // so always take an empty slice from the front of the tensor
138     // to avoid an illegal offset check failure somewhere.
139     return (num_elts > 0) ? output_.Slice(start, start + num_elts)
140                           : output_.Slice(0, 0);
141   }
142 
TempChunk(int i) const143   Tensor TempChunk(int i) const override {
144     AllocationAttributes empty;
145     MEMDEBUG_CACHE_OP("CollectiveAdapterImpl::TempChunk");
146     return Tensor(allocator_, dt_, {ChunkElts(i)}, empty);
147   }
148 
DebugString() const149   string DebugString() const override {
150     return strings::StrCat(
151         "base addr ", reinterpret_cast<int64>(DMAHelper::base(&output_)),
152         " num_chunks ", num_chunks_, " total_elts ", total_elts_, " chunk_elts",
153         chunk_elts_, " value ",
154         VALUE_IN_DEBUG_STRING ? output_.SummarizeValue(1024) : "<hidden>");
155   }
156 
TBounds(const Tensor & t) const157   string TBounds(const Tensor& t) const override {
158     int64 base_addr = reinterpret_cast<int64>(DMAHelper::base(&t));
159     return strings::StrCat("(", base_addr, ", ", (base_addr + t.TotalBytes()),
160                            ")");
161   }
162 
Scalar(int v) const163   Tensor Scalar(int v) const override { return Tensor(static_cast<T>(v)); }
164 
Scalar(Allocator * a,const AllocationAttributes & attr) const165   Tensor Scalar(Allocator* a, const AllocationAttributes& attr) const override {
166     Tensor t(a, dt_, TensorShape({}), attr);
167     return t;
168   }
169 
170   Tensor output_;
171   const DataType dt_;
172   const TensorShape old_shape_;
173   const int64 num_chunks_;
174   Allocator* allocator_;
175   const int64 total_elts_;
176   const int64 chunk_elts_;
177   const T* data_start_;
178   const T* data_end_;
179 };
180 
181 }  // namespace
182 
MakeCollectiveAdapter(Tensor * output,int num_chunks,Allocator * allocator,bool align_chunks)183 CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks,
184                                          Allocator* allocator,
185                                          bool align_chunks) {
186   switch (output->dtype()) {
187     case DT_HALF:
188       return new CollectiveAdapterImpl<Eigen::half>(output, num_chunks,
189                                                     allocator, align_chunks);
190       break;
191     case DT_FLOAT:
192       return new CollectiveAdapterImpl<float>(output, num_chunks, allocator,
193                                               align_chunks);
194       break;
195     case DT_DOUBLE:
196       return new CollectiveAdapterImpl<double>(output, num_chunks, allocator,
197                                                align_chunks);
198       break;
199     case DT_INT32:
200       return new CollectiveAdapterImpl<int32>(output, num_chunks, allocator,
201                                               align_chunks);
202       break;
203     case DT_INT64:
204       return new CollectiveAdapterImpl<int64>(output, num_chunks, allocator,
205                                               align_chunks);
206       break;
207     default:
208       LOG(FATAL) << "Unsupported type " << DataTypeString(output->dtype())
209                  << " to MakeCollectiveAdapter";
210       return nullptr;
211   }
212 }
213 
~BaseCollectiveExecutor()214 BaseCollectiveExecutor::~BaseCollectiveExecutor() {}
215 
StartAbort(const Status & s)216 void BaseCollectiveExecutor::StartAbort(const Status& s) {
217   VLOG(1) << "BaseCollectiveExecutor::StartAbort " << s;
218   remote_access_->StartAbort(s);
219 }
220 
ExecuteAsync(OpKernelContext * ctx,const CollectiveParams & col_params,const string & exec_key,StatusCallback done)221 void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
222                                           const CollectiveParams& col_params,
223                                           const string& exec_key,
224                                           StatusCallback done) {
225   // On any individual collective Op failure we need to abort the
226   // BufRendezvous so that other Ops in the instance don't hang
227   // waiting for transmissions that will never happen.  Do so after a
228   // delay so that the original error status is more likely to
229   // propagate up, and peers are unlikely to re-create the purged
230   // BufRendezvous by late-arriving requests.
231   StatusCallback done_safe = [this, done](const Status& s) {
232     if (!s.ok()) {
233       Ref();  // Ensure this lasts until the closure executes.
234       SchedNonBlockingClosureAfter(1000000, [this, s] {
235         remote_access_->buf_rendezvous()->StartAbort(s);
236         Unref();
237       });
238     }
239     done(s);
240   };
241 
242   Tensor* output = ctx->mutable_output(0);
243   const Tensor* input = (col_params.instance.type == REDUCTION_COLLECTIVE ||
244                          col_params.instance.type == GATHER_COLLECTIVE ||
245                          (col_params.instance.type == BROADCAST_COLLECTIVE &&
246                           col_params.is_source))
247                             ? &ctx->input(0)
248                             : nullptr;
249   CollectiveImplementationInterface* col_impl = nullptr;
250   Status status = CreateCollective(col_params, &col_impl);
251   if (!status.ok()) {
252     done_safe(status);
253     DCHECK_EQ(nullptr, col_impl);
254     return;
255   }
256   CollectiveContext* col_ctx =
257       new CollectiveContext(this, dev_mgr_, ctx, CtxParams(ctx), col_params,
258                             exec_key, step_id_, input, output);
259   status = col_impl->InitializeCollectiveContext(col_ctx);
260   if (!status.ok()) {
261     done_safe(status);
262     delete col_ctx;
263     delete col_impl;
264     return;
265   }
266   // Run on an unbounded work queue that can handle blocking work so as to not
267   // starve executor threads.
268   remote_access_->RunClosure([col_impl, col_ctx, done_safe, ctx]() {
269     profiler::TraceMe activity(
270         [&] {
271           return strings::StrCat(ctx->op_kernel().name_view(), ":",
272                                  ctx->op_kernel().type_string_view(),
273                                  "#id=", ctx->step_id(), "#");
274         },
275         profiler::TraceMeLevel::kInfo);
276     col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) {
277       done_safe(s);
278       delete col_ctx;
279       delete col_impl;
280     });
281   });
282 }
283 
CompleteParamsAsync(const string & device,CollectiveParams * cp,CancellationManager * cancel_mgr,StatusCallback done)284 void BaseCollectiveExecutor::CompleteParamsAsync(
285     const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
286     StatusCallback done) {
287   cp->instance.gpu_ring_order = *gpu_ring_order_;
288   cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr, done);
289 }
290 
CreateCollective(const CollectiveParams & col_params,CollectiveImplementationInterface ** col_impl)291 Status BaseCollectiveExecutor::CreateCollective(
292     const CollectiveParams& col_params,
293     CollectiveImplementationInterface** col_impl) {
294   VLOG(2) << "CreateCollective type "
295           << DataTypeString(col_params.instance.data_type) << " name "
296           << col_params.instance.impl_details.collective_name;
297   *col_impl = nullptr;
298   switch (col_params.instance.data_type) {
299     case DT_BOOL:
300       if (col_params.instance.type == BROADCAST_COLLECTIVE) {
301         return CollectiveRegistry::Lookup(
302             col_params.instance.impl_details.collective_name, col_impl);
303       } else {
304         return errors::Internal(
305             "No collective other than broadcast supports DT_BOOL");
306       }
307     case DT_INT32:
308       if (col_params.group.device_type == DEVICE_GPU &&
309           col_params.instance.type == REDUCTION_COLLECTIVE) {
310         // TODO(b/139421603): enable int32 all-reduce on GPU.
311         return errors::Internal(
312             "Collective all-reduce does not support datatype DT_INT32 on "
313             "DEVICE_GPU");
314       } else {
315         return CollectiveRegistry::Lookup(
316             col_params.instance.impl_details.collective_name, col_impl);
317       }
318     case DT_HALF:
319     case DT_FLOAT:
320     case DT_DOUBLE:
321     case DT_INT64: {
322       return CollectiveRegistry::Lookup(
323           col_params.instance.impl_details.collective_name, col_impl);
324     }
325     default:
326       return errors::Internal(
327           "CollectiveImplementation does not support datatype ",
328           DataTypeString(col_params.instance.data_type));
329   }
330 }
331 
CheckDependencies(const CollectiveParams & col_params)332 bool BaseCollectiveExecutor::CheckDependencies(
333     const CollectiveParams& col_params) {
334   for (int32 instance : col_params.instance.impl_details.dependencies) {
335     auto find_iter = launched_.find(instance);
336     if (find_iter == launched_.end() || find_iter->second != 0) {
337       VLOG(1) << "Collective " << col_params.ToString()
338               << " blocked by instance " << instance;
339       return false;
340     }
341   }
342   return true;
343 }
344 
WaitForDependencies(const CollectiveParams & col_params)345 void BaseCollectiveExecutor::WaitForDependencies(
346     const CollectiveParams& col_params) {
347   mutex_lock l(launch_mu_);
348   while (!CheckDependencies(col_params)) {
349     launch_cv_.wait(l);
350   }
351   VLOG(1) << "Unblocking collective " << col_params.ToString();
352 }
353 
UnblockDependencies(const CollectiveParams & col_params)354 void BaseCollectiveExecutor::UnblockDependencies(
355     const CollectiveParams& col_params) {
356   mutex_lock l(launch_mu_);
357   if (launched_.find(col_params.instance.instance_key) == launched_.end()) {
358     const string& task_name =
359         col_params.instance.task_names[col_params.default_rank];
360     const int32 num_devices =
361         col_params.instance.num_devices_per_task.at(task_name);
362     launched_[col_params.instance.instance_key] = num_devices;
363   }
364   if (--launched_[col_params.instance.instance_key] == 0) {
365     VLOG(1) << "Unblocking dependencies for collective instance "
366             << col_params.instance.instance_key;
367     launch_cv_.notify_all();
368   }
369 }
370 
371 }  // namespace tensorflow
372