1 /** 2 * Copyright 2021-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_EXPANDERS_UTILS_H_ 17 #define MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_EXPANDERS_UTILS_H_ 18 19 #include <string> 20 #include <memory> 21 #include <vector> 22 #include <utility> 23 24 #include "backend/common/graph_kernel/model/lite_graph.h" 25 #include "backend/common/graph_kernel/model/node.h" 26 #include "backend/common/graph_kernel/model/graph_builder.h" 27 28 namespace mindspore::graphkernel::expanders { 29 using inner::NodePtr; 30 using inner::NodePtrList; 31 using BaseInfoList = std::vector<inner::NodeBase>; 32 class Validator; 33 34 class OpDesc { 35 public: 36 inner::LiteGraphPtr Run(const BaseInfoList &inputs, const BaseInfoList &outputs, const inner::DAttrs &attrs, 37 const std::string &processor); Name()38 const std::string &Name() const { return name_; } InputsInfo()39 const BaseInfoList &InputsInfo() const { return inputs_info_; } OutputsInfo()40 const BaseInfoList &OutputsInfo() const { return outputs_info_; } Attrs()41 const inner::DAttrs &Attrs() const { return attrs_; } Processor()42 const std::string &Processor() const { return processor_; } 43 virtual ~OpDesc() = default; 44 45 protected: Init()46 virtual void Init() {} CheckInputs()47 virtual bool CheckInputs() { return true; } 48 virtual NodePtrList Expand(const NodePtrList &inputs) = 0; 49 bool CheckOutputs(); 50 51 inner::GraphBuilder gb; 52 std::string name_; 53 BaseInfoList inputs_info_; 54 BaseInfoList outputs_info_; 55 inner::DAttrs attrs_; 56 std::string processor_; 57 std::vector<std::unique_ptr<Validator>> validators_; 58 59 friend class OpDescFactory; 60 }; 61 62 class Validator { 63 public: 64 virtual bool Check(const OpDesc &e) = 0; 65 virtual ~Validator() = default; 66 }; 67 68 class CheckAllFormatsSame : public Validator { 69 public: Check(const OpDesc & e)70 bool Check(const OpDesc &e) override { 71 const auto &inputs_info = e.InputsInfo(); 72 if (inputs_info.empty()) { 73 return true; 74 } 75 const auto &fmt_0 = inputs_info[0].format; 76 for (size_t i = 1; i < inputs_info.size(); i++) { 77 if (inputs_info[i].format != fmt_0) { 78 MS_LOG(INFO) << "Unmatched format for op " << e.Name(); 79 return false; 80 } 81 } 82 return true; 83 } 84 }; 85 86 class CheckAttr : public Validator { 87 public: CheckAttr(std::initializer_list<std::string> l)88 CheckAttr(std::initializer_list<std::string> l) : attrs_(std::move(l)) {} 89 virtual ~CheckAttr() = default; Check(const OpDesc & e)90 bool Check(const OpDesc &e) override { 91 for (auto &a : attrs_) { 92 if (e.Attrs().count(a) == 0) { 93 MS_LOG(INFO) << "attr " << a << " does not exist. op " << e.Name(); 94 return false; 95 } 96 } 97 return true; 98 } 99 100 private: 101 std::vector<std::string> attrs_; 102 }; 103 104 class SupportFormat : public Validator { 105 public: AddFormat(std::initializer_list<std::string> l)106 void AddFormat(std::initializer_list<std::string> l) { (void)formats_.emplace_back(l); } Check(const OpDesc & e)107 bool Check(const OpDesc &e) override { 108 for (auto &formats : formats_) { 109 if (formats.size() != e.InputsInfo().size()) { 110 continue; 111 } 112 bool match = true; 113 for (size_t i = 0; i < formats.size(); i++) { 114 if (e.InputsInfo()[i].format != formats[i]) { 115 match = false; 116 break; 117 } 118 } 119 if (match) { 120 return true; 121 } 122 } 123 MS_LOG(INFO) << "unsupported format for op " << e.Name(); 124 return false; 125 } 126 virtual ~SupportFormat() = default; 127 128 private: 129 std::vector<std::vector<std::string>> formats_; 130 }; 131 132 std::vector<int64_t> GetAxisList(const ValuePtr &value); 133 ShapeVector ExpandDimsInferShape(const ShapeVector &shape, const std::vector<int64_t> &axis); 134 NodePtr ReluExpand(const inner::GraphBuilder &gb, const NodePtrList &inputs); 135 NodePtr SigmoidExpand(const inner::GraphBuilder &gb, const NodePtrList &inputs); 136 NodePtr GeluExpand(const inner::GraphBuilder &gb, const NodePtrList &inputs); 137 std::vector<int64_t> InferShapeFromFractalnz(const std::vector<int64_t> &fractal); 138 std::vector<int64_t> GetReducedOriShape(const std::vector<int64_t> &shape, const std::vector<int64_t> &axis); 139 std::vector<int64_t> ToFracZAxis(const std::vector<int64_t> &ori_shape, const std::vector<int64_t> &ori_axis); 140 } // namespace mindspore::graphkernel::expanders 141 #endif // MINDSPORE_CCSRC_COMMON_GRAPH_KERNEL_EXPANDERS_UTILS_H_ 142