1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_RPC_RECV_ACTOR_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_RPC_RECV_ACTOR_H_ 19 20 #include <set> 21 #include <mutex> 22 #include <vector> 23 #include <string> 24 #include <memory> 25 #include <condition_variable> 26 #include "runtime/graph_scheduler/actor/rpc/rpc_actor.h" 27 #include "plugin/device/cpu/hal/device/cpu_device_address.h" 28 29 namespace mindspore { 30 namespace runtime { 31 using CPUDeviceAddress = device::cpu::CPUDeviceAddress; 32 // RecvActor inherits from RpcActor and it's used to receive data from other processes. 33 class RecvActor : public RpcActor { 34 public: RecvActor(const std::string & name,const CNodePtr & kernel,const DeviceContext * device_context,const AID & memory_manager_aid,const AID * debug_aid,const AID * recorder_aid,GraphExecutionStrategy strategy,const std::set<size_t> & modifiable_ref_input_indexes,const std::set<size_t> & modifiable_ref_output_indexes)35 explicit RecvActor(const std::string &name, const CNodePtr &kernel, const DeviceContext *device_context, 36 const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid, 37 GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes, 38 const std::set<size_t> &modifiable_ref_output_indexes) 39 : RpcActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy, 40 modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kRecvActor), 41 server_(nullptr), 42 is_context_valid_(false), 43 recv_data_(nullptr), 44 ip_(""), 45 port_(0), 46 rdma_buf_(nullptr) {} 47 ~RecvActor() override; 48 49 // Besides set the op context, this method also notify the message handler to 'RunOpInterProcessData'. 50 void SetOpcontext(OpContext<DeviceTensor> *const op_context) override; 51 52 // This method means the op context is invalid now. If the message handler is called while the op context is invalid, 53 // it should be blocked until 'SetOpcontext' is called. 54 void ResetOpcontext() override; 55 56 // Update the context status after loop_count_actor is launched. 57 void UpdateStatus() override; 58 59 // Set recv actor's source peer info, in another word, recv actor's input. 60 void SetRouteInfo(uint32_t src_rank, const std::string &src_role, const std::string &recv_src_node_name, 61 const std::string &recv_dst_node_name) override; 62 63 // Start recv actor server and register this server address to actor route table in scheduler by proxy. 64 bool StartServer(); 65 66 // Finalize rpc server. 67 void Clear() override; 68 69 void StopRpcAtException() override; 70 71 protected: 72 // Besides the checking method in base class AbstractActor, condition of inter-process arrows should be checked for 73 // recv actor. 74 bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override; 75 76 // When an inter-process data received, this method is called. 77 void RunOpInterProcessData(MessageBase *const msg, OpContext<DeviceTensor> *const context); 78 79 // Besides erasing input data and input controls when finish actor running, inter-process inputs should be erased. 80 void EraseInput(const OpContext<DeviceTensor> *context) override; 81 82 // Before calling the Run method in KernelActor, some preprocess like inferring shape should be done. So rewrite the 83 // Run method. 84 void Run(OpContext<DeviceTensor> *const context) override; 85 86 // Set the message handler of the server. 87 virtual void SetMessageHandler(); 88 89 // Parse finalize command message from received message. ParseFinalizeReqData(size_t data_len,const MessageBase * const msg,bool * need_finalize)90 virtual void ParseFinalizeReqData(size_t data_len, const MessageBase *const msg, bool *need_finalize) {} 91 92 /** 93 * @description: The callback set to rpc module to allocate message(Raw pointer). 94 * @param {size_t} size: The message size. 95 * @return {void *}: A pointer to the newly allocated memory. 96 */ 97 virtual void *AllocateMessage(size_t size); 98 99 /** 100 * @description: Allocate memory by DeviceResManager. 101 * @param {size_t} size: memory buffer's size. 102 * @return {void *} 103 */ 104 void *AllocateMemByDeviceRes(size_t size); 105 106 std::unique_ptr<RPCServerBase> server_; 107 108 // The variables used to ensure thread-safe of op context visited by recv actor. 109 bool is_context_valid_; 110 std::mutex context_mtx_; 111 std::condition_variable context_cv_; 112 113 // The received data which should be allocated by framework. 114 // It will be used for copying the buffer from the kernel function. 115 std::shared_ptr<CPUDeviceAddress> recv_data_; 116 117 private: 118 // Create abstract and add to the abstract list. 119 void AddArgSpecForInput(AbstractBasePtrList *args_spec_list, const ShapeVector &shapes, TypeId data_type, 120 size_t input_index) const; 121 122 // Parse the protobuf message from the given buffer. The format is as below. 123 // |--------22 bytes------|---4 bytes--|PB data size bytes| data size bytes | 124 // |RPC_DYNAMIC_SHAPE_DATA|PB data size| PB data | real data | 125 // Return dynamic shape data length. 126 size_t ParseDynamicShapeData(const RpcDataPtr &dynamic_shape_data, size_t data_size, 127 AbstractBasePtrList *args_spec_list, size_t count); 128 129 // After Recv actor receives data from a remote peer, the data could be with dynamic shape so we need to preprocess 130 // it, e.g., infer shape for RpcRecv kernel and call Resize(). 131 void PreprocessRemoteInput(const MessageBase *const msg, bool *need_finalize); 132 133 // The message callback of the rpc server. 134 MessageBase *HandleMessage(MessageBase *const msg); 135 136 // The network address of this recv actor. It's generated automatically by rpc module. 137 std::string ip_; 138 uint32_t port_; 139 140 // Data returned by URPC. It should be released by RecvActor. 141 void *rdma_buf_; 142 }; 143 144 using RecvActorPtr = std::shared_ptr<RecvActor>; 145 } // namespace runtime 146 } // namespace mindspore 147 148 #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_RPC_RECV_ACTOR_H_ 149