• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BASE_COLLECTIVE_EXECUTOR_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_BASE_COLLECTIVE_EXECUTOR_H_
17 
18 #include <memory>
19 #include <string>
20 
21 #include "tensorflow/core/common_runtime/buf_rendezvous.h"
22 #include "tensorflow/core/framework/collective.h"
23 #include "tensorflow/core/framework/device_attributes.pb.h"
24 #include "tensorflow/core/platform/unbounded_work_queue.h"
25 
26 namespace tensorflow {
27 class CollectiveImplementation;
28 class DeviceMgr;
29 class Device;
30 
31 // Helper interface that aliases regular subfields of a Tensor as separate
32 // Tensors for in-place update.
33 class CollectiveAdapter {
34  public:
~CollectiveAdapter()35   virtual ~CollectiveAdapter() {}
36 
37   // Move the backing tensor to 'output' with its original storage and
38   // shape. After this call this CollectiveAdapter object should be
39   // deleted immediately without calling any of its other methods.
40   virtual void ConsumeFinalValue(Tensor* output) = 0;
41 
42   // const access to entire intermediate value for debugging
43   virtual const Tensor& Value() const = 0;
44 
45   // Returns tensor for chunk i which aliases the backing buffer.
46   virtual Tensor ChunkAlias(int i) = 0;
47 
48   // Returns tensor allocated on the same device but with its own
49   // separate backing buffer.  Will have same type and size as
50   // chunk i.
51   virtual Tensor TempChunk(int i) const = 0;
52 
53   // Bytes in chunk i
54   virtual int64_t ChunkBytes(int i) const = 0;
55 
56   // Generate a CPU RAM scalar tensor of the same DataType as the
57   // backing tensor with the given integer value.
58   virtual Tensor Scalar(int v) const = 0;
59 
60   // Generate a scalar tensor of same DataType and on the same device
61   // as the backing tensor.
62   virtual Tensor Scalar(Allocator* a,
63                         const AllocationAttributes& attr) const = 0;
64 
65   // Debugging string describing buffer location
66   virtual string TBounds(const Tensor& t) const = 0;
67 
68   virtual string DebugString() const = 0;
69 
70   // Computes the number of elements per alias chunk tensor.
71   //
72   // A CHECK in tensor.cc expects that the memory buffer backing a
73   // Tensor will be aligned according to EIGEN_MAX_ALIGN_BYTES.  To
74   // ensure that all chunk aliasing Tensors maintain this alignment we
75   // need to pick a chunk size that preserves it.  Note than in extreme
76   // cases (impractical, but possible with very small tensors) one or
77   // more tail chunks can end up emptby.
78   static int64_t AlignedChunkElts(int64_t elt_bytes, int64_t total_elts,
79                                   int64_t num_chunks);
80 };
81 
82 // Create a CollectiveAdaptor wrapping 'output', specialized to its
83 // data-type and shape.  If align_chunks == true then chunk size may
84 // be larger than output->NumElements() / num_chunks and one or more
85 // of the suffix chunks may be empty.  Chunks will be arranged to start
86 // and end on alignment boundaries.  If align_chunks == false then
87 // output->NumElements() % num_chunks must be 0 and all chunks will
88 // have exactly the same size, ignoring alignment issues.
89 CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks,
90                                          Allocator* allocator,
91                                          bool align_chunks = true);
92 
93 // Default implementation of CollectiveExecutor.  Delegates the actual
94 // work of moving data to a class specialized for the operation type,
95 // arguments and device+interconnect topology.
96 class BaseCollectiveExecutor : public CollectiveExecutor {
97  public:
BaseCollectiveExecutor(CollectiveExecutorMgrInterface * cem,CollectiveRemoteAccess * remote_access,int64_t step_id,const DeviceMgr * dev_mgr,std::shared_ptr<UnboundedWorkQueue> work_queue)98   BaseCollectiveExecutor(CollectiveExecutorMgrInterface* cem,
99                          CollectiveRemoteAccess* remote_access, int64_t step_id,
100                          const DeviceMgr* dev_mgr,
101                          std::shared_ptr<UnboundedWorkQueue> work_queue)
102       : CollectiveExecutor(cem),
103         step_id_(step_id),
104         dev_mgr_(dev_mgr),
105         remote_access_(remote_access),
106         work_queue_(std::move(work_queue)) {}
107 
108   ~BaseCollectiveExecutor() override;
109 
110   void StartAbort(const Status& s) override TF_LOCKS_EXCLUDED(status_mu_);
111 
112   void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams* col_params,
113                     const string& exec_key, StatusCallback done) override;
114 
115   void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
116                            CancellationManager* cancel_mgr,
117                            StatusCallback done) override;
118 
remote_access()119   CollectiveRemoteAccess* remote_access() override {
120     return remote_access_.get();
121   }
122 
RunClosure(std::function<void ()> closure)123   void RunClosure(std::function<void()> closure) override {
124     work_queue_->Schedule(std::move(closure));
125   }
126 
127   // If we need to enforce an ordering on any portion of collective
128   // implementation, and the ordering is encoded via attribute on the collective
129   // op, this function will block until all dependencies for this collective
130   // have completed.
131   void WaitForDependencies(const CollectiveParams& col_params) override;
132   // Record that this collective has completed the portion of the implementation
133   // that needs to be ordered wrt other collectives, to unblock any of its
134   // dependent ops.
135   void UnblockDependencies(const CollectiveParams& col_params) override;
136 
137  protected:
138   const int64_t step_id_;
139   const DeviceMgr* dev_mgr_;  // Not owned.
140   std::unique_ptr<CollectiveRemoteAccess> remote_access_;
141   // Ownership of `work_queue_` is shared between `this` and
142   // `CollectiveExecutorMgr`.
143   std::shared_ptr<UnboundedWorkQueue> work_queue_;
144   mutex launch_mu_;
145   condition_variable launch_cv_;
146   // collective instance key -> number of local devices for which NCCL ops have
147   // been launched.
148   std::unordered_map<int32, int32> launched_ TF_GUARDED_BY(launch_mu_);
149   mutex status_mu_;
150   Status status_ TF_GUARDED_BY(status_mu_);
151 
152  private:
153   Status CreateCollective(const CollectiveParams& col_params,
154                           CollectiveImplementationInterface** col_impl);
155   // Check if all ops on which this collective depends on have launched.
156   bool CheckDependencies(const CollectiveParams& col_params)
157       TF_EXCLUSIVE_LOCKS_REQUIRED(launch_mu_);
158   // Tries to return the status that is the original error. It returns the
159   // aborted status if the collective executor is aborted.
160   Status GetStatus(const Status& s) TF_LOCKS_EXCLUDED(status_mu_);
161 };
162 
163 }  // namespace tensorflow
164 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_BASE_COLLECTIVE_EXECUTOR_H_
165