1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "runtime/hccl_adapter/all_to_all_v_calc_param.h"
17 #include <functional>
18 #include <map>
19 #include <string>
20 #include "backend/session/anf_runtime_algorithm.h"
21 #include "transform/graph_ir/util.h"
22 #include "runtime/device/memory_manager.h"
23 #include "mindspore/core/utils/convert_utils_base.h"
24
25 namespace mindspore::hccl {
26 namespace {
IsInTheOrder(const std::vector<int64_t> & vec)27 bool IsInTheOrder(const std::vector<int64_t> &vec) {
28 for (size_t i = 1; i < vec.size(); ++i) {
29 if (vec[i] <= vec[i - 1]) {
30 return false;
31 }
32 }
33
34 return true;
35 }
36 } // namespace
AllToAllvCalcParam(const CNodeWeakPtr & cnode,uint32_t rank_size)37 AllToAllvCalcParam::AllToAllvCalcParam(const CNodeWeakPtr &cnode, uint32_t rank_size)
38 : node_(cnode),
39 rank_size_(rank_size),
40 send_counts_(rank_size, 0),
41 sdispls_(rank_size, 0),
42 recv_counts_(rank_size, 0),
43 rdispls_(rank_size, 0) {}
44
CalcOpParam()45 void AllToAllvCalcParam::CalcOpParam() {
46 CNodePtr cnode = node_.lock();
47 MS_EXCEPTION_IF_NULL(cnode);
48 size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
49 size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
50 std::vector<size_t> input_aligned_mem_size(input_num);
51 std::vector<size_t> output_aligned_mem_size(output_num);
52 std::vector<size_t> input_real_mem_size(input_num);
53 std::vector<size_t> output_real_mem_size(output_num);
54 for (size_t i = 0; i < input_num; ++i) {
55 auto ms_shape = AnfAlgo::GetInputDeviceShape(cnode, i);
56 auto type_size = transform::TransformUtil::GetDataTypeSize(AnfAlgo::GetInputDeviceDataType(cnode, i));
57 size_t origin_mem_size = std::accumulate(ms_shape.begin(), ms_shape.end(), type_size, std::multiplies<size_t>());
58 size_t aligned_mem_size = device::MemoryManager::GetCommonAlignSize(origin_mem_size);
59 input_aligned_mem_size[i] = aligned_mem_size / type_size;
60 input_real_mem_size[i] = origin_mem_size / type_size;
61 }
62 for (size_t i = 0; i < output_num; ++i) {
63 auto ms_shape = AnfAlgo::GetOutputDeviceShape(cnode, i);
64 auto type_size = transform::TransformUtil::GetDataTypeSize(AnfAlgo::GetOutputDeviceDataType(cnode, i));
65 size_t origin_mem_size = std::accumulate(ms_shape.begin(), ms_shape.end(), type_size, std::multiplies<size_t>());
66 size_t aligned_mem_size = device::MemoryManager::GetCommonAlignSize(origin_mem_size);
67 output_aligned_mem_size[i] = aligned_mem_size / type_size;
68 output_real_mem_size[i] = origin_mem_size / type_size;
69 }
70 CalcMemOffset(input_aligned_mem_size, input_real_mem_size, kAttrSendRankIds, &send_counts_, &sdispls_);
71 CalcMemOffset(output_aligned_mem_size, output_real_mem_size, kAttrRecvRankIds, &recv_counts_, &rdispls_);
72 }
73
CalcMemOffset(const std::vector<size_t> & mem_sizes,const std::vector<size_t> & real_sizes,const std::string & rank_ids_attr,std::vector<int64_t> * counts,std::vector<int64_t> * displs)74 void AllToAllvCalcParam::CalcMemOffset(const std::vector<size_t> &mem_sizes, const std::vector<size_t> &real_sizes,
75 const std::string &rank_ids_attr, std::vector<int64_t> *counts,
76 std::vector<int64_t> *displs) {
77 CNodePtr cnode = node_.lock();
78 MS_EXCEPTION_IF_NULL(cnode);
79 auto rank_ids = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, rank_ids_attr);
80 if (mem_sizes.size() != rank_ids.size() || real_sizes.size() != rank_ids.size()) {
81 MS_LOG(EXCEPTION) << "Invalid addr num " << mem_sizes.size() << " and " << real_sizes.size()
82 << " must be equal to rank ids size " << rank_ids.size();
83 }
84
85 if (!IsInTheOrder(rank_ids)) {
86 std::vector<size_t> mem_offset(mem_sizes.size(), 0);
87 for (size_t i = 1; i < mem_sizes.size(); ++i) {
88 mem_offset[i] = mem_offset[i - 1] + mem_sizes[i];
89 }
90 for (size_t i = 0; i < rank_ids.size(); ++i) {
91 if (rank_ids[i] < 0 || static_cast<size_t>(rank_ids[i]) >= rank_size_) {
92 MS_LOG(EXCEPTION) << "Invalid rank id " << rank_ids[i] << " at index " << i << " as rank size " << rank_size_;
93 }
94 (*counts)[LongToSize(rank_ids[i])] = SizeToLong(real_sizes[i]);
95 (*displs)[LongToSize(rank_ids[i])] = SizeToLong(mem_offset[i]);
96 }
97 return;
98 }
99
100 std::map<int64_t, size_t> rank_id_map;
101 for (size_t i = 0; i < rank_ids.size(); ++i) {
102 if (rank_ids[i] < 0 || static_cast<size_t>(rank_ids[i]) >= rank_size_) {
103 MS_LOG(EXCEPTION) << "Invalid rank id " << rank_ids[i] << " at index " << i << " as rank size " << rank_size_;
104 }
105 rank_id_map.emplace(rank_ids[i], i);
106 }
107
108 size_t offset = 0;
109 for (uint32_t i = 0; i < rank_size_; ++i) {
110 (*displs)[i] = SizeToLong(offset);
111 auto iter = rank_id_map.find(i);
112 if (iter != rank_id_map.end()) {
113 (*counts)[i] = SizeToLong(real_sizes[iter->second]);
114 offset += mem_sizes[iter->second];
115 } else {
116 (*counts)[i] = 0;
117 }
118 }
119 }
120 } // namespace mindspore::hccl
121