1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/kernel_compiler/hccl/hcom_util.h"
18 #include <memory>
19 #include "backend/kernel_compiler/common_utils.h"
20 #include "backend/session/anf_runtime_algorithm.h"
21 #include "utils/ms_context.h"
22 #include "utils/utils.h"
23
24 namespace mindspore {
25 namespace {
IsPyNativeMode()26 bool IsPyNativeMode() {
27 auto ms_context = MsContext::GetInstance();
28 MS_EXCEPTION_IF_NULL(ms_context);
29 return ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
30 }
31 } // namespace
32
GetKernelInputShape(const AnfNodePtr & anf_node,vector<vector<size_t>> * hccl_kernel_intput_shape_list)33 bool HcomUtil::GetKernelInputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_intput_shape_list) {
34 MS_EXCEPTION_IF_NULL(anf_node);
35 MS_EXCEPTION_IF_NULL(hccl_kernel_intput_shape_list);
36 size_t input_num = AnfAlgo::GetInputTensorNum(anf_node);
37 for (size_t i = 0; i < input_num; ++i) {
38 std::vector<size_t> shape_i = AnfAlgo::GetInputDeviceShape(anf_node, i);
39 hccl_kernel_intput_shape_list->emplace_back(shape_i);
40 }
41
42 return true;
43 }
44
GetKernelOutputShape(const AnfNodePtr & anf_node,vector<vector<size_t>> * hccl_kernel_output_shape_list)45 bool HcomUtil::GetKernelOutputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_output_shape_list) {
46 MS_EXCEPTION_IF_NULL(anf_node);
47 MS_EXCEPTION_IF_NULL(hccl_kernel_output_shape_list);
48 size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node);
49 for (size_t i = 0; i < output_num; ++i) {
50 std::vector<size_t> shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i);
51 hccl_kernel_output_shape_list->emplace_back(shape_i);
52 }
53
54 return true;
55 }
56
ConvertHcclType(TypeId type_id)57 ::HcclDataType HcomUtil::ConvertHcclType(TypeId type_id) {
58 auto iter = kConstOpHcomDataTypeMap.find(type_id);
59 if (iter == kConstOpHcomDataTypeMap.end()) {
60 MS_LOG(EXCEPTION) << "HcomDataType can't support Current Ascend Data Type : " << type_id;
61 }
62 return iter->second;
63 }
64
GetHcomDataType(const AnfNodePtr & anf_node,vector<HcclDataType> * data_type_list)65 bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector<HcclDataType> *data_type_list) {
66 MS_EXCEPTION_IF_NULL(anf_node);
67 MS_EXCEPTION_IF_NULL(data_type_list);
68 size_t tensor_num = AnfAlgo::GetInputTensorNum(anf_node);
69 auto op_name = AnfAlgo::GetCNodeName(anf_node);
70 if (op_name == kReceiveOpName) {
71 tensor_num = AnfAlgo::GetOutputTensorNum(anf_node);
72 }
73 for (size_t i = 0; i < tensor_num; ++i) {
74 TypeId type_ptr;
75 if (op_name == kReceiveOpName) {
76 type_ptr = AnfAlgo::GetOutputDeviceDataType(anf_node, i);
77 } else {
78 type_ptr = AnfAlgo::GetInputDeviceDataType(anf_node, i);
79 }
80 data_type_list->emplace_back(ConvertHcclType(type_ptr));
81 }
82 if (!data_type_list->empty()) {
83 if (std::any_of(data_type_list->begin(), data_type_list->end(),
84 [&data_type_list](HcclDataType type) { return type != *(data_type_list->begin()); })) {
85 MS_LOG(ERROR) << "hccl have different data type";
86 return false;
87 }
88 }
89 return true;
90 }
91
GetHcclOpSize(const HcclDataType & data_type,const vector<size_t> & shape,size_t * size)92 bool HcomUtil::GetHcclOpSize(const HcclDataType &data_type, const vector<size_t> &shape, size_t *size) {
93 MS_EXCEPTION_IF_NULL(size);
94 size_t tmp_size = 1;
95 uint32_t type_size = 4;
96 for (size_t i = 0; i < shape.size(); i++) {
97 tmp_size = SizetMulWithOverflowCheck(tmp_size, shape[i]);
98 }
99
100 if (!GetHcomTypeSize(data_type, &type_size)) {
101 return false;
102 }
103
104 *size = SizetMulWithOverflowCheck(tmp_size, type_size);
105
106 MS_LOG(INFO) << "size[" << *size << "]";
107 return true;
108 }
109
GetHcomTypeSize(const HcclDataType & data_type,uint32_t * size)110 bool HcomUtil::GetHcomTypeSize(const HcclDataType &data_type, uint32_t *size) {
111 MS_EXCEPTION_IF_NULL(size);
112 auto iter = kConstOpHcomDataTypeSizeMap.find(data_type);
113 if (iter == kConstOpHcomDataTypeSizeMap.end()) {
114 MS_LOG(ERROR) << "HcomUtil::HcomDataTypeSize, No DataTypeSize!";
115 return false;
116 }
117 *size = iter->second;
118 return true;
119 }
120
GetHcomCount(const AnfNodePtr & anf_node,const vector<HcclDataType> & data_type_list,const vector<vector<size_t>> & shape_list,uint64_t * total_count)121 bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector<HcclDataType> &data_type_list,
122 const vector<vector<size_t>> &shape_list, uint64_t *total_count) {
123 MS_EXCEPTION_IF_NULL(anf_node);
124 MS_EXCEPTION_IF_NULL(total_count);
125 const uint32_t align_size = 512;
126 const uint32_t filled_size = 32;
127 uint64_t total_size = 0;
128 uint64_t block_size;
129 size_t input_size;
130 uint32_t type_size = 4;
131 size_t size = AnfAlgo::GetInputTensorNum(anf_node);
132 auto cnode = anf_node->cast<CNodePtr>();
133 MS_EXCEPTION_IF_NULL(cnode);
134 if (AnfAlgo::GetCNodeName(anf_node) == kReceiveOpName) {
135 size = AnfAlgo::GetOutputTensorNum(anf_node);
136 }
137 for (size_t i = 0; i < size; ++i) {
138 if (!GetHcomTypeSize(data_type_list[i], &type_size)) {
139 return false;
140 }
141
142 if (!GetHcclOpSize(data_type_list[i], shape_list[i], &input_size)) {
143 MS_LOG(ERROR) << "Get GetHcclOpSize failed";
144 return false;
145 }
146
147 if (AnfAlgo::GetCNodeName(anf_node) == kReduceScatterOpName) {
148 int64_t rank_size;
149 auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
150 MS_EXCEPTION_IF_NULL(primitive);
151 if (primitive->GetAttr(kAttrRankSize) != nullptr) {
152 rank_size = GetValue<int64_t>(primitive->GetAttr(kAttrRankSize));
153 } else {
154 MS_LOG(ERROR) << "Get rank size failed";
155 return false;
156 }
157 size_t actual_input_size = input_size;
158 if (AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && AnfAlgo::GetNodeAttr<int64_t>(anf_node, kAttrFusion)) {
159 actual_input_size = (input_size + align_size - 1 + filled_size) / align_size * align_size;
160 }
161 block_size = static_cast<uint64_t>(actual_input_size / LongToSize(rank_size));
162 total_size = total_size + block_size;
163 } else {
164 if (AnfAlgo::GetCNodeName(anf_node) == kAllGatherOpName) {
165 if (AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && AnfAlgo::GetNodeAttr<int64_t>(anf_node, kAttrFusion) &&
166 AnfAlgo::GetInputTensorNum(anf_node) > 1) {
167 block_size = (input_size + align_size - 1 + filled_size) / align_size * align_size;
168 } else {
169 block_size = input_size;
170 }
171 } else {
172 block_size =
173 IsPyNativeMode() ? input_size : (input_size + align_size - 1 + filled_size) / align_size * align_size;
174 }
175 total_size = total_size + block_size;
176 }
177 }
178
179 if (type_size == 0 || total_size % type_size != 0) {
180 MS_LOG(ERROR) << "Total_size[" << total_size << "],Type_size[" << type_size << "] != 0, fail!";
181 return false;
182 }
183 *total_count = total_size / type_size;
184 return true;
185 }
186
GetHcomOperationType(const AnfNodePtr & anf_node,HcclReduceOp * op_type)187 bool HcomUtil::GetHcomOperationType(const AnfNodePtr &anf_node, HcclReduceOp *op_type) {
188 MS_EXCEPTION_IF_NULL(anf_node);
189 MS_EXCEPTION_IF_NULL(op_type);
190 auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
191 MS_EXCEPTION_IF_NULL(primitive);
192 if (primitive->GetAttr(kAttrOp) == nullptr) {
193 MS_LOG(ERROR) << "Get HCOM_ATTR_REDUCE_TYPE fail, not support!";
194 return false;
195 }
196 auto hcom_op_type = GetValue<std::string>(primitive->GetAttr(kAttrOp));
197 if (hcom_op_type == "min") {
198 *op_type = HCCL_REDUCE_MIN;
199 } else if (hcom_op_type == "max") {
200 *op_type = HCCL_REDUCE_MAX;
201 } else if (hcom_op_type == "prod") {
202 *op_type = HCCL_REDUCE_PROD;
203 } else if (hcom_op_type == "sum") {
204 *op_type = HCCL_REDUCE_SUM;
205 } else {
206 MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_REDUCE_TYPE fail, [" << hcom_op_type << "] not support!";
207 return false;
208 }
209 return true;
210 }
211
GetHcomRootId(const AnfNodePtr & anf_node,uint32_t * root_id)212 bool HcomUtil::GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id) {
213 MS_EXCEPTION_IF_NULL(anf_node);
214 MS_EXCEPTION_IF_NULL(root_id);
215 auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
216 MS_EXCEPTION_IF_NULL(primitive);
217 if (primitive->GetAttr(kAttrRootRank) != nullptr) {
218 *root_id = (uint32_t)GetValue<int64_t>(primitive->GetAttr(kAttrRootRank));
219 } else {
220 MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_ROOT_INDEX fail, not support!";
221 return false;
222 }
223 return true;
224 }
225
GetHcomSrcRank(const AnfNodePtr & anf_node,uint32_t * src_rank)226 bool HcomUtil::GetHcomSrcRank(const AnfNodePtr &anf_node, uint32_t *src_rank) {
227 MS_EXCEPTION_IF_NULL(anf_node);
228 MS_EXCEPTION_IF_NULL(src_rank);
229 auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
230 MS_EXCEPTION_IF_NULL(primitive);
231 if (primitive->GetAttr("src_rank") != nullptr) {
232 *src_rank = static_cast<uint32_t>(GetValue<int64_t>(primitive->GetAttr("src_rank")));
233 } else {
234 MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_SRC_RANK fail, not support!";
235 return false;
236 }
237 return true;
238 }
239
GetHcomDestRank(const AnfNodePtr & anf_node,uint32_t * dest_rank)240 bool HcomUtil::GetHcomDestRank(const AnfNodePtr &anf_node, uint32_t *dest_rank) {
241 MS_EXCEPTION_IF_NULL(anf_node);
242 MS_EXCEPTION_IF_NULL(dest_rank);
243 auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
244 MS_EXCEPTION_IF_NULL(primitive);
245 if (primitive->GetAttr("dest_rank") != nullptr) {
246 *dest_rank = static_cast<uint32_t>(GetValue<int64_t>(primitive->GetAttr("dest_rank")));
247 } else {
248 MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_DEST_RANK fail, not support!";
249 return false;
250 }
251 return true;
252 }
253
GetHcomReceiveType(const AnfNodePtr & anf_node,TypeId * receive_type)254 bool HcomUtil::GetHcomReceiveType(const AnfNodePtr &anf_node, TypeId *receive_type) {
255 MS_EXCEPTION_IF_NULL(anf_node);
256 MS_EXCEPTION_IF_NULL(receive_type);
257 auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
258 MS_EXCEPTION_IF_NULL(primitive);
259 if (primitive->GetAttr("dtype") != nullptr) {
260 *receive_type = GetValue<NumberPtr>(primitive->GetAttr("dtype"))->type_id();
261 } else {
262 MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_SRTAG_INDEX fail, not support!";
263 return false;
264 }
265 return true;
266 }
267
GetHcomGroup(NotNull<const AnfNodePtr &> anf_node,NotNull<std::string * > group)268 void HcomUtil::GetHcomGroup(NotNull<const AnfNodePtr &> anf_node, NotNull<std::string *> group) {
269 auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
270 MS_EXCEPTION_IF_NULL(primitive);
271 auto attr = primitive->GetAttr(kAttrGroup);
272 if (attr != nullptr) {
273 *group = GetValue<std::string>(attr);
274 } else {
275 MS_LOG(EXCEPTION) << "Get Hcom Group Attr of Op:" << anf_node->fullname_with_scope() << " failed";
276 }
277 }
278 } // namespace mindspore
279