1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/tensor_layout/arrangement.h"
18 #include <algorithm>
19 #include <utility>
20 #include "utils/ms_utils.h"
21 #include "frontend/parallel/status.h"
22 #include "frontend/parallel/tensor_layout/shape_util.h"
23 #include "utils/convert_utils.h"
24 #include "utils/log_adapter.h"
25
26 namespace mindspore {
27 namespace parallel {
Init(const Shape & array)28 Status Arrangement::Init(const Shape &array) {
29 Status status = Array::Init(array);
30 if (status != Status::SUCCESS) {
31 return Status::FAILED;
32 }
33 if (!IsValidArrangement()) {
34 MS_LOG(ERROR) << "invalid arrangement " << this->ToString();
35 return Status::FAILED;
36 }
37 ComputeSize();
38 return Status::SUCCESS;
39 }
40
IsValidArrangement()41 bool Arrangement::IsValidArrangement() {
42 return !std::any_of(array_.begin(), array_.end(), [](int64_t value) { return value <= 0 && value != -1; });
43 }
44
ComputeSize()45 void Arrangement::ComputeSize() {
46 size_ = 1;
47 for (auto &value : array_) {
48 size_ *= value;
49 }
50 }
51
52 /*
53 * if GetDimSize() = 0, return []
54 * if value <= array_[0], return [value]
55 * if array_[0] < value <= size_[i], return [shape[0], shape[1], ..., shape[i-1], value/size_[i-1]],
56 * where size_[i-1] = shape[0] * shape[1] * ... * shape[i-1],
57 * if value > size_, return []
58 */
GetFrontElementByValue(int64_t value) const59 Shape Arrangement::GetFrontElementByValue(int64_t value) const {
60 Shape out;
61 if (GetDimSize() == 0) {
62 return out;
63 }
64 if (value <= size_) {
65 int64_t size = 1;
66 size_t shape_list_idx = 0;
67 while (size < value) {
68 size *= array_[shape_list_idx];
69 if (size <= value) {
70 out.push_back(array_[shape_list_idx]);
71 } else {
72 if (size == 0) {
73 MS_LOG(ERROR) << "The size is 0";
74 out.clear();
75 return out;
76 }
77 out.push_back(value * array_[shape_list_idx] / size);
78 }
79 shape_list_idx++;
80 }
81 }
82 return out;
83 }
84
GetExpandedShapeByExpandListRemoveLeft(const std::vector<Arrangement> & expand_list) const85 std::shared_ptr<Arrangement> Arrangement::GetExpandedShapeByExpandListRemoveLeft(
86 const std::vector<Arrangement> &expand_list) const {
87 if (expand_list.size() != GetDimSize()) {
88 return nullptr;
89 }
90 Shape new_shape;
91 for (size_t i = 0; i < expand_list.size(); i++) {
92 Shape expand_shape = expand_list[i].GetFrontElementByValue(GetDimByIdx(i));
93 if (expand_shape.empty()) {
94 new_shape.push_back(GetDimByIdx(i));
95 } else {
96 (void)new_shape.insert(new_shape.end(), expand_shape.begin(), expand_shape.end());
97 }
98 }
99 Arrangement arrangement_new;
100 (void)arrangement_new.Init(new_shape);
101 return std::make_shared<Arrangement>(arrangement_new);
102 }
103
104 /*
105 * example:
106 * expand_shape = [4, 2, 2, 2]
107 * array_ = [8, 4],
108 * arrangement_list = [[4, 2], [2, 2]]
109 */
GetExpandShapeList(const Arrangement & expand_shape) const110 std::shared_ptr<std::vector<Arrangement>> Arrangement::GetExpandShapeList(const Arrangement &expand_shape) const {
111 int64_t size = 1;
112 size_t ind = 0;
113 std::vector<Arrangement> arrangement_list;
114 Shape shape;
115 for (size_t i = 0; i < expand_shape.GetDimSize(); i++) {
116 size *= expand_shape.GetDimByIdx(i);
117 if (size > GetDimByIdx(ind)) {
118 MS_LOG(ERROR) << "invalid expand_shape";
119 return nullptr;
120 } else if (size < GetDimByIdx(ind)) {
121 shape.push_back(expand_shape.GetDimByIdx(i));
122 continue;
123 } else {
124 shape.push_back(expand_shape.GetDimByIdx(i));
125 Arrangement arrangement;
126 (void)arrangement.Init(shape);
127 arrangement_list.push_back(arrangement);
128 shape.clear();
129 ind++;
130 size = 1;
131 }
132 }
133 if (ind != GetDimSize()) {
134 MS_LOG(ERROR) << "invalid expand_shape";
135 return nullptr;
136 }
137 auto arrangement_new = std::make_shared<std::vector<Arrangement>>(arrangement_list);
138 return arrangement_new;
139 }
140
GetExpandShapeListPair(const Arrangement & expand_shape) const141 std::shared_ptr<std::pair<std::vector<Arrangement>, Arrangement>> Arrangement::GetExpandShapeListPair(
142 const Arrangement &expand_shape) const {
143 std::shared_ptr<std::vector<Arrangement>> expand_shape_list_ptr = GetExpandShapeList(expand_shape);
144 if (expand_shape_list_ptr == nullptr) {
145 return nullptr;
146 }
147 Shape expand_num_list_shape;
148 (void)std::transform(expand_shape_list_ptr->begin(), expand_shape_list_ptr->end(),
149 std::back_inserter(expand_num_list_shape),
150 [](const Arrangement &arr) { return SizeToLong(arr.GetDimSize()); });
151 Arrangement expand_num_list;
152 Status status = expand_num_list.Init(expand_num_list_shape);
153 if (status != Status::SUCCESS) {
154 return nullptr;
155 }
156 auto out_value = std::make_pair(*expand_shape_list_ptr, expand_num_list);
157 return std::make_shared<std::pair<std::vector<Arrangement>, Arrangement>>(out_value);
158 }
159
ComputeReverseAccumulateSumInReverseOrder() const160 Shape Arrangement::ComputeReverseAccumulateSumInReverseOrder() const {
161 Shape shape_accum;
162 int64_t size = 0;
163 for (auto iter = array_.end() - 1; iter >= array_.begin(); --iter) {
164 shape_accum.push_back(size);
165 size += *iter;
166 }
167 return shape_accum;
168 }
169
GetExpandedShapeByExpandListReserveLeft(const std::vector<Arrangement> & expand_list) const170 std::shared_ptr<Arrangement> Arrangement::GetExpandedShapeByExpandListReserveLeft(
171 const std::vector<Arrangement> &expand_list) const {
172 if (expand_list.size() != GetDimSize()) {
173 return nullptr;
174 }
175 Shape new_shape;
176 for (size_t i = 0; i < expand_list.size(); i++) {
177 if (expand_list[i].GetDimSize() >= 1) {
178 int64_t size = 1;
179 for (size_t k = 0; k < expand_list[i].GetDimSize() - 1; k++) {
180 new_shape.push_back(expand_list[i].GetDimByIdx(k));
181 size *= expand_list[i].GetDimByIdx(k);
182 }
183 new_shape.push_back(GetDimByIdx(i) / size);
184 } else {
185 new_shape.push_back(GetDimByIdx(i));
186 }
187 }
188 Arrangement arrangement_new;
189 (void)arrangement_new.Init(new_shape);
190 return std::make_shared<Arrangement>(arrangement_new);
191 }
192
GetUnifiedShape(const Arrangement & in2) const193 std::shared_ptr<Arrangement> Arrangement::GetUnifiedShape(const Arrangement &in2) const {
194 std::vector<int64_t> in1_accum;
195 Status status = ShapeToAccumulateProduct(array_, &in1_accum);
196 if (status != Status::SUCCESS) {
197 return nullptr;
198 }
199 std::vector<int64_t> in2_accum;
200 status = ShapeToAccumulateProduct(in2.array(), &in2_accum);
201 if (status != Status::SUCCESS) {
202 return nullptr;
203 }
204 std::vector<int64_t> out_accum;
205 status = UnifyAccumulateProduct(in1_accum, in2_accum, &out_accum);
206 if (status != Status::SUCCESS) {
207 return nullptr;
208 }
209 Shape out_shape;
210 status = AccumulateProductToShape(out_accum, &out_shape);
211 if (status != Status::SUCCESS) {
212 return nullptr;
213 }
214 Arrangement out;
215 status = out.Init(out_shape);
216 if (status != Status::SUCCESS) {
217 return nullptr;
218 }
219 return std::make_shared<Arrangement>(out);
220 }
221
GetSqueezeIdx() const222 std::vector<size_t> Arrangement::GetSqueezeIdx() const {
223 std::vector<size_t> out;
224 for (size_t i = 0; i < GetDimSize(); i++) {
225 if (GetDimByIdx(SizeToUlong(i)) == 1) {
226 out.push_back(i);
227 }
228 }
229 return out;
230 }
231
GetSqueezeArrangement() const232 Arrangement Arrangement::GetSqueezeArrangement() const {
233 Shape out_shape(array_.size());
234 auto it = std::copy_if(array_.begin(), array_.end(), out_shape.begin(), [](int64_t value) { return value != 1; });
235 out_shape.resize(LongToSize(std::distance(out_shape.begin(), it)));
236
237 // if all elements are 1, out_shape = {1}
238 if (out_shape.empty()) {
239 MS_LOG(ERROR) << "out_shape size is 0, this may not happen under current situation";
240 out_shape.push_back(1);
241 }
242 Arrangement out;
243 (void)out.Init(out_shape);
244 return out;
245 }
246 } // namespace parallel
247 } // namespace mindspore
248