1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "runtime/device/ascend/executor/hccl_dynamic_kernel.h"
18
19 #include "hccl/hcom.h"
20 #include "common/opskernel/ge_task_info.h"
21 #include "utils/log_adapter.h"
22 #include "runtime/device/kernel_runtime.h"
23 #include "backend/kernel_compiler/hccl/hcom_util.h"
24 #include "runtime/hccl_adapter/hccl_adapter.h"
25
26 namespace {
27 // Find so in RPATH or LD_LIBRARY_PATH (/usr/local/Ascend/fwkacllib/lib64/)
28 constexpr auto kHcomGraphAdaptorPath = "libhcom_graph_adaptor.so";
29 } // namespace
30
31 namespace mindspore {
32 namespace device {
33 namespace ascend {
UpdateArgs()34 void HcclDynamicKernel::UpdateArgs() {
35 if (!is_dynamic_shape_) {
36 MS_LOG(INFO) << "Not Dynamic Shape";
37 return;
38 }
39 auto cnode = cnode_ptr_.lock();
40 MS_EXCEPTION_IF_NULL(cnode);
41 MS_LOG(INFO) << "Start to UpdateArgs. Node info: " << cnode->DebugString();
42 auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
43 MS_EXCEPTION_IF_NULL(kernel_mod);
44 // Update input, output, count
45 AddressPtrList kernel_inputs;
46 AddressPtrList kernel_workspaces;
47 AddressPtrList kernel_outputs;
48 KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
49 if (kernel_inputs.empty() || kernel_outputs.empty()) {
50 MS_LOG(EXCEPTION) << "Inputs or outputs is empty. Node info: " << cnode->DebugString();
51 }
52 auto input0 = kernel_inputs.at(0);
53 auto output0 = kernel_outputs.at(0);
54 MS_EXCEPTION_IF_NULL(input0);
55 MS_EXCEPTION_IF_NULL(output0);
56
57 // Update Hccl input and output
58 input_ptr_ = input0->addr;
59 output_ptr_ = output0->addr;
60
61 std::vector<std::vector<size_t>> hccl_kernel_input_shape_list;
62 if (!HcomUtil::GetKernelInputShape(cnode, &hccl_kernel_input_shape_list)) {
63 MS_LOG(EXCEPTION) << "GetKernelInputShape fail! Node info: " << cnode->DebugString();
64 }
65
66 std::vector<HcclDataType> hccl_data_type_list;
67 if (!HcomUtil::GetHcomDataType(cnode, &hccl_data_type_list)) {
68 MS_LOG(EXCEPTION) << "GetHcomDataType fail! Node info: " << cnode->DebugString();
69 }
70
71 // Update Hccl count
72 if (!HcomUtil::GetHcomCount(cnode, hccl_data_type_list, hccl_kernel_input_shape_list, &count_)) {
73 MS_LOG(EXCEPTION) << "GetHcomCount fail! Node info: " << cnode->DebugString();
74 }
75 MS_LOG(INFO) << "Update Hccl count:" << count_;
76 }
77
StaticShapeExecute()78 void HcclDynamicKernel::StaticShapeExecute() {
79 auto cnode = cnode_ptr_.lock();
80 MS_EXCEPTION_IF_NULL(cnode);
81 auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
82 MS_EXCEPTION_IF_NULL(kernel_mod);
83 AddressPtrList kernel_inputs;
84 AddressPtrList kernel_workspaces;
85 AddressPtrList kernel_outputs;
86 KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
87 kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
88 }
89
Execute()90 void HcclDynamicKernel::Execute() {
91 auto cnode = cnode_ptr_.lock();
92 MS_EXCEPTION_IF_NULL(cnode);
93 MS_LOG(INFO) << "Start Execute: " << cnode->DebugString();
94 ::HcomOperation op_info;
95 op_info.hcclType = hccl_type_;
96 op_info.inputPtr = input_ptr_;
97 op_info.outputPtr = output_ptr_;
98 op_info.dataType = static_cast<HcclDataType>(data_type_);
99 op_info.opType = static_cast<HcclReduceOp>(op_type_);
100 op_info.root = IntToUint(root_);
101 op_info.count = count_;
102
103 auto callback = [this](HcclResult status) {
104 if (status != HCCL_SUCCESS) {
105 MS_LOG(ERROR) << "HcomExcutorInitialize failed, ret:" << status;
106 }
107 std::lock_guard<std::mutex> lock(this->hccl_mutex_);
108 this->cond_.notify_all();
109 MS_LOG(INFO) << "hccl callback success.";
110 };
111
112 auto hccl_ret = hccl::HcclAdapter::GetInstance().HcclExecEnqueueOp(op_info, callback);
113 if (hccl_ret != HCCL_SUCCESS) {
114 MS_LOG(EXCEPTION) << "Call EnqueueHcomOperation failed, node info: " << cnode->DebugString();
115 }
116
117 std::unique_lock<std::mutex> ulock(hccl_mutex_);
118 cond_.wait(ulock);
119 MS_LOG(INFO) << "Execute " << cnode->DebugString() << " success";
120 }
121
PostExecute()122 void HcclDynamicKernel::PostExecute() {}
123 } // namespace ascend
124 } // namespace device
125 } // namespace mindspore
126