• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define USE_DEPRECATED_API
18 #include "tools/optimizer/graph/slice_prepose_pass.h"
19 #include <vector>
20 #include <memory>
21 #include <set>
22 #include <algorithm>
23 #include "mindspore/core/ops/nn_ops.h"
24 #include "mindspore/core/ops/lite_ops.h"
25 #include "mindspore/core/ops/array_ops.h"
26 #include "ops/fusion/full_connection.h"
27 #include "ops/auto_generate/gen_lite_ops.h"
28 #include "ops/fusion/slice_fusion.h"
29 #include "ops/op_utils.h"
30 #include "include/errorcode.h"
31 #include "tools/optimizer/common/gllo_utils.h"
32 #include "tools/optimizer/common/helper.h"
33 #include "include/backend/optimizer/helper.h"
34 #include "src/common/log_adapter.h"
35 #include "nnacl/op_base.h"
36 
37 namespace mindspore::opt {
38 namespace {
39 const int kArithmeticInputNum = 2;
40 const int SliceBeginIndex = 2;
41 const int SliceSizeIndex = 3;
42 int node_name_index = 0;
GetSliceBeginAndSize(const CNodePtr & cnode,const int index)43 std::vector<int> GetSliceBeginAndSize(const CNodePtr &cnode, const int index) {
44   MS_ASSERT(cnode != nullptr);
45   std::vector<int> content;
46   if (index != SliceBeginIndex && index != SliceSizeIndex && cnode->size() != 4) {
47     return content;
48   }
49   auto node = cnode->input(index);
50   if (node == nullptr) {
51     return content;
52   }
53   auto param_node = node->cast<ParameterPtr>();
54   if (param_node == nullptr || !param_node->has_default() || param_node->default_param() == nullptr) {
55     return content;
56   }
57   auto tensor_info = param_node->default_param()->cast<tensor::TensorPtr>();
58   if (tensor_info == nullptr) {
59     return content;
60   }
61   content.resize(tensor_info->DataSize());
62   if (memcpy_s(content.data(), tensor_info->Size(), tensor_info->data_c(), tensor_info->Size()) != EOK) {
63     MS_LOG(ERROR) << "memcpy data failed.";
64     return {};
65   }
66   return content;
67 }
68 
GetCNodeInputShape(const CNodePtr & cnode,size_t index=1)69 std::vector<int64_t> GetCNodeInputShape(const CNodePtr &cnode, size_t index = 1) {
70   MS_ASSERT(cnode != nullptr);
71   std::vector<int64_t> empty_shape;
72   if (index < 1 || cnode->size() <= index) {
73     MS_LOG(ERROR) << "out of index";
74     return empty_shape;
75   }
76   auto abstract = GetCNodeInputAbstract(cnode, index);
77   if (abstract == nullptr) {
78     MS_LOG(ERROR) << "Abstract of CNode is nullptr";
79     return empty_shape;
80   }
81   if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
82     MS_LOG(DEBUG) << "abstract is not AbstractTensor";
83     return empty_shape;
84   }
85   auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
86   MS_ASSERT(abstract_tensor != nullptr && abstract_tensor->shape() != nullptr);
87   return abstract_tensor->shape()->shape();
88 }
89 
GetDefaultParamShape(const ParameterPtr & param)90 std::vector<int64_t> GetDefaultParamShape(const ParameterPtr &param) {
91   MS_ASSERT(param != nullptr);
92   MS_ASSERT(param->has_default());
93   std::vector<int64_t> shape_vector;
94   auto default_param = param->default_param();
95   if (default_param == nullptr) {
96     MS_LOG(ERROR) << "default_param is nullptr";
97     return shape_vector;
98   }
99   if (!utils::isa<tensor::TensorPtr>(default_param)) {
100     MS_LOG(ERROR) << "default_param is not tensor::Tensor";
101     return shape_vector;
102   }
103   auto param_value_lite = utils::cast<tensor::TensorPtr>(default_param);
104   MS_ASSERT(param_value != nullptr);
105   auto shape = param_value_lite->shape();
106   std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
107                  [](const int val) { return static_cast<int64_t>(val); });
108   return shape_vector;
109 }
110 
IsScalarNode(const AnfNodePtr & nodePtr)111 bool IsScalarNode(const AnfNodePtr &nodePtr) {
112   MS_ASSERT(nodePtr != nullptr);
113   if (utils::isa<ParameterPtr>(nodePtr) && nodePtr->cast<ParameterPtr>()->has_default()) {
114     auto tensor = utils::cast<tensor::TensorPtr>(utils::cast<ParameterPtr>(nodePtr)->default_param());
115     MS_ASSERT(tensor != nullptr);
116     auto shape = tensor->shape();
117     if (shape.empty() || (shape.size() == 1 && shape[0] == 1)) {
118       return true;
119     }
120   }
121   return false;
122 }
123 
GetSlice(const CNodePtr & cnode)124 api::SharedPtr<mindspore::ops::SliceFusion> GetSlice(const CNodePtr &cnode) {
125   if (cnode == nullptr) {
126     return nullptr;
127   }
128   return ops::GetOperator<mindspore::ops::SliceFusion>(cnode->input(0));
129 }
130 
GetSoftmax(const CNodePtr & cnode)131 api::SharedPtr<mindspore::ops::Softmax> GetSoftmax(const CNodePtr &cnode) {
132   if (cnode == nullptr) {
133     return nullptr;
134   }
135   return ops::GetOperator<mindspore::ops::Softmax>(cnode->input(0));
136 }
137 
GetReshape(const CNodePtr & cnode)138 api::SharedPtr<mindspore::ops::Reshape> GetReshape(const CNodePtr &cnode) {
139   if (cnode == nullptr) {
140     return nullptr;
141   }
142   return ops::GetOperator<mindspore::ops::Reshape>(cnode->input(0));
143 }
144 
GetFc(const CNodePtr & cnode)145 api::SharedPtr<mindspore::ops::FullConnection> GetFc(const CNodePtr &cnode) {
146   if (cnode == nullptr) {
147     return nullptr;
148   }
149   return ops::GetOperator<mindspore::ops::FullConnection>(cnode->input(0));
150 }
151 
GetTransposePerm(const CNodePtr & node)152 std::vector<int> GetTransposePerm(const CNodePtr &node) {
153   MS_ASSERT(node != nullptr);
154   std::vector<int> perm;
155   if (!CheckPrimitiveType(node, prim::kPrimTranspose)) {
156     return perm;
157   }
158   if (node->size() != 3) {
159     return perm;
160   }
161   auto perm_node = node->input(2);
162   if (!utils::isa<ParameterPtr>(perm_node)) {
163     return perm;
164   }
165   auto perm_param = perm_node->cast<ParameterPtr>();
166   MS_ASSERT(perm_param != nullptr);
167   if (!perm_param->has_default() || perm_param->default_param() == nullptr) {
168     return perm;
169   }
170   auto perm_value = perm_param->default_param()->cast<tensor::TensorPtr>();
171   if (perm_value == nullptr) {
172     return perm;
173   }
174   MS_CHECK_TRUE_MSG(perm_value->shape().size() != 0, {}, "shape is empty");
175   perm.resize(perm_value->shape()[0]);
176   if (memcpy_s(perm.data(), perm_value->Size(), perm_value->data_c(), perm_value->Size()) != EOK) {
177     MS_LOG(ERROR) << "memcpy failed.";
178     return {};
179   }
180   return perm;
181 }
182 }  // namespace
183 
ClearCNodeAbstractValue(const CNodePtr & cnode)184 void SlicePreposePass::ClearCNodeAbstractValue(const CNodePtr &cnode) {
185   MS_ASSERT(cnode != nullptr);
186   auto abstract = cnode->abstract();
187   MS_ASSERT(abstract != nullptr);
188   if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
189     MS_LOG(DEBUG) << "Abstract of cnode is not abstract tensor, " << cnode->fullname_with_scope();
190   }
191   abstract->set_value(std::make_shared<ValueAny>());
192 }
193 
SwapSliceWithPreceed(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & preceed_cnode,const int index,const TransactionPtr & tr)194 STATUS SlicePreposePass::SwapSliceWithPreceed(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
195                                               const CNodePtr &preceed_cnode, const int index,
196                                               const TransactionPtr &tr) {
197   MS_ASSERT(graph != nullptr);
198   MS_ASSERT(slice_cnode != nullptr);
199   MS_ASSERT(preceed_cnode != nullptr);
200   if (slice_cnode->input(1) != preceed_cnode) {
201     MS_LOG(ERROR) << "proceed node must be slice node's direct parent";
202     return RET_ERROR;
203   }
204   if (IsMultiOutputTensors(graph, preceed_cnode)) {
205     MS_LOG(ERROR) << "proceed node referenced by multi nodes not support swap";
206     return RET_ERROR;
207   }
208   auto manager = graph->manager();
209   if (manager == nullptr) {
210     MS_LOG(ERROR) << "manager is nullptr";
211     return RET_ERROR;
212   }
213   auto node_users = manager->node_users()[slice_cnode];
214   if (tr != nullptr) {  // do swap with transaction
215     for (auto &node_user : node_users) {
216       tr->SetEdge(node_user.first, node_user.second, preceed_cnode);
217     }
218     tr->SetEdge(slice_cnode, 1, preceed_cnode->input(index));
219     tr->SetEdge(preceed_cnode, index, slice_cnode);
220   } else {
221     for (auto &node_user : node_users) {
222       manager->SetEdge(node_user.first, node_user.second, preceed_cnode);
223     }
224     manager->SetEdge(slice_cnode, 1, preceed_cnode->input(index));
225     manager->SetEdge(preceed_cnode, index, slice_cnode);
226   }
227   return RET_OK;
228 }
229 
CreateSliceValueNode(const std::vector<int64_t> & axes)230 ValueNodePtr SlicePreposePass::CreateSliceValueNode(const std::vector<int64_t> &axes) {
231   MS_ASSERT(graph != nullptr);
232   MS_ASSERT(slice_cnode != nullptr);
233   auto new_slice = std::make_shared<mindspore::ops::SliceFusion>();
234   MS_CHECK_TRUE_MSG(new_slice != nullptr, nullptr, "new_slice is nullptr");
235   auto new_slice_c = new_slice->GetPrim();
236   MS_CHECK_TRUE_MSG(new_slice_c != nullptr, nullptr, "new_slice_c is nullptr");
237   new_slice->set_axes(axes);
238   ValueNodePtr value_node = NewValueNode(new_slice_c);
239   MS_CHECK_TRUE_MSG(value_node != nullptr, nullptr, "NewValueNode Failed");
240   return value_node;
241 }
242 
CopySliceValueNode(const CNodePtr & slice_cnode)243 ValueNodePtr SlicePreposePass::CopySliceValueNode(const CNodePtr &slice_cnode) {
244   MS_ASSERT(graph != nullptr);
245   MS_ASSERT(slice_cnode != nullptr);
246   auto slice_c = ops::GetOperator<mindspore::ops::SliceFusion>(slice_cnode->input(0));
247   if (slice_c == nullptr) {
248     MS_LOG(ERROR) << "slice node is nullptr";
249     return nullptr;
250   }
251   auto new_slice = std::make_shared<mindspore::ops::SliceFusion>();
252   MS_CHECK_TRUE_MSG(new_slice != nullptr, nullptr, "new_slice_c is nullptr");
253   auto new_slice_c = new_slice->GetPrim();
254   MS_CHECK_TRUE_MSG(new_slice_c != nullptr, nullptr, "new_slice_c is nullptr");
255   new_slice->set_axes(new_slice->get_axes());
256   ValueNodePtr value_node = NewValueNode(new_slice_c);
257   MS_CHECK_TRUE_MSG(value_node != nullptr, nullptr, "NewValueNode Failed");
258   return value_node;
259 }
260 
InsertSlice(const FuncGraphPtr & graph,const std::vector<AnfNodePtr> & inputs,const CNodePtr & preceed_cnode,const int index,const TransactionPtr & tr)261 CNodePtr SlicePreposePass::InsertSlice(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &inputs,
262                                        const CNodePtr &preceed_cnode, const int index, const TransactionPtr &tr) {
263   MS_ASSERT(graph != nullptr);
264   MS_ASSERT(slice_cnode != nullptr);
265   MS_ASSERT(preceed_cnode != nullptr);
266   auto slice_cnode = graph->NewCNode(inputs);
267   MS_CHECK_TRUE_MSG(slice_cnode != nullptr, nullptr, "NewNode Failed");
268   slice_cnode->set_fullname_with_scope(preceed_cnode->fullname_with_scope() + "_slice_" +
269                                        std::to_string(node_name_index));
270   node_name_index += 1;
271   tr->SetEdge(preceed_cnode, index, slice_cnode);
272   return slice_cnode;
273 }
274 
VerifySliceAttrs(const CNodePtr & slice_cnode,const int dim)275 STATUS SlicePreposePass::VerifySliceAttrs(const CNodePtr &slice_cnode, const int dim) {
276   // according to ops/slice.cc, axes >= 0, begin >= 0, size >= -1
277   auto slice = GetSlice(slice_cnode);
278   if (slice == nullptr) {
279     MS_LOG(ERROR) << "Slice is nullptr";
280     return RET_ERROR;
281   }
282   auto axes = slice->get_axes();
283   auto begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
284   auto size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
285 
286   std::set<int64_t> unique_axes(axes.begin(), axes.end());
287   if (axes.empty() || unique_axes.size() != axes.size()) {
288     MS_LOG(DEBUG) << "Invalid slice axe attribute";
289     return RET_ERROR;
290   }
291   MS_CHECK_TRUE_MSG(begin.size() <= axes.size(), RET_ERROR, "begin size is wrong");
292   MS_CHECK_TRUE_MSG(size.size() <= axes.size(), RET_ERROR, "size.size() is wrong");
293   for (size_t i = 0; i < axes.size(); ++i) {
294     auto axe = axes[i];
295     if (dim > -1 && axe >= dim) {
296       MS_LOG(ERROR) << "Invalid slice axe attribute";
297       return RET_ERROR;
298     }
299     if (axe < 0) {
300       MS_LOG(ERROR) << "Invalid slice axe attribute";
301       return RET_ERROR;
302     }
303     if (begin[i] < 0) {  //  we not require begin[i] < ref_shape[axe], cause there may be broadcast
304       MS_LOG(ERROR) << "Invalid begin input! begin[" << i << "]=" << begin[i];
305       return RET_ERROR;
306     }
307     if (size[i] < -1) {
308       MS_LOG(ERROR) << "Invalid size input! size[" << i << "]=" << size[i];
309       return RET_ERROR;
310     }
311   }
312   return RET_OK;
313 }
314 
315 /*
316  * Adjust slice's attr when broadcast happened in Arithmetic
317  */
SliceParamDeBroadcast(const CNodePtr & slice_cnode,const std::vector<int64_t> & ref_shape,std::vector<int64_t> * axes,std::vector<int> * begin,std::vector<int> * size)318 STATUS SlicePreposePass::SliceParamDeBroadcast(const CNodePtr &slice_cnode, const std::vector<int64_t> &ref_shape,
319                                                std::vector<int64_t> *axes, std::vector<int> *begin,
320                                                std::vector<int> *size) {
321   MS_ASSERT(slice_cnode != nullptr);
322   MS_ASSERT(new_slice_cnode != nullptr);
323   MS_ASSERT(axes != nullptr);
324   MS_ASSERT(begin != nullptr);
325   MS_ASSERT(size != nullptr);
326   auto slice = GetSlice(slice_cnode);
327   if (slice == nullptr) {
328     MS_LOG(ERROR) << "slice is nullptr";
329     return RET_ERROR;
330   }
331   auto origin_axes = slice->get_axes();
332   auto origin_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
333   auto origin_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
334   auto status = VerifySliceAttrs(slice_cnode, ref_shape.size());
335   if (status != RET_OK) {
336     return status;
337   }
338   axes->resize(ref_shape.size());
339   std::iota(axes->begin(), axes->end(), 0);
340   begin->assign(ref_shape.size(), 0);
341   size->assign(ref_shape.size(), -1);
342   bool real_slice = false;  // whether slice happened at this input
343   MS_CHECK_TRUE_MSG(origin_begin.size() >= origin_axes.size(), RET_ERROR, "origin_begin.size() is wrong");
344   MS_CHECK_TRUE_MSG(origin_size.size() >= origin_axes.size(), RET_ERROR, "origin_size.size() is wrong");
345   for (size_t i = 0; i < origin_axes.size(); ++i) {
346     int a = origin_axes[i];
347     int b = origin_begin[i];
348     int s = origin_size[i];
349     MS_CHECK_TRUE_MSG(static_cast<int>(ref_shape.size()) > a, RET_ERROR, "ref_shape.size() is wrong");
350     int ref = ref_shape[a];
351     if (ref == 1) {        // broadcast
352       continue;            // sliced size is 0(such as begin=1,size=-1) is not considered.
353     } else if (ref > 1) {  // not broadcast
354       if (b >= ref) {
355         MS_LOG(ERROR) << "slice begin[" << a << "]=" << b << ", while ref_shape[" << a << "]=" << ref << ", can't fit!";
356         return RET_ERROR;
357       } else {
358         if (b != 0 || (s != -1 && s != ref)) {
359           real_slice = true;
360         }
361         MS_CHECK_TRUE_MSG(static_cast<int>(begin->size()) > a, RET_ERROR, "begin.size() is wrong");
362         MS_CHECK_TRUE_MSG(static_cast<int>(size->size()) > a, RET_ERROR, "size.size() is wrong");
363         begin->at(a) = b;
364         size->at(a) = s;
365       }
366     } else {  // ref == 0, not need slice
367       continue;
368     }
369   }
370   if (real_slice) {
371     return lite::RET_OK;
372   } else {
373     return lite::RET_NO_CHANGE;
374   }
375 }
376 
CreateReshapeCNode(const FuncGraphPtr & graph,const std::vector<int64_t> & shape_vector,const AbstractBasePtr & abstract,const CNodePtr & preceed_cnode)377 CNodePtr SlicePreposePass::CreateReshapeCNode(const FuncGraphPtr &graph, const std::vector<int64_t> &shape_vector,
378                                               const AbstractBasePtr &abstract, const CNodePtr &preceed_cnode) {
379   MS_ASSERT(graph != nullptr);
380   MS_ASSERT(slice_cnode != nullptr);
381   MS_ASSERT(abstract != nullptr);
382   MS_ASSERT(preceed_cnode != nullptr);
383   auto new_reshape = std::make_shared<mindspore::ops::Reshape>();
384   if (new_reshape == nullptr) {
385     MS_LOG(ERROR) << "primitive_c is nullptr";
386     return nullptr;
387   }
388   auto new_reshape_c = new_reshape->GetPrim();
389   MS_CHECK_TRUE_MSG(new_reshape_c != nullptr, nullptr, "new_reshape_c is nullptr");
390   ValueNodePtr value_node = NewValueNode(new_reshape_c);
391   if (value_node == nullptr) {
392     return nullptr;
393   }
394   std::vector<int> shape;
395   std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape),
396                  [](int64_t val) { return static_cast<int>(val); });
397   auto shape_node = BuildIntVecParameterNode(
398     graph, shape, preceed_cnode->fullname_with_scope() + "_shape_" + std::to_string(node_name_index));
399   node_name_index++;
400   if (shape_node == nullptr) {
401     MS_LOG(ERROR) << "build parameter node failed.";
402     return nullptr;
403   }
404   auto reshape_cnode = graph->NewCNode({value_node, preceed_cnode, shape_node});
405   MS_CHECK_TRUE_MSG(reshape_cnode != nullptr, nullptr, "NewCNode Failed");
406   reshape_cnode->set_abstract(abstract);
407   reshape_cnode->set_fullname_with_scope(preceed_cnode->fullname_with_scope() + "_reshape_" +
408                                          std::to_string(node_name_index));
409   node_name_index++;
410   ClearCNodeAbstractValue(reshape_cnode);
411   return reshape_cnode;
412 }
413 
SiblingsAreSameSlice(const NodeUsedListPtr & output_node_list,const std::vector<int64_t> & ref_shape)414 bool SlicePreposePass::SiblingsAreSameSlice(const NodeUsedListPtr &output_node_list,
415                                             const std::vector<int64_t> &ref_shape) {
416   MS_ASSERT(graph != nullptr);
417   MS_ASSERT(output_node_list != nullptr);
418   MS_ASSERT(output_node_list->size() >= 2);
419   std::vector<CNodePtr> slices;
420   for (auto &output_node : *(output_node_list.get())) {
421     auto cnode = output_node.first->cast<CNodePtr>();
422     MS_CHECK_TRUE_MSG(cnode != nullptr, false, "cnode is nullptr");
423     if (!CheckPrimitiveType(cnode, prim::kPrimSliceFusion)) {
424       return false;
425     }
426     auto slice_node = GetSlice(cnode);
427     MS_CHECK_TRUE_MSG(slice_node != nullptr, false, "Slice is nullptr");
428     slices.push_back(cnode);
429   }
430   MS_CHECK_TRUE_MSG(slices.size() > 0, false, "slices.size() is wrong");
431   auto first_slice_cnode = slices.front();
432   auto first_slice_node = GetSlice(first_slice_cnode);
433   MS_CHECK_TRUE_MSG(first_slice_node != nullptr, false, "GetSlice return nullptr");
434   auto first_axes = first_slice_node->get_axes();
435   auto first_begin = GetSliceBeginAndSize(first_slice_cnode, SliceBeginIndex);
436   auto first_size = GetSliceBeginAndSize(first_slice_cnode, SliceSizeIndex);
437   MS_CHECK_TRUE_MSG(first_begin.size() >= first_axes.size(), false, "first_begin.size() is wrong");
438   MS_CHECK_TRUE_MSG(first_size.size() >= first_axes.size(), false, "first_size.size() is wrong");
439   MS_CHECK_TRUE_MSG(slices.size() >= output_node_list->size(), false, "slices.size() is wrong");
440   for (size_t i = 1; i < output_node_list->size(); ++i) {
441     auto slice = GetSlice(slices[i]);
442     if (slice == nullptr) {
443       MS_LOG(WARNING) << "slice is nullptr!";
444       continue;
445     }
446     auto axes = slice->get_axes();
447     auto begin = GetSliceBeginAndSize(slices[i], SliceBeginIndex);
448     auto size = GetSliceBeginAndSize(slices[i], SliceSizeIndex);
449     MS_CHECK_TRUE_MSG(begin.size() >= axes.size(), false, "begin.size() is wrong");
450     MS_CHECK_TRUE_MSG(size.size() >= axes.size(), false, "size.size() is wrong");
451     if (axes.size() != first_axes.size()) {
452       return false;
453     }
454     for (size_t j = 0; j < axes.size(); ++j) {
455       auto axe = axes[j];
456       if (!ref_shape.empty() && axe >= static_cast<int>(ref_shape.size())) {
457         return false;
458       }
459       size_t k = 0;
460       for (; k < first_axes.size(); ++k) {  // axes may not be [0...n-1], so we use nested loop to find it
461         if (first_axes[k] == axe) {
462           break;
463         }
464       }
465       if (k == first_axes.size()) {
466         return false;
467       }
468       if (begin[j] != first_begin[k]) {
469         return false;
470       }
471       if (size[j] != first_size[k]) {
472         if (ref_shape.empty()) {
473           return false;
474         }
475         auto actual_size = size[j] > 0 ? size[j] : ref_shape[axe] - begin[j];
476         auto actual_first_size = first_size[k] > 0 ? first_size[k] : ref_shape[axe] - first_begin[k];
477         if (actual_size != actual_first_size) {
478           return false;
479         }
480       }
481     }
482   }
483   return true;
484 }
485 
GetReshapeAbnormalAxeIn(const std::vector<int64_t> & shape_in,const std::vector<int64_t> & shape_out,std::vector<int64_t> * mapped_axe)486 int64_t SlicePreposePass::GetReshapeAbnormalAxeIn(const std::vector<int64_t> &shape_in,
487                                                   const std::vector<int64_t> &shape_out,
488                                                   std::vector<int64_t> *mapped_axe) {
489   // find shape_out's correspond axe in shape_in
490   // when there are such as 3x1x1x4 => 3x1x4, mapped_axe[1] == 2
491   int64_t inner_size_in = 1;
492   int64_t abnormal_axe_in = -1;
493   MS_CHECK_TRUE_MSG(mapped_axe->size() >= shape_out.size(), abnormal_axe_in, "mapped_axe.size() is wrong");
494   for (size_t i = 0; i < shape_in.size(); ++i) {
495     inner_size_in *= shape_in[i];
496     int64_t inner_size_out = 1;
497     size_t j;
498     for (j = 0; j < shape_out.size(); ++j) {
499       inner_size_out *= shape_out[j];
500       if (shape_out[j] == shape_in[i] && inner_size_out == inner_size_in) {
501         mapped_axe->at(j) = i;
502         break;
503       }
504     }
505     if (j == shape_out.size() && abnormal_axe_in == -1) {
506       abnormal_axe_in = i;
507     }
508   }
509   return abnormal_axe_in;
510 }
511 
GetReshapeAbnormalIndexOut(const CNodePtr & slice_cnode,const std::vector<int64_t> & mapped_axe,const std::vector<int64_t> & shape_out,std::vector<int64_t> * shape_out_copy,bool * is_normal_mode,bool * support_abnormal_mode)512 int64_t SlicePreposePass::GetReshapeAbnormalIndexOut(const CNodePtr &slice_cnode,
513                                                      const std::vector<int64_t> &mapped_axe,
514                                                      const std::vector<int64_t> &shape_out,
515                                                      std::vector<int64_t> *shape_out_copy, bool *is_normal_mode,
516                                                      bool *support_abnormal_mode) {
517   MS_ASSERT(slice_cnode != nullptr);
518   MS_ASSERT(shape_out_copy != nullptr);
519   MS_ASSERT(is_normal_mode != nullptr);
520   MS_ASSERT(support_abnormal_mode != nullptr);
521   int64_t abnormal_index_out = -1;
522   auto slice_node = GetSlice(slice_cnode);
523   if (slice_node == nullptr) {
524     MS_LOG(ERROR) << "slice is nullptr";
525     return abnormal_index_out;
526   }
527   auto slice_axes = slice_node->get_axes();
528   auto slice_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
529   auto slice_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
530   for (size_t j = 0; j < shape_out.size(); ++j) {
531     int index = -1;
532     for (size_t i = 0; i < slice_axes.size(); ++i) {
533       if (slice_axes[i] == static_cast<int64_t>(j)) {
534         index = static_cast<int>(i);
535         break;
536       }
537     }
538     if (index == -1) {
539       continue;
540     }
541     MS_CHECK_TRUE_MSG(static_cast<int>(slice_begin.size()) > index, abnormal_index_out, "slice_begin.size() is wrong");
542     MS_CHECK_TRUE_MSG(static_cast<int>(slice_size.size()) > index, abnormal_index_out, "slice_size.size() is wrong");
543     if (slice_begin[index] != 0 || (slice_size[index] != -1 && slice_size[index] != shape_out[j])) {
544       if (mapped_axe[j] == -1) {
545         if (is_normal_mode) {
546           *is_normal_mode = false;
547           abnormal_index_out = static_cast<int64_t>(index);
548         } else {
549           *support_abnormal_mode = false;
550         }
551       } else {  // if there is matched axe sliced, not support abnormal mode
552         shape_out_copy->at(j) =
553           (slice_size[index] == -1 ? shape_out[j] - slice_begin[index] : static_cast<int64_t>(slice_size[index]));
554         *support_abnormal_mode = false;
555       }
556     }
557   }
558   return abnormal_index_out;
559 }
560 
PreposeWithNormalReshape(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & reshape_cnode,const std::vector<int64_t> & shape_in,const std::vector<int64_t> & shape_out_copy,const std::vector<int64_t> & mapped_axe)561 bool SlicePreposePass::PreposeWithNormalReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
562                                                 const CNodePtr &reshape_cnode, const std::vector<int64_t> &shape_in,
563                                                 const std::vector<int64_t> &shape_out_copy,
564                                                 const std::vector<int64_t> &mapped_axe) {
565   MS_ASSERT(graph != nullptr);
566   MS_ASSERT(slice_cnode != nullptr);
567   MS_ASSERT(reshape_cnode != nullptr);
568   auto slice_node = GetSlice(slice_cnode);
569   if (slice_node == nullptr) {
570     MS_LOG(ERROR) << "slice is nullptr";
571     return false;
572   }
573   auto slice_axes = slice_node->get_axes();
574   auto slice_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
575   auto slice_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
576   std::vector<int64_t> new_axes(shape_in.size());
577   std::iota(new_axes.begin(), new_axes.end(), 0);
578   std::vector<int> new_begin(shape_in.size(), 0);
579   std::vector<int> new_size(shape_in.size(), -1);
580   MS_CHECK_TRUE_MSG(slice_begin.size() >= mapped_axe.size(), false, "slice_begin.size() is wrong");
581   MS_CHECK_TRUE_MSG(slice_size.size() >= mapped_axe.size(), false, "slice_begin.size() is wrong");
582   for (size_t i = 0; i < mapped_axe.size(); ++i) {
583     auto axe_in = mapped_axe[i];
584     if (axe_in == -1) {
585       continue;
586     }
587     new_begin[axe_in] = slice_begin[i];
588     new_size[axe_in] = slice_size[i];
589   }
590 
591   auto reshape_node = GetReshape(reshape_cnode);
592   if (reshape_node == nullptr) {
593     MS_LOG(ERROR) << "reshape is nullptr";
594     return false;
595   }
596   std::vector<int> new_shape_out_copy;
597   std::transform(shape_out_copy.begin(), shape_out_copy.end(), std::back_inserter(new_shape_out_copy),
598                  [](int64_t val) { return static_cast<int>(val); });
599   auto shape_node = BuildIntVecParameterNode(
600     graph, new_shape_out_copy, reshape_cnode->fullname_with_scope() + "_shape_" + std::to_string(node_name_index));
601   node_name_index++;
602   if (shape_node == nullptr) {
603     MS_LOG(ERROR) << "build parameter node failed.";
604     return false;
605   }
606   reshape_cnode->set_inputs({reshape_cnode->input(0), reshape_cnode->input(1), shape_node});
607 
608   slice_node->set_axes(new_axes);
609   auto new_begin_parameter = BuildIntVecParameterNode(
610     graph, new_begin, slice_cnode->input(SliceBeginIndex)->cast<ParameterPtr>()->fullname_with_scope());
611   auto new_size_parameter = BuildIntVecParameterNode(
612     graph, new_size, slice_cnode->input(SliceSizeIndex)->cast<ParameterPtr>()->fullname_with_scope());
613   MS_CHECK_TRUE_MSG(new_begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
614   MS_CHECK_TRUE_MSG(new_size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
615   slice_cnode->set_input(SliceBeginIndex, new_begin_parameter);
616   slice_cnode->set_input(SliceSizeIndex, new_size_parameter);
617   auto status = SwapSliceWithPreceed(graph, slice_cnode, reshape_cnode, 1);
618   if (status != RET_OK) {
619     return false;
620   }
621   reshape_cnode->set_abstract(slice_cnode->abstract()->Clone());
622   ClearCNodeAbstractValue(slice_cnode);
623   return true;
624 }
625 
CreateSlice1ForReshapePrepose(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & matmul_cnode,const std::vector<int64_t> & shape_in,const int64_t abnormal_axe_in,const int64_t count_sliced_axe_in,const bool slice_at_front)626 CNodePtr SlicePreposePass::CreateSlice1ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
627                                                          const CNodePtr &matmul_cnode,
628                                                          const std::vector<int64_t> &shape_in,
629                                                          const int64_t abnormal_axe_in,
630                                                          const int64_t count_sliced_axe_in, const bool slice_at_front) {
631   MS_ASSERT(graph != nullptr);
632   MS_ASSERT(slice_cnode != nullptr);
633   MS_ASSERT(matmul_cnode != nullptr);
634   std::vector<int64_t> new_axes1(shape_in.size());
635   std::iota(new_axes1.begin(), new_axes1.end(), 0);
636   std::vector<int> new_begin1(shape_in.size(), 0);
637   std::vector<int> new_size1(shape_in.size(), -1);
638   if (slice_at_front) {
639     new_begin1[abnormal_axe_in] = static_cast<int>(count_sliced_axe_in);
640   } else {
641     new_size1[abnormal_axe_in] = static_cast<int>(shape_in[abnormal_axe_in] - count_sliced_axe_in);
642   }
643   auto new_slice1 = CreateSliceValueNode(new_axes1);
644   if (new_slice1 == nullptr) {
645     MS_LOG(ERROR) << "CreateSliceValueNode failed";
646     return nullptr;
647   }
648   auto begin_parameter = BuildIntVecParameterNode(
649     graph, new_begin1, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
650   MS_CHECK_TRUE_MSG(begin_parameter != nullptr, nullptr, "BuildIntVecParameterNode Failed");
651   node_name_index += 1;
652   auto size_parameter = BuildIntVecParameterNode(
653     graph, new_size1, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
654   MS_CHECK_TRUE_MSG(size_parameter != nullptr, nullptr, "BuildIntVecParameterNode Failed");
655   node_name_index += 1;
656   auto new_slice1_cnode = graph->NewCNode({new_slice1, matmul_cnode, begin_parameter, size_parameter});
657   MS_CHECK_TRUE_MSG(new_slice1_cnode != nullptr, nullptr, "NewNode Failed");
658   new_slice1_cnode->set_abstract(slice_cnode->abstract()->Clone());
659   new_slice1_cnode->set_fullname_with_scope(slice_cnode->fullname_with_scope() + "_slice_" +
660                                             std::to_string(node_name_index));
661   node_name_index++;
662   ClearCNodeAbstractValue(new_slice1_cnode);
663   return new_slice1_cnode;
664 }
665 
CreateSlice2ForReshapePrepose(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & new_reshape1_cnode,const std::vector<int64_t> & new_shape1,const int64_t abnormal_axe_in,const int64_t count_sliced2,const bool slice_at_front)666 CNodePtr SlicePreposePass::CreateSlice2ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
667                                                          const CNodePtr &new_reshape1_cnode,
668                                                          const std::vector<int64_t> &new_shape1,
669                                                          const int64_t abnormal_axe_in, const int64_t count_sliced2,
670                                                          const bool slice_at_front) {
671   MS_ASSERT(graph != nullptr);
672   MS_ASSERT(slice_cnode != nullptr);
673   MS_ASSERT(matmul_cnode != nullptr);
674   std::vector<int64_t> new_axes2(abnormal_axe_in + 1);
675   std::iota(new_axes2.begin(), new_axes2.end(), 0);
676   std::vector<int> new_begin2(abnormal_axe_in + 1, 0);
677   std::vector<int> new_size2(abnormal_axe_in + 1, -1);
678   if (count_sliced2 > new_shape1[abnormal_axe_in]) {
679     MS_LOG(WARNING) << "calculation error";
680     return nullptr;
681   }
682   if (slice_at_front) {
683     new_begin2[abnormal_axe_in] = static_cast<int>(new_shape1[abnormal_axe_in] - count_sliced2);
684   } else {
685     new_size2[abnormal_axe_in] = static_cast<int>(count_sliced2);
686   }
687   auto new_slice2 = CreateSliceValueNode(new_axes2);
688   if (new_slice2 == nullptr) {
689     MS_LOG(ERROR) << "CreateSliceValueNode failed";
690     return nullptr;
691   }
692   auto begin_parameter = BuildIntVecParameterNode(
693     graph, new_begin2, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
694   MS_CHECK_TRUE_MSG(begin_parameter != nullptr, nullptr, "BuildIntVecParameterNode Failed");
695   node_name_index += 1;
696   auto size_parameter = BuildIntVecParameterNode(
697     graph, new_size2, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
698   node_name_index += 1;
699   MS_CHECK_TRUE_MSG(size_parameter != nullptr, nullptr, "BuildIntVecParameterNode Failed");
700   auto new_slice2_cnode = graph->NewCNode({new_slice2, new_reshape1_cnode, begin_parameter, size_parameter});
701   MS_CHECK_TRUE_MSG(new_slice2_cnode != nullptr, nullptr, "NewNode Failed");
702   new_slice2_cnode->set_abstract(slice_cnode->abstract()->Clone());
703   new_slice2_cnode->set_fullname_with_scope(slice_cnode->fullname_with_scope() + "_slice_" +
704                                             std::to_string(node_name_index));
705   node_name_index++;
706   ClearCNodeAbstractValue(new_slice2_cnode);
707   return new_slice2_cnode;
708 }
709 
PreposeWithAbnormalReshape(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & matmul_cnode,const std::vector<int64_t> & shape_in,const std::vector<int64_t> & shape_out,const int64_t abnormal_axe_in,const int64_t abnormal_index_out)710 bool SlicePreposePass::PreposeWithAbnormalReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
711                                                   const CNodePtr &matmul_cnode, const std::vector<int64_t> &shape_in,
712                                                   const std::vector<int64_t> &shape_out, const int64_t abnormal_axe_in,
713                                                   const int64_t abnormal_index_out) {
714   MS_ASSERT(graph != nullptr);
715   MS_ASSERT(slice_cnode != nullptr);
716   auto manager = graph->manager();
717   MS_CHECK_TRUE_MSG(manager != nullptr, false, "manager is nullptr");
718   auto slice_node = GetSlice(slice_cnode);
719   if (slice_node == nullptr) {
720     MS_LOG(ERROR) << "slice is nullptr";
721     return false;
722   }
723   auto slice_axes = slice_node->get_axes();
724   MS_CHECK_TRUE_MSG(static_cast<int>(slice_axes.size()) > abnormal_index_out, false, "slice_axes.size() is wrong");
725   auto slice_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
726   auto slice_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
727   auto abnormal_axe_out = slice_axes[abnormal_index_out];
728   MS_ASSERT(abnormal_axe_out + 1 < shape_out.size());
729   int64_t inter_size_in = 1;
730   int64_t inter_size_out = 1;
731   for (auto i = 0; i < abnormal_axe_in; ++i) {
732     inter_size_in *= shape_in[i];
733   }
734   for (auto i = 0; i < abnormal_axe_out; ++i) {
735     inter_size_out *= shape_out[i];
736   }
737   if (inter_size_in != inter_size_out) {
738     MS_LOG(DEBUG) << "not support prepose now";
739     return false;
740   }
741   int64_t outer_size_in = 1;
742   int64_t outer_size_out = 1;
743   for (auto i = abnormal_axe_in + 1; i < static_cast<int>(shape_in.size()); ++i) {
744     outer_size_in *= shape_in[i];
745   }
746   for (auto i = abnormal_axe_out + 1; i < static_cast<int>(shape_out.size()); ++i) {
747     outer_size_out *= shape_out[i];
748   }
749   MS_CHECK_TRUE_MSG(static_cast<int>(slice_begin.size()) > abnormal_index_out, false, "slice_begin.size() is wrong");
750   const int64_t count_sliced_axe_front = slice_begin[abnormal_index_out];
751   MS_CHECK_TRUE_MSG(static_cast<int>(slice_size.size()) > abnormal_index_out, false, "slice_size.size() is wrong");
752   MS_CHECK_TRUE_MSG(static_cast<int>(shape_out.size()) > abnormal_index_out, false, "shape_out.size() is wrong");
753   const int64_t count_sliced_axe_rear =
754     slice_size[abnormal_index_out] == -1 ? 0 : (shape_out[abnormal_axe_out] - slice_size[abnormal_index_out]);
755   if (count_sliced_axe_front * count_sliced_axe_rear > 0) {
756     MS_LOG(DEBUG) << "not border slice at abnormal axe, prepose with reshape failed";
757     return false;
758   }
759   bool slice_at_front = count_sliced_axe_front > 0;
760   const int64_t count_sliced_out = (count_sliced_axe_front + count_sliced_axe_rear) * outer_size_out;
761   MS_CHECK_TRUE_MSG(outer_size_in != 0, false, "div zero");
762   const int64_t count_sliced_axe_in = count_sliced_out / outer_size_in;
763   MS_CHECK_TRUE_MSG(static_cast<int>(shape_in.size()) > abnormal_axe_in, false, "shape_in.size() is wrong");
764   if (count_sliced_axe_in <= 0 || count_sliced_axe_in > shape_in[abnormal_axe_in]) {
765     MS_LOG(DEBUG) << "amount of sliced out tensor is illegal";
766     return false;
767   }
768   // new_slice1
769   auto new_slice1_cnode = CreateSlice1ForReshapePrepose(graph, slice_cnode, matmul_cnode, shape_in, abnormal_axe_in,
770                                                         count_sliced_axe_in, slice_at_front);
771   if (new_slice1_cnode == nullptr) {
772     return false;
773   }
774   // new_reshape1
775   std::vector<int64_t> new_shape1(abnormal_axe_in + 1);
776   for (int i = 0; i < abnormal_axe_in; ++i) {
777     new_shape1[i] = shape_in[i];
778   }
779   new_shape1[abnormal_axe_in] = outer_size_in * (shape_in[abnormal_axe_in] - count_sliced_axe_in);
780   auto new_reshape1_cnode = CreateReshapeCNode(graph, new_shape1, slice_cnode->abstract()->Clone(), new_slice1_cnode);
781   if (new_reshape1_cnode == nullptr) {
782     return false;
783   }
784   // new_slice2
785   const int64_t count_sliced_abnormal_axe =
786     shape_out[abnormal_axe_out] - (count_sliced_axe_front + count_sliced_axe_rear);
787   const int64_t count_sliced2 = count_sliced_abnormal_axe * outer_size_out;
788   auto new_slice2_cnode = CreateSlice2ForReshapePrepose(graph, slice_cnode, new_reshape1_cnode, new_shape1,
789                                                         abnormal_axe_in, count_sliced2, slice_at_front);
790   if (new_slice2_cnode == nullptr) {
791     return false;
792   }
793   // new_reshape2
794   std::vector<int64_t> new_shape2(shape_out.begin(), shape_out.end());
795   new_shape2[abnormal_axe_out] = count_sliced_abnormal_axe;
796   auto new_reshape2_cnode = CreateReshapeCNode(graph, new_shape2, slice_cnode->abstract()->Clone(), new_slice2_cnode);
797   if (new_reshape2_cnode == nullptr) {
798     return false;
799   }
800   new_reshape2_cnode->set_abstract(slice_cnode->abstract()->Clone());
801   auto node_users = manager->node_users()[slice_cnode];
802   for (auto &node_user : node_users) {
803     manager->SetEdge(node_user.first, node_user.second, new_reshape2_cnode);
804   }
805   return true;
806 }
807 
GetArithmeticInputInfo(const CNodePtr & arithmetic_cnode,std::vector<AnfNodePtr> * inputs,std::vector<std::vector<int64_t>> * shapes,std::vector<bool> * is_default_params)808 bool SlicePreposePass::GetArithmeticInputInfo(const CNodePtr &arithmetic_cnode, std::vector<AnfNodePtr> *inputs,
809                                               std::vector<std::vector<int64_t>> *shapes,
810                                               std::vector<bool> *is_default_params) {
811   MS_ASSERT(inputs != nullptr);
812   MS_ASSERT(shapes != nullptr);
813   MS_ASSERT(is_default_params != nullptr);
814   MS_ASSERT(arithmetic_cnode != nullptr);
815   for (size_t i = 1; i < arithmetic_cnode->size(); ++i) {
816     auto input = arithmetic_cnode->input(i);
817     MS_ASSERT(input != nullptr);
818     std::vector<int64_t> shape;
819     if (utils::isa<ParameterPtr>(input)) {
820       auto parameter = utils::cast<ParameterPtr>(input);
821       MS_ASSERT(parameter != nullptr);
822       if (!parameter->has_default()) {  // if one input is input placeholder, we can't change it
823         return false;
824       } else {
825         shape = GetDefaultParamShape(parameter);
826         is_default_params->push_back(true);
827       }
828     } else {  // input is CNode
829       if (!utils::isa<CNodePtr>(input)) {
830         MS_LOG(ERROR) << "one of Arithmetic's input is not CNode";
831         return false;
832       }
833       shape = GetCNodeInputShape(arithmetic_cnode, i);
834       is_default_params->push_back(false);
835     }
836     inputs->push_back(input);
837     shapes->push_back(shape);
838   }
839   return true;
840 }
841 
842 /*
843  * Prepose condition:
844  *  the softmax axis is not sliced
845  */
PreposeWithSoftmax(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & softmax_cnode)846 bool SlicePreposePass::PreposeWithSoftmax(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
847                                           const CNodePtr &softmax_cnode) {
848   MS_ASSERT(graph != nullptr);
849   MS_ASSERT(slice_cnode != nullptr);
850   MS_ASSERT(softmax_cnode != nullptr);
851   auto softmax_node = GetSoftmax(softmax_cnode);
852   if (softmax_node == nullptr) {
853     MS_LOG(ERROR) << "softmax is nullptr";
854     return false;
855   }
856   std::vector<int64_t> softmax_axis{-1};
857   if (softmax_node->GetAttr(ops::kAxis) != nullptr) {
858     softmax_axis = softmax_node->get_axis();
859   }
860   if (softmax_axis.size() != 1) {
861     MS_LOG(ERROR) << "softmax axis is not a value, which don't support.";
862     return false;
863   }
864   auto shape = GetCNodeInputShape(softmax_cnode, 1);
865   if (softmax_axis.front() == -1) {
866     // when softmax axis == -1, shape info is needed to determine whether slice can be preposed
867     if (lite::JudgeDynamicShape(shape)) {
868       return false;
869     }
870     softmax_axis[0] += static_cast<int64_t>(shape.size());
871   }
872 
873   auto slice_node = GetSlice(slice_cnode);
874   if (slice_node == nullptr) {
875     return false;
876   }
877   auto slice_axes = slice_node->get_axes();
878   auto slice_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
879   auto slice_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
880 
881   MS_CHECK_TRUE_MSG(static_cast<int>(softmax_axis.size()) > 0, false, "shape_in.size() is wrong");
882   MS_CHECK_TRUE_MSG(slice_size.size() >= slice_axes.size(), false, "shape_in.size() is wrong");
883   MS_CHECK_TRUE_MSG(slice_begin.size() >= slice_axes.size(), false, "shape_in.size() is wrong");
884   for (size_t i = 0; i < slice_axes.size(); ++i) {
885     if (slice_axes[i] == softmax_axis.front()) {
886       if (slice_begin[i] != 0) {
887         return false;
888       }
889       if (slice_size[i] != -1) {
890         if (lite::JudgeDynamicShape(shape) || slice_axes[i] >= static_cast<int>(shape.size())) {
891           return false;
892         }
893         if (slice_size[i] < shape[slice_axes[i]]) {
894           return false;
895         }
896       }
897     }
898   }
899   auto status = SwapSliceWithPreceed(graph, slice_cnode, softmax_cnode, 1);
900   if (status != RET_OK) {
901     return false;
902   }
903   softmax_cnode->set_abstract(slice_cnode->abstract()->Clone());
904   ClearCNodeAbstractValue(slice_cnode);
905   return true;
906 }
907 
908 /*
909  * Prepose condition:
910  *  require shape info
911  *  when reshape is normal(memory view is not changed, such as 4x5 reshaped to 4x1x5), can always prepose
912  *  when reshape is abnormal(such as 4x5 reshaped to 5x4), can prepose under some constraint
913  * For abnormal mode:
914  *  we only support border(not slice at center) slice at first mismatch axe,
915  *  and we only support matmul->reshape->slice => matmul->slice->reshape*->slice*(drop "dead" data)->reshape now,
916  *  cause the performance influence introduced by additional (reshape*->slice*) has not been fully evaluated.
917  */
PreposeWithReshape(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & reshape_cnode)918 bool SlicePreposePass::PreposeWithReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
919                                           const CNodePtr &reshape_cnode) {
920   MS_ASSERT(graph != nullptr);
921   MS_ASSERT(slice_cnode != nullptr);
922   MS_ASSERT(reshape_cnode != nullptr);
923   auto shape_in = GetCNodeInputShape(reshape_cnode, 1);
924   auto shape_out = GetCNodeInputShape(slice_cnode, 1);
925   auto shape_out_copy = shape_out;
926   if (shape_in.empty() || shape_out.empty()) {
927     MS_LOG(DEBUG) << "Reshape can't be preposed if either input or output shape is unknown";
928     return false;
929   }
930   if (reshape_cnode->size() == 3 && utils::isa<ParameterPtr>(reshape_cnode->input(2))) {
931     auto reshape_input_shape = utils::cast<ParameterPtr>(reshape_cnode->input(2));
932     MS_ASSERT(reshape_input_shape != nullptr);
933     if (!reshape_input_shape->has_default()) {
934       MS_LOG(ERROR) << "Reshape input shape is not constant";
935       return false;
936     }
937   }
938   std::vector<int64_t> mapped_axe(shape_out.size(), -1);
939   int64_t abnormal_axe_in = GetReshapeAbnormalAxeIn(shape_in, shape_out, &mapped_axe);
940   bool is_normal_mode = true;         // if all sliced axe can be found in input shape, normal
941   bool support_abnormal_mode = true;  // if first mismatch axe are sliced and no more other axes are sliced, abnormal
942   int64_t abnormal_index_out = GetReshapeAbnormalIndexOut(slice_cnode, mapped_axe, shape_out, &shape_out_copy,
943                                                           &is_normal_mode, &support_abnormal_mode);
944   if (abnormal_index_out == -1) {
945     MS_LOG(ERROR) << "GetReshapeAbnormalIndexOut failed.";
946     return false;
947   }
948   if (is_normal_mode) {
949     return PreposeWithNormalReshape(graph, slice_cnode, reshape_cnode, shape_in, shape_out_copy, mapped_axe);
950   } else if (support_abnormal_mode) {
951     auto matmul_node = reshape_cnode->input(1);
952     MS_ASSERT(matmul_node != nullptr);
953     if (IsMultiOutputTensors(graph, matmul_node) || !utils::isa<CNodePtr>(matmul_node)) {
954       MS_LOG(DEBUG) << "not matmul->reshape->slice";
955       return false;
956     }
957     auto matmul_cnode = matmul_node->cast<CNodePtr>();
958     if (matmul_cnode == nullptr) {
959       MS_LOG(ERROR) << "matmul_cnode is nullptr";
960       return false;
961     }
962     if (!CheckPrimitiveType(matmul_node, prim::kPrimFullConnection) &&
963         !CheckPrimitiveType(matmul_node, prim::kPrimMatMulFusion)) {
964       MS_LOG(DEBUG) << "not matmul->reshape->slice pattern";
965       return false;
966     }
967     return PreposeWithAbnormalReshape(graph, slice_cnode, matmul_cnode, shape_in, shape_out, abnormal_axe_in,
968                                       abnormal_index_out);
969   }
970   return false;
971 }
972 
973 /*
974  * Prepose condition:
975  *  require shape info
976  */
PreposeWithMatmul(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & matmul_cnode)977 bool SlicePreposePass::PreposeWithMatmul(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
978                                          const CNodePtr &matmul_cnode) {
979   MS_ASSERT(graph != nullptr && slice_cnode != nullptr && matmul_cnode != nullptr);
980   auto matmul_shape = GetCNodeInputShape(slice_cnode, 1);
981   int dims = static_cast<int>(matmul_shape.size());
982   if (dims == 0) {
983     // if Matmul's output shape is unknown, can't do prepose, cause we can't determine last two axes
984     return false;
985   }
986   auto slice_node = GetSlice(slice_cnode);
987   MS_CHECK_TRUE_MSG(slice_node != nullptr, false, "slice is nullptr");
988   auto axes = slice_node->get_axes();
989   auto begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
990   auto size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
991   // matmul not support broadcast now, it makes things simpler
992   auto manager = graph->manager();
993   std::shared_ptr<FuncGraphTransaction> tr = std::make_shared<FuncGraphTransaction>(manager.get());
994   MS_CHECK_TRUE_MSG(tr != nullptr, false, "create FuncGraphTransaction failed");
995   auto node_users = manager->node_users()[slice_cnode];
996   bool changed = false;
997   bool prepose_to_left = false;   // if only the last axe is sliced, not need prepose to left
998   bool prepose_to_right = false;  // if only the second last axe is sliced, not need prepose to right
999   MS_CHECK_TRUE_MSG(begin.size() >= axes.size(), false, "begin.size() is wrong");
1000   MS_CHECK_TRUE_MSG(size.size() >= axes.size(), false, "size.size() is wrong");
1001   for (size_t i = 0; i < axes.size(); ++i) {
1002     if (begin[i] != 0 || (size[i] != -1 && size[i] != matmul_shape[axes[i]])) {
1003       if (axes[i] != dims - 1) {
1004         prepose_to_left = true;
1005       } else if (axes[i] != dims - 2) {
1006         prepose_to_right = true;
1007       }
1008     }
1009   }
1010   if (prepose_to_left) {  //  left matrix
1011     auto left_axes = axes;
1012     auto left_begin = begin;
1013     auto left_size = size;
1014     MS_CHECK_TRUE_MSG(left_begin.size() >= left_axes.size(), false, "left_begin.size() is wrong");
1015     MS_CHECK_TRUE_MSG(left_size.size() >= left_axes.size(), false, "left_size.size() is wrong");
1016     for (size_t i = 0; i < left_axes.size(); ++i) {
1017       if (left_axes[i] == dims - 1) {
1018         left_begin[i] = 0;
1019         left_size[i] = -1;
1020       }
1021     }
1022     auto left_slice_vnode = CreateSliceValueNode(left_axes);
1023     MS_CHECK_TRUE_MSG(left_slice_vnode != nullptr, false, "CreateSliceValueNode failed");
1024     auto begin_parameter = BuildIntVecParameterNode(
1025       graph, left_begin, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1026     node_name_index += 1;
1027     auto size_parameter = BuildIntVecParameterNode(
1028       graph, left_size, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1029     MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1030     MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1031     node_name_index += 1;
1032 
1033     const std::vector<AnfNodePtr> inputs = {left_slice_vnode, matmul_cnode->input(1), begin_parameter, size_parameter};
1034     auto new_slice_cnode = InsertSlice(graph, inputs, matmul_cnode, 1, tr);
1035     MS_CHECK_TRUE_MSG(new_slice_cnode != nullptr, false, "InsertSlice Failed");
1036     new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone());
1037     ClearCNodeAbstractValue(new_slice_cnode);
1038     changed = true;
1039   }
1040   if (prepose_to_right) {  //  right matrix
1041     auto right_axes = axes;
1042     auto right_begin = begin;
1043     auto right_size = size;
1044     MS_CHECK_TRUE_MSG(right_begin.size() >= right_axes.size(), false, "right_begin.size() is wrong");
1045     MS_CHECK_TRUE_MSG(right_size.size() >= right_axes.size(), false, "right_size.size() is wrong");
1046     for (size_t i = 0; i < right_axes.size(); ++i) {
1047       if (right_axes[i] == dims - 2) {
1048         right_begin[i] = 0;
1049         right_size[i] = -1;
1050       }
1051     }
1052     auto begin_parameter = BuildIntVecParameterNode(
1053       graph, right_begin, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1054     node_name_index += 1;
1055     auto size_parameter = BuildIntVecParameterNode(
1056       graph, right_size, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1057     MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1058     MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1059     node_name_index += 1;
1060     auto right_slice_vnode = CreateSliceValueNode(right_axes);
1061     MS_CHECK_TRUE_MSG(right_slice_vnode != nullptr, false, "CreateSliceValueNode failed");
1062     const std::vector<AnfNodePtr> inputs = {right_slice_vnode, matmul_cnode->input(2), begin_parameter, size_parameter};
1063     auto new_slice_cnode = InsertSlice(graph, inputs, matmul_cnode, 2, tr);
1064     MS_ASSERT(new_slice_cnode != nullptr);
1065     new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone());
1066     ClearCNodeAbstractValue(new_slice_cnode);
1067     changed = true;
1068   }
1069   if (changed) {
1070     matmul_cnode->set_abstract(slice_cnode->abstract()->Clone());
1071     for (auto &node_user : node_users) {
1072       tr->SetEdge(node_user.first, node_user.second, matmul_cnode);
1073     }
1074     tr->Commit();
1075     // we don't need graph->DropNode(slice_cnode);
1076   }
1077   return changed;
1078 }
1079 
1080 /*
1081  * Prepose condition:
1082  *  require shape info
1083  *  only support slice at first output axe now, and useAxis must be false
1084  */
PreposeWithFullConnection(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & fc_cnode)1085 bool SlicePreposePass::PreposeWithFullConnection(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
1086                                                  const CNodePtr &fc_cnode) {
1087   MS_ASSERT(graph != nullptr);
1088   MS_ASSERT(slice_cnode != nullptr);
1089   MS_ASSERT(fc_cnode != nullptr);
1090   auto shape_in = GetCNodeInputShape(fc_cnode, 1);
1091   auto shape_out = GetCNodeInputShape(slice_cnode, 1);
1092   if (shape_in.empty() || shape_out.size() != 2) {
1093     MS_LOG(DEBUG) << "FullConnection can't be preposed if input shape is unknown or output shape is illegal";
1094     return false;
1095   }
1096   auto fc_node = GetFc(fc_cnode);
1097   if (fc_node == nullptr || (fc_node->GetAttr(ops::kUseAxis) != nullptr && fc_node->get_use_axis())) {
1098     MS_LOG(DEBUG) << "prepose with fc only support useAxis == false currently";
1099     return false;
1100   }
1101   auto slice_node = GetSlice(slice_cnode);
1102   MS_CHECK_TRUE_MSG(slice_node != nullptr, false, "slice is nullptr");
1103   auto axes = slice_node->get_axes();
1104   auto begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
1105   auto size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
1106   MS_CHECK_TRUE_MSG(begin.size() >= axes.size(), false, "begin.size() is wrong");
1107   MS_CHECK_TRUE_MSG(size.size() >= axes.size(), false, "size.size() is wrong");
1108   for (size_t i = 0; i < axes.size(); ++i) {
1109     if (axes[i] == 1) {
1110       if (begin[i] != 0 || (size[i] != -1 && size[i] != shape_out[1])) {
1111         MS_LOG(DEBUG) << "prepose with fc only support first output axe is sliced currently";
1112         return false;
1113       }
1114     }
1115   }
1116 
1117   std::vector<int64_t> mapped_axe(shape_out.size(), -1);
1118   int64_t inner_size_in = 1;
1119   for (size_t i = 0; i < shape_in.size(); ++i) {
1120     inner_size_in *= shape_in[i];
1121     int64_t inner_size_out = 1;
1122     for (size_t j = 0; j < shape_out.size(); ++j) {
1123       inner_size_out *= shape_out[j];
1124       if (shape_out[j] == shape_in[i] && inner_size_out == inner_size_in) {
1125         mapped_axe[j] = static_cast<int64_t>(i);
1126         break;
1127       }
1128     }
1129   }
1130   if (mapped_axe[0] == -1) {
1131     MS_LOG(DEBUG) << "first axe in output can't find correspond input axe, can't do prepose";
1132     return false;
1133   }
1134 
1135   std::vector<int64_t> new_axes(shape_in.size());
1136   std::iota(new_axes.begin(), new_axes.end(), 0);
1137   std::vector<int> new_begin(shape_in.size(), 0);
1138   std::vector<int> new_size(shape_in.size(), -1);
1139   new_begin[mapped_axe[0]] = begin[0];
1140   new_size[mapped_axe[0]] = size[0];
1141   auto new_slice_vnode = CreateSliceValueNode(new_axes);
1142   MS_CHECK_TRUE_MSG(new_slice_vnode != nullptr, false, "CreateSliceValueNode failed");
1143 
1144   auto manager = graph->manager();
1145   std::shared_ptr<FuncGraphTransaction> tr = std::make_shared<FuncGraphTransaction>(manager.get());
1146   MS_CHECK_TRUE_MSG(tr != nullptr, false, "create FuncGraphTransaction failed");
1147   auto begin_parameter = BuildIntVecParameterNode(
1148     graph, new_begin, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1149   node_name_index += 1;
1150   auto size_parameter = BuildIntVecParameterNode(
1151     graph, new_size, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1152   MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1153   MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1154   node_name_index += 1;
1155   const std::vector<AnfNodePtr> inputs = {new_slice_vnode, fc_cnode->input(1), begin_parameter, size_parameter};
1156   auto new_slice_cnode = InsertSlice(graph, inputs, fc_cnode, 1, tr);
1157   MS_CHECK_TRUE_MSG(new_slice_cnode != nullptr, false, "InsertSlice Failed");
1158 
1159   fc_cnode->set_abstract(slice_cnode->abstract()->Clone());
1160   new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone());
1161   ClearCNodeAbstractValue(new_slice_cnode);
1162 
1163   auto node_users = manager->node_users()[slice_cnode];
1164   for (auto &node_user : node_users) {
1165     tr->SetEdge(node_user.first, node_user.second, fc_cnode);
1166   }
1167   tr->Commit();
1168   return true;
1169 }
1170 
1171 /*
1172  * Prepose condition:
1173  *  not require shape info, can always prepose
1174  */
PreposeWithTranspose(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & transpose_cnode)1175 bool SlicePreposePass::PreposeWithTranspose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
1176                                             const CNodePtr &transpose_cnode) {
1177   MS_ASSERT(graph != nullptr);
1178   MS_ASSERT(slice_cnode != nullptr);
1179   MS_ASSERT(transpose_cnode != nullptr);
1180   if (transpose_cnode->size() != 3) {
1181     MS_LOG(ERROR) << "transpose inputs size should be 3.";
1182     return false;
1183   }
1184   auto perm = GetTransposePerm(transpose_cnode);
1185   if (perm.empty()) {
1186     return false;
1187   }
1188   auto slice_node = GetSlice(slice_cnode);
1189   if (slice_node == nullptr) {
1190     MS_LOG(ERROR) << "GetSlicT failed";
1191     return false;
1192   }
1193   auto old_axes = slice_node->get_axes();
1194   auto old_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
1195   auto old_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
1196   auto slice_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
1197   auto slice_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
1198   // perm is random shuffle of [0...n-1] according to ops/transpose.cc
1199   for (size_t i = 0; i < perm.size(); ++i) {
1200     if (perm[i] != static_cast<int>(i)) {
1201       for (size_t j = 0; j < old_axes.size(); ++j) {
1202         if (old_axes[j] == static_cast<int>(i)) {
1203           MS_CHECK_TRUE_MSG(static_cast<int>(slice_begin.size()) > perm[i], false, "slice_begin.size() is wrong");
1204           MS_CHECK_TRUE_MSG(static_cast<int>(slice_size.size()) > perm[i], false, "slice_size.size() is wrong");
1205           slice_begin[perm[i]] = old_begin[j];
1206           slice_size[perm[i]] = old_size[j];
1207           break;
1208         }
1209       }
1210     }
1211   }
1212   auto begin_parameter = BuildIntVecParameterNode(
1213     graph, slice_begin, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1214   node_name_index += 1;
1215   auto size_parameter = BuildIntVecParameterNode(
1216     graph, slice_size, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1217   MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1218   MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1219   node_name_index += 1;
1220   slice_cnode->set_input(SliceBeginIndex, begin_parameter);
1221   slice_cnode->set_input(SliceSizeIndex, size_parameter);
1222   auto status = SwapSliceWithPreceed(graph, slice_cnode, transpose_cnode, 1);
1223   if (status != RET_OK) {
1224     return false;
1225   }
1226   transpose_cnode->set_abstract(slice_cnode->abstract()->Clone());
1227   ClearCNodeAbstractValue(slice_cnode);
1228   return true;
1229 }
1230 /*
1231  * Prepose condition:
1232  *  may or may not require shape info
1233  */
PreposeWithArithmetic(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & arithmetic_cnode)1234 bool SlicePreposePass::PreposeWithArithmetic(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
1235                                              const CNodePtr &arithmetic_cnode) {
1236   MS_ASSERT(graph != nullptr);
1237   MS_ASSERT(slice_cnode != nullptr);
1238   MS_ASSERT(arithmetic_cnode != nullptr);
1239   auto manager = graph->manager();
1240   MS_ASSERT(manager != nullptr);
1241   auto node_users = manager->node_users()[slice_cnode];
1242   std::shared_ptr<FuncGraphTransaction> tr = std::make_shared<FuncGraphTransaction>(manager.get());
1243   if (tr == nullptr) {
1244     MS_LOG(ERROR) << "create FuncGraphTransaction failed";
1245     return false;
1246   }
1247   bool changed = false;
1248   std::vector<AnfNodePtr> inputs;
1249   std::vector<std::vector<int64_t>> shapes;
1250   std::vector<bool> is_default_params;
1251   if (!GetArithmeticInputInfo(arithmetic_cnode, &inputs, &shapes, &is_default_params)) {
1252     return false;
1253   }
1254 
1255   for (size_t i = 1; i < arithmetic_cnode->size(); ++i) {
1256     auto &input = inputs[i - 1];
1257     if (IsScalarNode(input)) {  // scalar not need prepose
1258       continue;
1259     }
1260     auto &shape = shapes[i - 1];
1261     const size_t another_index = kArithmeticInputNum - i;
1262     auto &another_input = inputs[another_index];
1263     auto &another_shape = shapes[another_index];
1264     if (IsScalarNode(input)) {
1265       continue;
1266     } else if (lite::JudgeDynamicShape(shape)) {  // infershape failed at this input
1267       if (IsScalarNode(another_input)) {          // if another input is scalar, we can process this one
1268         auto new_slice_vnode = CopySliceValueNode(slice_cnode);
1269         if (new_slice_vnode == nullptr) {
1270           changed = false;
1271           break;
1272         }
1273         std::vector<AnfNodePtr> slice_inputs = {new_slice_vnode, arithmetic_cnode->input(i),
1274                                                 slice_cnode->input(SliceBeginIndex),
1275                                                 slice_cnode->input(SliceSizeIndex)};
1276         auto new_slice_cnode = InsertSlice(graph, slice_inputs, arithmetic_cnode, i, tr);
1277         MS_CHECK_TRUE_MSG(new_slice_cnode != nullptr, false, "InsertSlice Failed");
1278 
1279         new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone());
1280         ClearCNodeAbstractValue(new_slice_cnode);
1281         changed = true;
1282         break;
1283       } else {  // if another input's shape is not scalar, can't be processed
1284         changed = false;
1285         break;
1286       }
1287     } else {  // shape not empty
1288       if (!another_shape.empty() || IsScalarNode(another_input)) {
1289         std::vector<int64_t> new_axes;
1290         std::vector<int> new_begin;
1291         std::vector<int> new_size;
1292         auto status = SliceParamDeBroadcast(slice_cnode, shape, &new_axes, &new_begin, &new_size);
1293         if (status == lite::RET_NO_CHANGE) {
1294           continue;
1295         }
1296         if (status != lite::RET_OK) {
1297           changed = false;
1298           break;
1299         }
1300         auto new_slice_vnode = CreateSliceValueNode(new_axes);
1301         if (new_slice_vnode == nullptr) {
1302           changed = false;
1303           break;
1304         }
1305         auto begin_parameter = BuildIntVecParameterNode(
1306           graph, new_begin, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1307         node_name_index += 1;
1308         auto size_parameter = BuildIntVecParameterNode(
1309           graph, new_size, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1310         MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1311         MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1312         node_name_index += 1;
1313         std::vector<AnfNodePtr> slice_inputs = {new_slice_vnode, arithmetic_cnode->input(i), begin_parameter,
1314                                                 size_parameter};
1315         auto new_slice_cnode = InsertSlice(graph, slice_inputs, arithmetic_cnode, i, tr);
1316         MS_CHECK_TRUE_MSG(new_slice_cnode != nullptr, false, "InsertSlice Failed");
1317         new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone());
1318         ClearCNodeAbstractValue(new_slice_cnode);
1319         changed = true;
1320       } else {
1321         changed = false;
1322         break;
1323       }
1324     }
1325   }
1326   if (changed) {
1327     arithmetic_cnode->set_abstract(slice_cnode->abstract()->Clone());
1328     for (auto &node_user : node_users) {
1329       tr->SetEdge(node_user.first, node_user.second, arithmetic_cnode);
1330     }
1331     tr->Commit();
1332     // we don't need graph->DropNode(slice_cnode);
1333   }
1334   return changed;
1335 }  // namespace mindspore::opt
1336 /*
1337  * Prepose condition:
1338  *  not require shape info
1339  */
MergeSequentialSlice(const FuncGraphPtr & graph,const CNodePtr & slice1_cnode,const CNodePtr & slice2_cnode)1340 bool SlicePreposePass::MergeSequentialSlice(const FuncGraphPtr &graph, const CNodePtr &slice1_cnode,
1341                                             const CNodePtr &slice2_cnode) {
1342   if (slice2_cnode->size() != kArithmeticInputNum) {
1343     MS_LOG(INFO) << "Slice read attrs from input is not supported now";
1344     return false;
1345   }
1346   auto slice1_node = GetSlice(slice1_cnode);  // bottom node
1347   auto slice2_node = GetSlice(slice2_cnode);  // top node
1348   if (slice1_node == nullptr || slice2_node == nullptr) {
1349     MS_LOG(ERROR) << "slice is null";
1350     return false;
1351   }
1352   auto begin_slice1 = GetSliceBeginAndSize(slice1_cnode, SliceBeginIndex);
1353   auto size_slice1 = GetSliceBeginAndSize(slice1_cnode, SliceSizeIndex);
1354   auto axes_slice1 = slice1_node->get_axes();
1355   auto begin_slice2 = GetSliceBeginAndSize(slice2_cnode, SliceBeginIndex);
1356   auto size_slice2 = GetSliceBeginAndSize(slice2_cnode, SliceSizeIndex);
1357   auto axes_slice2 = slice2_node->get_axes();
1358   auto status1 = VerifySliceAttrs(slice1_cnode);
1359   auto status2 = VerifySliceAttrs(slice2_cnode);
1360   if (status1 != RET_OK || status2 != RET_OK) {
1361     return false;
1362   }
1363 
1364   auto manager = graph->manager();
1365   MS_ASSERT(manager != nullptr);
1366   auto node_users = manager->node_users()[slice1_cnode];
1367   int64_t axe_max1 = *std::max_element(axes_slice1.begin(), axes_slice1.end());
1368   int64_t axe_max2 = *std::max_element(axes_slice2.begin(), axes_slice2.end());
1369   int64_t axe_max = std::max(axe_max1, axe_max2);
1370   auto begin_new = begin_slice2;
1371   auto size_new = size_slice2;
1372   auto axes_new = slice2_node->get_axes();
1373   axes_new.resize(axe_max + 1);
1374   std::iota(axes_new.begin(), axes_new.end(), 0);
1375   begin_new.assign(axe_max + 1, 0);
1376   size_new.assign(axe_max + 1, -1);
1377   MS_CHECK_TRUE_MSG(begin_slice2.size() >= axes_slice2.size(), false, "begin_slice2.size() is wrong");
1378   MS_CHECK_TRUE_MSG(size_slice2.size() >= axes_slice2.size(), false, "size_slice2.size() is wrong");
1379   MS_CHECK_TRUE_MSG(size_slice1.size() >= axes_slice1.size(), false, "size_slice1.size() is wrong");
1380   MS_CHECK_TRUE_MSG(begin_slice1.size() >= axes_slice1.size(), false, "begin_slice1.size() is wrong");
1381   for (int i = 0; i <= axe_max; ++i) {
1382     for (size_t j = 0; j < axes_slice2.size(); ++j) {
1383       if (axes_slice2[j] == i) {
1384         begin_new[i] = begin_slice2[j];
1385         size_new[i] = size_slice2[j];
1386         break;
1387       }
1388     }
1389     for (size_t j = 0; j < axes_slice1.size(); ++j) {
1390       if (axes_slice1[j] == i) {
1391         begin_new[i] = begin_new[i] + begin_slice1[j];
1392         if (size_new[i] == -1) {
1393           size_new[i] = size_slice1[j];
1394         } else {
1395           if (size_slice1[j] == -1) {
1396             size_new[i] = std::max(size_new[i] - begin_slice1[i], 0);  // clip with zero to avoid invalid negative value
1397           } else {
1398             size_new[i] = std::max(std::min(size_new[i] - begin_slice1[j], size_slice1[j]), 0);
1399           }
1400         }
1401         break;
1402       }
1403     }
1404   }
1405   slice2_node->set_axes(axes_new);
1406   auto begin_parameter = BuildIntVecParameterNode(
1407     graph, begin_new, slice2_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1408   node_name_index += 1;
1409   auto size_parameter = BuildIntVecParameterNode(
1410     graph, size_new, slice2_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1411   MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1412   MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1413   node_name_index += 1;
1414   slice2_cnode->set_input(SliceBeginIndex, begin_parameter);
1415   slice2_cnode->set_input(SliceSizeIndex, size_parameter);
1416   slice2_cnode->set_abstract(slice1_cnode->abstract()->Clone());
1417   for (auto &node_user : node_users) {
1418     manager->SetEdge(node_user.first, node_user.second, slice2_cnode);
1419   }
1420   return true;
1421 }
1422 
1423 /*
1424  * Prepose condition:
1425  *  when all sibling slices do same work
1426  *  can be optimize to not require all siblings are slice
1427  */
MergeParallelSlice(const FuncGraphPtr & graph,const NodeUsedListPtr & slices)1428 bool SlicePreposePass::MergeParallelSlice(const FuncGraphPtr &graph, const NodeUsedListPtr &slices) {
1429   MS_ASSERT(graph != nullptr);
1430   MS_ASSERT(slices->size() >= 2);
1431   auto manager = graph->manager();
1432   MS_ASSERT(manager != nullptr);
1433   auto first_slice = utils::cast<CNodePtr>(slices->at(0).first);
1434   MS_ASSERT(first_slice != nullptr);
1435   if (!CheckPrimitiveType(first_slice, prim::kPrimSliceFusion)) {
1436     MS_LOG(ERROR) << "first node is not Slice";
1437     return false;
1438   }
1439   auto first_parent = first_slice->input(1);
1440   if (first_parent == nullptr) {
1441     MS_LOG(ERROR) << "first slice node's parent is nullptr";
1442     return false;
1443   }
1444   std::shared_ptr<FuncGraphTransaction> tr = std::make_shared<FuncGraphTransaction>(manager.get());
1445   if (tr == nullptr) {
1446     MS_LOG(ERROR) << "create FuncGraphTransaction failed";
1447     return false;
1448   }
1449   for (size_t i = 1; i < slices->size(); ++i) {
1450     auto slice = utils::cast<CNodePtr>(slices->at(i).first);
1451     MS_ASSERT(slice != nullptr);
1452     if (!CheckPrimitiveType(slice, prim::kPrimSliceFusion)) {
1453       MS_LOG(ERROR) << "current node is not Slice";
1454       return false;
1455     }
1456     auto parent = slice->input(1);
1457     if (parent == nullptr || parent != first_parent) {
1458       MS_LOG(ERROR) << "not all slices have same parent node";
1459       return false;
1460     }
1461     auto node_users = manager->node_users()[slices->at(i).first];
1462     for (auto &node_user : node_users) {
1463       tr->SetEdge(node_user.first, node_user.second, slices->at(0).first);
1464     }
1465   }
1466   tr->Commit();
1467   return true;
1468 }
1469 
DoPrepose(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & preceed_cnode)1470 bool SlicePreposePass::DoPrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
1471                                  const CNodePtr &preceed_cnode) {
1472   MS_ASSERT(graph != nullptr);
1473   MS_ASSERT(slice_cnode != nullptr);
1474   MS_ASSERT(preceed_cnode != nullptr);
1475   if (CheckPrimitiveType(preceed_cnode, prim::kPrimSoftmax)) {
1476     return PreposeWithSoftmax(graph, slice_cnode, preceed_cnode);
1477   } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimReshape)) {
1478     return PreposeWithReshape(graph, slice_cnode, preceed_cnode);
1479   } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimMatMulFusion)) {
1480     return PreposeWithMatmul(graph, slice_cnode, preceed_cnode);
1481   } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimFullConnection)) {
1482     return PreposeWithFullConnection(graph, slice_cnode, preceed_cnode);
1483   } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimTranspose)) {
1484     return PreposeWithTranspose(graph, slice_cnode, preceed_cnode);
1485   } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimSubFusion) ||
1486              CheckPrimitiveType(preceed_cnode, prim::kPrimMulFusion) ||
1487              CheckPrimitiveType(preceed_cnode, prim::kPrimAddFusion)) {
1488     return PreposeWithArithmetic(graph, slice_cnode, preceed_cnode);
1489   } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimSliceFusion)) {
1490     return MergeSequentialSlice(graph, slice_cnode, preceed_cnode);
1491   }
1492   return false;
1493 }
1494 
Run(const FuncGraphPtr & graph)1495 bool SlicePreposePass::Run(const FuncGraphPtr &graph) {
1496   if (fmk_type != converter::kFmkTypeTf && fmk_type != converter::kFmkTypeTflite) {
1497     MS_LOG(INFO) << "The framework type of model should be tf/tflite.";
1498     return false;
1499   }
1500   MS_ASSERT(graph != nullptr);
1501   bool changed = false;
1502   while (true) {
1503     bool this_time_changed = false;
1504     auto node_list = TopoSort(graph->get_return());
1505     for (auto &node : node_list) {
1506       if (node->func_graph() != graph || !utils::isa<CNodePtr>(node) ||
1507           !CheckPrimitiveType(node, prim::kPrimSliceFusion)) {
1508         continue;
1509       }
1510       auto slice_cnode = node->cast<CNodePtr>();
1511       MS_ASSERT(slice_cnode != nullptr);
1512       // only support begin and size is const tensor.
1513       if (!CheckIsAllInputsParam(slice_cnode) || GetSlice(slice_cnode)) {
1514         continue;
1515       }
1516       auto preceed_node = slice_cnode->input(1);
1517       if (preceed_node == nullptr) {
1518         MS_LOG(ERROR) << "proceed node is nullptr";
1519         continue;
1520       }
1521       auto output_tensor_num = GetOutputTensorNum(preceed_node);
1522       if (output_tensor_num > 1) {
1523         continue;
1524       }
1525       auto output_node_list = Helper::GetRealNodeUsedList(graph, utils::cast<AnfNodePtr>(preceed_node));
1526       if (output_node_list->size() > 1) {  // referenced by multi nodes
1527         if (SiblingsAreSameSlice(output_node_list) && MergeParallelSlice(graph, output_node_list)) {
1528           this_time_changed = true;
1529           break;
1530         }
1531         continue;
1532       } else {
1533         if (utils::isa<ParameterPtr>(preceed_node)) {
1534           /*
1535            * if preceed_node is parameter without default param, it's input placeholder, so we can't prepose
1536            * if preceed_node is parameter with default param, constant_folding will process it
1537            */
1538           continue;
1539         }
1540         auto preceed_cnode = preceed_node->cast<CNodePtr>();
1541         if (preceed_cnode == nullptr) {
1542           MS_LOG(ERROR) << "preceed_cnode is nullptr";
1543           continue;
1544         }
1545         if (DoPrepose(graph, slice_cnode, preceed_cnode)) {
1546           this_time_changed = true;
1547           break;
1548         }
1549       }
1550     }
1551     if (this_time_changed) {
1552       changed = true;
1553     } else {
1554       break;
1555     }
1556   }
1557   return changed;
1558 }
1559 }  // namespace mindspore::opt
1560