• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SUPER_KERNEL_ACTOR_H_
18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_SUPER_KERNEL_ACTOR_H_
19 
20 #include <string>
21 #include <memory>
22 #include <map>
23 #include <utility>
24 #include <vector>
25 #include <queue>
26 #include "runtime/graph_scheduler/actor/debug_aware_actor.h"
27 #include "runtime/graph_scheduler/actor/actor_common.h"
28 #include "runtime/graph_scheduler/actor/kernel_actor.h"
29 #include "runtime/graph_scheduler/actor/kernel_async_launch_actor.h"
30 #include "runtime/graph_scheduler/actor/kernel_async_infer_actor.h"
31 #include "runtime/graph_scheduler/actor/kernel_async_resize_actor.h"
32 #include "runtime/hardware/device_context.h"
33 #include "ir/anf.h"
34 
35 namespace mindspore {
36 namespace runtime {
37 using mindspore::device::DeviceAddress;
38 using mindspore::device::DeviceContext;
39 
40 struct OutputMemoryInfo {
41   size_t size;
42   std::string node_full_name;
43 };
44 
45 // The Super kernel actor is used to represent the sink executing of graph which is the combination of kernels.
46 class SuperKernelActor : public DebugAwareActor {
47  public:
48   SuperKernelActor(const std::string &name, const KernelGraphPtr &graph, const DeviceContext *device_context,
49                    const AID &memory_manager_aid, const AID *debug_aid, const AID *recorder_aid,
50                    KernelTransformType type = KernelTransformType::kSuperKernelActor)
DebugAwareActor(name,type,recorder_aid,memory_manager_aid,debug_aid,nullptr)51       : DebugAwareActor(name, type, recorder_aid, memory_manager_aid, debug_aid, nullptr), graph_(graph) {
52     (void)device_contexts_.emplace_back(device_context);
53     input_device_tensors_.resize(graph->input_nodes().size());
54     enable_kbk_sub_graph_execute_ = EnableKbkSubGraphExecute();
55     enable_trace_memory_ = EnableTraceMemory();
56     kernel_async_infer_aid_ = KernelAsyncInferActor::GetInstance()->GetAID();
57     kernel_async_resize_aid_ = KernelAsyncResizeActor::GetInstance()->GetAID();
58     kernel_async_launch_aid_ = KernelAsyncLaunchActor::GetInstance()->GetAID();
59     somas_info_ = graph_->MutableSomasInfo();
60   }
61   ~SuperKernelActor() override = default;
62 
63   size_t FetchInputNodePosition(const AnfNodePtr &intput_node);
64   virtual void FetchInputDeviceTensor(OpContext<DeviceTensor> *const context);
65   // The debug related operation interface.
66   void SendDebugReq(OpContext<DeviceTensor> *const context) override;
67 
68   // The memory related operation interface.
69   void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override;
70   // The callback after memory alloc finished.
71   void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override;
72   // The input may come from the control actor, so need free the input memory by the dynamic ref count.
73   void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override;
74   bool CopyInputData(const OpContext<DeviceTensor> *context, const KernelGraphPtr &graph);
75 
graph()76   const KernelGraphPtr &graph() const { return graph_; }
77 
78  protected:
79   void Init() override;
80   void Run(OpContext<DeviceTensor> *const context) override;
81   // The input device tensors for launch.
82   std::vector<DeviceTensor *> input_device_tensors_;
83   // The device tensors of graph input parameter, which used to compare the recv input data.
84   std::vector<DeviceTensorPtr> node_device_tensors_;
85   // The device tensors for memory alloc.
86   std::vector<DeviceTensor *> memory_alloc_list_;
87   // The lists of device tensors which need free by dynamic ref count, will be cleared at the end of step.
88   std::queue<std::vector<DeviceTensor *>> memory_free_lists_;
89 
90  private:
91   bool CopyInputDataPersistedHandle(const DeviceContext *device_context, DeviceTensor *input_device_tensor,
92                                     const DeviceTensorPtr &node_device_tensor, size_t i);
93   void RunGraphKernelByKernel(OpContext<DeviceTensor> *const context);
94 
95   void UpdateMemoryTraceMangerStatus(OpContext<DeviceTensor> *const context);
96   void SetTraceMemoryForKernel(const KernelActorPtr &kernel_actor);
97 
98   void FetchPersistentDeviceTensor();
99 
100   friend class GraphScheduler;
101   KernelGraphPtr graph_;
102 
103   // In the scheduler, check whether the parameters need to be copied after lunch. Only when the parameter has
104   // the ref attribute and is directly used by the kernel in the graph, it needs to be copied.
105   std::vector<bool> is_parameters_need_copy_;
106 
107   // Record the address map of ref node to copy back when running finished.
108   std::map<DeviceAddress *, DeviceAddress *> ref_node_addr_map_;
109 
110   // The received input device type and format may be different from the formal parameter in the control flow scenarios,
111   // so it needs to be copied from the input data to real data that graph launch needs.
112   std::vector<DeviceTensorPtr> copy_input_device_tensors_;
113   // Record the device address to the output node of graph.
114   std::map<DeviceAddress *, OutputMemoryInfo> device_address_to_node_;
115 
116   // For kerkel by kernl execute a sub garph.
117   void BuildKernelActors();
118   // Cache the kernel input index whose input is graph's input.
119   void ParseInputIndex();
120 
121   void CalcRefCount();
122 
123   // Kernel by kernel sub graph execute mode need not send actor message.
124   bool enable_kbk_sub_graph_execute_;
125   bool already_fetch_persistent_device_tensor_{false};
126   std::vector<KernelActorPtr> kernel_actors_;
127   mindspore::HashMap<AnfNode *, std::vector<std::pair<size_t, size_t>>> kernel_input_to_graph_input_indices_;
128   SomasInfo *somas_info_;
129 
130   AID kernel_async_infer_aid_;
131   AID kernel_async_resize_aid_;
132   AID kernel_async_launch_aid_;
133 
134   bool enable_trace_memory_;
135 };
136 
137 using SuperKernelActorPtr = std::shared_ptr<SuperKernelActor>;
138 }  // namespace runtime
139 }  // namespace mindspore
140 
141 #endif  // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_KERNEL_ACTOR_H_
142