• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_RPC_SEND_ACTOR_H_
18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_RPC_SEND_ACTOR_H_
19 
20 #include <set>
21 #include <vector>
22 #include <string>
23 #include <memory>
24 #include "runtime/graph_scheduler/actor/rpc/rpc_actor.h"
25 
26 namespace mindspore {
27 namespace runtime {
28 // SendActor inherits from RpcActor and it's used to send data to other processes.
29 class SendActor : public RpcActor {
30  public:
SendActor(const std::string & name,const CNodePtr & kernel,const DeviceContext * device_context,const AID & memory_manager_aid,const AID * debug_aid,const AID * recorder_aid,GraphExecutionStrategy strategy,const std::set<size_t> & modifiable_ref_input_indexes,const std::set<size_t> & modifiable_ref_output_indexes)31   explicit SendActor(const std::string &name, const CNodePtr &kernel, const DeviceContext *device_context,
32                      const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid,
33                      GraphExecutionStrategy strategy, const std::set<size_t> &modifiable_ref_input_indexes,
34                      const std::set<size_t> &modifiable_ref_output_indexes)
35       : RpcActor(name, kernel, device_context, memory_manager_aid, debug_aid, recorder_aid, strategy,
36                  modifiable_ref_input_indexes, modifiable_ref_output_indexes, KernelTransformType::kSendActor),
37         client_(nullptr),
38         context_(nullptr),
39         server_url_("") {}
40   ~SendActor() override;
41 
42   // Set send actor's destination peer info, in another word, send actor's output.
43   void SetRouteInfo(uint32_t dst_rank, const std::string &dst_role, const std::string &send_src_node_name,
44                     const std::string &send_dst_node_name) override;
45 
46   // Lookup peer actors' route and create connection to them.
47   bool ConnectServer();
48 
49   // Flush and wait for sent data to be passed to kernel.
50   void FlushData() override;
51 
52   // Finalize rpc client.
53   void Clear() override;
54 
55  protected:
56   // Do real send operation in this method.
57   bool LaunchKernel(OpContext<DeviceTensor> *const context, bool is_skip_launch = false) override;
58 
59   // Erase inter-process inputs for this sequential number.
60   void EraseInput(const OpContext<DeviceTensor> *context) override;
61 
62   // Client only supports to send MessageBase, so build MessageBase with data and url.
63   std::unique_ptr<MessageBase> BuildRpcMessage(const std::string &server_url);
64 
65   /**
66    * @description: Free message after it's sent to remote.
67    * @param {void} *data: Raw pointer data needs to be freed.
68    * @return {bool}: Whether the data is successfully freed.
69    */
70   virtual bool FreeMessage(void *data);
71 
72   /**
73    * @description: Flush the message to kernel so that the memory could be released. This method is used for synchronize
74    * sending operations.
75    * @return {void}
76    */
77   virtual void Flush();
78 
79   // The rpc client connection to multiple servers.
80   std::unique_ptr<RPCClientBase> client_;
81 
82  private:
83   /**
84    * @description: Find the memory list needs to be freed after the data is sent to remote. This should be called by
85    * FreeMessage.
86    * @param {const void} *data: Raw pointer data needs to be freed.
87    * @return {std::vector<DeviceTensor *>}: The memory list needs to be freed.
88    */
89   std::vector<DeviceTensor *> FindDeviceTensorNeedsFree(const void *data) const;
90 
91   /**
92    * @description: Serialize one dynamic shape input data to a piece of memory and returns the serialized data
93    * size for accessing memory by offset.
94    * The format is shown below:
95    * |--------22 bytes------|---4 bytes--|PB data size bytes| data size bytes |
96    * |RPC_DYNAMIC_SHAPE_DATA|PB data size|      PB data     | real data       |
97    * @param {RpcDataPtr} &rpc_data: A piece of memory which is allocated by the caller for serialized data to copy to.
98    * @param {ShapeVector} &shape_vec: Input data's shape vector.
99    * @param {TypeId} &data_type: Input data's type.
100    * @param {DeviceTensor} *addr: Input data's device tensor.
101    * @return {size_t}: Size of the serialized data.
102    */
103   size_t SerializeSingleDynamicShapeInput(RpcDataPtr rpc_data, const ShapeVector &shape_vec, const TypeId &data_type,
104                                           const DeviceTensor *addr) const;
105 
106   // Serialize dynamic shape data. The format is shown below:
107   // |--------22 bytes------|---4 bytes--|PB data size bytes| data size bytes |
108   // |RPC_DYNAMIC_SHAPE_DATA|PB data size|      PB data     | real data       |
109   /**
110    * @description: Serialize message with dynamic shape data. For each input in dynamic shape scenario, extra meta info
111    * like data shape, data type will be serialized as protobuffer and copied to message.
112    *
113    * @param {MessageBase} *message: MessageBase object.
114    * @param {DeviceTensor} *workspace_addr: Workspace device tensor.
115    * @return {void}
116    */
117   void SerializeDynamicShapeMessage(MessageBase *message, const DeviceTensor *workspace_addr) const;
118 
119   /**
120    * @description: Serialize common message without extra info, which means: the data of raw pointer will be directly
121    * copied to the message.
122    * @param {MessageBase} *message: MessageBase object.
123    * @param {DeviceTensor} *workspace_addr: Workspace device tensor.
124    * @return {void}
125    */
126   void SerializeCommonMessage(MessageBase *message, const DeviceTensor *workspace_addr) const;
127 
128   friend class GraphScheduler;
129 
130   // OpC ontext passed by graph scheduler.
131   OpContext<DeviceTensor> *context_;
132 
133   // This send actor's destination peers' actor ids and route table.
134   std::vector<std::string> peer_actor_ids_;
135   mindspore::HashMap<std::string, std::string> peer_actor_urls_;
136 
137   // The url of the peer recv actor's server.
138   std::string server_url_;
139 
140   // The remote function id this client will call.
141   uint32_t remote_func_id_;
142 };
143 
144 using SendActorPtr = std::shared_ptr<SendActor>;
145 }  // namespace runtime
146 }  // namespace mindspore
147 
148 #endif  // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_RPC_SEND_ACTOR_H_
149