1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/graph/slice_prepose_pass.h"
19 #include <vector>
20 #include <memory>
21 #include <set>
22 #include <algorithm>
23 #include "mindspore/core/ops/nn_ops.h"
24 #include "mindspore/core/ops/lite_ops.h"
25 #include "mindspore/core/ops/array_ops.h"
26 #include "ops/fusion/full_connection.h"
27 #include "ops/auto_generate/gen_lite_ops.h"
28 #include "ops/fusion/slice_fusion.h"
29 #include "ops/op_utils.h"
30 #include "include/errorcode.h"
31 #include "tools/optimizer/common/gllo_utils.h"
32 #include "tools/optimizer/common/helper.h"
33 #include "include/backend/optimizer/helper.h"
34 #include "src/common/log_adapter.h"
35 #include "nnacl/op_base.h"
36
37 namespace mindspore::opt {
38 namespace {
39 const int kArithmeticInputNum = 2;
40 const int SliceBeginIndex = 2;
41 const int SliceSizeIndex = 3;
42 int node_name_index = 0;
GetSliceBeginAndSize(const CNodePtr & cnode,const int index)43 std::vector<int> GetSliceBeginAndSize(const CNodePtr &cnode, const int index) {
44 MS_ASSERT(cnode != nullptr);
45 std::vector<int> content;
46 if (index != SliceBeginIndex && index != SliceSizeIndex && cnode->size() != 4) {
47 return content;
48 }
49 auto node = cnode->input(index);
50 if (node == nullptr) {
51 return content;
52 }
53 auto param_node = node->cast<ParameterPtr>();
54 if (param_node == nullptr || !param_node->has_default() || param_node->default_param() == nullptr) {
55 return content;
56 }
57 auto tensor_info = param_node->default_param()->cast<tensor::TensorPtr>();
58 if (tensor_info == nullptr) {
59 return content;
60 }
61 content.resize(tensor_info->DataSize());
62 if (memcpy_s(content.data(), tensor_info->Size(), tensor_info->data_c(), tensor_info->Size()) != EOK) {
63 MS_LOG(ERROR) << "memcpy data failed.";
64 return {};
65 }
66 return content;
67 }
68
GetCNodeInputShape(const CNodePtr & cnode,size_t index=1)69 std::vector<int64_t> GetCNodeInputShape(const CNodePtr &cnode, size_t index = 1) {
70 MS_ASSERT(cnode != nullptr);
71 std::vector<int64_t> empty_shape;
72 if (index < 1 || cnode->size() <= index) {
73 MS_LOG(ERROR) << "out of index";
74 return empty_shape;
75 }
76 auto abstract = GetCNodeInputAbstract(cnode, index);
77 if (abstract == nullptr) {
78 MS_LOG(ERROR) << "Abstract of CNode is nullptr";
79 return empty_shape;
80 }
81 if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
82 MS_LOG(DEBUG) << "abstract is not AbstractTensor";
83 return empty_shape;
84 }
85 auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
86 MS_ASSERT(abstract_tensor != nullptr && abstract_tensor->shape() != nullptr);
87 return abstract_tensor->shape()->shape();
88 }
89
GetDefaultParamShape(const ParameterPtr & param)90 std::vector<int64_t> GetDefaultParamShape(const ParameterPtr ¶m) {
91 MS_ASSERT(param != nullptr);
92 MS_ASSERT(param->has_default());
93 std::vector<int64_t> shape_vector;
94 auto default_param = param->default_param();
95 if (default_param == nullptr) {
96 MS_LOG(ERROR) << "default_param is nullptr";
97 return shape_vector;
98 }
99 if (!utils::isa<tensor::TensorPtr>(default_param)) {
100 MS_LOG(ERROR) << "default_param is not tensor::Tensor";
101 return shape_vector;
102 }
103 auto param_value_lite = utils::cast<tensor::TensorPtr>(default_param);
104 MS_ASSERT(param_value != nullptr);
105 auto shape = param_value_lite->shape();
106 std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
107 [](const int val) { return static_cast<int64_t>(val); });
108 return shape_vector;
109 }
110
IsScalarNode(const AnfNodePtr & nodePtr)111 bool IsScalarNode(const AnfNodePtr &nodePtr) {
112 MS_ASSERT(nodePtr != nullptr);
113 if (utils::isa<ParameterPtr>(nodePtr) && nodePtr->cast<ParameterPtr>()->has_default()) {
114 auto tensor = utils::cast<tensor::TensorPtr>(utils::cast<ParameterPtr>(nodePtr)->default_param());
115 MS_ASSERT(tensor != nullptr);
116 auto shape = tensor->shape();
117 if (shape.empty() || (shape.size() == 1 && shape[0] == 1)) {
118 return true;
119 }
120 }
121 return false;
122 }
123
GetSlice(const CNodePtr & cnode)124 api::SharedPtr<mindspore::ops::SliceFusion> GetSlice(const CNodePtr &cnode) {
125 if (cnode == nullptr) {
126 return nullptr;
127 }
128 return ops::GetOperator<mindspore::ops::SliceFusion>(cnode->input(0));
129 }
130
GetSoftmax(const CNodePtr & cnode)131 api::SharedPtr<mindspore::ops::Softmax> GetSoftmax(const CNodePtr &cnode) {
132 if (cnode == nullptr) {
133 return nullptr;
134 }
135 return ops::GetOperator<mindspore::ops::Softmax>(cnode->input(0));
136 }
137
GetReshape(const CNodePtr & cnode)138 api::SharedPtr<mindspore::ops::Reshape> GetReshape(const CNodePtr &cnode) {
139 if (cnode == nullptr) {
140 return nullptr;
141 }
142 return ops::GetOperator<mindspore::ops::Reshape>(cnode->input(0));
143 }
144
GetFc(const CNodePtr & cnode)145 api::SharedPtr<mindspore::ops::FullConnection> GetFc(const CNodePtr &cnode) {
146 if (cnode == nullptr) {
147 return nullptr;
148 }
149 return ops::GetOperator<mindspore::ops::FullConnection>(cnode->input(0));
150 }
151
GetTransposePerm(const CNodePtr & node)152 std::vector<int> GetTransposePerm(const CNodePtr &node) {
153 MS_ASSERT(node != nullptr);
154 std::vector<int> perm;
155 if (!CheckPrimitiveType(node, prim::kPrimTranspose)) {
156 return perm;
157 }
158 if (node->size() != 3) {
159 return perm;
160 }
161 auto perm_node = node->input(2);
162 if (!utils::isa<ParameterPtr>(perm_node)) {
163 return perm;
164 }
165 auto perm_param = perm_node->cast<ParameterPtr>();
166 MS_ASSERT(perm_param != nullptr);
167 if (!perm_param->has_default() || perm_param->default_param() == nullptr) {
168 return perm;
169 }
170 auto perm_value = perm_param->default_param()->cast<tensor::TensorPtr>();
171 if (perm_value == nullptr) {
172 return perm;
173 }
174 MS_CHECK_TRUE_MSG(perm_value->shape().size() != 0, {}, "shape is empty");
175 perm.resize(perm_value->shape()[0]);
176 if (memcpy_s(perm.data(), perm_value->Size(), perm_value->data_c(), perm_value->Size()) != EOK) {
177 MS_LOG(ERROR) << "memcpy failed.";
178 return {};
179 }
180 return perm;
181 }
182 } // namespace
183
ClearCNodeAbstractValue(const CNodePtr & cnode)184 void SlicePreposePass::ClearCNodeAbstractValue(const CNodePtr &cnode) {
185 MS_ASSERT(cnode != nullptr);
186 auto abstract = cnode->abstract();
187 MS_ASSERT(abstract != nullptr);
188 if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
189 MS_LOG(DEBUG) << "Abstract of cnode is not abstract tensor, " << cnode->fullname_with_scope();
190 }
191 abstract->set_value(std::make_shared<ValueAny>());
192 }
193
SwapSliceWithPreceed(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & preceed_cnode,const int index,const TransactionPtr & tr)194 STATUS SlicePreposePass::SwapSliceWithPreceed(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
195 const CNodePtr &preceed_cnode, const int index,
196 const TransactionPtr &tr) {
197 MS_ASSERT(graph != nullptr);
198 MS_ASSERT(slice_cnode != nullptr);
199 MS_ASSERT(preceed_cnode != nullptr);
200 if (slice_cnode->input(1) != preceed_cnode) {
201 MS_LOG(ERROR) << "proceed node must be slice node's direct parent";
202 return RET_ERROR;
203 }
204 if (IsMultiOutputTensors(graph, preceed_cnode)) {
205 MS_LOG(ERROR) << "proceed node referenced by multi nodes not support swap";
206 return RET_ERROR;
207 }
208 auto manager = graph->manager();
209 if (manager == nullptr) {
210 MS_LOG(ERROR) << "manager is nullptr";
211 return RET_ERROR;
212 }
213 auto node_users = manager->node_users()[slice_cnode];
214 if (tr != nullptr) { // do swap with transaction
215 for (auto &node_user : node_users) {
216 tr->SetEdge(node_user.first, node_user.second, preceed_cnode);
217 }
218 tr->SetEdge(slice_cnode, 1, preceed_cnode->input(index));
219 tr->SetEdge(preceed_cnode, index, slice_cnode);
220 } else {
221 for (auto &node_user : node_users) {
222 manager->SetEdge(node_user.first, node_user.second, preceed_cnode);
223 }
224 manager->SetEdge(slice_cnode, 1, preceed_cnode->input(index));
225 manager->SetEdge(preceed_cnode, index, slice_cnode);
226 }
227 return RET_OK;
228 }
229
CreateSliceValueNode(const std::vector<int64_t> & axes)230 ValueNodePtr SlicePreposePass::CreateSliceValueNode(const std::vector<int64_t> &axes) {
231 MS_ASSERT(graph != nullptr);
232 MS_ASSERT(slice_cnode != nullptr);
233 auto new_slice = std::make_shared<mindspore::ops::SliceFusion>();
234 MS_CHECK_TRUE_MSG(new_slice != nullptr, nullptr, "new_slice is nullptr");
235 auto new_slice_c = new_slice->GetPrim();
236 MS_CHECK_TRUE_MSG(new_slice_c != nullptr, nullptr, "new_slice_c is nullptr");
237 new_slice->set_axes(axes);
238 ValueNodePtr value_node = NewValueNode(new_slice_c);
239 MS_CHECK_TRUE_MSG(value_node != nullptr, nullptr, "NewValueNode Failed");
240 return value_node;
241 }
242
CopySliceValueNode(const CNodePtr & slice_cnode)243 ValueNodePtr SlicePreposePass::CopySliceValueNode(const CNodePtr &slice_cnode) {
244 MS_ASSERT(graph != nullptr);
245 MS_ASSERT(slice_cnode != nullptr);
246 auto slice_c = ops::GetOperator<mindspore::ops::SliceFusion>(slice_cnode->input(0));
247 if (slice_c == nullptr) {
248 MS_LOG(ERROR) << "slice node is nullptr";
249 return nullptr;
250 }
251 auto new_slice = std::make_shared<mindspore::ops::SliceFusion>();
252 MS_CHECK_TRUE_MSG(new_slice != nullptr, nullptr, "new_slice_c is nullptr");
253 auto new_slice_c = new_slice->GetPrim();
254 MS_CHECK_TRUE_MSG(new_slice_c != nullptr, nullptr, "new_slice_c is nullptr");
255 new_slice->set_axes(new_slice->get_axes());
256 ValueNodePtr value_node = NewValueNode(new_slice_c);
257 MS_CHECK_TRUE_MSG(value_node != nullptr, nullptr, "NewValueNode Failed");
258 return value_node;
259 }
260
InsertSlice(const FuncGraphPtr & graph,const std::vector<AnfNodePtr> & inputs,const CNodePtr & preceed_cnode,const int index,const TransactionPtr & tr)261 CNodePtr SlicePreposePass::InsertSlice(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &inputs,
262 const CNodePtr &preceed_cnode, const int index, const TransactionPtr &tr) {
263 MS_ASSERT(graph != nullptr);
264 MS_ASSERT(slice_cnode != nullptr);
265 MS_ASSERT(preceed_cnode != nullptr);
266 auto slice_cnode = graph->NewCNode(inputs);
267 MS_CHECK_TRUE_MSG(slice_cnode != nullptr, nullptr, "NewNode Failed");
268 slice_cnode->set_fullname_with_scope(preceed_cnode->fullname_with_scope() + "_slice_" +
269 std::to_string(node_name_index));
270 node_name_index += 1;
271 tr->SetEdge(preceed_cnode, index, slice_cnode);
272 return slice_cnode;
273 }
274
VerifySliceAttrs(const CNodePtr & slice_cnode,const int dim)275 STATUS SlicePreposePass::VerifySliceAttrs(const CNodePtr &slice_cnode, const int dim) {
276 // according to ops/slice.cc, axes >= 0, begin >= 0, size >= -1
277 auto slice = GetSlice(slice_cnode);
278 if (slice == nullptr) {
279 MS_LOG(ERROR) << "Slice is nullptr";
280 return RET_ERROR;
281 }
282 auto axes = slice->get_axes();
283 auto begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
284 auto size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
285
286 std::set<int64_t> unique_axes(axes.begin(), axes.end());
287 if (axes.empty() || unique_axes.size() != axes.size()) {
288 MS_LOG(DEBUG) << "Invalid slice axe attribute";
289 return RET_ERROR;
290 }
291 MS_CHECK_TRUE_MSG(begin.size() <= axes.size(), RET_ERROR, "begin size is wrong");
292 MS_CHECK_TRUE_MSG(size.size() <= axes.size(), RET_ERROR, "size.size() is wrong");
293 for (size_t i = 0; i < axes.size(); ++i) {
294 auto axe = axes[i];
295 if (dim > -1 && axe >= dim) {
296 MS_LOG(ERROR) << "Invalid slice axe attribute";
297 return RET_ERROR;
298 }
299 if (axe < 0) {
300 MS_LOG(ERROR) << "Invalid slice axe attribute";
301 return RET_ERROR;
302 }
303 if (begin[i] < 0) { // we not require begin[i] < ref_shape[axe], cause there may be broadcast
304 MS_LOG(ERROR) << "Invalid begin input! begin[" << i << "]=" << begin[i];
305 return RET_ERROR;
306 }
307 if (size[i] < -1) {
308 MS_LOG(ERROR) << "Invalid size input! size[" << i << "]=" << size[i];
309 return RET_ERROR;
310 }
311 }
312 return RET_OK;
313 }
314
315 /*
316 * Adjust slice's attr when broadcast happened in Arithmetic
317 */
SliceParamDeBroadcast(const CNodePtr & slice_cnode,const std::vector<int64_t> & ref_shape,std::vector<int64_t> * axes,std::vector<int> * begin,std::vector<int> * size)318 STATUS SlicePreposePass::SliceParamDeBroadcast(const CNodePtr &slice_cnode, const std::vector<int64_t> &ref_shape,
319 std::vector<int64_t> *axes, std::vector<int> *begin,
320 std::vector<int> *size) {
321 MS_ASSERT(slice_cnode != nullptr);
322 MS_ASSERT(new_slice_cnode != nullptr);
323 MS_ASSERT(axes != nullptr);
324 MS_ASSERT(begin != nullptr);
325 MS_ASSERT(size != nullptr);
326 auto slice = GetSlice(slice_cnode);
327 if (slice == nullptr) {
328 MS_LOG(ERROR) << "slice is nullptr";
329 return RET_ERROR;
330 }
331 auto origin_axes = slice->get_axes();
332 auto origin_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
333 auto origin_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
334 auto status = VerifySliceAttrs(slice_cnode, ref_shape.size());
335 if (status != RET_OK) {
336 return status;
337 }
338 axes->resize(ref_shape.size());
339 std::iota(axes->begin(), axes->end(), 0);
340 begin->assign(ref_shape.size(), 0);
341 size->assign(ref_shape.size(), -1);
342 bool real_slice = false; // whether slice happened at this input
343 MS_CHECK_TRUE_MSG(origin_begin.size() >= origin_axes.size(), RET_ERROR, "origin_begin.size() is wrong");
344 MS_CHECK_TRUE_MSG(origin_size.size() >= origin_axes.size(), RET_ERROR, "origin_size.size() is wrong");
345 for (size_t i = 0; i < origin_axes.size(); ++i) {
346 int a = origin_axes[i];
347 int b = origin_begin[i];
348 int s = origin_size[i];
349 MS_CHECK_TRUE_MSG(static_cast<int>(ref_shape.size()) > a, RET_ERROR, "ref_shape.size() is wrong");
350 int ref = ref_shape[a];
351 if (ref == 1) { // broadcast
352 continue; // sliced size is 0(such as begin=1,size=-1) is not considered.
353 } else if (ref > 1) { // not broadcast
354 if (b >= ref) {
355 MS_LOG(ERROR) << "slice begin[" << a << "]=" << b << ", while ref_shape[" << a << "]=" << ref << ", can't fit!";
356 return RET_ERROR;
357 } else {
358 if (b != 0 || (s != -1 && s != ref)) {
359 real_slice = true;
360 }
361 MS_CHECK_TRUE_MSG(static_cast<int>(begin->size()) > a, RET_ERROR, "begin.size() is wrong");
362 MS_CHECK_TRUE_MSG(static_cast<int>(size->size()) > a, RET_ERROR, "size.size() is wrong");
363 begin->at(a) = b;
364 size->at(a) = s;
365 }
366 } else { // ref == 0, not need slice
367 continue;
368 }
369 }
370 if (real_slice) {
371 return lite::RET_OK;
372 } else {
373 return lite::RET_NO_CHANGE;
374 }
375 }
376
CreateReshapeCNode(const FuncGraphPtr & graph,const std::vector<int64_t> & shape_vector,const AbstractBasePtr & abstract,const CNodePtr & preceed_cnode)377 CNodePtr SlicePreposePass::CreateReshapeCNode(const FuncGraphPtr &graph, const std::vector<int64_t> &shape_vector,
378 const AbstractBasePtr &abstract, const CNodePtr &preceed_cnode) {
379 MS_ASSERT(graph != nullptr);
380 MS_ASSERT(slice_cnode != nullptr);
381 MS_ASSERT(abstract != nullptr);
382 MS_ASSERT(preceed_cnode != nullptr);
383 auto new_reshape = std::make_shared<mindspore::ops::Reshape>();
384 if (new_reshape == nullptr) {
385 MS_LOG(ERROR) << "primitive_c is nullptr";
386 return nullptr;
387 }
388 auto new_reshape_c = new_reshape->GetPrim();
389 MS_CHECK_TRUE_MSG(new_reshape_c != nullptr, nullptr, "new_reshape_c is nullptr");
390 ValueNodePtr value_node = NewValueNode(new_reshape_c);
391 if (value_node == nullptr) {
392 return nullptr;
393 }
394 std::vector<int> shape;
395 std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape),
396 [](int64_t val) { return static_cast<int>(val); });
397 auto shape_node = BuildIntVecParameterNode(
398 graph, shape, preceed_cnode->fullname_with_scope() + "_shape_" + std::to_string(node_name_index));
399 node_name_index++;
400 if (shape_node == nullptr) {
401 MS_LOG(ERROR) << "build parameter node failed.";
402 return nullptr;
403 }
404 auto reshape_cnode = graph->NewCNode({value_node, preceed_cnode, shape_node});
405 MS_CHECK_TRUE_MSG(reshape_cnode != nullptr, nullptr, "NewCNode Failed");
406 reshape_cnode->set_abstract(abstract);
407 reshape_cnode->set_fullname_with_scope(preceed_cnode->fullname_with_scope() + "_reshape_" +
408 std::to_string(node_name_index));
409 node_name_index++;
410 ClearCNodeAbstractValue(reshape_cnode);
411 return reshape_cnode;
412 }
413
SiblingsAreSameSlice(const NodeUsedListPtr & output_node_list,const std::vector<int64_t> & ref_shape)414 bool SlicePreposePass::SiblingsAreSameSlice(const NodeUsedListPtr &output_node_list,
415 const std::vector<int64_t> &ref_shape) {
416 MS_ASSERT(graph != nullptr);
417 MS_ASSERT(output_node_list != nullptr);
418 MS_ASSERT(output_node_list->size() >= 2);
419 std::vector<CNodePtr> slices;
420 for (auto &output_node : *(output_node_list.get())) {
421 auto cnode = output_node.first->cast<CNodePtr>();
422 MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cnode is nullptr");
423 if (!CheckPrimitiveType(cnode, prim::kPrimSliceFusion)) {
424 return false;
425 }
426 auto slice_node = GetSlice(cnode);
427 MS_CHECK_TRUE_MSG(slice_node != nullptr, false, "Slice is nullptr");
428 slices.push_back(cnode);
429 }
430 MS_CHECK_TRUE_MSG(slices.size() > 0, false, "slices.size() is wrong");
431 auto first_slice_cnode = slices.front();
432 auto first_slice_node = GetSlice(first_slice_cnode);
433 MS_CHECK_TRUE_MSG(first_slice_node != nullptr, false, "GetSlice return nullptr");
434 auto first_axes = first_slice_node->get_axes();
435 auto first_begin = GetSliceBeginAndSize(first_slice_cnode, SliceBeginIndex);
436 auto first_size = GetSliceBeginAndSize(first_slice_cnode, SliceSizeIndex);
437 MS_CHECK_TRUE_MSG(first_begin.size() >= first_axes.size(), false, "first_begin.size() is wrong");
438 MS_CHECK_TRUE_MSG(first_size.size() >= first_axes.size(), false, "first_size.size() is wrong");
439 MS_CHECK_TRUE_MSG(slices.size() >= output_node_list->size(), false, "slices.size() is wrong");
440 for (size_t i = 1; i < output_node_list->size(); ++i) {
441 auto slice = GetSlice(slices[i]);
442 if (slice == nullptr) {
443 MS_LOG(WARNING) << "slice is nullptr!";
444 continue;
445 }
446 auto axes = slice->get_axes();
447 auto begin = GetSliceBeginAndSize(slices[i], SliceBeginIndex);
448 auto size = GetSliceBeginAndSize(slices[i], SliceSizeIndex);
449 MS_CHECK_TRUE_MSG(begin.size() >= axes.size(), false, "begin.size() is wrong");
450 MS_CHECK_TRUE_MSG(size.size() >= axes.size(), false, "size.size() is wrong");
451 if (axes.size() != first_axes.size()) {
452 return false;
453 }
454 for (size_t j = 0; j < axes.size(); ++j) {
455 auto axe = axes[j];
456 if (!ref_shape.empty() && axe >= static_cast<int>(ref_shape.size())) {
457 return false;
458 }
459 size_t k = 0;
460 for (; k < first_axes.size(); ++k) { // axes may not be [0...n-1], so we use nested loop to find it
461 if (first_axes[k] == axe) {
462 break;
463 }
464 }
465 if (k == first_axes.size()) {
466 return false;
467 }
468 if (begin[j] != first_begin[k]) {
469 return false;
470 }
471 if (size[j] != first_size[k]) {
472 if (ref_shape.empty()) {
473 return false;
474 }
475 auto actual_size = size[j] > 0 ? size[j] : ref_shape[axe] - begin[j];
476 auto actual_first_size = first_size[k] > 0 ? first_size[k] : ref_shape[axe] - first_begin[k];
477 if (actual_size != actual_first_size) {
478 return false;
479 }
480 }
481 }
482 }
483 return true;
484 }
485
GetReshapeAbnormalAxeIn(const std::vector<int64_t> & shape_in,const std::vector<int64_t> & shape_out,std::vector<int64_t> * mapped_axe)486 int64_t SlicePreposePass::GetReshapeAbnormalAxeIn(const std::vector<int64_t> &shape_in,
487 const std::vector<int64_t> &shape_out,
488 std::vector<int64_t> *mapped_axe) {
489 // find shape_out's correspond axe in shape_in
490 // when there are such as 3x1x1x4 => 3x1x4, mapped_axe[1] == 2
491 int64_t inner_size_in = 1;
492 int64_t abnormal_axe_in = -1;
493 MS_CHECK_TRUE_MSG(mapped_axe->size() >= shape_out.size(), abnormal_axe_in, "mapped_axe.size() is wrong");
494 for (size_t i = 0; i < shape_in.size(); ++i) {
495 inner_size_in *= shape_in[i];
496 int64_t inner_size_out = 1;
497 size_t j;
498 for (j = 0; j < shape_out.size(); ++j) {
499 inner_size_out *= shape_out[j];
500 if (shape_out[j] == shape_in[i] && inner_size_out == inner_size_in) {
501 mapped_axe->at(j) = i;
502 break;
503 }
504 }
505 if (j == shape_out.size() && abnormal_axe_in == -1) {
506 abnormal_axe_in = i;
507 }
508 }
509 return abnormal_axe_in;
510 }
511
GetReshapeAbnormalIndexOut(const CNodePtr & slice_cnode,const std::vector<int64_t> & mapped_axe,const std::vector<int64_t> & shape_out,std::vector<int64_t> * shape_out_copy,bool * is_normal_mode,bool * support_abnormal_mode)512 int64_t SlicePreposePass::GetReshapeAbnormalIndexOut(const CNodePtr &slice_cnode,
513 const std::vector<int64_t> &mapped_axe,
514 const std::vector<int64_t> &shape_out,
515 std::vector<int64_t> *shape_out_copy, bool *is_normal_mode,
516 bool *support_abnormal_mode) {
517 MS_ASSERT(slice_cnode != nullptr);
518 MS_ASSERT(shape_out_copy != nullptr);
519 MS_ASSERT(is_normal_mode != nullptr);
520 MS_ASSERT(support_abnormal_mode != nullptr);
521 int64_t abnormal_index_out = -1;
522 auto slice_node = GetSlice(slice_cnode);
523 if (slice_node == nullptr) {
524 MS_LOG(ERROR) << "slice is nullptr";
525 return abnormal_index_out;
526 }
527 auto slice_axes = slice_node->get_axes();
528 auto slice_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
529 auto slice_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
530 for (size_t j = 0; j < shape_out.size(); ++j) {
531 int index = -1;
532 for (size_t i = 0; i < slice_axes.size(); ++i) {
533 if (slice_axes[i] == static_cast<int64_t>(j)) {
534 index = static_cast<int>(i);
535 break;
536 }
537 }
538 if (index == -1) {
539 continue;
540 }
541 MS_CHECK_TRUE_MSG(static_cast<int>(slice_begin.size()) > index, abnormal_index_out, "slice_begin.size() is wrong");
542 MS_CHECK_TRUE_MSG(static_cast<int>(slice_size.size()) > index, abnormal_index_out, "slice_size.size() is wrong");
543 if (slice_begin[index] != 0 || (slice_size[index] != -1 && slice_size[index] != shape_out[j])) {
544 if (mapped_axe[j] == -1) {
545 if (is_normal_mode) {
546 *is_normal_mode = false;
547 abnormal_index_out = static_cast<int64_t>(index);
548 } else {
549 *support_abnormal_mode = false;
550 }
551 } else { // if there is matched axe sliced, not support abnormal mode
552 shape_out_copy->at(j) =
553 (slice_size[index] == -1 ? shape_out[j] - slice_begin[index] : static_cast<int64_t>(slice_size[index]));
554 *support_abnormal_mode = false;
555 }
556 }
557 }
558 return abnormal_index_out;
559 }
560
PreposeWithNormalReshape(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & reshape_cnode,const std::vector<int64_t> & shape_in,const std::vector<int64_t> & shape_out_copy,const std::vector<int64_t> & mapped_axe)561 bool SlicePreposePass::PreposeWithNormalReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
562 const CNodePtr &reshape_cnode, const std::vector<int64_t> &shape_in,
563 const std::vector<int64_t> &shape_out_copy,
564 const std::vector<int64_t> &mapped_axe) {
565 MS_ASSERT(graph != nullptr);
566 MS_ASSERT(slice_cnode != nullptr);
567 MS_ASSERT(reshape_cnode != nullptr);
568 auto slice_node = GetSlice(slice_cnode);
569 if (slice_node == nullptr) {
570 MS_LOG(ERROR) << "slice is nullptr";
571 return false;
572 }
573 auto slice_axes = slice_node->get_axes();
574 auto slice_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
575 auto slice_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
576 std::vector<int64_t> new_axes(shape_in.size());
577 std::iota(new_axes.begin(), new_axes.end(), 0);
578 std::vector<int> new_begin(shape_in.size(), 0);
579 std::vector<int> new_size(shape_in.size(), -1);
580 MS_CHECK_TRUE_MSG(slice_begin.size() >= mapped_axe.size(), false, "slice_begin.size() is wrong");
581 MS_CHECK_TRUE_MSG(slice_size.size() >= mapped_axe.size(), false, "slice_begin.size() is wrong");
582 for (size_t i = 0; i < mapped_axe.size(); ++i) {
583 auto axe_in = mapped_axe[i];
584 if (axe_in == -1) {
585 continue;
586 }
587 new_begin[axe_in] = slice_begin[i];
588 new_size[axe_in] = slice_size[i];
589 }
590
591 auto reshape_node = GetReshape(reshape_cnode);
592 if (reshape_node == nullptr) {
593 MS_LOG(ERROR) << "reshape is nullptr";
594 return false;
595 }
596 std::vector<int> new_shape_out_copy;
597 std::transform(shape_out_copy.begin(), shape_out_copy.end(), std::back_inserter(new_shape_out_copy),
598 [](int64_t val) { return static_cast<int>(val); });
599 auto shape_node = BuildIntVecParameterNode(
600 graph, new_shape_out_copy, reshape_cnode->fullname_with_scope() + "_shape_" + std::to_string(node_name_index));
601 node_name_index++;
602 if (shape_node == nullptr) {
603 MS_LOG(ERROR) << "build parameter node failed.";
604 return false;
605 }
606 reshape_cnode->set_inputs({reshape_cnode->input(0), reshape_cnode->input(1), shape_node});
607
608 slice_node->set_axes(new_axes);
609 auto new_begin_parameter = BuildIntVecParameterNode(
610 graph, new_begin, slice_cnode->input(SliceBeginIndex)->cast<ParameterPtr>()->fullname_with_scope());
611 auto new_size_parameter = BuildIntVecParameterNode(
612 graph, new_size, slice_cnode->input(SliceSizeIndex)->cast<ParameterPtr>()->fullname_with_scope());
613 MS_CHECK_TRUE_MSG(new_begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
614 MS_CHECK_TRUE_MSG(new_size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
615 slice_cnode->set_input(SliceBeginIndex, new_begin_parameter);
616 slice_cnode->set_input(SliceSizeIndex, new_size_parameter);
617 auto status = SwapSliceWithPreceed(graph, slice_cnode, reshape_cnode, 1);
618 if (status != RET_OK) {
619 return false;
620 }
621 reshape_cnode->set_abstract(slice_cnode->abstract()->Clone());
622 ClearCNodeAbstractValue(slice_cnode);
623 return true;
624 }
625
CreateSlice1ForReshapePrepose(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & matmul_cnode,const std::vector<int64_t> & shape_in,const int64_t abnormal_axe_in,const int64_t count_sliced_axe_in,const bool slice_at_front)626 CNodePtr SlicePreposePass::CreateSlice1ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
627 const CNodePtr &matmul_cnode,
628 const std::vector<int64_t> &shape_in,
629 const int64_t abnormal_axe_in,
630 const int64_t count_sliced_axe_in, const bool slice_at_front) {
631 MS_ASSERT(graph != nullptr);
632 MS_ASSERT(slice_cnode != nullptr);
633 MS_ASSERT(matmul_cnode != nullptr);
634 std::vector<int64_t> new_axes1(shape_in.size());
635 std::iota(new_axes1.begin(), new_axes1.end(), 0);
636 std::vector<int> new_begin1(shape_in.size(), 0);
637 std::vector<int> new_size1(shape_in.size(), -1);
638 if (slice_at_front) {
639 new_begin1[abnormal_axe_in] = static_cast<int>(count_sliced_axe_in);
640 } else {
641 new_size1[abnormal_axe_in] = static_cast<int>(shape_in[abnormal_axe_in] - count_sliced_axe_in);
642 }
643 auto new_slice1 = CreateSliceValueNode(new_axes1);
644 if (new_slice1 == nullptr) {
645 MS_LOG(ERROR) << "CreateSliceValueNode failed";
646 return nullptr;
647 }
648 auto begin_parameter = BuildIntVecParameterNode(
649 graph, new_begin1, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
650 MS_CHECK_TRUE_MSG(begin_parameter != nullptr, nullptr, "BuildIntVecParameterNode Failed");
651 node_name_index += 1;
652 auto size_parameter = BuildIntVecParameterNode(
653 graph, new_size1, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
654 MS_CHECK_TRUE_MSG(size_parameter != nullptr, nullptr, "BuildIntVecParameterNode Failed");
655 node_name_index += 1;
656 auto new_slice1_cnode = graph->NewCNode({new_slice1, matmul_cnode, begin_parameter, size_parameter});
657 MS_CHECK_TRUE_MSG(new_slice1_cnode != nullptr, nullptr, "NewNode Failed");
658 new_slice1_cnode->set_abstract(slice_cnode->abstract()->Clone());
659 new_slice1_cnode->set_fullname_with_scope(slice_cnode->fullname_with_scope() + "_slice_" +
660 std::to_string(node_name_index));
661 node_name_index++;
662 ClearCNodeAbstractValue(new_slice1_cnode);
663 return new_slice1_cnode;
664 }
665
CreateSlice2ForReshapePrepose(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & new_reshape1_cnode,const std::vector<int64_t> & new_shape1,const int64_t abnormal_axe_in,const int64_t count_sliced2,const bool slice_at_front)666 CNodePtr SlicePreposePass::CreateSlice2ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
667 const CNodePtr &new_reshape1_cnode,
668 const std::vector<int64_t> &new_shape1,
669 const int64_t abnormal_axe_in, const int64_t count_sliced2,
670 const bool slice_at_front) {
671 MS_ASSERT(graph != nullptr);
672 MS_ASSERT(slice_cnode != nullptr);
673 MS_ASSERT(matmul_cnode != nullptr);
674 std::vector<int64_t> new_axes2(abnormal_axe_in + 1);
675 std::iota(new_axes2.begin(), new_axes2.end(), 0);
676 std::vector<int> new_begin2(abnormal_axe_in + 1, 0);
677 std::vector<int> new_size2(abnormal_axe_in + 1, -1);
678 if (count_sliced2 > new_shape1[abnormal_axe_in]) {
679 MS_LOG(WARNING) << "calculation error";
680 return nullptr;
681 }
682 if (slice_at_front) {
683 new_begin2[abnormal_axe_in] = static_cast<int>(new_shape1[abnormal_axe_in] - count_sliced2);
684 } else {
685 new_size2[abnormal_axe_in] = static_cast<int>(count_sliced2);
686 }
687 auto new_slice2 = CreateSliceValueNode(new_axes2);
688 if (new_slice2 == nullptr) {
689 MS_LOG(ERROR) << "CreateSliceValueNode failed";
690 return nullptr;
691 }
692 auto begin_parameter = BuildIntVecParameterNode(
693 graph, new_begin2, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
694 MS_CHECK_TRUE_MSG(begin_parameter != nullptr, nullptr, "BuildIntVecParameterNode Failed");
695 node_name_index += 1;
696 auto size_parameter = BuildIntVecParameterNode(
697 graph, new_size2, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
698 node_name_index += 1;
699 MS_CHECK_TRUE_MSG(size_parameter != nullptr, nullptr, "BuildIntVecParameterNode Failed");
700 auto new_slice2_cnode = graph->NewCNode({new_slice2, new_reshape1_cnode, begin_parameter, size_parameter});
701 MS_CHECK_TRUE_MSG(new_slice2_cnode != nullptr, nullptr, "NewNode Failed");
702 new_slice2_cnode->set_abstract(slice_cnode->abstract()->Clone());
703 new_slice2_cnode->set_fullname_with_scope(slice_cnode->fullname_with_scope() + "_slice_" +
704 std::to_string(node_name_index));
705 node_name_index++;
706 ClearCNodeAbstractValue(new_slice2_cnode);
707 return new_slice2_cnode;
708 }
709
PreposeWithAbnormalReshape(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & matmul_cnode,const std::vector<int64_t> & shape_in,const std::vector<int64_t> & shape_out,const int64_t abnormal_axe_in,const int64_t abnormal_index_out)710 bool SlicePreposePass::PreposeWithAbnormalReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
711 const CNodePtr &matmul_cnode, const std::vector<int64_t> &shape_in,
712 const std::vector<int64_t> &shape_out, const int64_t abnormal_axe_in,
713 const int64_t abnormal_index_out) {
714 MS_ASSERT(graph != nullptr);
715 MS_ASSERT(slice_cnode != nullptr);
716 auto manager = graph->manager();
717 MS_CHECK_TRUE_MSG(manager != nullptr, false, "manager is nullptr");
718 auto slice_node = GetSlice(slice_cnode);
719 if (slice_node == nullptr) {
720 MS_LOG(ERROR) << "slice is nullptr";
721 return false;
722 }
723 auto slice_axes = slice_node->get_axes();
724 MS_CHECK_TRUE_MSG(static_cast<int>(slice_axes.size()) > abnormal_index_out, false, "slice_axes.size() is wrong");
725 auto slice_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
726 auto slice_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
727 auto abnormal_axe_out = slice_axes[abnormal_index_out];
728 MS_ASSERT(abnormal_axe_out + 1 < shape_out.size());
729 int64_t inter_size_in = 1;
730 int64_t inter_size_out = 1;
731 for (auto i = 0; i < abnormal_axe_in; ++i) {
732 inter_size_in *= shape_in[i];
733 }
734 for (auto i = 0; i < abnormal_axe_out; ++i) {
735 inter_size_out *= shape_out[i];
736 }
737 if (inter_size_in != inter_size_out) {
738 MS_LOG(DEBUG) << "not support prepose now";
739 return false;
740 }
741 int64_t outer_size_in = 1;
742 int64_t outer_size_out = 1;
743 for (auto i = abnormal_axe_in + 1; i < static_cast<int>(shape_in.size()); ++i) {
744 outer_size_in *= shape_in[i];
745 }
746 for (auto i = abnormal_axe_out + 1; i < static_cast<int>(shape_out.size()); ++i) {
747 outer_size_out *= shape_out[i];
748 }
749 MS_CHECK_TRUE_MSG(static_cast<int>(slice_begin.size()) > abnormal_index_out, false, "slice_begin.size() is wrong");
750 const int64_t count_sliced_axe_front = slice_begin[abnormal_index_out];
751 MS_CHECK_TRUE_MSG(static_cast<int>(slice_size.size()) > abnormal_index_out, false, "slice_size.size() is wrong");
752 MS_CHECK_TRUE_MSG(static_cast<int>(shape_out.size()) > abnormal_index_out, false, "shape_out.size() is wrong");
753 const int64_t count_sliced_axe_rear =
754 slice_size[abnormal_index_out] == -1 ? 0 : (shape_out[abnormal_axe_out] - slice_size[abnormal_index_out]);
755 if (count_sliced_axe_front * count_sliced_axe_rear > 0) {
756 MS_LOG(DEBUG) << "not border slice at abnormal axe, prepose with reshape failed";
757 return false;
758 }
759 bool slice_at_front = count_sliced_axe_front > 0;
760 const int64_t count_sliced_out = (count_sliced_axe_front + count_sliced_axe_rear) * outer_size_out;
761 MS_CHECK_TRUE_MSG(outer_size_in != 0, false, "div zero");
762 const int64_t count_sliced_axe_in = count_sliced_out / outer_size_in;
763 MS_CHECK_TRUE_MSG(static_cast<int>(shape_in.size()) > abnormal_axe_in, false, "shape_in.size() is wrong");
764 if (count_sliced_axe_in <= 0 || count_sliced_axe_in > shape_in[abnormal_axe_in]) {
765 MS_LOG(DEBUG) << "amount of sliced out tensor is illegal";
766 return false;
767 }
768 // new_slice1
769 auto new_slice1_cnode = CreateSlice1ForReshapePrepose(graph, slice_cnode, matmul_cnode, shape_in, abnormal_axe_in,
770 count_sliced_axe_in, slice_at_front);
771 if (new_slice1_cnode == nullptr) {
772 return false;
773 }
774 // new_reshape1
775 std::vector<int64_t> new_shape1(abnormal_axe_in + 1);
776 for (int i = 0; i < abnormal_axe_in; ++i) {
777 new_shape1[i] = shape_in[i];
778 }
779 new_shape1[abnormal_axe_in] = outer_size_in * (shape_in[abnormal_axe_in] - count_sliced_axe_in);
780 auto new_reshape1_cnode = CreateReshapeCNode(graph, new_shape1, slice_cnode->abstract()->Clone(), new_slice1_cnode);
781 if (new_reshape1_cnode == nullptr) {
782 return false;
783 }
784 // new_slice2
785 const int64_t count_sliced_abnormal_axe =
786 shape_out[abnormal_axe_out] - (count_sliced_axe_front + count_sliced_axe_rear);
787 const int64_t count_sliced2 = count_sliced_abnormal_axe * outer_size_out;
788 auto new_slice2_cnode = CreateSlice2ForReshapePrepose(graph, slice_cnode, new_reshape1_cnode, new_shape1,
789 abnormal_axe_in, count_sliced2, slice_at_front);
790 if (new_slice2_cnode == nullptr) {
791 return false;
792 }
793 // new_reshape2
794 std::vector<int64_t> new_shape2(shape_out.begin(), shape_out.end());
795 new_shape2[abnormal_axe_out] = count_sliced_abnormal_axe;
796 auto new_reshape2_cnode = CreateReshapeCNode(graph, new_shape2, slice_cnode->abstract()->Clone(), new_slice2_cnode);
797 if (new_reshape2_cnode == nullptr) {
798 return false;
799 }
800 new_reshape2_cnode->set_abstract(slice_cnode->abstract()->Clone());
801 auto node_users = manager->node_users()[slice_cnode];
802 for (auto &node_user : node_users) {
803 manager->SetEdge(node_user.first, node_user.second, new_reshape2_cnode);
804 }
805 return true;
806 }
807
GetArithmeticInputInfo(const CNodePtr & arithmetic_cnode,std::vector<AnfNodePtr> * inputs,std::vector<std::vector<int64_t>> * shapes,std::vector<bool> * is_default_params)808 bool SlicePreposePass::GetArithmeticInputInfo(const CNodePtr &arithmetic_cnode, std::vector<AnfNodePtr> *inputs,
809 std::vector<std::vector<int64_t>> *shapes,
810 std::vector<bool> *is_default_params) {
811 MS_ASSERT(inputs != nullptr);
812 MS_ASSERT(shapes != nullptr);
813 MS_ASSERT(is_default_params != nullptr);
814 MS_ASSERT(arithmetic_cnode != nullptr);
815 for (size_t i = 1; i < arithmetic_cnode->size(); ++i) {
816 auto input = arithmetic_cnode->input(i);
817 MS_ASSERT(input != nullptr);
818 std::vector<int64_t> shape;
819 if (utils::isa<ParameterPtr>(input)) {
820 auto parameter = utils::cast<ParameterPtr>(input);
821 MS_ASSERT(parameter != nullptr);
822 if (!parameter->has_default()) { // if one input is input placeholder, we can't change it
823 return false;
824 } else {
825 shape = GetDefaultParamShape(parameter);
826 is_default_params->push_back(true);
827 }
828 } else { // input is CNode
829 if (!utils::isa<CNodePtr>(input)) {
830 MS_LOG(ERROR) << "one of Arithmetic's input is not CNode";
831 return false;
832 }
833 shape = GetCNodeInputShape(arithmetic_cnode, i);
834 is_default_params->push_back(false);
835 }
836 inputs->push_back(input);
837 shapes->push_back(shape);
838 }
839 return true;
840 }
841
842 /*
843 * Prepose condition:
844 * the softmax axis is not sliced
845 */
PreposeWithSoftmax(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & softmax_cnode)846 bool SlicePreposePass::PreposeWithSoftmax(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
847 const CNodePtr &softmax_cnode) {
848 MS_ASSERT(graph != nullptr);
849 MS_ASSERT(slice_cnode != nullptr);
850 MS_ASSERT(softmax_cnode != nullptr);
851 auto softmax_node = GetSoftmax(softmax_cnode);
852 if (softmax_node == nullptr) {
853 MS_LOG(ERROR) << "softmax is nullptr";
854 return false;
855 }
856 std::vector<int64_t> softmax_axis{-1};
857 if (softmax_node->GetAttr(ops::kAxis) != nullptr) {
858 softmax_axis = softmax_node->get_axis();
859 }
860 if (softmax_axis.size() != 1) {
861 MS_LOG(ERROR) << "softmax axis is not a value, which don't support.";
862 return false;
863 }
864 auto shape = GetCNodeInputShape(softmax_cnode, 1);
865 if (softmax_axis.front() == -1) {
866 // when softmax axis == -1, shape info is needed to determine whether slice can be preposed
867 if (lite::JudgeDynamicShape(shape)) {
868 return false;
869 }
870 softmax_axis[0] += static_cast<int64_t>(shape.size());
871 }
872
873 auto slice_node = GetSlice(slice_cnode);
874 if (slice_node == nullptr) {
875 return false;
876 }
877 auto slice_axes = slice_node->get_axes();
878 auto slice_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
879 auto slice_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
880
881 MS_CHECK_TRUE_MSG(static_cast<int>(softmax_axis.size()) > 0, false, "shape_in.size() is wrong");
882 MS_CHECK_TRUE_MSG(slice_size.size() >= slice_axes.size(), false, "shape_in.size() is wrong");
883 MS_CHECK_TRUE_MSG(slice_begin.size() >= slice_axes.size(), false, "shape_in.size() is wrong");
884 for (size_t i = 0; i < slice_axes.size(); ++i) {
885 if (slice_axes[i] == softmax_axis.front()) {
886 if (slice_begin[i] != 0) {
887 return false;
888 }
889 if (slice_size[i] != -1) {
890 if (lite::JudgeDynamicShape(shape) || slice_axes[i] >= static_cast<int>(shape.size())) {
891 return false;
892 }
893 if (slice_size[i] < shape[slice_axes[i]]) {
894 return false;
895 }
896 }
897 }
898 }
899 auto status = SwapSliceWithPreceed(graph, slice_cnode, softmax_cnode, 1);
900 if (status != RET_OK) {
901 return false;
902 }
903 softmax_cnode->set_abstract(slice_cnode->abstract()->Clone());
904 ClearCNodeAbstractValue(slice_cnode);
905 return true;
906 }
907
908 /*
909 * Prepose condition:
910 * require shape info
911 * when reshape is normal(memory view is not changed, such as 4x5 reshaped to 4x1x5), can always prepose
912 * when reshape is abnormal(such as 4x5 reshaped to 5x4), can prepose under some constraint
913 * For abnormal mode:
914 * we only support border(not slice at center) slice at first mismatch axe,
915 * and we only support matmul->reshape->slice => matmul->slice->reshape*->slice*(drop "dead" data)->reshape now,
916 * cause the performance influence introduced by additional (reshape*->slice*) has not been fully evaluated.
917 */
PreposeWithReshape(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & reshape_cnode)918 bool SlicePreposePass::PreposeWithReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
919 const CNodePtr &reshape_cnode) {
920 MS_ASSERT(graph != nullptr);
921 MS_ASSERT(slice_cnode != nullptr);
922 MS_ASSERT(reshape_cnode != nullptr);
923 auto shape_in = GetCNodeInputShape(reshape_cnode, 1);
924 auto shape_out = GetCNodeInputShape(slice_cnode, 1);
925 auto shape_out_copy = shape_out;
926 if (shape_in.empty() || shape_out.empty()) {
927 MS_LOG(DEBUG) << "Reshape can't be preposed if either input or output shape is unknown";
928 return false;
929 }
930 if (reshape_cnode->size() == 3 && utils::isa<ParameterPtr>(reshape_cnode->input(2))) {
931 auto reshape_input_shape = utils::cast<ParameterPtr>(reshape_cnode->input(2));
932 MS_ASSERT(reshape_input_shape != nullptr);
933 if (!reshape_input_shape->has_default()) {
934 MS_LOG(ERROR) << "Reshape input shape is not constant";
935 return false;
936 }
937 }
938 std::vector<int64_t> mapped_axe(shape_out.size(), -1);
939 int64_t abnormal_axe_in = GetReshapeAbnormalAxeIn(shape_in, shape_out, &mapped_axe);
940 bool is_normal_mode = true; // if all sliced axe can be found in input shape, normal
941 bool support_abnormal_mode = true; // if first mismatch axe are sliced and no more other axes are sliced, abnormal
942 int64_t abnormal_index_out = GetReshapeAbnormalIndexOut(slice_cnode, mapped_axe, shape_out, &shape_out_copy,
943 &is_normal_mode, &support_abnormal_mode);
944 if (abnormal_index_out == -1) {
945 MS_LOG(ERROR) << "GetReshapeAbnormalIndexOut failed.";
946 return false;
947 }
948 if (is_normal_mode) {
949 return PreposeWithNormalReshape(graph, slice_cnode, reshape_cnode, shape_in, shape_out_copy, mapped_axe);
950 } else if (support_abnormal_mode) {
951 auto matmul_node = reshape_cnode->input(1);
952 MS_ASSERT(matmul_node != nullptr);
953 if (IsMultiOutputTensors(graph, matmul_node) || !utils::isa<CNodePtr>(matmul_node)) {
954 MS_LOG(DEBUG) << "not matmul->reshape->slice";
955 return false;
956 }
957 auto matmul_cnode = matmul_node->cast<CNodePtr>();
958 if (matmul_cnode == nullptr) {
959 MS_LOG(ERROR) << "matmul_cnode is nullptr";
960 return false;
961 }
962 if (!CheckPrimitiveType(matmul_node, prim::kPrimFullConnection) &&
963 !CheckPrimitiveType(matmul_node, prim::kPrimMatMulFusion)) {
964 MS_LOG(DEBUG) << "not matmul->reshape->slice pattern";
965 return false;
966 }
967 return PreposeWithAbnormalReshape(graph, slice_cnode, matmul_cnode, shape_in, shape_out, abnormal_axe_in,
968 abnormal_index_out);
969 }
970 return false;
971 }
972
973 /*
974 * Prepose condition:
975 * require shape info
976 */
PreposeWithMatmul(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & matmul_cnode)977 bool SlicePreposePass::PreposeWithMatmul(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
978 const CNodePtr &matmul_cnode) {
979 MS_ASSERT(graph != nullptr && slice_cnode != nullptr && matmul_cnode != nullptr);
980 auto matmul_shape = GetCNodeInputShape(slice_cnode, 1);
981 int dims = static_cast<int>(matmul_shape.size());
982 if (dims == 0) {
983 // if Matmul's output shape is unknown, can't do prepose, cause we can't determine last two axes
984 return false;
985 }
986 auto slice_node = GetSlice(slice_cnode);
987 MS_CHECK_TRUE_MSG(slice_node != nullptr, false, "slice is nullptr");
988 auto axes = slice_node->get_axes();
989 auto begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
990 auto size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
991 // matmul not support broadcast now, it makes things simpler
992 auto manager = graph->manager();
993 std::shared_ptr<FuncGraphTransaction> tr = std::make_shared<FuncGraphTransaction>(manager.get());
994 MS_CHECK_TRUE_MSG(tr != nullptr, false, "create FuncGraphTransaction failed");
995 auto node_users = manager->node_users()[slice_cnode];
996 bool changed = false;
997 bool prepose_to_left = false; // if only the last axe is sliced, not need prepose to left
998 bool prepose_to_right = false; // if only the second last axe is sliced, not need prepose to right
999 MS_CHECK_TRUE_MSG(begin.size() >= axes.size(), false, "begin.size() is wrong");
1000 MS_CHECK_TRUE_MSG(size.size() >= axes.size(), false, "size.size() is wrong");
1001 for (size_t i = 0; i < axes.size(); ++i) {
1002 if (begin[i] != 0 || (size[i] != -1 && size[i] != matmul_shape[axes[i]])) {
1003 if (axes[i] != dims - 1) {
1004 prepose_to_left = true;
1005 } else if (axes[i] != dims - 2) {
1006 prepose_to_right = true;
1007 }
1008 }
1009 }
1010 if (prepose_to_left) { // left matrix
1011 auto left_axes = axes;
1012 auto left_begin = begin;
1013 auto left_size = size;
1014 MS_CHECK_TRUE_MSG(left_begin.size() >= left_axes.size(), false, "left_begin.size() is wrong");
1015 MS_CHECK_TRUE_MSG(left_size.size() >= left_axes.size(), false, "left_size.size() is wrong");
1016 for (size_t i = 0; i < left_axes.size(); ++i) {
1017 if (left_axes[i] == dims - 1) {
1018 left_begin[i] = 0;
1019 left_size[i] = -1;
1020 }
1021 }
1022 auto left_slice_vnode = CreateSliceValueNode(left_axes);
1023 MS_CHECK_TRUE_MSG(left_slice_vnode != nullptr, false, "CreateSliceValueNode failed");
1024 auto begin_parameter = BuildIntVecParameterNode(
1025 graph, left_begin, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1026 node_name_index += 1;
1027 auto size_parameter = BuildIntVecParameterNode(
1028 graph, left_size, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1029 MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1030 MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1031 node_name_index += 1;
1032
1033 const std::vector<AnfNodePtr> inputs = {left_slice_vnode, matmul_cnode->input(1), begin_parameter, size_parameter};
1034 auto new_slice_cnode = InsertSlice(graph, inputs, matmul_cnode, 1, tr);
1035 MS_CHECK_TRUE_MSG(new_slice_cnode != nullptr, false, "InsertSlice Failed");
1036 new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone());
1037 ClearCNodeAbstractValue(new_slice_cnode);
1038 changed = true;
1039 }
1040 if (prepose_to_right) { // right matrix
1041 auto right_axes = axes;
1042 auto right_begin = begin;
1043 auto right_size = size;
1044 MS_CHECK_TRUE_MSG(right_begin.size() >= right_axes.size(), false, "right_begin.size() is wrong");
1045 MS_CHECK_TRUE_MSG(right_size.size() >= right_axes.size(), false, "right_size.size() is wrong");
1046 for (size_t i = 0; i < right_axes.size(); ++i) {
1047 if (right_axes[i] == dims - 2) {
1048 right_begin[i] = 0;
1049 right_size[i] = -1;
1050 }
1051 }
1052 auto begin_parameter = BuildIntVecParameterNode(
1053 graph, right_begin, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1054 node_name_index += 1;
1055 auto size_parameter = BuildIntVecParameterNode(
1056 graph, right_size, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1057 MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1058 MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1059 node_name_index += 1;
1060 auto right_slice_vnode = CreateSliceValueNode(right_axes);
1061 MS_CHECK_TRUE_MSG(right_slice_vnode != nullptr, false, "CreateSliceValueNode failed");
1062 const std::vector<AnfNodePtr> inputs = {right_slice_vnode, matmul_cnode->input(2), begin_parameter, size_parameter};
1063 auto new_slice_cnode = InsertSlice(graph, inputs, matmul_cnode, 2, tr);
1064 MS_ASSERT(new_slice_cnode != nullptr);
1065 new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone());
1066 ClearCNodeAbstractValue(new_slice_cnode);
1067 changed = true;
1068 }
1069 if (changed) {
1070 matmul_cnode->set_abstract(slice_cnode->abstract()->Clone());
1071 for (auto &node_user : node_users) {
1072 tr->SetEdge(node_user.first, node_user.second, matmul_cnode);
1073 }
1074 tr->Commit();
1075 // we don't need graph->DropNode(slice_cnode);
1076 }
1077 return changed;
1078 }
1079
1080 /*
1081 * Prepose condition:
1082 * require shape info
1083 * only support slice at first output axe now, and useAxis must be false
1084 */
PreposeWithFullConnection(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & fc_cnode)1085 bool SlicePreposePass::PreposeWithFullConnection(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
1086 const CNodePtr &fc_cnode) {
1087 MS_ASSERT(graph != nullptr);
1088 MS_ASSERT(slice_cnode != nullptr);
1089 MS_ASSERT(fc_cnode != nullptr);
1090 auto shape_in = GetCNodeInputShape(fc_cnode, 1);
1091 auto shape_out = GetCNodeInputShape(slice_cnode, 1);
1092 if (shape_in.empty() || shape_out.size() != 2) {
1093 MS_LOG(DEBUG) << "FullConnection can't be preposed if input shape is unknown or output shape is illegal";
1094 return false;
1095 }
1096 auto fc_node = GetFc(fc_cnode);
1097 if (fc_node == nullptr || (fc_node->GetAttr(ops::kUseAxis) != nullptr && fc_node->get_use_axis())) {
1098 MS_LOG(DEBUG) << "prepose with fc only support useAxis == false currently";
1099 return false;
1100 }
1101 auto slice_node = GetSlice(slice_cnode);
1102 MS_CHECK_TRUE_MSG(slice_node != nullptr, false, "slice is nullptr");
1103 auto axes = slice_node->get_axes();
1104 auto begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
1105 auto size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
1106 MS_CHECK_TRUE_MSG(begin.size() >= axes.size(), false, "begin.size() is wrong");
1107 MS_CHECK_TRUE_MSG(size.size() >= axes.size(), false, "size.size() is wrong");
1108 for (size_t i = 0; i < axes.size(); ++i) {
1109 if (axes[i] == 1) {
1110 if (begin[i] != 0 || (size[i] != -1 && size[i] != shape_out[1])) {
1111 MS_LOG(DEBUG) << "prepose with fc only support first output axe is sliced currently";
1112 return false;
1113 }
1114 }
1115 }
1116
1117 std::vector<int64_t> mapped_axe(shape_out.size(), -1);
1118 int64_t inner_size_in = 1;
1119 for (size_t i = 0; i < shape_in.size(); ++i) {
1120 inner_size_in *= shape_in[i];
1121 int64_t inner_size_out = 1;
1122 for (size_t j = 0; j < shape_out.size(); ++j) {
1123 inner_size_out *= shape_out[j];
1124 if (shape_out[j] == shape_in[i] && inner_size_out == inner_size_in) {
1125 mapped_axe[j] = static_cast<int64_t>(i);
1126 break;
1127 }
1128 }
1129 }
1130 if (mapped_axe[0] == -1) {
1131 MS_LOG(DEBUG) << "first axe in output can't find correspond input axe, can't do prepose";
1132 return false;
1133 }
1134
1135 std::vector<int64_t> new_axes(shape_in.size());
1136 std::iota(new_axes.begin(), new_axes.end(), 0);
1137 std::vector<int> new_begin(shape_in.size(), 0);
1138 std::vector<int> new_size(shape_in.size(), -1);
1139 new_begin[mapped_axe[0]] = begin[0];
1140 new_size[mapped_axe[0]] = size[0];
1141 auto new_slice_vnode = CreateSliceValueNode(new_axes);
1142 MS_CHECK_TRUE_MSG(new_slice_vnode != nullptr, false, "CreateSliceValueNode failed");
1143
1144 auto manager = graph->manager();
1145 std::shared_ptr<FuncGraphTransaction> tr = std::make_shared<FuncGraphTransaction>(manager.get());
1146 MS_CHECK_TRUE_MSG(tr != nullptr, false, "create FuncGraphTransaction failed");
1147 auto begin_parameter = BuildIntVecParameterNode(
1148 graph, new_begin, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1149 node_name_index += 1;
1150 auto size_parameter = BuildIntVecParameterNode(
1151 graph, new_size, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1152 MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1153 MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1154 node_name_index += 1;
1155 const std::vector<AnfNodePtr> inputs = {new_slice_vnode, fc_cnode->input(1), begin_parameter, size_parameter};
1156 auto new_slice_cnode = InsertSlice(graph, inputs, fc_cnode, 1, tr);
1157 MS_CHECK_TRUE_MSG(new_slice_cnode != nullptr, false, "InsertSlice Failed");
1158
1159 fc_cnode->set_abstract(slice_cnode->abstract()->Clone());
1160 new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone());
1161 ClearCNodeAbstractValue(new_slice_cnode);
1162
1163 auto node_users = manager->node_users()[slice_cnode];
1164 for (auto &node_user : node_users) {
1165 tr->SetEdge(node_user.first, node_user.second, fc_cnode);
1166 }
1167 tr->Commit();
1168 return true;
1169 }
1170
1171 /*
1172 * Prepose condition:
1173 * not require shape info, can always prepose
1174 */
PreposeWithTranspose(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & transpose_cnode)1175 bool SlicePreposePass::PreposeWithTranspose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
1176 const CNodePtr &transpose_cnode) {
1177 MS_ASSERT(graph != nullptr);
1178 MS_ASSERT(slice_cnode != nullptr);
1179 MS_ASSERT(transpose_cnode != nullptr);
1180 if (transpose_cnode->size() != 3) {
1181 MS_LOG(ERROR) << "transpose inputs size should be 3.";
1182 return false;
1183 }
1184 auto perm = GetTransposePerm(transpose_cnode);
1185 if (perm.empty()) {
1186 return false;
1187 }
1188 auto slice_node = GetSlice(slice_cnode);
1189 if (slice_node == nullptr) {
1190 MS_LOG(ERROR) << "GetSlicT failed";
1191 return false;
1192 }
1193 auto old_axes = slice_node->get_axes();
1194 auto old_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
1195 auto old_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
1196 auto slice_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
1197 auto slice_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
1198 // perm is random shuffle of [0...n-1] according to ops/transpose.cc
1199 for (size_t i = 0; i < perm.size(); ++i) {
1200 if (perm[i] != static_cast<int>(i)) {
1201 for (size_t j = 0; j < old_axes.size(); ++j) {
1202 if (old_axes[j] == static_cast<int>(i)) {
1203 MS_CHECK_TRUE_MSG(static_cast<int>(slice_begin.size()) > perm[i], false, "slice_begin.size() is wrong");
1204 MS_CHECK_TRUE_MSG(static_cast<int>(slice_size.size()) > perm[i], false, "slice_size.size() is wrong");
1205 slice_begin[perm[i]] = old_begin[j];
1206 slice_size[perm[i]] = old_size[j];
1207 break;
1208 }
1209 }
1210 }
1211 }
1212 auto begin_parameter = BuildIntVecParameterNode(
1213 graph, slice_begin, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1214 node_name_index += 1;
1215 auto size_parameter = BuildIntVecParameterNode(
1216 graph, slice_size, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1217 MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1218 MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1219 node_name_index += 1;
1220 slice_cnode->set_input(SliceBeginIndex, begin_parameter);
1221 slice_cnode->set_input(SliceSizeIndex, size_parameter);
1222 auto status = SwapSliceWithPreceed(graph, slice_cnode, transpose_cnode, 1);
1223 if (status != RET_OK) {
1224 return false;
1225 }
1226 transpose_cnode->set_abstract(slice_cnode->abstract()->Clone());
1227 ClearCNodeAbstractValue(slice_cnode);
1228 return true;
1229 }
1230 /*
1231 * Prepose condition:
1232 * may or may not require shape info
1233 */
PreposeWithArithmetic(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & arithmetic_cnode)1234 bool SlicePreposePass::PreposeWithArithmetic(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
1235 const CNodePtr &arithmetic_cnode) {
1236 MS_ASSERT(graph != nullptr);
1237 MS_ASSERT(slice_cnode != nullptr);
1238 MS_ASSERT(arithmetic_cnode != nullptr);
1239 auto manager = graph->manager();
1240 MS_ASSERT(manager != nullptr);
1241 auto node_users = manager->node_users()[slice_cnode];
1242 std::shared_ptr<FuncGraphTransaction> tr = std::make_shared<FuncGraphTransaction>(manager.get());
1243 if (tr == nullptr) {
1244 MS_LOG(ERROR) << "create FuncGraphTransaction failed";
1245 return false;
1246 }
1247 bool changed = false;
1248 std::vector<AnfNodePtr> inputs;
1249 std::vector<std::vector<int64_t>> shapes;
1250 std::vector<bool> is_default_params;
1251 if (!GetArithmeticInputInfo(arithmetic_cnode, &inputs, &shapes, &is_default_params)) {
1252 return false;
1253 }
1254
1255 for (size_t i = 1; i < arithmetic_cnode->size(); ++i) {
1256 auto &input = inputs[i - 1];
1257 if (IsScalarNode(input)) { // scalar not need prepose
1258 continue;
1259 }
1260 auto &shape = shapes[i - 1];
1261 const size_t another_index = kArithmeticInputNum - i;
1262 auto &another_input = inputs[another_index];
1263 auto &another_shape = shapes[another_index];
1264 if (IsScalarNode(input)) {
1265 continue;
1266 } else if (lite::JudgeDynamicShape(shape)) { // infershape failed at this input
1267 if (IsScalarNode(another_input)) { // if another input is scalar, we can process this one
1268 auto new_slice_vnode = CopySliceValueNode(slice_cnode);
1269 if (new_slice_vnode == nullptr) {
1270 changed = false;
1271 break;
1272 }
1273 std::vector<AnfNodePtr> slice_inputs = {new_slice_vnode, arithmetic_cnode->input(i),
1274 slice_cnode->input(SliceBeginIndex),
1275 slice_cnode->input(SliceSizeIndex)};
1276 auto new_slice_cnode = InsertSlice(graph, slice_inputs, arithmetic_cnode, i, tr);
1277 MS_CHECK_TRUE_MSG(new_slice_cnode != nullptr, false, "InsertSlice Failed");
1278
1279 new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone());
1280 ClearCNodeAbstractValue(new_slice_cnode);
1281 changed = true;
1282 break;
1283 } else { // if another input's shape is not scalar, can't be processed
1284 changed = false;
1285 break;
1286 }
1287 } else { // shape not empty
1288 if (!another_shape.empty() || IsScalarNode(another_input)) {
1289 std::vector<int64_t> new_axes;
1290 std::vector<int> new_begin;
1291 std::vector<int> new_size;
1292 auto status = SliceParamDeBroadcast(slice_cnode, shape, &new_axes, &new_begin, &new_size);
1293 if (status == lite::RET_NO_CHANGE) {
1294 continue;
1295 }
1296 if (status != lite::RET_OK) {
1297 changed = false;
1298 break;
1299 }
1300 auto new_slice_vnode = CreateSliceValueNode(new_axes);
1301 if (new_slice_vnode == nullptr) {
1302 changed = false;
1303 break;
1304 }
1305 auto begin_parameter = BuildIntVecParameterNode(
1306 graph, new_begin, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1307 node_name_index += 1;
1308 auto size_parameter = BuildIntVecParameterNode(
1309 graph, new_size, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1310 MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1311 MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1312 node_name_index += 1;
1313 std::vector<AnfNodePtr> slice_inputs = {new_slice_vnode, arithmetic_cnode->input(i), begin_parameter,
1314 size_parameter};
1315 auto new_slice_cnode = InsertSlice(graph, slice_inputs, arithmetic_cnode, i, tr);
1316 MS_CHECK_TRUE_MSG(new_slice_cnode != nullptr, false, "InsertSlice Failed");
1317 new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone());
1318 ClearCNodeAbstractValue(new_slice_cnode);
1319 changed = true;
1320 } else {
1321 changed = false;
1322 break;
1323 }
1324 }
1325 }
1326 if (changed) {
1327 arithmetic_cnode->set_abstract(slice_cnode->abstract()->Clone());
1328 for (auto &node_user : node_users) {
1329 tr->SetEdge(node_user.first, node_user.second, arithmetic_cnode);
1330 }
1331 tr->Commit();
1332 // we don't need graph->DropNode(slice_cnode);
1333 }
1334 return changed;
1335 } // namespace mindspore::opt
1336 /*
1337 * Prepose condition:
1338 * not require shape info
1339 */
MergeSequentialSlice(const FuncGraphPtr & graph,const CNodePtr & slice1_cnode,const CNodePtr & slice2_cnode)1340 bool SlicePreposePass::MergeSequentialSlice(const FuncGraphPtr &graph, const CNodePtr &slice1_cnode,
1341 const CNodePtr &slice2_cnode) {
1342 if (slice2_cnode->size() != kArithmeticInputNum) {
1343 MS_LOG(INFO) << "Slice read attrs from input is not supported now";
1344 return false;
1345 }
1346 auto slice1_node = GetSlice(slice1_cnode); // bottom node
1347 auto slice2_node = GetSlice(slice2_cnode); // top node
1348 if (slice1_node == nullptr || slice2_node == nullptr) {
1349 MS_LOG(ERROR) << "slice is null";
1350 return false;
1351 }
1352 auto begin_slice1 = GetSliceBeginAndSize(slice1_cnode, SliceBeginIndex);
1353 auto size_slice1 = GetSliceBeginAndSize(slice1_cnode, SliceSizeIndex);
1354 auto axes_slice1 = slice1_node->get_axes();
1355 auto begin_slice2 = GetSliceBeginAndSize(slice2_cnode, SliceBeginIndex);
1356 auto size_slice2 = GetSliceBeginAndSize(slice2_cnode, SliceSizeIndex);
1357 auto axes_slice2 = slice2_node->get_axes();
1358 auto status1 = VerifySliceAttrs(slice1_cnode);
1359 auto status2 = VerifySliceAttrs(slice2_cnode);
1360 if (status1 != RET_OK || status2 != RET_OK) {
1361 return false;
1362 }
1363
1364 auto manager = graph->manager();
1365 MS_ASSERT(manager != nullptr);
1366 auto node_users = manager->node_users()[slice1_cnode];
1367 int64_t axe_max1 = *std::max_element(axes_slice1.begin(), axes_slice1.end());
1368 int64_t axe_max2 = *std::max_element(axes_slice2.begin(), axes_slice2.end());
1369 int64_t axe_max = std::max(axe_max1, axe_max2);
1370 auto begin_new = begin_slice2;
1371 auto size_new = size_slice2;
1372 auto axes_new = slice2_node->get_axes();
1373 axes_new.resize(axe_max + 1);
1374 std::iota(axes_new.begin(), axes_new.end(), 0);
1375 begin_new.assign(axe_max + 1, 0);
1376 size_new.assign(axe_max + 1, -1);
1377 MS_CHECK_TRUE_MSG(begin_slice2.size() >= axes_slice2.size(), false, "begin_slice2.size() is wrong");
1378 MS_CHECK_TRUE_MSG(size_slice2.size() >= axes_slice2.size(), false, "size_slice2.size() is wrong");
1379 MS_CHECK_TRUE_MSG(size_slice1.size() >= axes_slice1.size(), false, "size_slice1.size() is wrong");
1380 MS_CHECK_TRUE_MSG(begin_slice1.size() >= axes_slice1.size(), false, "begin_slice1.size() is wrong");
1381 for (int i = 0; i <= axe_max; ++i) {
1382 for (size_t j = 0; j < axes_slice2.size(); ++j) {
1383 if (axes_slice2[j] == i) {
1384 begin_new[i] = begin_slice2[j];
1385 size_new[i] = size_slice2[j];
1386 break;
1387 }
1388 }
1389 for (size_t j = 0; j < axes_slice1.size(); ++j) {
1390 if (axes_slice1[j] == i) {
1391 begin_new[i] = begin_new[i] + begin_slice1[j];
1392 if (size_new[i] == -1) {
1393 size_new[i] = size_slice1[j];
1394 } else {
1395 if (size_slice1[j] == -1) {
1396 size_new[i] = std::max(size_new[i] - begin_slice1[i], 0); // clip with zero to avoid invalid negative value
1397 } else {
1398 size_new[i] = std::max(std::min(size_new[i] - begin_slice1[j], size_slice1[j]), 0);
1399 }
1400 }
1401 break;
1402 }
1403 }
1404 }
1405 slice2_node->set_axes(axes_new);
1406 auto begin_parameter = BuildIntVecParameterNode(
1407 graph, begin_new, slice2_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1408 node_name_index += 1;
1409 auto size_parameter = BuildIntVecParameterNode(
1410 graph, size_new, slice2_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1411 MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1412 MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1413 node_name_index += 1;
1414 slice2_cnode->set_input(SliceBeginIndex, begin_parameter);
1415 slice2_cnode->set_input(SliceSizeIndex, size_parameter);
1416 slice2_cnode->set_abstract(slice1_cnode->abstract()->Clone());
1417 for (auto &node_user : node_users) {
1418 manager->SetEdge(node_user.first, node_user.second, slice2_cnode);
1419 }
1420 return true;
1421 }
1422
1423 /*
1424 * Prepose condition:
1425 * when all sibling slices do same work
1426 * can be optimize to not require all siblings are slice
1427 */
MergeParallelSlice(const FuncGraphPtr & graph,const NodeUsedListPtr & slices)1428 bool SlicePreposePass::MergeParallelSlice(const FuncGraphPtr &graph, const NodeUsedListPtr &slices) {
1429 MS_ASSERT(graph != nullptr);
1430 MS_ASSERT(slices->size() >= 2);
1431 auto manager = graph->manager();
1432 MS_ASSERT(manager != nullptr);
1433 auto first_slice = utils::cast<CNodePtr>(slices->at(0).first);
1434 MS_ASSERT(first_slice != nullptr);
1435 if (!CheckPrimitiveType(first_slice, prim::kPrimSliceFusion)) {
1436 MS_LOG(ERROR) << "first node is not Slice";
1437 return false;
1438 }
1439 auto first_parent = first_slice->input(1);
1440 if (first_parent == nullptr) {
1441 MS_LOG(ERROR) << "first slice node's parent is nullptr";
1442 return false;
1443 }
1444 std::shared_ptr<FuncGraphTransaction> tr = std::make_shared<FuncGraphTransaction>(manager.get());
1445 if (tr == nullptr) {
1446 MS_LOG(ERROR) << "create FuncGraphTransaction failed";
1447 return false;
1448 }
1449 for (size_t i = 1; i < slices->size(); ++i) {
1450 auto slice = utils::cast<CNodePtr>(slices->at(i).first);
1451 MS_ASSERT(slice != nullptr);
1452 if (!CheckPrimitiveType(slice, prim::kPrimSliceFusion)) {
1453 MS_LOG(ERROR) << "current node is not Slice";
1454 return false;
1455 }
1456 auto parent = slice->input(1);
1457 if (parent == nullptr || parent != first_parent) {
1458 MS_LOG(ERROR) << "not all slices have same parent node";
1459 return false;
1460 }
1461 auto node_users = manager->node_users()[slices->at(i).first];
1462 for (auto &node_user : node_users) {
1463 tr->SetEdge(node_user.first, node_user.second, slices->at(0).first);
1464 }
1465 }
1466 tr->Commit();
1467 return true;
1468 }
1469
DoPrepose(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & preceed_cnode)1470 bool SlicePreposePass::DoPrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
1471 const CNodePtr &preceed_cnode) {
1472 MS_ASSERT(graph != nullptr);
1473 MS_ASSERT(slice_cnode != nullptr);
1474 MS_ASSERT(preceed_cnode != nullptr);
1475 if (CheckPrimitiveType(preceed_cnode, prim::kPrimSoftmax)) {
1476 return PreposeWithSoftmax(graph, slice_cnode, preceed_cnode);
1477 } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimReshape)) {
1478 return PreposeWithReshape(graph, slice_cnode, preceed_cnode);
1479 } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimMatMulFusion)) {
1480 return PreposeWithMatmul(graph, slice_cnode, preceed_cnode);
1481 } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimFullConnection)) {
1482 return PreposeWithFullConnection(graph, slice_cnode, preceed_cnode);
1483 } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimTranspose)) {
1484 return PreposeWithTranspose(graph, slice_cnode, preceed_cnode);
1485 } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimSubFusion) ||
1486 CheckPrimitiveType(preceed_cnode, prim::kPrimMulFusion) ||
1487 CheckPrimitiveType(preceed_cnode, prim::kPrimAddFusion)) {
1488 return PreposeWithArithmetic(graph, slice_cnode, preceed_cnode);
1489 } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimSliceFusion)) {
1490 return MergeSequentialSlice(graph, slice_cnode, preceed_cnode);
1491 }
1492 return false;
1493 }
1494
Run(const FuncGraphPtr & graph)1495 bool SlicePreposePass::Run(const FuncGraphPtr &graph) {
1496 if (fmk_type != converter::kFmkTypeTf && fmk_type != converter::kFmkTypeTflite) {
1497 MS_LOG(INFO) << "The framework type of model should be tf/tflite.";
1498 return false;
1499 }
1500 MS_ASSERT(graph != nullptr);
1501 bool changed = false;
1502 while (true) {
1503 bool this_time_changed = false;
1504 auto node_list = TopoSort(graph->get_return());
1505 for (auto &node : node_list) {
1506 if (node->func_graph() != graph || !utils::isa<CNodePtr>(node) ||
1507 !CheckPrimitiveType(node, prim::kPrimSliceFusion)) {
1508 continue;
1509 }
1510 auto slice_cnode = node->cast<CNodePtr>();
1511 MS_ASSERT(slice_cnode != nullptr);
1512 // only support begin and size is const tensor.
1513 if (!CheckIsAllInputsParam(slice_cnode) || GetSlice(slice_cnode)) {
1514 continue;
1515 }
1516 auto preceed_node = slice_cnode->input(1);
1517 if (preceed_node == nullptr) {
1518 MS_LOG(ERROR) << "proceed node is nullptr";
1519 continue;
1520 }
1521 auto output_tensor_num = GetOutputTensorNum(preceed_node);
1522 if (output_tensor_num > 1) {
1523 continue;
1524 }
1525 auto output_node_list = Helper::GetRealNodeUsedList(graph, utils::cast<AnfNodePtr>(preceed_node));
1526 if (output_node_list->size() > 1) { // referenced by multi nodes
1527 if (SiblingsAreSameSlice(output_node_list) && MergeParallelSlice(graph, output_node_list)) {
1528 this_time_changed = true;
1529 break;
1530 }
1531 continue;
1532 } else {
1533 if (utils::isa<ParameterPtr>(preceed_node)) {
1534 /*
1535 * if preceed_node is parameter without default param, it's input placeholder, so we can't prepose
1536 * if preceed_node is parameter with default param, constant_folding will process it
1537 */
1538 continue;
1539 }
1540 auto preceed_cnode = preceed_node->cast<CNodePtr>();
1541 if (preceed_cnode == nullptr) {
1542 MS_LOG(ERROR) << "preceed_cnode is nullptr";
1543 continue;
1544 }
1545 if (DoPrepose(graph, slice_cnode, preceed_cnode)) {
1546 this_time_changed = true;
1547 break;
1548 }
1549 }
1550 }
1551 if (this_time_changed) {
1552 changed = true;
1553 } else {
1554 break;
1555 }
1556 }
1557 return changed;
1558 }
1559 } // namespace mindspore::opt
1560