1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/base_collective_executor.h"
16
17 #include <algorithm>
18 #include <functional>
19 #include <utility>
20
21 #include "tensorflow/core/common_runtime/copy_tensor.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/common_runtime/dma_helper.h"
24 #include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h"
25 #include "tensorflow/core/common_runtime/process_util.h"
26 #include "tensorflow/core/common_runtime/ring_reducer.h"
27 #include "tensorflow/core/framework/allocator.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/notification.h"
35 #include "tensorflow/core/lib/core/status.h"
36 #include "tensorflow/core/lib/strings/strcat.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/platform/tracing.h"
39 #include "tensorflow/core/platform/types.h"
40 #include "tensorflow/core/profiler/lib/traceme.h"
41
42 #define VALUE_IN_DEBUG_STRING false
43
44 namespace tensorflow {
45 /*static*/
AlignedChunkElts(int64 elt_bytes,int64 total_elts,int64 num_chunks)46 int64 CollectiveAdapter::AlignedChunkElts(int64 elt_bytes, int64 total_elts,
47 int64 num_chunks) {
48 DCHECK_GT(num_chunks, 0);
49 int64 base_chunk_elts = (total_elts + (num_chunks - 1)) / num_chunks;
50 if (EIGEN_MAX_ALIGN_BYTES == 0) return base_chunk_elts;
51 if (EIGEN_MAX_ALIGN_BYTES <= elt_bytes) {
52 // Tolerate weird small values of EIGEN_MAX_ALIGN_BYTES
53 DCHECK_EQ(0, elt_bytes % EIGEN_MAX_ALIGN_BYTES);
54 return base_chunk_elts;
55 }
56 // elt_bytes < EIGEN_MAX_ALIGN_BYTES, which
57 // must be a common multiple of the various atomic data types.
58 DCHECK_EQ(0, EIGEN_MAX_ALIGN_BYTES % elt_bytes)
59 << "total_elts=" << total_elts << " num_chunks=" << num_chunks
60 << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES
61 << " elt_bytes=" << elt_bytes;
62 // Round bytes per chunk up to the next multiple of EIGEN_MAX_ALIGN_BYTES.
63 int64 chunk_bytes = base_chunk_elts * elt_bytes;
64 int64 diff =
65 (chunk_bytes < EIGEN_MAX_ALIGN_BYTES)
66 ? (EIGEN_MAX_ALIGN_BYTES - chunk_bytes)
67 : (EIGEN_MAX_ALIGN_BYTES - (chunk_bytes % EIGEN_MAX_ALIGN_BYTES));
68 DCHECK_EQ(0, diff % elt_bytes);
69 base_chunk_elts += (diff / elt_bytes);
70 DCHECK_EQ(0, ((base_chunk_elts * elt_bytes) % EIGEN_MAX_ALIGN_BYTES))
71 << "total_elts=" << total_elts << " num_chunks=" << num_chunks
72 << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES
73 << " base_chunk_elts=" << base_chunk_elts << " elt_bytes=" << elt_bytes;
74 return base_chunk_elts;
75 }
76
77 namespace {
78 template <typename T>
79 class CollectiveAdapterImpl : public CollectiveAdapter {
80 public:
81 // Takes ownership of output and prepares to properly alias its chunks.
82 // Ownership is taken because the shape may temporarily change.
CollectiveAdapterImpl(Tensor * output,int64 num_chunks,Allocator * allocator,bool align_chunks)83 CollectiveAdapterImpl(Tensor* output, int64 num_chunks, Allocator* allocator,
84 bool align_chunks)
85 : output_(std::move(*output)),
86 dt_(output_.dtype()),
87 old_shape_(output_.shape()),
88 num_chunks_(num_chunks),
89 allocator_(allocator),
90 total_elts_(output_.NumElements()),
91 chunk_elts_(align_chunks
92 ? AlignedChunkElts(sizeof(T), total_elts_, num_chunks_)
93 : total_elts_ / num_chunks_),
94 data_start_(reinterpret_cast<T*>(DMAHelper::base(&output_))),
95 data_end_(data_start_ + total_elts_) {
96 if (!align_chunks) {
97 DCHECK_EQ(total_elts_, num_chunks_ * chunk_elts_);
98 }
99 DCHECK_GT(chunk_elts_, 0);
100 Flatten();
101 }
102
~CollectiveAdapterImpl()103 ~CollectiveAdapterImpl() override {}
104
Value() const105 const Tensor& Value() const override { return output_; }
106
107 // If necessary, flatten output.
Flatten()108 void Flatten() {
109 if (old_shape_.dims() != 1) {
110 TensorShape new_shape = TensorShape({old_shape_.num_elements()});
111 DMAHelper::UnsafeSetShape(&output_, new_shape);
112 }
113 }
114
ConsumeFinalValue(Tensor * output)115 void ConsumeFinalValue(Tensor* output) override {
116 if (old_shape_ != output_.shape()) {
117 DMAHelper::UnsafeSetShape(&output_, old_shape_);
118 }
119 *output = std::move(output_);
120 }
121
122 // Number of T elements in a particular chunk.
ChunkElts(int i) const123 inline int64 ChunkElts(int i) const {
124 DCHECK_LT(i, num_chunks_);
125 const T* chunk_start = std::min(data_end_, data_start_ + i * chunk_elts_);
126 const T* chunk_end = std::min(data_end_, chunk_start + chunk_elts_);
127 return chunk_end - chunk_start;
128 }
129
ChunkBytes(int i) const130 int64 ChunkBytes(int i) const override { return sizeof(T) * ChunkElts(i); }
131
132 // Returns a new Tensor that aliases the required chunk.
ChunkAlias(int i)133 Tensor ChunkAlias(int i) override {
134 int64 start = chunk_elts_ * i;
135 int64 num_elts = ChunkElts(i);
136 // If this chunk is empty the prior chunk might also be short
137 // so always take an empty slice from the front of the tensor
138 // to avoid an illegal offset check failure somewhere.
139 return (num_elts > 0) ? output_.Slice(start, start + num_elts)
140 : output_.Slice(0, 0);
141 }
142
TempChunk(int i) const143 Tensor TempChunk(int i) const override {
144 AllocationAttributes empty;
145 MEMDEBUG_CACHE_OP("CollectiveAdapterImpl::TempChunk");
146 return Tensor(allocator_, dt_, {ChunkElts(i)}, empty);
147 }
148
DebugString() const149 string DebugString() const override {
150 return strings::StrCat(
151 "base addr ", reinterpret_cast<int64>(DMAHelper::base(&output_)),
152 " num_chunks ", num_chunks_, " total_elts ", total_elts_, " chunk_elts",
153 chunk_elts_, " value ",
154 VALUE_IN_DEBUG_STRING ? output_.SummarizeValue(1024) : "<hidden>");
155 }
156
TBounds(const Tensor & t) const157 string TBounds(const Tensor& t) const override {
158 int64 base_addr = reinterpret_cast<int64>(DMAHelper::base(&t));
159 return strings::StrCat("(", base_addr, ", ", (base_addr + t.TotalBytes()),
160 ")");
161 }
162
Scalar(int v) const163 Tensor Scalar(int v) const override { return Tensor(static_cast<T>(v)); }
164
Scalar(Allocator * a,const AllocationAttributes & attr) const165 Tensor Scalar(Allocator* a, const AllocationAttributes& attr) const override {
166 Tensor t(a, dt_, TensorShape({}), attr);
167 return t;
168 }
169
170 Tensor output_;
171 const DataType dt_;
172 const TensorShape old_shape_;
173 const int64 num_chunks_;
174 Allocator* allocator_;
175 const int64 total_elts_;
176 const int64 chunk_elts_;
177 const T* data_start_;
178 const T* data_end_;
179 };
180
181 } // namespace
182
MakeCollectiveAdapter(Tensor * output,int num_chunks,Allocator * allocator,bool align_chunks)183 CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks,
184 Allocator* allocator,
185 bool align_chunks) {
186 switch (output->dtype()) {
187 case DT_HALF:
188 return new CollectiveAdapterImpl<Eigen::half>(output, num_chunks,
189 allocator, align_chunks);
190 break;
191 case DT_FLOAT:
192 return new CollectiveAdapterImpl<float>(output, num_chunks, allocator,
193 align_chunks);
194 break;
195 case DT_DOUBLE:
196 return new CollectiveAdapterImpl<double>(output, num_chunks, allocator,
197 align_chunks);
198 break;
199 case DT_INT32:
200 return new CollectiveAdapterImpl<int32>(output, num_chunks, allocator,
201 align_chunks);
202 break;
203 case DT_INT64:
204 return new CollectiveAdapterImpl<int64>(output, num_chunks, allocator,
205 align_chunks);
206 break;
207 default:
208 LOG(FATAL) << "Unsupported type " << DataTypeString(output->dtype())
209 << " to MakeCollectiveAdapter";
210 return nullptr;
211 }
212 }
213
~BaseCollectiveExecutor()214 BaseCollectiveExecutor::~BaseCollectiveExecutor() {}
215
StartAbort(const Status & s)216 void BaseCollectiveExecutor::StartAbort(const Status& s) {
217 VLOG(1) << "BaseCollectiveExecutor::StartAbort " << s;
218 remote_access_->StartAbort(s);
219 }
220
ExecuteAsync(OpKernelContext * ctx,const CollectiveParams & col_params,const string & exec_key,StatusCallback done)221 void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
222 const CollectiveParams& col_params,
223 const string& exec_key,
224 StatusCallback done) {
225 // On any individual collective Op failure we need to abort the
226 // BufRendezvous so that other Ops in the instance don't hang
227 // waiting for transmissions that will never happen. Do so after a
228 // delay so that the original error status is more likely to
229 // propagate up, and peers are unlikely to re-create the purged
230 // BufRendezvous by late-arriving requests.
231 StatusCallback done_safe = [this, done](const Status& s) {
232 if (!s.ok()) {
233 Ref(); // Ensure this lasts until the closure executes.
234 SchedNonBlockingClosureAfter(1000000, [this, s] {
235 remote_access_->buf_rendezvous()->StartAbort(s);
236 Unref();
237 });
238 }
239 done(s);
240 };
241
242 Tensor* output = ctx->mutable_output(0);
243 const Tensor* input = (col_params.instance.type == REDUCTION_COLLECTIVE ||
244 col_params.instance.type == GATHER_COLLECTIVE ||
245 (col_params.instance.type == BROADCAST_COLLECTIVE &&
246 col_params.is_source))
247 ? &ctx->input(0)
248 : nullptr;
249 CollectiveImplementationInterface* col_impl = nullptr;
250 Status status = CreateCollective(col_params, &col_impl);
251 if (!status.ok()) {
252 done_safe(status);
253 DCHECK_EQ(nullptr, col_impl);
254 return;
255 }
256 CollectiveContext* col_ctx =
257 new CollectiveContext(this, dev_mgr_, ctx, CtxParams(ctx), col_params,
258 exec_key, step_id_, input, output);
259 status = col_impl->InitializeCollectiveContext(col_ctx);
260 if (!status.ok()) {
261 done_safe(status);
262 delete col_ctx;
263 delete col_impl;
264 return;
265 }
266 // Run on an unbounded work queue that can handle blocking work so as to not
267 // starve executor threads.
268 remote_access_->RunClosure([col_impl, col_ctx, done_safe, ctx]() {
269 profiler::TraceMe activity(
270 [&] {
271 return strings::StrCat(ctx->op_kernel().name_view(), ":",
272 ctx->op_kernel().type_string_view(),
273 "#id=", ctx->step_id(), "#");
274 },
275 profiler::TraceMeLevel::kInfo);
276 col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) {
277 done_safe(s);
278 delete col_ctx;
279 delete col_impl;
280 });
281 });
282 }
283
CompleteParamsAsync(const string & device,CollectiveParams * cp,CancellationManager * cancel_mgr,StatusCallback done)284 void BaseCollectiveExecutor::CompleteParamsAsync(
285 const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
286 StatusCallback done) {
287 cp->instance.gpu_ring_order = *gpu_ring_order_;
288 cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr, done);
289 }
290
CreateCollective(const CollectiveParams & col_params,CollectiveImplementationInterface ** col_impl)291 Status BaseCollectiveExecutor::CreateCollective(
292 const CollectiveParams& col_params,
293 CollectiveImplementationInterface** col_impl) {
294 VLOG(2) << "CreateCollective type "
295 << DataTypeString(col_params.instance.data_type) << " name "
296 << col_params.instance.impl_details.collective_name;
297 *col_impl = nullptr;
298 switch (col_params.instance.data_type) {
299 case DT_BOOL:
300 if (col_params.instance.type == BROADCAST_COLLECTIVE) {
301 return CollectiveRegistry::Lookup(
302 col_params.instance.impl_details.collective_name, col_impl);
303 } else {
304 return errors::Internal(
305 "No collective other than broadcast supports DT_BOOL");
306 }
307 case DT_INT32:
308 if (col_params.group.device_type == DEVICE_GPU &&
309 col_params.instance.type == REDUCTION_COLLECTIVE) {
310 // TODO(b/139421603): enable int32 all-reduce on GPU.
311 return errors::Internal(
312 "Collective all-reduce does not support datatype DT_INT32 on "
313 "DEVICE_GPU");
314 } else {
315 return CollectiveRegistry::Lookup(
316 col_params.instance.impl_details.collective_name, col_impl);
317 }
318 case DT_HALF:
319 case DT_FLOAT:
320 case DT_DOUBLE:
321 case DT_INT64: {
322 return CollectiveRegistry::Lookup(
323 col_params.instance.impl_details.collective_name, col_impl);
324 }
325 default:
326 return errors::Internal(
327 "CollectiveImplementation does not support datatype ",
328 DataTypeString(col_params.instance.data_type));
329 }
330 }
331
CheckDependencies(const CollectiveParams & col_params)332 bool BaseCollectiveExecutor::CheckDependencies(
333 const CollectiveParams& col_params) {
334 for (int32 instance : col_params.instance.impl_details.dependencies) {
335 auto find_iter = launched_.find(instance);
336 if (find_iter == launched_.end() || find_iter->second != 0) {
337 VLOG(1) << "Collective " << col_params.ToString()
338 << " blocked by instance " << instance;
339 return false;
340 }
341 }
342 return true;
343 }
344
WaitForDependencies(const CollectiveParams & col_params)345 void BaseCollectiveExecutor::WaitForDependencies(
346 const CollectiveParams& col_params) {
347 mutex_lock l(launch_mu_);
348 while (!CheckDependencies(col_params)) {
349 launch_cv_.wait(l);
350 }
351 VLOG(1) << "Unblocking collective " << col_params.ToString();
352 }
353
UnblockDependencies(const CollectiveParams & col_params)354 void BaseCollectiveExecutor::UnblockDependencies(
355 const CollectiveParams& col_params) {
356 mutex_lock l(launch_mu_);
357 if (launched_.find(col_params.instance.instance_key) == launched_.end()) {
358 const string& task_name =
359 col_params.instance.task_names[col_params.default_rank];
360 const int32 num_devices =
361 col_params.instance.num_devices_per_task.at(task_name);
362 launched_[col_params.instance.instance_key] = num_devices;
363 }
364 if (--launched_[col_params.instance.instance_key] == 0) {
365 VLOG(1) << "Unblocking dependencies for collective instance "
366 << col_params.instance.instance_key;
367 launch_cv_.notify_all();
368 }
369 }
370
371 } // namespace tensorflow
372