• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_FLOATSTATUS_FUSION__H_
17 #define MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_FLOATSTATUS_FUSION__H_
18 
19 #include <memory>
20 #include <string>
21 #include "include/backend/optimizer/optimizer.h"
22 
23 namespace mindspore::graphkernel {
24 /**
25  * @brief Fuse IsFinite and its user to FloatStatus
26  * @example
27  *  main_graph {
28  *     %1 = IsFinite(%0)
29  *     %2 = ReduceAll(%1)
30  *     %3 = Cast(%2)
31  *     %4 = Sub(1, %3)
32  *     return %4
33  *   }
34  *  or
35  *   main_graph {
36  *     %1 = IsFinite(%0)
37  *     %2 = ReduceAll(%1)
38  *     %3 = Cast(%2)
39  *     %4 = Sub(1, %3)
40  *     %5 = Reshape(%4, (1,))
41  *     return %5
42  *   }
43  *   ---------->
44  *   main_graph {
45  *     %1 = FloatStatus(%0)
46  *     return %1
47  *   }
48  */
49 class FloatStatusBaseFusion : public opt::PatternProcessPass {
50  public:
51   explicit FloatStatusBaseFusion(const std::string &pass_name, bool multigraph = true)
PatternProcessPass(pass_name,multigraph)52       : PatternProcessPass(pass_name, multigraph),
53         input_{std::make_shared<Var>()},
54         axis_{std::make_shared<Var>()},
55         keep_dims_{std::make_shared<Var>()},
56         type_{std::make_shared<Var>()},
57         s_{std::make_shared<Var>()} {}
58   ~FloatStatusBaseFusion() override = default;
59   const BaseRef DefinePattern() const override;
60   const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &node, const EquivPtr &) const override;
61 
62  protected:
63   VarPtr input_;
64   VarPtr axis_;
65   VarPtr keep_dims_;
66   VarPtr type_;
67   VarPtr s_;
68 };
69 
70 class FloatStatusReshapeFusion : public FloatStatusBaseFusion {
71  public:
72   explicit FloatStatusReshapeFusion(const std::string &pass_name, bool multigraph = true)
FloatStatusBaseFusion(pass_name,multigraph)73       : FloatStatusBaseFusion(pass_name, multigraph), to_shape_{std::make_shared<Var>()} {}
74   ~FloatStatusReshapeFusion() override = default;
75   const BaseRef DefinePattern() const override;
76 
77  private:
78   VarPtr to_shape_;
79 };
80 
81 class CastFloatStatusBaseFusion : public FloatStatusBaseFusion {
82  public:
83   explicit CastFloatStatusBaseFusion(const std::string &pass_name, bool multigraph = true)
FloatStatusBaseFusion(pass_name,multigraph)84       : FloatStatusBaseFusion(pass_name, multigraph), type_fp32_{std::make_shared<Var>()} {}
85   ~CastFloatStatusBaseFusion() override = default;
86   const BaseRef DefinePattern() const override;
87 
88  private:
89   VarPtr type_fp32_;
90 };
91 
92 class CastFloatStatusReshapeFusion : public CastFloatStatusBaseFusion {
93  public:
94   explicit CastFloatStatusReshapeFusion(const std::string &pass_name, bool multigraph = true)
CastFloatStatusBaseFusion(pass_name,multigraph)95       : CastFloatStatusBaseFusion(pass_name, multigraph), to_shape_{std::make_shared<Var>()} {}
96   ~CastFloatStatusReshapeFusion() override = default;
97   const BaseRef DefinePattern() const override;
98 
99  private:
100   VarPtr to_shape_;
101 };
102 
103 class FloatStatusFusion : public opt::Pass {
104  public:
FloatStatusFusion()105   FloatStatusFusion() : Pass("floatstatus_fusion") {
106     cast_floatstatus_reshape_ = std::make_shared<CastFloatStatusReshapeFusion>("cast_floatstatus_reshape_fusion");
107     cast_floatstatus_base_ = std::make_shared<CastFloatStatusBaseFusion>("cast_floatstatus_base_fusion");
108     floatstatus_reshape_ = std::make_shared<FloatStatusReshapeFusion>("floatstatus_reshape_fusion");
109     floatstatus_base_ = std::make_shared<FloatStatusBaseFusion>("floatstatus_base_fusion");
110   }
111   ~FloatStatusFusion() override = default;
Run(const FuncGraphPtr & func_graph)112   bool Run(const FuncGraphPtr &func_graph) override {
113     cast_floatstatus_reshape_->Run(func_graph);
114     cast_floatstatus_base_->Run(func_graph);
115     floatstatus_reshape_->Run(func_graph);
116     floatstatus_base_->Run(func_graph);
117     return true;
118   }
119 
120  private:
121   opt::PassPtr cast_floatstatus_reshape_;
122   opt::PassPtr cast_floatstatus_base_;
123   opt::PassPtr floatstatus_reshape_;
124   opt::PassPtr floatstatus_base_;
125 };
126 }  // namespace mindspore::graphkernel
127 #endif  // MINDSPORE_CCSRC_BACKEND_COMMON_GRAPH_KERNEL_FLOATSTATUS_FUSION__H_
128