• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_FUNCTOR_H_
18 #define MINDSPORE_CORE_IR_FUNCTOR_H_
19 
20 #include <string>
21 #include <memory>
22 #include <vector>
23 #include <utility>
24 #include <set>
25 
26 #include "ir/value.h"
27 #include "utils/hash_set.h"
28 #include "utils/hash_map.h"
29 #include "mindapi/base/shape_vector.h"
30 
31 namespace mindspore {
32 /// \brief Functor is a Value object to hold the c++ functors that supports exporting and importing mindir.
33 class MS_CORE_API Functor : public Value {
34  public:
35   /// \brief Constructor of Functor.
36   ///
37   /// \param[in] name The name of functor class
Functor(const std::string & name)38   explicit Functor(const std::string &name) : name_(name) {}
39   /// \brief Destructor of Functor.
40   ~Functor() override = default;
MS_DECLARE_PARENT(Functor,Value)41   MS_DECLARE_PARENT(Functor, Value)
42 
43   /// \brief  Get the name of functor object
44   ///
45   /// \return The name of functor object
46   const std::string &name() const { return name_; }
47 
48   /// \brief Pack member variables to a Value, it's the inverse operation of FromValue.
49   ///
50   /// \return ValuePtr that packed member variables.
51   virtual ValuePtr ToValue() const = 0;
52 
53   /// \brief Unpack member variables from Value, it's the inverse operation of ToValue.
54   virtual void FromValue(const ValuePtr &) = 0;
55 
56   /// \brief The hash value of the Functor object.
57   ///
58   /// \return The hash value.
hash()59   std::size_t hash() const override { return tid(); }
60 
61   /// \brief Check whether the input is the current Value object.
62   ///
63   /// \param[in] rhs The Value object to be compared.
64   /// \return Whether the input is the current Value object.
65   bool operator==(const Value &rhs) const override { return &rhs == this; }
66 
67   /// \brief Get abstract of the Functor object. abstract of Functor is unavailable.
ToAbstract()68   abstract::AbstractBasePtr ToAbstract() override {
69     MS_LOG(INTERNAL_EXCEPTION) << "Functor[" << name() << "] can't be converted to abstract.";
70   }
71 
72   /// \brief Show the Functor object.
73   ///
74   /// \return The description of the Functor object.
ToString()75   std::string ToString() const override {
76     auto value = ToValue();
77     if (value == nullptr) {
78       value = kNone;
79     }
80     return "Functor[" + name() + "]{" + value->ToString() + "}";
81   }
82 
83  protected:
84   std::string name_;
85 };
86 using FunctorPtr = std::shared_ptr<Functor>;
87 
88 // std::vector<int64_t> -> The shapes of Calc output.
89 // bool -> Whether the shapes is a dynamic sequence one or not.
90 using InferOutputInfo = std::pair<std::vector<int64_t>, bool>;
91 // For ShapeArray:
92 // 1. If all input only have one item, ElemPosIdx can be ignored.
93 // 2. If one input may contain more than one item, ElemPosIdx should be considered. For example,
94 //    For inputs (tuple0[item*2], item1, tuple2[item*3]),
95 //      ShapeArray of inputs is {a, b, c, d, e, f} and
96 //      ElemPosIdx is {[0,1], [2], [3,4,5]},
97 //    where tuple[item*2] -> {a, b}, item1 -> c, tuple2 -> {d, e, f}.
98 using ElemPosIdx = std::vector<std::vector<size_t>>;
99 /// \brief ShapeCalcBaseFunctor is the functor of operator ShapeCalc that encapsulate its Infer and Calc functions. The
100 /// shape-input of ShapeCalcBaseFunctor can be a tuple one, and the number of output can be dynamic.
101 class MS_CORE_API ShapeCalcBaseFunctor : public Functor {
102  public:
103   /// \brief Constructor of ShapeCalcBaseFunctor.
ShapeCalcBaseFunctor(const std::string & name)104   explicit ShapeCalcBaseFunctor(const std::string &name) : Functor(name) {}
105 
106   /// \brief Destructor of ShapeCalcBaseFunctor.
107   ~ShapeCalcBaseFunctor() override = default;
108   MS_DECLARE_PARENT(ShapeCalcBaseFunctor, Functor)
109 
110   /// \brief Calculate shapes. It's the real calculation of ShapeCalc kernel.
111   /// \param[in] inputs The inputs.
112   /// \param[in] pos_idx If input contain tuple cases, pos_idx will tell the real elements' index of it.
113   /// \return Result shapes.
114   virtual ShapeArray Calc(const ShapeArray &inputs, const ElemPosIdx &pos_idx) const = 0;
115 
116   /// \brief The InferShape implementation of ShapeCalc primitive.
117   /// \param[in] inputs The inputs.
118   /// \param[in] unknown_inputs If i exists in 'unknown_inputs', the shape value of inputs[i] is unknown.
119   /// \param[in] pos_idx If input contain tuple cases, pos_idx will tell the real elements' index of it.
120   /// \return A pair composited with length of each shape that returned by Calc and whether the number of Calc output is
121   /// unknown.
122   virtual InferOutputInfo Infer(const ShapeArray &inputs, const HashSet<size_t> &unknown_inputs,
123                                 const ElemPosIdx &pos_idx) const = 0;
124 };
125 using ShapeCalcBaseFunctorPtr = std::shared_ptr<ShapeCalcBaseFunctor>;
126 
127 /// \brief ShapeCalcFunctor is the functor of operator ShapeCalc that encapsulate its Infer and Calc functions. The
128 /// shape-input of ShapeCalcFunctor should be a scalar or a tensor.
129 class MS_CORE_API ShapeCalcFunctor : public ShapeCalcBaseFunctor {
130  public:
131   /// \brief Constructor of ShapeCalcFunctor.
ShapeCalcFunctor(const std::string & name)132   explicit ShapeCalcFunctor(const std::string &name) : ShapeCalcBaseFunctor(name) {}
133 
134   /// \brief Destructor of ShapeCalcFunctor.
135   ~ShapeCalcFunctor() override = default;
136   MS_DECLARE_PARENT(ShapeCalcFunctor, ShapeCalcBaseFunctor)
137 
138   /// \brief Calculate shapes. It's the real calculation of ShapeCalc kernel.
139   /// \param[in] inputs The inputs.
140   /// \return Result shapes.
141   virtual ShapeArray Calc(const ShapeArray &inputs) const = 0;
142 
143   /// \brief The InferShape implementation of ShapeCalc primitive.
144   /// \param[in] inputs The inputs.
145   /// \param[in] unknown_inputs If i exists in 'unknown_inputs', the shape value of inputs[i] is unknown.
146   /// \return Length of each shape that returned by Calc.
147   virtual std::vector<int64_t> Infer(const ShapeArray &inputs, const HashSet<size_t> &unknown_inputs) const = 0;
148 
Calc(const ShapeArray & inputs,const ElemPosIdx &)149   ShapeArray Calc(const ShapeArray &inputs, const ElemPosIdx &) const final { return Calc(inputs); }
Infer(const ShapeArray & inputs,const HashSet<size_t> & unknown_inputs,const ElemPosIdx & pos_idx)150   InferOutputInfo Infer(const ShapeArray &inputs, const HashSet<size_t> &unknown_inputs,
151                         const ElemPosIdx &pos_idx) const final {
152     auto lengths = Infer(inputs, unknown_inputs);
153     return std::make_pair(lengths, false);
154   }
155 };
156 using ShapeCalcFunctorPtr = std::shared_ptr<ShapeCalcFunctor>;
157 
158 // common code to declare ShapeCalcFunctor
159 #define DECLARE_SHAPE_CALC(reg_name, cls) \
160   cls() : ShapeCalcFunctor(reg_name) {}   \
161   ~cls() override = default;              \
162   MS_DECLARE_PARENT(cls, ShapeCalcFunctor)
163 
164 /// \brief FunctorRegistry is the registry of functors to support importing functor from mindir.
165 class MS_CORE_API FunctorRegistry {
166  public:
167   using Creator = std::function<FunctorPtr()>;
Instance()168   static FunctorRegistry &Instance() {
169     static FunctorRegistry ins{};
170     return ins;
171   }
GetCreator(const std::string & name)172   Creator GetCreator(const std::string &name) const {
173     auto iter = reg.find(name);
174     return iter == reg.end() ? nullptr : iter->second;
175   }
176   class RegCls {
177    public:
RegCls(const std::string & name,const Creator & creator)178     RegCls(const std::string &name, const Creator &creator) { FunctorRegistry::Instance().Register(name, creator); }
179     ~RegCls() = default;
180   };
181 
Register(const std::string & name,const Creator & creator)182   void Register(const std::string &name, const Creator &creator) {
183     auto ret = reg.insert({name, creator});
184     if (!ret.second) {
185       MS_LOG(WARNING) << "Duplicated functor is registered. name: " << name;
186     } else {
187       MS_LOG(DEBUG) << "Register functor: " << name;
188     }
189   }
190 
191  private:
192   FunctorRegistry() = default;
193   ~FunctorRegistry() = default;
194   HashMap<std::string, Creator> reg;
195 };
196 
197 #define REG_FUNCTOR(name, cls) \
198   static const FunctorRegistry::RegCls g_functor_##cls((name), []() -> FunctorPtr { return std::make_shared<cls>(); })
199 }  // namespace mindspore
200 #endif  // MINDSPORE_CORE_IR_FUNCTOR_H_
201