1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ps/core/node.h"
18
19 namespace mindspore {
20 namespace ps {
21 namespace core {
node_id() const22 std::string Node::node_id() const { return node_info_.node_id_; }
23
rank_id() const24 uint32_t Node::rank_id() const { return node_info_.rank_id_; }
25
role() const26 NodeRole Node::role() const { return node_info_.node_role_; }
27
WaitForStart(const uint32_t & timeout)28 bool Node::WaitForStart(const uint32_t &timeout) {
29 MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is Waiting for start!";
30 std::unique_lock<std::mutex> lock(wait_start_mutex_);
31 bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [this] {
32 bool result = this->is_ready_.load();
33 if (result) {
34 MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success start!";
35 }
36 return result;
37 });
38 return res;
39 }
40
SendMessageSync(const std::shared_ptr<TcpClient> & client,const CommMessage & message,const uint32_t & timeout)41 bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
42 const uint32_t &timeout) {
43 MS_EXCEPTION_IF_NULL(client);
44 uint64_t request_id = AddMessageTrack(1);
45 const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
46 if (!client->SendMessage(message)) {
47 MS_LOG(WARNING) << "Client send message failed.";
48 }
49 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
50 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
51 return Wait(request_id, timeout);
52 }
53
SendMessageAsync(const std::shared_ptr<TcpClient> & client,const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size)54 bool Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta,
55 const Protos &protos, const void *data, size_t size) {
56 MS_EXCEPTION_IF_NULL(client);
57 MS_EXCEPTION_IF_NULL(meta);
58 MS_EXCEPTION_IF_NULL(data);
59 if (!client->SendMessage(meta, protos, data, size)) {
60 MS_LOG(WARNING) << "Client send message failed.";
61 return false;
62 }
63 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
64 << ", the node id is:" << node_info_.node_id_;
65 return true;
66 }
67
SendMessageSync(const std::shared_ptr<TcpClient> & client,const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size,const uint32_t & timeout)68 bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta,
69 const Protos &protos, const void *data, size_t size, const uint32_t &timeout) {
70 MS_EXCEPTION_IF_NULL(client);
71 MS_EXCEPTION_IF_NULL(meta);
72 MS_EXCEPTION_IF_NULL(data);
73 uint64_t request_id = AddMessageTrack(1);
74 meta->set_request_id(request_id);
75 if (!client->SendMessage(meta, protos, data, size)) {
76 MS_LOG(WARNING) << "Client send message failed.";
77 }
78 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
79 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
80 return Wait(request_id, timeout);
81 }
82
EnableRecovery() const83 bool Node::EnableRecovery() const {
84 MS_EXCEPTION_IF_NULL(config_);
85 return config_->Exists(kKeyRecovery);
86 }
87
Wait(uint64_t request_id,const uint32_t & timeout)88 bool Node::Wait(uint64_t request_id, const uint32_t &timeout) {
89 std::unique_lock<std::mutex> tracker_lock(message_tracker_mutex_);
90 bool res = message_tracker_cond_.wait_for(tracker_lock, std::chrono::seconds(timeout), [&] {
91 if (message_tracker_.count(request_id)) {
92 bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second;
93 return ret;
94 }
95 return false;
96 });
97 (void)message_tracker_.erase(request_id);
98 tracker_lock.unlock();
99
100 std::unique_lock<std::mutex> msgs_lock(receive_messages_mutex_);
101 // The messages should be already copied before message_tracker_cond_ is notified.
102 if (receive_messages_.count(request_id) != 0) {
103 (void)receive_messages_.erase(request_id);
104 }
105 msgs_lock.unlock();
106 return res;
107 }
108
AddMessageTrack(const uint32_t & expected_response)109 uint64_t Node::AddMessageTrack(const uint32_t &expected_response) {
110 std::lock_guard<std::mutex> lock(message_tracker_mutex_);
111 uint64_t request_id = ++next_request_id_;
112 message_tracker_[request_id] = std::make_pair(expected_response, 0);
113 return request_id;
114 }
115
CheckMessageTrack(const uint64_t & request_id)116 bool Node::CheckMessageTrack(const uint64_t &request_id) {
117 std::lock_guard<std::mutex> lock(message_tracker_mutex_);
118 if (message_tracker_.count(request_id)) {
119 return message_tracker_[request_id].first == message_tracker_[request_id].second + 1;
120 }
121 MS_LOG(INFO) << "The message tracker is not contain the id:" << request_id;
122 return false;
123 }
124
NotifyMessageArrival(const std::shared_ptr<MessageMeta> & meta)125 void Node::NotifyMessageArrival(const std::shared_ptr<MessageMeta> &meta) {
126 std::lock_guard<std::mutex> lock(message_tracker_mutex_);
127 uint64_t request_id = meta->request_id();
128 if (message_tracker_.count(request_id)) {
129 message_tracker_[request_id].second++;
130 message_tracker_cond_.notify_all();
131 }
132 }
133
set_message_callback(const uint64_t & request_id,const MessageCallback & callback)134 void Node::set_message_callback(const uint64_t &request_id, const MessageCallback &callback) {
135 if (!callback) {
136 return;
137 }
138 std::lock_guard<std::mutex> lock(message_callbacks_mutex_);
139 message_callbacks_[request_id] = callback;
140 }
141
ProcessSendDataResp(const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)142 void Node::ProcessSendDataResp(const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
143 size_t size) {
144 MS_EXCEPTION_IF_NULL(meta);
145 MS_EXCEPTION_IF_NULL(data);
146 std::lock_guard<std::mutex> lock(receive_messages_mutex_);
147 const uint32_t &rank_id = meta->rank_id();
148 const uint64_t request_id = meta->request_id();
149 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
150 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
151 if (meta->role() == NodeRole::SERVER) {
152 auto it = receive_messages_.find(request_id);
153 VectorPtr received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
154 if (size > 0) {
155 size_t dest_size = size;
156 size_t src_size = size;
157 errno_t ret = memcpy_s(received_data.get()->data(), dest_size, data, src_size);
158 if (ret != EOK) {
159 MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
160 }
161 }
162 if (it != receive_messages_.end()) {
163 it->second[rank_id] = received_data;
164 } else {
165 std::unordered_map<uint32_t, VectorPtr> res;
166 (void)res.insert(std::make_pair(rank_id, received_data));
167 receive_messages_[request_id] = res;
168 }
169 } else {
170 auto it = workder_receive_messages_.find(request_id);
171 VectorPtr received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
172 if (size > 0) {
173 size_t dest_size = size;
174 size_t src_size = size;
175 errno_t ret = memcpy_s(received_data.get()->data(), dest_size, data, src_size);
176 if (ret != EOK) {
177 MS_LOG(EXCEPTION) << "The memcpy_s error, errorno(" << ret << ")";
178 }
179 }
180 if (it != workder_receive_messages_.end()) {
181 it->second[rank_id] = received_data;
182 } else {
183 std::unordered_map<uint32_t, VectorPtr> res;
184 (void)res.insert(std::make_pair(rank_id, received_data));
185 workder_receive_messages_[request_id] = res;
186 }
187 }
188 }
189
RunMessageCallback(const uint64_t & request_id)190 void Node::RunMessageCallback(const uint64_t &request_id) {
191 message_callbacks_mutex_.lock();
192 // When receiving a message's response, Then compare with the desired number of responses,
193 // If they are equal, then call the callback function
194 if (CheckMessageTrack(request_id)) {
195 auto it = message_callbacks_.find(request_id);
196 if (it != message_callbacks_.end()) {
197 message_callbacks_mutex_.unlock();
198
199 if (it->second) {
200 it->second();
201 }
202
203 message_callbacks_mutex_.lock();
204 (void)message_callbacks_.erase(it);
205 }
206 }
207 message_callbacks_mutex_.unlock();
208 }
209 } // namespace core
210 } // namespace ps
211 } // namespace mindspore
212