1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
16
17 #include "tensorflow/core/common_runtime/base_collective_executor.h"
18 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
19 #include "tensorflow/core/common_runtime/collective_rma_local.h"
20 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
21 #include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
22 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
23 #include "tensorflow/core/distributed_runtime/worker_cache.h"
24 #include "tensorflow/core/lib/random/random.h"
25
26 namespace tensorflow {
27
RpcCollectiveExecutorMgr(const ConfigProto & config,const DeviceMgr * dev_mgr,std::unique_ptr<DeviceResolverDistributed> dev_resolver,std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,WorkerCacheInterface * worker_cache,const string & task_name)28 RpcCollectiveExecutorMgr::RpcCollectiveExecutorMgr(
29 const ConfigProto& config, const DeviceMgr* dev_mgr,
30 std::unique_ptr<DeviceResolverDistributed> dev_resolver,
31 std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,
32 WorkerCacheInterface* worker_cache, const string& task_name)
33 : CollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver),
34 std::move(param_resolver)),
35 worker_cache_(worker_cache),
36 task_name_(task_name) {
37 group_leader_ = (task_name == config.experimental().collective_group_leader())
38 ? ""
39 : config.experimental().collective_group_leader();
40 }
41
~RpcCollectiveExecutorMgr()42 RpcCollectiveExecutorMgr::~RpcCollectiveExecutorMgr() {
43 for (auto it : sequence_table_) {
44 delete it.second;
45 }
46 }
47
Create(int64 step_id)48 CollectiveExecutor* RpcCollectiveExecutorMgr::Create(int64 step_id) {
49 CollectiveRemoteAccessDistributed* rma =
50 new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(),
51 worker_cache_, step_id);
52 return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_,
53 &gpu_ring_order_);
54 }
55
56 namespace {
57 // StepId must leave the most-significant 7 bits empty for future use.
58 static const int64 kStepIdMask = (((1uLL << 56) - 1) | (1uLL << 56));
59
NewRandomStepId()60 int64 NewRandomStepId() {
61 int64 step_id = random::New64();
62 // Leave MS 8 bits clear for future use.
63 step_id &= kStepIdMask;
64 return step_id;
65 }
66 } // namespace
67
RefreshStepIdSequenceAsync(int64 graph_key,const StatusCallback & done)68 void RpcCollectiveExecutorMgr::RefreshStepIdSequenceAsync(
69 int64 graph_key, const StatusCallback& done) {
70 if (group_leader_.empty()) {
71 mutex_lock l(sequence_mu_);
72 GraphKeySequence* gks = nullptr;
73 auto it = sequence_table_.find(graph_key);
74 if (it == sequence_table_.end()) {
75 gks = new GraphKeySequence(graph_key);
76 sequence_table_[graph_key] = gks;
77 } else {
78 gks = it->second;
79 }
80 gks->next_step_id_ = NewRandomStepId();
81 done(Status::OK());
82 } else {
83 WorkerInterface* wi = worker_cache_->CreateWorker(group_leader_);
84 GetStepSequenceRequest* req = new GetStepSequenceRequest;
85 GetStepSequenceResponse* resp = new GetStepSequenceResponse;
86 req->add_graph_key(graph_key);
87 wi->GetStepSequenceAsync(
88 req, resp, [this, req, resp, done](const Status& s) {
89 if (!s.ok()) {
90 LOG(ERROR) << "Bad response [" << s
91 << "] from GetStepSequenceAsync call to "
92 << group_leader_;
93 done(s);
94 } else {
95 done(UpdateStepSequences(*resp));
96 }
97 delete req;
98 delete resp;
99 });
100 }
101 }
102
GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,const StatusCallback & done)103 void RpcCollectiveExecutorMgr::GetStepSequenceAsync(
104 const GetStepSequenceRequest* request, GetStepSequenceResponse* response,
105 const StatusCallback& done) {
106 if (!group_leader_.empty()) {
107 LOG(ERROR) << "GetStepSequence called at non-group-leader";
108 done(errors::Internal("GetStepSequenceAsync called at non-group-leader"));
109 } else {
110 mutex_lock l(sequence_mu_);
111 for (int64 graph_key : request->graph_key()) {
112 auto it = sequence_table_.find(graph_key);
113 GraphKeySequence* gks = nullptr;
114 if (it == sequence_table_.end()) {
115 gks = new GraphKeySequence(graph_key);
116 gks->next_step_id_ = NewRandomStepId();
117 sequence_table_[graph_key] = gks;
118 } else {
119 gks = it->second;
120 }
121 StepSequence* ss = response->add_step_sequence();
122 ss->set_graph_key(graph_key);
123 ss->set_next_step_id(gks->next_step_id_);
124 }
125 done(Status::OK());
126 }
127 }
128
UpdateStepSequences(const GetStepSequenceResponse & resp)129 Status RpcCollectiveExecutorMgr::UpdateStepSequences(
130 const GetStepSequenceResponse& resp) {
131 mutex_lock l(sequence_mu_);
132 for (const StepSequence& ss : resp.step_sequence()) {
133 GraphKeySequence* gks = nullptr;
134 auto it = sequence_table_.find(ss.graph_key());
135 if (it == sequence_table_.end()) {
136 gks = new GraphKeySequence(ss.graph_key());
137 sequence_table_[ss.graph_key()] = gks;
138 } else {
139 gks = it->second;
140 }
141 gks->next_step_id_ = ss.next_step_id();
142 }
143 return Status::OK();
144 }
145
NextStepId(int64 graph_key)146 int64 RpcCollectiveExecutorMgr::NextStepId(int64 graph_key) {
147 mutex_lock l(sequence_mu_);
148 auto it = sequence_table_.find(graph_key);
149 if (it != sequence_table_.end()) {
150 return it->second->next_step_id_;
151 }
152 return CollectiveExecutor::kInvalidId;
153 }
154
RetireStepId(int64 graph_key,int64 step_id)155 void RpcCollectiveExecutorMgr::RetireStepId(int64 graph_key, int64 step_id) {
156 mutex_lock l(sequence_mu_);
157 auto it = sequence_table_.find(graph_key);
158 if (it != sequence_table_.end()) {
159 if (step_id == it->second->next_step_id_) {
160 it->second->next_step_id_ = (it->second->next_step_id_ + 1) & kStepIdMask;
161 } else {
162 it->second->next_step_id_ = CollectiveExecutor::kInvalidId;
163 }
164 } else {
165 LOG(ERROR) << "Failed to find graph_key " << graph_key << " to retire.";
166 }
167 }
168
169 } // namespace tensorflow
170