• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "plugin/device/ascend/kernel/hccl/hcom_util.h"
18 #include <algorithm>
19 #include <memory>
20 #include "include/backend/anf_runtime_algorithm.h"
21 #include "include/common/utils/anfalgo.h"
22 #include "include/common/utils/utils.h"
23 #include "ops/ascend_op_name.h"
24 #include "ops/framework_op_name.h"
25 #include "ops/other_op_name.h"
26 #include "utils/ms_context.h"
27 #include "utils/trace_base.h"
28 #include "ir/dtype/type.h"
29 
30 namespace mindspore {
ConvertHcclType(TypeId type_id)31 ::HcclDataType HcomUtil::ConvertHcclType(TypeId type_id) {
32   auto iter = kConstOpHcomDataTypeMap.find(type_id);
33   if (iter == kConstOpHcomDataTypeMap.end()) {
34     if (type_id == TypeId::kNumberTypeComplex64) {
35       MS_LOG(INFO) << "HcomDataType Can't support Current Ascend Data Type : Complex64, Convert it to Float32";
36       return HCCL_DATA_TYPE_FP32;
37     }
38     MS_LOG(EXCEPTION) << "HcomDataType can't support Current Ascend Data Type : " << TypeIdLabel(type_id);
39   }
40   return iter->second;
41 }
42 
GetHcomDataType(const std::string & kernel_name,const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs,vector<HcclDataType> * data_type_list)43 bool HcomUtil::GetHcomDataType(const std::string &kernel_name, const std::vector<KernelTensor *> &inputs,
44                                const std::vector<KernelTensor *> &outputs, vector<HcclDataType> *data_type_list) {
45   MS_EXCEPTION_IF_NULL(data_type_list);
46 
47   data_type_list->clear();
48   const std::vector<KernelTensor *> &tensors = HcomUtil::IsReceiveOp(kernel_name) ? outputs : inputs;
49   std::transform(tensors.begin(), tensors.end(), std::back_inserter(*data_type_list),
50                  [](KernelTensor *tensor_ptr) { return ConvertHcclType(tensor_ptr->dtype_id()); });
51 
52   if (!data_type_list->empty()) {
53     if (std::any_of(data_type_list->begin(), data_type_list->end(),
54                     [&data_type_list](HcclDataType type) { return type != *(data_type_list->begin()); })) {
55       MS_LOG(ERROR) << "hccl kernel " << kernel_name << " have different data type";
56       return false;
57     }
58   }
59   return true;
60 }
61 
GetHcclOpSize(const HcclDataType & data_type,const ShapeVector & shape,size_t * size)62 bool HcomUtil::GetHcclOpSize(const HcclDataType &data_type, const ShapeVector &shape, size_t *size) {
63   MS_EXCEPTION_IF_NULL(size);
64   int64_t tmp_size = 1;
65   uint32_t type_size = 4;
66   for (size_t i = 0; i < shape.size(); i++) {
67     tmp_size = LongMulWithOverflowCheck(tmp_size, shape[i]);
68   }
69 
70   if (!GetHcomTypeSize(data_type, &type_size)) {
71     return false;
72   }
73 
74   *size = SizetMulWithOverflowCheck(LongToSizeClipNeg(tmp_size), type_size);
75 
76   MS_LOG(DEBUG) << "size[" << *size << "]";
77   return true;
78 }
79 
GetHcomTypeSize(const HcclDataType & data_type,uint32_t * size)80 bool HcomUtil::GetHcomTypeSize(const HcclDataType &data_type, uint32_t *size) {
81   MS_EXCEPTION_IF_NULL(size);
82   auto iter = kConstOpHcomDataTypeSizeMap.find(data_type);
83   if (iter == kConstOpHcomDataTypeSizeMap.end()) {
84     MS_LOG(ERROR) << "HcomUtil::HcomDataTypeSize, No DataTypeSize!";
85     return false;
86   }
87   *size = iter->second;
88   return true;
89 }
90 
GetHcomCount(const PrimitivePtr & primitive,const vector<HcclDataType> & data_type_list,const vector<ShapeVector> & shape_list,const size_t input_tensor_num,uint64_t * total_count)91 bool HcomUtil::GetHcomCount(const PrimitivePtr &primitive, const vector<HcclDataType> &data_type_list,
92                             const vector<ShapeVector> &shape_list, const size_t input_tensor_num,
93                             uint64_t *total_count) {
94   MS_EXCEPTION_IF_NULL(primitive);
95   MS_EXCEPTION_IF_NULL(total_count);
96 
97   const uint32_t align_size = 512;
98   const uint32_t filled_size = 32;
99   uint64_t total_size = 0;
100   size_t input_size;
101   uint32_t type_size = 4;
102 
103   MS_EXCEPTION_IF_CHECK_FAIL(data_type_list.size() == shape_list.size(),
104                              "Size of data_type_list must be equal to size of shape_list");
105 
106   for (size_t i = 0; i < data_type_list.size(); ++i) {
107     if (!GetHcomTypeSize(data_type_list[i], &type_size)) {
108       return false;
109     }
110 
111     if (!GetHcclOpSize(data_type_list[i], shape_list[i], &input_size)) {
112       MS_LOG(ERROR) << "Get GetHcclOpSize failed";
113       return false;
114     }
115 
116     if (input_tensor_num > 1) {
117       // communication operator with dynamic input should have continuous memory.
118       MS_LOG(INFO) << "Communication operator " << primitive->name() << " has dynamic input.";
119       input_size = (input_size + align_size - 1 + filled_size) / align_size * align_size;
120     }
121     if (primitive->name() == kReduceScatterOpName) {
122       int64_t rank_size;
123       if (!HcomUtil::GetHcomAttr<int64_t>(primitive, kAttrRankSize, &rank_size)) {
124         return false;
125       }
126       input_size = static_cast<uint64_t>(input_size / LongToSize(rank_size));
127     }
128     bool all_dynamic = std::all_of(shape_list[i].begin(), shape_list[i].end(), [](int64_t x) { return x == -1; });
129     if (!all_dynamic && (type_size == 0 || input_size % type_size != 0)) {
130       MS_LOG(ERROR) << "primitive=" << primitive->name() << ", Input_size[" << input_size << "],Type_size[" << type_size
131                     << "] != 0, fail!"
132                     << " shape_list[i]=" << shape_list[i];
133       return false;
134     }
135     total_size += input_size / type_size;
136   }
137   *total_count = total_size;
138   return true;
139 }
140 
GetHcomOperationType(const PrimitivePtr & primitive,HcclReduceOp * op_type)141 bool HcomUtil::GetHcomOperationType(const PrimitivePtr &primitive, HcclReduceOp *op_type) {
142   MS_EXCEPTION_IF_NULL(primitive);
143   MS_EXCEPTION_IF_NULL(op_type);
144 
145   std::string hcom_op_type;
146   if (!GetHcomAttr<std::string>(primitive, kAttrOp, &hcom_op_type)) {
147     return false;
148   }
149   if (hcom_op_type == "min") {
150     *op_type = HCCL_REDUCE_MIN;
151   } else if (hcom_op_type == "max") {
152     *op_type = HCCL_REDUCE_MAX;
153   } else if (hcom_op_type == "prod") {
154     *op_type = HCCL_REDUCE_PROD;
155   } else if (hcom_op_type == "sum") {
156     *op_type = HCCL_REDUCE_SUM;
157   } else {
158     MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_REDUCE_TYPE fail, [" << hcom_op_type << "] not support!";
159     return false;
160   }
161   return true;
162 }
163 
GetHcomReceiveType(const AnfNodePtr & anf_node,TypeId * receive_type)164 bool HcomUtil::GetHcomReceiveType(const AnfNodePtr &anf_node, TypeId *receive_type) {
165   MS_EXCEPTION_IF_NULL(anf_node);
166   MS_EXCEPTION_IF_NULL(receive_type);
167   auto primitive = common::AnfAlgo::GetCNodePrimitive(anf_node);
168   MS_EXCEPTION_IF_NULL(primitive);
169   if (primitive->GetAttr("dtype") != nullptr) {
170     *receive_type = GetValue<NumberPtr>(primitive->GetAttr("dtype"))->type_id();
171   } else {
172     MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_SRTAG_INDEX fail, not support!";
173     return false;
174   }
175   return true;
176 }
177 
GetHcomGroup(NotNull<const AnfNodePtr &> anf_node,NotNull<std::string * > group)178 void HcomUtil::GetHcomGroup(NotNull<const AnfNodePtr &> anf_node, NotNull<std::string *> group) {
179   auto primitive = common::AnfAlgo::GetCNodePrimitive(anf_node);
180   MS_EXCEPTION_IF_NULL(primitive);
181   auto attr = primitive->GetAttr(kAttrGroup);
182   if (attr != nullptr) {
183     *group = GetValue<std::string>(attr);
184   } else {
185     MS_LOG(EXCEPTION) << "Get Hcom Group Attr of Op:" << anf_node->fullname_with_scope() << " failed."
186                       << trace::DumpSourceLines(anf_node);
187   }
188 }
189 }  // namespace mindspore
190