• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/base_collective_executor.h"
16 
17 #include <algorithm>
18 #include <functional>
19 #include <utility>
20 
21 #include "tensorflow/core/common_runtime/copy_tensor.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/common_runtime/dma_helper.h"
24 #include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h"
25 #include "tensorflow/core/common_runtime/process_util.h"
26 #include "tensorflow/core/common_runtime/ring_reducer.h"
27 #include "tensorflow/core/framework/allocator.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/notification.h"
35 #include "tensorflow/core/lib/core/status.h"
36 #include "tensorflow/core/lib/strings/strcat.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/platform/types.h"
39 
40 #define VALUE_IN_DEBUG_STRING false
41 
42 namespace tensorflow {
43 /*static*/
AlignedChunkElts(int64 elt_bytes,int64 total_elts,int64 num_chunks)44 int64 CollectiveAdapter::AlignedChunkElts(int64 elt_bytes, int64 total_elts,
45                                           int64 num_chunks) {
46   DCHECK_GT(num_chunks, 0);
47   int64 base_chunk_elts = (total_elts + (num_chunks - 1)) / num_chunks;
48   if (EIGEN_MAX_ALIGN_BYTES == 0) return base_chunk_elts;
49   if (EIGEN_MAX_ALIGN_BYTES <= elt_bytes) {
50     // Tolerate weird small values of EIGEN_MAX_ALIGN_BYTES
51     DCHECK_EQ(0, elt_bytes % EIGEN_MAX_ALIGN_BYTES);
52     return base_chunk_elts;
53   }
54   // elt_bytes < EIGEN_MAX_ALIGN_BYTES, which
55   // must be a common multiple of the various atomic data types.
56   DCHECK_EQ(0, EIGEN_MAX_ALIGN_BYTES % elt_bytes)
57       << "total_elts=" << total_elts << " num_chunks=" << num_chunks
58       << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES
59       << " elt_bytes=" << elt_bytes;
60   // Round bytes per chunk up to the next multiple of EIGEN_MAX_ALIGN_BYTES.
61   int64 chunk_bytes = base_chunk_elts * elt_bytes;
62   int64 diff =
63       (chunk_bytes < EIGEN_MAX_ALIGN_BYTES)
64           ? (EIGEN_MAX_ALIGN_BYTES - chunk_bytes)
65           : (EIGEN_MAX_ALIGN_BYTES - (chunk_bytes % EIGEN_MAX_ALIGN_BYTES));
66   DCHECK_EQ(0, diff % elt_bytes);
67   base_chunk_elts += (diff / elt_bytes);
68   DCHECK_EQ(0, ((base_chunk_elts * elt_bytes) % EIGEN_MAX_ALIGN_BYTES))
69       << "total_elts=" << total_elts << " num_chunks=" << num_chunks
70       << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES
71       << " base_chunk_elts=" << base_chunk_elts << " elt_bytes=" << elt_bytes;
72   return base_chunk_elts;
73 }
74 
75 namespace {
76 template <typename T>
77 class CollectiveAdapterImpl : public CollectiveAdapter {
78  public:
79   // Takes ownership of output and prepares to properly alias its chunks.
80   // Ownership is taken because the shape may temporarily change.
CollectiveAdapterImpl(Tensor * output,int64 num_chunks,Allocator * allocator,bool align_chunks)81   CollectiveAdapterImpl(Tensor* output, int64 num_chunks, Allocator* allocator,
82                         bool align_chunks)
83       : output_(std::move(*output)),
84         dt_(output_.dtype()),
85         old_shape_(output_.shape()),
86         num_chunks_(num_chunks),
87         allocator_(allocator),
88         total_elts_(output_.NumElements()),
89         chunk_elts_(align_chunks
90                         ? AlignedChunkElts(sizeof(T), total_elts_, num_chunks_)
91                         : total_elts_ / num_chunks_),
92         data_start_(reinterpret_cast<T*>(DMAHelper::base(&output_))),
93         data_end_(data_start_ + total_elts_) {
94     if (!align_chunks) {
95       DCHECK_EQ(total_elts_, num_chunks_ * chunk_elts_);
96     }
97     DCHECK_GT(chunk_elts_, 0);
98     Flatten();
99   }
100 
~CollectiveAdapterImpl()101   ~CollectiveAdapterImpl() override {}
102 
Value() const103   const Tensor& Value() const override { return output_; }
104 
105   // If necessary, flatten output.
Flatten()106   void Flatten() {
107     if (old_shape_.dims() != 1) {
108       TensorShape new_shape = TensorShape({old_shape_.num_elements()});
109       DMAHelper::UnsafeSetShape(&output_, new_shape);
110     }
111   }
112 
ConsumeFinalValue(Tensor * output)113   void ConsumeFinalValue(Tensor* output) override {
114     if (old_shape_ != output_.shape()) {
115       DMAHelper::UnsafeSetShape(&output_, old_shape_);
116     }
117     *output = std::move(output_);
118   }
119 
120   // Number of T elements in a particular chunk.
ChunkElts(int i) const121   inline int64 ChunkElts(int i) const {
122     DCHECK_LT(i, num_chunks_);
123     const T* chunk_start = std::min(data_end_, data_start_ + i * chunk_elts_);
124     const T* chunk_end = std::min(data_end_, chunk_start + chunk_elts_);
125     return chunk_end - chunk_start;
126   }
127 
ChunkBytes(int i) const128   int64 ChunkBytes(int i) const override { return sizeof(T) * ChunkElts(i); }
129 
130   // Returns a new Tensor that aliases the required chunk.
ChunkAlias(int i)131   Tensor ChunkAlias(int i) override {
132     int64 start = chunk_elts_ * i;
133     int64 num_elts = ChunkElts(i);
134     // If this chunk is empty the prior chunk might also be short
135     // so always take an empty slice from the front of the tensor
136     // to avoid an illegal offset check failure somewhere.
137     return (num_elts > 0) ? output_.Slice(start, start + num_elts)
138                           : output_.Slice(0, 0);
139   }
140 
TempChunk(int i) const141   Tensor TempChunk(int i) const override {
142     AllocationAttributes empty;
143     return Tensor(allocator_, dt_, {ChunkElts(i)}, empty);
144   }
145 
DebugString() const146   string DebugString() const override {
147     return strings::StrCat(
148         "base addr ", reinterpret_cast<int64>(DMAHelper::base(&output_)),
149         " num_chunks ", num_chunks_, " total_elts ", total_elts_, " chunk_elts",
150         chunk_elts_, " value ",
151         VALUE_IN_DEBUG_STRING ? output_.SummarizeValue(1024) : "<hidden>");
152   }
153 
TBounds(const Tensor & t) const154   string TBounds(const Tensor& t) const override {
155     int64 base_addr = reinterpret_cast<int64>(DMAHelper::base(&t));
156     return strings::StrCat("(", base_addr, ", ", (base_addr + t.TotalBytes()),
157                            ")");
158   }
159 
Scalar(int v) const160   Tensor Scalar(int v) const override {
161     Tensor t(dt_, TensorShape({}));
162     t.scalar<T>()() = v;
163     return t;
164   }
165 
Scalar(Allocator * a) const166   Tensor Scalar(Allocator* a) const override {
167     Tensor t(a, dt_, TensorShape({}));
168     return t;
169   }
170 
171   Tensor output_;
172   const DataType dt_;
173   const TensorShape old_shape_;
174   const int64 num_chunks_;
175   Allocator* allocator_;
176   const int64 total_elts_;
177   const int64 chunk_elts_;
178   const T* data_start_;
179   const T* data_end_;
180 };
181 
182 }  // namespace
183 
MakeCollectiveAdapter(Tensor * output,int num_chunks,Allocator * allocator,bool align_chunks)184 CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks,
185                                          Allocator* allocator,
186                                          bool align_chunks) {
187   switch (output->dtype()) {
188     case DT_FLOAT:
189       return new CollectiveAdapterImpl<float>(output, num_chunks, allocator,
190                                               align_chunks);
191       break;
192     case DT_DOUBLE:
193       return new CollectiveAdapterImpl<double>(output, num_chunks, allocator,
194                                                align_chunks);
195       break;
196     case DT_INT32:
197       return new CollectiveAdapterImpl<int32>(output, num_chunks, allocator,
198                                               align_chunks);
199       break;
200     case DT_INT64:
201       return new CollectiveAdapterImpl<int64>(output, num_chunks, allocator,
202                                               align_chunks);
203       break;
204     default:
205       LOG(FATAL) << "Unsupported type " << output->dtype()
206                  << " to MakeCollectiveAdapter";
207       return nullptr;
208   }
209 }
210 
~BaseCollectiveExecutor()211 BaseCollectiveExecutor::~BaseCollectiveExecutor() {}
212 
StartAbort(const Status & s)213 void BaseCollectiveExecutor::StartAbort(const Status& s) {
214   LOG(WARNING) << "BaseCollectiveExecutor::StartAbort " << s;
215   remote_access_->StartAbort(s);
216 }
217 
ExecuteAsync(OpKernelContext * ctx,const CollectiveParams & col_params,const string & exec_key,StatusCallback done)218 void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
219                                           const CollectiveParams& col_params,
220                                           const string& exec_key,
221                                           StatusCallback done) {
222   // On any individual collective Op failure we need to abort the
223   // BufRendezvous so that other Ops in the instance don't hang
224   // waiting for transmissions that will never happen.  Do so after a
225   // delay so that the original error status is more likely to
226   // propagate up, and peers are unlikely to re-create the purged
227   // BufRendezvous by late-arriving requests.
228   StatusCallback done_safe = [this, done](const Status& s) {
229     if (!s.ok()) {
230       Ref();  // Ensure this lasts until the closure executes.
231       SchedNonBlockingClosureAfter(1000000, [this, s] {
232         remote_access_->buf_rendezvous()->StartAbort(s);
233         Unref();
234       });
235     }
236     done(s);
237   };
238 
239   Tensor* output = ctx->mutable_output(0);
240   const Tensor* input = (col_params.instance.type == REDUCTION_COLLECTIVE ||
241                          col_params.instance.type == GATHER_COLLECTIVE ||
242                          (col_params.instance.type == BROADCAST_COLLECTIVE &&
243                           col_params.is_source))
244                             ? &ctx->input(0)
245                             : nullptr;
246   CollectiveImplementationInterface* col_impl = nullptr;
247   Status status = CreateCollective(col_params, &col_impl);
248   if (!status.ok()) {
249     done_safe(status);
250     DCHECK_EQ(nullptr, col_impl);
251     return;
252   }
253   CollectiveContext* col_ctx =
254       new CollectiveContext(this, dev_mgr_, ctx, CtxParams(ctx), col_params,
255                             exec_key, step_id_, input, output);
256   status = col_impl->InitializeCollectiveContext(col_ctx);
257   if (!status.ok()) {
258     done_safe(status);
259     delete col_ctx;
260     delete col_impl;
261     return;
262   }
263   // Run in an I/O thread, so as not to starve the executor threads.
264   // TODO(b/80529858): Instead of forking every per-device Collective
265   // Op off into its own thread, consider queuing them on a
266   // fixed-size thread-pool dedicated to running CollectiveOps.
267   SchedClosure([col_impl, col_ctx, done_safe]() {
268     col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) {
269       done_safe(s);
270       delete col_ctx;
271       delete col_impl;
272     });
273   });
274 }
275 
CompleteParamsAsync(const string & device,CollectiveParams * cp,CancellationManager * cancel_mgr,StatusCallback done)276 void BaseCollectiveExecutor::CompleteParamsAsync(
277     const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
278     StatusCallback done) {
279   cp->instance.gpu_ring_order = *gpu_ring_order_;
280   cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr, done);
281 }
282 
CreateCollective(const CollectiveParams & col_params,CollectiveImplementationInterface ** col_impl)283 Status BaseCollectiveExecutor::CreateCollective(
284     const CollectiveParams& col_params,
285     CollectiveImplementationInterface** col_impl) {
286   *col_impl = nullptr;
287   Status status;
288   switch (col_params.instance.data_type) {
289     case DT_INT32:
290       if (col_params.group.device_type == DEVICE_GPU) {
291         status = errors::Internal(
292             "CollectiveImplementation does not support datatype DT_INT32 on "
293             "DEVICE_GPU");
294       }
295       TF_FALLTHROUGH_INTENDED;
296     case DT_FLOAT:
297     case DT_DOUBLE:
298     case DT_INT64: {
299       status = CollectiveRegistry::Lookup(
300           col_params.instance.impl_details.collective_name, col_impl);
301       break;
302     }
303     default:
304       status = errors::Internal(
305           "CollectiveImplementation does not support datatype ",
306           col_params.instance.data_type);
307   }
308   return status;
309 }
310 
CheckDependencies(const CollectiveParams & col_params)311 bool BaseCollectiveExecutor::CheckDependencies(
312     const CollectiveParams& col_params) {
313   for (int32 instance : col_params.instance.impl_details.dependencies) {
314     auto find_iter = launched_.find(instance);
315     if (find_iter == launched_.end() || find_iter->second != 0) {
316       VLOG(1) << "Collective " << col_params.ToString()
317               << " blocked by instance " << instance;
318       return false;
319     }
320   }
321   return true;
322 }
323 
WaitForDependencies(const CollectiveParams & col_params)324 void BaseCollectiveExecutor::WaitForDependencies(
325     const CollectiveParams& col_params) {
326   mutex_lock l(launch_mu_);
327   while (!CheckDependencies(col_params)) {
328     launch_cv_.wait(l);
329   }
330   VLOG(1) << "Unblocking collective " << col_params.ToString();
331 }
332 
Launched(const CollectiveParams & col_params)333 void BaseCollectiveExecutor::Launched(const CollectiveParams& col_params) {
334   mutex_lock l(launch_mu_);
335   if (launched_.find(col_params.instance.instance_key) == launched_.end()) {
336     const string& task_name =
337         col_params.instance.task_names[col_params.default_rank];
338     const int32 num_devices =
339         col_params.instance.num_devices_per_task.at(task_name);
340     launched_[col_params.instance.instance_key] = num_devices;
341   }
342   if (--launched_[col_params.instance.instance_key] == 0) {
343     VLOG(1) << "Unblocking dependencies for collective instance "
344             << col_params.instance.instance_key;
345     launch_cv_.notify_all();
346   }
347 }
348 
349 }  // namespace tensorflow
350