1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "frontend/parallel/tensor_layout/shape_util.h"
18 #include "frontend/parallel/status.h"
19 #include "utils/log_adapter.h"
20
21 namespace mindspore {
22 namespace parallel {
23 /*
24 * example:
25 * shape = [2, 8, 32]
26 * shape_accum = [2, 2 * 8, 2 * 8 * 32]
27 */
ShapeToAccumulateProduct(const Shape & shape,Shape * shape_accum)28 Status ShapeToAccumulateProduct(const Shape &shape, Shape *shape_accum) {
29 MS_EXCEPTION_IF_NULL(shape_accum);
30 shape_accum->clear();
31 int64_t size = 1;
32 for (auto iter = shape.begin(); iter < shape.end(); ++iter) {
33 size *= *iter;
34 if (size <= 0) {
35 MS_LOG(ERROR) << "element of shape should not be zero";
36 return Status::FAILED;
37 }
38 shape_accum->push_back(size);
39 }
40 return Status::SUCCESS;
41 }
42
43 /*
44 * example:
45 * shape = [2, 8, 32]
46 * shape_accum = [2 * 8 * 32, 8 * 32, 32]
47 *
48 */
ShapeToAccumulateProductReverse(const Shape & shape,Shape * shape_accum)49 Status ShapeToAccumulateProductReverse(const Shape &shape, Shape *shape_accum) {
50 MS_EXCEPTION_IF_NULL(shape_accum);
51 shape_accum->clear();
52 int64_t size = 1;
53 for (auto iter = shape.end() - 1; iter >= shape.begin(); --iter) {
54 size *= *iter;
55 if (size <= 0) {
56 MS_LOG(ERROR) << "element of shape should not be zero";
57 return Status::FAILED;
58 }
59 (void)shape_accum->insert(shape_accum->begin(), size);
60 }
61 return Status::SUCCESS;
62 }
63
64 /*
65 * example:
66 * shape_accum = [2, 2 * 8, 2 * 8 * 32]
67 * shape = [2, 8, 32]
68 *
69 */
AccumulateProductToShape(const Shape & shape_accum,Shape * shape)70 Status AccumulateProductToShape(const Shape &shape_accum, Shape *shape) {
71 MS_EXCEPTION_IF_NULL(shape);
72 shape->clear();
73 int64_t value = 1;
74 for (auto iter = shape_accum.begin(); iter < shape_accum.end(); ++iter) {
75 if ((*iter) == 0) {
76 MS_LOG(ERROR) << "element of shape_accum should not be zero";
77 return Status::FAILED;
78 }
79 if ((*iter) % value != 0) {
80 MS_LOG(ERROR) << "shape_accum is not a accumulate product in ascending order";
81 return Status::FAILED;
82 }
83 shape->push_back(static_cast<int64_t>((*iter) / value));
84 value = (*iter);
85 }
86 return Status::SUCCESS;
87 }
88
89 /*
90 * example:
91 * shape_accum_reverse = [2 * 8 * 32, 8 * 32, 32]
92 * shape = [2, 8, 32]
93 */
AccumulateProductReverseToShape(const Shape & shape_accum_reverse,Shape * shape)94 Status AccumulateProductReverseToShape(const Shape &shape_accum_reverse, Shape *shape) {
95 MS_EXCEPTION_IF_NULL(shape);
96 shape->clear();
97 int64_t value = 1;
98 for (auto iter = shape_accum_reverse.end() - 1; iter >= shape_accum_reverse.begin(); --iter) {
99 if (*iter == 0) {
100 MS_LOG(WARNING) << "element of shape_accum should not be zero";
101 return Status::FAILED;
102 }
103 if ((*iter) % value != 0) {
104 MS_LOG(DEBUG) << "shape_accum is not a accumulate product in ascending order";
105 return Status::FAILED;
106 }
107 (void)shape->insert(shape->begin(), static_cast<int64_t>((*iter) / value));
108 value = *iter;
109 }
110 return Status::SUCCESS;
111 }
112
113 /*
114 * example1:
115 * in1 = [2, 8]
116 * in2 = [4, 8]
117 * *out = [2, 4, 8]
118 *
119 * example2:
120 * in1 = [2, 4, 16]
121 * in2 = [8, 16]
122 * *out = [2, 4, 8, 16]
123 */
UnifyAccumulateProduct(const Shape & in1_accum,const Shape & in2_accum,Shape * out_accum)124 Status UnifyAccumulateProduct(const Shape &in1_accum, const Shape &in2_accum, Shape *out_accum) {
125 MS_EXCEPTION_IF_NULL(out_accum);
126 out_accum->clear();
127 auto in1_iter = in1_accum.begin();
128 auto in2_iter = in2_accum.begin();
129 while ((in1_iter < in1_accum.end()) || (in2_iter < in2_accum.end())) {
130 if ((*in1_iter <= 0) || (*in2_iter <= 0)) {
131 MS_LOG(ERROR) << "element of in1 and in2 must be larger than zero";
132 return Status::FAILED;
133 }
134 if (*in1_iter < *in2_iter) {
135 out_accum->push_back(*in1_iter);
136 ++in1_iter;
137 continue;
138 } else if (*in1_iter == *in2_iter) {
139 out_accum->push_back(*in1_iter);
140 ++in1_iter;
141 ++in2_iter;
142 } else {
143 out_accum->push_back(*in2_iter);
144 ++in2_iter;
145 }
146 }
147 if ((in1_iter != in1_accum.end()) || (in2_iter != in2_accum.end())) {
148 MS_LOG(ERROR) << "last element of in1 and in2 must be equal";
149 return Status::FAILED;
150 }
151 return Status::SUCCESS;
152 }
153
154 /*
155 * example:
156 * in1 = [8, 4]
157 * in2 = [2, 16]
158 * out = [2, 4, 4]
159 */
UnifyShape(const Shape & in1,const Shape & in2,Shape * out)160 Status UnifyShape(const Shape &in1, const Shape &in2, Shape *out) {
161 MS_EXCEPTION_IF_NULL(out);
162 Shape in1_accum;
163 Status status = ShapeToAccumulateProduct(in1, &in1_accum);
164 if (status != Status::SUCCESS) {
165 return status;
166 }
167 Shape in2_accum;
168 status = ShapeToAccumulateProduct(in2, &in2_accum);
169 if (status != Status::SUCCESS) {
170 return status;
171 }
172 Shape out_accum;
173 status = UnifyAccumulateProduct(in1_accum, in2_accum, &out_accum);
174 if (status != Status::SUCCESS) {
175 return status;
176 }
177 status = AccumulateProductToShape(out_accum, out);
178 if (status != Status::SUCCESS) {
179 return status;
180 }
181 return status;
182 }
183
184 /*
185 * example1:
186 * in_accum_reverse = [2 * 8 * 32, 8 * 32, 32]
187 * expand_accum_reverse = [2 * 8 * 32, 32, 8]
188 * out_accum_reverse = [2 * 8 * 4 * 8, 8 * 4 * 8, 4 * 8, 8]
189 *
190 * example2:
191 * in_accum_reverse = [2 * 8 * 32, 8 * 32, 32]
192 * expand_accum_reverse = [2 * 4 * 8, 4 * 8, 8]
193 * out_accum_reverse = [2 * 4 * 2 * 4 * 8, 4 * 2 * 4 * 8, 2 * 4 * 8, 4 * 8, 8]
194 */
ExpandAccumulateProduct(const Shape & in_accum_reverse,const Shape & expand_accum_reverse,Shape * out_accum_reverse)195 Status ExpandAccumulateProduct(const Shape &in_accum_reverse, const Shape &expand_accum_reverse,
196 Shape *out_accum_reverse) {
197 MS_EXCEPTION_IF_NULL(out_accum_reverse);
198 out_accum_reverse->clear();
199 auto in_riter = in_accum_reverse.rbegin();
200 auto expand_riter = expand_accum_reverse.rbegin();
201 while (expand_riter != expand_accum_reverse.rend()) {
202 if (in_riter == in_accum_reverse.rend()) {
203 MS_LOG(ERROR) << "invalid ExpandAccumProd inputs";
204 return Status::FAILED;
205 }
206 if (*in_riter > *expand_riter) {
207 (void)out_accum_reverse->insert(out_accum_reverse->begin(), *expand_riter);
208 ++expand_riter;
209 } else if (*in_riter == *expand_riter) {
210 (void)out_accum_reverse->insert(out_accum_reverse->begin(), *expand_riter);
211 ++in_riter;
212 ++expand_riter;
213 } else {
214 (void)out_accum_reverse->insert(out_accum_reverse->begin(), *in_riter);
215 ++in_riter;
216 }
217 }
218 while (in_riter != in_accum_reverse.rend()) {
219 (void)out_accum_reverse->insert(out_accum_reverse->begin(), *in_riter);
220 ++in_riter;
221 }
222 return Status::SUCCESS;
223 }
224
225 /*
226 * example1:
227 * in = [2, 8, 32]
228 * expand = [16, 4, 8]
229 * out = [2, 8, 4, 8]
230 *
231 * example2:
232 * in = [2, 8, 32]
233 * expand = [2, 4, 8]
234 * out = [2, 4, 2, 4, 8]
235 */
ExpandShape(const Shape & in,const Shape & expand,Shape * out)236 Status ExpandShape(const Shape &in, const Shape &expand, Shape *out) {
237 MS_EXCEPTION_IF_NULL(out);
238 Shape in_accum_reverse;
239 Status status = ShapeToAccumulateProductReverse(in, &in_accum_reverse);
240 if (status != Status::SUCCESS) {
241 return status;
242 }
243 Shape expand_accum_reverse;
244 status = ShapeToAccumulateProductReverse(expand, &expand_accum_reverse);
245 if (status != Status::SUCCESS) {
246 return status;
247 }
248 Shape out_accum_reverse;
249 status = ExpandAccumulateProduct(in_accum_reverse, expand_accum_reverse, &out_accum_reverse);
250 if (status != Status::SUCCESS) {
251 return status;
252 }
253 status = AccumulateProductReverseToShape(out_accum_reverse, out);
254 if (status != Status::SUCCESS) {
255 return status;
256 }
257 return status;
258 }
259 } // namespace parallel
260 } // namespace mindspore
261