• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/device_matrix.h"
18 
19 #include <algorithm>
20 #include <cstdint>
21 #include <functional>
22 #include <numeric>
23 #include <utility>
24 #include <vector>
25 
26 #include "frontend/parallel/ops_info/operator_info.h"
27 #include "frontend/parallel/status.h"
28 #include "utils/log_adapter.h"
29 
30 namespace mindspore {
31 namespace parallel {
DeviceMatrix(int64_t rank,RankList dev_list,Shape dev_shape)32 DeviceMatrix::DeviceMatrix(int64_t rank, RankList dev_list, Shape dev_shape)
33     : rank_(rank), dev_list_(std::move(dev_list)), dev_shape_(std::move(dev_shape)) {
34   if (!std::any_of(dev_list_.begin(), dev_list_.end(), [rank](int64_t a) { return a == rank; })) {
35     MS_LOG(EXCEPTION) << "Rank " << rank << " is not in the current stage!";
36   }
37   int64_t total = std::accumulate(dev_shape_.begin(), dev_shape_.end(), 1, std::multiplies<int64_t>());
38   if (LongToSize(total) != dev_list_.size()) {
39     MS_LOG(EXCEPTION) << "Device shape does not match the size of the device list!";
40   }
41 }
42 
CreateGroupList()43 Status DeviceMatrix::CreateGroupList() {
44   size_t size = dev_shape_.size();
45   RankList group;
46   for (size_t i = 0; i < size; i++) {
47     Status status = GetDevicesAlongDim(SizeToUlong(i), &group);
48     group_list_.push_back(group);
49     if (status == Status::FAILED) {
50       return Status::FAILED;
51     }
52   }
53   return Status::SUCCESS;
54 }
55 
GetDevicesAlongDim(const uint64_t & dim,RankList * devices)56 Status DeviceMatrix::GetDevicesAlongDim(const uint64_t &dim, RankList *devices) {
57   if (dim >= dev_shape_.size()) {
58     MS_LOG(EXCEPTION) << "The dimension " << dim << " is out of the size of the device shape!";
59   }
60   if (dev_shape_[dim] == 1) {
61     *devices = {rank_};
62     return Status::SUCCESS;
63   }
64 
65   RankList group;
66   std::vector<RankList> local_group_list;
67 
68   // lower than dim
69   int64_t step = 1;
70   for (uint64_t i = dim + 1; i < dev_shape_.size(); i++) {
71     step = step * dev_shape_[i];
72   }
73   int64_t num = *dev_list_.begin();
74   for (int64_t i = 0; i < dev_shape_[dim]; i++) {
75     group.push_back(num);
76     num += step;
77   }
78 
79   for (int64_t i = 0; i < step; i++) {
80     local_group_list.push_back(group);
81     (void)std::for_each(group.begin(), group.end(), [](int64_t &a) { a++; });
82   }
83 
84   // higher than dim
85   step = step * dev_shape_[dim];
86   int64_t len = SizeToLong(dev_list_.size()) / step;
87 
88   // search rank
89   int64_t target = rank_;
90   for (int64_t i = 0; i < len; i++) {
91     for (RankList &temp : local_group_list) {
92       if (std::any_of(temp.begin(), temp.end(), [target](int64_t a) { return a == target; })) {
93         *devices = temp;
94         return Status::SUCCESS;
95       }
96       (void)std::for_each(temp.begin(), temp.end(), [step](int64_t &a) { a = a + step; });
97     }
98   }
99   MS_LOG(ERROR) << "Can't find groups for rank" << rank_ << " in device list!";
100   return Status::FAILED;
101 }
102 
ConvertRankToCoordinate(int64_t rank,const Shape & dev_shape)103 Shape ConvertRankToCoordinate(int64_t rank, const Shape &dev_shape) {
104   Shape dev_coordinate;
105   for (size_t i = 0; i < dev_shape.size(); ++i) {
106     int64_t size = dev_shape[dev_shape.size() - i - 1];
107     if (size == 0) {
108       MS_LOG(EXCEPTION) << "Invalid dev shape: " << ShapeToString(dev_shape);
109     } else {
110       int64_t index = rank % size;
111       (void)dev_coordinate.insert(dev_coordinate.begin(), index);
112       rank = rank / size;
113     }
114   }
115   return dev_coordinate;
116 }
117 
GetDevicesByTensorMap(const Shape & tensor_map,RankList * rank_list)118 Status DeviceMatrix::GetDevicesByTensorMap(const Shape &tensor_map, RankList *rank_list) {
119   for (auto &element : tensor_map) {
120     // -1 means the corresponding dimension is not split.
121     if (element == MAP_NONE) {
122       continue;
123     } else if ((element < 0) || (LongToSize(element) >= dev_shape_.size())) {
124       MS_LOG(ERROR) << "create group by tensor map: the tensor map is invalid";
125       return FAILED;
126     }
127   }
128 
129   // Convert the global rank to the local rank(The index of the array) to compute the coordinate
130   uint32_t local_rank = 0;
131   for (auto &tmp_rank : dev_list_) {
132     if (tmp_rank == rank_) {
133       break;
134     }
135     ++local_rank;
136   }
137   if (local_rank == dev_list_.size()) {
138     MS_LOG(ERROR) << "Rank id: " << local_rank << "is not in the device list.";
139     return FAILED;
140   }
141 
142   Shape current_rank_coordinate = ConvertRankToCoordinate((int32_t)local_rank, dev_shape_);
143   for (uint32_t loop_local_rank = 0; loop_local_rank < dev_list_.size(); ++loop_local_rank) {
144     Shape tmp_rank_coordinate = ConvertRankToCoordinate(loop_local_rank, dev_shape_);
145     bool matched = true;
146     for (auto &map : tensor_map) {
147       if (map == MAP_NONE) {
148         continue;
149       }
150       size_t index = dev_shape_.size() - LongToSize(map) - 1;
151       if (current_rank_coordinate[index] != tmp_rank_coordinate[index]) {
152         matched = false;
153         break;
154       }
155     }
156     if (matched) {
157       rank_list->push_back(dev_list_[loop_local_rank]);
158     }
159   }
160 
161   return SUCCESS;
162 }
163 
ShapeToString(const Shape & shape)164 std::string ShapeToString(const Shape &shape) {
165   std::string str = "[";
166   for (size_t i = 0; i < shape.size(); ++i) {
167     str += std::to_string(shape[i]);
168     if (i < shape.size() - 1) {
169       str += ", ";
170     }
171   }
172   return str + "]";
173 }
174 
ListToString(const RankList & list)175 std::string ListToString(const RankList &list) {
176   std::string str = "[";
177   for (auto &element : list) {
178     str += std::to_string(element) + ", ";
179   }
180   return str + "]";
181 }
182 }  // namespace parallel
183 }  // namespace mindspore
184