1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_RPC_SEND_ACTOR_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_RPC_SEND_ACTOR_H_ 19 20 #include <set> 21 #include <vector> 22 #include <string> 23 #include <memory> 24 #include "runtime/graph_scheduler/actor/rpc/rpc_actor.h" 25 26 namespace mindspore { 27 namespace runtime { 28 // SendActor inherits from RpcActor and it's used to send data to other processes. 29 class SendActor : public RpcActor { 30 public: SendActor(const std::string & name,const CNodePtr & kernel,const DeviceContext * device_context,const AID & memory_manager_aid,const AID * debug_aid,const AID * recorder_aid,GraphExecutionStrategy strategy,const std::set<size_t> & modifiable_ref_input_indexes,const std::set<size_t> & modifiable_ref_output_indexes)31 explicit SendActor(const std::string &name, const CNodePtr &kernel, const DeviceContext *device_context, 32 const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid, 33 GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes, 34 const std::set<size_t> &modifiable_ref_output_indexes) 35 : RpcActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy, 36 modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kSendActor), 37 client_(nullptr), 38 context_(nullptr), 39 server_url_("") {} 40 ~SendActor() override; 41 42 // Set send actor's destination peer info, in another word, send actor's output. 43 void SetRouteInfo(uint32_t dst_rank, const std::string &dst_role, const std::string &send_src_node_name, 44 const std::string &send_dst_node_name) override; 45 46 // Lookup peer actors' route and create connection to them. 47 bool ConnectServer(); 48 49 // Flush and wait for sent data to be passed to kernel. 50 void FlushData() override; 51 52 // Finalize rpc client. 53 void Clear() override; 54 55 protected: 56 // Do real send operation in this method. 57 bool LaunchKernel(OpContext<DeviceTensor> *const context, bool is_skip_launch = false) override; 58 59 // Erase inter-process inputs for this sequential number. 60 void EraseInput(const OpContext<DeviceTensor> *context) override; 61 62 // Client only supports to send MessageBase, so build MessageBase with data and url. 63 std::unique_ptr<MessageBase> BuildRpcMessage(const std::string &server_url); 64 65 /** 66 * @description: Free message after it's sent to remote. 67 * @param {void} *data: Raw pointer data needs to be freed. 68 * @return {bool}: Whether the data is successfully freed. 69 */ 70 virtual bool FreeMessage(void *data); 71 72 /** 73 * @description: Flush the message to kernel so that the memory could be released. This method is used for synchronize 74 * sending operations. 75 * @return {void} 76 */ 77 virtual void Flush(); 78 79 // The rpc client connection to multiple servers. 80 std::unique_ptr<RPCClientBase> client_; 81 82 private: 83 /** 84 * @description: Find the memory list needs to be freed after the data is sent to remote. This should be called by 85 * FreeMessage. 86 * @param {const void} *data: Raw pointer data needs to be freed. 87 * @return {std::vector<DeviceTensor *>}: The memory list needs to be freed. 88 */ 89 std::vector<DeviceTensor *> FindDeviceTensorNeedsFree(const void *data) const; 90 91 /** 92 * @description: Serialize one dynamic shape input data to a piece of memory and returns the serialized data 93 * size for accessing memory by offset. 94 * The format is shown below: 95 * |--------22 bytes------|---4 bytes--|PB data size bytes| data size bytes | 96 * |RPC_DYNAMIC_SHAPE_DATA|PB data size| PB data | real data | 97 * @param {RpcDataPtr} &rpc_data: A piece of memory which is allocated by the caller for serialized data to copy to. 98 * @param {ShapeVector} &shape_vec: Input data's shape vector. 99 * @param {TypeId} &data_type: Input data's type. 100 * @param {DeviceTensor} *addr: Input data's device tensor. 101 * @return {size_t}: Size of the serialized data. 102 */ 103 size_t SerializeSingleDynamicShapeInput(RpcDataPtr rpc_data, const ShapeVector &shape_vec, const TypeId &data_type, 104 const DeviceTensor *addr) const; 105 106 // Serialize dynamic shape data. The format is shown below: 107 // |--------22 bytes------|---4 bytes--|PB data size bytes| data size bytes | 108 // |RPC_DYNAMIC_SHAPE_DATA|PB data size| PB data | real data | 109 /** 110 * @description: Serialize message with dynamic shape data. For each input in dynamic shape scenario, extra meta info 111 * like data shape, data type will be serialized as protobuffer and copied to message. 112 * 113 * @param {MessageBase} *message: MessageBase object. 114 * @param {DeviceTensor} *workspace_addr: Workspace device tensor. 115 * @return {void} 116 */ 117 void SerializeDynamicShapeMessage(MessageBase *message, const DeviceTensor *workspace_addr) const; 118 119 /** 120 * @description: Serialize common message without extra info, which means: the data of raw pointer will be directly 121 * copied to the message. 122 * @param {MessageBase} *message: MessageBase object. 123 * @param {DeviceTensor} *workspace_addr: Workspace device tensor. 124 * @return {void} 125 */ 126 void SerializeCommonMessage(MessageBase *message, const DeviceTensor *workspace_addr) const; 127 128 friend class GraphScheduler; 129 130 // OpC ontext passed by graph scheduler. 131 OpContext<DeviceTensor> *context_; 132 133 // This send actor's destination peers' actor ids and route table. 134 std::vector<std::string> peer_actor_ids_; 135 mindspore::HashMap<std::string, std::string> peer_actor_urls_; 136 137 // The url of the peer recv actor's server. 138 std::string server_url_; 139 140 // The remote function id this client will call. 141 uint32_t remote_func_id_; 142 }; 143 144 using SendActorPtr = std::shared_ptr<SendActor>; 145 } // namespace runtime 146 } // namespace mindspore 147 148 #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_RPC_SEND_ACTOR_H_ 149