1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/base_collective_executor.h"
16
17 #include <algorithm>
18 #include <functional>
19 #include <utility>
20
21 #include "tensorflow/core/common_runtime/copy_tensor.h"
22 #include "tensorflow/core/common_runtime/device_mgr.h"
23 #include "tensorflow/core/common_runtime/dma_helper.h"
24 #include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h"
25 #include "tensorflow/core/common_runtime/process_util.h"
26 #include "tensorflow/core/common_runtime/ring_reducer.h"
27 #include "tensorflow/core/framework/allocator.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/lib/core/notification.h"
35 #include "tensorflow/core/lib/core/status.h"
36 #include "tensorflow/core/lib/strings/strcat.h"
37 #include "tensorflow/core/platform/macros.h"
38 #include "tensorflow/core/platform/types.h"
39
40 #define VALUE_IN_DEBUG_STRING false
41
42 namespace tensorflow {
43 /*static*/
AlignedChunkElts(int64 elt_bytes,int64 total_elts,int64 num_chunks)44 int64 CollectiveAdapter::AlignedChunkElts(int64 elt_bytes, int64 total_elts,
45 int64 num_chunks) {
46 DCHECK_GT(num_chunks, 0);
47 int64 base_chunk_elts = (total_elts + (num_chunks - 1)) / num_chunks;
48 if (EIGEN_MAX_ALIGN_BYTES == 0) return base_chunk_elts;
49 if (EIGEN_MAX_ALIGN_BYTES <= elt_bytes) {
50 // Tolerate weird small values of EIGEN_MAX_ALIGN_BYTES
51 DCHECK_EQ(0, elt_bytes % EIGEN_MAX_ALIGN_BYTES);
52 return base_chunk_elts;
53 }
54 // elt_bytes < EIGEN_MAX_ALIGN_BYTES, which
55 // must be a common multiple of the various atomic data types.
56 DCHECK_EQ(0, EIGEN_MAX_ALIGN_BYTES % elt_bytes)
57 << "total_elts=" << total_elts << " num_chunks=" << num_chunks
58 << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES
59 << " elt_bytes=" << elt_bytes;
60 // Round bytes per chunk up to the next multiple of EIGEN_MAX_ALIGN_BYTES.
61 int64 chunk_bytes = base_chunk_elts * elt_bytes;
62 int64 diff =
63 (chunk_bytes < EIGEN_MAX_ALIGN_BYTES)
64 ? (EIGEN_MAX_ALIGN_BYTES - chunk_bytes)
65 : (EIGEN_MAX_ALIGN_BYTES - (chunk_bytes % EIGEN_MAX_ALIGN_BYTES));
66 DCHECK_EQ(0, diff % elt_bytes);
67 base_chunk_elts += (diff / elt_bytes);
68 DCHECK_EQ(0, ((base_chunk_elts * elt_bytes) % EIGEN_MAX_ALIGN_BYTES))
69 << "total_elts=" << total_elts << " num_chunks=" << num_chunks
70 << " EIGEN_MAX_ALIGN_BYTES=" << EIGEN_MAX_ALIGN_BYTES
71 << " base_chunk_elts=" << base_chunk_elts << " elt_bytes=" << elt_bytes;
72 return base_chunk_elts;
73 }
74
75 namespace {
76 template <typename T>
77 class CollectiveAdapterImpl : public CollectiveAdapter {
78 public:
79 // Takes ownership of output and prepares to properly alias its chunks.
80 // Ownership is taken because the shape may temporarily change.
CollectiveAdapterImpl(Tensor * output,int64 num_chunks,Allocator * allocator,bool align_chunks)81 CollectiveAdapterImpl(Tensor* output, int64 num_chunks, Allocator* allocator,
82 bool align_chunks)
83 : output_(std::move(*output)),
84 dt_(output_.dtype()),
85 old_shape_(output_.shape()),
86 num_chunks_(num_chunks),
87 allocator_(allocator),
88 total_elts_(output_.NumElements()),
89 chunk_elts_(align_chunks
90 ? AlignedChunkElts(sizeof(T), total_elts_, num_chunks_)
91 : total_elts_ / num_chunks_),
92 data_start_(reinterpret_cast<T*>(DMAHelper::base(&output_))),
93 data_end_(data_start_ + total_elts_) {
94 if (!align_chunks) {
95 DCHECK_EQ(total_elts_, num_chunks_ * chunk_elts_);
96 }
97 DCHECK_GT(chunk_elts_, 0);
98 Flatten();
99 }
100
~CollectiveAdapterImpl()101 ~CollectiveAdapterImpl() override {}
102
Value() const103 const Tensor& Value() const override { return output_; }
104
105 // If necessary, flatten output.
Flatten()106 void Flatten() {
107 if (old_shape_.dims() != 1) {
108 TensorShape new_shape = TensorShape({old_shape_.num_elements()});
109 DMAHelper::UnsafeSetShape(&output_, new_shape);
110 }
111 }
112
ConsumeFinalValue(Tensor * output)113 void ConsumeFinalValue(Tensor* output) override {
114 if (old_shape_ != output_.shape()) {
115 DMAHelper::UnsafeSetShape(&output_, old_shape_);
116 }
117 *output = std::move(output_);
118 }
119
120 // Number of T elements in a particular chunk.
ChunkElts(int i) const121 inline int64 ChunkElts(int i) const {
122 DCHECK_LT(i, num_chunks_);
123 const T* chunk_start = std::min(data_end_, data_start_ + i * chunk_elts_);
124 const T* chunk_end = std::min(data_end_, chunk_start + chunk_elts_);
125 return chunk_end - chunk_start;
126 }
127
ChunkBytes(int i) const128 int64 ChunkBytes(int i) const override { return sizeof(T) * ChunkElts(i); }
129
130 // Returns a new Tensor that aliases the required chunk.
ChunkAlias(int i)131 Tensor ChunkAlias(int i) override {
132 int64 start = chunk_elts_ * i;
133 int64 num_elts = ChunkElts(i);
134 // If this chunk is empty the prior chunk might also be short
135 // so always take an empty slice from the front of the tensor
136 // to avoid an illegal offset check failure somewhere.
137 return (num_elts > 0) ? output_.Slice(start, start + num_elts)
138 : output_.Slice(0, 0);
139 }
140
TempChunk(int i) const141 Tensor TempChunk(int i) const override {
142 AllocationAttributes empty;
143 return Tensor(allocator_, dt_, {ChunkElts(i)}, empty);
144 }
145
DebugString() const146 string DebugString() const override {
147 return strings::StrCat(
148 "base addr ", reinterpret_cast<int64>(DMAHelper::base(&output_)),
149 " num_chunks ", num_chunks_, " total_elts ", total_elts_, " chunk_elts",
150 chunk_elts_, " value ",
151 VALUE_IN_DEBUG_STRING ? output_.SummarizeValue(1024) : "<hidden>");
152 }
153
TBounds(const Tensor & t) const154 string TBounds(const Tensor& t) const override {
155 int64 base_addr = reinterpret_cast<int64>(DMAHelper::base(&t));
156 return strings::StrCat("(", base_addr, ", ", (base_addr + t.TotalBytes()),
157 ")");
158 }
159
Scalar(int v) const160 Tensor Scalar(int v) const override {
161 Tensor t(dt_, TensorShape({}));
162 t.scalar<T>()() = v;
163 return t;
164 }
165
Scalar(Allocator * a) const166 Tensor Scalar(Allocator* a) const override {
167 Tensor t(a, dt_, TensorShape({}));
168 return t;
169 }
170
171 Tensor output_;
172 const DataType dt_;
173 const TensorShape old_shape_;
174 const int64 num_chunks_;
175 Allocator* allocator_;
176 const int64 total_elts_;
177 const int64 chunk_elts_;
178 const T* data_start_;
179 const T* data_end_;
180 };
181
182 } // namespace
183
MakeCollectiveAdapter(Tensor * output,int num_chunks,Allocator * allocator,bool align_chunks)184 CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks,
185 Allocator* allocator,
186 bool align_chunks) {
187 switch (output->dtype()) {
188 case DT_FLOAT:
189 return new CollectiveAdapterImpl<float>(output, num_chunks, allocator,
190 align_chunks);
191 break;
192 case DT_DOUBLE:
193 return new CollectiveAdapterImpl<double>(output, num_chunks, allocator,
194 align_chunks);
195 break;
196 case DT_INT32:
197 return new CollectiveAdapterImpl<int32>(output, num_chunks, allocator,
198 align_chunks);
199 break;
200 case DT_INT64:
201 return new CollectiveAdapterImpl<int64>(output, num_chunks, allocator,
202 align_chunks);
203 break;
204 default:
205 LOG(FATAL) << "Unsupported type " << output->dtype()
206 << " to MakeCollectiveAdapter";
207 return nullptr;
208 }
209 }
210
~BaseCollectiveExecutor()211 BaseCollectiveExecutor::~BaseCollectiveExecutor() {}
212
StartAbort(const Status & s)213 void BaseCollectiveExecutor::StartAbort(const Status& s) {
214 LOG(WARNING) << "BaseCollectiveExecutor::StartAbort " << s;
215 remote_access_->StartAbort(s);
216 }
217
ExecuteAsync(OpKernelContext * ctx,const CollectiveParams & col_params,const string & exec_key,StatusCallback done)218 void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
219 const CollectiveParams& col_params,
220 const string& exec_key,
221 StatusCallback done) {
222 // On any individual collective Op failure we need to abort the
223 // BufRendezvous so that other Ops in the instance don't hang
224 // waiting for transmissions that will never happen. Do so after a
225 // delay so that the original error status is more likely to
226 // propagate up, and peers are unlikely to re-create the purged
227 // BufRendezvous by late-arriving requests.
228 StatusCallback done_safe = [this, done](const Status& s) {
229 if (!s.ok()) {
230 Ref(); // Ensure this lasts until the closure executes.
231 SchedNonBlockingClosureAfter(1000000, [this, s] {
232 remote_access_->buf_rendezvous()->StartAbort(s);
233 Unref();
234 });
235 }
236 done(s);
237 };
238
239 Tensor* output = ctx->mutable_output(0);
240 const Tensor* input = (col_params.instance.type == REDUCTION_COLLECTIVE ||
241 col_params.instance.type == GATHER_COLLECTIVE ||
242 (col_params.instance.type == BROADCAST_COLLECTIVE &&
243 col_params.is_source))
244 ? &ctx->input(0)
245 : nullptr;
246 CollectiveImplementationInterface* col_impl = nullptr;
247 Status status = CreateCollective(col_params, &col_impl);
248 if (!status.ok()) {
249 done_safe(status);
250 DCHECK_EQ(nullptr, col_impl);
251 return;
252 }
253 CollectiveContext* col_ctx =
254 new CollectiveContext(this, dev_mgr_, ctx, CtxParams(ctx), col_params,
255 exec_key, step_id_, input, output);
256 status = col_impl->InitializeCollectiveContext(col_ctx);
257 if (!status.ok()) {
258 done_safe(status);
259 delete col_ctx;
260 delete col_impl;
261 return;
262 }
263 // Run in an I/O thread, so as not to starve the executor threads.
264 // TODO(b/80529858): Instead of forking every per-device Collective
265 // Op off into its own thread, consider queuing them on a
266 // fixed-size thread-pool dedicated to running CollectiveOps.
267 SchedClosure([col_impl, col_ctx, done_safe]() {
268 col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) {
269 done_safe(s);
270 delete col_ctx;
271 delete col_impl;
272 });
273 });
274 }
275
CompleteParamsAsync(const string & device,CollectiveParams * cp,CancellationManager * cancel_mgr,StatusCallback done)276 void BaseCollectiveExecutor::CompleteParamsAsync(
277 const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
278 StatusCallback done) {
279 cp->instance.gpu_ring_order = *gpu_ring_order_;
280 cem_->GetParamResolver()->CompleteParamsAsync(device, cp, cancel_mgr, done);
281 }
282
CreateCollective(const CollectiveParams & col_params,CollectiveImplementationInterface ** col_impl)283 Status BaseCollectiveExecutor::CreateCollective(
284 const CollectiveParams& col_params,
285 CollectiveImplementationInterface** col_impl) {
286 *col_impl = nullptr;
287 Status status;
288 switch (col_params.instance.data_type) {
289 case DT_INT32:
290 if (col_params.group.device_type == DEVICE_GPU) {
291 status = errors::Internal(
292 "CollectiveImplementation does not support datatype DT_INT32 on "
293 "DEVICE_GPU");
294 }
295 TF_FALLTHROUGH_INTENDED;
296 case DT_FLOAT:
297 case DT_DOUBLE:
298 case DT_INT64: {
299 status = CollectiveRegistry::Lookup(
300 col_params.instance.impl_details.collective_name, col_impl);
301 break;
302 }
303 default:
304 status = errors::Internal(
305 "CollectiveImplementation does not support datatype ",
306 col_params.instance.data_type);
307 }
308 return status;
309 }
310
CheckDependencies(const CollectiveParams & col_params)311 bool BaseCollectiveExecutor::CheckDependencies(
312 const CollectiveParams& col_params) {
313 for (int32 instance : col_params.instance.impl_details.dependencies) {
314 auto find_iter = launched_.find(instance);
315 if (find_iter == launched_.end() || find_iter->second != 0) {
316 VLOG(1) << "Collective " << col_params.ToString()
317 << " blocked by instance " << instance;
318 return false;
319 }
320 }
321 return true;
322 }
323
WaitForDependencies(const CollectiveParams & col_params)324 void BaseCollectiveExecutor::WaitForDependencies(
325 const CollectiveParams& col_params) {
326 mutex_lock l(launch_mu_);
327 while (!CheckDependencies(col_params)) {
328 launch_cv_.wait(l);
329 }
330 VLOG(1) << "Unblocking collective " << col_params.ToString();
331 }
332
Launched(const CollectiveParams & col_params)333 void BaseCollectiveExecutor::Launched(const CollectiveParams& col_params) {
334 mutex_lock l(launch_mu_);
335 if (launched_.find(col_params.instance.instance_key) == launched_.end()) {
336 const string& task_name =
337 col_params.instance.task_names[col_params.default_rank];
338 const int32 num_devices =
339 col_params.instance.num_devices_per_task.at(task_name);
340 launched_[col_params.instance.instance_key] = num_devices;
341 }
342 if (--launched_[col_params.instance.instance_key] == 0) {
343 VLOG(1) << "Unblocking dependencies for collective instance "
344 << col_params.instance.instance_key;
345 launch_cv_.notify_all();
346 }
347 }
348
349 } // namespace tensorflow
350