• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/kernel_compiler/hccl/hcom_util.h"
18 #include <memory>
19 #include "backend/kernel_compiler/common_utils.h"
20 #include "backend/session/anf_runtime_algorithm.h"
21 #include "utils/ms_context.h"
22 #include "utils/utils.h"
23 
24 namespace mindspore {
25 namespace {
IsPyNativeMode()26 bool IsPyNativeMode() {
27   auto ms_context = MsContext::GetInstance();
28   MS_EXCEPTION_IF_NULL(ms_context);
29   return ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode;
30 }
31 }  // namespace
32 
GetKernelInputShape(const AnfNodePtr & anf_node,vector<vector<size_t>> * hccl_kernel_intput_shape_list)33 bool HcomUtil::GetKernelInputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_intput_shape_list) {
34   MS_EXCEPTION_IF_NULL(anf_node);
35   MS_EXCEPTION_IF_NULL(hccl_kernel_intput_shape_list);
36   size_t input_num = AnfAlgo::GetInputTensorNum(anf_node);
37   for (size_t i = 0; i < input_num; ++i) {
38     std::vector<size_t> shape_i = AnfAlgo::GetInputDeviceShape(anf_node, i);
39     hccl_kernel_intput_shape_list->emplace_back(shape_i);
40   }
41 
42   return true;
43 }
44 
GetKernelOutputShape(const AnfNodePtr & anf_node,vector<vector<size_t>> * hccl_kernel_output_shape_list)45 bool HcomUtil::GetKernelOutputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_output_shape_list) {
46   MS_EXCEPTION_IF_NULL(anf_node);
47   MS_EXCEPTION_IF_NULL(hccl_kernel_output_shape_list);
48   size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node);
49   for (size_t i = 0; i < output_num; ++i) {
50     std::vector<size_t> shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i);
51     hccl_kernel_output_shape_list->emplace_back(shape_i);
52   }
53 
54   return true;
55 }
56 
ConvertHcclType(TypeId type_id)57 ::HcclDataType HcomUtil::ConvertHcclType(TypeId type_id) {
58   auto iter = kConstOpHcomDataTypeMap.find(type_id);
59   if (iter == kConstOpHcomDataTypeMap.end()) {
60     MS_LOG(EXCEPTION) << "HcomDataType can't support Current Ascend Data Type : " << type_id;
61   }
62   return iter->second;
63 }
64 
GetHcomDataType(const AnfNodePtr & anf_node,vector<HcclDataType> * data_type_list)65 bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector<HcclDataType> *data_type_list) {
66   MS_EXCEPTION_IF_NULL(anf_node);
67   MS_EXCEPTION_IF_NULL(data_type_list);
68   size_t tensor_num = AnfAlgo::GetInputTensorNum(anf_node);
69   auto op_name = AnfAlgo::GetCNodeName(anf_node);
70   if (op_name == kReceiveOpName) {
71     tensor_num = AnfAlgo::GetOutputTensorNum(anf_node);
72   }
73   for (size_t i = 0; i < tensor_num; ++i) {
74     TypeId type_ptr;
75     if (op_name == kReceiveOpName) {
76       type_ptr = AnfAlgo::GetOutputDeviceDataType(anf_node, i);
77     } else {
78       type_ptr = AnfAlgo::GetInputDeviceDataType(anf_node, i);
79     }
80     data_type_list->emplace_back(ConvertHcclType(type_ptr));
81   }
82   if (!data_type_list->empty()) {
83     if (std::any_of(data_type_list->begin(), data_type_list->end(),
84                     [&data_type_list](HcclDataType type) { return type != *(data_type_list->begin()); })) {
85       MS_LOG(ERROR) << "hccl have different data type";
86       return false;
87     }
88   }
89   return true;
90 }
91 
GetHcclOpSize(const HcclDataType & data_type,const vector<size_t> & shape,size_t * size)92 bool HcomUtil::GetHcclOpSize(const HcclDataType &data_type, const vector<size_t> &shape, size_t *size) {
93   MS_EXCEPTION_IF_NULL(size);
94   size_t tmp_size = 1;
95   uint32_t type_size = 4;
96   for (size_t i = 0; i < shape.size(); i++) {
97     tmp_size = SizetMulWithOverflowCheck(tmp_size, shape[i]);
98   }
99 
100   if (!GetHcomTypeSize(data_type, &type_size)) {
101     return false;
102   }
103 
104   *size = SizetMulWithOverflowCheck(tmp_size, type_size);
105 
106   MS_LOG(INFO) << "size[" << *size << "]";
107   return true;
108 }
109 
GetHcomTypeSize(const HcclDataType & data_type,uint32_t * size)110 bool HcomUtil::GetHcomTypeSize(const HcclDataType &data_type, uint32_t *size) {
111   MS_EXCEPTION_IF_NULL(size);
112   auto iter = kConstOpHcomDataTypeSizeMap.find(data_type);
113   if (iter == kConstOpHcomDataTypeSizeMap.end()) {
114     MS_LOG(ERROR) << "HcomUtil::HcomDataTypeSize, No DataTypeSize!";
115     return false;
116   }
117   *size = iter->second;
118   return true;
119 }
120 
GetHcomCount(const AnfNodePtr & anf_node,const vector<HcclDataType> & data_type_list,const vector<vector<size_t>> & shape_list,uint64_t * total_count)121 bool HcomUtil::GetHcomCount(const AnfNodePtr &anf_node, const vector<HcclDataType> &data_type_list,
122                             const vector<vector<size_t>> &shape_list, uint64_t *total_count) {
123   MS_EXCEPTION_IF_NULL(anf_node);
124   MS_EXCEPTION_IF_NULL(total_count);
125   const uint32_t align_size = 512;
126   const uint32_t filled_size = 32;
127   uint64_t total_size = 0;
128   uint64_t block_size;
129   size_t input_size;
130   uint32_t type_size = 4;
131   size_t size = AnfAlgo::GetInputTensorNum(anf_node);
132   auto cnode = anf_node->cast<CNodePtr>();
133   MS_EXCEPTION_IF_NULL(cnode);
134   if (AnfAlgo::GetCNodeName(anf_node) == kReceiveOpName) {
135     size = AnfAlgo::GetOutputTensorNum(anf_node);
136   }
137   for (size_t i = 0; i < size; ++i) {
138     if (!GetHcomTypeSize(data_type_list[i], &type_size)) {
139       return false;
140     }
141 
142     if (!GetHcclOpSize(data_type_list[i], shape_list[i], &input_size)) {
143       MS_LOG(ERROR) << "Get GetHcclOpSize failed";
144       return false;
145     }
146 
147     if (AnfAlgo::GetCNodeName(anf_node) == kReduceScatterOpName) {
148       int64_t rank_size;
149       auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
150       MS_EXCEPTION_IF_NULL(primitive);
151       if (primitive->GetAttr(kAttrRankSize) != nullptr) {
152         rank_size = GetValue<int64_t>(primitive->GetAttr(kAttrRankSize));
153       } else {
154         MS_LOG(ERROR) << "Get rank size failed";
155         return false;
156       }
157       size_t actual_input_size = input_size;
158       if (AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && AnfAlgo::GetNodeAttr<int64_t>(anf_node, kAttrFusion)) {
159         actual_input_size = (input_size + align_size - 1 + filled_size) / align_size * align_size;
160       }
161       block_size = static_cast<uint64_t>(actual_input_size / LongToSize(rank_size));
162       total_size = total_size + block_size;
163     } else {
164       if (AnfAlgo::GetCNodeName(anf_node) == kAllGatherOpName) {
165         if (AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && AnfAlgo::GetNodeAttr<int64_t>(anf_node, kAttrFusion) &&
166             AnfAlgo::GetInputTensorNum(anf_node) > 1) {
167           block_size = (input_size + align_size - 1 + filled_size) / align_size * align_size;
168         } else {
169           block_size = input_size;
170         }
171       } else {
172         block_size =
173           IsPyNativeMode() ? input_size : (input_size + align_size - 1 + filled_size) / align_size * align_size;
174       }
175       total_size = total_size + block_size;
176     }
177   }
178 
179   if (type_size == 0 || total_size % type_size != 0) {
180     MS_LOG(ERROR) << "Total_size[" << total_size << "],Type_size[" << type_size << "] != 0, fail!";
181     return false;
182   }
183   *total_count = total_size / type_size;
184   return true;
185 }
186 
GetHcomOperationType(const AnfNodePtr & anf_node,HcclReduceOp * op_type)187 bool HcomUtil::GetHcomOperationType(const AnfNodePtr &anf_node, HcclReduceOp *op_type) {
188   MS_EXCEPTION_IF_NULL(anf_node);
189   MS_EXCEPTION_IF_NULL(op_type);
190   auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
191   MS_EXCEPTION_IF_NULL(primitive);
192   if (primitive->GetAttr(kAttrOp) == nullptr) {
193     MS_LOG(ERROR) << "Get HCOM_ATTR_REDUCE_TYPE fail, not support!";
194     return false;
195   }
196   auto hcom_op_type = GetValue<std::string>(primitive->GetAttr(kAttrOp));
197   if (hcom_op_type == "min") {
198     *op_type = HCCL_REDUCE_MIN;
199   } else if (hcom_op_type == "max") {
200     *op_type = HCCL_REDUCE_MAX;
201   } else if (hcom_op_type == "prod") {
202     *op_type = HCCL_REDUCE_PROD;
203   } else if (hcom_op_type == "sum") {
204     *op_type = HCCL_REDUCE_SUM;
205   } else {
206     MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_REDUCE_TYPE fail, [" << hcom_op_type << "] not support!";
207     return false;
208   }
209   return true;
210 }
211 
GetHcomRootId(const AnfNodePtr & anf_node,uint32_t * root_id)212 bool HcomUtil::GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id) {
213   MS_EXCEPTION_IF_NULL(anf_node);
214   MS_EXCEPTION_IF_NULL(root_id);
215   auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
216   MS_EXCEPTION_IF_NULL(primitive);
217   if (primitive->GetAttr(kAttrRootRank) != nullptr) {
218     *root_id = (uint32_t)GetValue<int64_t>(primitive->GetAttr(kAttrRootRank));
219   } else {
220     MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_ROOT_INDEX fail, not support!";
221     return false;
222   }
223   return true;
224 }
225 
GetHcomSrcRank(const AnfNodePtr & anf_node,uint32_t * src_rank)226 bool HcomUtil::GetHcomSrcRank(const AnfNodePtr &anf_node, uint32_t *src_rank) {
227   MS_EXCEPTION_IF_NULL(anf_node);
228   MS_EXCEPTION_IF_NULL(src_rank);
229   auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
230   MS_EXCEPTION_IF_NULL(primitive);
231   if (primitive->GetAttr("src_rank") != nullptr) {
232     *src_rank = static_cast<uint32_t>(GetValue<int64_t>(primitive->GetAttr("src_rank")));
233   } else {
234     MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_SRC_RANK fail, not support!";
235     return false;
236   }
237   return true;
238 }
239 
GetHcomDestRank(const AnfNodePtr & anf_node,uint32_t * dest_rank)240 bool HcomUtil::GetHcomDestRank(const AnfNodePtr &anf_node, uint32_t *dest_rank) {
241   MS_EXCEPTION_IF_NULL(anf_node);
242   MS_EXCEPTION_IF_NULL(dest_rank);
243   auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
244   MS_EXCEPTION_IF_NULL(primitive);
245   if (primitive->GetAttr("dest_rank") != nullptr) {
246     *dest_rank = static_cast<uint32_t>(GetValue<int64_t>(primitive->GetAttr("dest_rank")));
247   } else {
248     MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_DEST_RANK fail, not support!";
249     return false;
250   }
251   return true;
252 }
253 
GetHcomReceiveType(const AnfNodePtr & anf_node,TypeId * receive_type)254 bool HcomUtil::GetHcomReceiveType(const AnfNodePtr &anf_node, TypeId *receive_type) {
255   MS_EXCEPTION_IF_NULL(anf_node);
256   MS_EXCEPTION_IF_NULL(receive_type);
257   auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
258   MS_EXCEPTION_IF_NULL(primitive);
259   if (primitive->GetAttr("dtype") != nullptr) {
260     *receive_type = GetValue<NumberPtr>(primitive->GetAttr("dtype"))->type_id();
261   } else {
262     MS_LOG(ERROR) << "HcomUtil::Get HCOM_ATTR_SRTAG_INDEX fail, not support!";
263     return false;
264   }
265   return true;
266 }
267 
GetHcomGroup(NotNull<const AnfNodePtr &> anf_node,NotNull<std::string * > group)268 void HcomUtil::GetHcomGroup(NotNull<const AnfNodePtr &> anf_node, NotNull<std::string *> group) {
269   auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
270   MS_EXCEPTION_IF_NULL(primitive);
271   auto attr = primitive->GetAttr(kAttrGroup);
272   if (attr != nullptr) {
273     *group = GetValue<std::string>(attr);
274   } else {
275     MS_LOG(EXCEPTION) << "Get Hcom Group Attr of Op:" << anf_node->fullname_with_scope() << " failed";
276   }
277 }
278 }  // namespace mindspore
279