1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_FORMAT_BASE_H_ 18 #define MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_FORMAT_BASE_H_ 19 20 #include <memory> 21 #include <set> 22 #include <string> 23 #include <unordered_map> 24 #include <vector> 25 #include "backend/optimizer/common/pass.h" 26 #include "tools/converter/converter_flags.h" 27 #include "tools/optimizer/common/format_utils.h" 28 #include "tools/optimizer/graph/infershape_pass.h" 29 #include "ops/fusion/conv2d_fusion.h" 30 #include "ops/fusion/conv2d_transpose_fusion.h" 31 #include "ops/adam.h" 32 #include "ops/sgd.h" 33 #include "ops/apply_momentum.h" 34 35 using mindspore::converter::FmkType; 36 namespace mindspore { 37 namespace opt { 38 class ToFormatBase : public Pass { 39 public: 40 explicit ToFormatBase(FmkType fmk_type = converter::kFmkTypeMs, bool train_flag = false, 41 std::string pass_name = "ToFormatBase") Pass(pass_name)42 : Pass(pass_name), fmk_type_(fmk_type), train_flag_(train_flag) {} 43 ~ToFormatBase() override = default; 44 bool Run(const FuncGraphPtr &func_graph) override; IsConvFamilyNode(const AnfNodePtr & node)45 static bool IsConvFamilyNode(const AnfNodePtr &node) { 46 return opt::CheckPrimitiveType(node, prim::kPrimConv2DFusion) || 47 opt::CheckPrimitiveType(node, opt::kPrimConv2DBackpropInputFusion) || 48 opt::CheckPrimitiveType(node, prim::kPrimConv2dTransposeFusion); 49 } IsOptimizerNode(const AnfNodePtr & node)50 static bool IsOptimizerNode(const AnfNodePtr &node) { 51 return opt::CheckPrimitiveType(node, prim::kPrimApplyMomentum) || opt::CheckPrimitiveType(node, prim::kPrimSGD) || 52 opt::CheckPrimitiveType(node, prim::kPrimAdam); 53 } IsWeightNodeSensitive(const AnfNodePtr & node)54 static bool IsWeightNodeSensitive(const AnfNodePtr &node) { return IsConvFamilyNode(node) || IsOptimizerNode(node); } 55 56 private: 57 bool BasicProcess(const FuncGraphPtr &func_graph, bool main_graph); 58 STATUS HandleGraphInput(const FuncGraphPtr &func_graph); 59 STATUS HandleGraphNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode); 60 STATUS InsertPostTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm); 61 STATUS InsertPreTransNode(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm); 62 STATUS GenNewInput(const FuncGraphPtr &func_graph, const CNodePtr &cnode, const std::vector<int> &perm, bool before, 63 size_t index = 0); 64 STATUS ModifyCNode(const CNodePtr &cnode); 65 STATUS ConvWeightFormatTrans(const FuncGraphPtr &graph, std::set<AnfNodePtr> *has_visited); 66 67 protected: 68 virtual STATUS GetTransNodeFormatType(const CNodePtr &cnode, opt::TransTypePair *trans_info) = 0; SetSensitiveOps()69 virtual void SetSensitiveOps() { sensitive_ops_ = opt::GetNHWCOpMap(); } 70 virtual bool DecideWhetherHandleGraphInput(const FuncGraphPtr &func_graph, const ParameterPtr &input, 71 const ShapeVector &shape); DecideWhetherInferShapeForNewNode()72 virtual bool DecideWhetherInferShapeForNewNode() { return true; } 73 virtual STATUS DecideConvWeightSrcAndDstFormat(const CNodePtr &cnode, schema::Format *src_format, 74 schema::Format *dst_format) = 0; 75 FmkType fmk_type_{converter::kFmkTypeMs}; 76 bool train_flag_{false}; 77 mindspore::Format format_{mindspore::NHWC}; 78 std::shared_ptr<NodeInferShape> node_infer_shape_{nullptr}; 79 std::unordered_map<std::string, std::vector<size_t>> sensitive_ops_; 80 FuncGraphManagerPtr manager_; 81 }; 82 } // namespace opt 83 } // namespace mindspore 84 85 #endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FORMAT_TO_FORMAT_BASE_H_ 86