1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_UTILS_H_ 17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_UTILS_H_ 18 19 #include <string> 20 #include <memory> 21 #include <vector> 22 23 #include "backend/optimizer/graph_kernel/model/lite_graph.h" 24 #include "backend/optimizer/graph_kernel/model/node.h" 25 26 namespace mindspore { 27 namespace opt { 28 namespace expanders { 29 using graphkernel::NodePtrList; 30 using BaseInfoList = std::vector<graphkernel::NodeBase>; 31 class Validator; 32 33 class OpExpander { 34 public: 35 graphkernel::LiteGraphPtr Run(const BaseInfoList &inputs, const BaseInfoList &outputs, 36 const graphkernel::DAttrs &attrs, const std::string &processor); 37 virtual ~OpExpander() = default; 38 39 protected: CheckInputs()40 virtual bool CheckInputs() { return true; } 41 virtual NodePtrList Expand() = 0; 42 bool CheckOutputs(); 43 44 graphkernel::LiteGraph::GraphBuilder gb; 45 std::string op_; 46 BaseInfoList inputs_info_; 47 BaseInfoList outputs_info_; 48 graphkernel::DAttrs attrs_; 49 std::string processor_; 50 std::vector<std::unique_ptr<Validator>> validators_; 51 52 friend class OpExpanderFactory; 53 friend class CheckAllFormatsSame; 54 friend class CheckAttr; 55 friend class SupportFormat; 56 }; 57 58 class Validator { 59 public: 60 virtual bool Check(const OpExpander &e) = 0; 61 }; 62 63 class CheckAllFormatsSame : public Validator { 64 public: Check(const OpExpander & e)65 bool Check(const OpExpander &e) override { 66 if (e.inputs_info_.empty()) return true; 67 const auto &fmt_0 = e.inputs_info_[0].format; 68 for (size_t i = 1; i < e.inputs_info_.size(); i++) { 69 if (e.inputs_info_[i].format != fmt_0) { 70 MS_LOG(INFO) << "Unmatched format for op " << e.op_; 71 return false; 72 } 73 } 74 return true; 75 } 76 }; 77 78 class CheckAttr : public Validator { 79 public: CheckAttr(std::initializer_list<std::string> l)80 CheckAttr(std::initializer_list<std::string> l) : attrs_(l) {} 81 ~CheckAttr() = default; Check(const OpExpander & e)82 bool Check(const OpExpander &e) override { 83 for (auto &a : attrs_) { 84 if (e.attrs_.count(a) == 0) { 85 MS_LOG(INFO) << "attr " << a << " does not exist. op " << e.op_; 86 return false; 87 } 88 } 89 return true; 90 } 91 92 private: 93 std::vector<std::string> attrs_; 94 }; 95 96 class SupportFormat : public Validator { 97 public: AddFormat(std::initializer_list<std::string> l)98 void AddFormat(std::initializer_list<std::string> l) { formats_.emplace_back(l); } Check(const OpExpander & e)99 bool Check(const OpExpander &e) override { 100 for (auto &formats : formats_) { 101 if (formats.size() != e.inputs_info_.size()) { 102 continue; 103 } 104 bool match = true; 105 for (size_t i = 0; i < formats.size(); i++) { 106 if (e.inputs_info_[i].format != formats[i]) { 107 match = false; 108 break; 109 } 110 } 111 if (match) { 112 return true; 113 } 114 } 115 MS_LOG(INFO) << "unsupported format for op " << e.op_; 116 return false; 117 } 118 119 private: 120 std::vector<std::vector<std::string>> formats_; 121 }; 122 123 std::vector<int64_t> GetAxisList(const ValuePtr &value); 124 ShapeVector ExpandDimsInferShape(const ShapeVector &shape, const std::vector<int64_t> &axis); 125 } // namespace expanders 126 } // namespace opt 127 } // namespace mindspore 128 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_EXPANDERS_UTILS_H_ 129