• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_ANY_TYPE_ACTOR_H_
18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_ANY_TYPE_ACTOR_H_
19 
20 #include <string>
21 #include <memory>
22 #include <map>
23 #include <utility>
24 #include <vector>
25 #include "runtime/graph_scheduler/actor/super_kernel_actor.h"
26 #include "runtime/graph_scheduler/actor/actor_common.h"
27 #include "include/common/utils/python_adapter.h"
28 #include "ir/anf.h"
29 
30 namespace mindspore {
31 namespace runtime {
32 // State is used to mark the state of the actor, which is divided into two states: processing the input of the graph
33 // and the output of the graph.
34 enum AnyTypeKernelActorState { kAnyTypeKernelActorInit, kAnyTypeKernelActorSendInput, kAnyTypeKernelActorSendOutput };
35 using mindspore::device::DeviceContext;
36 using DataArrowGroupMap = mindspore::HashMap<std::string, std::vector<DataArrowPtr>>;
37 using ControlArrowGroupMap = mindspore::HashMap<std::string, std::vector<AID *>>;
38 using TransformFunc =
39   std::function<std::vector<AbstractActorPtr>(const KernelGraphPtr &, const KernelGraphPtr &, const DeviceContext *)>;
40 using ScheduleFunc = std::function<void(const std::vector<AbstractActorPtr> &)>;
41 // The Any Type kernel actor is used to represent the graph whose data type is uncertain and need compiler when
42 // the actor run.
43 // The execution is as follows:
44 // 1. Receive input
45 // 2. Send graph input to kernel\superkernel actor
46 // 3. Receive graph output from kernel\superkernel actor
47 // 4. Send graph output
48 class AnyTypeKernelActor : public SuperKernelActor {
49  public:
50   AnyTypeKernelActor(const std::string &name, const KernelGraphPtr &graph, const DeviceContext *device_context,
51                      const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid,
52                      KernelTransformType type = KernelTransformType::kAnyTypeKernelActor);
53   ~AnyTypeKernelActor() override = default;
54 
55   void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override;
56   void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override;
current_data_type()57   const std::string &current_data_type() const { return current_data_type_; }
58 
59  protected:
60   void Init() override;
61 
62   // Hand the graph input.
63   // The execution of actor is divided into the following steps:
64   // Receive graph inputs:
65   // 1. generate type key
66   // 2. check whether the corresponding graph already exists, if not found, execute 3, if there is, execute 4
67   // 3. compile the corresponding kernel_graph according to the type and generate the corresponding actor_set
68   // 4. send graph inputs to kernel actor of current graph
69   void RunForGraphInput(OpContext<DeviceTensor> *const context);
70   void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context) override;
71   void UpdataDynamicShapeParameterForGraphInput(OpContext<DeviceTensor> *const context);
72   void SendOutput(OpContext<DeviceTensor> *const context) override;
73   void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override;
74 
75   // Handle the graph output.
76   bool CheckGraphOutputRunningCondition(const OpContext<DeviceTensor> *context);
77   // Receive graph outputs:
78   // 1. find the corresponding arrow according to the current type key, and send the outputs.
79   void RunForGraphOutput(OpContext<DeviceTensor> *const context);
80   void CheckParams(OpContext<DeviceTensor> *const context);
81   void FetchGraphOutput(OpContext<DeviceTensor> *const context);
82   void EraseGraphOutput(OpContext<DeviceTensor> *const context);
83   void UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrowPtr &data_arrow,
84                         const AnfNodePtr &output_node, OpContext<DeviceTensor> *const context) override;
85 
86  private:
87   friend class AnyTypeGraphScheduler;
88 
89   // When the actor receives the input of the graph, it can determine the data type of the parameter and then compile
90   // an executable kernel graph and actors.
91   mindspore::HashMap<string, std::vector<AbstractActorPtr>> actors_;
92   // Kernel graphs that are actually executed.
93   mindspore::HashMap<string, KernelGraphPtr> real_graphs_;
94   // The positions of any type parameter in the kernel graph.
95   // After graph compiler, a unique key will be generate according to the type of these parameters to save the arrows
96   // corresponding to the graph.
97   std::vector<size_t> any_type_parameter_indexes_;
98   // The data type of any type parameters in the currently received input, the format is like:typeid1_typeid2_typeid3.
99   std::string current_data_type_;
100 
101   // Parameters that have a dynamic shape.
102   mindspore::HashMap<std::string, std::vector<AnfNodePtr>> graph_input_backend_parameters_;
103 
104   // Arrows send to kernel/superkernel actors of graph.
105   mindspore::HashMap<std::string, std::vector<DataArrowPtr>> graph_input_data_arrows_;
106   mindspore::HashMap<std::string, std::vector<ControlArrowPtr>> graph_input_control_arrows_;
107   // The output_data_nodes_ and output_data_ corresponds to the output_data_arrows_ one by one.
108   mindspore::HashMap<std::string, std::vector<AnfNodePtr>> graph_input_data_nodes_;
109   // The second of pair indicates the output data flag. See constant prefixed with kOutputDataFalg for details.
110   mindspore::HashMap<std::string, std::vector<std::pair<OpDataUniquePtr<DeviceTensor>, size_t>>> graph_input_data_;
111   // Record the fusion output index for output data arrow.
112   mindspore::HashMap<std::string, mindspore::HashMap<DataArrow *, size_t>> data_arrow_to_graph_input_actor_indexs_;
113   // Used to send batch data in the message which RunBatchOpData needs, the key is the actor name of destination actor.
114   mindspore::HashMap<std::string, mindspore::HashMap<std::string, std::vector<OpData<DeviceTensor> *>>>
115     batch_graph_input_data_;
116   mindspore::HashMap<std::string, mindspore::HashMap<std::string, std::vector<DataArrowPtr>>>
117     batch_graph_input_data_arrows_;
118 
119   // Graph outputs receive from kernel/superkernel actors of graph.
120   mindspore::HashMap<int, std::vector<OpData<DeviceTensor> *>> graph_output_op_data_;
121   mindspore::HashMap<int, std::vector<AID *>> graph_output_op_control_;
122   std::vector<DeviceTensor *> graph_ouput_device_tensors_;
123   // In any type kernel actor, the kernel in the model graph will have fallback scenario, the device type of the
124   // model graph and the real graph will be different. A new device address needs to be created for the model graph
125   // and placed here.
126   std::vector<DeviceTensorPtr> fallback_device_tensors_;
127   mindspore::HashMap<std::string, size_t> graph_output_data_num_;
128   mindspore::HashMap<std::string, size_t> graph_output_control_num_;
129 
130   AnyTypeKernelActorState actor_state_{kAnyTypeKernelActorInit};
131 
132   static std::mutex instance_lock_;
133 
134   CompileFunc compile_func_;
135   TransformFunc transform_func_;
136   ScheduleFunc schedule_func_;
137 };
138 
139 using AnyTypeKernelActorPtr = std::shared_ptr<AnyTypeKernelActor>;
140 }  // namespace runtime
141 }  // namespace mindspore
142 
143 #endif  // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_ANY_TYPE_ACTOR_H_
144