• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/tensor_layout/arrangement.h"
18 
19 #include <algorithm>
20 #include <utility>
21 
22 #include "frontend/parallel/status.h"
23 #include "frontend/parallel/tensor_layout/shape_util.h"
24 #include "include/common/utils/convert_utils.h"
25 #include "utils/log_adapter.h"
26 
27 namespace mindspore {
28 namespace parallel {
Init(const Shape & array)29 Status Arrangement::Init(const Shape &array) {
30   Status status = Array::Init(array);
31   if (status != Status::SUCCESS) {
32     return Status::FAILED;
33   }
34   if (!IsValidArrangement()) {
35     MS_LOG(ERROR) << "invalid arrangement " << this->ToString();
36     return Status::FAILED;
37   }
38   ComputeSize();
39   return Status::SUCCESS;
40 }
41 
UpdateTensorShape(size_t index,int64_t update_value)42 Status Arrangement::UpdateTensorShape(size_t index, int64_t update_value) {
43   if (index >= this->array_.size()) {
44     return Status::FAILED;
45   }
46   this->array_[index] = update_value;
47   return Status::SUCCESS;
48 }
49 
IsValidArrangement()50 bool Arrangement::IsValidArrangement() {
51   return !std::any_of(array_.begin(), array_.end(), [](int64_t value) { return value <= 0 && value != -1; });
52 }
53 
ComputeSize()54 void Arrangement::ComputeSize() {
55   size_ = 1;
56   for (auto &value : array_) {
57     size_ *= value;
58   }
59 }
60 
61 /*
62  * if GetDimSize() = 0, return []
63  * if value <= array_[0], return [value]
64  * if array_[0] < value <= size_[i], return [shape[0], shape[1], ..., shape[i-1], value/size_[i-1]],
65  * where size_[i-1] = shape[0] * shape[1] * ... * shape[i-1],
66  * if value > size_, return []
67  */
GetFrontElementByValue(int64_t value) const68 Shape Arrangement::GetFrontElementByValue(int64_t value) const {
69   Shape out;
70   if (GetDimSize() == 0) {
71     return out;
72   }
73   if (value <= size_) {
74     int64_t size = 1;
75     size_t shape_list_idx = 0;
76     while (size < value) {
77       size *= array_[shape_list_idx];
78       if (size <= value) {
79         out.push_back(array_[shape_list_idx]);
80       } else {
81         if (size == 0) {
82           MS_LOG(ERROR) << "The size is 0";
83           out.clear();
84           return out;
85         }
86         out.push_back(value * array_[shape_list_idx] / size);
87       }
88       shape_list_idx++;
89     }
90   }
91   return out;
92 }
93 
GetExpandedShapeByExpandListRemoveLeft(const std::vector<Arrangement> & expand_list) const94 std::shared_ptr<Arrangement> Arrangement::GetExpandedShapeByExpandListRemoveLeft(
95   const std::vector<Arrangement> &expand_list) const {
96   if (expand_list.size() != GetDimSize()) {
97     return nullptr;
98   }
99   Shape new_shape;
100   for (size_t i = 0; i < expand_list.size(); i++) {
101     Shape expand_shape = expand_list[i].GetFrontElementByValue(GetDimByIdx(i));
102     if (expand_shape.empty()) {
103       new_shape.push_back(GetDimByIdx(i));
104     } else {
105       (void)new_shape.insert(new_shape.cend(), expand_shape.cbegin(), expand_shape.cend());
106     }
107   }
108   Arrangement arrangement_new;
109   (void)arrangement_new.Init(new_shape);
110   return std::make_shared<Arrangement>(arrangement_new);
111 }
112 
113 /*
114  *  example:
115  *    expand_shape = [4, 2, 2, 2]
116  *    array_ = [8, 4],
117  *    arrangement_list = [[4, 2], [2, 2]]
118  */
GetExpandShapeList(const Arrangement & expand_shape) const119 std::shared_ptr<std::vector<Arrangement>> Arrangement::GetExpandShapeList(const Arrangement &expand_shape) const {
120   int64_t size = 1;
121   size_t ind = 0;
122   std::vector<Arrangement> arrangement_list;
123   Shape shape;
124   for (size_t i = 0; i < expand_shape.GetDimSize(); i++) {
125     size *= expand_shape.GetDimByIdx(i);
126     if (size > GetDimByIdx(ind)) {
127       MS_LOG(INFO) << "invalid expand_shape:" << expand_shape.array();
128       return nullptr;
129     } else if (size < GetDimByIdx(ind)) {
130       shape.push_back(expand_shape.GetDimByIdx(i));
131       continue;
132     } else {
133       shape.push_back(expand_shape.GetDimByIdx(i));
134       Arrangement arrangement;
135       (void)arrangement.Init(shape);
136       arrangement_list.push_back(arrangement);
137       shape.clear();
138       ind++;
139       size = 1;
140     }
141   }
142   if (ind != GetDimSize()) {
143     MS_LOG(INFO) << "invalid expand_shape:" << expand_shape.array();
144     return nullptr;
145   }
146   auto arrangement_new = std::make_shared<std::vector<Arrangement>>(arrangement_list);
147   return arrangement_new;
148 }
149 
GetExpandShapeListPair(const Arrangement & expand_shape) const150 std::shared_ptr<std::pair<std::vector<Arrangement>, Arrangement>> Arrangement::GetExpandShapeListPair(
151   const Arrangement &expand_shape) const {
152   std::shared_ptr<std::vector<Arrangement>> expand_shape_list_ptr = GetExpandShapeList(expand_shape);
153   if (expand_shape_list_ptr == nullptr) {
154     return nullptr;
155   }
156   Shape expand_num_list_shape;
157   (void)std::transform(expand_shape_list_ptr->begin(), expand_shape_list_ptr->end(),
158                        std::back_inserter(expand_num_list_shape),
159                        [](const Arrangement &arr) { return SizeToLong(arr.GetDimSize()); });
160   Arrangement expand_num_list;
161   Status status = expand_num_list.Init(expand_num_list_shape);
162   if (status != Status::SUCCESS) {
163     return nullptr;
164   }
165   auto out_value = std::make_pair(*expand_shape_list_ptr, expand_num_list);
166   return std::make_shared<std::pair<std::vector<Arrangement>, Arrangement>>(out_value);
167 }
168 
ComputeReverseAccumulateSumInReverseOrder() const169 Shape Arrangement::ComputeReverseAccumulateSumInReverseOrder() const {
170   Shape shape_accum;
171   int64_t size = 0;
172   for (auto iter = array_.end() - 1; iter >= array_.begin(); --iter) {
173     shape_accum.push_back(size);
174     size += *iter;
175   }
176   return shape_accum;
177 }
178 
GetExpandedShapeByExpandListReserveLeft(const std::vector<Arrangement> & expand_list) const179 std::shared_ptr<Arrangement> Arrangement::GetExpandedShapeByExpandListReserveLeft(
180   const std::vector<Arrangement> &expand_list) const {
181   if (expand_list.size() != GetDimSize()) {
182     return nullptr;
183   }
184   Shape new_shape;
185   for (size_t i = 0; i < expand_list.size(); i++) {
186     if (expand_list[i].GetDimSize() >= 1) {
187       int64_t size = 1;
188       for (size_t k = 0; k < expand_list[i].GetDimSize() - 1; k++) {
189         new_shape.push_back(expand_list[i].GetDimByIdx(k));
190         size *= expand_list[i].GetDimByIdx(k);
191       }
192       new_shape.push_back(GetDimByIdx(i) / size);
193     } else {
194       new_shape.push_back(GetDimByIdx(i));
195     }
196   }
197   Arrangement arrangement_new;
198   (void)arrangement_new.Init(new_shape);
199   return std::make_shared<Arrangement>(arrangement_new);
200 }
201 
GetUnifiedShape(const Arrangement & in2) const202 std::shared_ptr<Arrangement> Arrangement::GetUnifiedShape(const Arrangement &in2) const {
203   std::vector<int64_t> in1_accum;
204   Status status = ShapeToAccumulateProduct(array_, &in1_accum);
205   if (status != Status::SUCCESS) {
206     return nullptr;
207   }
208   std::vector<int64_t> in2_accum;
209   status = ShapeToAccumulateProduct(in2.array(), &in2_accum);
210   if (status != Status::SUCCESS) {
211     return nullptr;
212   }
213   std::vector<int64_t> out_accum;
214   status = UnifyAccumulateProduct(in1_accum, in2_accum, &out_accum);
215   if (status != Status::SUCCESS) {
216     return nullptr;
217   }
218   Shape out_shape;
219   status = AccumulateProductToShape(out_accum, &out_shape);
220   if (status != Status::SUCCESS) {
221     return nullptr;
222   }
223   Arrangement out;
224   status = out.Init(out_shape);
225   if (status != Status::SUCCESS) {
226     return nullptr;
227   }
228   return std::make_shared<Arrangement>(out);
229 }
230 
GetSqueezeIdx() const231 std::vector<size_t> Arrangement::GetSqueezeIdx() const {
232   std::vector<size_t> out;
233   for (size_t i = 0; i < GetDimSize(); i++) {
234     if (GetDimByIdx(SizeToUlong(i)) == 1) {
235       out.push_back(i);
236     }
237   }
238   return out;
239 }
240 
GetSqueezeArrangement() const241 Arrangement Arrangement::GetSqueezeArrangement() const {
242   Shape out_shape(array_.size());
243   auto it = std::copy_if(array_.begin(), array_.end(), out_shape.begin(), [](int64_t value) { return value != 1; });
244   out_shape.resize(LongToSize(std::distance(out_shape.begin(), it)));
245 
246   // if all elements are 1, out_shape = {1}
247   if (out_shape.empty()) {
248     MS_LOG(ERROR) << "out_shape size is 0, this may not happen under current situation";
249     out_shape.push_back(1);
250   }
251   Arrangement out;
252   (void)out.Init(out_shape);
253   return out;
254 }
255 }  // namespace parallel
256 }  // namespace mindspore
257