• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/device_matrix.h"
18 
19 #include <algorithm>
20 #include <cstdint>
21 #include <functional>
22 #include <numeric>
23 #include <utility>
24 #include <vector>
25 #include <set>
26 
27 #include "frontend/parallel/status.h"
28 #include "utils/log_adapter.h"
29 #include "frontend/parallel/tensor_layout/map.h"
30 #include "utils/convert_utils_base.h"
31 
32 namespace mindspore {
33 namespace parallel {
DeviceMatrix(int64_t rank,RankList dev_list,Shape dev_shape)34 DeviceMatrix::DeviceMatrix(int64_t rank, RankList dev_list, Shape dev_shape)
35     : rank_(rank), dev_list_(std::move(dev_list)), dev_shape_(std::move(dev_shape)) {
36   if (!std::any_of(dev_list_.begin(), dev_list_.end(), [rank](int64_t a) { return a == rank; })) {
37     MS_LOG(EXCEPTION) << "Rank " << rank << " is not in the current stage!";
38   }
39   int64_t total = std::accumulate(dev_shape_.begin(), dev_shape_.end(), 1, std::multiplies<int64_t>());
40   if (LongToSize(total) != dev_list_.size()) {
41     MS_LOG(EXCEPTION) << "Device shape does not match the size of the device list!";
42   }
43 }
44 
CreateGroupList()45 Status DeviceMatrix::CreateGroupList() {
46   size_t size = dev_shape_.size();
47   RankList group;
48   for (size_t i = 0; i < size; i++) {
49     Status status = GetDevicesAlongDim(SizeToUlong(i), &group);
50     group_list_.push_back(group);
51     if (status == Status::FAILED) {
52       return Status::FAILED;
53     }
54   }
55   return Status::SUCCESS;
56 }
57 
GetDevicesAlongDim(const uint64_t & dim,RankList * devices)58 Status DeviceMatrix::GetDevicesAlongDim(const uint64_t &dim, RankList *devices) {
59   if (dim >= dev_shape_.size()) {
60     MS_LOG(EXCEPTION) << "The dimension " << dim << " is out of the size of the device shape!";
61   }
62   if (dev_shape_[dim] == 1) {
63     *devices = {rank_};
64     return Status::SUCCESS;
65   }
66 
67   RankList group;
68   std::vector<RankList> local_group_list;
69 
70   // lower than dim
71   int64_t step = 1;
72   for (uint64_t i = dim + 1; i < dev_shape_.size(); i++) {
73     step = step * dev_shape_[i];
74   }
75   int64_t num = *dev_list_.begin();
76   for (int64_t i = 0; i < dev_shape_[dim]; i++) {
77     group.push_back(num);
78     num += step;
79   }
80 
81   for (int64_t i = 0; i < step; i++) {
82     local_group_list.push_back(group);
83     (void)std::for_each(group.begin(), group.end(), [](int64_t &a) { a++; });
84   }
85 
86   // higher than dim
87   step = step * dev_shape_[dim];
88   int64_t len = SizeToLong(dev_list_.size()) / step;
89 
90   // search rank
91   int64_t target = rank_;
92   for (int64_t i = 0; i < len; i++) {
93     for (RankList &temp : local_group_list) {
94       if (std::any_of(temp.begin(), temp.end(), [target](int64_t a) { return a == target; })) {
95         *devices = temp;
96         return Status::SUCCESS;
97       }
98       (void)std::for_each(temp.begin(), temp.end(), [step](int64_t &a) { a = a + step; });
99     }
100   }
101   MS_LOG(ERROR) << "Can't find groups for rank" << rank_ << " in device list!";
102   return Status::FAILED;
103 }
104 
GetDevicesAlongMultiDim(const std::vector<int64_t> & dims,RankList * devices)105 Status DeviceMatrix::GetDevicesAlongMultiDim(const std::vector<int64_t> &dims, RankList *devices) {
106   std::set<int64_t> repeated_rank_set;
107   for (const auto &dim : dims) {
108     if (dim != -1) {
109       auto r_dim = LongToUlong(dim);
110       if (repeated_rank_set.empty()) {
111         DeviceMatrix dev_matrix(rank_, dev_list_, dev_shape_);
112         RankList cur_dim_reduce_list;
113         dev_matrix.GetDevicesAlongDim(r_dim, &cur_dim_reduce_list);
114         repeated_rank_set.insert(cur_dim_reduce_list.begin(), cur_dim_reduce_list.end());
115       } else {
116         auto repeated_rank_set_cpy = repeated_rank_set;
117         for (const auto &rank : repeated_rank_set_cpy) {
118           DeviceMatrix dev_matrix(rank, dev_list_, dev_shape_);
119           RankList dim_reduce_list;
120           dev_matrix.GetDevicesAlongDim(r_dim, &dim_reduce_list);
121           repeated_rank_set.insert(dim_reduce_list.begin(), dim_reduce_list.end());
122         }
123       }
124     }
125   }
126   std::copy(repeated_rank_set.begin(), repeated_rank_set.end(), std::back_inserter(*devices));
127   return SUCCESS;
128 }
129 
ConvertRankToCoordinate(int64_t rank,const Shape & dev_shape)130 Shape ConvertRankToCoordinate(int64_t rank, const Shape &dev_shape) {
131   Shape dev_coordinate;
132   for (size_t i = 0; i < dev_shape.size(); ++i) {
133     int64_t size = dev_shape[dev_shape.size() - i - 1];
134     if (size == 0) {
135       MS_LOG(EXCEPTION) << "Invalid dev shape: " << ShapeToString(dev_shape);
136     } else {
137       int64_t index = rank % size;
138       (void)dev_coordinate.insert(dev_coordinate.cbegin(), index);
139       rank = rank / size;
140     }
141   }
142   return dev_coordinate;
143 }
144 
GetDevicesByTensorMap(const Shape & tensor_map,RankList * rank_list)145 Status DeviceMatrix::GetDevicesByTensorMap(const Shape &tensor_map, RankList *rank_list) {
146   for (auto &element : tensor_map) {
147     // -1 means the corresponding dimension is not split.
148     if (element == MAP_NONE) {
149       continue;
150     } else if ((element < 0) || (LongToSize(element) >= dev_shape_.size())) {
151       MS_LOG(ERROR) << "create group by tensor map: the tensor map is invalid";
152       return FAILED;
153     }
154   }
155 
156   // Convert the global rank to the local rank(The index of the array) to compute the coordinate
157   uint32_t local_rank = 0;
158   for (auto &tmp_rank : dev_list_) {
159     if (tmp_rank == rank_) {
160       break;
161     }
162     ++local_rank;
163   }
164   if (local_rank == dev_list_.size()) {
165     MS_LOG(ERROR) << "Rank id: " << local_rank << "is not in the device list.";
166     return FAILED;
167   }
168 
169   Shape current_rank_coordinate = ConvertRankToCoordinate(static_cast<int32_t>(local_rank), dev_shape_);
170   for (uint32_t loop_local_rank = 0; loop_local_rank < dev_list_.size(); ++loop_local_rank) {
171     Shape tmp_rank_coordinate = ConvertRankToCoordinate(loop_local_rank, dev_shape_);
172     bool matched = true;
173     for (auto &map : tensor_map) {
174       if (map == MAP_NONE) {
175         continue;
176       }
177       size_t index = dev_shape_.size() - LongToSize(map) - 1;
178       if (current_rank_coordinate[index] != tmp_rank_coordinate[index]) {
179         matched = false;
180         break;
181       }
182     }
183     if (matched) {
184       rank_list->push_back(dev_list_[loop_local_rank]);
185     }
186   }
187 
188   return SUCCESS;
189 }
190 
ShapeToString(const Shape & shape)191 std::string ShapeToString(const Shape &shape) {
192   std::string str = "[";
193   for (size_t i = 0; i < shape.size(); ++i) {
194     str += std::to_string(shape[i]);
195     if (i < shape.size() - 1) {
196       str += ", ";
197     }
198   }
199   return str + "]";
200 }
201 
ShapesToString(const Shapes & shapes)202 std::string ShapesToString(const Shapes &shapes) {
203   std::string str = "[";
204   for (size_t i = 0; i < shapes.size(); ++i) {
205     str += ShapeToString(shapes[i]);
206     if (i < shapes.size() - 1) {
207       str += ", ";
208     }
209   }
210   return str + "]";
211 }
212 
ListToString(const RankList & list)213 std::string ListToString(const RankList &list) {
214   std::string str = "[";
215   for (auto &element : list) {
216     str += std::to_string(element) + ", ";
217   }
218   return str + "]";
219 }
220 }  // namespace parallel
221 }  // namespace mindspore
222