• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/pynative/op_compiler.h"
18 
19 #include <memory>
20 #include <algorithm>
21 #include <vector>
22 #include <unordered_set>
23 #include "mindspore/core/ops/op_utils.h"
24 #include "include/backend/anf_runtime_algorithm.h"
25 #include "ops/nn_op_name.h"
26 #include "ops/conv_pool_op_name.h"
27 #include "runtime/pynative/op_executor.h"
28 #include "runtime/pynative/op_runtime_info.h"
29 #include "runtime/device/device_address_utils.h"
30 #include "backend/common/optimizer/common_backend_optimization.h"
31 #ifdef ENABLE_D
32 #include "transform/acl_ir/acl_adapter_info.h"
33 #endif
34 
35 namespace mindspore {
36 using runtime::DeviceAddressUtils;
37 namespace pynative {
38 namespace {
39 using KernelWithIndex = std::pair<AnfNodePtr, size_t>;
40 mindspore::HashSet<std::string> kExcludedAttr = {"input_names", "output_names", "IsFeatureMapOutput",
41                                                  "IsFeatureMapInputList", "pri_format"};
42 std::vector<std::string> kNumStrCache;
43 
GetNumString(int n)44 inline std::string GetNumString(int n) {
45   if (n >= static_cast<int>(kNumStrCache.size())) {
46     return std::to_string(n);
47   }
48 
49   return kNumStrCache[n];
50 }
51 
UpdateRefInfoBeforeCreateKernel(const session::BackendOpRunInfoPtr & op_run_info,const KernelGraphPtr & graph)52 void UpdateRefInfoBeforeCreateKernel(const session::BackendOpRunInfoPtr &op_run_info, const KernelGraphPtr &graph) {
53   // Building Graph and Create Kernel is async, under pynative mode.Ref info is bind with kernel.
54   // So need to get ref info to generate output addr, before create kernel.
55   if (op_run_info->base_op_run_info.device_target != kCPUDevice &&
56       op_run_info->base_op_run_info.device_target != kGPUDevice) {
57     // just ascend ref mode is diff with cpu and gpu
58     return;
59   }
60 
61   AnfAlgo::AddOutInRefToGraph(graph);
62 }
63 
CreateDeviceAddressWithoutWorkspace(const KernelGraphPtr & graph,const DeviceContext * device_context,bool is_gradient_out)64 void CreateDeviceAddressWithoutWorkspace(const KernelGraphPtr &graph, const DeviceContext *device_context,
65                                          bool is_gradient_out) {
66   DeviceAddressUtils::CreateParameterDeviceAddress(device_context, graph);
67   DeviceAddressUtils::CreateValueNodeDeviceAddress(device_context, graph);
68   DeviceAddressUtils::CreateKernelOutputDeviceAddress(device_context, graph, is_gradient_out);
69   DeviceAddressUtils::UpdateDeviceAddressForInplaceNode(graph);
70   DeviceAddressUtils::UpdateDeviceAddressForRefNode(graph);
71 }
72 
SetIgnoreSyncHostToDeviceList(const SimpleGraphPtr & simple_graph)73 void SetIgnoreSyncHostToDeviceList(const SimpleGraphPtr &simple_graph) {
74   const auto &single_ops = simple_graph->single_ops_;
75   for (const auto &single_op : single_ops) {
76     const auto &kernel = single_op->kernel_;
77     const auto &edges = single_op->inputs_;
78 
79     auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
80     MS_EXCEPTION_IF_NULL(kernel_mod);
81     std::vector<size_t> ignore_input_index_list = kernel_mod->GetLaunchIgnoredInputAddressIdx();
82     for (size_t index : ignore_input_index_list) {
83       // Some input may be converted to attribute or input size is wrong.
84       // This behavior is incorrect, but it does exist in the current kernel
85       // and needs to be rectified by the operators who develop this kernel.
86       if (index >= edges.size()) {
87         MS_LOG(INFO) << simple_graph->name_ << " ignore input index is " << index << ", but total input num is "
88                      << edges.size();
89         continue;
90       }
91       edges[index]->ignore_h2d_ = true;
92       MS_LOG(INFO) << "For graph " << simple_graph->name_ << " ignore input host to device " << index;
93     }
94   }
95 }
96 }  // namespace
97 
OpCompiler()98 OpCompiler::OpCompiler() {
99   session_ = session::SessionFactory::Get().Create(kSessionBasic);
100   for (size_t i = 0; i < kNumberTypeEnd; i++) {
101     (void)kNumStrCache.emplace_back(std::to_string(i));
102   }
103 }
104 
GetInstance()105 OpCompiler &OpCompiler::GetInstance() {
106   static OpCompiler instance;
107   return instance;
108 }
109 
UpdateStatus(bool ready)110 void OpCompilerInfo::UpdateStatus(bool ready) { ready_.store(ready, std::memory_order_release); }
111 
WaitReady() const112 void OpCompilerInfo::WaitReady() const {
113   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kWaitTaskFinish,
114                                      graph_info_, true);
115   while (!ready_.load(std::memory_order_acquire)) {
116     std::this_thread::yield();
117   }
118 }
119 
IsInvalidInferResultOp(const std::string & op_name) const120 bool OpCompiler::IsInvalidInferResultOp(const std::string &op_name) const {
121   static const std::unordered_set<std::string> kInvalidInferResultOp = {kDropoutOpName, kMaxPoolWithArgmaxOpName,
122                                                                         kLSTMOpName};
123   return kInvalidInferResultOp.find(op_name) != kInvalidInferResultOp.end();
124 }
125 
GenerateKernelGraph(const session::BackendOpRunInfoPtr & op_run_info,const device::DeviceContext * device_context) const126 KernelGraphPtr OpCompiler::GenerateKernelGraph(const session::BackendOpRunInfoPtr &op_run_info,
127                                                const device::DeviceContext *device_context) const {
128   MS_EXCEPTION_IF_NULL(session_);
129   MS_EXCEPTION_IF_NULL(device_context);
130   MS_EXCEPTION_IF_NULL(op_run_info->op_prim);
131   KernelGraphPtr graph;
132   graph = session_->ConstructSingleOpGraph(op_run_info, op_run_info->base_op_run_info.expanded_input_values,
133                                            op_run_info->base_op_run_info.input_types);
134   graph->set_is_from_single_op(true);
135   return graph;
136 }
137 
AssignStreamIdForSingleOpGraph(const KernelGraphPtr & graph,uint32_t stream_id)138 void OpCompiler::AssignStreamIdForSingleOpGraph(const KernelGraphPtr &graph, uint32_t stream_id) {
139   MS_EXCEPTION_IF_NULL(graph);
140 
141   for (const auto &cnode : graph->execution_order()) {
142     MS_EXCEPTION_IF_NULL(cnode);
143     AnfAlgo::SetStreamId(stream_id, cnode.get());
144     size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
145     for (size_t index = 0; index < input_num; ++index) {
146       const auto &input_node = common::AnfAlgo::GetInputNode(cnode, index);
147       AnfAlgo::SetStreamId(stream_id, input_node.get());
148     }
149   }
150 }
151 
Compile(const session::BackendOpRunInfoPtr & op_run_info,bool * single_op_cache_hit,const std::string & device_name,const uint32_t & device_id)152 OpCompilerInfoPtr OpCompiler::Compile(const session::BackendOpRunInfoPtr &op_run_info, bool *single_op_cache_hit,
153                                       const std::string &device_name, const uint32_t &device_id) {
154   MS_EXCEPTION_IF_NULL(op_run_info);
155   const auto &graph_info = GetSingleOpGraphInfo(op_run_info->base_op_run_info, op_run_info->op_prim);
156   const auto &iter = op_compiler_infos_.find(graph_info);
157   // Check if the graph cache exists.
158   if (iter != op_compiler_infos_.end()) {
159     MS_EXCEPTION_IF_NULL(iter->second);
160     const auto &op_compiler_info = iter->second;
161     MS_EXCEPTION_IF_NULL(op_compiler_info);
162     *single_op_cache_hit = true;
163     return iter->second;
164   }
165 
166   MS_LOG(INFO) << "Run Op cache miss " << graph_info;
167   runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyNativeOpCompile,
168                                      graph_info, true);
169 
170   *single_op_cache_hit = false;
171   // Generate kernel graph.
172   MS_EXCEPTION_IF_NULL(session_);
173   const auto &device_context =
174     device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name, device_id});
175   MS_EXCEPTION_IF_NULL(device_context);
176   device_context->Initialize();
177   py::gil_scoped_acquire acquire_gil;
178   KernelGraphPtr graph = GenerateKernelGraph(op_run_info, device_context);
179   MS_EXCEPTION_IF_NULL(graph);
180 
181   graph->set_run_mode(device::RunMode::kKernelMode);
182   bool use_dynamic_shape_process = op_run_info->base_op_run_info.use_dynamic_shape_process;
183   auto kernel_executor = device_context->GetKernelExecutor(use_dynamic_shape_process);
184   MS_EXCEPTION_IF_NULL(kernel_executor);
185 
186   opt::OptimizationWithoutBackend(graph);
187   // Unify the MindIR, must be before of the graph optimization.
188   kernel_executor->AddMindIRPass(graph);
189 
190   // Select kernel and optimize
191   kernel_executor->OptimizeGraph(graph);
192 
193   UpdateRefInfoBeforeCreateKernel(op_run_info, graph);
194   AssignStreamIdForSingleOpGraph(graph, op_run_info->base_op_run_info.stream_id);
195   // Create device address for all anf nodes of graph.
196   CreateDeviceAddressWithoutWorkspace(graph, device_context, op_run_info->is_gradient_out);
197 
198   auto output_nodes = graph->outputs();
199   std::vector<KernelWithIndex> outputs_with_index;
200   std::vector<size_t> outputs_tensor_num;
201   std::vector<std::string> outputs_padding_type;
202   bool need_refresh_abstract = IsInvalidInferResultOp(op_run_info->base_op_run_info.op_name);
203   for (auto &node : output_nodes) {
204     MS_EXCEPTION_IF_NULL(node);
205     const auto &output_with_index = common::AnfAlgo::VisitKernel(node, 0);
206     (void)outputs_with_index.emplace_back(output_with_index);
207     (void)outputs_tensor_num.emplace_back(AnfAlgo::GetOutputTensorNum(output_with_index.first));
208     const auto &padding_type = (device_context->GetDeviceType() == device::DeviceType::kAscend
209                                   ? AnfAlgo::GetOutputReshapeType(output_with_index.first, output_with_index.second)
210                                   : "");
211     (void)outputs_padding_type.emplace_back(padding_type);
212 
213     MS_EXCEPTION_IF_NULL(output_with_index.first);
214     const auto &abstract = output_with_index.first->abstract();
215     MS_EXCEPTION_IF_NULL(abstract);
216     const auto &shape = abstract->BuildShape();
217     MS_EXCEPTION_IF_NULL(shape);
218     if (shape->IsDynamic()) {
219       need_refresh_abstract = true;
220     }
221   }
222   AnfAlgo::UpdateGraphValidRefPair(graph);
223   UpdateRefNodeOutputDeviceAddress(graph);
224   auto simple_graph = IrConverter::Convert(op_run_info->base_op_run_info.op_name, graph, device_context);
225   MS_LOG(DEBUG) << "DEBUG generate new IR " << simple_graph->DebugInfo().dump();
226 
227   auto op_compiler_info = std::make_shared<OpCompilerInfo>(
228     graph_info, graph->graph_id(), graph, device_context, op_run_info->base_op_run_info.need_earse_cache,
229     need_refresh_abstract, outputs_with_index, outputs_tensor_num, outputs_padding_type, std::move(simple_graph));
230 
231   graph->set_graph_info(graph_info);
232   op_compiler_infos_[graph_info] = op_compiler_info;
233   return op_compiler_info;
234 }
235 
KernelBuild(const OpCompilerInfoPtr & op_compiler_info,const DeviceContext * device_context,bool is_dynamic) const236 void OpCompiler::KernelBuild(const OpCompilerInfoPtr &op_compiler_info, const DeviceContext *device_context,
237                              bool is_dynamic) const {
238   MS_EXCEPTION_IF_NULL(device_context);
239   MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
240   // The compilation task may be in a child thread that has not yet set rt_context,
241   // but the AICPU.so loading needs to use rt_context
242   if (!device_context->device_res_manager_->BindDeviceToCurrentThread(true)) {
243     MS_LOG(EXCEPTION) << "Bind device failed";
244   }
245   std::vector<CNodePtr> node_to_build;
246   const auto &graph = op_compiler_info->graph_;
247   MS_EXCEPTION_IF_NULL(graph);
248   const auto &nodes = graph->execution_order();
249   (void)std::copy(nodes.begin(), nodes.end(), std::back_inserter(node_to_build));
250   // Kernel build
251   auto kernel_executor = device_context->GetKernelExecutor(is_dynamic);
252   MS_EXCEPTION_IF_NULL(kernel_executor);
253   kernel_executor->CreateKernel(node_to_build);
254   kernel_executor->PreprocessBeforeRun(graph);
255   DeviceAddressUtils::CreateKernelWorkspaceDeviceAddress(device_context, graph);
256   // Need to execute after PreprocessBeforeRunSingleOpGraph
257   runtime::OpRuntimeInfo::CacheGraphOpRuntimeInfo(graph);
258 
259   // After kernel generated.
260   SetIgnoreSyncHostToDeviceList(op_compiler_info->simple_graph_);
261 }
262 
263 #ifdef ENABLE_D
GetGraphInfoForAscendSpecial(const pynative::BaseOpRunInfo & op_info,const PrimitivePtr & op_prim,const std::string & graph_info)264 std::string GetGraphInfoForAscendSpecial(const pynative::BaseOpRunInfo &op_info, const PrimitivePtr &op_prim,
265                                          const std::string &graph_info) {
266   std::string ascend_special_info = graph_info;
267   MS_EXCEPTION_IF_NULL(op_prim);
268   auto op_name = op_prim->name();
269   auto ms_context = MsContext::GetInstance();
270   MS_EXCEPTION_IF_NULL(ms_context);
271   if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice &&
272       transform::AclAdapterManager::GetInstance().CheckAclAdapter(op_name)) {
273     auto acl_info = transform::AclAdapterManager::GetInstance().GetOpInfo(op_name);
274     if (!acl_info.input_selector().empty() || acl_info.output_selector() != nullptr) {
275       if (op_info.expanded_input_values.size() == 0) {
276         return ascend_special_info;
277       }
278       TypeId first_dtype = TypeId::kTypeUnknown;
279       std::vector<ShapeVector> input_shapes;
280       (void)std::transform(op_info.expanded_input_values.begin(), op_info.expanded_input_values.end(),
281                            std::back_inserter(input_shapes), [&first_dtype](const ValuePtr &value) -> ShapeVector {
282                              auto tensor = value->cast<tensor::BaseTensorPtr>();
283                              if (tensor != nullptr) {
284                                if (first_dtype == TypeId::kTypeUnknown) {
285                                  first_dtype = tensor->data_type();
286                                }
287                                return tensor->shape();
288                              }
289                              return {};
290                            });
291 
292       auto in_func_map = acl_info.input_selector();
293       for (auto [index, in_func] : in_func_map) {
294         MS_EXCEPTION_IF_NULL(in_func);
295         auto tensor = op_info.expanded_input_values[index]->cast<tensor::BaseTensorPtr>();
296         MS_EXCEPTION_IF_NULL(tensor);
297         ascend_special_info += in_func(tensor->data_type(), input_shapes);
298       }
299 
300       auto out_func = acl_info.output_selector();
301       if (out_func != nullptr) {
302         auto tensor = op_info.expanded_input_values[0]->cast<tensor::BaseTensorPtr>();
303         MS_EXCEPTION_IF_NULL(tensor);
304         auto out_format = out_func(tensor->data_type(), input_shapes);
305         ascend_special_info += out_format;
306       }
307       MS_EXCEPTION_IF_NULL(out_func);
308       auto tensor = op_info.expanded_input_values[0]->cast<tensor::BaseTensorPtr>();
309       MS_EXCEPTION_IF_NULL(tensor);
310       auto out_format = out_func(tensor->data_type(), input_shapes);
311       ascend_special_info += out_format;
312     }
313   }
314   return ascend_special_info;
315 }
316 #endif
317 
GetDependList(const pynative::BaseOpRunInfo & op_info,const PrimitivePtr & op_prim)318 inline std::set<int64_t> GetDependList(const pynative::BaseOpRunInfo &op_info, const PrimitivePtr &op_prim) {
319   auto depend_list = mindspore::ops::GetInputDependValueList(op_prim);
320   if (!op_info.dyn_input_sizes.empty()) {
321     auto list_tmp = depend_list;
322     depend_list.clear();
323     for (const auto item : list_tmp) {
324       int64_t bias = 0;
325       for (int64_t i = 0; i < item; i++) {
326         auto idx = static_cast<size_t>(i);
327         if (op_info.dyn_input_sizes[idx] == -1) {
328           bias += 1;
329         } else {
330           bias += op_info.dyn_input_sizes[idx];
331         }
332       }
333       (void)depend_list.emplace(bias);
334       MS_LOG(DEBUG) << "Adjust depend list from " << item << " to " << bias << " for op: " << op_prim->name();
335     }
336   }
337 
338   return depend_list;
339 }
340 
GetSingleOpGraphInfo(const pynative::BaseOpRunInfo & op_info,const PrimitivePtr & op_prim) const341 std::string OpCompiler::GetSingleOpGraphInfo(const pynative::BaseOpRunInfo &op_info,
342                                              const PrimitivePtr &op_prim) const {
343   MS_EXCEPTION_IF_NULL(op_prim);
344   if (op_info.expanded_input_values.size() != op_info.input_types.size()) {
345     MS_LOG(EXCEPTION) << "Input tensors size " << op_info.expanded_input_values.size()
346                       << " should be equal to tensors mask size " << op_info.input_types.size();
347   }
348   std::string graph_info = op_info.device_target;
349 
350   if (op_info.use_dynamic_shape_process) {
351     graph_info += "_1_";
352   } else {
353     graph_info += "_0_";
354   }
355   auto op_name = op_prim->name();
356   graph_info += op_name;
357   bool has_hidden_side_effect;
358   {
359     PrimitiveReadLock read_lock(op_prim->shared_mutex());
360     if (op_info.need_earse_cache) {
361       return graph_info;
362     }
363     has_hidden_side_effect = op_prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_HIDDEN);
364     // The value of the attribute affects the operator selection
365     const auto &attr_map = op_prim->attrs();
366     (void)std::for_each(attr_map.begin(), attr_map.end(), [&graph_info](const auto &element) {
367       if (kExcludedAttr.find(element.first) != kExcludedAttr.end()) {
368         return;
369       }
370       MS_EXCEPTION_IF_NULL(element.second);
371       graph_info.append(element.second->ToString());
372     });
373   }
374 
375   const auto &depend_list = GetDependList(op_info, op_prim);
376   for (size_t index = 0; index < op_info.expanded_input_values.size(); ++index) {
377     auto const &value = op_info.expanded_input_values[index];
378     if (value->isa<tensor::BaseTensor>()) {
379       const auto &input_tensor = value->cast<tensor::BaseTensorPtr>();
380       MS_EXCEPTION_IF_NULL(input_tensor);
381       if (op_info.use_dynamic_shape_process) {
382         graph_info += GetNumString(static_cast<int>(input_tensor->shape().size()));
383       } else {
384         if (input_tensor->base_shape_ptr() != nullptr) {
385           graph_info += input_tensor->base_shape_ptr()->ToString();
386         } else if (!input_tensor->shape().empty()) {
387           const auto &shape_str =
388             std::accumulate(std::next(input_tensor->shape().begin()), input_tensor->shape().end(),
389                             std::to_string(input_tensor->shape()[0]),
390                             [](std::string cur, size_t n) { return cur.append("-").append(std::to_string(n)); });
391           graph_info += shape_str;
392         }
393       }
394 
395       graph_info += GetNumString(input_tensor->data_type());
396       // In the case of the same shape, but dtype and format are inconsistent
397       auto tensor_addr = input_tensor->device_address();
398       if (tensor_addr != nullptr && !has_hidden_side_effect) {
399         auto p_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor_addr);
400         MS_EXCEPTION_IF_NULL(p_address);
401         graph_info += p_address->format();
402         graph_info += p_address->padding_type();
403       }
404 
405       if (op_info.input_types[index] == InputType::kConstant || depend_list.find(index) != depend_list.end()) {
406         graph_info += common::AnfAlgo::GetTensorValueString(input_tensor);
407       }
408     } else {
409       graph_info += value->ToString();
410     }
411 
412     graph_info += "_";
413   }
414 
415   graph_info += std::to_string(op_info.stream_id);
416 
417   // Operator with hidden side effect.
418   if (has_hidden_side_effect) {
419     (void)graph_info.append("r_").append(std::to_string(op_info.py_prim_id_)).append("_");
420   }
421 
422 #ifdef ENABLE_D
423   // Ascend special info.
424   graph_info = GetGraphInfoForAscendSpecial(op_info, op_prim, graph_info);
425 #endif
426 
427   return graph_info;
428 }
429 
ClearOpCache(const GraphInfo & graph_info)430 void OpCompiler::ClearOpCache(const GraphInfo &graph_info) { (void)op_compiler_infos_.erase(graph_info); }
431 
ClearAllCache()432 void OpCompiler::ClearAllCache() { op_compiler_infos_.clear(); }
433 
UpdateRefNodeOutputDeviceAddress(const KernelGraphPtr & graph)434 void OpCompiler::UpdateRefNodeOutputDeviceAddress(const KernelGraphPtr &graph) {
435   MS_EXCEPTION_IF_NULL(graph);
436   auto ref_node_map = graph->GetRefMap();
437   for (const auto &[output_pair, input_pair] : ref_node_map) {
438     const auto &[ref_node, output_index] = output_pair;
439     const auto &[input_node, input_node_output_index] = input_pair;
440     if (!AnfAlgo::OutputAddrExist(input_node, input_node_output_index, false)) {
441       MS_EXCEPTION_IF_NULL(input_node);
442       MS_LOG(WARNING) << "Output address not exist, node " << input_node->fullname_with_scope() << " index "
443                       << input_node_output_index;
444       continue;
445     }
446     auto input_addr = AnfAlgo::GetMutableOutputAddr(input_node, input_node_output_index, false);
447     AnfAlgo::SetOutputAddr(input_addr, output_index, ref_node.get());
448   }
449 }
450 }  // namespace pynative
451 }  // namespace mindspore
452