• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "runtime/device/ascend/executor/hccl_dynamic_kernel.h"
18 
19 #include "hccl/hcom.h"
20 #include "common/opskernel/ge_task_info.h"
21 #include "utils/log_adapter.h"
22 #include "runtime/device/kernel_runtime.h"
23 #include "backend/kernel_compiler/hccl/hcom_util.h"
24 #include "runtime/hccl_adapter/hccl_adapter.h"
25 
26 namespace {
27 // Find so in RPATH or LD_LIBRARY_PATH (/usr/local/Ascend/fwkacllib/lib64/)
28 constexpr auto kHcomGraphAdaptorPath = "libhcom_graph_adaptor.so";
29 }  // namespace
30 
31 namespace mindspore {
32 namespace device {
33 namespace ascend {
UpdateArgs()34 void HcclDynamicKernel::UpdateArgs() {
35   if (!is_dynamic_shape_) {
36     MS_LOG(INFO) << "Not Dynamic Shape";
37     return;
38   }
39   auto cnode = cnode_ptr_.lock();
40   MS_EXCEPTION_IF_NULL(cnode);
41   MS_LOG(INFO) << "Start to UpdateArgs. Node info: " << cnode->DebugString();
42   auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
43   MS_EXCEPTION_IF_NULL(kernel_mod);
44   // Update input, output, count
45   AddressPtrList kernel_inputs;
46   AddressPtrList kernel_workspaces;
47   AddressPtrList kernel_outputs;
48   KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
49   if (kernel_inputs.empty() || kernel_outputs.empty()) {
50     MS_LOG(EXCEPTION) << "Inputs or outputs is empty. Node info: " << cnode->DebugString();
51   }
52   auto input0 = kernel_inputs.at(0);
53   auto output0 = kernel_outputs.at(0);
54   MS_EXCEPTION_IF_NULL(input0);
55   MS_EXCEPTION_IF_NULL(output0);
56 
57   // Update Hccl input and output
58   input_ptr_ = input0->addr;
59   output_ptr_ = output0->addr;
60 
61   std::vector<std::vector<size_t>> hccl_kernel_input_shape_list;
62   if (!HcomUtil::GetKernelInputShape(cnode, &hccl_kernel_input_shape_list)) {
63     MS_LOG(EXCEPTION) << "GetKernelInputShape fail! Node info: " << cnode->DebugString();
64   }
65 
66   std::vector<HcclDataType> hccl_data_type_list;
67   if (!HcomUtil::GetHcomDataType(cnode, &hccl_data_type_list)) {
68     MS_LOG(EXCEPTION) << "GetHcomDataType fail! Node info: " << cnode->DebugString();
69   }
70 
71   // Update Hccl count
72   if (!HcomUtil::GetHcomCount(cnode, hccl_data_type_list, hccl_kernel_input_shape_list, &count_)) {
73     MS_LOG(EXCEPTION) << "GetHcomCount fail! Node info: " << cnode->DebugString();
74   }
75   MS_LOG(INFO) << "Update Hccl count:" << count_;
76 }
77 
StaticShapeExecute()78 void HcclDynamicKernel::StaticShapeExecute() {
79   auto cnode = cnode_ptr_.lock();
80   MS_EXCEPTION_IF_NULL(cnode);
81   auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
82   MS_EXCEPTION_IF_NULL(kernel_mod);
83   AddressPtrList kernel_inputs;
84   AddressPtrList kernel_workspaces;
85   AddressPtrList kernel_outputs;
86   KernelRuntime::GenLaunchArgs(*kernel_mod, cnode, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
87   kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
88 }
89 
Execute()90 void HcclDynamicKernel::Execute() {
91   auto cnode = cnode_ptr_.lock();
92   MS_EXCEPTION_IF_NULL(cnode);
93   MS_LOG(INFO) << "Start Execute: " << cnode->DebugString();
94   ::HcomOperation op_info;
95   op_info.hcclType = hccl_type_;
96   op_info.inputPtr = input_ptr_;
97   op_info.outputPtr = output_ptr_;
98   op_info.dataType = static_cast<HcclDataType>(data_type_);
99   op_info.opType = static_cast<HcclReduceOp>(op_type_);
100   op_info.root = IntToUint(root_);
101   op_info.count = count_;
102 
103   auto callback = [this](HcclResult status) {
104     if (status != HCCL_SUCCESS) {
105       MS_LOG(ERROR) << "HcomExcutorInitialize failed, ret:" << status;
106     }
107     std::lock_guard<std::mutex> lock(this->hccl_mutex_);
108     this->cond_.notify_all();
109     MS_LOG(INFO) << "hccl callback success.";
110   };
111 
112   auto hccl_ret = hccl::HcclAdapter::GetInstance().HcclExecEnqueueOp(op_info, callback);
113   if (hccl_ret != HCCL_SUCCESS) {
114     MS_LOG(EXCEPTION) << "Call EnqueueHcomOperation failed, node info: " << cnode->DebugString();
115   }
116 
117   std::unique_lock<std::mutex> ulock(hccl_mutex_);
118   cond_.wait(ulock);
119   MS_LOG(INFO) << "Execute " << cnode->DebugString() << " success";
120 }
121 
PostExecute()122 void HcclDynamicKernel::PostExecute() {}
123 }  // namespace ascend
124 }  // namespace device
125 }  // namespace mindspore
126