• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "frontend/parallel/tensor_layout/shape_util.h"
18 #include "frontend/parallel/status.h"
19 #include "utils/log_adapter.h"
20 
21 namespace mindspore {
22 namespace parallel {
23 /*
24  * example:
25  * shape = [2, 8, 32]
26  * shape_accum = [2, 2 * 8, 2 * 8 * 32]
27  */
ShapeToAccumulateProduct(const Shape & shape,Shape * shape_accum)28 Status ShapeToAccumulateProduct(const Shape &shape, Shape *shape_accum) {
29   MS_EXCEPTION_IF_NULL(shape_accum);
30   shape_accum->clear();
31   int64_t size = 1;
32   for (auto iter = shape.begin(); iter < shape.end(); ++iter) {
33     size *= *iter;
34     if (size <= 0) {
35       MS_LOG(ERROR) << "element of shape should not be zero";
36       return Status::FAILED;
37     }
38     shape_accum->push_back(size);
39   }
40   return Status::SUCCESS;
41 }
42 
43 /*
44  * example:
45  * shape = [2, 8, 32]
46  * shape_accum = [2 * 8 * 32, 8 * 32, 32]
47  *
48  */
ShapeToAccumulateProductReverse(const Shape & shape,Shape * shape_accum)49 Status ShapeToAccumulateProductReverse(const Shape &shape, Shape *shape_accum) {
50   MS_EXCEPTION_IF_NULL(shape_accum);
51   shape_accum->clear();
52   int64_t size = 1;
53   for (auto iter = shape.end() - 1; iter >= shape.begin(); --iter) {
54     size *= *iter;
55     if (size <= 0) {
56       MS_LOG(ERROR) << "element of shape should not be zero";
57       return Status::FAILED;
58     }
59     (void)shape_accum->insert(shape_accum->begin(), size);
60   }
61   return Status::SUCCESS;
62 }
63 
64 /*
65  * example:
66  * shape_accum = [2, 2 * 8, 2 * 8 * 32]
67  * shape = [2, 8, 32]
68  *
69  */
AccumulateProductToShape(const Shape & shape_accum,Shape * shape)70 Status AccumulateProductToShape(const Shape &shape_accum, Shape *shape) {
71   MS_EXCEPTION_IF_NULL(shape);
72   shape->clear();
73   int64_t value = 1;
74   for (auto iter = shape_accum.begin(); iter < shape_accum.end(); ++iter) {
75     if ((*iter) == 0) {
76       MS_LOG(ERROR) << "element of shape_accum should not be zero";
77       return Status::FAILED;
78     }
79     if ((*iter) % value != 0) {
80       MS_LOG(ERROR) << "shape_accum is not a accumulate product in ascending order";
81       return Status::FAILED;
82     }
83     shape->push_back(static_cast<int64_t>((*iter) / value));
84     value = (*iter);
85   }
86   return Status::SUCCESS;
87 }
88 
89 /*
90  * example:
91  * shape_accum_reverse = [2 * 8 * 32, 8 * 32, 32]
92  * shape = [2, 8, 32]
93  */
AccumulateProductReverseToShape(const Shape & shape_accum_reverse,Shape * shape)94 Status AccumulateProductReverseToShape(const Shape &shape_accum_reverse, Shape *shape) {
95   MS_EXCEPTION_IF_NULL(shape);
96   shape->clear();
97   int64_t value = 1;
98   for (auto iter = shape_accum_reverse.end() - 1; iter >= shape_accum_reverse.begin(); --iter) {
99     if (*iter == 0) {
100       MS_LOG(WARNING) << "element of shape_accum should not be zero";
101       return Status::FAILED;
102     }
103     if ((*iter) % value != 0) {
104       MS_LOG(DEBUG) << "shape_accum is not a accumulate product in ascending order";
105       return Status::FAILED;
106     }
107     (void)shape->insert(shape->begin(), static_cast<int64_t>((*iter) / value));
108     value = *iter;
109   }
110   return Status::SUCCESS;
111 }
112 
113 /*
114  * example1:
115  * in1 = [2, 8]
116  * in2 = [4, 8]
117  * *out = [2, 4, 8]
118  *
119  * example2:
120  * in1 = [2, 4, 16]
121  * in2 = [8, 16]
122  * *out = [2, 4, 8, 16]
123  */
UnifyAccumulateProduct(const Shape & in1_accum,const Shape & in2_accum,Shape * out_accum)124 Status UnifyAccumulateProduct(const Shape &in1_accum, const Shape &in2_accum, Shape *out_accum) {
125   MS_EXCEPTION_IF_NULL(out_accum);
126   out_accum->clear();
127   auto in1_iter = in1_accum.begin();
128   auto in2_iter = in2_accum.begin();
129   while ((in1_iter < in1_accum.end()) || (in2_iter < in2_accum.end())) {
130     if ((*in1_iter <= 0) || (*in2_iter <= 0)) {
131       MS_LOG(ERROR) << "element of in1 and in2 must be larger than zero";
132       return Status::FAILED;
133     }
134     if (*in1_iter < *in2_iter) {
135       out_accum->push_back(*in1_iter);
136       ++in1_iter;
137       continue;
138     } else if (*in1_iter == *in2_iter) {
139       out_accum->push_back(*in1_iter);
140       ++in1_iter;
141       ++in2_iter;
142     } else {
143       out_accum->push_back(*in2_iter);
144       ++in2_iter;
145     }
146   }
147   if ((in1_iter != in1_accum.end()) || (in2_iter != in2_accum.end())) {
148     MS_LOG(ERROR) << "last element of in1 and in2 must be equal";
149     return Status::FAILED;
150   }
151   return Status::SUCCESS;
152 }
153 
154 /*
155  * example:
156  * in1 = [8, 4]
157  * in2 = [2, 16]
158  * out = [2, 4, 4]
159  */
UnifyShape(const Shape & in1,const Shape & in2,Shape * out)160 Status UnifyShape(const Shape &in1, const Shape &in2, Shape *out) {
161   MS_EXCEPTION_IF_NULL(out);
162   Shape in1_accum;
163   Status status = ShapeToAccumulateProduct(in1, &in1_accum);
164   if (status != Status::SUCCESS) {
165     return status;
166   }
167   Shape in2_accum;
168   status = ShapeToAccumulateProduct(in2, &in2_accum);
169   if (status != Status::SUCCESS) {
170     return status;
171   }
172   Shape out_accum;
173   status = UnifyAccumulateProduct(in1_accum, in2_accum, &out_accum);
174   if (status != Status::SUCCESS) {
175     return status;
176   }
177   status = AccumulateProductToShape(out_accum, out);
178   if (status != Status::SUCCESS) {
179     return status;
180   }
181   return status;
182 }
183 
184 /*
185  * example1:
186  * in_accum_reverse = [2 * 8 * 32, 8 * 32, 32]
187  * expand_accum_reverse = [2 * 8 * 32, 32, 8]
188  * out_accum_reverse = [2 * 8 * 4 * 8, 8 * 4 * 8, 4 * 8, 8]
189  *
190  * example2:
191  * in_accum_reverse = [2 * 8 * 32, 8 * 32, 32]
192  * expand_accum_reverse = [2 * 4 * 8, 4 * 8, 8]
193  * out_accum_reverse = [2 * 4 * 2 * 4 * 8, 4 * 2 * 4 * 8, 2 * 4 * 8, 4 * 8, 8]
194  */
ExpandAccumulateProduct(const Shape & in_accum_reverse,const Shape & expand_accum_reverse,Shape * out_accum_reverse)195 Status ExpandAccumulateProduct(const Shape &in_accum_reverse, const Shape &expand_accum_reverse,
196                                Shape *out_accum_reverse) {
197   MS_EXCEPTION_IF_NULL(out_accum_reverse);
198   out_accum_reverse->clear();
199   auto in_riter = in_accum_reverse.rbegin();
200   auto expand_riter = expand_accum_reverse.rbegin();
201   while (expand_riter != expand_accum_reverse.rend()) {
202     if (in_riter == in_accum_reverse.rend()) {
203       MS_LOG(ERROR) << "invalid ExpandAccumProd inputs";
204       return Status::FAILED;
205     }
206     if (*in_riter > *expand_riter) {
207       (void)out_accum_reverse->insert(out_accum_reverse->begin(), *expand_riter);
208       ++expand_riter;
209     } else if (*in_riter == *expand_riter) {
210       (void)out_accum_reverse->insert(out_accum_reverse->begin(), *expand_riter);
211       ++in_riter;
212       ++expand_riter;
213     } else {
214       (void)out_accum_reverse->insert(out_accum_reverse->begin(), *in_riter);
215       ++in_riter;
216     }
217   }
218   while (in_riter != in_accum_reverse.rend()) {
219     (void)out_accum_reverse->insert(out_accum_reverse->begin(), *in_riter);
220     ++in_riter;
221   }
222   return Status::SUCCESS;
223 }
224 
225 /*
226  * example1:
227  * in = [2, 8, 32]
228  * expand = [16, 4, 8]
229  * out = [2, 8, 4, 8]
230  *
231  * example2:
232  * in = [2, 8, 32]
233  * expand = [2, 4, 8]
234  * out = [2, 4, 2, 4, 8]
235  */
ExpandShape(const Shape & in,const Shape & expand,Shape * out)236 Status ExpandShape(const Shape &in, const Shape &expand, Shape *out) {
237   MS_EXCEPTION_IF_NULL(out);
238   Shape in_accum_reverse;
239   Status status = ShapeToAccumulateProductReverse(in, &in_accum_reverse);
240   if (status != Status::SUCCESS) {
241     return status;
242   }
243   Shape expand_accum_reverse;
244   status = ShapeToAccumulateProductReverse(expand, &expand_accum_reverse);
245   if (status != Status::SUCCESS) {
246     return status;
247   }
248   Shape out_accum_reverse;
249   status = ExpandAccumulateProduct(in_accum_reverse, expand_accum_reverse, &out_accum_reverse);
250   if (status != Status::SUCCESS) {
251     return status;
252   }
253   status = AccumulateProductReverseToShape(out_accum_reverse, out);
254   if (status != Status::SUCCESS) {
255     return status;
256   }
257   return status;
258 }
259 }  // namespace parallel
260 }  // namespace mindspore
261