• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_TRANSFORM_OP_OPTIMIZER_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_TRANSFORM_OP_OPTIMIZER_H_
18 
19 #include <string>
20 #include <vector>
21 #include <memory>
22 #include <functional>
23 #include <set>
24 #include "include/backend/optimizer/pass.h"
25 #include "ir/func_graph.h"
26 #include "backend/common/graph_kernel/model/lite_graph.h"
27 
28 namespace mindspore::graphkernel {
29 using inner::NodePtr;
30 enum class FormatType { kFlexFormat, kFormatA, kFormatB };
31 enum class TransOpType { kTransAB, kTransBA };
32 
33 /**
34  * @brief Handle for transform op interfaces, which is called in Mutator.
35  * @note Subclass should NOT save the NodePtr in constructor.
36  */
37 class TransformOp {
38  public:
39   explicit TransformOp(const NodePtr &node);
40   virtual ~TransformOp() = default;
41 
42   // get the output format of `node`
43   virtual std::string GetFormat(const NodePtr &node) const;
44   // check the node is TransAB or TransBA
45   virtual bool IsTransformOp(const NodePtr &node);
46   // check whether need to insert a new transform op or not
47   virtual bool NeedInsert(const NodePtr &input_node) const;
48   // gen a new transform op of the trans_type
49   virtual NodePtr GenTransformOp(const NodePtr &input_node, TransOpType trans_type) = 0;
50   // check the input format is kFormatA or kFormatB
51   virtual FormatType GetFormatType(const std::string &fmt);
52   // set inputs for this transform op
53   virtual void SetInput(const NodePtr &node, const NodePtr &input_node);
54   // get hash value for this transform op
55   size_t Hash() const;
56 
57   friend std::ostream &operator<<(std::ostream &os, const TransformOp &t) {
58     return os << t.op_ << "(" << t.format_a_ << " <-> " << t.format_b_ << ")";
59   }
60 
61  protected:
62   std::string op_;
63   std::string format_a_;
64   std::string format_b_;
65 };
66 using TransformOpPtr = std::shared_ptr<TransformOp>;
67 
68 class TransformOpCreator {
69  public:
70   using HandleCreateFunc = std::function<TransformOpPtr(const NodePtr &)>;
TransformOpCreator(const std::string & name,const HandleCreateFunc & func)71   TransformOpCreator(const std::string &name, const HandleCreateFunc &func) : op_name_(name), func_(func) {}
72   ~TransformOpCreator() = default;
73 
74   bool IsTransOp(const NodePtr &node) const;
Name()75   std::string Name() const { return op_name_; }
CreateHandle(const NodePtr & node)76   TransformOpPtr CreateHandle(const NodePtr &node) const { return func_(node); }
77 
78  private:
79   std::string op_name_;
80   HandleCreateFunc func_;
81 };
82 
83 #define TRANS_OP_CREATOR(op_name, hd_cls)                                         \
84   TransformOpCreator(op_name, [](const NodePtr &node) {                           \
85     return std::static_pointer_cast<TransformOp>(std::make_shared<hd_cls>(node)); \
86   })
87 
88 /**
89  * @brief Eliminate the unnecessary transformation ops when the other operators
90  *        are format flexible.
91  * @example
92  *   %1 = Transpose(p0) // NCHW to NHWC
93  *   %2 = Transpose(p1) // NCHW to NHWC
94  *   %3 = Add(%1, %2)
95  *   return %3
96  *  -->
97  *   %1 = Add(p0, p1)
98  *   %2 = Transpose(%1) // NCHW to NHWC
99  *   return %2
100  * @example
101  *   %1 = Transpose(p0) // NCHW to NHWC
102  *   %2 = Transpose(p1) // NCHW to NHWC
103  *   %3 = Add(%1, %2)
104  *   %4 = Transpose(%3) // NHWC to NCHW
105  *   return %4
106  *  -->
107  *   %1 = Add(p0, p1)
108  *   return %1
109  * ============================================================================
110  * See https://gitee.com/mindspore/mindspore/issues/I3UW79 for more details.
111  */
112 class TransformOpOptimizer : public opt::Pass {
113  public:
TransformOpOptimizer()114   TransformOpOptimizer() : Pass("transform_op_optimizer") { Init(); }
115   ~TransformOpOptimizer() = default;
116   bool Run(const FuncGraphPtr &func_graph) override;
117 
118  protected:
119   std::vector<TransformOpPtr> CreateOpHandles(const inner::LiteGraphPtr &litegraph) const;
120   bool Process(const inner::LiteGraphPtr &litegraph, const TransformOpPtr &op_handle) const;
121   void ReInfer(const inner::LiteGraphPtr &litegraph, const std::set<NodePtr> &nodes_may_change) const;
122   void Init();
123   std::vector<TransformOpCreator> supported_ops_;
124 };
125 }  // namespace mindspore::graphkernel
126 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_TRANSFORM_OP_OPTIMIZER_H_
127