• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_COMMON_EXPANDER_FALLBACK_FALLBACK_IRBUILDER_H_
18 #define MINDSPORE_CCSRC_BACKEND_COMMON_EXPANDER_FALLBACK_FALLBACK_IRBUILDER_H_
19 #include <memory>
20 #include <string>
21 #include <unordered_set>
22 #include <functional>
23 #include <vector>
24 
25 #include "include/common/expander/core/node.h"
26 #include "include/common/expander/core/emitter.h"
27 
28 namespace mindspore {
29 namespace expander {
30 using SelectKernelFunc = std::function<bool(const CNodePtr &)>;
31 class FallbackIRBuilder;
32 using IRBuilderFunc = std::function<NodePtrList(FallbackIRBuilder *)>;
33 struct IRBuilderHandle {
34   IRBuilderFunc func;
35 };
36 
37 class FallbackIRBuilder : public IrEmitter {
38  public:
39   FallbackIRBuilder(const std::string &name, const FuncGraphPtr &fg, const SelectKernelFunc &func);
40   ~FallbackIRBuilder() override = default;
41 
42   AnfNodePtr Run(const CNodePtr &cnode, const IRBuilderHandle &handle);
43   ValuePtr GetAttr(const std::string &attr) const;
GetDtype(const NodePtr & node)44   TypePtr GetDtype(const NodePtr &node) const { return node->dtype(); }
45   template <typename S>
GetAttr(const std::string & attr)46   S GetAttr(const std::string &attr) const {
47     return GetValue<S>(GetAttr(attr));
48   }
GetAttrs()49   const mindspore::HashMap<std::string, ValuePtr> &GetAttrs() const { return *attrs_ptr_; }
GetInput(size_t i)50   NodePtr GetInput(size_t i) const {
51     if (i >= inputs_.size()) {
52       MS_LOG(EXCEPTION) << "For " << name_ << ", the index " << i << " is out of range of inputs size "
53                         << inputs_.size();
54     }
55     return inputs_[i];
56   }
GetInputs()57   const NodePtrList &GetInputs() const { return inputs_; }
58   int64_t GetSize(const NodePtr &node) const;
GetShape(const NodePtr & node)59   ShapeVector GetShape(const NodePtr &node) const { return node->shape(); }
60   NodePtr DynSize(const NodePtr &node, const TypePtr &type);
61   NodePtr DynSize(const NodePtr &node);
62   NodePtr SequenceToTensor(const NodePtr &node, const TypePtr &dtype = kInt64);
63   std::vector<int64_t> GetIntList(const ValuePtr &value);
64   std::vector<int64_t> GetIntList(const NodePtr &node);
65 
66  protected:
67   std::string name_;
68   NodePtrList inputs_;
69   const mindspore::HashMap<std::string, ValuePtr> *attrs_ptr_{nullptr};
70   bool success_{true};
71 };
72 
73 class IRBuilderFactory {
74  public:
Instance()75   static IRBuilderFactory &Instance() {
76     static IRBuilderFactory instance{};
77     return instance;
78   }
79 
GetBuilder(const std::string & name)80   const IRBuilderHandle *GetBuilder(const std::string &name) const {
81     auto iter = registry_.find(name);
82     return (iter == registry_.end()) ? nullptr : &(iter->second);
83   }
84 
85   class RegHelper {
86    public:
RegHelper(const std::string & name)87     explicit RegHelper(const std::string &name) : name_(name) {}
88     ~RegHelper() = default;
SetBody(const IRBuilderFunc & func)89     const RegHelper &SetBody(const IRBuilderFunc &func) const {
90       IRBuilderFactory::Instance().registry_[name_].func = func;
91       return *this;
92     }
93 
94    private:
95     std::string name_;
96   };
97 
98  private:
99   HashMap<std::string, IRBuilderHandle> registry_;
100 };
101 
102 #define EXPANDER_JOIN(x, y) x##y
103 #define EXPANDER_UNIQUE_NAME(prefix, cnt) EXPANDER_JOIN(prefix, cnt)
104 #define REG_FALLBACK_BUILDER(name) \
105   static const IRBuilderFactory::RegHelper EXPANDER_UNIQUE_NAME(g_fbib, __COUNTER__) = IRBuilderFactory::RegHelper(name)
106 #define BODYFUNC(v) [](FallbackIRBuilder * v) -> NodePtrList
107 }  // namespace expander
108 }  // namespace mindspore
109 #endif  // MINDSPORE_CCSRC_BACKEND_COMMON_EXPANDER_FALLBACK_FALLBACK_IRBUILDER_H_
110