1 /** 2 * Copyright 2021-2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SUPER_KERNEL_ACTOR_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SUPER_KERNEL_ACTOR_H_ 19 20 #include <string> 21 #include <memory> 22 #include <map> 23 #include <utility> 24 #include <vector> 25 #include <queue> 26 #include "runtime/graph_scheduler/actor/debug_aware_actor.h" 27 #include "runtime/graph_scheduler/actor/actor_common.h" 28 #include "runtime/graph_scheduler/actor/kernel_actor.h" 29 #include "runtime/graph_scheduler/actor/kernel_async_launch_actor.h" 30 #include "runtime/graph_scheduler/actor/kernel_async_infer_actor.h" 31 #include "runtime/graph_scheduler/actor/kernel_async_resize_actor.h" 32 #include "runtime/hardware/device_context.h" 33 #include "ir/anf.h" 34 35 namespace mindspore { 36 namespace runtime { 37 using mindspore::device::DeviceAddress; 38 using mindspore::device::DeviceContext; 39 40 struct OutputMemoryInfo { 41 size_t size; 42 std::string node_full_name; 43 }; 44 45 // The Super kernel actor is used to represent the sink executing of graph which is the combination of kernels. 46 class SuperKernelActor : public DebugAwareActor { 47 public: 48 SuperKernelActor(const std::string &name, const KernelGraphPtr &graph, const DeviceContext *device_context, 49 const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid, 50 KernelTransformType type = KernelTransformType::kSuperKernelActor) DebugAwareActor(name,type,recorder_aid,memory_manager_aid,debug_aid,nullptr)51 : DebugAwareActor(name, type, recorder_aid, memory_manager_aid, debug_aid, nullptr), graph_(graph) { 52 (void)device_contexts_.emplace_back(device_context); 53 input_device_tensors_.resize(graph->input_nodes().size()); 54 enable_kbk_sub_graph_execute_ = EnableKbkSubGraphExecute(); 55 enable_trace_memory_ = EnableTraceMemory(); 56 kernel_async_infer_aid_ = KernelAsyncInferActor::GetInstance()->GetAID(); 57 kernel_async_resize_aid_ = KernelAsyncResizeActor::GetInstance()->GetAID(); 58 kernel_async_launch_aid_ = KernelAsyncLaunchActor::GetInstance()->GetAID(); 59 somas_info_ = graph_->MutableSomasInfo(); 60 } 61 ~SuperKernelActor() override = default; 62 63 size_t FetchInputNodePosition(const AnfNodePtr &intput_node); 64 virtual void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context); 65 // The debug related operation interface. 66 void SendDebugReq(OpContext<DeviceTensor> *const context) override; 67 68 // The memory related operation interface. 69 void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override; 70 // The callback after memory alloc finished. 71 void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override; 72 // The input may come from the control actor, so need free the input memory by the dynamic ref count. 73 void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override; 74 bool CopyInputData(const OpContext<DeviceTensor> *context, const KernelGraphPtr &graph); 75 graph()76 const KernelGraphPtr &graph() const { return graph_; } 77 78 protected: 79 void Init() override; 80 void Run(OpContext<DeviceTensor> *const context) override; 81 // The input device tensors for launch. 82 std::vector<DeviceTensor *> input_device_tensors_; 83 // The device tensors of graph input parameter, which used to compare the recv input data. 84 std::vector<DeviceTensorPtr> node_device_tensors_; 85 // The device tensors for memory alloc. 86 std::vector<DeviceTensor *> memory_alloc_list_; 87 // The lists of device tensors which need free by dynamic ref count, will be cleared at the end of step. 88 std::queue<std::vector<DeviceTensor *>> memory_free_lists_; 89 90 private: 91 bool CopyInputDataPersistedHandle(const DeviceContext *device_context, DeviceTensor *input_device_tensor, 92 const DeviceTensorPtr &node_device_tensor, size_t i); 93 void RunGraphKernelByKernel(OpContext<DeviceTensor> *const context); 94 95 void UpdateMemoryTraceMangerStatus(OpContext<DeviceTensor> *const context); 96 void SetTraceMemoryForKernel(const KernelActorPtr &kernel_actor); 97 98 void FetchPersistentDeviceTensor(); 99 100 friend class GraphScheduler; 101 KernelGraphPtr graph_; 102 103 // In the scheduler, check whether the parameters need to be copied after lunch. Only when the parameter has 104 // the ref attribute and is directly used by the kernel in the graph, it needs to be copied. 105 std::vector<bool> is_parameters_need_copy_; 106 107 // Record the address map of ref node to copy back when running finished. 108 std::map<DeviceAddress *, DeviceAddress *> ref_node_addr_map_; 109 110 // The received input device type and format may be different from the formal parameter in the control flow scenarios, 111 // so it needs to be copied from the input data to real data that graph launch needs. 112 std::vector<DeviceTensorPtr> copy_input_device_tensors_; 113 // Record the device address to the output node of graph. 114 std::map<DeviceAddress *, OutputMemoryInfo> device_address_to_node_; 115 116 // For kerkel by kernl execute a sub garph. 117 void BuildKernelActors(); 118 // Cache the kernel input index whose input is graph's input. 119 void ParseInputIndex(); 120 121 void CalcRefCount(); 122 123 // Kernel by kernel sub graph execute mode need not send actor message. 124 bool enable_kbk_sub_graph_execute_; 125 bool already_fetch_persistent_device_tensor_{false}; 126 std::vector<KernelActorPtr> kernel_actors_; 127 mindspore::HashMap<AnfNode *, std::vector<std::pair<size_t, size_t>>> kernel_input_to_graph_input_indices_; 128 SomasInfo *somas_info_; 129 130 AID kernel_async_infer_aid_; 131 AID kernel_async_resize_aid_; 132 AID kernel_async_launch_aid_; 133 134 bool enable_trace_memory_; 135 }; 136 137 using SuperKernelActorPtr = std::shared_ptr<SuperKernelActor>; 138 } // namespace runtime 139 } // namespace mindspore 140 141 #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_KERNEL_ACTOR_H_ 142