1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_COMMON_EXPANDER_FALLBACK_FALLBACK_IRBUILDER_H_ 18 #define MINDSPORE_CCSRC_BACKEND_COMMON_EXPANDER_FALLBACK_FALLBACK_IRBUILDER_H_ 19 #include <memory> 20 #include <string> 21 #include <unordered_set> 22 #include <functional> 23 #include <vector> 24 25 #include "include/common/expander/core/node.h" 26 #include "include/common/expander/core/emitter.h" 27 28 namespace mindspore { 29 namespace expander { 30 using SelectKernelFunc = std::function<bool(const CNodePtr &)>; 31 class FallbackIRBuilder; 32 using IRBuilderFunc = std::function<NodePtrList(FallbackIRBuilder *)>; 33 struct IRBuilderHandle { 34 IRBuilderFunc func; 35 }; 36 37 class FallbackIRBuilder : public IrEmitter { 38 public: 39 FallbackIRBuilder(const std::string &name, const FuncGraphPtr &fg, const SelectKernelFunc &func); 40 ~FallbackIRBuilder() override = default; 41 42 AnfNodePtr Run(const CNodePtr &cnode, const IRBuilderHandle &handle); 43 ValuePtr GetAttr(const std::string &attr) const; GetDtype(const NodePtr & node)44 TypePtr GetDtype(const NodePtr &node) const { return node->dtype(); } 45 template <typename S> GetAttr(const std::string & attr)46 S GetAttr(const std::string &attr) const { 47 return GetValue<S>(GetAttr(attr)); 48 } GetAttrs()49 const mindspore::HashMap<std::string, ValuePtr> &GetAttrs() const { return *attrs_ptr_; } GetInput(size_t i)50 NodePtr GetInput(size_t i) const { 51 if (i >= inputs_.size()) { 52 MS_LOG(EXCEPTION) << "For " << name_ << ", the index " << i << " is out of range of inputs size " 53 << inputs_.size(); 54 } 55 return inputs_[i]; 56 } GetInputs()57 const NodePtrList &GetInputs() const { return inputs_; } 58 int64_t GetSize(const NodePtr &node) const; GetShape(const NodePtr & node)59 ShapeVector GetShape(const NodePtr &node) const { return node->shape(); } 60 NodePtr DynSize(const NodePtr &node, const TypePtr &type); 61 NodePtr DynSize(const NodePtr &node); 62 NodePtr SequenceToTensor(const NodePtr &node, const TypePtr &dtype = kInt64); 63 std::vector<int64_t> GetIntList(const ValuePtr &value); 64 std::vector<int64_t> GetIntList(const NodePtr &node); 65 66 protected: 67 std::string name_; 68 NodePtrList inputs_; 69 const mindspore::HashMap<std::string, ValuePtr> *attrs_ptr_{nullptr}; 70 bool success_{true}; 71 }; 72 73 class IRBuilderFactory { 74 public: Instance()75 static IRBuilderFactory &Instance() { 76 static IRBuilderFactory instance{}; 77 return instance; 78 } 79 GetBuilder(const std::string & name)80 const IRBuilderHandle *GetBuilder(const std::string &name) const { 81 auto iter = registry_.find(name); 82 return (iter == registry_.end()) ? nullptr : &(iter->second); 83 } 84 85 class RegHelper { 86 public: RegHelper(const std::string & name)87 explicit RegHelper(const std::string &name) : name_(name) {} 88 ~RegHelper() = default; SetBody(const IRBuilderFunc & func)89 const RegHelper &SetBody(const IRBuilderFunc &func) const { 90 IRBuilderFactory::Instance().registry_[name_].func = func; 91 return *this; 92 } 93 94 private: 95 std::string name_; 96 }; 97 98 private: 99 HashMap<std::string, IRBuilderHandle> registry_; 100 }; 101 102 #define EXPANDER_JOIN(x, y) x##y 103 #define EXPANDER_UNIQUE_NAME(prefix, cnt) EXPANDER_JOIN(prefix, cnt) 104 #define REG_FALLBACK_BUILDER(name) \ 105 static const IRBuilderFactory::RegHelper EXPANDER_UNIQUE_NAME(g_fbib, __COUNTER__) = IRBuilderFactory::RegHelper(name) 106 #define BODYFUNC(v) [](FallbackIRBuilder * v) -> NodePtrList 107 } // namespace expander 108 } // namespace mindspore 109 #endif // MINDSPORE_CCSRC_BACKEND_COMMON_EXPANDER_FALLBACK_FALLBACK_IRBUILDER_H_ 110