1 /**
2 * Copyright 2019-2024 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/tensor_layout/arrangement.h"
18
19 #include <algorithm>
20 #include <utility>
21
22 #include "frontend/parallel/status.h"
23 #include "frontend/parallel/tensor_layout/shape_util.h"
24 #include "include/common/utils/convert_utils.h"
25 #include "utils/log_adapter.h"
26
27 namespace mindspore {
28 namespace parallel {
Init(const Shape & array)29 Status Arrangement::Init(const Shape &array) {
30 Status status = Array::Init(array);
31 if (status != Status::SUCCESS) {
32 return Status::FAILED;
33 }
34 if (!IsValidArrangement()) {
35 MS_LOG(ERROR) << "invalid arrangement " << this->ToString();
36 return Status::FAILED;
37 }
38 ComputeSize();
39 return Status::SUCCESS;
40 }
41
UpdateTensorShape(size_t index,int64_t update_value)42 Status Arrangement::UpdateTensorShape(size_t index, int64_t update_value) {
43 if (index >= this->array_.size()) {
44 return Status::FAILED;
45 }
46 this->array_[index] = update_value;
47 return Status::SUCCESS;
48 }
49
IsValidArrangement()50 bool Arrangement::IsValidArrangement() {
51 return !std::any_of(array_.begin(), array_.end(), [](int64_t value) { return value <= 0 && value != -1; });
52 }
53
ComputeSize()54 void Arrangement::ComputeSize() {
55 size_ = 1;
56 for (auto &value : array_) {
57 size_ *= value;
58 }
59 }
60
61 /*
62 * if GetDimSize() = 0, return []
63 * if value <= array_[0], return [value]
64 * if array_[0] < value <= size_[i], return [shape[0], shape[1], ..., shape[i-1], value/size_[i-1]],
65 * where size_[i-1] = shape[0] * shape[1] * ... * shape[i-1],
66 * if value > size_, return []
67 */
GetFrontElementByValue(int64_t value) const68 Shape Arrangement::GetFrontElementByValue(int64_t value) const {
69 Shape out;
70 if (GetDimSize() == 0) {
71 return out;
72 }
73 if (value <= size_) {
74 int64_t size = 1;
75 size_t shape_list_idx = 0;
76 while (size < value) {
77 size *= array_[shape_list_idx];
78 if (size <= value) {
79 out.push_back(array_[shape_list_idx]);
80 } else {
81 if (size == 0) {
82 MS_LOG(ERROR) << "The size is 0";
83 out.clear();
84 return out;
85 }
86 out.push_back(value * array_[shape_list_idx] / size);
87 }
88 shape_list_idx++;
89 }
90 }
91 return out;
92 }
93
GetExpandedShapeByExpandListRemoveLeft(const std::vector<Arrangement> & expand_list) const94 std::shared_ptr<Arrangement> Arrangement::GetExpandedShapeByExpandListRemoveLeft(
95 const std::vector<Arrangement> &expand_list) const {
96 if (expand_list.size() != GetDimSize()) {
97 return nullptr;
98 }
99 Shape new_shape;
100 for (size_t i = 0; i < expand_list.size(); i++) {
101 Shape expand_shape = expand_list[i].GetFrontElementByValue(GetDimByIdx(i));
102 if (expand_shape.empty()) {
103 new_shape.push_back(GetDimByIdx(i));
104 } else {
105 (void)new_shape.insert(new_shape.cend(), expand_shape.cbegin(), expand_shape.cend());
106 }
107 }
108 Arrangement arrangement_new;
109 (void)arrangement_new.Init(new_shape);
110 return std::make_shared<Arrangement>(arrangement_new);
111 }
112
113 /*
114 * example:
115 * expand_shape = [4, 2, 2, 2]
116 * array_ = [8, 4],
117 * arrangement_list = [[4, 2], [2, 2]]
118 */
GetExpandShapeList(const Arrangement & expand_shape) const119 std::shared_ptr<std::vector<Arrangement>> Arrangement::GetExpandShapeList(const Arrangement &expand_shape) const {
120 int64_t size = 1;
121 size_t ind = 0;
122 std::vector<Arrangement> arrangement_list;
123 Shape shape;
124 for (size_t i = 0; i < expand_shape.GetDimSize(); i++) {
125 size *= expand_shape.GetDimByIdx(i);
126 if (size > GetDimByIdx(ind)) {
127 MS_LOG(INFO) << "invalid expand_shape:" << expand_shape.array();
128 return nullptr;
129 } else if (size < GetDimByIdx(ind)) {
130 shape.push_back(expand_shape.GetDimByIdx(i));
131 continue;
132 } else {
133 shape.push_back(expand_shape.GetDimByIdx(i));
134 Arrangement arrangement;
135 (void)arrangement.Init(shape);
136 arrangement_list.push_back(arrangement);
137 shape.clear();
138 ind++;
139 size = 1;
140 }
141 }
142 if (ind != GetDimSize()) {
143 MS_LOG(INFO) << "invalid expand_shape:" << expand_shape.array();
144 return nullptr;
145 }
146 auto arrangement_new = std::make_shared<std::vector<Arrangement>>(arrangement_list);
147 return arrangement_new;
148 }
149
GetExpandShapeListPair(const Arrangement & expand_shape) const150 std::shared_ptr<std::pair<std::vector<Arrangement>, Arrangement>> Arrangement::GetExpandShapeListPair(
151 const Arrangement &expand_shape) const {
152 std::shared_ptr<std::vector<Arrangement>> expand_shape_list_ptr = GetExpandShapeList(expand_shape);
153 if (expand_shape_list_ptr == nullptr) {
154 return nullptr;
155 }
156 Shape expand_num_list_shape;
157 (void)std::transform(expand_shape_list_ptr->begin(), expand_shape_list_ptr->end(),
158 std::back_inserter(expand_num_list_shape),
159 [](const Arrangement &arr) { return SizeToLong(arr.GetDimSize()); });
160 Arrangement expand_num_list;
161 Status status = expand_num_list.Init(expand_num_list_shape);
162 if (status != Status::SUCCESS) {
163 return nullptr;
164 }
165 auto out_value = std::make_pair(*expand_shape_list_ptr, expand_num_list);
166 return std::make_shared<std::pair<std::vector<Arrangement>, Arrangement>>(out_value);
167 }
168
ComputeReverseAccumulateSumInReverseOrder() const169 Shape Arrangement::ComputeReverseAccumulateSumInReverseOrder() const {
170 Shape shape_accum;
171 int64_t size = 0;
172 for (auto iter = array_.end() - 1; iter >= array_.begin(); --iter) {
173 shape_accum.push_back(size);
174 size += *iter;
175 }
176 return shape_accum;
177 }
178
GetExpandedShapeByExpandListReserveLeft(const std::vector<Arrangement> & expand_list) const179 std::shared_ptr<Arrangement> Arrangement::GetExpandedShapeByExpandListReserveLeft(
180 const std::vector<Arrangement> &expand_list) const {
181 if (expand_list.size() != GetDimSize()) {
182 return nullptr;
183 }
184 Shape new_shape;
185 for (size_t i = 0; i < expand_list.size(); i++) {
186 if (expand_list[i].GetDimSize() >= 1) {
187 int64_t size = 1;
188 for (size_t k = 0; k < expand_list[i].GetDimSize() - 1; k++) {
189 new_shape.push_back(expand_list[i].GetDimByIdx(k));
190 size *= expand_list[i].GetDimByIdx(k);
191 }
192 new_shape.push_back(GetDimByIdx(i) / size);
193 } else {
194 new_shape.push_back(GetDimByIdx(i));
195 }
196 }
197 Arrangement arrangement_new;
198 (void)arrangement_new.Init(new_shape);
199 return std::make_shared<Arrangement>(arrangement_new);
200 }
201
GetUnifiedShape(const Arrangement & in2) const202 std::shared_ptr<Arrangement> Arrangement::GetUnifiedShape(const Arrangement &in2) const {
203 std::vector<int64_t> in1_accum;
204 Status status = ShapeToAccumulateProduct(array_, &in1_accum);
205 if (status != Status::SUCCESS) {
206 return nullptr;
207 }
208 std::vector<int64_t> in2_accum;
209 status = ShapeToAccumulateProduct(in2.array(), &in2_accum);
210 if (status != Status::SUCCESS) {
211 return nullptr;
212 }
213 std::vector<int64_t> out_accum;
214 status = UnifyAccumulateProduct(in1_accum, in2_accum, &out_accum);
215 if (status != Status::SUCCESS) {
216 return nullptr;
217 }
218 Shape out_shape;
219 status = AccumulateProductToShape(out_accum, &out_shape);
220 if (status != Status::SUCCESS) {
221 return nullptr;
222 }
223 Arrangement out;
224 status = out.Init(out_shape);
225 if (status != Status::SUCCESS) {
226 return nullptr;
227 }
228 return std::make_shared<Arrangement>(out);
229 }
230
GetSqueezeIdx() const231 std::vector<size_t> Arrangement::GetSqueezeIdx() const {
232 std::vector<size_t> out;
233 for (size_t i = 0; i < GetDimSize(); i++) {
234 if (GetDimByIdx(SizeToUlong(i)) == 1) {
235 out.push_back(i);
236 }
237 }
238 return out;
239 }
240
GetSqueezeArrangement() const241 Arrangement Arrangement::GetSqueezeArrangement() const {
242 Shape out_shape(array_.size());
243 auto it = std::copy_if(array_.begin(), array_.end(), out_shape.begin(), [](int64_t value) { return value != 1; });
244 out_shape.resize(LongToSize(std::distance(out_shape.begin(), it)));
245
246 // if all elements are 1, out_shape = {1}
247 if (out_shape.empty()) {
248 MS_LOG(ERROR) << "out_shape size is 0, this may not happen under current situation";
249 out_shape.push_back(1);
250 }
251 Arrangement out;
252 (void)out.Init(out_shape);
253 return out;
254 }
255 } // namespace parallel
256 } // namespace mindspore
257