• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BASE_COLLECTIVE_EXECUTOR_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_BASE_COLLECTIVE_EXECUTOR_H_
17 
18 #include <memory>
19 #include <string>
20 
21 #include "tensorflow/core/common_runtime/buf_rendezvous.h"
22 #include "tensorflow/core/framework/collective.h"
23 #include "tensorflow/core/framework/device_attributes.pb.h"
24 
25 namespace tensorflow {
26 class CollectiveImplementation;
27 class DeviceMgr;
28 class Device;
29 
30 // Helper interface that aliases regular subfields of a Tensor as separate
31 // Tensors for in-place update.
32 class CollectiveAdapter {
33  public:
~CollectiveAdapter()34   virtual ~CollectiveAdapter() {}
35 
36   // Move the backing tensor to 'output' with its original storage and
37   // shape. After this call this CollectiveAdapter object should be
38   // deleted immediately without calling any of its other methods.
39   virtual void ConsumeFinalValue(Tensor* output) = 0;
40 
41   // const access to entire intermediate value for debugging
42   virtual const Tensor& Value() const = 0;
43 
44   // Returns tensor for chunk i which aliases the backing buffer.
45   virtual Tensor ChunkAlias(int i) = 0;
46 
47   // Returns tensor allocated on the same device but with its own
48   // separate backing buffer.  Will have same type and size as
49   // chunk i.
50   virtual Tensor TempChunk(int i) const = 0;
51 
52   // Bytes in chunk i
53   virtual int64 ChunkBytes(int i) const = 0;
54 
55   // Generate a CPU RAM scalar tensor of the same DataType as the
56   // backing tensor with the given integer value.
57   virtual Tensor Scalar(int v) const = 0;
58 
59   // Generate a scalar tensor of same DataType and on the same device
60   // as the backing tensor.
61   virtual Tensor Scalar(Allocator* a,
62                         const AllocationAttributes& attr) const = 0;
63 
64   // Debugging string describing buffer location
65   virtual string TBounds(const Tensor& t) const = 0;
66 
67   virtual string DebugString() const = 0;
68 
69   // Computes the number of elements per alias chunk tensor.
70   //
71   // A CHECK in tensor.cc expects that the memory buffer backing a
72   // Tensor will be aligned according to EIGEN_MAX_ALIGN_BYTES.  To
73   // ensure that all chunk aliasing Tensors maintain this alignment we
74   // need to pick a chunk size that preserves it.  Note than in extreme
75   // cases (impractical, but possible with very small tensors) one or
76   // more tail chunks can end up emptby.
77   static int64 AlignedChunkElts(int64 elt_bytes, int64 total_elts,
78                                 int64 num_chunks);
79 };
80 
81 // Create a CollectiveAdaptor wrapping 'output', specialized to its
82 // data-type and shape.  If align_chunks == true then chunk size may
83 // be larger than output->NumElements() / num_chunks and one or more
84 // of the suffix chunks may be empty.  Chunks will be arranged to start
85 // and end on alignment boundaries.  If align_chunks == false then
86 // output->NumElements() % num_chunks must be 0 and all chunks will
87 // have exactly the same size, ignoring alignment issues.
88 CollectiveAdapter* MakeCollectiveAdapter(Tensor* output, int num_chunks,
89                                          Allocator* allocator,
90                                          bool align_chunks = true);
91 
92 // Default implementation of CollectiveExecutor.  Delegates the actual
93 // work of moving data to a class specialized for the operation type,
94 // arguments and device+interconnect topology.
95 class BaseCollectiveExecutor : public CollectiveExecutor {
96  public:
BaseCollectiveExecutor(CollectiveExecutorMgrInterface * cem,PerStepCollectiveRemoteAccess * remote_access,int64 step_id,const DeviceMgr * dev_mgr,const string * gpu_ring_order)97   BaseCollectiveExecutor(CollectiveExecutorMgrInterface* cem,
98                          PerStepCollectiveRemoteAccess* remote_access,
99                          int64 step_id, const DeviceMgr* dev_mgr,
100                          const string* gpu_ring_order)
101       : CollectiveExecutor(cem),
102         step_id_(step_id),
103         dev_mgr_(dev_mgr),
104         remote_access_(remote_access),
105         gpu_ring_order_(gpu_ring_order) {}
106 
107   ~BaseCollectiveExecutor() override;
108 
109   void StartAbort(const Status& s) override;
110 
111   void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams& col_params,
112                     const string& exec_key, StatusCallback done) override;
113 
114   void CompleteParamsAsync(const string& device, CollectiveParams* cp,
115                            CancellationManager* cancel_mgr,
116                            StatusCallback done) override;
117 
remote_access()118   PerStepCollectiveRemoteAccess* remote_access() override {
119     return remote_access_.get();
120   }
121 
RecvFromPeer(const string & peer_device,const string & peer_task,bool peer_is_local,const string & key,Device * to_device,DeviceContext * to_device_ctx,const AllocatorAttributes & to_alloc_attr,Tensor * to_tensor,const DeviceLocality & client_locality,int stream_index,const StatusCallback & done)122   void RecvFromPeer(const string& peer_device, const string& peer_task,
123                     bool peer_is_local, const string& key, Device* to_device,
124                     DeviceContext* to_device_ctx,
125                     const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
126                     const DeviceLocality& client_locality, int stream_index,
127                     const StatusCallback& done) override {
128     remote_access_->RecvFromPeer(
129         peer_device, peer_task, peer_is_local, key, to_device, to_device_ctx,
130         to_alloc_attr, to_tensor, client_locality, stream_index, done);
131   }
132 
PostToPeer(const string & peer_device,const string & peer_task,const string & key,Device * from_device,DeviceContext * from_device_ctx,const AllocatorAttributes & from_alloc_attr,const Tensor * from_tensor,const DeviceLocality & client_locality,const StatusCallback & done)133   void PostToPeer(const string& peer_device, const string& peer_task,
134                   const string& key, Device* from_device,
135                   DeviceContext* from_device_ctx,
136                   const AllocatorAttributes& from_alloc_attr,
137                   const Tensor* from_tensor,
138                   const DeviceLocality& client_locality,
139                   const StatusCallback& done) override {
140     remote_access_->PostToPeer(peer_device, peer_task, key, from_device,
141                                from_device_ctx, from_alloc_attr, from_tensor,
142                                client_locality, done);
143   }
144 
RunClosure(std::function<void ()> closure)145   void RunClosure(std::function<void()> closure) override {
146     remote_access_->RunClosure(std::move(closure));
147   }
148 
149   // If we need to enforce an ordering on any portion of collective
150   // implementation, and the ordering is encoded via attribute on the collective
151   // op, this function will block until all dependencies for this collective
152   // have completed.
153   void WaitForDependencies(const CollectiveParams& col_params) override;
154   // Record that this collective has completed the portion of the implementation
155   // that needs to be ordered wrt other collectives, to unblock any of its
156   // dependent ops.
157   void UnblockDependencies(const CollectiveParams& col_params) override;
158 
159  protected:
160   const int64 step_id_;
161   const DeviceMgr* dev_mgr_;  // Not owned.
162   std::unique_ptr<PerStepCollectiveRemoteAccess> remote_access_;
163   const string* gpu_ring_order_;  // Not owned.
164   mutex launch_mu_;
165   condition_variable launch_cv_;
166   // collective instance key -> number of local devices for which NCCL ops have
167   // been launched.
168   std::unordered_map<int32, int32> launched_ GUARDED_BY(launch_mu_);
169 
170  private:
171   Status CreateCollective(const CollectiveParams& col_params,
172                           CollectiveImplementationInterface** col_impl);
173   // Check if all ops on which this collective depends on have launched.
174   bool CheckDependencies(const CollectiveParams& col_params)
175       EXCLUSIVE_LOCKS_REQUIRED(launch_mu_);
176 };
177 
178 }  // namespace tensorflow
179 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_BASE_COLLECTIVE_EXECUTOR_H_
180