1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "tools/optimizer/graph/slice_prepose_pass.h"
17 #include <vector>
18 #include <memory>
19 #include <set>
20 #include <algorithm>
21 #include "ops/fusion/full_connection.h"
22 #include "ops/reshape.h"
23 #include "ops/fusion/slice_fusion.h"
24 #include "ops/softmax.h"
25 #include "ops/op_utils.h"
26 #include "include/errorcode.h"
27 #include "tools/optimizer/common/gllo_utils.h"
28 #include "backend/optimizer/common/helper.h"
29 #include "src/common/log_adapter.h"
30 #include "nnacl/op_base.h"
31
32 namespace mindspore::opt {
33 namespace {
34 const int kArithmeticInputNum = 2;
35 const int SliceBeginIndex = 2;
36 const int SliceSizeIndex = 3;
37 int node_name_index = 0;
GetSliceBeginAndSize(const CNodePtr & cnode,const int index)38 std::vector<int> GetSliceBeginAndSize(const CNodePtr &cnode, const int index) {
39 MS_ASSERT(cnode != nullptr);
40 std::vector<int> content;
41 if (index != SliceBeginIndex && index != SliceSizeIndex && cnode->size() != 4) {
42 return content;
43 }
44 auto node = cnode->input(index);
45 if (node == nullptr) {
46 return content;
47 }
48 auto param_node = node->cast<ParameterPtr>();
49 if (param_node == nullptr || !param_node->has_default() || param_node->default_param() == nullptr) {
50 return content;
51 }
52 auto tensor_info = param_node->default_param()->cast<tensor::TensorPtr>();
53 if (tensor_info == nullptr) {
54 return content;
55 }
56 content.resize(tensor_info->DataSize());
57 if (memcpy_s(content.data(), tensor_info->Size(), tensor_info->data_c(), tensor_info->Size()) != EOK) {
58 MS_LOG(ERROR) << "memcpy data failed.";
59 return {};
60 }
61 return content;
62 }
63
GetCNodeInputShape(const CNodePtr & cnode,size_t index=1)64 std::vector<int64_t> GetCNodeInputShape(const CNodePtr &cnode, size_t index = 1) {
65 MS_ASSERT(cnode != nullptr);
66 std::vector<int64_t> empty_shape;
67 if (index < 1 || cnode->inputs().size() <= index) {
68 MS_LOG(ERROR) << "out of index";
69 return empty_shape;
70 }
71 auto abstract = GetCNodeInputAbstract(cnode, index);
72 if (abstract == nullptr) {
73 MS_LOG(ERROR) << "Abstract of CNode is nullptr";
74 return empty_shape;
75 }
76 if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
77 MS_LOG(DEBUG) << "abstract is not AbstractTensor";
78 return empty_shape;
79 }
80 auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
81 MS_ASSERT(abstract_tensor != nullptr && abstract_tensor->shape() != nullptr);
82 return abstract_tensor->shape()->shape();
83 }
84
GetDefaultParamShape(const ParameterPtr & param)85 std::vector<int64_t> GetDefaultParamShape(const ParameterPtr ¶m) {
86 MS_ASSERT(param != nullptr);
87 MS_ASSERT(param->has_default());
88 std::vector<int64_t> shape_vector;
89 auto default_param = param->default_param();
90 if (default_param == nullptr) {
91 MS_LOG(ERROR) << "default_param is nullptr";
92 return shape_vector;
93 }
94 if (!utils::isa<tensor::TensorPtr>(default_param)) {
95 MS_LOG(ERROR) << "default_param is not tensor::Tensor";
96 return shape_vector;
97 }
98 auto param_value_lite = utils::cast<tensor::TensorPtr>(default_param);
99 MS_ASSERT(param_value != nullptr);
100 auto shape = param_value_lite->shape();
101 std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
102 [](const int val) { return static_cast<int64_t>(val); });
103 return shape_vector;
104 }
105
IsScalarNode(const AnfNodePtr & nodePtr)106 bool IsScalarNode(const AnfNodePtr &nodePtr) {
107 MS_ASSERT(nodePtr != nullptr);
108 if (utils::isa<ParameterPtr>(nodePtr) && nodePtr->cast<ParameterPtr>()->has_default()) {
109 auto tensor = utils::cast<tensor::TensorPtr>(utils::cast<ParameterPtr>(nodePtr)->default_param());
110 MS_ASSERT(tensor != nullptr);
111 auto shape = tensor->shape();
112 if (shape.empty() || (shape.size() == 1 && shape[0] == 1)) {
113 return true;
114 }
115 }
116 return false;
117 }
118
GetSlice(const CNodePtr & cnode)119 std::shared_ptr<mindspore::ops::SliceFusion> GetSlice(const CNodePtr &cnode) {
120 if (cnode == nullptr) {
121 return nullptr;
122 }
123 return GetValueNode<std::shared_ptr<mindspore::ops::SliceFusion>>(cnode->input(0));
124 }
125
GetSoftmax(const CNodePtr & cnode)126 std::shared_ptr<mindspore::ops::Softmax> GetSoftmax(const CNodePtr &cnode) {
127 if (cnode == nullptr) {
128 return nullptr;
129 }
130 return GetValueNode<std::shared_ptr<mindspore::ops::Softmax>>(cnode->input(0));
131 }
132
GetReshape(const CNodePtr & cnode)133 std::shared_ptr<mindspore::ops::Reshape> GetReshape(const CNodePtr &cnode) {
134 if (cnode == nullptr) {
135 return nullptr;
136 }
137 return GetValueNode<std::shared_ptr<mindspore::ops::Reshape>>(cnode->input(0));
138 }
139
GetFc(const CNodePtr & cnode)140 std::shared_ptr<mindspore::ops::FullConnection> GetFc(const CNodePtr &cnode) {
141 if (cnode == nullptr) {
142 return nullptr;
143 }
144 return GetValueNode<std::shared_ptr<mindspore::ops::FullConnection>>(cnode->input(0));
145 }
146
GetTransposePerm(const CNodePtr & node)147 std::vector<int> GetTransposePerm(const CNodePtr &node) {
148 MS_ASSERT(node != nullptr);
149 std::vector<int> perm;
150 if (!CheckPrimitiveType(node, prim::kPrimTranspose)) {
151 return perm;
152 }
153 if (node->inputs().size() != 3) {
154 return perm;
155 }
156 auto perm_node = node->input(2);
157 if (!utils::isa<ParameterPtr>(perm_node)) {
158 return perm;
159 }
160 auto perm_param = perm_node->cast<ParameterPtr>();
161 MS_ASSERT(perm_param != nullptr);
162 if (!perm_param->has_default() || perm_param->default_param() == nullptr) {
163 return perm;
164 }
165 auto perm_value = perm_param->default_param()->cast<tensor::TensorPtr>();
166 if (perm_value == nullptr) {
167 return perm;
168 }
169 MS_CHECK_TRUE_MSG(perm_value->shape().size() != 0, {}, "shape is empty");
170 perm.resize(perm_value->shape()[0]);
171 if (memcpy_s(perm.data(), perm_value->Size(), perm_value->data_c(), perm_value->Size()) != EOK) {
172 MS_LOG(ERROR) << "memcpy failed.";
173 return {};
174 }
175 return perm;
176 }
177 } // namespace
178
ClearCNodeAbstractValue(const CNodePtr & cnode)179 void SlicePreposePass::ClearCNodeAbstractValue(const CNodePtr &cnode) {
180 MS_ASSERT(cnode != nullptr);
181 auto abstract = cnode->abstract();
182 MS_ASSERT(abstract != nullptr);
183 if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
184 MS_LOG(DEBUG) << "Abstract of cnode is not abstract tensor, " << cnode->fullname_with_scope();
185 }
186 abstract->set_value(std::make_shared<AnyValue>());
187 }
188
SwapSliceWithPreceed(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & preceed_cnode,const int index,const TransactionPtr & tr)189 STATUS SlicePreposePass::SwapSliceWithPreceed(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
190 const CNodePtr &preceed_cnode, const int index,
191 const TransactionPtr &tr) {
192 MS_ASSERT(graph != nullptr);
193 MS_ASSERT(slice_cnode != nullptr);
194 MS_ASSERT(preceed_cnode != nullptr);
195 if (slice_cnode->input(1) != preceed_cnode) {
196 MS_LOG(ERROR) << "proceed node must be slice node's direct parent";
197 return RET_ERROR;
198 }
199 if (IsMultiOutputTensors(graph, preceed_cnode)) {
200 MS_LOG(ERROR) << "proceed node referenced by multi nodes not support swap";
201 return RET_ERROR;
202 }
203 auto manager = graph->manager();
204 if (manager == nullptr) {
205 MS_LOG(ERROR) << "manager is nullptr";
206 return RET_ERROR;
207 }
208 auto node_users = manager->node_users()[slice_cnode];
209 if (tr != nullptr) { // do swap with transaction
210 for (auto &node_user : node_users) {
211 tr->SetEdge(node_user.first, node_user.second, preceed_cnode);
212 }
213 tr->SetEdge(slice_cnode, 1, preceed_cnode->input(index));
214 tr->SetEdge(preceed_cnode, index, slice_cnode);
215 } else {
216 for (auto &node_user : node_users) {
217 manager->SetEdge(node_user.first, node_user.second, preceed_cnode);
218 }
219 manager->SetEdge(slice_cnode, 1, preceed_cnode->input(index));
220 manager->SetEdge(preceed_cnode, index, slice_cnode);
221 }
222 return RET_OK;
223 }
224
CreateSliceValueNode(const std::vector<int64_t> & axes)225 ValueNodePtr SlicePreposePass::CreateSliceValueNode(const std::vector<int64_t> &axes) {
226 MS_ASSERT(graph != nullptr);
227 MS_ASSERT(slice_cnode != nullptr);
228 auto new_slice = std::make_shared<mindspore::ops::SliceFusion>();
229 MS_CHECK_TRUE_MSG(new_slice != nullptr, nullptr, "new_slice is nullptr");
230 new_slice->set_axes(axes);
231 ValueNodePtr value_node = NewValueNode(new_slice);
232 MS_CHECK_TRUE_MSG(value_node != nullptr, nullptr, "NewValueNode Failed");
233 return value_node;
234 }
235
CopySliceValueNode(const CNodePtr & slice_cnode)236 ValueNodePtr SlicePreposePass::CopySliceValueNode(const CNodePtr &slice_cnode) {
237 MS_ASSERT(graph != nullptr);
238 MS_ASSERT(slice_cnode != nullptr);
239 auto slice_c = GetValueNode<std::shared_ptr<mindspore::ops::SliceFusion>>(slice_cnode->input(0));
240 if (slice_c == nullptr) {
241 MS_LOG(ERROR) << "slice node is nullptr";
242 return nullptr;
243 }
244 auto new_slice_c = std::make_shared<mindspore::ops::SliceFusion>();
245 MS_CHECK_TRUE_MSG(new_slice_c != nullptr, nullptr, "new_slice_c is nullptr");
246 new_slice_c->set_axes(slice_c->get_axes());
247 ValueNodePtr value_node = NewValueNode(new_slice_c);
248 MS_CHECK_TRUE_MSG(value_node != nullptr, nullptr, "NewValueNode Failed");
249 return value_node;
250 }
251
InsertSlice(const FuncGraphPtr & graph,const std::vector<AnfNodePtr> & inputs,const CNodePtr & preceed_cnode,const int index,const TransactionPtr & tr)252 CNodePtr SlicePreposePass::InsertSlice(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &inputs,
253 const CNodePtr &preceed_cnode, const int index, const TransactionPtr &tr) {
254 MS_ASSERT(graph != nullptr);
255 MS_ASSERT(slice_cnode != nullptr);
256 MS_ASSERT(preceed_cnode != nullptr);
257 auto slice_cnode = graph->NewCNode(inputs);
258 MS_CHECK_TRUE_MSG(slice_cnode != nullptr, nullptr, "NewNode Failed");
259 slice_cnode->set_fullname_with_scope(preceed_cnode->fullname_with_scope() + "_slice_" +
260 std::to_string(node_name_index));
261 node_name_index += 1;
262 tr->SetEdge(preceed_cnode, index, slice_cnode);
263 return slice_cnode;
264 }
265
VerifySliceAttrs(const CNodePtr & slice_cnode,const int dim)266 STATUS SlicePreposePass::VerifySliceAttrs(const CNodePtr &slice_cnode, const int dim) {
267 // according to ops/slice.cc, axes >= 0, begin >= 0, size >= -1
268 auto slice = GetSlice(slice_cnode);
269 if (slice == nullptr) {
270 MS_LOG(ERROR) << "Slice is nullptr";
271 return RET_ERROR;
272 }
273 auto axes = slice->get_axes();
274 auto begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
275 auto size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
276
277 std::set<int64_t> unique_axes(axes.begin(), axes.end());
278 if (axes.empty() || unique_axes.size() != axes.size()) {
279 MS_LOG(DEBUG) << "Invalid slice axe attribute";
280 return RET_ERROR;
281 }
282 MS_CHECK_TRUE_MSG(begin.size() <= axes.size(), RET_ERROR, "begin size is wrong");
283 MS_CHECK_TRUE_MSG(size.size() <= axes.size(), RET_ERROR, "size.size() is wrong");
284 for (size_t i = 0; i < axes.size(); ++i) {
285 auto axe = axes[i];
286 if (dim > -1 && axe >= dim) {
287 MS_LOG(ERROR) << "Invalid slice axe attribute";
288 return RET_ERROR;
289 }
290 if (axe < 0) {
291 MS_LOG(ERROR) << "Invalid slice axe attribute";
292 return RET_ERROR;
293 }
294 if (begin[i] < 0) { // we not require begin[i] < ref_shape[axe], cause there may be broadcast
295 MS_LOG(ERROR) << "Invalid begin input! begin[" << i << "]=" << begin[i];
296 return RET_ERROR;
297 }
298 if (size[i] < -1) {
299 MS_LOG(ERROR) << "Invalid size input! size[" << i << "]=" << size[i];
300 return RET_ERROR;
301 }
302 }
303 return RET_OK;
304 }
305
306 /*
307 * Adjust slice's attr when broadcast happened in Arithmetic
308 */
SliceParamDeBroadcast(const CNodePtr & slice_cnode,const std::vector<int64_t> & ref_shape,std::vector<int64_t> * axes,std::vector<int> * begin,std::vector<int> * size)309 STATUS SlicePreposePass::SliceParamDeBroadcast(const CNodePtr &slice_cnode, const std::vector<int64_t> &ref_shape,
310 std::vector<int64_t> *axes, std::vector<int> *begin,
311 std::vector<int> *size) {
312 MS_ASSERT(slice_cnode != nullptr);
313 MS_ASSERT(new_slice_cnode != nullptr);
314 MS_ASSERT(axes != nullptr);
315 MS_ASSERT(begin != nullptr);
316 MS_ASSERT(size != nullptr);
317 auto slice = GetSlice(slice_cnode);
318 if (slice == nullptr) {
319 MS_LOG(ERROR) << "slice is nullptr";
320 return RET_ERROR;
321 }
322 auto origin_axes = slice->get_axes();
323 auto origin_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
324 auto origin_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
325 auto status = VerifySliceAttrs(slice_cnode, ref_shape.size());
326 if (status != RET_OK) {
327 return status;
328 }
329 axes->resize(ref_shape.size());
330 std::iota(axes->begin(), axes->end(), 0);
331 begin->assign(ref_shape.size(), 0);
332 size->assign(ref_shape.size(), -1);
333 bool real_slice = false; // whether slice happened at this input
334 MS_CHECK_TRUE_MSG(origin_begin.size() >= origin_axes.size(), RET_ERROR, "origin_begin.size() is wrong");
335 MS_CHECK_TRUE_MSG(origin_size.size() >= origin_axes.size(), RET_ERROR, "origin_size.size() is wrong");
336 for (size_t i = 0; i < origin_axes.size(); ++i) {
337 int a = origin_axes[i];
338 int b = origin_begin[i];
339 int s = origin_size[i];
340 MS_CHECK_TRUE_MSG(static_cast<int>(ref_shape.size()) > a, RET_ERROR, "ref_shape.size() is wrong");
341 int ref = ref_shape[a];
342 if (ref == 1) { // broadcast
343 continue; // sliced size is 0(such as begin=1,size=-1) is not considered.
344 } else if (ref > 1) { // not broadcast
345 if (b >= ref) {
346 MS_LOG(ERROR) << "slice begin[" << a << "]=" << b << ", while ref_shape[" << a << "]=" << ref << ", can't fit!";
347 return RET_ERROR;
348 } else {
349 if (b != 0 || (s != -1 && s != ref)) {
350 real_slice = true;
351 }
352 MS_CHECK_TRUE_MSG(static_cast<int>(begin->size()) > a, RET_ERROR, "begin.size() is wrong");
353 MS_CHECK_TRUE_MSG(static_cast<int>(size->size()) > a, RET_ERROR, "size.size() is wrong");
354 begin->at(a) = b;
355 size->at(a) = s;
356 }
357 } else { // ref == 0, not need slice
358 continue;
359 }
360 }
361 if (real_slice) {
362 return lite::RET_OK;
363 } else {
364 return lite::RET_NO_CHANGE;
365 }
366 }
367
CreateReshapeCNode(const FuncGraphPtr & graph,const std::vector<int64_t> & shape_vector,const AbstractBasePtr & abstract,const CNodePtr & preceed_cnode)368 CNodePtr SlicePreposePass::CreateReshapeCNode(const FuncGraphPtr &graph, const std::vector<int64_t> &shape_vector,
369 const AbstractBasePtr &abstract, const CNodePtr &preceed_cnode) {
370 MS_ASSERT(graph != nullptr);
371 MS_ASSERT(slice_cnode != nullptr);
372 MS_ASSERT(abstract != nullptr);
373 MS_ASSERT(preceed_cnode != nullptr);
374 auto new_reshape = std::make_shared<mindspore::ops::Reshape>();
375 if (new_reshape == nullptr) {
376 MS_LOG(ERROR) << "primitive_c is nullptr";
377 return nullptr;
378 }
379 ValueNodePtr value_node = NewValueNode(new_reshape);
380 if (value_node == nullptr) {
381 return nullptr;
382 }
383 std::vector<int> shape;
384 std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape),
385 [](int64_t val) { return static_cast<int>(val); });
386 auto shape_node = BuildIntVecParameterNode(
387 graph, shape, preceed_cnode->fullname_with_scope() + "_shape_" + std::to_string(node_name_index));
388 node_name_index++;
389 if (shape_node == nullptr) {
390 MS_LOG(ERROR) << "build parameter node failed.";
391 return nullptr;
392 }
393 auto reshape_cnode = graph->NewCNode({value_node, preceed_cnode, shape_node});
394 MS_CHECK_TRUE_MSG(reshape_cnode != nullptr, nullptr, "NewCNode Failed");
395 reshape_cnode->set_abstract(abstract);
396 reshape_cnode->set_fullname_with_scope(preceed_cnode->fullname_with_scope() + "_reshape_" +
397 std::to_string(node_name_index));
398 node_name_index++;
399 ClearCNodeAbstractValue(reshape_cnode);
400 return reshape_cnode;
401 }
402
SiblingsAreSameSlice(const NodeUsedListPtr & output_node_list,const std::vector<int64_t> & ref_shape)403 bool SlicePreposePass::SiblingsAreSameSlice(const NodeUsedListPtr &output_node_list,
404 const std::vector<int64_t> &ref_shape) {
405 MS_ASSERT(graph != nullptr);
406 MS_ASSERT(output_node_list != nullptr);
407 MS_ASSERT(output_node_list->size() >= 2);
408 std::vector<CNodePtr> slices;
409 for (auto &output_node : *(output_node_list.get())) {
410 auto cnode = output_node.first->cast<CNodePtr>();
411 if (cnode == nullptr) {
412 MS_LOG(ERROR) << "cnode is nullptr";
413 return false;
414 }
415 if (!CheckPrimitiveType(cnode, prim::kPrimSliceFusion)) {
416 return false;
417 }
418 auto slice_node = GetSlice(cnode);
419 if (slice_node == nullptr) {
420 MS_LOG(ERROR) << "Slice is nullptr";
421 return false;
422 }
423 slices.push_back(cnode);
424 }
425 MS_CHECK_TRUE_MSG(slices.size() > 0, RET_ERROR, "slices.size() is wrong");
426 auto first_slice_cnode = slices.front();
427 auto first_slice_node = GetSlice(first_slice_cnode);
428 if (first_slice_node == nullptr) {
429 MS_LOG(ERROR) << "GetSlice return nullptr";
430 return false;
431 }
432 auto first_axes = first_slice_node->get_axes();
433 auto first_begin = GetSliceBeginAndSize(first_slice_cnode, SliceBeginIndex);
434 auto first_size = GetSliceBeginAndSize(first_slice_cnode, SliceSizeIndex);
435 MS_CHECK_TRUE_MSG(first_begin.size() >= first_axes.size(), RET_ERROR, "first_begin.size() is wrong");
436 MS_CHECK_TRUE_MSG(first_size.size() >= first_axes.size(), RET_ERROR, "first_size.size() is wrong");
437 MS_CHECK_TRUE_MSG(slices.size() >= output_node_list->size(), RET_ERROR, "slices.size() is wrong");
438 for (size_t i = 1; i < output_node_list->size(); ++i) {
439 auto slice = GetSlice(slices[i]);
440 auto axes = slice->get_axes();
441 auto begin = GetSliceBeginAndSize(slices[i], SliceBeginIndex);
442 auto size = GetSliceBeginAndSize(slices[i], SliceSizeIndex);
443 MS_CHECK_TRUE_MSG(begin.size() >= axes.size(), RET_ERROR, "begin.size() is wrong");
444 MS_CHECK_TRUE_MSG(size.size() >= axes.size(), RET_ERROR, "size.size() is wrong");
445 if (axes.size() != first_axes.size()) {
446 return false;
447 }
448 for (size_t j = 0; j < axes.size(); ++j) {
449 auto axe = axes[j];
450 if (!ref_shape.empty() && axe >= static_cast<int>(ref_shape.size())) {
451 return false;
452 }
453 size_t k = 0;
454 for (; k < first_axes.size(); ++k) { // axes may not be [0...n-1], so we use nested loop to find it
455 if (first_axes[k] == axe) {
456 break;
457 }
458 }
459 if (k == first_axes.size()) {
460 return false;
461 }
462 if (begin[j] != first_begin[k]) {
463 return false;
464 }
465 if (size[j] != first_size[k]) {
466 if (ref_shape.empty()) {
467 return false;
468 }
469 auto actual_size = size[j] > 0 ? size[j] : ref_shape[axe] - begin[j];
470 auto actual_first_size = first_size[k] > 0 ? first_size[k] : ref_shape[axe] - first_begin[k];
471 if (actual_size != actual_first_size) {
472 return false;
473 }
474 }
475 }
476 }
477 return true;
478 }
479
GetReshapeAbnormalAxeIn(const std::vector<int64_t> & shape_in,const std::vector<int64_t> & shape_out,std::vector<int64_t> * mapped_axe)480 int64_t SlicePreposePass::GetReshapeAbnormalAxeIn(const std::vector<int64_t> &shape_in,
481 const std::vector<int64_t> &shape_out,
482 std::vector<int64_t> *mapped_axe) {
483 // find shape_out's correspond axe in shape_in
484 // when there are such as 3x1x1x4 => 3x1x4, mapped_axe[1] == 2
485 int64_t inner_size_in = 1;
486 int64_t abnormal_axe_in = -1;
487 MS_CHECK_TRUE_MSG(mapped_axe->size() >= shape_out.size(), abnormal_axe_in, "mapped_axe.size() is wrong");
488 for (size_t i = 0; i < shape_in.size(); ++i) {
489 inner_size_in *= shape_in[i];
490 int64_t inner_size_out = 1;
491 size_t j;
492 for (j = 0; j < shape_out.size(); ++j) {
493 inner_size_out *= shape_out[j];
494 if (shape_out[j] == shape_in[i] && inner_size_out == inner_size_in) {
495 mapped_axe->at(j) = i;
496 break;
497 }
498 }
499 if (j == shape_out.size() && abnormal_axe_in == -1) {
500 abnormal_axe_in = i;
501 }
502 }
503 return abnormal_axe_in;
504 }
505
GetReshapeAbnormalIndexOut(const CNodePtr & slice_cnode,const std::vector<int64_t> & mapped_axe,const std::vector<int64_t> & shape_out,std::vector<int64_t> * shape_out_copy,bool * is_normal_mode,bool * support_abnormal_mode)506 int64_t SlicePreposePass::GetReshapeAbnormalIndexOut(const CNodePtr &slice_cnode,
507 const std::vector<int64_t> &mapped_axe,
508 const std::vector<int64_t> &shape_out,
509 std::vector<int64_t> *shape_out_copy, bool *is_normal_mode,
510 bool *support_abnormal_mode) {
511 MS_ASSERT(slice_cnode != nullptr);
512 MS_ASSERT(shape_out_copy != nullptr);
513 MS_ASSERT(is_normal_mode != nullptr);
514 MS_ASSERT(support_abnormal_mode != nullptr);
515 int64_t abnormal_index_out = -1;
516 auto slice_node = GetSlice(slice_cnode);
517 if (slice_node == nullptr) {
518 MS_LOG(ERROR) << "slice is nullptr";
519 return abnormal_index_out;
520 }
521 auto slice_axes = slice_node->get_axes();
522 auto slice_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
523 auto slice_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
524 for (size_t j = 0; j < shape_out.size(); ++j) {
525 int index = -1;
526 for (size_t i = 0; i < slice_axes.size(); ++i) {
527 if (slice_axes[i] == static_cast<int64_t>(j)) {
528 index = i;
529 break;
530 }
531 }
532 if (index == -1) {
533 continue;
534 }
535 MS_CHECK_TRUE_MSG(static_cast<int>(slice_begin.size()) > index, false, "slice_begin.size() is wrong");
536 MS_CHECK_TRUE_MSG(static_cast<int>(slice_size.size()) > index, false, "slice_size.size() is wrong");
537 if (slice_begin[index] != 0 || (slice_size[index] != -1 && slice_size[index] != shape_out[j])) {
538 if (mapped_axe[j] == -1) {
539 if (is_normal_mode) {
540 *is_normal_mode = false;
541 abnormal_index_out = index;
542 } else {
543 *support_abnormal_mode = false;
544 }
545 } else { // if there is matched axe sliced, not support abnormal mode
546 shape_out_copy->at(j) =
547 (slice_size[index] == -1 ? shape_out[j] - slice_begin[index] : static_cast<int64_t>(slice_size[index]));
548 *support_abnormal_mode = false;
549 }
550 }
551 }
552 return abnormal_index_out;
553 }
554
PreposeWithNormalReshape(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & reshape_cnode,const std::vector<int64_t> & shape_in,const std::vector<int64_t> & shape_out_copy,const std::vector<int64_t> & mapped_axe)555 bool SlicePreposePass::PreposeWithNormalReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
556 const CNodePtr &reshape_cnode, const std::vector<int64_t> &shape_in,
557 const std::vector<int64_t> &shape_out_copy,
558 const std::vector<int64_t> &mapped_axe) {
559 MS_ASSERT(graph != nullptr);
560 MS_ASSERT(slice_cnode != nullptr);
561 MS_ASSERT(reshape_cnode != nullptr);
562 auto slice_node = GetSlice(slice_cnode);
563 if (slice_node == nullptr) {
564 MS_LOG(ERROR) << "slice is nullptr";
565 return false;
566 }
567 auto slice_axes = slice_node->get_axes();
568 auto slice_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
569 auto slice_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
570 std::vector<int64_t> new_axes(shape_in.size());
571 std::iota(new_axes.begin(), new_axes.end(), 0);
572 std::vector<int> new_begin(shape_in.size(), 0);
573 std::vector<int> new_size(shape_in.size(), -1);
574 MS_CHECK_TRUE_MSG(slice_begin.size() >= mapped_axe.size(), false, "slice_begin.size() is wrong");
575 MS_CHECK_TRUE_MSG(slice_size.size() >= mapped_axe.size(), false, "slice_begin.size() is wrong");
576 for (size_t i = 0; i < mapped_axe.size(); ++i) {
577 auto axe_in = mapped_axe[i];
578 if (axe_in == -1) {
579 continue;
580 }
581 new_begin[axe_in] = slice_begin[i];
582 new_size[axe_in] = slice_size[i];
583 }
584
585 auto reshape_node = GetReshape(reshape_cnode);
586 if (reshape_node == nullptr) {
587 MS_LOG(ERROR) << "reshape is nullptr";
588 return false;
589 }
590 std::vector<int> new_shape_out_copy;
591 std::transform(shape_out_copy.begin(), shape_out_copy.end(), std::back_inserter(new_shape_out_copy),
592 [](int64_t val) { return static_cast<int>(val); });
593 auto shape_node = BuildIntVecParameterNode(
594 graph, new_shape_out_copy, reshape_cnode->fullname_with_scope() + "_shape_" + std::to_string(node_name_index));
595 node_name_index++;
596 if (shape_node == nullptr) {
597 MS_LOG(ERROR) << "build parameter node failed.";
598 return false;
599 }
600 reshape_cnode->set_inputs({reshape_cnode->input(0), reshape_cnode->input(1), shape_node});
601
602 slice_node->set_axes(new_axes);
603 auto new_begin_parameter = BuildIntVecParameterNode(
604 graph, new_begin, slice_cnode->input(SliceBeginIndex)->cast<ParameterPtr>()->fullname_with_scope());
605 auto new_size_parameter = BuildIntVecParameterNode(
606 graph, new_size, slice_cnode->input(SliceSizeIndex)->cast<ParameterPtr>()->fullname_with_scope());
607 MS_CHECK_TRUE_MSG(new_begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
608 MS_CHECK_TRUE_MSG(new_size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
609 slice_cnode->set_input(SliceBeginIndex, new_begin_parameter);
610 slice_cnode->set_input(SliceSizeIndex, new_size_parameter);
611 auto status = SwapSliceWithPreceed(graph, slice_cnode, reshape_cnode, 1);
612 if (status != RET_OK) {
613 return false;
614 }
615 reshape_cnode->set_abstract(slice_cnode->abstract()->Clone());
616 ClearCNodeAbstractValue(slice_cnode);
617 return true;
618 }
619
CreateSlice1ForReshapePrepose(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & matmul_cnode,const std::vector<int64_t> & shape_in,const int64_t abnormal_axe_in,const int64_t count_sliced_axe_in,const bool slice_at_front)620 CNodePtr SlicePreposePass::CreateSlice1ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
621 const CNodePtr &matmul_cnode,
622 const std::vector<int64_t> &shape_in,
623 const int64_t abnormal_axe_in,
624 const int64_t count_sliced_axe_in, const bool slice_at_front) {
625 MS_ASSERT(graph != nullptr);
626 MS_ASSERT(slice_cnode != nullptr);
627 MS_ASSERT(matmul_cnode != nullptr);
628 std::vector<int64_t> new_axes1(shape_in.size());
629 std::iota(new_axes1.begin(), new_axes1.end(), 0);
630 std::vector<int> new_begin1(shape_in.size(), 0);
631 std::vector<int> new_size1(shape_in.size(), -1);
632 if (slice_at_front) {
633 new_begin1[abnormal_axe_in] = static_cast<int>(count_sliced_axe_in);
634 } else {
635 new_size1[abnormal_axe_in] = static_cast<int>(shape_in[abnormal_axe_in] - count_sliced_axe_in);
636 }
637 auto new_slice1 = CreateSliceValueNode(new_axes1);
638 if (new_slice1 == nullptr) {
639 MS_LOG(ERROR) << "CreateSliceValueNode failed";
640 return nullptr;
641 }
642 auto begin_parameter = BuildIntVecParameterNode(
643 graph, new_begin1, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
644 MS_CHECK_TRUE_MSG(begin_parameter != nullptr, nullptr, "BuildIntVecParameterNode Failed");
645 node_name_index += 1;
646 auto size_parameter = BuildIntVecParameterNode(
647 graph, new_size1, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
648 MS_CHECK_TRUE_MSG(size_parameter != nullptr, nullptr, "BuildIntVecParameterNode Failed");
649 node_name_index += 1;
650 auto new_slice1_cnode = graph->NewCNode({new_slice1, matmul_cnode, begin_parameter, size_parameter});
651 MS_CHECK_TRUE_MSG(new_slice1_cnode != nullptr, nullptr, "NewNode Failed");
652 new_slice1_cnode->set_abstract(slice_cnode->abstract()->Clone());
653 new_slice1_cnode->set_fullname_with_scope(slice_cnode->fullname_with_scope() + "_slice_" +
654 std::to_string(node_name_index));
655 node_name_index++;
656 ClearCNodeAbstractValue(new_slice1_cnode);
657 return new_slice1_cnode;
658 }
659
CreateSlice2ForReshapePrepose(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & new_reshape1_cnode,const std::vector<int64_t> & new_shape1,const int64_t abnormal_axe_in,const int64_t count_sliced2,const bool slice_at_front)660 CNodePtr SlicePreposePass::CreateSlice2ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
661 const CNodePtr &new_reshape1_cnode,
662 const std::vector<int64_t> &new_shape1,
663 const int64_t abnormal_axe_in, const int64_t count_sliced2,
664 const bool slice_at_front) {
665 MS_ASSERT(graph != nullptr);
666 MS_ASSERT(slice_cnode != nullptr);
667 MS_ASSERT(matmul_cnode != nullptr);
668 std::vector<int64_t> new_axes2(abnormal_axe_in + 1);
669 std::iota(new_axes2.begin(), new_axes2.end(), 0);
670 std::vector<int> new_begin2(abnormal_axe_in + 1, 0);
671 std::vector<int> new_size2(abnormal_axe_in + 1, -1);
672 if (count_sliced2 > new_shape1[abnormal_axe_in]) {
673 MS_LOG(WARNING) << "calculation error";
674 return nullptr;
675 }
676 if (slice_at_front) {
677 new_begin2[abnormal_axe_in] = static_cast<int>(new_shape1[abnormal_axe_in] - count_sliced2);
678 } else {
679 new_size2[abnormal_axe_in] = static_cast<int>(count_sliced2);
680 }
681 auto new_slice2 = CreateSliceValueNode(new_axes2);
682 if (new_slice2 == nullptr) {
683 MS_LOG(ERROR) << "CreateSliceValueNode failed";
684 return nullptr;
685 }
686 auto begin_parameter = BuildIntVecParameterNode(
687 graph, new_begin2, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
688 MS_CHECK_TRUE_MSG(begin_parameter != nullptr, nullptr, "BuildIntVecParameterNode Failed");
689 node_name_index += 1;
690 auto size_parameter = BuildIntVecParameterNode(
691 graph, new_size2, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
692 node_name_index += 1;
693 MS_CHECK_TRUE_MSG(size_parameter != nullptr, nullptr, "BuildIntVecParameterNode Failed");
694 auto new_slice2_cnode = graph->NewCNode({new_slice2, new_reshape1_cnode, begin_parameter, size_parameter});
695 MS_CHECK_TRUE_MSG(new_slice2_cnode != nullptr, nullptr, "NewNode Failed");
696 new_slice2_cnode->set_abstract(slice_cnode->abstract()->Clone());
697 new_slice2_cnode->set_fullname_with_scope(slice_cnode->fullname_with_scope() + "_slice_" +
698 std::to_string(node_name_index));
699 node_name_index++;
700 ClearCNodeAbstractValue(new_slice2_cnode);
701 return new_slice2_cnode;
702 }
703
PreposeWithAbnormalReshape(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & matmul_cnode,const std::vector<int64_t> & shape_in,const std::vector<int64_t> & shape_out,const int64_t abnormal_axe_in,const int64_t abnormal_index_out)704 bool SlicePreposePass::PreposeWithAbnormalReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
705 const CNodePtr &matmul_cnode, const std::vector<int64_t> &shape_in,
706 const std::vector<int64_t> &shape_out, const int64_t abnormal_axe_in,
707 const int64_t abnormal_index_out) {
708 MS_ASSERT(graph != nullptr);
709 MS_ASSERT(slice_cnode != nullptr);
710 auto manager = graph->manager();
711 MS_CHECK_TRUE_MSG(manager != nullptr, false, "manager is nullptr");
712 auto slice_node = GetSlice(slice_cnode);
713 if (slice_node == nullptr) {
714 MS_LOG(ERROR) << "slice is nullptr";
715 return false;
716 }
717 auto slice_axes = slice_node->get_axes();
718 MS_CHECK_TRUE_MSG(static_cast<int>(slice_axes.size()) > abnormal_index_out, false, "slice_axes.size() is wrong");
719 auto slice_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
720 auto slice_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
721 auto abnormal_axe_out = slice_axes[abnormal_index_out];
722 MS_ASSERT(abnormal_axe_out + 1 < shape_out.size());
723 int64_t inter_size_in = 1;
724 int64_t inter_size_out = 1;
725 for (auto i = 0; i < abnormal_axe_in; ++i) {
726 inter_size_in *= shape_in[i];
727 }
728 for (auto i = 0; i < abnormal_axe_out; ++i) {
729 inter_size_out *= shape_out[i];
730 }
731 if (inter_size_in != inter_size_out) {
732 MS_LOG(DEBUG) << "not support prepose now";
733 return false;
734 }
735 int64_t outer_size_in = 1;
736 int64_t outer_size_out = 1;
737 for (auto i = abnormal_axe_in + 1; i < static_cast<int>(shape_in.size()); ++i) {
738 outer_size_in *= shape_in[i];
739 }
740 for (auto i = abnormal_axe_out + 1; i < static_cast<int>(shape_out.size()); ++i) {
741 outer_size_out *= shape_out[i];
742 }
743 MS_CHECK_TRUE_MSG(static_cast<int>(slice_begin.size()) > abnormal_index_out, false, "slice_begin.size() is wrong");
744 const int64_t count_sliced_axe_front = slice_begin[abnormal_index_out];
745 MS_CHECK_TRUE_MSG(static_cast<int>(slice_size.size()) > abnormal_index_out, false, "slice_size.size() is wrong");
746 MS_CHECK_TRUE_MSG(static_cast<int>(shape_out.size()) > abnormal_index_out, false, "shape_out.size() is wrong");
747 const int64_t count_sliced_axe_rear =
748 slice_size[abnormal_index_out] == -1 ? 0 : (shape_out[abnormal_axe_out] - slice_size[abnormal_index_out]);
749 if (count_sliced_axe_front * count_sliced_axe_rear > 0) {
750 MS_LOG(DEBUG) << "not border slice at abnormal axe, prepose with reshape failed";
751 return false;
752 }
753 bool slice_at_front = count_sliced_axe_front > 0;
754 const int64_t count_sliced_out = (count_sliced_axe_front + count_sliced_axe_rear) * outer_size_out;
755 MS_CHECK_TRUE_MSG(outer_size_in != 0, false, "div zero");
756 const int64_t count_sliced_axe_in = count_sliced_out / outer_size_in;
757 MS_CHECK_TRUE_MSG(static_cast<int>(shape_in.size()) > abnormal_axe_in, false, "shape_in.size() is wrong");
758 if (count_sliced_axe_in <= 0 || count_sliced_axe_in > shape_in[abnormal_axe_in]) {
759 MS_LOG(DEBUG) << "amount of sliced out tensor is illegal";
760 return false;
761 }
762 // new_slice1
763 auto new_slice1_cnode = CreateSlice1ForReshapePrepose(graph, slice_cnode, matmul_cnode, shape_in, abnormal_axe_in,
764 count_sliced_axe_in, slice_at_front);
765 if (new_slice1_cnode == nullptr) {
766 return false;
767 }
768 // new_reshape1
769 std::vector<int64_t> new_shape1(abnormal_axe_in + 1);
770 for (int i = 0; i < abnormal_axe_in; ++i) {
771 new_shape1[i] = shape_in[i];
772 }
773 new_shape1[abnormal_axe_in] = outer_size_in * (shape_in[abnormal_axe_in] - count_sliced_axe_in);
774 auto new_reshape1_cnode = CreateReshapeCNode(graph, new_shape1, slice_cnode->abstract()->Clone(), new_slice1_cnode);
775 if (new_reshape1_cnode == nullptr) {
776 return false;
777 }
778 // new_slice2
779 const int64_t count_sliced_abnormal_axe =
780 shape_out[abnormal_axe_out] - (count_sliced_axe_front + count_sliced_axe_rear);
781 const int64_t count_sliced2 = count_sliced_abnormal_axe * outer_size_out;
782 auto new_slice2_cnode = CreateSlice2ForReshapePrepose(graph, slice_cnode, new_reshape1_cnode, new_shape1,
783 abnormal_axe_in, count_sliced2, slice_at_front);
784 if (new_slice2_cnode == nullptr) {
785 return false;
786 }
787 // new_reshape2
788 std::vector<int64_t> new_shape2(shape_out.begin(), shape_out.end());
789 new_shape2[abnormal_axe_out] = count_sliced_abnormal_axe;
790 auto new_reshape2_cnode = CreateReshapeCNode(graph, new_shape2, slice_cnode->abstract()->Clone(), new_slice2_cnode);
791 if (new_reshape2_cnode == nullptr) {
792 return false;
793 }
794 new_reshape2_cnode->set_abstract(slice_cnode->abstract()->Clone());
795 auto node_users = manager->node_users()[slice_cnode];
796 for (auto &node_user : node_users) {
797 manager->SetEdge(node_user.first, node_user.second, new_reshape2_cnode);
798 }
799 return true;
800 }
801
GetArithmeticInputInfo(const CNodePtr & arithmetic_cnode,std::vector<AnfNodePtr> * inputs,std::vector<std::vector<int64_t>> * shapes,std::vector<bool> * is_default_params)802 bool SlicePreposePass::GetArithmeticInputInfo(const CNodePtr &arithmetic_cnode, std::vector<AnfNodePtr> *inputs,
803 std::vector<std::vector<int64_t>> *shapes,
804 std::vector<bool> *is_default_params) {
805 MS_ASSERT(inputs != nullptr);
806 MS_ASSERT(shapes != nullptr);
807 MS_ASSERT(is_default_params != nullptr);
808 MS_ASSERT(arithmetic_cnode != nullptr);
809 for (size_t i = 1; i < arithmetic_cnode->inputs().size(); ++i) {
810 auto input = arithmetic_cnode->input(i);
811 MS_ASSERT(input != nullptr);
812 std::vector<int64_t> shape;
813 if (utils::isa<ParameterPtr>(input)) {
814 auto parameter = utils::cast<ParameterPtr>(input);
815 MS_ASSERT(parameter != nullptr);
816 if (!parameter->has_default()) { // if one input is input placeholder, we can't change it
817 return false;
818 } else {
819 shape = GetDefaultParamShape(parameter);
820 is_default_params->push_back(true);
821 }
822 } else { // input is CNode
823 if (!utils::isa<CNodePtr>(input)) {
824 MS_LOG(ERROR) << "one of Arithmetic's input is not CNode";
825 return false;
826 }
827 shape = GetCNodeInputShape(arithmetic_cnode, i);
828 is_default_params->push_back(false);
829 }
830 inputs->push_back(input);
831 shapes->push_back(shape);
832 }
833 return true;
834 }
835
836 /*
837 * Prepose condition:
838 * the softmax axis is not sliced
839 */
PreposeWithSoftmax(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & softmax_cnode)840 bool SlicePreposePass::PreposeWithSoftmax(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
841 const CNodePtr &softmax_cnode) {
842 MS_ASSERT(graph != nullptr);
843 MS_ASSERT(slice_cnode != nullptr);
844 MS_ASSERT(softmax_cnode != nullptr);
845 auto softmax_node = GetSoftmax(softmax_cnode);
846 if (softmax_node == nullptr) {
847 MS_LOG(ERROR) << "softmax is nullptr";
848 return false;
849 }
850 std::vector<int64_t> softmax_axis{-1};
851 if (softmax_node->GetAttr(ops::kAxis) != nullptr) {
852 softmax_axis = softmax_node->get_axis();
853 }
854 if (softmax_axis.size() != 1) {
855 MS_LOG(ERROR) << "softmax axis is not a value, which don't support.";
856 return false;
857 }
858 auto shape = GetCNodeInputShape(softmax_cnode, 1);
859 if (softmax_axis.front() == -1) {
860 if (shape.empty()) { // when softmax axis == -1, shape info is needed to determine whether slice can be preposed
861 return false;
862 }
863 softmax_axis[0] += shape.size();
864 }
865
866 auto slice_node = GetSlice(slice_cnode);
867 if (slice_node == nullptr) {
868 return false;
869 }
870 auto slice_axes = slice_node->get_axes();
871 auto slice_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
872 auto slice_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
873
874 MS_CHECK_TRUE_MSG(static_cast<int>(softmax_axis.size()) > 0, false, "shape_in.size() is wrong");
875 MS_CHECK_TRUE_MSG(slice_size.size() >= slice_axes.size(), false, "shape_in.size() is wrong");
876 MS_CHECK_TRUE_MSG(slice_begin.size() >= slice_axes.size(), false, "shape_in.size() is wrong");
877 for (size_t i = 0; i < slice_axes.size(); ++i) {
878 if (slice_axes[i] == softmax_axis.front()) {
879 if (slice_begin[i] != 0) {
880 return false;
881 }
882 if (slice_size[i] != -1) {
883 if (shape.empty() || slice_axes[i] >= static_cast<int>(shape.size())) {
884 return false;
885 }
886 if (slice_size[i] < shape[slice_axes[i]]) {
887 return false;
888 }
889 }
890 }
891 }
892 auto status = SwapSliceWithPreceed(graph, slice_cnode, softmax_cnode, 1);
893 if (status != RET_OK) {
894 return false;
895 }
896 softmax_cnode->set_abstract(slice_cnode->abstract()->Clone());
897 ClearCNodeAbstractValue(slice_cnode);
898 return true;
899 }
900
901 /*
902 * Prepose condition:
903 * require shape info
904 * when reshape is normal(memory view is not changed, such as 4x5 reshaped to 4x1x5), can always prepose
905 * when reshape is abnormal(such as 4x5 reshaped to 5x4), can prepose under some constraint
906 * For abnormal mode:
907 * we only support border(not slice at center) slice at first mismatch axe,
908 * and we only support matmul->reshape->slice => matmul->slice->reshape*->slice*(drop "dead" data)->reshape now,
909 * cause the performance influence introduced by additional (reshape*->slice*) has not been fully evaluated.
910 */
PreposeWithReshape(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & reshape_cnode)911 bool SlicePreposePass::PreposeWithReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
912 const CNodePtr &reshape_cnode) {
913 MS_ASSERT(graph != nullptr);
914 MS_ASSERT(slice_cnode != nullptr);
915 MS_ASSERT(reshape_cnode != nullptr);
916 auto shape_in = GetCNodeInputShape(reshape_cnode, 1);
917 auto shape_out = GetCNodeInputShape(slice_cnode, 1);
918 auto shape_out_copy = shape_out;
919 if (shape_in.empty() || shape_out.empty()) {
920 MS_LOG(DEBUG) << "Reshape can't be preposed if either input or output shape is unknown";
921 return false;
922 }
923 if (reshape_cnode->inputs().size() == 3 && utils::isa<ParameterPtr>(reshape_cnode->input(2))) {
924 auto reshape_input_shape = utils::cast<ParameterPtr>(reshape_cnode->input(2));
925 MS_ASSERT(reshape_input_shape != nullptr);
926 if (!reshape_input_shape->has_default()) {
927 MS_LOG(ERROR) << "Reshape input shape is not constant";
928 return false;
929 }
930 }
931 std::vector<int64_t> mapped_axe(shape_out.size(), -1);
932 int64_t abnormal_axe_in = GetReshapeAbnormalAxeIn(shape_in, shape_out, &mapped_axe);
933 bool is_normal_mode = true; // if all sliced axe can be found in input shape, normal
934 bool support_abnormal_mode = true; // if first mismatch axe are sliced and no more other axes are sliced, abnormal
935 int64_t abnormal_index_out = GetReshapeAbnormalIndexOut(slice_cnode, mapped_axe, shape_out, &shape_out_copy,
936 &is_normal_mode, &support_abnormal_mode);
937 if (abnormal_index_out == -1) {
938 MS_LOG(ERROR) << "GetReshapeAbnormalIndexOut failed.";
939 return false;
940 }
941 if (is_normal_mode) {
942 return PreposeWithNormalReshape(graph, slice_cnode, reshape_cnode, shape_in, shape_out_copy, mapped_axe);
943 } else if (support_abnormal_mode) {
944 auto matmul_node = reshape_cnode->input(1);
945 MS_ASSERT(matmul_node != nullptr);
946 if (IsMultiOutputTensors(graph, matmul_node) || !utils::isa<CNodePtr>(matmul_node)) {
947 MS_LOG(DEBUG) << "not matmul->reshape->slice";
948 return false;
949 }
950 auto matmul_cnode = matmul_node->cast<CNodePtr>();
951 if (matmul_cnode == nullptr) {
952 MS_LOG(ERROR) << "matmul_cnode is nullptr";
953 return false;
954 }
955 if (!CheckPrimitiveType(matmul_node, prim::kPrimFullConnection) &&
956 !CheckPrimitiveType(matmul_node, prim::kPrimMatMul)) {
957 MS_LOG(DEBUG) << "not matmul->reshape->slice pattern";
958 return false;
959 }
960 return PreposeWithAbnormalReshape(graph, slice_cnode, matmul_cnode, shape_in, shape_out, abnormal_axe_in,
961 abnormal_index_out);
962 }
963 return false;
964 }
965
966 /*
967 * Prepose condition:
968 * require shape info
969 */
PreposeWithMatmul(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & matmul_cnode)970 bool SlicePreposePass::PreposeWithMatmul(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
971 const CNodePtr &matmul_cnode) {
972 MS_ASSERT(graph != nullptr && slice_cnode != nullptr && matmul_cnode != nullptr);
973 auto matmul_shape = GetCNodeInputShape(slice_cnode, 1);
974 const int dims = matmul_shape.size();
975 if (dims == 0) {
976 // if Matmul's output shape is unknown, can't do prepose, cause we can't determine last two axes
977 return false;
978 }
979 auto slice_node = GetSlice(slice_cnode);
980 MS_CHECK_TRUE_MSG(slice_node != nullptr, false, "slice is nullptr");
981 auto axes = slice_node->get_axes();
982 auto begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
983 auto size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
984 // matmul not support broadcast now, it makes things simpler
985 auto manager = graph->manager();
986 std::shared_ptr<FuncGraphTransaction> tr = std::make_shared<FuncGraphTransaction>(manager.get());
987 MS_CHECK_TRUE_MSG(tr != nullptr, false, "create FuncGraphTransaction failed");
988 auto node_users = manager->node_users()[slice_cnode];
989 bool changed = false;
990 bool prepose_to_left = false; // if only the last axe is sliced, not need prepose to left
991 bool prepose_to_right = false; // if only the second last axe is sliced, not need prepose to right
992 MS_CHECK_TRUE_MSG(begin.size() >= axes.size(), false, "begin.size() is wrong");
993 MS_CHECK_TRUE_MSG(size.size() >= axes.size(), false, "size.size() is wrong");
994 for (size_t i = 0; i < axes.size(); ++i) {
995 if (begin[i] != 0 || (size[i] != -1 && size[i] != matmul_shape[axes[i]])) {
996 if (axes[i] != dims - 1) {
997 prepose_to_left = true;
998 } else if (axes[i] != dims - 2) {
999 prepose_to_right = true;
1000 }
1001 }
1002 }
1003 if (prepose_to_left) { // left matrix
1004 auto left_axes = axes;
1005 auto left_begin = begin;
1006 auto left_size = size;
1007 MS_CHECK_TRUE_MSG(left_begin.size() >= left_axes.size(), false, "left_begin.size() is wrong");
1008 MS_CHECK_TRUE_MSG(left_size.size() >= left_axes.size(), false, "left_size.size() is wrong");
1009 for (size_t i = 0; i < left_axes.size(); ++i) {
1010 if (left_axes[i] == dims - 1) {
1011 left_begin[i] = 0;
1012 left_size[i] = -1;
1013 }
1014 }
1015 auto left_slice_vnode = CreateSliceValueNode(left_axes);
1016 MS_CHECK_TRUE_MSG(left_slice_vnode != nullptr, false, "CreateSliceValueNode failed");
1017 auto begin_parameter = BuildIntVecParameterNode(
1018 graph, left_begin, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1019 node_name_index += 1;
1020 auto size_parameter = BuildIntVecParameterNode(
1021 graph, left_size, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1022 MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1023 MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1024 node_name_index += 1;
1025
1026 const std::vector<AnfNodePtr> inputs = {left_slice_vnode, matmul_cnode->input(1), begin_parameter, size_parameter};
1027 auto new_slice_cnode = InsertSlice(graph, inputs, matmul_cnode, 1, tr);
1028 MS_CHECK_TRUE_MSG(new_slice_cnode != nullptr, false, "InsertSlice Failed");
1029 new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone());
1030 ClearCNodeAbstractValue(new_slice_cnode);
1031 changed = true;
1032 }
1033 if (prepose_to_right) { // right matrix
1034 auto right_axes = axes;
1035 auto right_begin = begin;
1036 auto right_size = size;
1037 MS_CHECK_TRUE_MSG(right_begin.size() >= right_axes.size(), false, "right_begin.size() is wrong");
1038 MS_CHECK_TRUE_MSG(right_size.size() >= right_axes.size(), false, "right_size.size() is wrong");
1039 for (size_t i = 0; i < right_axes.size(); ++i) {
1040 if (right_axes[i] == dims - 2) {
1041 right_begin[i] = 0;
1042 right_size[i] = -1;
1043 }
1044 }
1045 auto begin_parameter = BuildIntVecParameterNode(
1046 graph, right_begin, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1047 node_name_index += 1;
1048 auto size_parameter = BuildIntVecParameterNode(
1049 graph, right_size, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1050 MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1051 MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1052 node_name_index += 1;
1053 auto right_slice_vnode = CreateSliceValueNode(right_axes);
1054 MS_CHECK_TRUE_MSG(right_slice_vnode != nullptr, false, "CreateSliceValueNode failed");
1055 const std::vector<AnfNodePtr> inputs = {right_slice_vnode, matmul_cnode->input(2), begin_parameter, size_parameter};
1056 auto new_slice_cnode = InsertSlice(graph, inputs, matmul_cnode, 2, tr);
1057 MS_ASSERT(new_slice_cnode != nullptr);
1058 new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone());
1059 ClearCNodeAbstractValue(new_slice_cnode);
1060 changed = true;
1061 }
1062 if (changed) {
1063 matmul_cnode->set_abstract(slice_cnode->abstract()->Clone());
1064 for (auto &node_user : node_users) {
1065 tr->SetEdge(node_user.first, node_user.second, matmul_cnode);
1066 }
1067 tr->Commit();
1068 // we don't need graph->DropNode(slice_cnode);
1069 }
1070 return changed;
1071 }
1072
1073 /*
1074 * Prepose condition:
1075 * require shape info
1076 * only support slice at first output axe now, and useAxis must be false
1077 */
PreposeWithFullConnection(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & fc_cnode)1078 bool SlicePreposePass::PreposeWithFullConnection(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
1079 const CNodePtr &fc_cnode) {
1080 MS_ASSERT(graph != nullptr);
1081 MS_ASSERT(slice_cnode != nullptr);
1082 MS_ASSERT(fc_cnode != nullptr);
1083 auto shape_in = GetCNodeInputShape(fc_cnode, 1);
1084 auto shape_out = GetCNodeInputShape(slice_cnode, 1);
1085 if (shape_in.empty() || shape_out.size() != 2) {
1086 MS_LOG(DEBUG) << "FullConnection can't be preposed if input shape is unknown or output shape is illegal";
1087 return false;
1088 }
1089 auto fc_node = GetFc(fc_cnode);
1090 if (fc_node == nullptr || (fc_node->GetAttr(ops::kUseAxis) != nullptr && fc_node->get_use_axis())) {
1091 MS_LOG(DEBUG) << "prepose with fc only support useAxis == false currently";
1092 return false;
1093 }
1094 auto slice_node = GetSlice(slice_cnode);
1095 if (slice_node == nullptr) {
1096 MS_LOG(ERROR) << "slice is nullptr";
1097 return false;
1098 }
1099 auto axes = slice_node->get_axes();
1100 auto begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
1101 auto size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
1102 MS_CHECK_TRUE_MSG(begin.size() >= axes.size(), false, "begin.size() is wrong");
1103 MS_CHECK_TRUE_MSG(size.size() >= axes.size(), false, "size.size() is wrong");
1104 for (size_t i = 0; i < axes.size(); ++i) {
1105 if (axes[i] == 1) {
1106 if (begin[i] != 0 || (size[i] != -1 && size[i] != shape_out[1])) {
1107 MS_LOG(DEBUG) << "prepose with fc only support first output axe is sliced currently";
1108 return false;
1109 }
1110 }
1111 }
1112
1113 std::vector<int64_t> mapped_axe(shape_out.size(), -1);
1114 int64_t inner_size_in = 1;
1115 for (size_t i = 0; i < shape_in.size(); ++i) {
1116 inner_size_in *= shape_in[i];
1117 int64_t inner_size_out = 1;
1118 for (size_t j = 0; j < shape_out.size(); ++j) {
1119 inner_size_out *= shape_out[j];
1120 if (shape_out[j] == shape_in[i] && inner_size_out == inner_size_in) {
1121 mapped_axe[j] = i;
1122 break;
1123 }
1124 }
1125 }
1126 if (mapped_axe[0] == -1) {
1127 MS_LOG(DEBUG) << "first axe in output can't find correspond input axe, can't do prepose";
1128 return false;
1129 }
1130
1131 std::vector<int64_t> new_axes(shape_in.size());
1132 std::iota(new_axes.begin(), new_axes.end(), 0);
1133 std::vector<int> new_begin(shape_in.size(), 0);
1134 std::vector<int> new_size(shape_in.size(), -1);
1135 new_begin[mapped_axe[0]] = begin[0];
1136 new_size[mapped_axe[0]] = size[0];
1137 auto new_slice_vnode = CreateSliceValueNode(new_axes);
1138 if (new_slice_vnode == nullptr) {
1139 MS_LOG(ERROR) << "CreateSliceValueNode failed";
1140 return false;
1141 }
1142
1143 auto manager = graph->manager();
1144 std::shared_ptr<FuncGraphTransaction> tr = std::make_shared<FuncGraphTransaction>(manager.get());
1145 if (tr == nullptr) {
1146 MS_LOG(ERROR) << "create FuncGraphTransaction failed";
1147 return false;
1148 }
1149 auto begin_parameter = BuildIntVecParameterNode(
1150 graph, new_begin, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1151 node_name_index += 1;
1152 auto size_parameter = BuildIntVecParameterNode(
1153 graph, new_size, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1154 MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1155 MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1156 node_name_index += 1;
1157 const std::vector<AnfNodePtr> inputs = {new_slice_vnode, fc_cnode->input(1), begin_parameter, size_parameter};
1158 auto new_slice_cnode = InsertSlice(graph, inputs, fc_cnode, 1, tr);
1159 MS_CHECK_TRUE_MSG(new_slice_cnode != nullptr, false, "InsertSlice Failed");
1160
1161 fc_cnode->set_abstract(slice_cnode->abstract()->Clone());
1162 new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone());
1163 ClearCNodeAbstractValue(new_slice_cnode);
1164
1165 auto node_users = manager->node_users()[slice_cnode];
1166 for (auto &node_user : node_users) {
1167 tr->SetEdge(node_user.first, node_user.second, fc_cnode);
1168 }
1169 tr->Commit();
1170 return true;
1171 }
1172
1173 /*
1174 * Prepose condition:
1175 * not require shape info, can always prepose
1176 */
PreposeWithTranspose(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & transpose_cnode)1177 bool SlicePreposePass::PreposeWithTranspose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
1178 const CNodePtr &transpose_cnode) {
1179 MS_ASSERT(graph != nullptr);
1180 MS_ASSERT(slice_cnode != nullptr);
1181 MS_ASSERT(transpose_cnode != nullptr);
1182 if (transpose_cnode->inputs().size() != 3) {
1183 MS_LOG(ERROR) << "transpose inputs size should be 3.";
1184 return false;
1185 }
1186 auto perm = GetTransposePerm(transpose_cnode);
1187 if (perm.empty()) {
1188 return false;
1189 }
1190 auto slice_node = GetSlice(slice_cnode);
1191 if (slice_node == nullptr) {
1192 MS_LOG(ERROR) << "GetSlicT failed";
1193 return false;
1194 }
1195 auto old_axes = slice_node->get_axes();
1196 auto old_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
1197 auto old_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
1198 auto slice_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
1199 auto slice_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
1200 // perm is random shuffle of [0...n-1] according to ops/transpose.cc
1201 for (size_t i = 0; i < perm.size(); ++i) {
1202 if (perm[i] != static_cast<int>(i)) {
1203 for (size_t j = 0; j < old_axes.size(); ++j) {
1204 if (old_axes[j] == static_cast<int>(i)) {
1205 MS_CHECK_TRUE_MSG(static_cast<int>(slice_begin.size()) > perm[i], false, "slice_begin.size() is wrong");
1206 MS_CHECK_TRUE_MSG(static_cast<int>(slice_size.size()) > perm[i], false, "slice_size.size() is wrong");
1207 slice_begin[perm[i]] = old_begin[j];
1208 slice_size[perm[i]] = old_size[j];
1209 break;
1210 }
1211 }
1212 }
1213 }
1214 auto begin_parameter = BuildIntVecParameterNode(
1215 graph, slice_begin, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1216 node_name_index += 1;
1217 auto size_parameter = BuildIntVecParameterNode(
1218 graph, slice_size, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1219 MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1220 MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1221 node_name_index += 1;
1222 slice_cnode->set_input(SliceBeginIndex, begin_parameter);
1223 slice_cnode->set_input(SliceSizeIndex, size_parameter);
1224 auto status = SwapSliceWithPreceed(graph, slice_cnode, transpose_cnode, 1);
1225 if (status != RET_OK) {
1226 return false;
1227 }
1228 transpose_cnode->set_abstract(slice_cnode->abstract()->Clone());
1229 ClearCNodeAbstractValue(slice_cnode);
1230 return true;
1231 }
1232 /*
1233 * Prepose condition:
1234 * may or may not require shape info
1235 */
PreposeWithArithmetic(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & arithmetic_cnode)1236 bool SlicePreposePass::PreposeWithArithmetic(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
1237 const CNodePtr &arithmetic_cnode) {
1238 MS_ASSERT(graph != nullptr);
1239 MS_ASSERT(slice_cnode != nullptr);
1240 MS_ASSERT(arithmetic_cnode != nullptr);
1241 auto manager = graph->manager();
1242 MS_ASSERT(manager != nullptr);
1243 auto node_users = manager->node_users()[slice_cnode];
1244 std::shared_ptr<FuncGraphTransaction> tr = std::make_shared<FuncGraphTransaction>(manager.get());
1245 if (tr == nullptr) {
1246 MS_LOG(ERROR) << "create FuncGraphTransaction failed";
1247 return false;
1248 }
1249 bool changed = false;
1250 std::vector<AnfNodePtr> inputs;
1251 std::vector<std::vector<int64_t>> shapes;
1252 std::vector<bool> is_default_params;
1253 if (!GetArithmeticInputInfo(arithmetic_cnode, &inputs, &shapes, &is_default_params)) {
1254 return false;
1255 }
1256
1257 for (size_t i = 1; i < arithmetic_cnode->inputs().size(); ++i) {
1258 auto &input = inputs[i - 1];
1259 if (IsScalarNode(input)) { // scalar not need prepose
1260 continue;
1261 }
1262 auto &shape = shapes[i - 1];
1263 const size_t another_index = kArithmeticInputNum - i;
1264 auto &another_input = inputs[another_index];
1265 auto &another_shape = shapes[another_index];
1266 if (IsScalarNode(input)) {
1267 continue;
1268 } else if (shape.empty()) { // infershape failed at this input
1269 if (IsScalarNode(another_input)) { // if another input is scalar, we can process this one
1270 auto new_slice_vnode = CopySliceValueNode(slice_cnode);
1271 if (new_slice_vnode == nullptr) {
1272 changed = false;
1273 break;
1274 }
1275 std::vector<AnfNodePtr> slice_inputs = {new_slice_vnode, arithmetic_cnode->input(i),
1276 slice_cnode->input(SliceBeginIndex),
1277 slice_cnode->input(SliceSizeIndex)};
1278 auto new_slice_cnode = InsertSlice(graph, slice_inputs, arithmetic_cnode, i, tr);
1279 MS_CHECK_TRUE_MSG(new_slice_cnode != nullptr, false, "InsertSlice Failed");
1280
1281 new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone());
1282 ClearCNodeAbstractValue(new_slice_cnode);
1283 changed = true;
1284 break;
1285 } else { // if another input's shape is not scalar, can't be processed
1286 changed = false;
1287 break;
1288 }
1289 } else { // shape not empty
1290 if (!another_shape.empty() || IsScalarNode(another_input)) {
1291 std::vector<int64_t> new_axes;
1292 std::vector<int> new_begin;
1293 std::vector<int> new_size;
1294 auto status = SliceParamDeBroadcast(slice_cnode, shape, &new_axes, &new_begin, &new_size);
1295 if (status == lite::RET_NO_CHANGE) {
1296 continue;
1297 }
1298 if (status != lite::RET_OK) {
1299 changed = false;
1300 break;
1301 }
1302 auto new_slice_vnode = CreateSliceValueNode(new_axes);
1303 if (new_slice_vnode == nullptr) {
1304 changed = false;
1305 break;
1306 }
1307 auto begin_parameter = BuildIntVecParameterNode(
1308 graph, new_begin, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1309 node_name_index += 1;
1310 auto size_parameter = BuildIntVecParameterNode(
1311 graph, new_size, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1312 MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1313 MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1314 node_name_index += 1;
1315 std::vector<AnfNodePtr> slice_inputs = {new_slice_vnode, arithmetic_cnode->input(i), begin_parameter,
1316 size_parameter};
1317 auto new_slice_cnode = InsertSlice(graph, slice_inputs, arithmetic_cnode, i, tr);
1318 MS_CHECK_TRUE_MSG(new_slice_cnode != nullptr, false, "InsertSlice Failed");
1319 new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone());
1320 ClearCNodeAbstractValue(new_slice_cnode);
1321 changed = true;
1322 } else {
1323 changed = false;
1324 break;
1325 }
1326 }
1327 }
1328 if (changed) {
1329 arithmetic_cnode->set_abstract(slice_cnode->abstract()->Clone());
1330 for (auto &node_user : node_users) {
1331 tr->SetEdge(node_user.first, node_user.second, arithmetic_cnode);
1332 }
1333 tr->Commit();
1334 // we don't need graph->DropNode(slice_cnode);
1335 }
1336 return changed;
1337 } // namespace mindspore::opt
1338 /*
1339 * Prepose condition:
1340 * not require shape info
1341 */
MergeSequentialSlice(const FuncGraphPtr & graph,const CNodePtr & slice1_cnode,const CNodePtr & slice2_cnode)1342 bool SlicePreposePass::MergeSequentialSlice(const FuncGraphPtr &graph, const CNodePtr &slice1_cnode,
1343 const CNodePtr &slice2_cnode) {
1344 if (slice2_cnode->inputs().size() != kArithmeticInputNum) {
1345 MS_LOG(INFO) << "Slice read attrs from input is not supported now";
1346 return false;
1347 }
1348 auto slice1_node = GetSlice(slice1_cnode); // bottom node
1349 auto slice2_node = GetSlice(slice2_cnode); // top node
1350 if (slice1_node == nullptr || slice2_node == nullptr) {
1351 MS_LOG(ERROR) << "slice is null";
1352 return false;
1353 }
1354 auto begin_slice1 = GetSliceBeginAndSize(slice1_cnode, SliceBeginIndex);
1355 auto size_slice1 = GetSliceBeginAndSize(slice1_cnode, SliceSizeIndex);
1356 auto axes_slice1 = slice1_node->get_axes();
1357 auto begin_slice2 = GetSliceBeginAndSize(slice2_cnode, SliceBeginIndex);
1358 auto size_slice2 = GetSliceBeginAndSize(slice2_cnode, SliceSizeIndex);
1359 auto axes_slice2 = slice2_node->get_axes();
1360 auto status1 = VerifySliceAttrs(slice1_cnode);
1361 auto status2 = VerifySliceAttrs(slice2_cnode);
1362 if (status1 != RET_OK || status2 != RET_OK) {
1363 return false;
1364 }
1365
1366 auto manager = graph->manager();
1367 MS_ASSERT(manager != nullptr);
1368 auto node_users = manager->node_users()[slice1_cnode];
1369 int64_t axe_max1 = *std::max_element(axes_slice1.begin(), axes_slice1.end());
1370 int64_t axe_max2 = *std::max_element(axes_slice2.begin(), axes_slice2.end());
1371 int64_t axe_max = std::max(axe_max1, axe_max2);
1372 auto begin_new = begin_slice2;
1373 auto size_new = size_slice2;
1374 auto axes_new = slice2_node->get_axes();
1375 axes_new.resize(axe_max + 1);
1376 std::iota(axes_new.begin(), axes_new.end(), 0);
1377 begin_new.assign(axe_max + 1, 0);
1378 size_new.assign(axe_max + 1, -1);
1379 MS_CHECK_TRUE_MSG(begin_slice2.size() >= axes_slice2.size(), false, "begin_slice2.size() is wrong");
1380 MS_CHECK_TRUE_MSG(size_slice2.size() >= axes_slice2.size(), false, "size_slice2.size() is wrong");
1381 MS_CHECK_TRUE_MSG(size_slice1.size() >= axes_slice1.size(), false, "size_slice1.size() is wrong");
1382 MS_CHECK_TRUE_MSG(begin_slice1.size() >= axes_slice1.size(), false, "begin_slice1.size() is wrong");
1383 for (int i = 0; i <= axe_max; ++i) {
1384 for (size_t j = 0; j < axes_slice2.size(); ++j) {
1385 if (axes_slice2[j] == i) {
1386 begin_new[i] = begin_slice2[j];
1387 size_new[i] = size_slice2[j];
1388 break;
1389 }
1390 }
1391 for (size_t j = 0; j < axes_slice1.size(); ++j) {
1392 if (axes_slice1[j] == i) {
1393 begin_new[i] = begin_new[i] + begin_slice1[j];
1394 if (size_new[i] == -1) {
1395 size_new[i] = size_slice1[j];
1396 } else {
1397 if (size_slice1[j] == -1) {
1398 size_new[i] = std::max(size_new[i] - begin_slice1[i], 0); // clip with zero to avoid invalid negative value
1399 } else {
1400 size_new[i] = std::max(std::min(size_new[i] - begin_slice1[j], size_slice1[j]), 0);
1401 }
1402 }
1403 break;
1404 }
1405 }
1406 }
1407 slice2_node->set_axes(axes_new);
1408 auto begin_parameter = BuildIntVecParameterNode(
1409 graph, begin_new, slice2_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1410 node_name_index += 1;
1411 auto size_parameter = BuildIntVecParameterNode(
1412 graph, size_new, slice2_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1413 MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1414 MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1415 node_name_index += 1;
1416 slice2_cnode->set_input(SliceBeginIndex, begin_parameter);
1417 slice2_cnode->set_input(SliceSizeIndex, size_parameter);
1418 slice2_cnode->set_abstract(slice1_cnode->abstract()->Clone());
1419 for (auto &node_user : node_users) {
1420 manager->SetEdge(node_user.first, node_user.second, slice2_cnode);
1421 }
1422 return true;
1423 }
1424
1425 /*
1426 * Prepose condition:
1427 * when all sibling slices do same work
1428 * can be optimize to not require all siblings are slice
1429 */
MergeParallelSlice(const FuncGraphPtr & graph,const NodeUsedListPtr & slices)1430 bool SlicePreposePass::MergeParallelSlice(const FuncGraphPtr &graph, const NodeUsedListPtr &slices) {
1431 MS_ASSERT(graph != nullptr);
1432 MS_ASSERT(slices->size() >= 2);
1433 auto manager = graph->manager();
1434 MS_ASSERT(manager != nullptr);
1435 auto first_slice = utils::cast<CNodePtr>(slices->at(0).first);
1436 MS_ASSERT(first_slice == nullptr);
1437 if (!CheckPrimitiveType(first_slice, prim::kPrimSliceFusion)) {
1438 MS_LOG(ERROR) << "first node is not Slice";
1439 return false;
1440 }
1441 auto first_parent = first_slice->input(1);
1442 if (first_parent == nullptr) {
1443 MS_LOG(ERROR) << "first slice node's parent is nullptr";
1444 return false;
1445 }
1446 std::shared_ptr<FuncGraphTransaction> tr = std::make_shared<FuncGraphTransaction>(manager.get());
1447 if (tr == nullptr) {
1448 MS_LOG(ERROR) << "create FuncGraphTransaction failed";
1449 return false;
1450 }
1451 for (size_t i = 1; i < slices->size(); ++i) {
1452 auto slice = utils::cast<CNodePtr>(slices->at(i).first);
1453 MS_ASSERT(slice == nullptr);
1454 if (!CheckPrimitiveType(slice, prim::kPrimSliceFusion)) {
1455 MS_LOG(ERROR) << "current node is not Slice";
1456 return false;
1457 }
1458 auto parent = slice->input(1);
1459 if (parent == nullptr || parent != first_parent) {
1460 MS_LOG(ERROR) << "not all slices have same parent node";
1461 return false;
1462 }
1463 auto node_users = manager->node_users()[slices->at(i).first];
1464 for (auto &node_user : node_users) {
1465 tr->SetEdge(node_user.first, node_user.second, slices->at(0).first);
1466 }
1467 }
1468 tr->Commit();
1469 return true;
1470 }
1471
DoPrepose(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & preceed_cnode)1472 bool SlicePreposePass::DoPrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
1473 const CNodePtr &preceed_cnode) {
1474 MS_ASSERT(graph != nullptr);
1475 MS_ASSERT(slice_cnode != nullptr);
1476 MS_ASSERT(preceed_cnode != nullptr);
1477 if (CheckPrimitiveType(preceed_cnode, prim::kPrimSoftmax)) {
1478 return PreposeWithSoftmax(graph, slice_cnode, preceed_cnode);
1479 } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimReshape)) {
1480 return PreposeWithReshape(graph, slice_cnode, preceed_cnode);
1481 } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimMatMul)) {
1482 return PreposeWithMatmul(graph, slice_cnode, preceed_cnode);
1483 } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimFullConnection)) {
1484 return PreposeWithFullConnection(graph, slice_cnode, preceed_cnode);
1485 } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimTranspose)) {
1486 return PreposeWithTranspose(graph, slice_cnode, preceed_cnode);
1487 } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimSubFusion) ||
1488 CheckPrimitiveType(preceed_cnode, prim::kPrimMulFusion) ||
1489 CheckPrimitiveType(preceed_cnode, prim::kPrimAddFusion)) {
1490 return PreposeWithArithmetic(graph, slice_cnode, preceed_cnode);
1491 } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimSliceFusion)) {
1492 return MergeSequentialSlice(graph, slice_cnode, preceed_cnode);
1493 }
1494 return false;
1495 }
1496
Run(const FuncGraphPtr & graph)1497 bool SlicePreposePass::Run(const FuncGraphPtr &graph) {
1498 if (fmk_type != converter::kFmkTypeTf && fmk_type != converter::kFmkTypeTflite) {
1499 MS_LOG(INFO) << "The framework type of model should be tf/tflite.";
1500 return false;
1501 }
1502 MS_ASSERT(graph != nullptr);
1503 bool changed = false;
1504 while (true) {
1505 bool this_time_changed = false;
1506 auto node_list = TopoSort(graph->get_return());
1507 for (auto &node : node_list) {
1508 if (node->func_graph() != graph || !utils::isa<CNodePtr>(node) ||
1509 !CheckPrimitiveType(node, prim::kPrimSliceFusion)) {
1510 continue;
1511 }
1512 auto slice_cnode = node->cast<CNodePtr>();
1513 MS_ASSERT(slice_cnode != nullptr);
1514 // only support begin and size is const tensor.
1515 if (!CheckIsAllInputsParam(slice_cnode) || GetSlice(slice_cnode)) {
1516 continue;
1517 }
1518 auto preceed_node = slice_cnode->input(1);
1519 if (preceed_node == nullptr) {
1520 MS_LOG(ERROR) << "proceed node is nullptr";
1521 continue;
1522 }
1523 auto output_tensor_num = GetOutputTensorNum(preceed_node);
1524 if (output_tensor_num > 1) {
1525 continue;
1526 }
1527 auto output_node_list = GetRealNodeUsedList(graph, utils::cast<AnfNodePtr>(preceed_node));
1528 if (output_node_list->size() > 1) { // referenced by multi nodes
1529 if (SiblingsAreSameSlice(output_node_list) && MergeParallelSlice(graph, output_node_list)) {
1530 this_time_changed = true;
1531 break;
1532 }
1533 continue;
1534 } else {
1535 if (utils::isa<ParameterPtr>(preceed_node)) {
1536 /*
1537 * if preceed_node is parameter without default param, it's input placeholder, so we can't prepose
1538 * if preceed_node is parameter with default param, constant_folding will process it
1539 */
1540 continue;
1541 }
1542 auto preceed_cnode = preceed_node->cast<CNodePtr>();
1543 if (preceed_cnode == nullptr) {
1544 MS_LOG(ERROR) << "preceed_cnode is nullptr";
1545 continue;
1546 }
1547 if (DoPrepose(graph, slice_cnode, preceed_cnode)) {
1548 this_time_changed = true;
1549 break;
1550 }
1551 }
1552 }
1553 if (this_time_changed) {
1554 changed = true;
1555 } else {
1556 break;
1557 }
1558 }
1559 return changed;
1560 }
1561 } // namespace mindspore::opt
1562