1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_ 17 18 #include <functional> 19 #include <memory> 20 #include <set> 21 #include <string> 22 #include <unordered_map> 23 #include <vector> 24 25 #include "tensorflow/core/framework/collective.h" 26 #include "tensorflow/core/framework/device_attributes.pb.h" 27 #include "tensorflow/core/lib/gtl/flatmap.h" 28 #include "tensorflow/core/platform/thread_annotations.h" 29 30 namespace tensorflow { 31 class CompleteGroupRequest; 32 class CompleteGroupResponse; 33 class CompleteInstanceRequest; 34 class CompleteInstanceResponse; 35 class ConfigProto; 36 class DeviceMgr; 37 38 // Implements ParamResolverInterface for a single-task context. 39 // It also implements the functionality necessary to serve as the 40 // group leader for param resolution in a multi-task context. 41 class CollectiveParamResolverLocal : public ParamResolverInterface { 42 public: 43 CollectiveParamResolverLocal(const ConfigProto& config, 44 const DeviceMgr* dev_mgr, 45 DeviceResolverInterface* dev_resolver, 46 NcclCommunicatorInterface* nccl_communicator, 47 const string& task_name); 48 ~CollectiveParamResolverLocal()49 ~CollectiveParamResolverLocal() override {} 50 51 void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp, 52 CancellationManager* cancel_mgr, 53 const StatusCallback& done) override; 54 55 void CompleteGroupAsync(const DeviceAttributes& device, 56 CollGroupParams* group_params, 57 CancellationManager* cancel_mgr, 58 const StatusCallback& done) override; 59 60 void CompleteInstanceAsync(const CompleteInstanceRequest* request, 61 CompleteInstanceResponse* response, 62 CancellationManager* cancel_mgr, 63 const StatusCallback& done) override; 64 65 void StartAbort(const Status& s) override; 66 67 protected: 68 // For access to InstanceRec and CompleteDefaultRanking. 69 friend class CollectiveParamResolverLocalTest; 70 71 // Used to complete/verify CollGroup. 72 struct GroupRec { 73 mutable mutex mu; 74 CollGroupParams group TF_GUARDED_BY(mu); 75 Status status TF_GUARDED_BY(mu); 76 std::unordered_map<string, int64> incarnations_by_device_name 77 TF_GUARDED_BY(mu); 78 std::vector<CollGroupParams*> pending_params TF_GUARDED_BY(mu); 79 std::vector<StatusCallback> pending_done TF_GUARDED_BY(mu); 80 }; 81 82 // Finds the GroupRec that corresponds to group_params->group_key. 83 // Also populates group_params from that group_rec. 84 // Will wait until GroupRec is fully populated or an error arises before 85 // calling done. Callback GroupRec* arg is only valid if status is ok. 86 // Ownership of GroupRec stays with this object and does not pass to the 87 // callback. 88 void CompleteGroupLocal(const DeviceAttributes& device, 89 CollGroupParams* group_params, 90 CancellationManager* cancel_mgr, StatusCallback done) 91 TF_LOCKS_EXCLUDED(group_mu_); 92 93 // Finishes the group parameters once all members of the group are there. 94 void FinishGroup(GroupRec* gr) TF_EXCLUSIVE_LOCKS_REQUIRED(gr->mu); 95 96 // Cancels the group if it's still pending. 97 void CancelGroup(int32 group_key) TF_LOCKS_EXCLUDED(group_mu_); 98 99 // Used to complete/verify CollInstance. 100 struct InstanceRec; 101 102 typedef std::function<void(InstanceRec*)> IRConsumer; 103 struct InstanceRec { 104 mutex mu; 105 // Values to be shared by all instances, constant after initialization. 106 CollectiveParams* shared; 107 // If an error occurs during initialization this structure stays in the 108 // table with a non-OK status. Purging the table and restarting needs to be 109 // done at a higher level. 110 Status status TF_GUARDED_BY(mu); 111 112 // These fields are used to count the instances that have called 113 // in and become known while resolving broadcast source identity and 114 // communicator key. 115 int source_rank TF_GUARDED_BY(mu); 116 string communicator_key TF_GUARDED_BY(mu); 117 int known_count TF_GUARDED_BY(mu); 118 std::vector<bool> known TF_GUARDED_BY(mu); 119 std::vector<IRConsumer> known_waiters TF_GUARDED_BY(mu); 120 InstanceRecInstanceRec121 InstanceRec() 122 : shared(new CollectiveParams()), source_rank(-1), known_count(0) {} ~InstanceRecInstanceRec123 ~InstanceRec() { shared->Unref(); } 124 }; 125 126 // Find the InstanceRec with the same instance_key as cp. If it doesn't 127 // already exist, create and initialize from gr and cp. 128 // created is set to true if a new IRec is created, false otherwise. 129 // 130 // Precondition: *gr must be a complete GroupRec, i.e. the value set 131 // by CompleteGroupLocal. *cp must be populated with all the fields 132 // required by InitInstanceSharedParams. Ownership of InstanceRec stays 133 // with this object and does not pass to the callback. 134 InstanceRec* GetOrCreateInstanceRec(CollectiveParams* cp, bool* created) 135 TF_LOCKS_EXCLUDED(instance_mu_, group_mu_); 136 137 // Populate *ir with device membership from gr, then initialize to be specific 138 // to cp->instance_key, i.e. order the devices and tasks. 139 // 140 // Preconditions: 141 // cp is populated with all DeviceLocalities 142 void InitInstanceSharedParams(const CollectiveParams* cp, InstanceRec* ir); 143 144 // Establishes the final order of gp->device_names and gp->task_names by 145 // considering localities of all devices. 146 void CompleteDefaultRanking(CollGroupParams* gp); 147 148 // Finish populating *cp. 149 // Precondition: *gr has been fully populated by CompleteGroupLocal. 150 void CompleteInstanceLocal(const string& device, CollectiveParams* cp, 151 const StatusCallback& done) 152 TF_LOCKS_EXCLUDED(instance_mu_, group_mu_); 153 154 // Finish populating *cp from fully initialized *ir. 155 // Precondition: *gr and *ir are fully populated. 156 void CompleteInstanceFromInitializedIRec(const string& device, 157 CollectiveParams* cp, 158 InstanceRec* ir, 159 const StatusCallback& done) 160 TF_LOCKS_EXCLUDED(ir->mu); 161 162 // Complete instance params after waiting for group. 163 // Precondition: *cp has complete group data and default_rank. 164 void WaitForGroup(InstanceRec* ir, CollectiveParams* cp, const IRConsumer& f) 165 TF_LOCKS_EXCLUDED(ir->mu); 166 167 // If cp.device_names contains only devices local to this process 168 // populates *localities, else returns an error. 169 Status GetLocalDeviceLocalities(const CollectiveParams& cp, 170 std::vector<DeviceLocality>* localities); 171 172 // Sets CollTaskParams.is_local and CollectiveParams.default_rank. 173 // Precondition: cp->device_names is fully populated and in final order. 174 void CompleteTaskIsLocal(const string& task_name, CollectiveParams* cp); 175 176 // Sets cp->instance_default_rank according to location of device in 177 // current ordering of cp->instance.device_names. 178 void SetDefaultRank(const string& device, CollectiveParams* cp); 179 180 // Sets cp->instance.type based on collective op type, and attempts to assign 181 // best implementation. 182 void AssignCollectiveType(CollectiveParams* cp); 183 184 void StartAbortLocal(const Status& s) 185 TF_LOCKS_EXCLUDED(status_mu_, group_mu_, instance_mu_); 186 187 const bool nccl_; 188 const DeviceMgr* dev_mgr_; 189 DeviceResolverInterface* dev_resolver_; // Not owned. 190 NcclCommunicatorInterface* nccl_communicator_; // Not owned. 191 string task_name_; 192 string gpu_ring_order_; 193 mutex group_mu_; 194 gtl::FlatMap<int32, std::unique_ptr<GroupRec>> group_table_ 195 TF_GUARDED_BY(group_mu_); 196 mutex instance_mu_; 197 gtl::FlatMap<int32, gtl::FlatMap<int32, std::unique_ptr<InstanceRec>>> 198 instance_table_ TF_GUARDED_BY(instance_mu_); 199 mutex status_mu_; 200 Status status_ TF_GUARDED_BY(status_mu_); 201 }; 202 203 } // namespace tensorflow 204 205 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_ 206