1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_ANY_TYPE_ACTOR_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_ANY_TYPE_ACTOR_H_ 19 20 #include <string> 21 #include <memory> 22 #include <map> 23 #include <utility> 24 #include <vector> 25 #include "runtime/graph_scheduler/actor/super_kernel_actor.h" 26 #include "runtime/graph_scheduler/actor/actor_common.h" 27 #include "include/common/utils/python_adapter.h" 28 #include "ir/anf.h" 29 30 namespace mindspore { 31 namespace runtime { 32 // State is used to mark the state of the actor, which is divided into two states: processing the input of the graph 33 // and the output of the graph. 34 enum AnyTypeKernelActorState { kAnyTypeKernelActorInit, kAnyTypeKernelActorSendInput, kAnyTypeKernelActorSendOutput }; 35 using mindspore::device::DeviceContext; 36 using DataArrowGroupMap = mindspore::HashMap<std::string, std::vector<DataArrowPtr>>; 37 using ControlArrowGroupMap = mindspore::HashMap<std::string, std::vector<AID *>>; 38 using TransformFunc = 39 std::function<std::vector<AbstractActorPtr>(const KernelGraphPtr &, const KernelGraphPtr &, const DeviceContext *)>; 40 using ScheduleFunc = std::function<void(const std::vector<AbstractActorPtr> &)>; 41 // The Any Type kernel actor is used to represent the graph whose data type is uncertain and need compiler when 42 // the actor run. 43 // The execution is as follows: 44 // 1. Receive input 45 // 2. Send graph input to kernel\superkernel actor 46 // 3. Receive graph output from kernel\superkernel actor 47 // 4. Send graph output 48 class AnyTypeKernelActor : public SuperKernelActor { 49 public: 50 AnyTypeKernelActor(const std::string &name, const KernelGraphPtr &graph, const DeviceContext *device_context, 51 const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid, 52 KernelTransformType type = KernelTransformType::kAnyTypeKernelActor); 53 ~AnyTypeKernelActor() override = default; 54 55 void RunOpData(OpData<DeviceTensor> *const input_data, OpContext<DeviceTensor> *const context) override; 56 void RunOpControl(AID *const input_control, OpContext<DeviceTensor> *const context) override; current_data_type()57 const std::string ¤t_data_type() const { return current_data_type_; } 58 59 protected: 60 void Init() override; 61 62 // Hand the graph input. 63 // The execution of actor is divided into the following steps: 64 // Receive graph inputs: 65 // 1. generate type key 66 // 2. check whether the corresponding graph already exists, if not found, execute 3, if there is, execute 4 67 // 3. compile the corresponding kernel_graph according to the type and generate the corresponding actor_set 68 // 4. send graph inputs to kernel actor of current graph 69 void RunForGraphInput(OpContext<DeviceTensor> *const context); 70 void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context) override; 71 void UpdataDynamicShapeParameterForGraphInput(OpContext<DeviceTensor> *const context); 72 void SendOutput(OpContext<DeviceTensor> *const context) override; 73 void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override; 74 75 // Handle the graph output. 76 bool CheckGraphOutputRunningCondition(const OpContext<DeviceTensor> *context); 77 // Receive graph outputs: 78 // 1. find the corresponding arrow according to the current type key, and send the outputs. 79 void RunForGraphOutput(OpContext<DeviceTensor> *const context); 80 void CheckParams(OpContext<DeviceTensor> *const context); 81 void FetchGraphOutput(OpContext<DeviceTensor> *const context); 82 void EraseGraphOutput(OpContext<DeviceTensor> *const context); 83 void UpdateOutputData(OpData<DeviceTensor> *const output_data, const DataArrowPtr &data_arrow, 84 const AnfNodePtr &output_node, OpContext<DeviceTensor> *const context) override; 85 86 private: 87 friend class AnyTypeGraphScheduler; 88 89 // When the actor receives the input of the graph, it can determine the data type of the parameter and then compile 90 // an executable kernel graph and actors. 91 mindspore::HashMap<string, std::vector<AbstractActorPtr>> actors_; 92 // Kernel graphs that are actually executed. 93 mindspore::HashMap<string, KernelGraphPtr> real_graphs_; 94 // The positions of any type parameter in the kernel graph. 95 // After graph compiler, a unique key will be generate according to the type of these parameters to save the arrows 96 // corresponding to the graph. 97 std::vector<size_t> any_type_parameter_indexes_; 98 // The data type of any type parameters in the currently received input, the format is like:typeid1_typeid2_typeid3. 99 std::string current_data_type_; 100 101 // Parameters that have a dynamic shape. 102 mindspore::HashMap<std::string, std::vector<AnfNodePtr>> graph_input_backend_parameters_; 103 104 // Arrows send to kernel/superkernel actors of graph. 105 mindspore::HashMap<std::string, std::vector<DataArrowPtr>> graph_input_data_arrows_; 106 mindspore::HashMap<std::string, std::vector<ControlArrowPtr>> graph_input_control_arrows_; 107 // The output_data_nodes_ and output_data_ corresponds to the output_data_arrows_ one by one. 108 mindspore::HashMap<std::string, std::vector<AnfNodePtr>> graph_input_data_nodes_; 109 // The second of pair indicates the output data flag. See constant prefixed with kOutputDataFalg for details. 110 mindspore::HashMap<std::string, std::vector<std::pair<OpDataUniquePtr<DeviceTensor>, size_t>>> graph_input_data_; 111 // Record the fusion output index for output data arrow. 112 mindspore::HashMap<std::string, mindspore::HashMap<DataArrow *, size_t>> data_arrow_to_graph_input_actor_indexs_; 113 // Used to send batch data in the message which RunBatchOpData needs, the key is the actor name of destination actor. 114 mindspore::HashMap<std::string, mindspore::HashMap<std::string, std::vector<OpData<DeviceTensor> *>>> 115 batch_graph_input_data_; 116 mindspore::HashMap<std::string, mindspore::HashMap<std::string, std::vector<DataArrowPtr>>> 117 batch_graph_input_data_arrows_; 118 119 // Graph outputs receive from kernel/superkernel actors of graph. 120 mindspore::HashMap<int, std::vector<OpData<DeviceTensor> *>> graph_output_op_data_; 121 mindspore::HashMap<int, std::vector<AID *>> graph_output_op_control_; 122 std::vector<DeviceTensor *> graph_ouput_device_tensors_; 123 // In any type kernel actor, the kernel in the model graph will have fallback scenario, the device type of the 124 // model graph and the real graph will be different. A new device address needs to be created for the model graph 125 // and placed here. 126 std::vector<DeviceTensorPtr> fallback_device_tensors_; 127 mindspore::HashMap<std::string, size_t> graph_output_data_num_; 128 mindspore::HashMap<std::string, size_t> graph_output_control_num_; 129 130 AnyTypeKernelActorState actor_state_{kAnyTypeKernelActorInit}; 131 132 static std::mutex instance_lock_; 133 134 CompileFunc compile_func_; 135 TransformFunc transform_func_; 136 ScheduleFunc schedule_func_; 137 }; 138 139 using AnyTypeKernelActorPtr = std::shared_ptr<AnyTypeKernelActor>; 140 } // namespace runtime 141 } // namespace mindspore 142 143 #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_ANY_TYPE_ACTOR_H_ 144