1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_COMMUNICATOR_H_ 18 #define MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_COMMUNICATOR_H_ 19 20 #include <map> 21 #include <vector> 22 #include <string> 23 #include <memory> 24 #include <unordered_map> 25 #include "proto/ps.pb.h" 26 #include "ps/core/server_node.h" 27 #include "ps/core/cluster_metadata.h" 28 #include "ps/core/cluster_config.h" 29 #include "include/backend/distributed/ps/ps_context.h" 30 #include "ps/core/communicator/task_executor.h" 31 #include "ps/core/communicator/communicator_base.h" 32 #include "ps/core/communicator/tcp_msg_handler.h" 33 #include "ps/core/comm_util.h" 34 #include "include/backend/distributed/ps/constants.h" 35 36 namespace mindspore { 37 namespace ps { 38 namespace core { 39 const std::unordered_map<TcpUserCommand, std::string> kUserCommandToMsgType = {{TcpUserCommand::kPush, "push"}, 40 {TcpUserCommand::kPull, "pull"}}; 41 42 class TcpCommunicator : public CommunicatorBase { 43 public: TcpCommunicator(const std::shared_ptr<TaskExecutor> & task_executor,AbstractNode * node)44 explicit TcpCommunicator(const std::shared_ptr<TaskExecutor> &task_executor, AbstractNode *node) 45 : task_executor_(task_executor), abstrace_node_(node) {} 46 ~TcpCommunicator() = default; 47 48 bool Start() override; 49 bool Stop() override; 50 51 void RegisterMsgCallBack(const std::string &msg_type, const MessageCallback &cb) override; 52 void RegisterEventCallback(const core::ClusterEvent &event, const EventCallback &event_cb); 53 54 template <class T> 55 bool SendPbRequest(const T &pb_msg, const uint32_t &rank_id, TcpUserCommand command, 56 std::shared_ptr<std::vector<unsigned char>> *output = nullptr) { 57 const std::string &msg_str = pb_msg.SerializeAsString(); 58 if (!abstrace_node_->Send(NodeRole::SERVER, rank_id, msg_str, static_cast<int>(command), output)) { 59 MS_LOG(ERROR) << "Sending protobuffer message to server " << rank_id << " failed."; 60 return false; 61 } 62 return true; 63 } 64 65 private: 66 std::shared_ptr<TaskExecutor> task_executor_; 67 68 TcpMsgCallback tcp_msg_callback_; 69 OnNodeEventCallback event_callback_; 70 71 AbstractNode *abstrace_node_; 72 }; 73 } // namespace core 74 } // namespace ps 75 } // namespace mindspore 76 #endif // MINDSPORE_CCSRC_PS_CORE_COMMUNICATOR_TCP_COMMUNICATOR_H_ 77