1 /**
2 * Copyright 2022-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/pynative/op_compiler.h"
18
19 #include <memory>
20 #include <algorithm>
21 #include <vector>
22 #include <unordered_set>
23 #include "mindspore/core/ops/op_utils.h"
24 #include "include/backend/anf_runtime_algorithm.h"
25 #include "ops/nn_op_name.h"
26 #include "ops/conv_pool_op_name.h"
27 #include "runtime/pynative/op_executor.h"
28 #include "runtime/pynative/op_runtime_info.h"
29 #include "runtime/device/device_address_utils.h"
30 #include "backend/common/optimizer/common_backend_optimization.h"
31 #ifdef ENABLE_D
32 #include "transform/acl_ir/acl_adapter_info.h"
33 #endif
34
35 namespace mindspore {
36 using runtime::DeviceAddressUtils;
37 namespace pynative {
38 namespace {
39 using KernelWithIndex = std::pair<AnfNodePtr, size_t>;
40 mindspore::HashSet<std::string> kExcludedAttr = {"input_names", "output_names", "IsFeatureMapOutput",
41 "IsFeatureMapInputList", "pri_format"};
42 std::vector<std::string> kNumStrCache;
43
GetNumString(int n)44 inline std::string GetNumString(int n) {
45 if (n >= static_cast<int>(kNumStrCache.size())) {
46 return std::to_string(n);
47 }
48
49 return kNumStrCache[n];
50 }
51
UpdateRefInfoBeforeCreateKernel(const session::BackendOpRunInfoPtr & op_run_info,const KernelGraphPtr & graph)52 void UpdateRefInfoBeforeCreateKernel(const session::BackendOpRunInfoPtr &op_run_info, const KernelGraphPtr &graph) {
53 // Building Graph and Create Kernel is async, under pynative mode.Ref info is bind with kernel.
54 // So need to get ref info to generate output addr, before create kernel.
55 if (op_run_info->base_op_run_info.device_target != kCPUDevice &&
56 op_run_info->base_op_run_info.device_target != kGPUDevice) {
57 // just ascend ref mode is diff with cpu and gpu
58 return;
59 }
60
61 AnfAlgo::AddOutInRefToGraph(graph);
62 }
63
CreateDeviceAddressWithoutWorkspace(const KernelGraphPtr & graph,const DeviceContext * device_context,bool is_gradient_out)64 void CreateDeviceAddressWithoutWorkspace(const KernelGraphPtr &graph, const DeviceContext *device_context,
65 bool is_gradient_out) {
66 DeviceAddressUtils::CreateParameterDeviceAddress(device_context, graph);
67 DeviceAddressUtils::CreateValueNodeDeviceAddress(device_context, graph);
68 DeviceAddressUtils::CreateKernelOutputDeviceAddress(device_context, graph, is_gradient_out);
69 DeviceAddressUtils::UpdateDeviceAddressForInplaceNode(graph);
70 DeviceAddressUtils::UpdateDeviceAddressForRefNode(graph);
71 }
72
SetIgnoreSyncHostToDeviceList(const SimpleGraphPtr & simple_graph)73 void SetIgnoreSyncHostToDeviceList(const SimpleGraphPtr &simple_graph) {
74 const auto &single_ops = simple_graph->single_ops_;
75 for (const auto &single_op : single_ops) {
76 const auto &kernel = single_op->kernel_;
77 const auto &edges = single_op->inputs_;
78
79 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
80 MS_EXCEPTION_IF_NULL(kernel_mod);
81 std::vector<size_t> ignore_input_index_list = kernel_mod->GetLaunchIgnoredInputAddressIdx();
82 for (size_t index : ignore_input_index_list) {
83 // Some input may be converted to attribute or input size is wrong.
84 // This behavior is incorrect, but it does exist in the current kernel
85 // and needs to be rectified by the operators who develop this kernel.
86 if (index >= edges.size()) {
87 MS_LOG(INFO) << simple_graph->name_ << " ignore input index is " << index << ", but total input num is "
88 << edges.size();
89 continue;
90 }
91 edges[index]->ignore_h2d_ = true;
92 MS_LOG(INFO) << "For graph " << simple_graph->name_ << " ignore input host to device " << index;
93 }
94 }
95 }
96 } // namespace
97
OpCompiler()98 OpCompiler::OpCompiler() {
99 session_ = session::SessionFactory::Get().Create(kSessionBasic);
100 for (size_t i = 0; i < kNumberTypeEnd; i++) {
101 (void)kNumStrCache.emplace_back(std::to_string(i));
102 }
103 }
104
GetInstance()105 OpCompiler &OpCompiler::GetInstance() {
106 static OpCompiler instance;
107 return instance;
108 }
109
UpdateStatus(bool ready)110 void OpCompilerInfo::UpdateStatus(bool ready) { ready_.store(ready, std::memory_order_release); }
111
WaitReady() const112 void OpCompilerInfo::WaitReady() const {
113 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kWaitTaskFinish,
114 graph_info_, true);
115 while (!ready_.load(std::memory_order_acquire)) {
116 std::this_thread::yield();
117 }
118 }
119
IsInvalidInferResultOp(const std::string & op_name) const120 bool OpCompiler::IsInvalidInferResultOp(const std::string &op_name) const {
121 static const std::unordered_set<std::string> kInvalidInferResultOp = {kDropoutOpName, kMaxPoolWithArgmaxOpName,
122 kLSTMOpName};
123 return kInvalidInferResultOp.find(op_name) != kInvalidInferResultOp.end();
124 }
125
GenerateKernelGraph(const session::BackendOpRunInfoPtr & op_run_info,const device::DeviceContext * device_context) const126 KernelGraphPtr OpCompiler::GenerateKernelGraph(const session::BackendOpRunInfoPtr &op_run_info,
127 const device::DeviceContext *device_context) const {
128 MS_EXCEPTION_IF_NULL(session_);
129 MS_EXCEPTION_IF_NULL(device_context);
130 MS_EXCEPTION_IF_NULL(op_run_info->op_prim);
131 KernelGraphPtr graph;
132 graph = session_->ConstructSingleOpGraph(op_run_info, op_run_info->base_op_run_info.expanded_input_values,
133 op_run_info->base_op_run_info.input_types);
134 graph->set_is_from_single_op(true);
135 return graph;
136 }
137
AssignStreamIdForSingleOpGraph(const KernelGraphPtr & graph,uint32_t stream_id)138 void OpCompiler::AssignStreamIdForSingleOpGraph(const KernelGraphPtr &graph, uint32_t stream_id) {
139 MS_EXCEPTION_IF_NULL(graph);
140
141 for (const auto &cnode : graph->execution_order()) {
142 MS_EXCEPTION_IF_NULL(cnode);
143 AnfAlgo::SetStreamId(stream_id, cnode.get());
144 size_t input_num = common::AnfAlgo::GetInputTensorNum(cnode);
145 for (size_t index = 0; index < input_num; ++index) {
146 const auto &input_node = common::AnfAlgo::GetInputNode(cnode, index);
147 AnfAlgo::SetStreamId(stream_id, input_node.get());
148 }
149 }
150 }
151
Compile(const session::BackendOpRunInfoPtr & op_run_info,bool * single_op_cache_hit,const std::string & device_name,const uint32_t & device_id)152 OpCompilerInfoPtr OpCompiler::Compile(const session::BackendOpRunInfoPtr &op_run_info, bool *single_op_cache_hit,
153 const std::string &device_name, const uint32_t &device_id) {
154 MS_EXCEPTION_IF_NULL(op_run_info);
155 const auto &graph_info = GetSingleOpGraphInfo(op_run_info->base_op_run_info, op_run_info->op_prim);
156 const auto &iter = op_compiler_infos_.find(graph_info);
157 // Check if the graph cache exists.
158 if (iter != op_compiler_infos_.end()) {
159 MS_EXCEPTION_IF_NULL(iter->second);
160 const auto &op_compiler_info = iter->second;
161 MS_EXCEPTION_IF_NULL(op_compiler_info);
162 *single_op_cache_hit = true;
163 return iter->second;
164 }
165
166 MS_LOG(INFO) << "Run Op cache miss " << graph_info;
167 runtime::ProfilerRecorder profiler(runtime::ProfilerModule::kPynative, runtime::ProfilerEvent::kPyNativeOpCompile,
168 graph_info, true);
169
170 *single_op_cache_hit = false;
171 // Generate kernel graph.
172 MS_EXCEPTION_IF_NULL(session_);
173 const auto &device_context =
174 device::DeviceContextManager::GetInstance().GetOrCreateDeviceContext({device_name, device_id});
175 MS_EXCEPTION_IF_NULL(device_context);
176 device_context->Initialize();
177 py::gil_scoped_acquire acquire_gil;
178 KernelGraphPtr graph = GenerateKernelGraph(op_run_info, device_context);
179 MS_EXCEPTION_IF_NULL(graph);
180
181 graph->set_run_mode(device::RunMode::kKernelMode);
182 bool use_dynamic_shape_process = op_run_info->base_op_run_info.use_dynamic_shape_process;
183 auto kernel_executor = device_context->GetKernelExecutor(use_dynamic_shape_process);
184 MS_EXCEPTION_IF_NULL(kernel_executor);
185
186 opt::OptimizationWithoutBackend(graph);
187 // Unify the MindIR, must be before of the graph optimization.
188 kernel_executor->AddMindIRPass(graph);
189
190 // Select kernel and optimize
191 kernel_executor->OptimizeGraph(graph);
192
193 UpdateRefInfoBeforeCreateKernel(op_run_info, graph);
194 AssignStreamIdForSingleOpGraph(graph, op_run_info->base_op_run_info.stream_id);
195 // Create device address for all anf nodes of graph.
196 CreateDeviceAddressWithoutWorkspace(graph, device_context, op_run_info->is_gradient_out);
197
198 auto output_nodes = graph->outputs();
199 std::vector<KernelWithIndex> outputs_with_index;
200 std::vector<size_t> outputs_tensor_num;
201 std::vector<std::string> outputs_padding_type;
202 bool need_refresh_abstract = IsInvalidInferResultOp(op_run_info->base_op_run_info.op_name);
203 for (auto &node : output_nodes) {
204 MS_EXCEPTION_IF_NULL(node);
205 const auto &output_with_index = common::AnfAlgo::VisitKernel(node, 0);
206 (void)outputs_with_index.emplace_back(output_with_index);
207 (void)outputs_tensor_num.emplace_back(AnfAlgo::GetOutputTensorNum(output_with_index.first));
208 const auto &padding_type = (device_context->GetDeviceType() == device::DeviceType::kAscend
209 ? AnfAlgo::GetOutputReshapeType(output_with_index.first, output_with_index.second)
210 : "");
211 (void)outputs_padding_type.emplace_back(padding_type);
212
213 MS_EXCEPTION_IF_NULL(output_with_index.first);
214 const auto &abstract = output_with_index.first->abstract();
215 MS_EXCEPTION_IF_NULL(abstract);
216 const auto &shape = abstract->BuildShape();
217 MS_EXCEPTION_IF_NULL(shape);
218 if (shape->IsDynamic()) {
219 need_refresh_abstract = true;
220 }
221 }
222 AnfAlgo::UpdateGraphValidRefPair(graph);
223 UpdateRefNodeOutputDeviceAddress(graph);
224 auto simple_graph = IrConverter::Convert(op_run_info->base_op_run_info.op_name, graph, device_context);
225 MS_LOG(DEBUG) << "DEBUG generate new IR " << simple_graph->DebugInfo().dump();
226
227 auto op_compiler_info = std::make_shared<OpCompilerInfo>(
228 graph_info, graph->graph_id(), graph, device_context, op_run_info->base_op_run_info.need_earse_cache,
229 need_refresh_abstract, outputs_with_index, outputs_tensor_num, outputs_padding_type, std::move(simple_graph));
230
231 graph->set_graph_info(graph_info);
232 op_compiler_infos_[graph_info] = op_compiler_info;
233 return op_compiler_info;
234 }
235
KernelBuild(const OpCompilerInfoPtr & op_compiler_info,const DeviceContext * device_context,bool is_dynamic) const236 void OpCompiler::KernelBuild(const OpCompilerInfoPtr &op_compiler_info, const DeviceContext *device_context,
237 bool is_dynamic) const {
238 MS_EXCEPTION_IF_NULL(device_context);
239 MS_EXCEPTION_IF_NULL(device_context->device_res_manager_);
240 // The compilation task may be in a child thread that has not yet set rt_context,
241 // but the AICPU.so loading needs to use rt_context
242 if (!device_context->device_res_manager_->BindDeviceToCurrentThread(true)) {
243 MS_LOG(EXCEPTION) << "Bind device failed";
244 }
245 std::vector<CNodePtr> node_to_build;
246 const auto &graph = op_compiler_info->graph_;
247 MS_EXCEPTION_IF_NULL(graph);
248 const auto &nodes = graph->execution_order();
249 (void)std::copy(nodes.begin(), nodes.end(), std::back_inserter(node_to_build));
250 // Kernel build
251 auto kernel_executor = device_context->GetKernelExecutor(is_dynamic);
252 MS_EXCEPTION_IF_NULL(kernel_executor);
253 kernel_executor->CreateKernel(node_to_build);
254 kernel_executor->PreprocessBeforeRun(graph);
255 DeviceAddressUtils::CreateKernelWorkspaceDeviceAddress(device_context, graph);
256 // Need to execute after PreprocessBeforeRunSingleOpGraph
257 runtime::OpRuntimeInfo::CacheGraphOpRuntimeInfo(graph);
258
259 // After kernel generated.
260 SetIgnoreSyncHostToDeviceList(op_compiler_info->simple_graph_);
261 }
262
263 #ifdef ENABLE_D
GetGraphInfoForAscendSpecial(const pynative::BaseOpRunInfo & op_info,const PrimitivePtr & op_prim,const std::string & graph_info)264 std::string GetGraphInfoForAscendSpecial(const pynative::BaseOpRunInfo &op_info, const PrimitivePtr &op_prim,
265 const std::string &graph_info) {
266 std::string ascend_special_info = graph_info;
267 MS_EXCEPTION_IF_NULL(op_prim);
268 auto op_name = op_prim->name();
269 auto ms_context = MsContext::GetInstance();
270 MS_EXCEPTION_IF_NULL(ms_context);
271 if (ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice &&
272 transform::AclAdapterManager::GetInstance().CheckAclAdapter(op_name)) {
273 auto acl_info = transform::AclAdapterManager::GetInstance().GetOpInfo(op_name);
274 if (!acl_info.input_selector().empty() || acl_info.output_selector() != nullptr) {
275 if (op_info.expanded_input_values.size() == 0) {
276 return ascend_special_info;
277 }
278 TypeId first_dtype = TypeId::kTypeUnknown;
279 std::vector<ShapeVector> input_shapes;
280 (void)std::transform(op_info.expanded_input_values.begin(), op_info.expanded_input_values.end(),
281 std::back_inserter(input_shapes), [&first_dtype](const ValuePtr &value) -> ShapeVector {
282 auto tensor = value->cast<tensor::BaseTensorPtr>();
283 if (tensor != nullptr) {
284 if (first_dtype == TypeId::kTypeUnknown) {
285 first_dtype = tensor->data_type();
286 }
287 return tensor->shape();
288 }
289 return {};
290 });
291
292 auto in_func_map = acl_info.input_selector();
293 for (auto [index, in_func] : in_func_map) {
294 MS_EXCEPTION_IF_NULL(in_func);
295 auto tensor = op_info.expanded_input_values[index]->cast<tensor::BaseTensorPtr>();
296 MS_EXCEPTION_IF_NULL(tensor);
297 ascend_special_info += in_func(tensor->data_type(), input_shapes);
298 }
299
300 auto out_func = acl_info.output_selector();
301 if (out_func != nullptr) {
302 auto tensor = op_info.expanded_input_values[0]->cast<tensor::BaseTensorPtr>();
303 MS_EXCEPTION_IF_NULL(tensor);
304 auto out_format = out_func(tensor->data_type(), input_shapes);
305 ascend_special_info += out_format;
306 }
307 MS_EXCEPTION_IF_NULL(out_func);
308 auto tensor = op_info.expanded_input_values[0]->cast<tensor::BaseTensorPtr>();
309 MS_EXCEPTION_IF_NULL(tensor);
310 auto out_format = out_func(tensor->data_type(), input_shapes);
311 ascend_special_info += out_format;
312 }
313 }
314 return ascend_special_info;
315 }
316 #endif
317
GetDependList(const pynative::BaseOpRunInfo & op_info,const PrimitivePtr & op_prim)318 inline std::set<int64_t> GetDependList(const pynative::BaseOpRunInfo &op_info, const PrimitivePtr &op_prim) {
319 auto depend_list = mindspore::ops::GetInputDependValueList(op_prim);
320 if (!op_info.dyn_input_sizes.empty()) {
321 auto list_tmp = depend_list;
322 depend_list.clear();
323 for (const auto item : list_tmp) {
324 int64_t bias = 0;
325 for (int64_t i = 0; i < item; i++) {
326 auto idx = static_cast<size_t>(i);
327 if (op_info.dyn_input_sizes[idx] == -1) {
328 bias += 1;
329 } else {
330 bias += op_info.dyn_input_sizes[idx];
331 }
332 }
333 (void)depend_list.emplace(bias);
334 MS_LOG(DEBUG) << "Adjust depend list from " << item << " to " << bias << " for op: " << op_prim->name();
335 }
336 }
337
338 return depend_list;
339 }
340
GetSingleOpGraphInfo(const pynative::BaseOpRunInfo & op_info,const PrimitivePtr & op_prim) const341 std::string OpCompiler::GetSingleOpGraphInfo(const pynative::BaseOpRunInfo &op_info,
342 const PrimitivePtr &op_prim) const {
343 MS_EXCEPTION_IF_NULL(op_prim);
344 if (op_info.expanded_input_values.size() != op_info.input_types.size()) {
345 MS_LOG(EXCEPTION) << "Input tensors size " << op_info.expanded_input_values.size()
346 << " should be equal to tensors mask size " << op_info.input_types.size();
347 }
348 std::string graph_info = op_info.device_target;
349
350 if (op_info.use_dynamic_shape_process) {
351 graph_info += "_1_";
352 } else {
353 graph_info += "_0_";
354 }
355 auto op_name = op_prim->name();
356 graph_info += op_name;
357 bool has_hidden_side_effect;
358 {
359 PrimitiveReadLock read_lock(op_prim->shared_mutex());
360 if (op_info.need_earse_cache) {
361 return graph_info;
362 }
363 has_hidden_side_effect = op_prim->HasAttr(GRAPH_FLAG_SIDE_EFFECT_HIDDEN);
364 // The value of the attribute affects the operator selection
365 const auto &attr_map = op_prim->attrs();
366 (void)std::for_each(attr_map.begin(), attr_map.end(), [&graph_info](const auto &element) {
367 if (kExcludedAttr.find(element.first) != kExcludedAttr.end()) {
368 return;
369 }
370 MS_EXCEPTION_IF_NULL(element.second);
371 graph_info.append(element.second->ToString());
372 });
373 }
374
375 const auto &depend_list = GetDependList(op_info, op_prim);
376 for (size_t index = 0; index < op_info.expanded_input_values.size(); ++index) {
377 auto const &value = op_info.expanded_input_values[index];
378 if (value->isa<tensor::BaseTensor>()) {
379 const auto &input_tensor = value->cast<tensor::BaseTensorPtr>();
380 MS_EXCEPTION_IF_NULL(input_tensor);
381 if (op_info.use_dynamic_shape_process) {
382 graph_info += GetNumString(static_cast<int>(input_tensor->shape().size()));
383 } else {
384 if (input_tensor->base_shape_ptr() != nullptr) {
385 graph_info += input_tensor->base_shape_ptr()->ToString();
386 } else if (!input_tensor->shape().empty()) {
387 const auto &shape_str =
388 std::accumulate(std::next(input_tensor->shape().begin()), input_tensor->shape().end(),
389 std::to_string(input_tensor->shape()[0]),
390 [](std::string cur, size_t n) { return cur.append("-").append(std::to_string(n)); });
391 graph_info += shape_str;
392 }
393 }
394
395 graph_info += GetNumString(input_tensor->data_type());
396 // In the case of the same shape, but dtype and format are inconsistent
397 auto tensor_addr = input_tensor->device_address();
398 if (tensor_addr != nullptr && !has_hidden_side_effect) {
399 auto p_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor_addr);
400 MS_EXCEPTION_IF_NULL(p_address);
401 graph_info += p_address->format();
402 graph_info += p_address->padding_type();
403 }
404
405 if (op_info.input_types[index] == InputType::kConstant || depend_list.find(index) != depend_list.end()) {
406 graph_info += common::AnfAlgo::GetTensorValueString(input_tensor);
407 }
408 } else {
409 graph_info += value->ToString();
410 }
411
412 graph_info += "_";
413 }
414
415 graph_info += std::to_string(op_info.stream_id);
416
417 // Operator with hidden side effect.
418 if (has_hidden_side_effect) {
419 (void)graph_info.append("r_").append(std::to_string(op_info.py_prim_id_)).append("_");
420 }
421
422 #ifdef ENABLE_D
423 // Ascend special info.
424 graph_info = GetGraphInfoForAscendSpecial(op_info, op_prim, graph_info);
425 #endif
426
427 return graph_info;
428 }
429
ClearOpCache(const GraphInfo & graph_info)430 void OpCompiler::ClearOpCache(const GraphInfo &graph_info) { (void)op_compiler_infos_.erase(graph_info); }
431
ClearAllCache()432 void OpCompiler::ClearAllCache() { op_compiler_infos_.clear(); }
433
UpdateRefNodeOutputDeviceAddress(const KernelGraphPtr & graph)434 void OpCompiler::UpdateRefNodeOutputDeviceAddress(const KernelGraphPtr &graph) {
435 MS_EXCEPTION_IF_NULL(graph);
436 auto ref_node_map = graph->GetRefMap();
437 for (const auto &[output_pair, input_pair] : ref_node_map) {
438 const auto &[ref_node, output_index] = output_pair;
439 const auto &[input_node, input_node_output_index] = input_pair;
440 if (!AnfAlgo::OutputAddrExist(input_node, input_node_output_index, false)) {
441 MS_EXCEPTION_IF_NULL(input_node);
442 MS_LOG(WARNING) << "Output address not exist, node " << input_node->fullname_with_scope() << " index "
443 << input_node_output_index;
444 continue;
445 }
446 auto input_addr = AnfAlgo::GetMutableOutputAddr(input_node, input_node_output_index, false);
447 AnfAlgo::SetOutputAddr(input_addr, output_index, ref_node.get());
448 }
449 }
450 } // namespace pynative
451 } // namespace mindspore
452