• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_COMMUNICATOR_H_
18 #define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_COMMUNICATOR_H_
19 
20 #include <map>
21 #include <vector>
22 #include <string>
23 #include <memory>
24 #include <unordered_map>
25 #include "proto/ps.pb.h"
26 #include "ps/core/server_node.h"
27 #include "ps/core/cluster_metadata.h"
28 #include "ps/core/cluster_config.h"
29 #include "ps/ps_context.h"
30 #include "ps/core/communicator/task_executor.h"
31 #include "ps/core/communicator/communicator_base.h"
32 #include "ps/core/communicator/tcp_msg_handler.h"
33 #include "ps/core/comm_util.h"
34 #include "ps/constants.h"
35 
36 namespace mindspore {
37 namespace ps {
38 namespace core {
39 const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {
40   {TcpUserCommand::kPush, "push"},
41   {TcpUserCommand::kPull, "pull"},
42   {TcpUserCommand::kCount, "count"},
43   {TcpUserCommand::kReachThreshold, "countReachThreshold"},
44   {TcpUserCommand::kResetCount, "resetCnt"},
45   {TcpUserCommand::kGetMetadata, "getMetadata"},
46   {TcpUserCommand::kUpdateMetadata, "updateMetadata"},
47   {TcpUserCommand::kCounterEvent, "counterEvent"},
48   {TcpUserCommand::kPullWeight, "pullWeight"},
49   {TcpUserCommand::kPushWeight, "pushWeight"},
50   {TcpUserCommand::kSyncIteration, "syncIteration"},
51   {TcpUserCommand::kNotifyLeaderToNextIter, "notifyLeaderToNextIter"},
52   {TcpUserCommand::kPrepareForNextIter, "prepareForNextIter"},
53   {TcpUserCommand::kProceedToNextIter, "proceedToNextIter"},
54   {TcpUserCommand::kEndLastIter, "endLastIter"},
55   {TcpUserCommand::kStartFLJob, "startFLJob"},
56   {TcpUserCommand::kUpdateModel, "updateModel"},
57   {TcpUserCommand::kGetModel, "getModel"},
58   {TcpUserCommand::kPushMetrics, "pushMetrics"},
59   {TcpUserCommand::kNewInstance, "newInstance"},
60   {TcpUserCommand::kQueryInstance, "queryInstance"},
61   {TcpUserCommand::kEnableFLS, "enableFLS"},
62   {TcpUserCommand::kDisableFLS, "disableFLS"}};
63 
64 class TcpCommunicator : public CommunicatorBase {
65  public:
TcpCommunicator(const std::shared_ptr<TaskExecutor> & task_executor,AbstractNode * node)66   explicit TcpCommunicator(const std::shared_ptr<TaskExecutor> &task_executor, AbstractNode *node)
67       : task_executor_(task_executor),
68         server_num_(0),
69         worker_num_(0),
70         scheduler_ip_(""),
71         scheduler_port_(0),
72         abstrace_node_(node) {}
73   ~TcpCommunicator() = default;
74 
75   bool Start() override;
76   bool Stop() override;
77 
78   void RegisterMsgCallBack(const std::string &msg_type, const MessageCallback &cb) override;
79   void RegisterEventCallback(const core::ClusterEvent &event, const EventCallback &event_cb);
80 
81   template <class T>
82   bool SendPbRequest(const T &pb_msg, const uint32_t &rank_id, TcpUserCommand command,
83                      std::shared_ptr<std::vector<unsigned char>> *output = nullptr) {
84     const std::string &msg_str = pb_msg.SerializeAsString();
85     std::shared_ptr<unsigned char[]> msg(new unsigned char[msg_str.size()]);
86     MS_ERROR_IF_NULL_W_RET_VAL(msg, false);
87     size_t dest_size = msg_str.size();
88     size_t src_size = msg_str.size();
89     if (memcpy_s(msg.get(), dest_size, msg_str.c_str(), src_size) != EOK) {
90       MS_LOG(EXCEPTION) << "Memcpy_s error";
91     }
92 
93     if (output != nullptr) {
94       if (!abstrace_node_->Send(NodeRole::SERVER, rank_id, msg, msg_str.size(), static_cast<int>(command), output)) {
95         MS_LOG(ERROR) << "Sending protobuffer message to server " << rank_id << " failed.";
96         return false;
97       }
98     } else {
99       if (!abstrace_node_->Send(NodeRole::SERVER, rank_id, msg, msg_str.size(), static_cast<int>(command))) {
100         MS_LOG(ERROR) << "Sending protobuffer message to server " << rank_id << " failed.";
101         return false;
102       }
103     }
104     return true;
105   }
106 
107  private:
108   std::shared_ptr<TaskExecutor> task_executor_;
109 
110   TcpMsgCallback tcp_msg_callback_;
111   OnNodeEventCallback event_callback_;
112 
113   uint32_t server_num_;
114   uint32_t worker_num_;
115 
116   std::string scheduler_ip_;
117   uint16_t scheduler_port_;
118 
119   AbstractNode *abstrace_node_;
120 };
121 }  // namespace core
122 }  // namespace ps
123 }  // namespace mindspore
124 #endif  // MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_COMMUNICATOR_H_
125