1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_ 17 18 #include <functional> 19 #include <memory> 20 #include <set> 21 #include <string> 22 #include <vector> 23 24 #include "tensorflow/core/framework/collective.h" 25 #include "tensorflow/core/lib/gtl/flatmap.h" 26 #include "tensorflow/core/protobuf/config.pb.h" 27 28 namespace tensorflow { 29 class CompleteGroupRequest; 30 class CompleteGroupResponse; 31 class CompleteInstanceRequest; 32 class CompleteInstanceResponse; 33 class DeviceMgr; 34 35 // Implements ParamResolverInterface for a single-task context. 36 // It also implements the functionality necessary to serve as the 37 // group leader for param resolution in a multi-task context. 38 class CollectiveParamResolverLocal : public ParamResolverInterface { 39 public: 40 CollectiveParamResolverLocal(const ConfigProto& config, 41 const DeviceMgr* dev_mgr, 42 DeviceResolverInterface* dev_resolver, 43 const string& task_name); 44 ~CollectiveParamResolverLocal()45 ~CollectiveParamResolverLocal() override {} 46 47 void CompleteParamsAsync(const string& device, CollectiveParams* cp, 48 CancellationManager* cancel_mgr, 49 const StatusCallback& done) override; 50 51 void CompleteGroupAsync(const CompleteGroupRequest* request, 52 CompleteGroupResponse* response, 53 CancellationManager* cancel_mgr, 54 const StatusCallback& done) override; 55 56 void CompleteInstanceAsync(const CompleteInstanceRequest* request, 57 CompleteInstanceResponse* response, 58 CancellationManager* cancel_mgr, 59 const StatusCallback& done) override; 60 61 protected: 62 // For access to InstanceRec and CompleteDefaultRanking. 63 friend class CollectiveParamResolverLocalTest; 64 65 // Used to complete/verify CollGroup. 66 struct GroupRec { 67 CollGroupParams group; 68 mutable mutex mu; 69 Status status GUARDED_BY(mu); 70 std::set<string> device_set GUARDED_BY(mu); 71 std::vector<string> device_list GUARDED_BY(mu); 72 std::set<string> task_set GUARDED_BY(mu); 73 std::vector<string> task_list GUARDED_BY(mu); 74 std::vector<StatusCallback> waiting GUARDED_BY(mu); 75 }; 76 77 // Finds the GroupRec that corresponds to cp->group_key. 78 // Also populates cp->group from that group_rec. 79 // Will wait until GroupRec is fully populated or an error arises before 80 // calling done. Callback GroupRec* arg is only valid if status is ok. 81 // Ownership of GroupRec stays with this object and does not pass to the 82 // callback. 83 typedef std::function<void(const Status& s, const GroupRec* gr)> 84 GroupRecCallback; 85 void CompleteGroupLocal(const string& device, CollectiveParams* cp, 86 const GroupRecCallback& done) 87 LOCKS_EXCLUDED(group_mu_); 88 89 // Used to complete/verify CollInstance. 90 struct InstanceRec; 91 92 typedef std::function<void(InstanceRec*)> IRConsumer; 93 struct InstanceRec { 94 // This structure has two mutexes so that a possibly long 95 // initialization can be done without holding the instance_mu_ 96 // table lock the whole time (which can cause an excessive number 97 // of threads to block on it), and because the compiler may not 98 // permit mutex locks to be taken in more than one order. 99 // 100 // out_mu guards access to most of the fields. 101 // in_mu guards access to a queue of consumer callbacks wanting to 102 // read the fields guarded by out_mu. 103 // 104 // The in_mu should be locked only while holding instance_mu_; the 105 // out_mu should be locked only while not holding 106 // instance_mu_. 107 // 108 // When is_init is false (the initial value) any potential user 109 // other than the creator should queue a callback on init_waiters. 110 // As soon as the shared member of this structure is fully 111 // initialized is_init will be set true and those callbacks will 112 // be invoked. 113 // 114 // Once inserted in the table this structure will never be replaced 115 // so users can capture the pointer while holding instance_mu_, 116 // drop that lock, then take a lock on out_mu before 117 // reading/modifying its values. 118 mutex in_mu; 119 bool is_init GUARDED_BY(in_mu); 120 std::vector<IRConsumer> init_waiters GUARDED_BY(in_mu); 121 122 // A thread that wishes to acquire out_mu must ensure that it is available 123 // by invoking WaitForOutMu(). 124 mutex out_mu; 125 condition_variable out_cv; 126 bool out_mu_available GUARDED_BY(out_mu); 127 // Values to be shared by all instances, constant after initialization. 128 CollectiveParams shared GUARDED_BY(out_mu); 129 // If an error occurs during initialization this structure stays in 130 // the table with a non-OK status. Purging the table and restarting 131 // needs to be done at a higher level. 132 Status status GUARDED_BY(out_mu); 133 134 // These fields are used to count the instances that have called 135 // in and become known while resolving broadcast source identity and 136 // communicator key. 137 int source_rank GUARDED_BY(out_mu); 138 string communicator_key GUARDED_BY(out_mu); 139 int known_count GUARDED_BY(out_mu); 140 std::vector<bool> known GUARDED_BY(out_mu); 141 std::vector<IRConsumer> known_waiters GUARDED_BY(out_mu); 142 InstanceRecInstanceRec143 InstanceRec() 144 : is_init(false), 145 out_mu_available(true), 146 source_rank(-1), 147 known_count(0) {} 148 149 // If out_mu is unavailable during distributed device locality 150 // initialization, wait on out_cv until it is available again. 151 void WaitForOutMu(mutex_lock& lock) EXCLUSIVE_LOCKS_REQUIRED(out_mu); 152 }; 153 154 // Find the InstanceRec with the same instance_key as cp. If it doesn't 155 // already exist, create and initialize from gr and cp. 156 // 157 // Precondition: *gr must be a complete GroupRec, i.e. the value set 158 // by CompleteGroupLocal. *cp must be populated with all the fields 159 // required by InitInstanceSharedParams. Ownership of InstanceRec stays 160 // with this object and does not pass to the callback. 161 typedef std::function<void(const Status& s, InstanceRec* ir)> 162 InstanceRecCallback; 163 void FindInstanceRec(const GroupRec* gr, CollectiveParams* cp, 164 const InstanceRecCallback& done) 165 LOCKS_EXCLUDED(instance_mu_, gr->mu, group_mu_); 166 167 // Populate *ir with device membership from gr, then initialize to be specific 168 // to cp->instance_key, i.e. order the devices and tasks. 169 // 170 // Preconditions: 171 // cp is populated with all DeviceLocalities 172 void InitInstanceSharedParams(const GroupRec* gr, const CollectiveParams* cp, 173 InstanceRec* ir, const StatusCallback& done) 174 UNLOCK_FUNCTION(ir->out_mu) LOCKS_EXCLUDED(gr->mu); 175 176 void CallInitInstanceSharedParams(const GroupRec* gr, 177 const CollectiveParams* cp, InstanceRec* ir, 178 const InstanceRecCallback& done) 179 LOCKS_EXCLUDED(ir->out_mu, gr->mu); 180 181 // Establishes the final order of ir->shared.instance.device_names and 182 // ir->shared.instance.task_names by considering localities of all devices. 183 void CompleteDefaultRanking(const GroupRec* gr, const CollectiveParams* cp, 184 InstanceRec* ir, 185 const std::vector<DeviceLocality>& localities) 186 EXCLUSIVE_LOCKS_REQUIRED(ir->out_mu); 187 188 // Finish populating *cp. 189 // Precondition: *gr has been fully populated by CompleteGroupLocal. 190 void CompleteInstanceLocal(const string& device, const GroupRec* gr, 191 CollectiveParams* cp, bool is_source, 192 const StatusCallback& done) 193 LOCKS_EXCLUDED(instance_mu_, gr->mu, group_mu_); 194 195 // Finish populating *cp from fully initialized *ir. 196 // Precondition: *gr and *ir are fully populated. 197 void CompleteInstanceFromInitializedIRec(const string& device, 198 const GroupRec* gr, 199 CollectiveParams* cp, 200 InstanceRec* ir, bool is_source, 201 const StatusCallback& done) 202 LOCKS_EXCLUDED(ir->out_mu); 203 204 // Complete source data and/or nccl communicator key. 205 // Precondition: *cp has complete group data and default_rank. 206 void WaitForGroup(InstanceRec* ir, CollectiveParams* cp, bool is_source, 207 bool init_source, bool init_nccl, const IRConsumer& f) 208 LOCKS_EXCLUDED(ir->out_mu); 209 210 // If cp.device_names contains only devices local to this process 211 // populates *localities, else returns an error. 212 Status GetLocalDeviceLocalities(const CollectiveParams& cp, 213 std::vector<DeviceLocality>* localities); 214 215 // Sets CollTaskParams.is_local and CollectiveParams.default_rank. 216 // Precondition: cp->device_names is fully populated and in final order. 217 void CompleteTaskIsLocal(const string& task_name, CollectiveParams* cp); 218 219 // Sets cp->instance_default_rank according to location of device in 220 // current ordering of cp->instance.device_names. 221 void SetDefaultRank(const string& device, CollectiveParams* cp); 222 223 // Sets cp->instance.type based on collective op type, and attempts to assign 224 // best implementation. 225 void AssignCollectiveType(CollectiveParams* cp); 226 227 // Helper to grab status under lock, invoke callback out of lock. 228 void CallbackWithStatus(const InstanceRecCallback& done, InstanceRec* irec) 229 LOCKS_EXCLUDED(irec->out_mu); 230 231 const bool nccl_; 232 const DeviceMgr* dev_mgr_; 233 DeviceResolverInterface* dev_resolver_; // Not owned. 234 string task_name_; 235 mutex group_mu_; 236 gtl::FlatMap<int32, std::unique_ptr<GroupRec>> group_table_ 237 GUARDED_BY(group_mu_); 238 mutex instance_mu_; 239 gtl::FlatMap<int32, std::unique_ptr<InstanceRec>> instance_table_ 240 GUARDED_BY(instance_mu_); 241 }; 242 243 } // namespace tensorflow 244 245 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_ 246