• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_UTILS_H_
18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HCCL_HCOM_UTILS_H_
19 
20 #include <string>
21 #include <map>
22 #include <vector>
23 #include <memory>
24 #include "ir/dtype.h"
25 #include "hccl/base.h"
26 #include "utils/contract.h"
27 #include "hccl/hccl_types.h"
28 
29 namespace mindspore {
30 using std::map;
31 using std::string;
32 using std::vector;
33 
34 constexpr auto kAllGather = "AllGather";
35 constexpr auto kAllReduce = "AllReduce";
36 constexpr auto kBroadcast = "Broadcast";
37 constexpr auto kHcomSend = "Send";
38 constexpr auto kReceive = "Receive";
39 constexpr auto kReduceScatter = "ReduceScatter";
40 constexpr auto kAllToAllv = "AllToAllv";
41 
42 /* Correspondence between data_type and hcom data type in Ascend */
43 static map<int64_t, HcclDataType> kConstOpHcomDataTypeMap = {
44   {TypeId::kNumberTypeFloat32, HCCL_DATA_TYPE_FP32},
45   {TypeId::kNumberTypeFloat16, HCCL_DATA_TYPE_FP16},
46   {TypeId::kNumberTypeInt8, HCCL_DATA_TYPE_INT8},
47   {TypeId::kNumberTypeInt32, HCCL_DATA_TYPE_INT32},
48 };
49 
50 /* Correspondence between data_type and occupied byte size in hcom */
51 static map<HcclDataType, uint32_t> kConstOpHcomDataTypeSizeMap = {
52   {HCCL_DATA_TYPE_FP32, sizeof(float)},
53   {HCCL_DATA_TYPE_FP16, sizeof(float) / 2},
54   {HCCL_DATA_TYPE_INT8, sizeof(int8_t)},
55   {HCCL_DATA_TYPE_INT32, sizeof(int32_t)},
56 };
57 
58 class HcomUtil {
59  public:
60   static bool GetKernelInputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_shape_list);
61   static bool GetKernelOutputShape(const AnfNodePtr &anf_node, vector<vector<size_t>> *hccl_kernel_shape_list);
62   static ::HcclDataType ConvertHcclType(TypeId type_id);
63   static bool GetHcomDataType(const AnfNodePtr &anf_node, vector<HcclDataType> *data_type_list);
64   static bool GetHcclOpSize(const HcclDataType &data_type, const vector<size_t> &shape, size_t *size);
65   static bool GetHcomTypeSize(const HcclDataType &data_type, uint32_t *size);
66   static bool GetHcomCount(const AnfNodePtr &anf_node, const vector<HcclDataType> &data_type_list,
67                            const vector<vector<size_t>> &shape_list, uint64_t *total_count);
68   static bool GetHcomOperationType(const AnfNodePtr &anf_node, HcclReduceOp *op_type);
69   static bool GetHcomRootId(const AnfNodePtr &anf_node, uint32_t *root_id);
70   static bool GetHcomSrcRank(const AnfNodePtr &anf_node, uint32_t *src_rank);
71   static bool GetHcomDestRank(const AnfNodePtr &anf_node, uint32_t *dest_rank);
72   static void GetHcomGroup(NotNull<const AnfNodePtr &> anf_node, NotNull<std::string *> group);
73   static bool GetHcomReceiveType(const AnfNodePtr &anf_node, TypeId *receive_type);
74 };
75 }  // namespace mindspore
76 
77 #endif
78