• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
17 
18 #include <functional>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <unordered_map>
23 #include <vector>
24 
25 #include "tensorflow/core/framework/collective.h"
26 #include "tensorflow/core/framework/device_attributes.pb.h"
27 #include "tensorflow/core/lib/gtl/flatmap.h"
28 #include "tensorflow/core/platform/thread_annotations.h"
29 
30 namespace tensorflow {
31 class CompleteGroupRequest;
32 class CompleteGroupResponse;
33 class CompleteInstanceRequest;
34 class CompleteInstanceResponse;
35 class ConfigProto;
36 class DeviceMgr;
37 
38 // Implements ParamResolverInterface for a single-task context.
39 // It also implements the functionality necessary to serve as the
40 // group leader for param resolution in a multi-task context.
41 class CollectiveParamResolverLocal : public ParamResolverInterface {
42  public:
43   CollectiveParamResolverLocal(const ConfigProto& config,
44                                const DeviceMgr* dev_mgr,
45                                DeviceResolverInterface* dev_resolver,
46                                NcclCommunicatorInterface* nccl_communicator,
47                                const string& task_name);
48 
~CollectiveParamResolverLocal()49   ~CollectiveParamResolverLocal() override {}
50 
51   void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
52                            CancellationManager* cancel_mgr,
53                            const StatusCallback& done) override;
54 
55   void CompleteGroupAsync(const DeviceAttributes& device,
56                           CollGroupParams* group_params,
57                           CancellationManager* cancel_mgr,
58                           const StatusCallback& done) override;
59 
60   void CompleteInstanceAsync(const CompleteInstanceRequest* request,
61                              CompleteInstanceResponse* response,
62                              CancellationManager* cancel_mgr,
63                              const StatusCallback& done) override;
64 
65   void StartAbort(const Status& s) override;
66 
67  protected:
68   // For access to InstanceRec and CompleteDefaultRanking.
69   friend class CollectiveParamResolverLocalTest;
70 
71   // Used to complete/verify CollGroup.
72   struct GroupRec {
73     mutable mutex mu;
74     CollGroupParams group TF_GUARDED_BY(mu);
75     Status status TF_GUARDED_BY(mu);
76     std::unordered_map<string, int64> incarnations_by_device_name
77         TF_GUARDED_BY(mu);
78     std::vector<CollGroupParams*> pending_params TF_GUARDED_BY(mu);
79     std::vector<StatusCallback> pending_done TF_GUARDED_BY(mu);
80   };
81 
82   // Finds the GroupRec that corresponds to group_params->group_key.
83   // Also populates group_params from that group_rec.
84   // Will wait until GroupRec is fully populated or an error arises before
85   // calling done.  Callback GroupRec* arg is only valid if status is ok.
86   // Ownership of GroupRec stays with this object and does not pass to the
87   // callback.
88   void CompleteGroupLocal(const DeviceAttributes& device,
89                           CollGroupParams* group_params,
90                           CancellationManager* cancel_mgr, StatusCallback done)
91       TF_LOCKS_EXCLUDED(group_mu_);
92 
93   // Finishes the group parameters once all members of the group are there.
94   void FinishGroup(GroupRec* gr) TF_EXCLUSIVE_LOCKS_REQUIRED(gr->mu);
95 
96   // Cancels the group if it's still pending.
97   void CancelGroup(int32 group_key) TF_LOCKS_EXCLUDED(group_mu_);
98 
99   // Used to complete/verify CollInstance.
100   struct InstanceRec;
101 
102   typedef std::function<void(InstanceRec*)> IRConsumer;
103   struct InstanceRec {
104     mutex mu;
105     // Values to be shared by all instances, constant after initialization.
106     CollectiveParams* shared;
107     // If an error occurs during initialization this structure stays in the
108     // table with a non-OK status. Purging the table and restarting needs to be
109     // done at a higher level.
110     Status status TF_GUARDED_BY(mu);
111 
112     // These fields are used to count the instances that have called
113     // in and become known while resolving broadcast source identity and
114     // communicator key.
115     int source_rank TF_GUARDED_BY(mu);
116     string communicator_key TF_GUARDED_BY(mu);
117     int known_count TF_GUARDED_BY(mu);
118     std::vector<bool> known TF_GUARDED_BY(mu);
119     std::vector<IRConsumer> known_waiters TF_GUARDED_BY(mu);
120 
InstanceRecInstanceRec121     InstanceRec()
122         : shared(new CollectiveParams()), source_rank(-1), known_count(0) {}
~InstanceRecInstanceRec123     ~InstanceRec() { shared->Unref(); }
124   };
125 
126   // Find the InstanceRec with the same instance_key as cp.  If it doesn't
127   // already exist, create and initialize from gr and cp.
128   // created is set to true if a new IRec is created, false otherwise.
129   //
130   // Precondition: *gr must be a complete GroupRec, i.e. the value set
131   // by CompleteGroupLocal. *cp must be populated with all the fields
132   // required by InitInstanceSharedParams.  Ownership of InstanceRec stays
133   // with this object and does not pass to the callback.
134   InstanceRec* GetOrCreateInstanceRec(CollectiveParams* cp, bool* created)
135       TF_LOCKS_EXCLUDED(instance_mu_, group_mu_);
136 
137   // Populate *ir with device membership from gr, then initialize to be specific
138   // to cp->instance_key, i.e. order the devices and tasks.
139   //
140   // Preconditions:
141   //  cp is populated with all DeviceLocalities
142   void InitInstanceSharedParams(const CollectiveParams* cp, InstanceRec* ir);
143 
144   // Establishes the final order of gp->device_names and gp->task_names by
145   // considering localities of all devices.
146   void CompleteDefaultRanking(CollGroupParams* gp);
147 
148   // Finish populating *cp.
149   // Precondition: *gr has been fully populated by CompleteGroupLocal.
150   void CompleteInstanceLocal(const string& device, CollectiveParams* cp,
151                              const StatusCallback& done)
152       TF_LOCKS_EXCLUDED(instance_mu_, group_mu_);
153 
154   // Finish populating *cp from fully initialized *ir.
155   // Precondition: *gr and *ir are fully populated.
156   void CompleteInstanceFromInitializedIRec(const string& device,
157                                            CollectiveParams* cp,
158                                            InstanceRec* ir,
159                                            const StatusCallback& done)
160       TF_LOCKS_EXCLUDED(ir->mu);
161 
162   // Complete instance params after waiting for group.
163   // Precondition: *cp has complete group data and default_rank.
164   void WaitForGroup(InstanceRec* ir, CollectiveParams* cp, const IRConsumer& f)
165       TF_LOCKS_EXCLUDED(ir->mu);
166 
167   // If cp.device_names contains only devices local to this process
168   // populates *localities, else returns an error.
169   Status GetLocalDeviceLocalities(const CollectiveParams& cp,
170                                   std::vector<DeviceLocality>* localities);
171 
172   // Sets CollTaskParams.is_local and CollectiveParams.default_rank.
173   // Precondition: cp->device_names is fully populated and in final order.
174   void CompleteTaskIsLocal(const string& task_name, CollectiveParams* cp);
175 
176   // Sets cp->instance_default_rank according to location of device in
177   // current ordering of cp->instance.device_names.
178   void SetDefaultRank(const string& device, CollectiveParams* cp);
179 
180   // Sets cp->instance.type based on collective op type, and attempts to assign
181   // best implementation.
182   void AssignCollectiveType(CollectiveParams* cp);
183 
184   void StartAbortLocal(const Status& s)
185       TF_LOCKS_EXCLUDED(status_mu_, group_mu_, instance_mu_);
186 
187   const bool nccl_;
188   const DeviceMgr* dev_mgr_;
189   DeviceResolverInterface* dev_resolver_;  // Not owned.
190   NcclCommunicatorInterface* nccl_communicator_;  // Not owned.
191   string task_name_;
192   string gpu_ring_order_;
193   mutex group_mu_;
194   gtl::FlatMap<int32, std::unique_ptr<GroupRec>> group_table_
195       TF_GUARDED_BY(group_mu_);
196   mutex instance_mu_;
197   gtl::FlatMap<int32, gtl::FlatMap<int32, std::unique_ptr<InstanceRec>>>
198       instance_table_ TF_GUARDED_BY(instance_mu_);
199   mutex status_mu_;
200   Status status_ TF_GUARDED_BY(status_mu_);
201 };
202 
203 }  // namespace tensorflow
204 
205 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
206