• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "runtime/hccl_adapter/all_to_all_v_calc_param.h"
17 #include <functional>
18 #include <map>
19 #include <string>
20 #include "backend/session/anf_runtime_algorithm.h"
21 #include "transform/graph_ir/util.h"
22 #include "runtime/device/memory_manager.h"
23 #include "mindspore/core/utils/convert_utils_base.h"
24 
25 namespace mindspore::hccl {
26 namespace {
IsInTheOrder(const std::vector<int64_t> & vec)27 bool IsInTheOrder(const std::vector<int64_t> &vec) {
28   for (size_t i = 1; i < vec.size(); ++i) {
29     if (vec[i] <= vec[i - 1]) {
30       return false;
31     }
32   }
33 
34   return true;
35 }
36 }  // namespace
AllToAllvCalcParam(const CNodeWeakPtr & cnode,uint32_t rank_size)37 AllToAllvCalcParam::AllToAllvCalcParam(const CNodeWeakPtr &cnode, uint32_t rank_size)
38     : node_(cnode),
39       rank_size_(rank_size),
40       send_counts_(rank_size, 0),
41       sdispls_(rank_size, 0),
42       recv_counts_(rank_size, 0),
43       rdispls_(rank_size, 0) {}
44 
CalcOpParam()45 void AllToAllvCalcParam::CalcOpParam() {
46   CNodePtr cnode = node_.lock();
47   MS_EXCEPTION_IF_NULL(cnode);
48   size_t input_num = AnfAlgo::GetInputTensorNum(cnode);
49   size_t output_num = AnfAlgo::GetOutputTensorNum(cnode);
50   std::vector<size_t> input_aligned_mem_size(input_num);
51   std::vector<size_t> output_aligned_mem_size(output_num);
52   std::vector<size_t> input_real_mem_size(input_num);
53   std::vector<size_t> output_real_mem_size(output_num);
54   for (size_t i = 0; i < input_num; ++i) {
55     auto ms_shape = AnfAlgo::GetInputDeviceShape(cnode, i);
56     auto type_size = transform::TransformUtil::GetDataTypeSize(AnfAlgo::GetInputDeviceDataType(cnode, i));
57     size_t origin_mem_size = std::accumulate(ms_shape.begin(), ms_shape.end(), type_size, std::multiplies<size_t>());
58     size_t aligned_mem_size = device::MemoryManager::GetCommonAlignSize(origin_mem_size);
59     input_aligned_mem_size[i] = aligned_mem_size / type_size;
60     input_real_mem_size[i] = origin_mem_size / type_size;
61   }
62   for (size_t i = 0; i < output_num; ++i) {
63     auto ms_shape = AnfAlgo::GetOutputDeviceShape(cnode, i);
64     auto type_size = transform::TransformUtil::GetDataTypeSize(AnfAlgo::GetOutputDeviceDataType(cnode, i));
65     size_t origin_mem_size = std::accumulate(ms_shape.begin(), ms_shape.end(), type_size, std::multiplies<size_t>());
66     size_t aligned_mem_size = device::MemoryManager::GetCommonAlignSize(origin_mem_size);
67     output_aligned_mem_size[i] = aligned_mem_size / type_size;
68     output_real_mem_size[i] = origin_mem_size / type_size;
69   }
70   CalcMemOffset(input_aligned_mem_size, input_real_mem_size, kAttrSendRankIds, &send_counts_, &sdispls_);
71   CalcMemOffset(output_aligned_mem_size, output_real_mem_size, kAttrRecvRankIds, &recv_counts_, &rdispls_);
72 }
73 
CalcMemOffset(const std::vector<size_t> & mem_sizes,const std::vector<size_t> & real_sizes,const std::string & rank_ids_attr,std::vector<int64_t> * counts,std::vector<int64_t> * displs)74 void AllToAllvCalcParam::CalcMemOffset(const std::vector<size_t> &mem_sizes, const std::vector<size_t> &real_sizes,
75                                        const std::string &rank_ids_attr, std::vector<int64_t> *counts,
76                                        std::vector<int64_t> *displs) {
77   CNodePtr cnode = node_.lock();
78   MS_EXCEPTION_IF_NULL(cnode);
79   auto rank_ids = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, rank_ids_attr);
80   if (mem_sizes.size() != rank_ids.size() || real_sizes.size() != rank_ids.size()) {
81     MS_LOG(EXCEPTION) << "Invalid addr num " << mem_sizes.size() << " and " << real_sizes.size()
82                       << " must be equal to rank ids size " << rank_ids.size();
83   }
84 
85   if (!IsInTheOrder(rank_ids)) {
86     std::vector<size_t> mem_offset(mem_sizes.size(), 0);
87     for (size_t i = 1; i < mem_sizes.size(); ++i) {
88       mem_offset[i] = mem_offset[i - 1] + mem_sizes[i];
89     }
90     for (size_t i = 0; i < rank_ids.size(); ++i) {
91       if (rank_ids[i] < 0 || static_cast<size_t>(rank_ids[i]) >= rank_size_) {
92         MS_LOG(EXCEPTION) << "Invalid rank id " << rank_ids[i] << " at index " << i << " as rank size " << rank_size_;
93       }
94       (*counts)[LongToSize(rank_ids[i])] = SizeToLong(real_sizes[i]);
95       (*displs)[LongToSize(rank_ids[i])] = SizeToLong(mem_offset[i]);
96     }
97     return;
98   }
99 
100   std::map<int64_t, size_t> rank_id_map;
101   for (size_t i = 0; i < rank_ids.size(); ++i) {
102     if (rank_ids[i] < 0 || static_cast<size_t>(rank_ids[i]) >= rank_size_) {
103       MS_LOG(EXCEPTION) << "Invalid rank id " << rank_ids[i] << " at index " << i << " as rank size " << rank_size_;
104     }
105     rank_id_map.emplace(rank_ids[i], i);
106   }
107 
108   size_t offset = 0;
109   for (uint32_t i = 0; i < rank_size_; ++i) {
110     (*displs)[i] = SizeToLong(offset);
111     auto iter = rank_id_map.find(i);
112     if (iter != rank_id_map.end()) {
113       (*counts)[i] = SizeToLong(real_sizes[iter->second]);
114       offset += mem_sizes[iter->second];
115     } else {
116       (*counts)[i] = 0;
117     }
118   }
119 }
120 }  // namespace mindspore::hccl
121