1 /**
2 * Copyright 2019-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "plugin/device/ascend/kernel/hccl/hcom_util.h"
18 #include <algorithm>
19 #include <memory>
20 #include "include/backend/anf_runtime_algorithm.h"
21 #include "include/common/utils/anfalgo.h"
22 #include "include/common/utils/utils.h"
23 #include "ops/ascend_op_name.h"
24 #include "ops/framework_op_name.h"
25 #include "ops/other_op_name.h"
26 #include "utils/ms_context.h"
27 #include "utils/trace_base.h"
28 #include "ir/dtype/type.h"
29
30 namespace mindspore {
ConvertHcclType(TypeId type_id)31 ::HcclDataType HcomUtil::ConvertHcclType(TypeId type_id) {
32 auto iter = kConstOpHcomDataTypeMap.find(type_id);
33 if (iter == kConstOpHcomDataTypeMap.end()) {
34 if (type_id == TypeId::kNumberTypeComplex64) {
35 MS_LOG(INFO) << "HcomDataType Can't support Current Ascend Data Type : Complex64, Convert it to Float32";
36 return HCCL_DATA_TYPE_FP32;
37 }
38 MS_LOG(EXCEPTION) << "HcomDataType can't support Current Ascend Data Type : " << TypeIdLabel(type_id);
39 }
40 return iter->second;
41 }
42
GetHcomDataType(const std::string & kernel_name,const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs,vector<HcclDataType> * data_type_list)43 bool HcomUtil::GetHcomDataType(const std::string &kernel_name, const std::vector<KernelTensor *> &inputs,
44 const std::vector<KernelTensor *> &outputs, vector<HcclDataType> *data_type_list) {
45 MS_EXCEPTION_IF_NULL(data_type_list);
46
47 data_type_list->clear();
48 const std::vector<KernelTensor *> &tensors = HcomUtil::IsReceiveOp(kernel_name) ? outputs : inputs;
49 std::transform(tensors.begin(), tensors.end(), std::back_inserter(*data_type_list),
50 [](KernelTensor *tensor_ptr) { return ConvertHcclType(tensor_ptr->dtype_id()); });
51
52 if (!data_type_list->empty()) {
53 if (std::any_of(data_type_list->begin(), data_type_list->end(),
54 [&data_type_list](HcclDataType type) { return type != *(data_type_list->begin()); })) {
55 MS_LOG(ERROR) << "hccl kernel " << kernel_name << " have different data type";
56 return false;
57 }
58 }
59 return true;
60 }
61
GetHcclOpSize(const HcclDataType & data_type,const ShapeVector & shape,size_t * size)62 bool HcomUtil::GetHcclOpSize(const HcclDataType &data_type, const ShapeVector &shape, size_t *size) {
63 MS_EXCEPTION_IF_NULL(size);
64 int64_t tmp_size = 1;
65 uint32_t type_size = 4;
66 for (size_t i = 0; i < shape.size(); i++) {
67 tmp_size = LongMulWithOverflowCheck(tmp_size, shape[i]);
68 }
69
70 if (!GetHcomTypeSize(data_type, &type_size)) {
71 return false;
72 }
73
74 *size = SizetMulWithOverflowCheck(LongToSizeClipNeg(tmp_size), type_size);
75
76 MS_LOG(DEBUG) << "size[" << *size << "]";
77 return true;
78 }
79
GetHcomTypeSize(const HcclDataType & data_type,uint32_t * size)80 bool HcomUtil::GetHcomTypeSize(const HcclDataType &data_type, uint32_t *size) {
81 MS_EXCEPTION_IF_NULL(size);
82 auto iter = kConstOpHcomDataTypeSizeMap.find(data_type);
83 if (iter == kConstOpHcomDataTypeSizeMap.end()) {
84 MS_LOG(ERROR) << "HcomUtil::HcomDataTypeSize, No DataTypeSize!";
85 return false;
86 }
87 *size = iter->second;
88 return true;
89 }
90
GetHcomCount(const PrimitivePtr & primitive,const vector<HcclDataType> & data_type_list,const vector<ShapeVector> & shape_list,const size_t input_tensor_num,uint64_t * total_count)91 bool HcomUtil::GetHcomCount(const PrimitivePtr &primitive, const vector<HcclDataType> &data_type_list,
92 const vector<ShapeVector> &shape_list, const size_t input_tensor_num,
93 uint64_t *total_count) {
94 MS_EXCEPTION_IF_NULL(primitive);
95 MS_EXCEPTION_IF_NULL(total_count);
96
97 const uint32_t align_size = 512;
98 const uint32_t filled_size = 32;
99 uint64_t total_size = 0;
100 size_t input_size;
101 uint32_t type_size = 4;
102
103 MS_EXCEPTION_IF_CHECK_FAIL(data_type_list.size() == shape_list.size(),
104 "Size of data_type_list must be equal to size of shape_list");
105
106 for (size_t i = 0; i < data_type_list.size(); ++i) {
107 if (!GetHcomTypeSize(data_type_list[i], &type_size)) {
108 return false;
109 }
110
111 if (!GetHcclOpSize(data_type_list[i], shape_list[i], &input_size)) {
112 MS_LOG(ERROR) << "Get GetHcclOpSize failed";
113 return false;
114 }
115
116 if (input_tensor_num > 1) {
117 // communication operator with dynamic input should have continuous memory.
118 MS_LOG(INFO) << "Communication operator " << primitive->name() << " has dynamic input.";
119 input_size = (input_size + align_size - 1 + filled_size) / align_size * align_size;
120 }
121 if (primitive->name() == kReduceScatterOpName) {
122 int64_t rank_size;
123 if (!HcomUtil::GetHcomAttr<int64_t>(primitive, kAttrRankSize, &rank_size)) {
124 return false;
125 }
126 input_size = static_cast<uint64_t>(input_size / LongToSize(rank_size));
127 }
128 bool all_dynamic = std::all_of(shape_list[i].begin(), shape_list[i].end(), [](int64_t x) { return x == -1; });
129 if (!all_dynamic && (type_size == 0 || input_size % type_size != 0)) {
130 MS_LOG(ERROR) << "primitive=" << primitive->name() << ", Input_size[" << input_size << "],Type_size[" << type_size
131 << "] != 0, fail!"
132 << " shape_list[i]=" << shape_list[i];
133 return false;
134 }
135 total_size += input_size / type_size;
136 }
137 *total_count = total_size;
138 return true;
139 }
140
GetHcomOperationType(const PrimitivePtr & primitive,HcclReduceOp * op_type)141 bool HcomUtil::GetHcomOperationType(const PrimitivePtr &primitive, HcclReduceOp *op_type) {
142 MS_EXCEPTION_IF_NULL(primitive);
143 MS_EXCEPTION_IF_NULL(op_type);
144
145 std::string hcom_op_type;
146 if (!GetHcomAttr<std::string>(primitive, kAttrOp, &hcom_op_type)) {
147 return false;
148 }
149 if (hcom_op_type == "min") {
150 *op_type = HCCL_REDUCE_MIN;
151 } else if (hcom_op_type == "max") {
152 *op_type = HCCL_REDUCE_MAX;
153 } else if (hcom_op_type == "prod") {
154 *op_type = HCCL_REDUCE_PROD;
155 } else if (hcom_op_type == "sum") {
156 *op_type = HCCL_REDUCE_SUM;
157 } else {
158 MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_REDUCE_TYPE fail, [" << hcom_op_type << "] not support!";
159 return false;
160 }
161 return true;
162 }
163
GetHcomReceiveType(const AnfNodePtr & anf_node,TypeId * receive_type)164 bool HcomUtil::GetHcomReceiveType(const AnfNodePtr &anf_node, TypeId *receive_type) {
165 MS_EXCEPTION_IF_NULL(anf_node);
166 MS_EXCEPTION_IF_NULL(receive_type);
167 auto primitive = common::AnfAlgo::GetCNodePrimitive(anf_node);
168 MS_EXCEPTION_IF_NULL(primitive);
169 if (primitive->GetAttr("dtype") != nullptr) {
170 *receive_type = GetValue<NumberPtr>(primitive->GetAttr("dtype"))->type_id();
171 } else {
172 MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_SRTAG_INDEX fail, not support!";
173 return false;
174 }
175 return true;
176 }
177
GetHcomGroup(NotNull<const AnfNodePtr &> anf_node,NotNull<std::string * > group)178 void HcomUtil::GetHcomGroup(NotNull<const AnfNodePtr &> anf_node, NotNull<std::string *> group) {
179 auto primitive = common::AnfAlgo::GetCNodePrimitive(anf_node);
180 MS_EXCEPTION_IF_NULL(primitive);
181 auto attr = primitive->GetAttr(kAttrGroup);
182 if (attr != nullptr) {
183 *group = GetValue<std::string>(attr);
184 } else {
185 MS_LOG(EXCEPTION) << "Get Hcom Group Attr of Op:" << anf_node->fullname_with_scope() << " failed."
186 << trace::DumpSourceLines(anf_node);
187 }
188 }
189 } // namespace mindspore
190