• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "ps/core/node.h"
18 
19 namespace mindspore {
20 namespace ps {
21 namespace core {
node_id() const22 std::string Node::node_id() const { return node_info_.node_id_; }
23 
rank_id() const24 uint32_t Node::rank_id() const { return node_info_.rank_id_; }
25 
role() const26 NodeRole Node::role() const { return node_info_.node_role_; }
27 
BoundPort() const28 uint16_t Node::BoundPort() const { return node_info_.port_; }
29 
BoundIp() const30 std::string Node::BoundIp() const { return node_info_.ip_; }
31 
WaitForStart(const uint32_t & timeout)32 bool Node::WaitForStart(const uint32_t &timeout) {
33   std::unique_lock<std::mutex> lock(wait_start_mutex_);
34   bool res = wait_start_cond_.wait_for(lock, std::chrono::seconds(timeout), [this] {
35     bool result = this->is_ready_.load();
36     if (result) {
37       MS_LOG(INFO) << "The node id:" << node_info_.node_id_ << " is success start!";
38     }
39     return result;
40   });
41   return res;
42 }
43 
SendMessageSync(const std::shared_ptr<TcpClient> & client,const CommMessage & message,const uint32_t & timeout)44 bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const CommMessage &message,
45                            const uint32_t &timeout) {
46   MS_EXCEPTION_IF_NULL(client);
47   uint64_t request_id = AddMessageTrack(1);
48   const_cast<CommMessage &>(message).mutable_pb_meta()->set_request_id(request_id);
49   if (!client->SendMessage(message)) {
50     MS_LOG(WARNING) << "Client send message failed.";
51   }
52   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
53                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
54   return Wait(request_id, timeout);
55 }
56 
SendMessageAsync(const std::shared_ptr<TcpClient> & client,const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size)57 uint64_t Node::SendMessageAsync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta,
58                                 const Protos &protos, const void *data, size_t size) {
59   MS_EXCEPTION_IF_NULL(client);
60   MS_EXCEPTION_IF_NULL(meta);
61   MS_EXCEPTION_IF_NULL(data);
62   uint64_t request_id = AddMessageTrack(1);
63   meta->set_request_id(request_id);
64   if (!client->SendMessage(meta, protos, data, size)) {
65     MS_LOG(WARNING) << "Client send message failed.";
66   }
67   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
68                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
69   return request_id;
70 }
71 
SendMessageSync(const std::shared_ptr<TcpClient> & client,const std::shared_ptr<MessageMeta> & meta,const Protos & protos,const void * data,size_t size,const uint32_t & timeout)72 bool Node::SendMessageSync(const std::shared_ptr<TcpClient> &client, const std::shared_ptr<MessageMeta> &meta,
73                            const Protos &protos, const void *data, size_t size, const uint32_t &timeout) {
74   MS_EXCEPTION_IF_NULL(client);
75   MS_EXCEPTION_IF_NULL(meta);
76   MS_EXCEPTION_IF_NULL(data);
77   uint64_t request_id = AddMessageTrack(1);
78   meta->set_request_id(request_id);
79   if (!client->SendMessage(meta, protos, data, size)) {
80     MS_LOG(WARNING) << "Client send message failed.";
81   }
82   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
83                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
84   return Wait(request_id, timeout);
85 }
86 
Wait(uint64_t request_id,const uint32_t & timeout)87 bool Node::Wait(uint64_t request_id, const uint32_t &timeout) {
88   std::unique_lock<std::mutex> tracker_lock(message_tracker_mutex_);
89   bool res = message_tracker_cond_.wait_for(tracker_lock, std::chrono::seconds(timeout), [&] {
90     if (message_tracker_.count(request_id)) {
91       bool ret = message_tracker_[request_id].first == message_tracker_[request_id].second;
92       return ret;
93     }
94     return false;
95   });
96   (void)message_tracker_.erase(request_id);
97   tracker_lock.unlock();
98 
99   std::unique_lock<std::mutex> msgs_lock(receive_messages_mutex_);
100   if (receive_messages_.count(request_id) != 0) {
101     (void)receive_messages_.erase(request_id);
102   }
103   msgs_lock.unlock();
104   return res;
105 }
106 
AddMessageTrack(const uint32_t & expected_response)107 uint64_t Node::AddMessageTrack(const uint32_t &expected_response) {
108   std::lock_guard<std::mutex> lock(message_tracker_mutex_);
109   uint64_t request_id = ++next_request_id_;
110   message_tracker_[request_id] = std::make_pair(expected_response, 0);
111   return request_id;
112 }
113 
CheckMessageTrack(const uint64_t & request_id)114 bool Node::CheckMessageTrack(const uint64_t &request_id) {
115   std::lock_guard<std::mutex> lock(message_tracker_mutex_);
116   if (message_tracker_.count(request_id)) {
117     return message_tracker_[request_id].first == message_tracker_[request_id].second + 1;
118   }
119   MS_LOG(INFO) << "The message tracker is not contain the id:" << request_id;
120   return false;
121 }
122 
NotifyMessageArrival(const std::shared_ptr<MessageMeta> & meta)123 void Node::NotifyMessageArrival(const std::shared_ptr<MessageMeta> &meta) {
124   std::lock_guard<std::mutex> lock(message_tracker_mutex_);
125   uint64_t request_id = meta->request_id();
126   if (message_tracker_.count(request_id)) {
127     message_tracker_[request_id].second++;
128     message_tracker_cond_.notify_all();
129   }
130 }
131 
set_message_callback(const uint64_t & request_id,const MessageCallback & callback)132 void Node::set_message_callback(const uint64_t &request_id, const MessageCallback &callback) {
133   if (!callback) {
134     return;
135   }
136   std::lock_guard<std::mutex> lock(message_callbacks_mutex_);
137   message_callbacks_[request_id] = callback;
138 }
139 
ProcessSendDataResp(const std::shared_ptr<MessageMeta> & meta,const Protos &,const void * data,size_t size)140 void Node::ProcessSendDataResp(const std::shared_ptr<MessageMeta> &meta, const Protos &, const void *data,
141                                size_t size) {
142   MS_EXCEPTION_IF_NULL(meta);
143   MS_EXCEPTION_IF_NULL(data);
144   std::lock_guard<std::mutex> lock(receive_messages_mutex_);
145   const uint32_t &rank_id = meta->rank_id();
146   const uint64_t request_id = meta->request_id();
147   MS_LOG(DEBUG) << "The node role is:" << CommUtil::NodeRoleToString(node_info_.node_role_)
148                 << ", the node id is:" << node_info_.node_id_ << " send the request id is:" << request_id;
149   if (meta->role() == NodeRole::SERVER) {
150     auto it = receive_messages_.find(request_id);
151     VectorPtr received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
152     if (size > 0) {
153       size_t dest_size = size;
154       size_t src_size = size;
155       if (memcpy_s(received_data.get()->data(), dest_size, data, src_size) != EOK) {
156         MS_LOG(EXCEPTION) << "The memcpy_s error";
157       }
158     }
159     if (it != receive_messages_.end()) {
160       it->second[rank_id] = received_data;
161     } else {
162       std::unordered_map<uint32_t, VectorPtr> res;
163       (void)res.insert(std::make_pair(rank_id, received_data));
164       receive_messages_[request_id] = res;
165     }
166   } else {
167     auto it = workder_receive_messages_.find(request_id);
168     VectorPtr received_data = std::make_shared<std::vector<unsigned char>>(size, 0);
169     if (size > 0) {
170       size_t dest_size = size;
171       size_t src_size = size;
172       if (memcpy_s(received_data.get()->data(), dest_size, data, src_size) != EOK) {
173         MS_LOG(EXCEPTION) << "The memcpy_s error";
174       }
175     }
176     if (it != workder_receive_messages_.end()) {
177       it->second[rank_id] = received_data;
178     } else {
179       std::unordered_map<uint32_t, VectorPtr> res;
180       (void)res.insert(std::make_pair(rank_id, received_data));
181       workder_receive_messages_[request_id] = res;
182     }
183   }
184 }
185 
RunMessageCallback(const uint64_t & request_id)186 void Node::RunMessageCallback(const uint64_t &request_id) {
187   message_callbacks_mutex_.lock();
188   // When receiving a message's response, Then compare with the desired number of responses,
189   // If they are equal, then call the callback function
190   if (CheckMessageTrack(request_id)) {
191     auto it = message_callbacks_.find(request_id);
192     if (it != message_callbacks_.end()) {
193       message_callbacks_mutex_.unlock();
194 
195       if (it->second) {
196         it->second();
197       }
198 
199       message_callbacks_mutex_.lock();
200       message_callbacks_.erase(it);
201     }
202   }
203   message_callbacks_mutex_.unlock();
204 }
205 }  // namespace core
206 }  // namespace ps
207 }  // namespace mindspore
208