1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "ps/core/node.h"
18
19 namespace mindspore {
20 namespace ps {
21 namespace core {
node_id() const22 std::string Node::node_id() const { return node_info_.node_id_; }
23
rank_id() const24 uint32_t Node::rank_id() const { return node_info_.rank_id_; }
25
role() const26 NodeRole Node::role() const { return node_info_.node_role_; }
27
BoundPort() const28 uint16_t Node::BoundPort() const { return node_info_.port_; }
29
BoundIp() const30 std::string Node::BoundIp() const { return node_info_.ip_; }
31
WaitForStart(const uint32_t & timeout)32 bool Node::WaitForStart(const uint32_t &timeout) {
33 std::unique_lock<std::mutex> lock(wait_start_mutex_);
34 bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [this] {
35 bool result = this->is_ready_.load();
36 if (result) {
37 MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success start!";
38 }
39 return result;
40 });
41 return res;
42 }
43
SendMessageSync(const std::shared_ptr<TcpClient> & client,const CommMessage & message,const uint32_t & timeout)44 bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
45 const uint32_t &timeout) {
46 MS_EXCEPTION_IF_NULL(client);
47 uint64_t request_id = AddMessageTrack(1);
48 const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
49 if (!client->SendMessage(message)) {
50 MS_LOG(WARNING) << "Client send message failed.";
51 }
52 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
53 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
54 return Wait(request_id, timeout);
55 }
56
SendMessageAsync(const std::shared_ptr<TcpClient> & client,const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size)57 uint64_t Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta,
58 const Protos &protos, const void *data, size_t size) {
59 MS_EXCEPTION_IF_NULL(client);
60 MS_EXCEPTION_IF_NULL(meta);
61 MS_EXCEPTION_IF_NULL(data);
62 uint64_t request_id = AddMessageTrack(1);
63 meta->set_request_id(request_id);
64 if (!client->SendMessage(meta, protos, data, size)) {
65 MS_LOG(WARNING) << "Client send message failed.";
66 }
67 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
68 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
69 return request_id;
70 }
71
SendMessageSync(const std::shared_ptr<TcpClient> & client,const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size,const uint32_t & timeout)72 bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta,
73 const Protos &protos, const void *data, size_t size, const uint32_t &timeout) {
74 MS_EXCEPTION_IF_NULL(client);
75 MS_EXCEPTION_IF_NULL(meta);
76 MS_EXCEPTION_IF_NULL(data);
77 uint64_t request_id = AddMessageTrack(1);
78 meta->set_request_id(request_id);
79 if (!client->SendMessage(meta, protos, data, size)) {
80 MS_LOG(WARNING) << "Client send message failed.";
81 }
82 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
83 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
84 return Wait(request_id, timeout);
85 }
86
Wait(uint64_t request_id,const uint32_t & timeout)87 bool Node::Wait(uint64_t request_id, const uint32_t &timeout) {
88 std::unique_lock<std::mutex> tracker_lock(message_tracker_mutex_);
89 bool res = message_tracker_cond_.wait_for(tracker_lock, std::chrono::seconds(timeout), [&] {
90 if (message_tracker_.count(request_id)) {
91 bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second;
92 return ret;
93 }
94 return false;
95 });
96 (void)message_tracker_.erase(request_id);
97 tracker_lock.unlock();
98
99 std::unique_lock<std::mutex> msgs_lock(receive_messages_mutex_);
100 if (receive_messages_.count(request_id) != 0) {
101 (void)receive_messages_.erase(request_id);
102 }
103 msgs_lock.unlock();
104 return res;
105 }
106
AddMessageTrack(const uint32_t & expected_response)107 uint64_t Node::AddMessageTrack(const uint32_t &expected_response) {
108 std::lock_guard<std::mutex> lock(message_tracker_mutex_);
109 uint64_t request_id = ++next_request_id_;
110 message_tracker_[request_id] = std::make_pair(expected_response, 0);
111 return request_id;
112 }
113
CheckMessageTrack(const uint64_t & request_id)114 bool Node::CheckMessageTrack(const uint64_t &request_id) {
115 std::lock_guard<std::mutex> lock(message_tracker_mutex_);
116 if (message_tracker_.count(request_id)) {
117 return message_tracker_[request_id].first == message_tracker_[request_id].second + 1;
118 }
119 MS_LOG(INFO) << "The message tracker is not contain the id:" << request_id;
120 return false;
121 }
122
NotifyMessageArrival(const std::shared_ptr<MessageMeta> & meta)123 void Node::NotifyMessageArrival(const std::shared_ptr<MessageMeta> &meta) {
124 std::lock_guard<std::mutex> lock(message_tracker_mutex_);
125 uint64_t request_id = meta->request_id();
126 if (message_tracker_.count(request_id)) {
127 message_tracker_[request_id].second++;
128 message_tracker_cond_.notify_all();
129 }
130 }
131
set_message_callback(const uint64_t & request_id,const MessageCallback & callback)132 void Node::set_message_callback(const uint64_t &request_id, const MessageCallback &callback) {
133 if (!callback) {
134 return;
135 }
136 std::lock_guard<std::mutex> lock(message_callbacks_mutex_);
137 message_callbacks_[request_id] = callback;
138 }
139
ProcessSendDataResp(const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)140 void Node::ProcessSendDataResp(const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
141 size_t size) {
142 MS_EXCEPTION_IF_NULL(meta);
143 MS_EXCEPTION_IF_NULL(data);
144 std::lock_guard<std::mutex> lock(receive_messages_mutex_);
145 const uint32_t &rank_id = meta->rank_id();
146 const uint64_t request_id = meta->request_id();
147 MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
148 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
149 if (meta->role() == NodeRole::SERVER) {
150 auto it = receive_messages_.find(request_id);
151 VectorPtr received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
152 if (size > 0) {
153 size_t dest_size = size;
154 size_t src_size = size;
155 if (memcpy_s(received_data.get()->data(), dest_size, data, src_size) != EOK) {
156 MS_LOG(EXCEPTION) << "The memcpy_s error";
157 }
158 }
159 if (it != receive_messages_.end()) {
160 it->second[rank_id] = received_data;
161 } else {
162 std::unordered_map<uint32_t, VectorPtr> res;
163 (void)res.insert(std::make_pair(rank_id, received_data));
164 receive_messages_[request_id] = res;
165 }
166 } else {
167 auto it = workder_receive_messages_.find(request_id);
168 VectorPtr received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
169 if (size > 0) {
170 size_t dest_size = size;
171 size_t src_size = size;
172 if (memcpy_s(received_data.get()->data(), dest_size, data, src_size) != EOK) {
173 MS_LOG(EXCEPTION) << "The memcpy_s error";
174 }
175 }
176 if (it != workder_receive_messages_.end()) {
177 it->second[rank_id] = received_data;
178 } else {
179 std::unordered_map<uint32_t, VectorPtr> res;
180 (void)res.insert(std::make_pair(rank_id, received_data));
181 workder_receive_messages_[request_id] = res;
182 }
183 }
184 }
185
RunMessageCallback(const uint64_t & request_id)186 void Node::RunMessageCallback(const uint64_t &request_id) {
187 message_callbacks_mutex_.lock();
188 // When receiving a message's response, Then compare with the desired number of responses,
189 // If they are equal, then call the callback function
190 if (CheckMessageTrack(request_id)) {
191 auto it = message_callbacks_.find(request_id);
192 if (it != message_callbacks_.end()) {
193 message_callbacks_mutex_.unlock();
194
195 if (it->second) {
196 it->second();
197 }
198
199 message_callbacks_mutex_.lock();
200 message_callbacks_.erase(it);
201 }
202 }
203 message_callbacks_mutex_.unlock();
204 }
205 } // namespace core
206 } // namespace ps
207 } // namespace mindspore
208