• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "tools/optimizer/graph/slice_prepose_pass.h"
17 #include <vector>
18 #include <memory>
19 #include <set>
20 #include <algorithm>
21 #include "ops/fusion/full_connection.h"
22 #include "ops/reshape.h"
23 #include "ops/fusion/slice_fusion.h"
24 #include "ops/softmax.h"
25 #include "ops/op_utils.h"
26 #include "include/errorcode.h"
27 #include "tools/optimizer/common/gllo_utils.h"
28 #include "backend/optimizer/common/helper.h"
29 #include "src/common/log_adapter.h"
30 #include "nnacl/op_base.h"
31 
32 namespace mindspore::opt {
33 namespace {
34 const int kArithmeticInputNum = 2;
35 const int SliceBeginIndex = 2;
36 const int SliceSizeIndex = 3;
37 int node_name_index = 0;
GetSliceBeginAndSize(const CNodePtr & cnode,const int index)38 std::vector<int> GetSliceBeginAndSize(const CNodePtr &cnode, const int index) {
39   MS_ASSERT(cnode != nullptr);
40   std::vector<int> content;
41   if (index != SliceBeginIndex && index != SliceSizeIndex && cnode->size() != 4) {
42     return content;
43   }
44   auto node = cnode->input(index);
45   if (node == nullptr) {
46     return content;
47   }
48   auto param_node = node->cast<ParameterPtr>();
49   if (param_node == nullptr || !param_node->has_default() || param_node->default_param() == nullptr) {
50     return content;
51   }
52   auto tensor_info = param_node->default_param()->cast<tensor::TensorPtr>();
53   if (tensor_info == nullptr) {
54     return content;
55   }
56   content.resize(tensor_info->DataSize());
57   if (memcpy_s(content.data(), tensor_info->Size(), tensor_info->data_c(), tensor_info->Size()) != EOK) {
58     MS_LOG(ERROR) << "memcpy data failed.";
59     return {};
60   }
61   return content;
62 }
63 
GetCNodeInputShape(const CNodePtr & cnode,size_t index=1)64 std::vector<int64_t> GetCNodeInputShape(const CNodePtr &cnode, size_t index = 1) {
65   MS_ASSERT(cnode != nullptr);
66   std::vector<int64_t> empty_shape;
67   if (index < 1 || cnode->inputs().size() <= index) {
68     MS_LOG(ERROR) << "out of index";
69     return empty_shape;
70   }
71   auto abstract = GetCNodeInputAbstract(cnode, index);
72   if (abstract == nullptr) {
73     MS_LOG(ERROR) << "Abstract of CNode is nullptr";
74     return empty_shape;
75   }
76   if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
77     MS_LOG(DEBUG) << "abstract is not AbstractTensor";
78     return empty_shape;
79   }
80   auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract);
81   MS_ASSERT(abstract_tensor != nullptr && abstract_tensor->shape() != nullptr);
82   return abstract_tensor->shape()->shape();
83 }
84 
GetDefaultParamShape(const ParameterPtr & param)85 std::vector<int64_t> GetDefaultParamShape(const ParameterPtr &param) {
86   MS_ASSERT(param != nullptr);
87   MS_ASSERT(param->has_default());
88   std::vector<int64_t> shape_vector;
89   auto default_param = param->default_param();
90   if (default_param == nullptr) {
91     MS_LOG(ERROR) << "default_param is nullptr";
92     return shape_vector;
93   }
94   if (!utils::isa<tensor::TensorPtr>(default_param)) {
95     MS_LOG(ERROR) << "default_param is not tensor::Tensor";
96     return shape_vector;
97   }
98   auto param_value_lite = utils::cast<tensor::TensorPtr>(default_param);
99   MS_ASSERT(param_value != nullptr);
100   auto shape = param_value_lite->shape();
101   std::transform(shape.begin(), shape.end(), std::back_inserter(shape_vector),
102                  [](const int val) { return static_cast<int64_t>(val); });
103   return shape_vector;
104 }
105 
IsScalarNode(const AnfNodePtr & nodePtr)106 bool IsScalarNode(const AnfNodePtr &nodePtr) {
107   MS_ASSERT(nodePtr != nullptr);
108   if (utils::isa<ParameterPtr>(nodePtr) && nodePtr->cast<ParameterPtr>()->has_default()) {
109     auto tensor = utils::cast<tensor::TensorPtr>(utils::cast<ParameterPtr>(nodePtr)->default_param());
110     MS_ASSERT(tensor != nullptr);
111     auto shape = tensor->shape();
112     if (shape.empty() || (shape.size() == 1 && shape[0] == 1)) {
113       return true;
114     }
115   }
116   return false;
117 }
118 
GetSlice(const CNodePtr & cnode)119 std::shared_ptr<mindspore::ops::SliceFusion> GetSlice(const CNodePtr &cnode) {
120   if (cnode == nullptr) {
121     return nullptr;
122   }
123   return GetValueNode<std::shared_ptr<mindspore::ops::SliceFusion>>(cnode->input(0));
124 }
125 
GetSoftmax(const CNodePtr & cnode)126 std::shared_ptr<mindspore::ops::Softmax> GetSoftmax(const CNodePtr &cnode) {
127   if (cnode == nullptr) {
128     return nullptr;
129   }
130   return GetValueNode<std::shared_ptr<mindspore::ops::Softmax>>(cnode->input(0));
131 }
132 
GetReshape(const CNodePtr & cnode)133 std::shared_ptr<mindspore::ops::Reshape> GetReshape(const CNodePtr &cnode) {
134   if (cnode == nullptr) {
135     return nullptr;
136   }
137   return GetValueNode<std::shared_ptr<mindspore::ops::Reshape>>(cnode->input(0));
138 }
139 
GetFc(const CNodePtr & cnode)140 std::shared_ptr<mindspore::ops::FullConnection> GetFc(const CNodePtr &cnode) {
141   if (cnode == nullptr) {
142     return nullptr;
143   }
144   return GetValueNode<std::shared_ptr<mindspore::ops::FullConnection>>(cnode->input(0));
145 }
146 
GetTransposePerm(const CNodePtr & node)147 std::vector<int> GetTransposePerm(const CNodePtr &node) {
148   MS_ASSERT(node != nullptr);
149   std::vector<int> perm;
150   if (!CheckPrimitiveType(node, prim::kPrimTranspose)) {
151     return perm;
152   }
153   if (node->inputs().size() != 3) {
154     return perm;
155   }
156   auto perm_node = node->input(2);
157   if (!utils::isa<ParameterPtr>(perm_node)) {
158     return perm;
159   }
160   auto perm_param = perm_node->cast<ParameterPtr>();
161   MS_ASSERT(perm_param != nullptr);
162   if (!perm_param->has_default() || perm_param->default_param() == nullptr) {
163     return perm;
164   }
165   auto perm_value = perm_param->default_param()->cast<tensor::TensorPtr>();
166   if (perm_value == nullptr) {
167     return perm;
168   }
169   MS_CHECK_TRUE_MSG(perm_value->shape().size() != 0, {}, "shape is empty");
170   perm.resize(perm_value->shape()[0]);
171   if (memcpy_s(perm.data(), perm_value->Size(), perm_value->data_c(), perm_value->Size()) != EOK) {
172     MS_LOG(ERROR) << "memcpy failed.";
173     return {};
174   }
175   return perm;
176 }
177 }  // namespace
178 
ClearCNodeAbstractValue(const CNodePtr & cnode)179 void SlicePreposePass::ClearCNodeAbstractValue(const CNodePtr &cnode) {
180   MS_ASSERT(cnode != nullptr);
181   auto abstract = cnode->abstract();
182   MS_ASSERT(abstract != nullptr);
183   if (!utils::isa<abstract::AbstractTensorPtr>(abstract)) {
184     MS_LOG(DEBUG) << "Abstract of cnode is not abstract tensor, " << cnode->fullname_with_scope();
185   }
186   abstract->set_value(std::make_shared<AnyValue>());
187 }
188 
SwapSliceWithPreceed(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & preceed_cnode,const int index,const TransactionPtr & tr)189 STATUS SlicePreposePass::SwapSliceWithPreceed(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
190                                               const CNodePtr &preceed_cnode, const int index,
191                                               const TransactionPtr &tr) {
192   MS_ASSERT(graph != nullptr);
193   MS_ASSERT(slice_cnode != nullptr);
194   MS_ASSERT(preceed_cnode != nullptr);
195   if (slice_cnode->input(1) != preceed_cnode) {
196     MS_LOG(ERROR) << "proceed node must be slice node's direct parent";
197     return RET_ERROR;
198   }
199   if (IsMultiOutputTensors(graph, preceed_cnode)) {
200     MS_LOG(ERROR) << "proceed node referenced by multi nodes not support swap";
201     return RET_ERROR;
202   }
203   auto manager = graph->manager();
204   if (manager == nullptr) {
205     MS_LOG(ERROR) << "manager is nullptr";
206     return RET_ERROR;
207   }
208   auto node_users = manager->node_users()[slice_cnode];
209   if (tr != nullptr) {  // do swap with transaction
210     for (auto &node_user : node_users) {
211       tr->SetEdge(node_user.first, node_user.second, preceed_cnode);
212     }
213     tr->SetEdge(slice_cnode, 1, preceed_cnode->input(index));
214     tr->SetEdge(preceed_cnode, index, slice_cnode);
215   } else {
216     for (auto &node_user : node_users) {
217       manager->SetEdge(node_user.first, node_user.second, preceed_cnode);
218     }
219     manager->SetEdge(slice_cnode, 1, preceed_cnode->input(index));
220     manager->SetEdge(preceed_cnode, index, slice_cnode);
221   }
222   return RET_OK;
223 }
224 
CreateSliceValueNode(const std::vector<int64_t> & axes)225 ValueNodePtr SlicePreposePass::CreateSliceValueNode(const std::vector<int64_t> &axes) {
226   MS_ASSERT(graph != nullptr);
227   MS_ASSERT(slice_cnode != nullptr);
228   auto new_slice = std::make_shared<mindspore::ops::SliceFusion>();
229   MS_CHECK_TRUE_MSG(new_slice != nullptr, nullptr, "new_slice is nullptr");
230   new_slice->set_axes(axes);
231   ValueNodePtr value_node = NewValueNode(new_slice);
232   MS_CHECK_TRUE_MSG(value_node != nullptr, nullptr, "NewValueNode Failed");
233   return value_node;
234 }
235 
CopySliceValueNode(const CNodePtr & slice_cnode)236 ValueNodePtr SlicePreposePass::CopySliceValueNode(const CNodePtr &slice_cnode) {
237   MS_ASSERT(graph != nullptr);
238   MS_ASSERT(slice_cnode != nullptr);
239   auto slice_c = GetValueNode<std::shared_ptr<mindspore::ops::SliceFusion>>(slice_cnode->input(0));
240   if (slice_c == nullptr) {
241     MS_LOG(ERROR) << "slice node is nullptr";
242     return nullptr;
243   }
244   auto new_slice_c = std::make_shared<mindspore::ops::SliceFusion>();
245   MS_CHECK_TRUE_MSG(new_slice_c != nullptr, nullptr, "new_slice_c is nullptr");
246   new_slice_c->set_axes(slice_c->get_axes());
247   ValueNodePtr value_node = NewValueNode(new_slice_c);
248   MS_CHECK_TRUE_MSG(value_node != nullptr, nullptr, "NewValueNode Failed");
249   return value_node;
250 }
251 
InsertSlice(const FuncGraphPtr & graph,const std::vector<AnfNodePtr> & inputs,const CNodePtr & preceed_cnode,const int index,const TransactionPtr & tr)252 CNodePtr SlicePreposePass::InsertSlice(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &inputs,
253                                        const CNodePtr &preceed_cnode, const int index, const TransactionPtr &tr) {
254   MS_ASSERT(graph != nullptr);
255   MS_ASSERT(slice_cnode != nullptr);
256   MS_ASSERT(preceed_cnode != nullptr);
257   auto slice_cnode = graph->NewCNode(inputs);
258   MS_CHECK_TRUE_MSG(slice_cnode != nullptr, nullptr, "NewNode Failed");
259   slice_cnode->set_fullname_with_scope(preceed_cnode->fullname_with_scope() + "_slice_" +
260                                        std::to_string(node_name_index));
261   node_name_index += 1;
262   tr->SetEdge(preceed_cnode, index, slice_cnode);
263   return slice_cnode;
264 }
265 
VerifySliceAttrs(const CNodePtr & slice_cnode,const int dim)266 STATUS SlicePreposePass::VerifySliceAttrs(const CNodePtr &slice_cnode, const int dim) {
267   // according to ops/slice.cc, axes >= 0, begin >= 0, size >= -1
268   auto slice = GetSlice(slice_cnode);
269   if (slice == nullptr) {
270     MS_LOG(ERROR) << "Slice is nullptr";
271     return RET_ERROR;
272   }
273   auto axes = slice->get_axes();
274   auto begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
275   auto size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
276 
277   std::set<int64_t> unique_axes(axes.begin(), axes.end());
278   if (axes.empty() || unique_axes.size() != axes.size()) {
279     MS_LOG(DEBUG) << "Invalid slice axe attribute";
280     return RET_ERROR;
281   }
282   MS_CHECK_TRUE_MSG(begin.size() <= axes.size(), RET_ERROR, "begin size is wrong");
283   MS_CHECK_TRUE_MSG(size.size() <= axes.size(), RET_ERROR, "size.size() is wrong");
284   for (size_t i = 0; i < axes.size(); ++i) {
285     auto axe = axes[i];
286     if (dim > -1 && axe >= dim) {
287       MS_LOG(ERROR) << "Invalid slice axe attribute";
288       return RET_ERROR;
289     }
290     if (axe < 0) {
291       MS_LOG(ERROR) << "Invalid slice axe attribute";
292       return RET_ERROR;
293     }
294     if (begin[i] < 0) {  //  we not require begin[i] < ref_shape[axe], cause there may be broadcast
295       MS_LOG(ERROR) << "Invalid begin input! begin[" << i << "]=" << begin[i];
296       return RET_ERROR;
297     }
298     if (size[i] < -1) {
299       MS_LOG(ERROR) << "Invalid size input! size[" << i << "]=" << size[i];
300       return RET_ERROR;
301     }
302   }
303   return RET_OK;
304 }
305 
306 /*
307  * Adjust slice's attr when broadcast happened in Arithmetic
308  */
SliceParamDeBroadcast(const CNodePtr & slice_cnode,const std::vector<int64_t> & ref_shape,std::vector<int64_t> * axes,std::vector<int> * begin,std::vector<int> * size)309 STATUS SlicePreposePass::SliceParamDeBroadcast(const CNodePtr &slice_cnode, const std::vector<int64_t> &ref_shape,
310                                                std::vector<int64_t> *axes, std::vector<int> *begin,
311                                                std::vector<int> *size) {
312   MS_ASSERT(slice_cnode != nullptr);
313   MS_ASSERT(new_slice_cnode != nullptr);
314   MS_ASSERT(axes != nullptr);
315   MS_ASSERT(begin != nullptr);
316   MS_ASSERT(size != nullptr);
317   auto slice = GetSlice(slice_cnode);
318   if (slice == nullptr) {
319     MS_LOG(ERROR) << "slice is nullptr";
320     return RET_ERROR;
321   }
322   auto origin_axes = slice->get_axes();
323   auto origin_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
324   auto origin_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
325   auto status = VerifySliceAttrs(slice_cnode, ref_shape.size());
326   if (status != RET_OK) {
327     return status;
328   }
329   axes->resize(ref_shape.size());
330   std::iota(axes->begin(), axes->end(), 0);
331   begin->assign(ref_shape.size(), 0);
332   size->assign(ref_shape.size(), -1);
333   bool real_slice = false;  // whether slice happened at this input
334   MS_CHECK_TRUE_MSG(origin_begin.size() >= origin_axes.size(), RET_ERROR, "origin_begin.size() is wrong");
335   MS_CHECK_TRUE_MSG(origin_size.size() >= origin_axes.size(), RET_ERROR, "origin_size.size() is wrong");
336   for (size_t i = 0; i < origin_axes.size(); ++i) {
337     int a = origin_axes[i];
338     int b = origin_begin[i];
339     int s = origin_size[i];
340     MS_CHECK_TRUE_MSG(static_cast<int>(ref_shape.size()) > a, RET_ERROR, "ref_shape.size() is wrong");
341     int ref = ref_shape[a];
342     if (ref == 1) {        // broadcast
343       continue;            // sliced size is 0(such as begin=1,size=-1) is not considered.
344     } else if (ref > 1) {  // not broadcast
345       if (b >= ref) {
346         MS_LOG(ERROR) << "slice begin[" << a << "]=" << b << ", while ref_shape[" << a << "]=" << ref << ", can't fit!";
347         return RET_ERROR;
348       } else {
349         if (b != 0 || (s != -1 && s != ref)) {
350           real_slice = true;
351         }
352         MS_CHECK_TRUE_MSG(static_cast<int>(begin->size()) > a, RET_ERROR, "begin.size() is wrong");
353         MS_CHECK_TRUE_MSG(static_cast<int>(size->size()) > a, RET_ERROR, "size.size() is wrong");
354         begin->at(a) = b;
355         size->at(a) = s;
356       }
357     } else {  // ref == 0, not need slice
358       continue;
359     }
360   }
361   if (real_slice) {
362     return lite::RET_OK;
363   } else {
364     return lite::RET_NO_CHANGE;
365   }
366 }
367 
CreateReshapeCNode(const FuncGraphPtr & graph,const std::vector<int64_t> & shape_vector,const AbstractBasePtr & abstract,const CNodePtr & preceed_cnode)368 CNodePtr SlicePreposePass::CreateReshapeCNode(const FuncGraphPtr &graph, const std::vector<int64_t> &shape_vector,
369                                               const AbstractBasePtr &abstract, const CNodePtr &preceed_cnode) {
370   MS_ASSERT(graph != nullptr);
371   MS_ASSERT(slice_cnode != nullptr);
372   MS_ASSERT(abstract != nullptr);
373   MS_ASSERT(preceed_cnode != nullptr);
374   auto new_reshape = std::make_shared<mindspore::ops::Reshape>();
375   if (new_reshape == nullptr) {
376     MS_LOG(ERROR) << "primitive_c is nullptr";
377     return nullptr;
378   }
379   ValueNodePtr value_node = NewValueNode(new_reshape);
380   if (value_node == nullptr) {
381     return nullptr;
382   }
383   std::vector<int> shape;
384   std::transform(shape_vector.begin(), shape_vector.end(), std::back_inserter(shape),
385                  [](int64_t val) { return static_cast<int>(val); });
386   auto shape_node = BuildIntVecParameterNode(
387     graph, shape, preceed_cnode->fullname_with_scope() + "_shape_" + std::to_string(node_name_index));
388   node_name_index++;
389   if (shape_node == nullptr) {
390     MS_LOG(ERROR) << "build parameter node failed.";
391     return nullptr;
392   }
393   auto reshape_cnode = graph->NewCNode({value_node, preceed_cnode, shape_node});
394   MS_CHECK_TRUE_MSG(reshape_cnode != nullptr, nullptr, "NewCNode Failed");
395   reshape_cnode->set_abstract(abstract);
396   reshape_cnode->set_fullname_with_scope(preceed_cnode->fullname_with_scope() + "_reshape_" +
397                                          std::to_string(node_name_index));
398   node_name_index++;
399   ClearCNodeAbstractValue(reshape_cnode);
400   return reshape_cnode;
401 }
402 
SiblingsAreSameSlice(const NodeUsedListPtr & output_node_list,const std::vector<int64_t> & ref_shape)403 bool SlicePreposePass::SiblingsAreSameSlice(const NodeUsedListPtr &output_node_list,
404                                             const std::vector<int64_t> &ref_shape) {
405   MS_ASSERT(graph != nullptr);
406   MS_ASSERT(output_node_list != nullptr);
407   MS_ASSERT(output_node_list->size() >= 2);
408   std::vector<CNodePtr> slices;
409   for (auto &output_node : *(output_node_list.get())) {
410     auto cnode = output_node.first->cast<CNodePtr>();
411     if (cnode == nullptr) {
412       MS_LOG(ERROR) << "cnode is nullptr";
413       return false;
414     }
415     if (!CheckPrimitiveType(cnode, prim::kPrimSliceFusion)) {
416       return false;
417     }
418     auto slice_node = GetSlice(cnode);
419     if (slice_node == nullptr) {
420       MS_LOG(ERROR) << "Slice is nullptr";
421       return false;
422     }
423     slices.push_back(cnode);
424   }
425   MS_CHECK_TRUE_MSG(slices.size() > 0, RET_ERROR, "slices.size() is wrong");
426   auto first_slice_cnode = slices.front();
427   auto first_slice_node = GetSlice(first_slice_cnode);
428   if (first_slice_node == nullptr) {
429     MS_LOG(ERROR) << "GetSlice return nullptr";
430     return false;
431   }
432   auto first_axes = first_slice_node->get_axes();
433   auto first_begin = GetSliceBeginAndSize(first_slice_cnode, SliceBeginIndex);
434   auto first_size = GetSliceBeginAndSize(first_slice_cnode, SliceSizeIndex);
435   MS_CHECK_TRUE_MSG(first_begin.size() >= first_axes.size(), RET_ERROR, "first_begin.size() is wrong");
436   MS_CHECK_TRUE_MSG(first_size.size() >= first_axes.size(), RET_ERROR, "first_size.size() is wrong");
437   MS_CHECK_TRUE_MSG(slices.size() >= output_node_list->size(), RET_ERROR, "slices.size() is wrong");
438   for (size_t i = 1; i < output_node_list->size(); ++i) {
439     auto slice = GetSlice(slices[i]);
440     auto axes = slice->get_axes();
441     auto begin = GetSliceBeginAndSize(slices[i], SliceBeginIndex);
442     auto size = GetSliceBeginAndSize(slices[i], SliceSizeIndex);
443     MS_CHECK_TRUE_MSG(begin.size() >= axes.size(), RET_ERROR, "begin.size() is wrong");
444     MS_CHECK_TRUE_MSG(size.size() >= axes.size(), RET_ERROR, "size.size() is wrong");
445     if (axes.size() != first_axes.size()) {
446       return false;
447     }
448     for (size_t j = 0; j < axes.size(); ++j) {
449       auto axe = axes[j];
450       if (!ref_shape.empty() && axe >= static_cast<int>(ref_shape.size())) {
451         return false;
452       }
453       size_t k = 0;
454       for (; k < first_axes.size(); ++k) {  // axes may not be [0...n-1], so we use nested loop to find it
455         if (first_axes[k] == axe) {
456           break;
457         }
458       }
459       if (k == first_axes.size()) {
460         return false;
461       }
462       if (begin[j] != first_begin[k]) {
463         return false;
464       }
465       if (size[j] != first_size[k]) {
466         if (ref_shape.empty()) {
467           return false;
468         }
469         auto actual_size = size[j] > 0 ? size[j] : ref_shape[axe] - begin[j];
470         auto actual_first_size = first_size[k] > 0 ? first_size[k] : ref_shape[axe] - first_begin[k];
471         if (actual_size != actual_first_size) {
472           return false;
473         }
474       }
475     }
476   }
477   return true;
478 }
479 
GetReshapeAbnormalAxeIn(const std::vector<int64_t> & shape_in,const std::vector<int64_t> & shape_out,std::vector<int64_t> * mapped_axe)480 int64_t SlicePreposePass::GetReshapeAbnormalAxeIn(const std::vector<int64_t> &shape_in,
481                                                   const std::vector<int64_t> &shape_out,
482                                                   std::vector<int64_t> *mapped_axe) {
483   // find shape_out's correspond axe in shape_in
484   // when there are such as 3x1x1x4 => 3x1x4, mapped_axe[1] == 2
485   int64_t inner_size_in = 1;
486   int64_t abnormal_axe_in = -1;
487   MS_CHECK_TRUE_MSG(mapped_axe->size() >= shape_out.size(), abnormal_axe_in, "mapped_axe.size() is wrong");
488   for (size_t i = 0; i < shape_in.size(); ++i) {
489     inner_size_in *= shape_in[i];
490     int64_t inner_size_out = 1;
491     size_t j;
492     for (j = 0; j < shape_out.size(); ++j) {
493       inner_size_out *= shape_out[j];
494       if (shape_out[j] == shape_in[i] && inner_size_out == inner_size_in) {
495         mapped_axe->at(j) = i;
496         break;
497       }
498     }
499     if (j == shape_out.size() && abnormal_axe_in == -1) {
500       abnormal_axe_in = i;
501     }
502   }
503   return abnormal_axe_in;
504 }
505 
GetReshapeAbnormalIndexOut(const CNodePtr & slice_cnode,const std::vector<int64_t> & mapped_axe,const std::vector<int64_t> & shape_out,std::vector<int64_t> * shape_out_copy,bool * is_normal_mode,bool * support_abnormal_mode)506 int64_t SlicePreposePass::GetReshapeAbnormalIndexOut(const CNodePtr &slice_cnode,
507                                                      const std::vector<int64_t> &mapped_axe,
508                                                      const std::vector<int64_t> &shape_out,
509                                                      std::vector<int64_t> *shape_out_copy, bool *is_normal_mode,
510                                                      bool *support_abnormal_mode) {
511   MS_ASSERT(slice_cnode != nullptr);
512   MS_ASSERT(shape_out_copy != nullptr);
513   MS_ASSERT(is_normal_mode != nullptr);
514   MS_ASSERT(support_abnormal_mode != nullptr);
515   int64_t abnormal_index_out = -1;
516   auto slice_node = GetSlice(slice_cnode);
517   if (slice_node == nullptr) {
518     MS_LOG(ERROR) << "slice is nullptr";
519     return abnormal_index_out;
520   }
521   auto slice_axes = slice_node->get_axes();
522   auto slice_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
523   auto slice_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
524   for (size_t j = 0; j < shape_out.size(); ++j) {
525     int index = -1;
526     for (size_t i = 0; i < slice_axes.size(); ++i) {
527       if (slice_axes[i] == static_cast<int64_t>(j)) {
528         index = i;
529         break;
530       }
531     }
532     if (index == -1) {
533       continue;
534     }
535     MS_CHECK_TRUE_MSG(static_cast<int>(slice_begin.size()) > index, false, "slice_begin.size() is wrong");
536     MS_CHECK_TRUE_MSG(static_cast<int>(slice_size.size()) > index, false, "slice_size.size() is wrong");
537     if (slice_begin[index] != 0 || (slice_size[index] != -1 && slice_size[index] != shape_out[j])) {
538       if (mapped_axe[j] == -1) {
539         if (is_normal_mode) {
540           *is_normal_mode = false;
541           abnormal_index_out = index;
542         } else {
543           *support_abnormal_mode = false;
544         }
545       } else {  // if there is matched axe sliced, not support abnormal mode
546         shape_out_copy->at(j) =
547           (slice_size[index] == -1 ? shape_out[j] - slice_begin[index] : static_cast<int64_t>(slice_size[index]));
548         *support_abnormal_mode = false;
549       }
550     }
551   }
552   return abnormal_index_out;
553 }
554 
PreposeWithNormalReshape(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & reshape_cnode,const std::vector<int64_t> & shape_in,const std::vector<int64_t> & shape_out_copy,const std::vector<int64_t> & mapped_axe)555 bool SlicePreposePass::PreposeWithNormalReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
556                                                 const CNodePtr &reshape_cnode, const std::vector<int64_t> &shape_in,
557                                                 const std::vector<int64_t> &shape_out_copy,
558                                                 const std::vector<int64_t> &mapped_axe) {
559   MS_ASSERT(graph != nullptr);
560   MS_ASSERT(slice_cnode != nullptr);
561   MS_ASSERT(reshape_cnode != nullptr);
562   auto slice_node = GetSlice(slice_cnode);
563   if (slice_node == nullptr) {
564     MS_LOG(ERROR) << "slice is nullptr";
565     return false;
566   }
567   auto slice_axes = slice_node->get_axes();
568   auto slice_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
569   auto slice_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
570   std::vector<int64_t> new_axes(shape_in.size());
571   std::iota(new_axes.begin(), new_axes.end(), 0);
572   std::vector<int> new_begin(shape_in.size(), 0);
573   std::vector<int> new_size(shape_in.size(), -1);
574   MS_CHECK_TRUE_MSG(slice_begin.size() >= mapped_axe.size(), false, "slice_begin.size() is wrong");
575   MS_CHECK_TRUE_MSG(slice_size.size() >= mapped_axe.size(), false, "slice_begin.size() is wrong");
576   for (size_t i = 0; i < mapped_axe.size(); ++i) {
577     auto axe_in = mapped_axe[i];
578     if (axe_in == -1) {
579       continue;
580     }
581     new_begin[axe_in] = slice_begin[i];
582     new_size[axe_in] = slice_size[i];
583   }
584 
585   auto reshape_node = GetReshape(reshape_cnode);
586   if (reshape_node == nullptr) {
587     MS_LOG(ERROR) << "reshape is nullptr";
588     return false;
589   }
590   std::vector<int> new_shape_out_copy;
591   std::transform(shape_out_copy.begin(), shape_out_copy.end(), std::back_inserter(new_shape_out_copy),
592                  [](int64_t val) { return static_cast<int>(val); });
593   auto shape_node = BuildIntVecParameterNode(
594     graph, new_shape_out_copy, reshape_cnode->fullname_with_scope() + "_shape_" + std::to_string(node_name_index));
595   node_name_index++;
596   if (shape_node == nullptr) {
597     MS_LOG(ERROR) << "build parameter node failed.";
598     return false;
599   }
600   reshape_cnode->set_inputs({reshape_cnode->input(0), reshape_cnode->input(1), shape_node});
601 
602   slice_node->set_axes(new_axes);
603   auto new_begin_parameter = BuildIntVecParameterNode(
604     graph, new_begin, slice_cnode->input(SliceBeginIndex)->cast<ParameterPtr>()->fullname_with_scope());
605   auto new_size_parameter = BuildIntVecParameterNode(
606     graph, new_size, slice_cnode->input(SliceSizeIndex)->cast<ParameterPtr>()->fullname_with_scope());
607   MS_CHECK_TRUE_MSG(new_begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
608   MS_CHECK_TRUE_MSG(new_size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
609   slice_cnode->set_input(SliceBeginIndex, new_begin_parameter);
610   slice_cnode->set_input(SliceSizeIndex, new_size_parameter);
611   auto status = SwapSliceWithPreceed(graph, slice_cnode, reshape_cnode, 1);
612   if (status != RET_OK) {
613     return false;
614   }
615   reshape_cnode->set_abstract(slice_cnode->abstract()->Clone());
616   ClearCNodeAbstractValue(slice_cnode);
617   return true;
618 }
619 
CreateSlice1ForReshapePrepose(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & matmul_cnode,const std::vector<int64_t> & shape_in,const int64_t abnormal_axe_in,const int64_t count_sliced_axe_in,const bool slice_at_front)620 CNodePtr SlicePreposePass::CreateSlice1ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
621                                                          const CNodePtr &matmul_cnode,
622                                                          const std::vector<int64_t> &shape_in,
623                                                          const int64_t abnormal_axe_in,
624                                                          const int64_t count_sliced_axe_in, const bool slice_at_front) {
625   MS_ASSERT(graph != nullptr);
626   MS_ASSERT(slice_cnode != nullptr);
627   MS_ASSERT(matmul_cnode != nullptr);
628   std::vector<int64_t> new_axes1(shape_in.size());
629   std::iota(new_axes1.begin(), new_axes1.end(), 0);
630   std::vector<int> new_begin1(shape_in.size(), 0);
631   std::vector<int> new_size1(shape_in.size(), -1);
632   if (slice_at_front) {
633     new_begin1[abnormal_axe_in] = static_cast<int>(count_sliced_axe_in);
634   } else {
635     new_size1[abnormal_axe_in] = static_cast<int>(shape_in[abnormal_axe_in] - count_sliced_axe_in);
636   }
637   auto new_slice1 = CreateSliceValueNode(new_axes1);
638   if (new_slice1 == nullptr) {
639     MS_LOG(ERROR) << "CreateSliceValueNode failed";
640     return nullptr;
641   }
642   auto begin_parameter = BuildIntVecParameterNode(
643     graph, new_begin1, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
644   MS_CHECK_TRUE_MSG(begin_parameter != nullptr, nullptr, "BuildIntVecParameterNode Failed");
645   node_name_index += 1;
646   auto size_parameter = BuildIntVecParameterNode(
647     graph, new_size1, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
648   MS_CHECK_TRUE_MSG(size_parameter != nullptr, nullptr, "BuildIntVecParameterNode Failed");
649   node_name_index += 1;
650   auto new_slice1_cnode = graph->NewCNode({new_slice1, matmul_cnode, begin_parameter, size_parameter});
651   MS_CHECK_TRUE_MSG(new_slice1_cnode != nullptr, nullptr, "NewNode Failed");
652   new_slice1_cnode->set_abstract(slice_cnode->abstract()->Clone());
653   new_slice1_cnode->set_fullname_with_scope(slice_cnode->fullname_with_scope() + "_slice_" +
654                                             std::to_string(node_name_index));
655   node_name_index++;
656   ClearCNodeAbstractValue(new_slice1_cnode);
657   return new_slice1_cnode;
658 }
659 
CreateSlice2ForReshapePrepose(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & new_reshape1_cnode,const std::vector<int64_t> & new_shape1,const int64_t abnormal_axe_in,const int64_t count_sliced2,const bool slice_at_front)660 CNodePtr SlicePreposePass::CreateSlice2ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
661                                                          const CNodePtr &new_reshape1_cnode,
662                                                          const std::vector<int64_t> &new_shape1,
663                                                          const int64_t abnormal_axe_in, const int64_t count_sliced2,
664                                                          const bool slice_at_front) {
665   MS_ASSERT(graph != nullptr);
666   MS_ASSERT(slice_cnode != nullptr);
667   MS_ASSERT(matmul_cnode != nullptr);
668   std::vector<int64_t> new_axes2(abnormal_axe_in + 1);
669   std::iota(new_axes2.begin(), new_axes2.end(), 0);
670   std::vector<int> new_begin2(abnormal_axe_in + 1, 0);
671   std::vector<int> new_size2(abnormal_axe_in + 1, -1);
672   if (count_sliced2 > new_shape1[abnormal_axe_in]) {
673     MS_LOG(WARNING) << "calculation error";
674     return nullptr;
675   }
676   if (slice_at_front) {
677     new_begin2[abnormal_axe_in] = static_cast<int>(new_shape1[abnormal_axe_in] - count_sliced2);
678   } else {
679     new_size2[abnormal_axe_in] = static_cast<int>(count_sliced2);
680   }
681   auto new_slice2 = CreateSliceValueNode(new_axes2);
682   if (new_slice2 == nullptr) {
683     MS_LOG(ERROR) << "CreateSliceValueNode failed";
684     return nullptr;
685   }
686   auto begin_parameter = BuildIntVecParameterNode(
687     graph, new_begin2, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
688   MS_CHECK_TRUE_MSG(begin_parameter != nullptr, nullptr, "BuildIntVecParameterNode Failed");
689   node_name_index += 1;
690   auto size_parameter = BuildIntVecParameterNode(
691     graph, new_size2, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
692   node_name_index += 1;
693   MS_CHECK_TRUE_MSG(size_parameter != nullptr, nullptr, "BuildIntVecParameterNode Failed");
694   auto new_slice2_cnode = graph->NewCNode({new_slice2, new_reshape1_cnode, begin_parameter, size_parameter});
695   MS_CHECK_TRUE_MSG(new_slice2_cnode != nullptr, nullptr, "NewNode Failed");
696   new_slice2_cnode->set_abstract(slice_cnode->abstract()->Clone());
697   new_slice2_cnode->set_fullname_with_scope(slice_cnode->fullname_with_scope() + "_slice_" +
698                                             std::to_string(node_name_index));
699   node_name_index++;
700   ClearCNodeAbstractValue(new_slice2_cnode);
701   return new_slice2_cnode;
702 }
703 
PreposeWithAbnormalReshape(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & matmul_cnode,const std::vector<int64_t> & shape_in,const std::vector<int64_t> & shape_out,const int64_t abnormal_axe_in,const int64_t abnormal_index_out)704 bool SlicePreposePass::PreposeWithAbnormalReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
705                                                   const CNodePtr &matmul_cnode, const std::vector<int64_t> &shape_in,
706                                                   const std::vector<int64_t> &shape_out, const int64_t abnormal_axe_in,
707                                                   const int64_t abnormal_index_out) {
708   MS_ASSERT(graph != nullptr);
709   MS_ASSERT(slice_cnode != nullptr);
710   auto manager = graph->manager();
711   MS_CHECK_TRUE_MSG(manager != nullptr, false, "manager is nullptr");
712   auto slice_node = GetSlice(slice_cnode);
713   if (slice_node == nullptr) {
714     MS_LOG(ERROR) << "slice is nullptr";
715     return false;
716   }
717   auto slice_axes = slice_node->get_axes();
718   MS_CHECK_TRUE_MSG(static_cast<int>(slice_axes.size()) > abnormal_index_out, false, "slice_axes.size() is wrong");
719   auto slice_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
720   auto slice_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
721   auto abnormal_axe_out = slice_axes[abnormal_index_out];
722   MS_ASSERT(abnormal_axe_out + 1 < shape_out.size());
723   int64_t inter_size_in = 1;
724   int64_t inter_size_out = 1;
725   for (auto i = 0; i < abnormal_axe_in; ++i) {
726     inter_size_in *= shape_in[i];
727   }
728   for (auto i = 0; i < abnormal_axe_out; ++i) {
729     inter_size_out *= shape_out[i];
730   }
731   if (inter_size_in != inter_size_out) {
732     MS_LOG(DEBUG) << "not support prepose now";
733     return false;
734   }
735   int64_t outer_size_in = 1;
736   int64_t outer_size_out = 1;
737   for (auto i = abnormal_axe_in + 1; i < static_cast<int>(shape_in.size()); ++i) {
738     outer_size_in *= shape_in[i];
739   }
740   for (auto i = abnormal_axe_out + 1; i < static_cast<int>(shape_out.size()); ++i) {
741     outer_size_out *= shape_out[i];
742   }
743   MS_CHECK_TRUE_MSG(static_cast<int>(slice_begin.size()) > abnormal_index_out, false, "slice_begin.size() is wrong");
744   const int64_t count_sliced_axe_front = slice_begin[abnormal_index_out];
745   MS_CHECK_TRUE_MSG(static_cast<int>(slice_size.size()) > abnormal_index_out, false, "slice_size.size() is wrong");
746   MS_CHECK_TRUE_MSG(static_cast<int>(shape_out.size()) > abnormal_index_out, false, "shape_out.size() is wrong");
747   const int64_t count_sliced_axe_rear =
748     slice_size[abnormal_index_out] == -1 ? 0 : (shape_out[abnormal_axe_out] - slice_size[abnormal_index_out]);
749   if (count_sliced_axe_front * count_sliced_axe_rear > 0) {
750     MS_LOG(DEBUG) << "not border slice at abnormal axe, prepose with reshape failed";
751     return false;
752   }
753   bool slice_at_front = count_sliced_axe_front > 0;
754   const int64_t count_sliced_out = (count_sliced_axe_front + count_sliced_axe_rear) * outer_size_out;
755   MS_CHECK_TRUE_MSG(outer_size_in != 0, false, "div zero");
756   const int64_t count_sliced_axe_in = count_sliced_out / outer_size_in;
757   MS_CHECK_TRUE_MSG(static_cast<int>(shape_in.size()) > abnormal_axe_in, false, "shape_in.size() is wrong");
758   if (count_sliced_axe_in <= 0 || count_sliced_axe_in > shape_in[abnormal_axe_in]) {
759     MS_LOG(DEBUG) << "amount of sliced out tensor is illegal";
760     return false;
761   }
762   // new_slice1
763   auto new_slice1_cnode = CreateSlice1ForReshapePrepose(graph, slice_cnode, matmul_cnode, shape_in, abnormal_axe_in,
764                                                         count_sliced_axe_in, slice_at_front);
765   if (new_slice1_cnode == nullptr) {
766     return false;
767   }
768   // new_reshape1
769   std::vector<int64_t> new_shape1(abnormal_axe_in + 1);
770   for (int i = 0; i < abnormal_axe_in; ++i) {
771     new_shape1[i] = shape_in[i];
772   }
773   new_shape1[abnormal_axe_in] = outer_size_in * (shape_in[abnormal_axe_in] - count_sliced_axe_in);
774   auto new_reshape1_cnode = CreateReshapeCNode(graph, new_shape1, slice_cnode->abstract()->Clone(), new_slice1_cnode);
775   if (new_reshape1_cnode == nullptr) {
776     return false;
777   }
778   // new_slice2
779   const int64_t count_sliced_abnormal_axe =
780     shape_out[abnormal_axe_out] - (count_sliced_axe_front + count_sliced_axe_rear);
781   const int64_t count_sliced2 = count_sliced_abnormal_axe * outer_size_out;
782   auto new_slice2_cnode = CreateSlice2ForReshapePrepose(graph, slice_cnode, new_reshape1_cnode, new_shape1,
783                                                         abnormal_axe_in, count_sliced2, slice_at_front);
784   if (new_slice2_cnode == nullptr) {
785     return false;
786   }
787   // new_reshape2
788   std::vector<int64_t> new_shape2(shape_out.begin(), shape_out.end());
789   new_shape2[abnormal_axe_out] = count_sliced_abnormal_axe;
790   auto new_reshape2_cnode = CreateReshapeCNode(graph, new_shape2, slice_cnode->abstract()->Clone(), new_slice2_cnode);
791   if (new_reshape2_cnode == nullptr) {
792     return false;
793   }
794   new_reshape2_cnode->set_abstract(slice_cnode->abstract()->Clone());
795   auto node_users = manager->node_users()[slice_cnode];
796   for (auto &node_user : node_users) {
797     manager->SetEdge(node_user.first, node_user.second, new_reshape2_cnode);
798   }
799   return true;
800 }
801 
GetArithmeticInputInfo(const CNodePtr & arithmetic_cnode,std::vector<AnfNodePtr> * inputs,std::vector<std::vector<int64_t>> * shapes,std::vector<bool> * is_default_params)802 bool SlicePreposePass::GetArithmeticInputInfo(const CNodePtr &arithmetic_cnode, std::vector<AnfNodePtr> *inputs,
803                                               std::vector<std::vector<int64_t>> *shapes,
804                                               std::vector<bool> *is_default_params) {
805   MS_ASSERT(inputs != nullptr);
806   MS_ASSERT(shapes != nullptr);
807   MS_ASSERT(is_default_params != nullptr);
808   MS_ASSERT(arithmetic_cnode != nullptr);
809   for (size_t i = 1; i < arithmetic_cnode->inputs().size(); ++i) {
810     auto input = arithmetic_cnode->input(i);
811     MS_ASSERT(input != nullptr);
812     std::vector<int64_t> shape;
813     if (utils::isa<ParameterPtr>(input)) {
814       auto parameter = utils::cast<ParameterPtr>(input);
815       MS_ASSERT(parameter != nullptr);
816       if (!parameter->has_default()) {  // if one input is input placeholder, we can't change it
817         return false;
818       } else {
819         shape = GetDefaultParamShape(parameter);
820         is_default_params->push_back(true);
821       }
822     } else {  // input is CNode
823       if (!utils::isa<CNodePtr>(input)) {
824         MS_LOG(ERROR) << "one of Arithmetic's input is not CNode";
825         return false;
826       }
827       shape = GetCNodeInputShape(arithmetic_cnode, i);
828       is_default_params->push_back(false);
829     }
830     inputs->push_back(input);
831     shapes->push_back(shape);
832   }
833   return true;
834 }
835 
836 /*
837  * Prepose condition:
838  *  the softmax axis is not sliced
839  */
PreposeWithSoftmax(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & softmax_cnode)840 bool SlicePreposePass::PreposeWithSoftmax(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
841                                           const CNodePtr &softmax_cnode) {
842   MS_ASSERT(graph != nullptr);
843   MS_ASSERT(slice_cnode != nullptr);
844   MS_ASSERT(softmax_cnode != nullptr);
845   auto softmax_node = GetSoftmax(softmax_cnode);
846   if (softmax_node == nullptr) {
847     MS_LOG(ERROR) << "softmax is nullptr";
848     return false;
849   }
850   std::vector<int64_t> softmax_axis{-1};
851   if (softmax_node->GetAttr(ops::kAxis) != nullptr) {
852     softmax_axis = softmax_node->get_axis();
853   }
854   if (softmax_axis.size() != 1) {
855     MS_LOG(ERROR) << "softmax axis is not a value, which don't support.";
856     return false;
857   }
858   auto shape = GetCNodeInputShape(softmax_cnode, 1);
859   if (softmax_axis.front() == -1) {
860     if (shape.empty()) {  // when softmax axis == -1, shape info is needed to determine whether slice can be preposed
861       return false;
862     }
863     softmax_axis[0] += shape.size();
864   }
865 
866   auto slice_node = GetSlice(slice_cnode);
867   if (slice_node == nullptr) {
868     return false;
869   }
870   auto slice_axes = slice_node->get_axes();
871   auto slice_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
872   auto slice_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
873 
874   MS_CHECK_TRUE_MSG(static_cast<int>(softmax_axis.size()) > 0, false, "shape_in.size() is wrong");
875   MS_CHECK_TRUE_MSG(slice_size.size() >= slice_axes.size(), false, "shape_in.size() is wrong");
876   MS_CHECK_TRUE_MSG(slice_begin.size() >= slice_axes.size(), false, "shape_in.size() is wrong");
877   for (size_t i = 0; i < slice_axes.size(); ++i) {
878     if (slice_axes[i] == softmax_axis.front()) {
879       if (slice_begin[i] != 0) {
880         return false;
881       }
882       if (slice_size[i] != -1) {
883         if (shape.empty() || slice_axes[i] >= static_cast<int>(shape.size())) {
884           return false;
885         }
886         if (slice_size[i] < shape[slice_axes[i]]) {
887           return false;
888         }
889       }
890     }
891   }
892   auto status = SwapSliceWithPreceed(graph, slice_cnode, softmax_cnode, 1);
893   if (status != RET_OK) {
894     return false;
895   }
896   softmax_cnode->set_abstract(slice_cnode->abstract()->Clone());
897   ClearCNodeAbstractValue(slice_cnode);
898   return true;
899 }
900 
901 /*
902  * Prepose condition:
903  *  require shape info
904  *  when reshape is normal(memory view is not changed, such as 4x5 reshaped to 4x1x5), can always prepose
905  *  when reshape is abnormal(such as 4x5 reshaped to 5x4), can prepose under some constraint
906  * For abnormal mode:
907  *  we only support border(not slice at center) slice at first mismatch axe,
908  *  and we only support matmul->reshape->slice => matmul->slice->reshape*->slice*(drop "dead" data)->reshape now,
909  *  cause the performance influence introduced by additional (reshape*->slice*) has not been fully evaluated.
910  */
PreposeWithReshape(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & reshape_cnode)911 bool SlicePreposePass::PreposeWithReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
912                                           const CNodePtr &reshape_cnode) {
913   MS_ASSERT(graph != nullptr);
914   MS_ASSERT(slice_cnode != nullptr);
915   MS_ASSERT(reshape_cnode != nullptr);
916   auto shape_in = GetCNodeInputShape(reshape_cnode, 1);
917   auto shape_out = GetCNodeInputShape(slice_cnode, 1);
918   auto shape_out_copy = shape_out;
919   if (shape_in.empty() || shape_out.empty()) {
920     MS_LOG(DEBUG) << "Reshape can't be preposed if either input or output shape is unknown";
921     return false;
922   }
923   if (reshape_cnode->inputs().size() == 3 && utils::isa<ParameterPtr>(reshape_cnode->input(2))) {
924     auto reshape_input_shape = utils::cast<ParameterPtr>(reshape_cnode->input(2));
925     MS_ASSERT(reshape_input_shape != nullptr);
926     if (!reshape_input_shape->has_default()) {
927       MS_LOG(ERROR) << "Reshape input shape is not constant";
928       return false;
929     }
930   }
931   std::vector<int64_t> mapped_axe(shape_out.size(), -1);
932   int64_t abnormal_axe_in = GetReshapeAbnormalAxeIn(shape_in, shape_out, &mapped_axe);
933   bool is_normal_mode = true;         // if all sliced axe can be found in input shape, normal
934   bool support_abnormal_mode = true;  // if first mismatch axe are sliced and no more other axes are sliced, abnormal
935   int64_t abnormal_index_out = GetReshapeAbnormalIndexOut(slice_cnode, mapped_axe, shape_out, &shape_out_copy,
936                                                           &is_normal_mode, &support_abnormal_mode);
937   if (abnormal_index_out == -1) {
938     MS_LOG(ERROR) << "GetReshapeAbnormalIndexOut failed.";
939     return false;
940   }
941   if (is_normal_mode) {
942     return PreposeWithNormalReshape(graph, slice_cnode, reshape_cnode, shape_in, shape_out_copy, mapped_axe);
943   } else if (support_abnormal_mode) {
944     auto matmul_node = reshape_cnode->input(1);
945     MS_ASSERT(matmul_node != nullptr);
946     if (IsMultiOutputTensors(graph, matmul_node) || !utils::isa<CNodePtr>(matmul_node)) {
947       MS_LOG(DEBUG) << "not matmul->reshape->slice";
948       return false;
949     }
950     auto matmul_cnode = matmul_node->cast<CNodePtr>();
951     if (matmul_cnode == nullptr) {
952       MS_LOG(ERROR) << "matmul_cnode is nullptr";
953       return false;
954     }
955     if (!CheckPrimitiveType(matmul_node, prim::kPrimFullConnection) &&
956         !CheckPrimitiveType(matmul_node, prim::kPrimMatMul)) {
957       MS_LOG(DEBUG) << "not matmul->reshape->slice pattern";
958       return false;
959     }
960     return PreposeWithAbnormalReshape(graph, slice_cnode, matmul_cnode, shape_in, shape_out, abnormal_axe_in,
961                                       abnormal_index_out);
962   }
963   return false;
964 }
965 
966 /*
967  * Prepose condition:
968  *  require shape info
969  */
PreposeWithMatmul(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & matmul_cnode)970 bool SlicePreposePass::PreposeWithMatmul(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
971                                          const CNodePtr &matmul_cnode) {
972   MS_ASSERT(graph != nullptr && slice_cnode != nullptr && matmul_cnode != nullptr);
973   auto matmul_shape = GetCNodeInputShape(slice_cnode, 1);
974   const int dims = matmul_shape.size();
975   if (dims == 0) {
976     // if Matmul's output shape is unknown, can't do prepose, cause we can't determine last two axes
977     return false;
978   }
979   auto slice_node = GetSlice(slice_cnode);
980   MS_CHECK_TRUE_MSG(slice_node != nullptr, false, "slice is nullptr");
981   auto axes = slice_node->get_axes();
982   auto begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
983   auto size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
984   // matmul not support broadcast now, it makes things simpler
985   auto manager = graph->manager();
986   std::shared_ptr<FuncGraphTransaction> tr = std::make_shared<FuncGraphTransaction>(manager.get());
987   MS_CHECK_TRUE_MSG(tr != nullptr, false, "create FuncGraphTransaction failed");
988   auto node_users = manager->node_users()[slice_cnode];
989   bool changed = false;
990   bool prepose_to_left = false;   // if only the last axe is sliced, not need prepose to left
991   bool prepose_to_right = false;  // if only the second last axe is sliced, not need prepose to right
992   MS_CHECK_TRUE_MSG(begin.size() >= axes.size(), false, "begin.size() is wrong");
993   MS_CHECK_TRUE_MSG(size.size() >= axes.size(), false, "size.size() is wrong");
994   for (size_t i = 0; i < axes.size(); ++i) {
995     if (begin[i] != 0 || (size[i] != -1 && size[i] != matmul_shape[axes[i]])) {
996       if (axes[i] != dims - 1) {
997         prepose_to_left = true;
998       } else if (axes[i] != dims - 2) {
999         prepose_to_right = true;
1000       }
1001     }
1002   }
1003   if (prepose_to_left) {  //  left matrix
1004     auto left_axes = axes;
1005     auto left_begin = begin;
1006     auto left_size = size;
1007     MS_CHECK_TRUE_MSG(left_begin.size() >= left_axes.size(), false, "left_begin.size() is wrong");
1008     MS_CHECK_TRUE_MSG(left_size.size() >= left_axes.size(), false, "left_size.size() is wrong");
1009     for (size_t i = 0; i < left_axes.size(); ++i) {
1010       if (left_axes[i] == dims - 1) {
1011         left_begin[i] = 0;
1012         left_size[i] = -1;
1013       }
1014     }
1015     auto left_slice_vnode = CreateSliceValueNode(left_axes);
1016     MS_CHECK_TRUE_MSG(left_slice_vnode != nullptr, false, "CreateSliceValueNode failed");
1017     auto begin_parameter = BuildIntVecParameterNode(
1018       graph, left_begin, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1019     node_name_index += 1;
1020     auto size_parameter = BuildIntVecParameterNode(
1021       graph, left_size, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1022     MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1023     MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1024     node_name_index += 1;
1025 
1026     const std::vector<AnfNodePtr> inputs = {left_slice_vnode, matmul_cnode->input(1), begin_parameter, size_parameter};
1027     auto new_slice_cnode = InsertSlice(graph, inputs, matmul_cnode, 1, tr);
1028     MS_CHECK_TRUE_MSG(new_slice_cnode != nullptr, false, "InsertSlice Failed");
1029     new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone());
1030     ClearCNodeAbstractValue(new_slice_cnode);
1031     changed = true;
1032   }
1033   if (prepose_to_right) {  //  right matrix
1034     auto right_axes = axes;
1035     auto right_begin = begin;
1036     auto right_size = size;
1037     MS_CHECK_TRUE_MSG(right_begin.size() >= right_axes.size(), false, "right_begin.size() is wrong");
1038     MS_CHECK_TRUE_MSG(right_size.size() >= right_axes.size(), false, "right_size.size() is wrong");
1039     for (size_t i = 0; i < right_axes.size(); ++i) {
1040       if (right_axes[i] == dims - 2) {
1041         right_begin[i] = 0;
1042         right_size[i] = -1;
1043       }
1044     }
1045     auto begin_parameter = BuildIntVecParameterNode(
1046       graph, right_begin, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1047     node_name_index += 1;
1048     auto size_parameter = BuildIntVecParameterNode(
1049       graph, right_size, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1050     MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1051     MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1052     node_name_index += 1;
1053     auto right_slice_vnode = CreateSliceValueNode(right_axes);
1054     MS_CHECK_TRUE_MSG(right_slice_vnode != nullptr, false, "CreateSliceValueNode failed");
1055     const std::vector<AnfNodePtr> inputs = {right_slice_vnode, matmul_cnode->input(2), begin_parameter, size_parameter};
1056     auto new_slice_cnode = InsertSlice(graph, inputs, matmul_cnode, 2, tr);
1057     MS_ASSERT(new_slice_cnode != nullptr);
1058     new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone());
1059     ClearCNodeAbstractValue(new_slice_cnode);
1060     changed = true;
1061   }
1062   if (changed) {
1063     matmul_cnode->set_abstract(slice_cnode->abstract()->Clone());
1064     for (auto &node_user : node_users) {
1065       tr->SetEdge(node_user.first, node_user.second, matmul_cnode);
1066     }
1067     tr->Commit();
1068     // we don't need graph->DropNode(slice_cnode);
1069   }
1070   return changed;
1071 }
1072 
1073 /*
1074  * Prepose condition:
1075  *  require shape info
1076  *  only support slice at first output axe now, and useAxis must be false
1077  */
PreposeWithFullConnection(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & fc_cnode)1078 bool SlicePreposePass::PreposeWithFullConnection(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
1079                                                  const CNodePtr &fc_cnode) {
1080   MS_ASSERT(graph != nullptr);
1081   MS_ASSERT(slice_cnode != nullptr);
1082   MS_ASSERT(fc_cnode != nullptr);
1083   auto shape_in = GetCNodeInputShape(fc_cnode, 1);
1084   auto shape_out = GetCNodeInputShape(slice_cnode, 1);
1085   if (shape_in.empty() || shape_out.size() != 2) {
1086     MS_LOG(DEBUG) << "FullConnection can't be preposed if input shape is unknown or output shape is illegal";
1087     return false;
1088   }
1089   auto fc_node = GetFc(fc_cnode);
1090   if (fc_node == nullptr || (fc_node->GetAttr(ops::kUseAxis) != nullptr && fc_node->get_use_axis())) {
1091     MS_LOG(DEBUG) << "prepose with fc only support useAxis == false currently";
1092     return false;
1093   }
1094   auto slice_node = GetSlice(slice_cnode);
1095   if (slice_node == nullptr) {
1096     MS_LOG(ERROR) << "slice is nullptr";
1097     return false;
1098   }
1099   auto axes = slice_node->get_axes();
1100   auto begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
1101   auto size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
1102   MS_CHECK_TRUE_MSG(begin.size() >= axes.size(), false, "begin.size() is wrong");
1103   MS_CHECK_TRUE_MSG(size.size() >= axes.size(), false, "size.size() is wrong");
1104   for (size_t i = 0; i < axes.size(); ++i) {
1105     if (axes[i] == 1) {
1106       if (begin[i] != 0 || (size[i] != -1 && size[i] != shape_out[1])) {
1107         MS_LOG(DEBUG) << "prepose with fc only support first output axe is sliced currently";
1108         return false;
1109       }
1110     }
1111   }
1112 
1113   std::vector<int64_t> mapped_axe(shape_out.size(), -1);
1114   int64_t inner_size_in = 1;
1115   for (size_t i = 0; i < shape_in.size(); ++i) {
1116     inner_size_in *= shape_in[i];
1117     int64_t inner_size_out = 1;
1118     for (size_t j = 0; j < shape_out.size(); ++j) {
1119       inner_size_out *= shape_out[j];
1120       if (shape_out[j] == shape_in[i] && inner_size_out == inner_size_in) {
1121         mapped_axe[j] = i;
1122         break;
1123       }
1124     }
1125   }
1126   if (mapped_axe[0] == -1) {
1127     MS_LOG(DEBUG) << "first axe in output can't find correspond input axe, can't do prepose";
1128     return false;
1129   }
1130 
1131   std::vector<int64_t> new_axes(shape_in.size());
1132   std::iota(new_axes.begin(), new_axes.end(), 0);
1133   std::vector<int> new_begin(shape_in.size(), 0);
1134   std::vector<int> new_size(shape_in.size(), -1);
1135   new_begin[mapped_axe[0]] = begin[0];
1136   new_size[mapped_axe[0]] = size[0];
1137   auto new_slice_vnode = CreateSliceValueNode(new_axes);
1138   if (new_slice_vnode == nullptr) {
1139     MS_LOG(ERROR) << "CreateSliceValueNode failed";
1140     return false;
1141   }
1142 
1143   auto manager = graph->manager();
1144   std::shared_ptr<FuncGraphTransaction> tr = std::make_shared<FuncGraphTransaction>(manager.get());
1145   if (tr == nullptr) {
1146     MS_LOG(ERROR) << "create FuncGraphTransaction failed";
1147     return false;
1148   }
1149   auto begin_parameter = BuildIntVecParameterNode(
1150     graph, new_begin, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1151   node_name_index += 1;
1152   auto size_parameter = BuildIntVecParameterNode(
1153     graph, new_size, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1154   MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1155   MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1156   node_name_index += 1;
1157   const std::vector<AnfNodePtr> inputs = {new_slice_vnode, fc_cnode->input(1), begin_parameter, size_parameter};
1158   auto new_slice_cnode = InsertSlice(graph, inputs, fc_cnode, 1, tr);
1159   MS_CHECK_TRUE_MSG(new_slice_cnode != nullptr, false, "InsertSlice Failed");
1160 
1161   fc_cnode->set_abstract(slice_cnode->abstract()->Clone());
1162   new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone());
1163   ClearCNodeAbstractValue(new_slice_cnode);
1164 
1165   auto node_users = manager->node_users()[slice_cnode];
1166   for (auto &node_user : node_users) {
1167     tr->SetEdge(node_user.first, node_user.second, fc_cnode);
1168   }
1169   tr->Commit();
1170   return true;
1171 }
1172 
1173 /*
1174  * Prepose condition:
1175  *  not require shape info, can always prepose
1176  */
PreposeWithTranspose(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & transpose_cnode)1177 bool SlicePreposePass::PreposeWithTranspose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
1178                                             const CNodePtr &transpose_cnode) {
1179   MS_ASSERT(graph != nullptr);
1180   MS_ASSERT(slice_cnode != nullptr);
1181   MS_ASSERT(transpose_cnode != nullptr);
1182   if (transpose_cnode->inputs().size() != 3) {
1183     MS_LOG(ERROR) << "transpose inputs size should be 3.";
1184     return false;
1185   }
1186   auto perm = GetTransposePerm(transpose_cnode);
1187   if (perm.empty()) {
1188     return false;
1189   }
1190   auto slice_node = GetSlice(slice_cnode);
1191   if (slice_node == nullptr) {
1192     MS_LOG(ERROR) << "GetSlicT failed";
1193     return false;
1194   }
1195   auto old_axes = slice_node->get_axes();
1196   auto old_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
1197   auto old_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
1198   auto slice_begin = GetSliceBeginAndSize(slice_cnode, SliceBeginIndex);
1199   auto slice_size = GetSliceBeginAndSize(slice_cnode, SliceSizeIndex);
1200   // perm is random shuffle of [0...n-1] according to ops/transpose.cc
1201   for (size_t i = 0; i < perm.size(); ++i) {
1202     if (perm[i] != static_cast<int>(i)) {
1203       for (size_t j = 0; j < old_axes.size(); ++j) {
1204         if (old_axes[j] == static_cast<int>(i)) {
1205           MS_CHECK_TRUE_MSG(static_cast<int>(slice_begin.size()) > perm[i], false, "slice_begin.size() is wrong");
1206           MS_CHECK_TRUE_MSG(static_cast<int>(slice_size.size()) > perm[i], false, "slice_size.size() is wrong");
1207           slice_begin[perm[i]] = old_begin[j];
1208           slice_size[perm[i]] = old_size[j];
1209           break;
1210         }
1211       }
1212     }
1213   }
1214   auto begin_parameter = BuildIntVecParameterNode(
1215     graph, slice_begin, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1216   node_name_index += 1;
1217   auto size_parameter = BuildIntVecParameterNode(
1218     graph, slice_size, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1219   MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1220   MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1221   node_name_index += 1;
1222   slice_cnode->set_input(SliceBeginIndex, begin_parameter);
1223   slice_cnode->set_input(SliceSizeIndex, size_parameter);
1224   auto status = SwapSliceWithPreceed(graph, slice_cnode, transpose_cnode, 1);
1225   if (status != RET_OK) {
1226     return false;
1227   }
1228   transpose_cnode->set_abstract(slice_cnode->abstract()->Clone());
1229   ClearCNodeAbstractValue(slice_cnode);
1230   return true;
1231 }
1232 /*
1233  * Prepose condition:
1234  *  may or may not require shape info
1235  */
PreposeWithArithmetic(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & arithmetic_cnode)1236 bool SlicePreposePass::PreposeWithArithmetic(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
1237                                              const CNodePtr &arithmetic_cnode) {
1238   MS_ASSERT(graph != nullptr);
1239   MS_ASSERT(slice_cnode != nullptr);
1240   MS_ASSERT(arithmetic_cnode != nullptr);
1241   auto manager = graph->manager();
1242   MS_ASSERT(manager != nullptr);
1243   auto node_users = manager->node_users()[slice_cnode];
1244   std::shared_ptr<FuncGraphTransaction> tr = std::make_shared<FuncGraphTransaction>(manager.get());
1245   if (tr == nullptr) {
1246     MS_LOG(ERROR) << "create FuncGraphTransaction failed";
1247     return false;
1248   }
1249   bool changed = false;
1250   std::vector<AnfNodePtr> inputs;
1251   std::vector<std::vector<int64_t>> shapes;
1252   std::vector<bool> is_default_params;
1253   if (!GetArithmeticInputInfo(arithmetic_cnode, &inputs, &shapes, &is_default_params)) {
1254     return false;
1255   }
1256 
1257   for (size_t i = 1; i < arithmetic_cnode->inputs().size(); ++i) {
1258     auto &input = inputs[i - 1];
1259     if (IsScalarNode(input)) {  // scalar not need prepose
1260       continue;
1261     }
1262     auto &shape = shapes[i - 1];
1263     const size_t another_index = kArithmeticInputNum - i;
1264     auto &another_input = inputs[another_index];
1265     auto &another_shape = shapes[another_index];
1266     if (IsScalarNode(input)) {
1267       continue;
1268     } else if (shape.empty()) {           // infershape failed at this input
1269       if (IsScalarNode(another_input)) {  // if another input is scalar, we can process this one
1270         auto new_slice_vnode = CopySliceValueNode(slice_cnode);
1271         if (new_slice_vnode == nullptr) {
1272           changed = false;
1273           break;
1274         }
1275         std::vector<AnfNodePtr> slice_inputs = {new_slice_vnode, arithmetic_cnode->input(i),
1276                                                 slice_cnode->input(SliceBeginIndex),
1277                                                 slice_cnode->input(SliceSizeIndex)};
1278         auto new_slice_cnode = InsertSlice(graph, slice_inputs, arithmetic_cnode, i, tr);
1279         MS_CHECK_TRUE_MSG(new_slice_cnode != nullptr, false, "InsertSlice Failed");
1280 
1281         new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone());
1282         ClearCNodeAbstractValue(new_slice_cnode);
1283         changed = true;
1284         break;
1285       } else {  // if another input's shape is not scalar, can't be processed
1286         changed = false;
1287         break;
1288       }
1289     } else {  // shape not empty
1290       if (!another_shape.empty() || IsScalarNode(another_input)) {
1291         std::vector<int64_t> new_axes;
1292         std::vector<int> new_begin;
1293         std::vector<int> new_size;
1294         auto status = SliceParamDeBroadcast(slice_cnode, shape, &new_axes, &new_begin, &new_size);
1295         if (status == lite::RET_NO_CHANGE) {
1296           continue;
1297         }
1298         if (status != lite::RET_OK) {
1299           changed = false;
1300           break;
1301         }
1302         auto new_slice_vnode = CreateSliceValueNode(new_axes);
1303         if (new_slice_vnode == nullptr) {
1304           changed = false;
1305           break;
1306         }
1307         auto begin_parameter = BuildIntVecParameterNode(
1308           graph, new_begin, slice_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1309         node_name_index += 1;
1310         auto size_parameter = BuildIntVecParameterNode(
1311           graph, new_size, slice_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1312         MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1313         MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1314         node_name_index += 1;
1315         std::vector<AnfNodePtr> slice_inputs = {new_slice_vnode, arithmetic_cnode->input(i), begin_parameter,
1316                                                 size_parameter};
1317         auto new_slice_cnode = InsertSlice(graph, slice_inputs, arithmetic_cnode, i, tr);
1318         MS_CHECK_TRUE_MSG(new_slice_cnode != nullptr, false, "InsertSlice Failed");
1319         new_slice_cnode->set_abstract(slice_cnode->abstract()->Clone());
1320         ClearCNodeAbstractValue(new_slice_cnode);
1321         changed = true;
1322       } else {
1323         changed = false;
1324         break;
1325       }
1326     }
1327   }
1328   if (changed) {
1329     arithmetic_cnode->set_abstract(slice_cnode->abstract()->Clone());
1330     for (auto &node_user : node_users) {
1331       tr->SetEdge(node_user.first, node_user.second, arithmetic_cnode);
1332     }
1333     tr->Commit();
1334     // we don't need graph->DropNode(slice_cnode);
1335   }
1336   return changed;
1337 }  // namespace mindspore::opt
1338 /*
1339  * Prepose condition:
1340  *  not require shape info
1341  */
MergeSequentialSlice(const FuncGraphPtr & graph,const CNodePtr & slice1_cnode,const CNodePtr & slice2_cnode)1342 bool SlicePreposePass::MergeSequentialSlice(const FuncGraphPtr &graph, const CNodePtr &slice1_cnode,
1343                                             const CNodePtr &slice2_cnode) {
1344   if (slice2_cnode->inputs().size() != kArithmeticInputNum) {
1345     MS_LOG(INFO) << "Slice read attrs from input is not supported now";
1346     return false;
1347   }
1348   auto slice1_node = GetSlice(slice1_cnode);  // bottom node
1349   auto slice2_node = GetSlice(slice2_cnode);  // top node
1350   if (slice1_node == nullptr || slice2_node == nullptr) {
1351     MS_LOG(ERROR) << "slice is null";
1352     return false;
1353   }
1354   auto begin_slice1 = GetSliceBeginAndSize(slice1_cnode, SliceBeginIndex);
1355   auto size_slice1 = GetSliceBeginAndSize(slice1_cnode, SliceSizeIndex);
1356   auto axes_slice1 = slice1_node->get_axes();
1357   auto begin_slice2 = GetSliceBeginAndSize(slice2_cnode, SliceBeginIndex);
1358   auto size_slice2 = GetSliceBeginAndSize(slice2_cnode, SliceSizeIndex);
1359   auto axes_slice2 = slice2_node->get_axes();
1360   auto status1 = VerifySliceAttrs(slice1_cnode);
1361   auto status2 = VerifySliceAttrs(slice2_cnode);
1362   if (status1 != RET_OK || status2 != RET_OK) {
1363     return false;
1364   }
1365 
1366   auto manager = graph->manager();
1367   MS_ASSERT(manager != nullptr);
1368   auto node_users = manager->node_users()[slice1_cnode];
1369   int64_t axe_max1 = *std::max_element(axes_slice1.begin(), axes_slice1.end());
1370   int64_t axe_max2 = *std::max_element(axes_slice2.begin(), axes_slice2.end());
1371   int64_t axe_max = std::max(axe_max1, axe_max2);
1372   auto begin_new = begin_slice2;
1373   auto size_new = size_slice2;
1374   auto axes_new = slice2_node->get_axes();
1375   axes_new.resize(axe_max + 1);
1376   std::iota(axes_new.begin(), axes_new.end(), 0);
1377   begin_new.assign(axe_max + 1, 0);
1378   size_new.assign(axe_max + 1, -1);
1379   MS_CHECK_TRUE_MSG(begin_slice2.size() >= axes_slice2.size(), false, "begin_slice2.size() is wrong");
1380   MS_CHECK_TRUE_MSG(size_slice2.size() >= axes_slice2.size(), false, "size_slice2.size() is wrong");
1381   MS_CHECK_TRUE_MSG(size_slice1.size() >= axes_slice1.size(), false, "size_slice1.size() is wrong");
1382   MS_CHECK_TRUE_MSG(begin_slice1.size() >= axes_slice1.size(), false, "begin_slice1.size() is wrong");
1383   for (int i = 0; i <= axe_max; ++i) {
1384     for (size_t j = 0; j < axes_slice2.size(); ++j) {
1385       if (axes_slice2[j] == i) {
1386         begin_new[i] = begin_slice2[j];
1387         size_new[i] = size_slice2[j];
1388         break;
1389       }
1390     }
1391     for (size_t j = 0; j < axes_slice1.size(); ++j) {
1392       if (axes_slice1[j] == i) {
1393         begin_new[i] = begin_new[i] + begin_slice1[j];
1394         if (size_new[i] == -1) {
1395           size_new[i] = size_slice1[j];
1396         } else {
1397           if (size_slice1[j] == -1) {
1398             size_new[i] = std::max(size_new[i] - begin_slice1[i], 0);  // clip with zero to avoid invalid negative value
1399           } else {
1400             size_new[i] = std::max(std::min(size_new[i] - begin_slice1[j], size_slice1[j]), 0);
1401           }
1402         }
1403         break;
1404       }
1405     }
1406   }
1407   slice2_node->set_axes(axes_new);
1408   auto begin_parameter = BuildIntVecParameterNode(
1409     graph, begin_new, slice2_cnode->fullname_with_scope() + "_begin_" + std::to_string(node_name_index));
1410   node_name_index += 1;
1411   auto size_parameter = BuildIntVecParameterNode(
1412     graph, size_new, slice2_cnode->fullname_with_scope() + "_size_" + std::to_string(node_name_index));
1413   MS_CHECK_TRUE_MSG(begin_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1414   MS_CHECK_TRUE_MSG(size_parameter != nullptr, false, "BuildIntVecParameterNode Failed");
1415   node_name_index += 1;
1416   slice2_cnode->set_input(SliceBeginIndex, begin_parameter);
1417   slice2_cnode->set_input(SliceSizeIndex, size_parameter);
1418   slice2_cnode->set_abstract(slice1_cnode->abstract()->Clone());
1419   for (auto &node_user : node_users) {
1420     manager->SetEdge(node_user.first, node_user.second, slice2_cnode);
1421   }
1422   return true;
1423 }
1424 
1425 /*
1426  * Prepose condition:
1427  *  when all sibling slices do same work
1428  *  can be optimize to not require all siblings are slice
1429  */
MergeParallelSlice(const FuncGraphPtr & graph,const NodeUsedListPtr & slices)1430 bool SlicePreposePass::MergeParallelSlice(const FuncGraphPtr &graph, const NodeUsedListPtr &slices) {
1431   MS_ASSERT(graph != nullptr);
1432   MS_ASSERT(slices->size() >= 2);
1433   auto manager = graph->manager();
1434   MS_ASSERT(manager != nullptr);
1435   auto first_slice = utils::cast<CNodePtr>(slices->at(0).first);
1436   MS_ASSERT(first_slice == nullptr);
1437   if (!CheckPrimitiveType(first_slice, prim::kPrimSliceFusion)) {
1438     MS_LOG(ERROR) << "first node is not Slice";
1439     return false;
1440   }
1441   auto first_parent = first_slice->input(1);
1442   if (first_parent == nullptr) {
1443     MS_LOG(ERROR) << "first slice node's parent is nullptr";
1444     return false;
1445   }
1446   std::shared_ptr<FuncGraphTransaction> tr = std::make_shared<FuncGraphTransaction>(manager.get());
1447   if (tr == nullptr) {
1448     MS_LOG(ERROR) << "create FuncGraphTransaction failed";
1449     return false;
1450   }
1451   for (size_t i = 1; i < slices->size(); ++i) {
1452     auto slice = utils::cast<CNodePtr>(slices->at(i).first);
1453     MS_ASSERT(slice == nullptr);
1454     if (!CheckPrimitiveType(slice, prim::kPrimSliceFusion)) {
1455       MS_LOG(ERROR) << "current node is not Slice";
1456       return false;
1457     }
1458     auto parent = slice->input(1);
1459     if (parent == nullptr || parent != first_parent) {
1460       MS_LOG(ERROR) << "not all slices have same parent node";
1461       return false;
1462     }
1463     auto node_users = manager->node_users()[slices->at(i).first];
1464     for (auto &node_user : node_users) {
1465       tr->SetEdge(node_user.first, node_user.second, slices->at(0).first);
1466     }
1467   }
1468   tr->Commit();
1469   return true;
1470 }
1471 
DoPrepose(const FuncGraphPtr & graph,const CNodePtr & slice_cnode,const CNodePtr & preceed_cnode)1472 bool SlicePreposePass::DoPrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode,
1473                                  const CNodePtr &preceed_cnode) {
1474   MS_ASSERT(graph != nullptr);
1475   MS_ASSERT(slice_cnode != nullptr);
1476   MS_ASSERT(preceed_cnode != nullptr);
1477   if (CheckPrimitiveType(preceed_cnode, prim::kPrimSoftmax)) {
1478     return PreposeWithSoftmax(graph, slice_cnode, preceed_cnode);
1479   } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimReshape)) {
1480     return PreposeWithReshape(graph, slice_cnode, preceed_cnode);
1481   } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimMatMul)) {
1482     return PreposeWithMatmul(graph, slice_cnode, preceed_cnode);
1483   } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimFullConnection)) {
1484     return PreposeWithFullConnection(graph, slice_cnode, preceed_cnode);
1485   } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimTranspose)) {
1486     return PreposeWithTranspose(graph, slice_cnode, preceed_cnode);
1487   } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimSubFusion) ||
1488              CheckPrimitiveType(preceed_cnode, prim::kPrimMulFusion) ||
1489              CheckPrimitiveType(preceed_cnode, prim::kPrimAddFusion)) {
1490     return PreposeWithArithmetic(graph, slice_cnode, preceed_cnode);
1491   } else if (CheckPrimitiveType(preceed_cnode, prim::kPrimSliceFusion)) {
1492     return MergeSequentialSlice(graph, slice_cnode, preceed_cnode);
1493   }
1494   return false;
1495 }
1496 
Run(const FuncGraphPtr & graph)1497 bool SlicePreposePass::Run(const FuncGraphPtr &graph) {
1498   if (fmk_type != converter::kFmkTypeTf && fmk_type != converter::kFmkTypeTflite) {
1499     MS_LOG(INFO) << "The framework type of model should be tf/tflite.";
1500     return false;
1501   }
1502   MS_ASSERT(graph != nullptr);
1503   bool changed = false;
1504   while (true) {
1505     bool this_time_changed = false;
1506     auto node_list = TopoSort(graph->get_return());
1507     for (auto &node : node_list) {
1508       if (node->func_graph() != graph || !utils::isa<CNodePtr>(node) ||
1509           !CheckPrimitiveType(node, prim::kPrimSliceFusion)) {
1510         continue;
1511       }
1512       auto slice_cnode = node->cast<CNodePtr>();
1513       MS_ASSERT(slice_cnode != nullptr);
1514       // only support begin and size is const tensor.
1515       if (!CheckIsAllInputsParam(slice_cnode) || GetSlice(slice_cnode)) {
1516         continue;
1517       }
1518       auto preceed_node = slice_cnode->input(1);
1519       if (preceed_node == nullptr) {
1520         MS_LOG(ERROR) << "proceed node is nullptr";
1521         continue;
1522       }
1523       auto output_tensor_num = GetOutputTensorNum(preceed_node);
1524       if (output_tensor_num > 1) {
1525         continue;
1526       }
1527       auto output_node_list = GetRealNodeUsedList(graph, utils::cast<AnfNodePtr>(preceed_node));
1528       if (output_node_list->size() > 1) {  // referenced by multi nodes
1529         if (SiblingsAreSameSlice(output_node_list) && MergeParallelSlice(graph, output_node_list)) {
1530           this_time_changed = true;
1531           break;
1532         }
1533         continue;
1534       } else {
1535         if (utils::isa<ParameterPtr>(preceed_node)) {
1536           /*
1537            * if preceed_node is parameter without default param, it's input placeholder, so we can't prepose
1538            * if preceed_node is parameter with default param, constant_folding will process it
1539            */
1540           continue;
1541         }
1542         auto preceed_cnode = preceed_node->cast<CNodePtr>();
1543         if (preceed_cnode == nullptr) {
1544           MS_LOG(ERROR) << "preceed_cnode is nullptr";
1545           continue;
1546         }
1547         if (DoPrepose(graph, slice_cnode, preceed_cnode)) {
1548           this_time_changed = true;
1549           break;
1550         }
1551       }
1552     }
1553     if (this_time_changed) {
1554       changed = true;
1555     } else {
1556       break;
1557     }
1558   }
1559   return changed;
1560 }
1561 }  // namespace mindspore::opt
1562