• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/tensor_layout/arrangement.h"
18 #include <algorithm>
19 #include <utility>
20 #include "utils/ms_utils.h"
21 #include "frontend/parallel/status.h"
22 #include "frontend/parallel/tensor_layout/shape_util.h"
23 #include "utils/convert_utils.h"
24 #include "utils/log_adapter.h"
25 
26 namespace mindspore {
27 namespace parallel {
Init(const Shape & array)28 Status Arrangement::Init(const Shape &array) {
29   Status status = Array::Init(array);
30   if (status != Status::SUCCESS) {
31     return Status::FAILED;
32   }
33   if (!IsValidArrangement()) {
34     MS_LOG(ERROR) << "invalid arrangement " << this->ToString();
35     return Status::FAILED;
36   }
37   ComputeSize();
38   return Status::SUCCESS;
39 }
40 
IsValidArrangement()41 bool Arrangement::IsValidArrangement() {
42   return !std::any_of(array_.begin(), array_.end(), [](int64_t value) { return value <= 0 && value != -1; });
43 }
44 
ComputeSize()45 void Arrangement::ComputeSize() {
46   size_ = 1;
47   for (auto &value : array_) {
48     size_ *= value;
49   }
50 }
51 
52 /*
53  * if GetDimSize() = 0, return []
54  * if value <= array_[0], return [value]
55  * if array_[0] < value <= size_[i], return [shape[0], shape[1], ..., shape[i-1], value/size_[i-1]],
56  * where size_[i-1] = shape[0] * shape[1] * ... * shape[i-1],
57  * if value > size_, return []
58  */
GetFrontElementByValue(int64_t value) const59 Shape Arrangement::GetFrontElementByValue(int64_t value) const {
60   Shape out;
61   if (GetDimSize() == 0) {
62     return out;
63   }
64   if (value <= size_) {
65     int64_t size = 1;
66     size_t shape_list_idx = 0;
67     while (size < value) {
68       size *= array_[shape_list_idx];
69       if (size <= value) {
70         out.push_back(array_[shape_list_idx]);
71       } else {
72         if (size == 0) {
73           MS_LOG(ERROR) << "The size is 0";
74           out.clear();
75           return out;
76         }
77         out.push_back(value * array_[shape_list_idx] / size);
78       }
79       shape_list_idx++;
80     }
81   }
82   return out;
83 }
84 
GetExpandedShapeByExpandListRemoveLeft(const std::vector<Arrangement> & expand_list) const85 std::shared_ptr<Arrangement> Arrangement::GetExpandedShapeByExpandListRemoveLeft(
86   const std::vector<Arrangement> &expand_list) const {
87   if (expand_list.size() != GetDimSize()) {
88     return nullptr;
89   }
90   Shape new_shape;
91   for (size_t i = 0; i < expand_list.size(); i++) {
92     Shape expand_shape = expand_list[i].GetFrontElementByValue(GetDimByIdx(i));
93     if (expand_shape.empty()) {
94       new_shape.push_back(GetDimByIdx(i));
95     } else {
96       (void)new_shape.insert(new_shape.end(), expand_shape.begin(), expand_shape.end());
97     }
98   }
99   Arrangement arrangement_new;
100   (void)arrangement_new.Init(new_shape);
101   return std::make_shared<Arrangement>(arrangement_new);
102 }
103 
104 /*
105  *  example:
106  *    expand_shape = [4, 2, 2, 2]
107  *    array_ = [8, 4],
108  *    arrangement_list = [[4, 2], [2, 2]]
109  */
GetExpandShapeList(const Arrangement & expand_shape) const110 std::shared_ptr<std::vector<Arrangement>> Arrangement::GetExpandShapeList(const Arrangement &expand_shape) const {
111   int64_t size = 1;
112   size_t ind = 0;
113   std::vector<Arrangement> arrangement_list;
114   Shape shape;
115   for (size_t i = 0; i < expand_shape.GetDimSize(); i++) {
116     size *= expand_shape.GetDimByIdx(i);
117     if (size > GetDimByIdx(ind)) {
118       MS_LOG(ERROR) << "invalid expand_shape";
119       return nullptr;
120     } else if (size < GetDimByIdx(ind)) {
121       shape.push_back(expand_shape.GetDimByIdx(i));
122       continue;
123     } else {
124       shape.push_back(expand_shape.GetDimByIdx(i));
125       Arrangement arrangement;
126       (void)arrangement.Init(shape);
127       arrangement_list.push_back(arrangement);
128       shape.clear();
129       ind++;
130       size = 1;
131     }
132   }
133   if (ind != GetDimSize()) {
134     MS_LOG(ERROR) << "invalid expand_shape";
135     return nullptr;
136   }
137   auto arrangement_new = std::make_shared<std::vector<Arrangement>>(arrangement_list);
138   return arrangement_new;
139 }
140 
GetExpandShapeListPair(const Arrangement & expand_shape) const141 std::shared_ptr<std::pair<std::vector<Arrangement>, Arrangement>> Arrangement::GetExpandShapeListPair(
142   const Arrangement &expand_shape) const {
143   std::shared_ptr<std::vector<Arrangement>> expand_shape_list_ptr = GetExpandShapeList(expand_shape);
144   if (expand_shape_list_ptr == nullptr) {
145     return nullptr;
146   }
147   Shape expand_num_list_shape;
148   (void)std::transform(expand_shape_list_ptr->begin(), expand_shape_list_ptr->end(),
149                        std::back_inserter(expand_num_list_shape),
150                        [](const Arrangement &arr) { return SizeToLong(arr.GetDimSize()); });
151   Arrangement expand_num_list;
152   Status status = expand_num_list.Init(expand_num_list_shape);
153   if (status != Status::SUCCESS) {
154     return nullptr;
155   }
156   auto out_value = std::make_pair(*expand_shape_list_ptr, expand_num_list);
157   return std::make_shared<std::pair<std::vector<Arrangement>, Arrangement>>(out_value);
158 }
159 
ComputeReverseAccumulateSumInReverseOrder() const160 Shape Arrangement::ComputeReverseAccumulateSumInReverseOrder() const {
161   Shape shape_accum;
162   int64_t size = 0;
163   for (auto iter = array_.end() - 1; iter >= array_.begin(); --iter) {
164     shape_accum.push_back(size);
165     size += *iter;
166   }
167   return shape_accum;
168 }
169 
GetExpandedShapeByExpandListReserveLeft(const std::vector<Arrangement> & expand_list) const170 std::shared_ptr<Arrangement> Arrangement::GetExpandedShapeByExpandListReserveLeft(
171   const std::vector<Arrangement> &expand_list) const {
172   if (expand_list.size() != GetDimSize()) {
173     return nullptr;
174   }
175   Shape new_shape;
176   for (size_t i = 0; i < expand_list.size(); i++) {
177     if (expand_list[i].GetDimSize() >= 1) {
178       int64_t size = 1;
179       for (size_t k = 0; k < expand_list[i].GetDimSize() - 1; k++) {
180         new_shape.push_back(expand_list[i].GetDimByIdx(k));
181         size *= expand_list[i].GetDimByIdx(k);
182       }
183       new_shape.push_back(GetDimByIdx(i) / size);
184     } else {
185       new_shape.push_back(GetDimByIdx(i));
186     }
187   }
188   Arrangement arrangement_new;
189   (void)arrangement_new.Init(new_shape);
190   return std::make_shared<Arrangement>(arrangement_new);
191 }
192 
GetUnifiedShape(const Arrangement & in2) const193 std::shared_ptr<Arrangement> Arrangement::GetUnifiedShape(const Arrangement &in2) const {
194   std::vector<int64_t> in1_accum;
195   Status status = ShapeToAccumulateProduct(array_, &in1_accum);
196   if (status != Status::SUCCESS) {
197     return nullptr;
198   }
199   std::vector<int64_t> in2_accum;
200   status = ShapeToAccumulateProduct(in2.array(), &in2_accum);
201   if (status != Status::SUCCESS) {
202     return nullptr;
203   }
204   std::vector<int64_t> out_accum;
205   status = UnifyAccumulateProduct(in1_accum, in2_accum, &out_accum);
206   if (status != Status::SUCCESS) {
207     return nullptr;
208   }
209   Shape out_shape;
210   status = AccumulateProductToShape(out_accum, &out_shape);
211   if (status != Status::SUCCESS) {
212     return nullptr;
213   }
214   Arrangement out;
215   status = out.Init(out_shape);
216   if (status != Status::SUCCESS) {
217     return nullptr;
218   }
219   return std::make_shared<Arrangement>(out);
220 }
221 
GetSqueezeIdx() const222 std::vector<size_t> Arrangement::GetSqueezeIdx() const {
223   std::vector<size_t> out;
224   for (size_t i = 0; i < GetDimSize(); i++) {
225     if (GetDimByIdx(SizeToUlong(i)) == 1) {
226       out.push_back(i);
227     }
228   }
229   return out;
230 }
231 
GetSqueezeArrangement() const232 Arrangement Arrangement::GetSqueezeArrangement() const {
233   Shape out_shape(array_.size());
234   auto it = std::copy_if(array_.begin(), array_.end(), out_shape.begin(), [](int64_t value) { return value != 1; });
235   out_shape.resize(LongToSize(std::distance(out_shape.begin(), it)));
236 
237   // if all elements are 1, out_shape = {1}
238   if (out_shape.empty()) {
239     MS_LOG(ERROR) << "out_shape size is 0, this may not happen under current situation";
240     out_shape.push_back(1);
241   }
242   Arrangement out;
243   (void)out.Init(out_shape);
244   return out;
245 }
246 }  // namespace parallel
247 }  // namespace mindspore
248