• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_RPC_RECV_ACTOR_H_
18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_RPC_RECV_ACTOR_H_
19 
20 #include <set>
21 #include <mutex>
22 #include <vector>
23 #include <string>
24 #include <memory>
25 #include <condition_variable>
26 #include "runtime/graph_scheduler/actor/rpc/rpc_actor.h"
27 #include "plugin/device/cpu/hal/device/cpu_device_address.h"
28 
29 namespace mindspore {
30 namespace runtime {
31 using CPUDeviceAddress = device::cpu::CPUDeviceAddress;
32 // RecvActor inherits from RpcActor and it's used to receive data from other processes.
33 class RecvActor : public RpcActor {
34  public:
RecvActor(const std::string & name,const CNodePtr & kernel,const DeviceContext * device_context,const AID & memory_manager_aid,const AID * debug_aid,const AID * recorder_aid,GraphExecutionStrategy strategy,const std::set<size_t> & modifiable_ref_input_indexes,const std::set<size_t> & modifiable_ref_output_indexes)35   explicit RecvActor(const std::string &name, const CNodePtr &kernel, const DeviceContext *device_context,
36                      const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid,
37                      GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
38                      const std::set<size_t> &modifiable_ref_output_indexes)
39       : RpcActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
40                  modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kRecvActor),
41         server_(nullptr),
42         is_context_valid_(false),
43         recv_data_(nullptr),
44         ip_(""),
45         port_(0),
46         rdma_buf_(nullptr) {}
47   ~RecvActor() override;
48 
49   // Besides set the op context, this method also notify the message handler to 'RunOpInterProcessData'.
50   void SetOpcontext(OpContext<DeviceTensor> *const op_context) override;
51 
52   // This method means the op context is invalid now. If the message handler is called while the op context is invalid,
53   // it should be blocked until 'SetOpcontext' is called.
54   void ResetOpcontext() override;
55 
56   // Update the context status after loop_count_actor is launched.
57   void UpdateStatus() override;
58 
59   // Set recv actor's source peer info, in another word, recv actor's input.
60   void SetRouteInfo(uint32_t src_rank, const std::string &src_role, const std::string &recv_src_node_name,
61                     const std::string &recv_dst_node_name) override;
62 
63   // Start recv actor server and register this server address to actor route table in scheduler by proxy.
64   bool StartServer();
65 
66   // Finalize rpc server.
67   void Clear() override;
68 
69   void StopRpcAtException() override;
70 
71  protected:
72   // Besides the checking method in base class AbstractActor, condition of inter-process arrows should be checked for
73   // recv actor.
74   bool CheckRunningCondition(const OpContext<DeviceTensor> *context) const override;
75 
76   // When an inter-process data received, this method is called.
77   void RunOpInterProcessData(MessageBase *const msg, OpContext<DeviceTensor> *const context);
78 
79   // Besides erasing input data and input controls when finish actor running, inter-process inputs should be erased.
80   void EraseInput(const OpContext<DeviceTensor> *context) override;
81 
82   // Before calling the Run method in KernelActor, some preprocess like inferring shape should be done. So rewrite the
83   // Run method.
84   void Run(OpContext<DeviceTensor> *const context) override;
85 
86   // Set the message handler of the server.
87   virtual void SetMessageHandler();
88 
89   // Parse finalize command message from received message.
ParseFinalizeReqData(size_t data_len,const MessageBase * const msg,bool * need_finalize)90   virtual void ParseFinalizeReqData(size_t data_len, const MessageBase *const msg, bool *need_finalize) {}
91 
92   /**
93    * @description: The callback set to rpc module to allocate message(Raw pointer).
94    * @param {size_t} size: The message size.
95    * @return {void *}: A pointer to the newly allocated memory.
96    */
97   virtual void *AllocateMessage(size_t size);
98 
99   /**
100    * @description: Allocate memory by DeviceResManager.
101    * @param {size_t} size: memory buffer's size.
102    * @return {void *}
103    */
104   void *AllocateMemByDeviceRes(size_t size);
105 
106   std::unique_ptr<RPCServerBase> server_;
107 
108   // The variables used to ensure thread-safe of op context visited by recv actor.
109   bool is_context_valid_;
110   std::mutex context_mtx_;
111   std::condition_variable context_cv_;
112 
113   // The received data which should be allocated by framework.
114   // It will be used for copying the buffer from the kernel function.
115   std::shared_ptr<CPUDeviceAddress> recv_data_;
116 
117  private:
118   // Create abstract and add to the abstract list.
119   void AddArgSpecForInput(AbstractBasePtrList *args_spec_list, const ShapeVector &shapes, TypeId data_type,
120                           size_t input_index) const;
121 
122   // Parse the protobuf message from the given buffer. The format is as below.
123   // |--------22 bytes------|---4 bytes--|PB data size bytes| data size bytes |
124   // |RPC_DYNAMIC_SHAPE_DATA|PB data size|      PB data     | real data       |
125   // Return dynamic shape data length.
126   size_t ParseDynamicShapeData(const RpcDataPtr &dynamic_shape_data, size_t data_size,
127                                AbstractBasePtrList *args_spec_list, size_t count);
128 
129   // After Recv actor receives data from a remote peer, the data could be with dynamic shape so we need to preprocess
130   // it, e.g., infer shape for RpcRecv kernel and call Resize().
131   void PreprocessRemoteInput(const MessageBase *const msg, bool *need_finalize);
132 
133   // The message callback of the rpc server.
134   MessageBase *HandleMessage(MessageBase *const msg);
135 
136   // The network address of this recv actor. It's generated automatically by rpc module.
137   std::string ip_;
138   uint32_t port_;
139 
140   // Data returned by URPC. It should be released by RecvActor.
141   void *rdma_buf_;
142 };
143 
144 using RecvActorPtr = std::shared_ptr<RecvActor>;
145 }  // namespace runtime
146 }  // namespace mindspore
147 
148 #endif  // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_RPC_RECV_ACTOR_H_
149