1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_FLOATSTATUS_FUSION__H_ 17 #define MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_FLOATSTATUS_FUSION__H_ 18 19 #include <memory> 20 #include <string> 21 #include "include/backend/optimizer/optimizer.h" 22 23 namespace mindspore::graphkernel { 24 /** 25 * @brief Fuse IsFinite and its user to FloatStatus 26 * @example 27 * main_graph { 28 * %1 = IsFinite(%0) 29 * %2 = ReduceAll(%1) 30 * %3 = Cast(%2) 31 * %4 = Sub(1, %3) 32 * return %4 33 * } 34 * or 35 * main_graph { 36 * %1 = IsFinite(%0) 37 * %2 = ReduceAll(%1) 38 * %3 = Cast(%2) 39 * %4 = Sub(1, %3) 40 * %5 = Reshape(%4, (1,)) 41 * return %5 42 * } 43 * ----------> 44 * main_graph { 45 * %1 = FloatStatus(%0) 46 * return %1 47 * } 48 */ 49 class FloatStatusBaseFusion : public opt::PatternProcessPass { 50 public: 51 explicit FloatStatusBaseFusion(const std::string &pass_name, bool multigraph = true) PatternProcessPass(pass_name,multigraph)52 : PatternProcessPass(pass_name, multigraph), 53 input_{std::make_shared<Var>()}, 54 axis_{std::make_shared<Var>()}, 55 keep_dims_{std::make_shared<Var>()}, 56 type_{std::make_shared<Var>()}, 57 s_{std::make_shared<Var>()} {} 58 ~FloatStatusBaseFusion() override = default; 59 const BaseRef DefinePattern() const override; 60 const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const override; 61 62 protected: 63 VarPtr input_; 64 VarPtr axis_; 65 VarPtr keep_dims_; 66 VarPtr type_; 67 VarPtr s_; 68 }; 69 70 class FloatStatusReshapeFusion : public FloatStatusBaseFusion { 71 public: 72 explicit FloatStatusReshapeFusion(const std::string &pass_name, bool multigraph = true) FloatStatusBaseFusion(pass_name,multigraph)73 : FloatStatusBaseFusion(pass_name, multigraph), to_shape_{std::make_shared<Var>()} {} 74 ~FloatStatusReshapeFusion() override = default; 75 const BaseRef DefinePattern() const override; 76 77 private: 78 VarPtr to_shape_; 79 }; 80 81 class CastFloatStatusBaseFusion : public FloatStatusBaseFusion { 82 public: 83 explicit CastFloatStatusBaseFusion(const std::string &pass_name, bool multigraph = true) FloatStatusBaseFusion(pass_name,multigraph)84 : FloatStatusBaseFusion(pass_name, multigraph), type_fp32_{std::make_shared<Var>()} {} 85 ~CastFloatStatusBaseFusion() override = default; 86 const BaseRef DefinePattern() const override; 87 88 private: 89 VarPtr type_fp32_; 90 }; 91 92 class CastFloatStatusReshapeFusion : public CastFloatStatusBaseFusion { 93 public: 94 explicit CastFloatStatusReshapeFusion(const std::string &pass_name, bool multigraph = true) CastFloatStatusBaseFusion(pass_name,multigraph)95 : CastFloatStatusBaseFusion(pass_name, multigraph), to_shape_{std::make_shared<Var>()} {} 96 ~CastFloatStatusReshapeFusion() override = default; 97 const BaseRef DefinePattern() const override; 98 99 private: 100 VarPtr to_shape_; 101 }; 102 103 class FloatStatusFusion : public opt::Pass { 104 public: FloatStatusFusion()105 FloatStatusFusion() : Pass("floatstatus_fusion") { 106 cast_floatstatus_reshape_ = std::make_shared<CastFloatStatusReshapeFusion>("cast_floatstatus_reshape_fusion"); 107 cast_floatstatus_base_ = std::make_shared<CastFloatStatusBaseFusion>("cast_floatstatus_base_fusion"); 108 floatstatus_reshape_ = std::make_shared<FloatStatusReshapeFusion>("floatstatus_reshape_fusion"); 109 floatstatus_base_ = std::make_shared<FloatStatusBaseFusion>("floatstatus_base_fusion"); 110 } 111 ~FloatStatusFusion() override = default; Run(const FuncGraphPtr & func_graph)112 bool Run(const FuncGraphPtr &func_graph) override { 113 cast_floatstatus_reshape_->Run(func_graph); 114 cast_floatstatus_base_->Run(func_graph); 115 floatstatus_reshape_->Run(func_graph); 116 floatstatus_base_->Run(func_graph); 117 return true; 118 } 119 120 private: 121 opt::PassPtr cast_floatstatus_reshape_; 122 opt::PassPtr cast_floatstatus_base_; 123 opt::PassPtr floatstatus_reshape_; 124 opt::PassPtr floatstatus_base_; 125 }; 126 } // namespace mindspore::graphkernel 127 #endif // MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_FLOATSTATUS_FUSION__H_ 128