• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_UTILS_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_UTILS_H_
18 
19 #include <string>
20 #include <memory>
21 #include <vector>
22 
23 #include "backend/optimizer/graph_kernel/model/lite_graph.h"
24 #include "backend/optimizer/graph_kernel/model/node.h"
25 
26 namespace mindspore {
27 namespace opt {
28 namespace expanders {
29 using graphkernel::NodePtrList;
30 using BaseInfoList = std::vector<graphkernel::NodeBase>;
31 class Validator;
32 
33 class OpExpander {
34  public:
35   graphkernel::LiteGraphPtr Run(const BaseInfoList &inputs, const BaseInfoList &outputs,
36                                 const graphkernel::DAttrs &attrs, const std::string &processor);
37   virtual ~OpExpander() = default;
38 
39  protected:
CheckInputs()40   virtual bool CheckInputs() { return true; }
41   virtual NodePtrList Expand() = 0;
42   bool CheckOutputs();
43 
44   graphkernel::LiteGraph::GraphBuilder gb;
45   std::string op_;
46   BaseInfoList inputs_info_;
47   BaseInfoList outputs_info_;
48   graphkernel::DAttrs attrs_;
49   std::string processor_;
50   std::vector<std::unique_ptr<Validator>> validators_;
51 
52   friend class OpExpanderFactory;
53   friend class CheckAllFormatsSame;
54   friend class CheckAttr;
55   friend class SupportFormat;
56 };
57 
58 class Validator {
59  public:
60   virtual bool Check(const OpExpander &e) = 0;
61 };
62 
63 class CheckAllFormatsSame : public Validator {
64  public:
Check(const OpExpander & e)65   bool Check(const OpExpander &e) override {
66     if (e.inputs_info_.empty()) return true;
67     const auto &fmt_0 = e.inputs_info_[0].format;
68     for (size_t i = 1; i < e.inputs_info_.size(); i++) {
69       if (e.inputs_info_[i].format != fmt_0) {
70         MS_LOG(INFO) << "Unmatched format for op " << e.op_;
71         return false;
72       }
73     }
74     return true;
75   }
76 };
77 
78 class CheckAttr : public Validator {
79  public:
CheckAttr(std::initializer_list<std::string> l)80   CheckAttr(std::initializer_list<std::string> l) : attrs_(l) {}
81   ~CheckAttr() = default;
Check(const OpExpander & e)82   bool Check(const OpExpander &e) override {
83     for (auto &a : attrs_) {
84       if (e.attrs_.count(a) == 0) {
85         MS_LOG(INFO) << "attr " << a << " does not exist. op " << e.op_;
86         return false;
87       }
88     }
89     return true;
90   }
91 
92  private:
93   std::vector<std::string> attrs_;
94 };
95 
96 class SupportFormat : public Validator {
97  public:
AddFormat(std::initializer_list<std::string> l)98   void AddFormat(std::initializer_list<std::string> l) { formats_.emplace_back(l); }
Check(const OpExpander & e)99   bool Check(const OpExpander &e) override {
100     for (auto &formats : formats_) {
101       if (formats.size() != e.inputs_info_.size()) {
102         continue;
103       }
104       bool match = true;
105       for (size_t i = 0; i < formats.size(); i++) {
106         if (e.inputs_info_[i].format != formats[i]) {
107           match = false;
108           break;
109         }
110       }
111       if (match) {
112         return true;
113       }
114     }
115     MS_LOG(INFO) << "unsupported format for op " << e.op_;
116     return false;
117   }
118 
119  private:
120   std::vector<std::vector<std::string>> formats_;
121 };
122 
123 std::vector<int64_t> GetAxisList(const ValuePtr &value);
124 ShapeVector ExpandDimsInferShape(const ShapeVector &shape, const std::vector<int64_t> &axis);
125 }  // namespace expanders
126 }  // namespace opt
127 }  // namespace mindspore
128 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_UTILS_H_
129