1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ 17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ 18 19 #include <vector> 20 #include <memory> 21 #include <string> 22 #include "backend/optimizer/common/optimizer.h" 23 #include "utils/utils.h" 24 25 namespace mindspore { 26 namespace opt { 27 constexpr size_t kAdamApplyOneInputVarNum = 5; 28 constexpr size_t kAdamApplyOneMulInputVarNum = 4; 29 30 class AdamApplyOneFusion : public PatternProcessPass { 31 public: 32 explicit AdamApplyOneFusion(const std::string &name = "adam_apply_one_fusion", bool multigraph = true) PatternProcessPass(name,multigraph)33 : PatternProcessPass(name, multigraph) { 34 for (size_t i = 0; i < kAdamApplyOneInputVarNum; ++i) { 35 input_vars_.push_back(std::make_shared<Var>()); 36 } 37 for (size_t i = 0; i < kAdamApplyOneMulInputVarNum; ++i) { 38 mul_x_input_vars_.push_back(std::make_shared<Var>()); 39 } 40 add2_y_ = std::make_shared<Var>(); 41 add0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name())); 42 add1_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimAdd->name())); 43 sub0_var_ = std::make_shared<Var>(std::make_shared<Primitive>(prim::kPrimSub->name())); 44 } 45 46 ~AdamApplyOneFusion() override = default; 47 const BaseRef DefinePattern() const override; 48 const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; 49 50 protected: 51 AnfNodePtr CreateAdamApplyOneNode(const FuncGraphPtr &func_graph, const EquivPtr &equiv, 52 const AnfNodePtr &final_node) const; 53 std::vector<VarPtr> input_vars_; 54 std::vector<VarPtr> mul_x_input_vars_; 55 VarPtr add2_y_; 56 VarPtr add0_var_; 57 VarPtr add1_var_; 58 VarPtr sub0_var_; 59 }; 60 61 class AdamApplyOneCond1Fusion : public AdamApplyOneFusion { 62 public: 63 explicit AdamApplyOneCond1Fusion(bool multigraph = true) 64 : AdamApplyOneFusion("adam_apply_one_cond1_fusion", multigraph) {} 65 66 ~AdamApplyOneCond1Fusion() override = default; 67 const BaseRef DefinePattern() const override; 68 }; 69 70 class AdamApplyOneCond2Fusion : public AdamApplyOneFusion { 71 public: 72 explicit AdamApplyOneCond2Fusion(bool multigraph = true) 73 : AdamApplyOneFusion("adam_apply_one_cond2_fusion", multigraph) {} 74 75 ~AdamApplyOneCond2Fusion() override = default; 76 const BaseRef DefinePattern() const override; 77 }; 78 79 class AdamApplyOneCond3Fusion : public AdamApplyOneFusion { 80 public: 81 explicit AdamApplyOneCond3Fusion(bool multigraph = true) 82 : AdamApplyOneFusion("adam_apply_one_cond3_fusion", multigraph) {} 83 84 ~AdamApplyOneCond3Fusion() override = default; 85 const BaseRef DefinePattern() const override; 86 }; 87 88 class AdamApplyOneCond4Fusion : public AdamApplyOneFusion { 89 public: 90 explicit AdamApplyOneCond4Fusion(bool multigraph = true) 91 : AdamApplyOneFusion("adam_apply_one_cond4_fusion", multigraph) {} 92 93 ~AdamApplyOneCond4Fusion() override = default; 94 const BaseRef DefinePattern() const override; 95 }; 96 97 class AdamApplyOneAssignFusion : public AdamApplyOneFusion { 98 public: 99 explicit AdamApplyOneAssignFusion(bool multigraph = true) 100 : AdamApplyOneFusion("adam_apply_one_assign_fusion", multigraph) {} 101 102 ~AdamApplyOneAssignFusion() override = default; 103 const BaseRef DefinePattern() const override; 104 }; 105 106 class AdamApplyOneAssignCond1Fusion : public AdamApplyOneFusion { 107 public: 108 explicit AdamApplyOneAssignCond1Fusion(bool multigraph = true) 109 : AdamApplyOneFusion("adam_apply_one_assign_cond1_fusion", multigraph) {} 110 111 ~AdamApplyOneAssignCond1Fusion() override = default; 112 const BaseRef DefinePattern() const override; 113 }; 114 115 class AdamApplyOneAssignCond2Fusion : public AdamApplyOneFusion { 116 public: 117 explicit AdamApplyOneAssignCond2Fusion(bool multigraph = true) 118 : AdamApplyOneFusion("adam_apply_one_assign_cond2_fusion", multigraph) {} 119 120 ~AdamApplyOneAssignCond2Fusion() override = default; 121 const BaseRef DefinePattern() const override; 122 }; 123 124 class AdamApplyOneAssignCond3Fusion : public AdamApplyOneFusion { 125 public: 126 explicit AdamApplyOneAssignCond3Fusion(bool multigraph = true) 127 : AdamApplyOneFusion("adam_apply_one_assign_cond3_fusion", multigraph) {} 128 129 ~AdamApplyOneAssignCond3Fusion() override = default; 130 const BaseRef DefinePattern() const override; 131 }; 132 133 class AdamApplyOneAssignCond4Fusion : public AdamApplyOneFusion { 134 public: 135 explicit AdamApplyOneAssignCond4Fusion(bool multigraph = true) 136 : AdamApplyOneFusion("adam_apply_one_assign_cond4_fusion", multigraph) {} 137 138 ~AdamApplyOneAssignCond4Fusion() override = default; 139 const BaseRef DefinePattern() const override; 140 }; 141 } // namespace opt 142 } // namespace mindspore 143 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_IR_FUSION_ADAM_APPLY_ONE_FUSION_H_ 144