• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_EXPANDERS_UTILS_H_
17 #define MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_EXPANDERS_UTILS_H_
18 
19 #include <string>
20 #include <memory>
21 #include <vector>
22 #include <utility>
23 
24 #include "backend/common/graph_kernel/model/lite_graph.h"
25 #include "backend/common/graph_kernel/model/node.h"
26 #include "backend/common/graph_kernel/model/graph_builder.h"
27 
28 namespace mindspore::graphkernel::expanders {
29 using inner::NodePtr;
30 using inner::NodePtrList;
31 using BaseInfoList = std::vector<inner::NodeBase>;
32 class Validator;
33 
34 class OpDesc {
35  public:
36   inner::LiteGraphPtr Run(const BaseInfoList &inputs, const BaseInfoList &outputs, const inner::DAttrs &attrs,
37                           const std::string &processor);
Name()38   const std::string &Name() const { return name_; }
InputsInfo()39   const BaseInfoList &InputsInfo() const { return inputs_info_; }
OutputsInfo()40   const BaseInfoList &OutputsInfo() const { return outputs_info_; }
Attrs()41   const inner::DAttrs &Attrs() const { return attrs_; }
Processor()42   const std::string &Processor() const { return processor_; }
43   virtual ~OpDesc() = default;
44 
45  protected:
Init()46   virtual void Init() {}
CheckInputs()47   virtual bool CheckInputs() { return true; }
48   virtual NodePtrList Expand(const NodePtrList &inputs) = 0;
49   bool CheckOutputs();
50 
51   inner::GraphBuilder gb;
52   std::string name_;
53   BaseInfoList inputs_info_;
54   BaseInfoList outputs_info_;
55   inner::DAttrs attrs_;
56   std::string processor_;
57   std::vector<std::unique_ptr<Validator>> validators_;
58 
59   friend class OpDescFactory;
60 };
61 
62 class Validator {
63  public:
64   virtual bool Check(const OpDesc &e) = 0;
65   virtual ~Validator() = default;
66 };
67 
68 class CheckAllFormatsSame : public Validator {
69  public:
Check(const OpDesc & e)70   bool Check(const OpDesc &e) override {
71     const auto &inputs_info = e.InputsInfo();
72     if (inputs_info.empty()) {
73       return true;
74     }
75     const auto &fmt_0 = inputs_info[0].format;
76     for (size_t i = 1; i < inputs_info.size(); i++) {
77       if (inputs_info[i].format != fmt_0) {
78         MS_LOG(INFO) << "Unmatched format for op " << e.Name();
79         return false;
80       }
81     }
82     return true;
83   }
84 };
85 
86 class CheckAttr : public Validator {
87  public:
CheckAttr(std::initializer_list<std::string> l)88   CheckAttr(std::initializer_list<std::string> l) : attrs_(std::move(l)) {}
89   virtual ~CheckAttr() = default;
Check(const OpDesc & e)90   bool Check(const OpDesc &e) override {
91     for (auto &a : attrs_) {
92       if (e.Attrs().count(a) == 0) {
93         MS_LOG(INFO) << "attr " << a << " does not exist. op " << e.Name();
94         return false;
95       }
96     }
97     return true;
98   }
99 
100  private:
101   std::vector<std::string> attrs_;
102 };
103 
104 class SupportFormat : public Validator {
105  public:
AddFormat(std::initializer_list<std::string> l)106   void AddFormat(std::initializer_list<std::string> l) { (void)formats_.emplace_back(l); }
Check(const OpDesc & e)107   bool Check(const OpDesc &e) override {
108     for (auto &formats : formats_) {
109       if (formats.size() != e.InputsInfo().size()) {
110         continue;
111       }
112       bool match = true;
113       for (size_t i = 0; i < formats.size(); i++) {
114         if (e.InputsInfo()[i].format != formats[i]) {
115           match = false;
116           break;
117         }
118       }
119       if (match) {
120         return true;
121       }
122     }
123     MS_LOG(INFO) << "unsupported format for op " << e.Name();
124     return false;
125   }
126   virtual ~SupportFormat() = default;
127 
128  private:
129   std::vector<std::vector<std::string>> formats_;
130 };
131 
132 std::vector<int64_t> GetAxisList(const ValuePtr &value);
133 ShapeVector ExpandDimsInferShape(const ShapeVector &shape, const std::vector<int64_t> &axis);
134 NodePtr ReluExpand(const inner::GraphBuilder &gb, const NodePtrList &inputs);
135 NodePtr SigmoidExpand(const inner::GraphBuilder &gb, const NodePtrList &inputs);
136 NodePtr GeluExpand(const inner::GraphBuilder &gb, const NodePtrList &inputs);
137 std::vector<int64_t> InferShapeFromFractalnz(const std::vector<int64_t> &fractal);
138 std::vector<int64_t> GetReducedOriShape(const std::vector<int64_t> &shape, const std::vector<int64_t> &axis);
139 std::vector<int64_t> ToFracZAxis(const std::vector<int64_t> &ori_shape, const std::vector<int64_t> &ori_axis);
140 }  // namespace mindspore::graphkernel::expanders
141 #endif  // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_EXPANDERS_UTILS_H_
142