• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
17 
18 #include <functional>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <vector>
23 
24 #include "tensorflow/core/framework/collective.h"
25 #include "tensorflow/core/lib/gtl/flatmap.h"
26 #include "tensorflow/core/protobuf/config.pb.h"
27 
28 namespace tensorflow {
29 class CompleteGroupRequest;
30 class CompleteGroupResponse;
31 class CompleteInstanceRequest;
32 class CompleteInstanceResponse;
33 class DeviceMgr;
34 
35 // Implements ParamResolverInterface for a single-task context.
36 // It also implements the functionality necessary to serve as the
37 // group leader for param resolution in a multi-task context.
38 class CollectiveParamResolverLocal : public ParamResolverInterface {
39  public:
40   CollectiveParamResolverLocal(const ConfigProto& config,
41                                const DeviceMgr* dev_mgr,
42                                DeviceResolverInterface* dev_resolver,
43                                const string& task_name);
44 
~CollectiveParamResolverLocal()45   ~CollectiveParamResolverLocal() override {}
46 
47   void CompleteParamsAsync(const string& device, CollectiveParams* cp,
48                            CancellationManager* cancel_mgr,
49                            const StatusCallback& done) override;
50 
51   void CompleteGroupAsync(const CompleteGroupRequest* request,
52                           CompleteGroupResponse* response,
53                           CancellationManager* cancel_mgr,
54                           const StatusCallback& done) override;
55 
56   void CompleteInstanceAsync(const CompleteInstanceRequest* request,
57                              CompleteInstanceResponse* response,
58                              CancellationManager* cancel_mgr,
59                              const StatusCallback& done) override;
60 
61  protected:
62   // For access to InstanceRec and CompleteDefaultRanking.
63   friend class CollectiveParamResolverLocalTest;
64 
65   // Used to complete/verify CollGroup.
66   struct GroupRec {
67     CollGroupParams group;
68     mutable mutex mu;
69     Status status GUARDED_BY(mu);
70     std::set<string> device_set GUARDED_BY(mu);
71     std::vector<string> device_list GUARDED_BY(mu);
72     std::set<string> task_set GUARDED_BY(mu);
73     std::vector<string> task_list GUARDED_BY(mu);
74     std::vector<StatusCallback> waiting GUARDED_BY(mu);
75   };
76 
77   // Finds the GroupRec that corresponds to cp->group_key.
78   // Also populates cp->group from that group_rec.
79   // Will wait until GroupRec is fully populated or an error arises before
80   // calling done.  Callback GroupRec* arg is only valid if status is ok.
81   // Ownership of GroupRec stays with this object and does not pass to the
82   // callback.
83   typedef std::function<void(const Status& s, const GroupRec* gr)>
84       GroupRecCallback;
85   void CompleteGroupLocal(const string& device, CollectiveParams* cp,
86                           const GroupRecCallback& done)
87       LOCKS_EXCLUDED(group_mu_);
88 
89   // Used to complete/verify CollInstance.
90   struct InstanceRec;
91 
92   typedef std::function<void(InstanceRec*)> IRConsumer;
93   struct InstanceRec {
94     // This structure has two mutexes so that a possibly long
95     // initialization can be done without holding the instance_mu_
96     // table lock the whole time (which can cause an excessive number
97     // of threads to block on it), and because the compiler may not
98     // permit mutex locks to be taken in more than one order.
99     //
100     // out_mu guards access to most of the fields.
101     // in_mu guards access to a queue of consumer callbacks wanting to
102     // read the fields guarded by out_mu.
103     //
104     // The in_mu should be locked only while holding instance_mu_; the
105     // out_mu should be locked only while not holding
106     // instance_mu_.
107     //
108     // When is_init is false (the initial value) any potential user
109     // other than the creator should queue a callback on init_waiters.
110     // As soon as the shared member of this structure is fully
111     // initialized is_init will be set true and those callbacks will
112     // be invoked.
113     //
114     // Once inserted in the table this structure will never be replaced
115     // so users can capture the pointer while holding instance_mu_,
116     // drop that lock, then take a lock on out_mu before
117     // reading/modifying its values.
118     mutex in_mu;
119     bool is_init GUARDED_BY(in_mu);
120     std::vector<IRConsumer> init_waiters GUARDED_BY(in_mu);
121 
122     // A thread that wishes to acquire out_mu must ensure that it is available
123     // by invoking WaitForOutMu().
124     mutex out_mu;
125     condition_variable out_cv;
126     bool out_mu_available GUARDED_BY(out_mu);
127     // Values to be shared by all instances, constant after initialization.
128     CollectiveParams shared GUARDED_BY(out_mu);
129     // If an error occurs during initialization this structure stays in
130     // the table with a non-OK status.  Purging the table and restarting
131     // needs to be done at a higher level.
132     Status status GUARDED_BY(out_mu);
133 
134     // These fields are used to count the instances that have called
135     // in and become known while resolving broadcast source identity and
136     // communicator key.
137     int source_rank GUARDED_BY(out_mu);
138     string communicator_key GUARDED_BY(out_mu);
139     int known_count GUARDED_BY(out_mu);
140     std::vector<bool> known GUARDED_BY(out_mu);
141     std::vector<IRConsumer> known_waiters GUARDED_BY(out_mu);
142 
InstanceRecInstanceRec143     InstanceRec()
144         : is_init(false),
145           out_mu_available(true),
146           source_rank(-1),
147           known_count(0) {}
148 
149     // If out_mu is unavailable during distributed device locality
150     // initialization, wait on out_cv until it is available again.
151     void WaitForOutMu(mutex_lock& lock) EXCLUSIVE_LOCKS_REQUIRED(out_mu);
152   };
153 
154   // Find the InstanceRec with the same instance_key as cp.  If it doesn't
155   // already exist, create and initialize from gr and cp.
156   //
157   // Precondition: *gr must be a complete GroupRec, i.e. the value set
158   // by CompleteGroupLocal. *cp must be populated with all the fields
159   // required by InitInstanceSharedParams.  Ownership of InstanceRec stays
160   // with this object and does not pass to the callback.
161   typedef std::function<void(const Status& s, InstanceRec* ir)>
162       InstanceRecCallback;
163   void FindInstanceRec(const GroupRec* gr, CollectiveParams* cp,
164                        const InstanceRecCallback& done)
165       LOCKS_EXCLUDED(instance_mu_, gr->mu, group_mu_);
166 
167   // Populate *ir with device membership from gr, then initialize to be specific
168   // to cp->instance_key, i.e. order the devices and tasks.
169   //
170   // Preconditions:
171   //  cp is populated with all DeviceLocalities
172   void InitInstanceSharedParams(const GroupRec* gr, const CollectiveParams* cp,
173                                 InstanceRec* ir, const StatusCallback& done)
174       UNLOCK_FUNCTION(ir->out_mu) LOCKS_EXCLUDED(gr->mu);
175 
176   void CallInitInstanceSharedParams(const GroupRec* gr,
177                                     const CollectiveParams* cp, InstanceRec* ir,
178                                     const InstanceRecCallback& done)
179       LOCKS_EXCLUDED(ir->out_mu, gr->mu);
180 
181   // Establishes the final order of ir->shared.instance.device_names and
182   // ir->shared.instance.task_names by considering localities of all devices.
183   void CompleteDefaultRanking(const GroupRec* gr, const CollectiveParams* cp,
184                               InstanceRec* ir,
185                               const std::vector<DeviceLocality>& localities)
186       EXCLUSIVE_LOCKS_REQUIRED(ir->out_mu);
187 
188   // Finish populating *cp.
189   // Precondition: *gr has been fully populated by CompleteGroupLocal.
190   void CompleteInstanceLocal(const string& device, const GroupRec* gr,
191                              CollectiveParams* cp, bool is_source,
192                              const StatusCallback& done)
193       LOCKS_EXCLUDED(instance_mu_, gr->mu, group_mu_);
194 
195   // Finish populating *cp from fully initialized *ir.
196   // Precondition: *gr and *ir are fully populated.
197   void CompleteInstanceFromInitializedIRec(const string& device,
198                                            const GroupRec* gr,
199                                            CollectiveParams* cp,
200                                            InstanceRec* ir, bool is_source,
201                                            const StatusCallback& done)
202       LOCKS_EXCLUDED(ir->out_mu);
203 
204   // Complete source data and/or nccl communicator key.
205   // Precondition: *cp has complete group data and default_rank.
206   void WaitForGroup(InstanceRec* ir, CollectiveParams* cp, bool is_source,
207                     bool init_source, bool init_nccl, const IRConsumer& f)
208       LOCKS_EXCLUDED(ir->out_mu);
209 
210   // If cp.device_names contains only devices local to this process
211   // populates *localities, else returns an error.
212   Status GetLocalDeviceLocalities(const CollectiveParams& cp,
213                                   std::vector<DeviceLocality>* localities);
214 
215   // Sets CollTaskParams.is_local and CollectiveParams.default_rank.
216   // Precondition: cp->device_names is fully populated and in final order.
217   void CompleteTaskIsLocal(const string& task_name, CollectiveParams* cp);
218 
219   // Sets cp->instance_default_rank according to location of device in
220   // current ordering of cp->instance.device_names.
221   void SetDefaultRank(const string& device, CollectiveParams* cp);
222 
223   // Sets cp->instance.type based on collective op type, and attempts to assign
224   // best implementation.
225   void AssignCollectiveType(CollectiveParams* cp);
226 
227   // Helper to grab status under lock, invoke callback out of lock.
228   void CallbackWithStatus(const InstanceRecCallback& done, InstanceRec* irec)
229       LOCKS_EXCLUDED(irec->out_mu);
230 
231   const bool nccl_;
232   const DeviceMgr* dev_mgr_;
233   DeviceResolverInterface* dev_resolver_;  // Not owned.
234   string task_name_;
235   mutex group_mu_;
236   gtl::FlatMap<int32, std::unique_ptr<GroupRec>> group_table_
237       GUARDED_BY(group_mu_);
238   mutex instance_mu_;
239   gtl::FlatMap<int32, std::unique_ptr<InstanceRec>> instance_table_
240       GUARDED_BY(instance_mu_);
241 };
242 
243 }  // namespace tensorflow
244 
245 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
246