1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_COMMUNICATOR_H_ 18 #define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_COMMUNICATOR_H_ 19 20 #include <map> 21 #include <vector> 22 #include <string> 23 #include <memory> 24 #include <unordered_map> 25 #include "proto/ps.pb.h" 26 #include "ps/core/server_node.h" 27 #include "ps/core/cluster_metadata.h" 28 #include "ps/core/cluster_config.h" 29 #include "ps/ps_context.h" 30 #include "ps/core/communicator/task_executor.h" 31 #include "ps/core/communicator/communicator_base.h" 32 #include "ps/core/communicator/tcp_msg_handler.h" 33 #include "ps/core/comm_util.h" 34 #include "ps/constants.h" 35 36 namespace mindspore { 37 namespace ps { 38 namespace core { 39 const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = { 40 {TcpUserCommand::kPush, "push"}, 41 {TcpUserCommand::kPull, "pull"}, 42 {TcpUserCommand::kCount, "count"}, 43 {TcpUserCommand::kReachThreshold, "countReachThreshold"}, 44 {TcpUserCommand::kResetCount, "resetCnt"}, 45 {TcpUserCommand::kGetMetadata, "getMetadata"}, 46 {TcpUserCommand::kUpdateMetadata, "updateMetadata"}, 47 {TcpUserCommand::kCounterEvent, "counterEvent"}, 48 {TcpUserCommand::kPullWeight, "pullWeight"}, 49 {TcpUserCommand::kPushWeight, "pushWeight"}, 50 {TcpUserCommand::kSyncIteration, "syncIteration"}, 51 {TcpUserCommand::kNotifyLeaderToNextIter, "notifyLeaderToNextIter"}, 52 {TcpUserCommand::kPrepareForNextIter, "prepareForNextIter"}, 53 {TcpUserCommand::kProceedToNextIter, "proceedToNextIter"}, 54 {TcpUserCommand::kEndLastIter, "endLastIter"}, 55 {TcpUserCommand::kStartFLJob, "startFLJob"}, 56 {TcpUserCommand::kUpdateModel, "updateModel"}, 57 {TcpUserCommand::kGetModel, "getModel"}, 58 {TcpUserCommand::kPushMetrics, "pushMetrics"}, 59 {TcpUserCommand::kNewInstance, "newInstance"}, 60 {TcpUserCommand::kQueryInstance, "queryInstance"}, 61 {TcpUserCommand::kEnableFLS, "enableFLS"}, 62 {TcpUserCommand::kDisableFLS, "disableFLS"}}; 63 64 class TcpCommunicator : public CommunicatorBase { 65 public: TcpCommunicator(const std::shared_ptr<TaskExecutor> & task_executor,AbstractNode * node)66 explicit TcpCommunicator(const std::shared_ptr<TaskExecutor> &task_executor, AbstractNode *node) 67 : task_executor_(task_executor), 68 server_num_(0), 69 worker_num_(0), 70 scheduler_ip_(""), 71 scheduler_port_(0), 72 abstrace_node_(node) {} 73 ~TcpCommunicator() = default; 74 75 bool Start() override; 76 bool Stop() override; 77 78 void RegisterMsgCallBack(const std::string &msg_type, const MessageCallback &cb) override; 79 void RegisterEventCallback(const core::ClusterEvent &event, const EventCallback &event_cb); 80 81 template <class T> 82 bool SendPbRequest(const T &pb_msg, const uint32_t &rank_id, TcpUserCommand command, 83 std::shared_ptr<std::vector<unsigned char>> *output = nullptr) { 84 const std::string &msg_str = pb_msg.SerializeAsString(); 85 std::shared_ptr<unsigned char[]> msg(new unsigned char[msg_str.size()]); 86 MS_ERROR_IF_NULL_W_RET_VAL(msg, false); 87 size_t dest_size = msg_str.size(); 88 size_t src_size = msg_str.size(); 89 if (memcpy_s(msg.get(), dest_size, msg_str.c_str(), src_size) != EOK) { 90 MS_LOG(EXCEPTION) << "Memcpy_s error"; 91 } 92 93 if (output != nullptr) { 94 if (!abstrace_node_->Send(NodeRole::SERVER, rank_id, msg, msg_str.size(), static_cast<int>(command), output)) { 95 MS_LOG(ERROR) << "Sending protobuffer message to server " << rank_id << " failed."; 96 return false; 97 } 98 } else { 99 if (!abstrace_node_->Send(NodeRole::SERVER, rank_id, msg, msg_str.size(), static_cast<int>(command))) { 100 MS_LOG(ERROR) << "Sending protobuffer message to server " << rank_id << " failed."; 101 return false; 102 } 103 } 104 return true; 105 } 106 107 private: 108 std::shared_ptr<TaskExecutor> task_executor_; 109 110 TcpMsgCallback tcp_msg_callback_; 111 OnNodeEventCallback event_callback_; 112 113 uint32_t server_num_; 114 uint32_t worker_num_; 115 116 std::string scheduler_ip_; 117 uint16_t scheduler_port_; 118 119 AbstractNode *abstrace_node_; 120 }; 121 } // namespace core 122 } // namespace ps 123 } // namespace mindspore 124 #endif // MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_COMMUNICATOR_H_ 125