• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ps/core/node.h"
18 
19 namespace mindspore {
20 namespace ps {
21 namespace core {
node_id() const22 std::string Node::node_id() const { return node_info_.node_id_; }
23 
rank_id() const24 uint32_t Node::rank_id() const { return node_info_.rank_id_; }
25 
role() const26 NodeRole Node::role() const { return node_info_.node_role_; }
27 
WaitForStart(const uint32_t & timeout)28 bool Node::WaitForStart(const uint32_t &timeout) {
29   MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is Waiting for start!";
30   std::unique_lock<std::mutex> lock(wait_start_mutex_);
31   bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [this] {
32     bool result = this->is_ready_.load();
33     if (result) {
34       MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success start!";
35     }
36     return result;
37   });
38   return res;
39 }
40 
SendMessageSync(const std::shared_ptr<TcpClient> & client,const CommMessage & message,const uint32_t & timeout)41 bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
42                            const uint32_t &timeout) {
43   MS_EXCEPTION_IF_NULL(client);
44   uint64_t request_id = AddMessageTrack(1);
45   const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
46   if (!client->SendMessage(message)) {
47     MS_LOG(WARNING) << "Client send message failed.";
48   }
49   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
50                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
51   return Wait(request_id, timeout);
52 }
53 
SendMessageAsync(const std::shared_ptr<TcpClient> & client,const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size)54 bool Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta,
55                             const Protos &protos, const void *data, size_t size) {
56   MS_EXCEPTION_IF_NULL(client);
57   MS_EXCEPTION_IF_NULL(meta);
58   MS_EXCEPTION_IF_NULL(data);
59   if (!client->SendMessage(meta, protos, data, size)) {
60     MS_LOG(WARNING) << "Client send message failed.";
61     return false;
62   }
63   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
64                 << ", the node id is:" << node_info_.node_id_;
65   return true;
66 }
67 
SendMessageSync(const std::shared_ptr<TcpClient> & client,const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size,const uint32_t & timeout)68 bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta,
69                            const Protos &protos, const void *data, size_t size, const uint32_t &timeout) {
70   MS_EXCEPTION_IF_NULL(client);
71   MS_EXCEPTION_IF_NULL(meta);
72   MS_EXCEPTION_IF_NULL(data);
73   uint64_t request_id = AddMessageTrack(1);
74   meta->set_request_id(request_id);
75   if (!client->SendMessage(meta, protos, data, size)) {
76     MS_LOG(WARNING) << "Client send message failed.";
77   }
78   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
79                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
80   return Wait(request_id, timeout);
81 }
82 
EnableRecovery() const83 bool Node::EnableRecovery() const {
84   MS_EXCEPTION_IF_NULL(config_);
85   return config_->Exists(kKeyRecovery);
86 }
87 
Wait(uint64_t request_id,const uint32_t & timeout)88 bool Node::Wait(uint64_t request_id, const uint32_t &timeout) {
89   std::unique_lock<std::mutex> tracker_lock(message_tracker_mutex_);
90   bool res = message_tracker_cond_.wait_for(tracker_lock, std::chrono::seconds(timeout), [&] {
91     if (message_tracker_.count(request_id)) {
92       bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second;
93       return ret;
94     }
95     return false;
96   });
97   (void)message_tracker_.erase(request_id);
98   tracker_lock.unlock();
99 
100   std::unique_lock<std::mutex> msgs_lock(receive_messages_mutex_);
101   // The messages should be already copied before message_tracker_cond_ is notified.
102   if (receive_messages_.count(request_id) != 0) {
103     (void)receive_messages_.erase(request_id);
104   }
105   msgs_lock.unlock();
106   return res;
107 }
108 
AddMessageTrack(const uint32_t & expected_response)109 uint64_t Node::AddMessageTrack(const uint32_t &expected_response) {
110   std::lock_guard<std::mutex> lock(message_tracker_mutex_);
111   uint64_t request_id = ++next_request_id_;
112   message_tracker_[request_id] = std::make_pair(expected_response, 0);
113   return request_id;
114 }
115 
CheckMessageTrack(const uint64_t & request_id)116 bool Node::CheckMessageTrack(const uint64_t &request_id) {
117   std::lock_guard<std::mutex> lock(message_tracker_mutex_);
118   if (message_tracker_.count(request_id)) {
119     return message_tracker_[request_id].first == message_tracker_[request_id].second + 1;
120   }
121   MS_LOG(INFO) << "The message tracker is not contain the id:" << request_id;
122   return false;
123 }
124 
NotifyMessageArrival(const std::shared_ptr<MessageMeta> & meta)125 void Node::NotifyMessageArrival(const std::shared_ptr<MessageMeta> &meta) {
126   std::lock_guard<std::mutex> lock(message_tracker_mutex_);
127   uint64_t request_id = meta->request_id();
128   if (message_tracker_.count(request_id)) {
129     message_tracker_[request_id].second++;
130     message_tracker_cond_.notify_all();
131   }
132 }
133 
set_message_callback(const uint64_t & request_id,const MessageCallback & callback)134 void Node::set_message_callback(const uint64_t &request_id, const MessageCallback &callback) {
135   if (!callback) {
136     return;
137   }
138   std::lock_guard<std::mutex> lock(message_callbacks_mutex_);
139   message_callbacks_[request_id] = callback;
140 }
141 
ProcessSendDataResp(const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)142 void Node::ProcessSendDataResp(const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
143                                size_t size) {
144   MS_EXCEPTION_IF_NULL(meta);
145   MS_EXCEPTION_IF_NULL(data);
146   std::lock_guard<std::mutex> lock(receive_messages_mutex_);
147   const uint32_t &rank_id = meta->rank_id();
148   const uint64_t request_id = meta->request_id();
149   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
150                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
151   if (meta->role() == NodeRole::SERVER) {
152     auto it = receive_messages_.find(request_id);
153     VectorPtr received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
154     if (size > 0) {
155       size_t dest_size = size;
156       size_t src_size = size;
157       errno_t ret = memcpy_s(received_data.get()->data(), dest_size, data, src_size);
158       if (ret != EOK) {
159         MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
160       }
161     }
162     if (it != receive_messages_.end()) {
163       it->second[rank_id] = received_data;
164     } else {
165       std::unordered_map<uint32_t, VectorPtr> res;
166       (void)res.insert(std::make_pair(rank_id, received_data));
167       receive_messages_[request_id] = res;
168     }
169   } else {
170     auto it = workder_receive_messages_.find(request_id);
171     VectorPtr received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
172     if (size > 0) {
173       size_t dest_size = size;
174       size_t src_size = size;
175       errno_t ret = memcpy_s(received_data.get()->data(), dest_size, data, src_size);
176       if (ret != EOK) {
177         MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
178       }
179     }
180     if (it != workder_receive_messages_.end()) {
181       it->second[rank_id] = received_data;
182     } else {
183       std::unordered_map<uint32_t, VectorPtr> res;
184       (void)res.insert(std::make_pair(rank_id, received_data));
185       workder_receive_messages_[request_id] = res;
186     }
187   }
188 }
189 
RunMessageCallback(const uint64_t & request_id)190 void Node::RunMessageCallback(const uint64_t &request_id) {
191   message_callbacks_mutex_.lock();
192   // When receiving a message's response, Then compare with the desired number of responses,
193   // If they are equal, then call the callback function
194   if (CheckMessageTrack(request_id)) {
195     auto it = message_callbacks_.find(request_id);
196     if (it != message_callbacks_.end()) {
197       message_callbacks_mutex_.unlock();
198 
199       if (it->second) {
200         it->second();
201       }
202 
203       message_callbacks_mutex_.lock();
204       (void)message_callbacks_.erase(it);
205     }
206   }
207   message_callbacks_mutex_.unlock();
208 }
209 }  // namespace core
210 }  // namespace ps
211 }  // namespace mindspore
212