1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_SLICE_PREPOSE_PASS_H_ 17 #define MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_SLICE_PREPOSE_PASS_H_ 18 19 #include <vector> 20 #include <memory> 21 #include <utility> 22 #include <string> 23 #include "include/backend/optimizer/pass.h" 24 #include "include/errorcode.h" 25 #include "mindspore/core/ir/manager.h" 26 #include "include/registry/converter_context.h" 27 28 using mindspore::converter::FmkType; 29 namespace mindspore::opt { 30 using lite::RET_ERROR; 31 using lite::RET_OK; 32 using lite::STATUS; 33 using TransactionPtr = std::shared_ptr<mindspore::FuncGraphTransaction>; 34 using NodeUsedListPtr = std::shared_ptr<std::vector<std::pair<AnfNodePtr, int>>>; 35 class SlicePreposePass : public Pass { 36 public: SlicePreposePass()37 SlicePreposePass() : Pass("slice_prepose_pass") {} 38 ~SlicePreposePass() override = default; 39 bool Run(const FuncGraphPtr &graph) override; SetFmkType(FmkType fmkType)40 void SetFmkType(FmkType fmkType) { this->fmk_type = fmkType; } 41 42 private: 43 static void ClearCNodeAbstractValue(const CNodePtr &cnode); 44 static STATUS SwapSliceWithPreceed(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, 45 const CNodePtr &preceed_cnode, int index, const TransactionPtr &tr = nullptr); 46 static ValueNodePtr CreateSliceValueNode(const std::vector<int64_t> &axes); 47 static ValueNodePtr CopySliceValueNode(const CNodePtr &slice_cnode); 48 static CNodePtr InsertSlice(const FuncGraphPtr &graph, const std::vector<AnfNodePtr> &inputs, 49 const CNodePtr &preceed_cnode, int index, const TransactionPtr &tr); 50 static STATUS VerifySliceAttrs(const CNodePtr &slice_cnode, int dim = -1); 51 static STATUS SliceParamDeBroadcast(const CNodePtr &slice_cnode, const std::vector<int64_t> &ref_shape, 52 std::vector<int64_t> *axes, std::vector<int> *begin, std::vector<int> *size); 53 static CNodePtr CreateReshapeCNode(const FuncGraphPtr &graph, const std::vector<int64_t> &shape, 54 const AbstractBasePtr &abstract, const CNodePtr &preceed_cnode); 55 static bool SiblingsAreSameSlice(const NodeUsedListPtr &output_node_list, const std::vector<int64_t> &ref_shape = {}); 56 static int64_t GetReshapeAbnormalAxeIn(const std::vector<int64_t> &shape_in, const std::vector<int64_t> &shape_out, 57 std::vector<int64_t> *mapped_axe); 58 static int64_t GetReshapeAbnormalIndexOut(const CNodePtr &slice_cnode, const std::vector<int64_t> &mapped_axe, 59 const std::vector<int64_t> &shape_out, std::vector<int64_t> *shape_out_copy, 60 bool *is_normal_mode, bool *support_abnormal_mode); 61 static bool PreposeWithNormalReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, 62 const CNodePtr &reshape_cnode, const std::vector<int64_t> &shape_in, 63 const std::vector<int64_t> &shape_out_copy, 64 const std::vector<int64_t> &mapped_axe); 65 static CNodePtr CreateSlice1ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, 66 const CNodePtr &matmul_cnode, const std::vector<int64_t> &shape_in, 67 int64_t abnormal_axe_in, int64_t count_sliced_axe_in, 68 bool slice_at_front); 69 static CNodePtr CreateSlice2ForReshapePrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, 70 const CNodePtr &new_reshape1_cnode, 71 const std::vector<int64_t> &new_shape1, int64_t abnormal_axe_in, 72 int64_t count_sliced2, bool slice_at_front); 73 static bool PreposeWithAbnormalReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, 74 const CNodePtr &matmul_cnode, const std::vector<int64_t> &shape_in, 75 const std::vector<int64_t> &shape_out, int64_t abnormal_axe_in, 76 int64_t abnormal_index_out); 77 static bool GetArithmeticInputInfo(const CNodePtr &arithmetic_cnode, std::vector<AnfNodePtr> *inputs, 78 std::vector<std::vector<int64_t>> *shapes, std::vector<bool> *is_default_params); 79 80 static bool DoPrepose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &preceed_cnode); 81 82 static bool PreposeWithSoftmax(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &softmax_cnode); 83 static bool PreposeWithReshape(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &reshape_cnode); 84 static bool PreposeWithMatmul(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, const CNodePtr &matmul_cnode); 85 static bool PreposeWithFullConnection(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, 86 const CNodePtr &fc_cnode); 87 static bool PreposeWithTranspose(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, 88 const CNodePtr &transpose_cnode); 89 static bool PreposeWithArithmetic(const FuncGraphPtr &graph, const CNodePtr &slice_cnode, 90 const CNodePtr &arithmetic_cnode); 91 static bool MergeSequentialSlice(const FuncGraphPtr &graph, const CNodePtr &slice1_cnode, 92 const CNodePtr &slice2_cnode); 93 static bool MergeParallelSlice(const FuncGraphPtr &graph, const NodeUsedListPtr &slices); 94 95 private: 96 FmkType fmk_type = converter::kFmkTypeOnnx; 97 }; 98 } // namespace mindspore::opt 99 100 #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_GRAPH_SLICE_PREPOSE_PASS_H_ 101