1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CORE_IR_FUNCTOR_H_ 18 #define MINDSPORE_CORE_IR_FUNCTOR_H_ 19 20 #include <string> 21 #include <memory> 22 #include <vector> 23 #include <utility> 24 #include <set> 25 26 #include "ir/value.h" 27 #include "utils/hash_set.h" 28 #include "utils/hash_map.h" 29 #include "mindapi/base/shape_vector.h" 30 31 namespace mindspore { 32 /// \brief Functor is a Value object to hold the c++ functors that supports exporting and importing mindir. 33 class MS_CORE_API Functor : public Value { 34 public: 35 /// \brief Constructor of Functor. 36 /// 37 /// \param[in] name The name of functor class Functor(const std::string & name)38 explicit Functor(const std::string &name) : name_(name) {} 39 /// \brief Destructor of Functor. 40 ~Functor() override = default; MS_DECLARE_PARENT(Functor,Value)41 MS_DECLARE_PARENT(Functor, Value) 42 43 /// \brief Get the name of functor object 44 /// 45 /// \return The name of functor object 46 const std::string &name() const { return name_; } 47 48 /// \brief Pack member variables to a Value, it's the inverse operation of FromValue. 49 /// 50 /// \return ValuePtr that packed member variables. 51 virtual ValuePtr ToValue() const = 0; 52 53 /// \brief Unpack member variables from Value, it's the inverse operation of ToValue. 54 virtual void FromValue(const ValuePtr &) = 0; 55 56 /// \brief The hash value of the Functor object. 57 /// 58 /// \return The hash value. hash()59 std::size_t hash() const override { return tid(); } 60 61 /// \brief Check whether the input is the current Value object. 62 /// 63 /// \param[in] rhs The Value object to be compared. 64 /// \return Whether the input is the current Value object. 65 bool operator==(const Value &rhs) const override { return &rhs == this; } 66 67 /// \brief Get abstract of the Functor object. abstract of Functor is unavailable. ToAbstract()68 abstract::AbstractBasePtr ToAbstract() override { 69 MS_LOG(INTERNAL_EXCEPTION) << "Functor[" << name() << "] can't be converted to abstract."; 70 } 71 72 /// \brief Show the Functor object. 73 /// 74 /// \return The description of the Functor object. ToString()75 std::string ToString() const override { 76 auto value = ToValue(); 77 if (value == nullptr) { 78 value = kNone; 79 } 80 return "Functor[" + name() + "]{" + value->ToString() + "}"; 81 } 82 83 protected: 84 std::string name_; 85 }; 86 using FunctorPtr = std::shared_ptr<Functor>; 87 88 // std::vector<int64_t> -> The shapes of Calc output. 89 // bool -> Whether the shapes is a dynamic sequence one or not. 90 using InferOutputInfo = std::pair<std::vector<int64_t>, bool>; 91 // For ShapeArray: 92 // 1. If all input only have one item, ElemPosIdx can be ignored. 93 // 2. If one input may contain more than one item, ElemPosIdx should be considered. For example, 94 // For inputs (tuple0[item*2], item1, tuple2[item*3]), 95 // ShapeArray of inputs is {a, b, c, d, e, f} and 96 // ElemPosIdx is {[0,1], [2], [3,4,5]}, 97 // where tuple[item*2] -> {a, b}, item1 -> c, tuple2 -> {d, e, f}. 98 using ElemPosIdx = std::vector<std::vector<size_t>>; 99 /// \brief ShapeCalcBaseFunctor is the functor of operator ShapeCalc that encapsulate its Infer and Calc functions. The 100 /// shape-input of ShapeCalcBaseFunctor can be a tuple one, and the number of output can be dynamic. 101 class MS_CORE_API ShapeCalcBaseFunctor : public Functor { 102 public: 103 /// \brief Constructor of ShapeCalcBaseFunctor. ShapeCalcBaseFunctor(const std::string & name)104 explicit ShapeCalcBaseFunctor(const std::string &name) : Functor(name) {} 105 106 /// \brief Destructor of ShapeCalcBaseFunctor. 107 ~ShapeCalcBaseFunctor() override = default; 108 MS_DECLARE_PARENT(ShapeCalcBaseFunctor, Functor) 109 110 /// \brief Calculate shapes. It's the real calculation of ShapeCalc kernel. 111 /// \param[in] inputs The inputs. 112 /// \param[in] pos_idx If input contain tuple cases, pos_idx will tell the real elements' index of it. 113 /// \return Result shapes. 114 virtual ShapeArray Calc(const ShapeArray &inputs, const ElemPosIdx &pos_idx) const = 0; 115 116 /// \brief The InferShape implementation of ShapeCalc primitive. 117 /// \param[in] inputs The inputs. 118 /// \param[in] unknown_inputs If i exists in 'unknown_inputs', the shape value of inputs[i] is unknown. 119 /// \param[in] pos_idx If input contain tuple cases, pos_idx will tell the real elements' index of it. 120 /// \return A pair composited with length of each shape that returned by Calc and whether the number of Calc output is 121 /// unknown. 122 virtual InferOutputInfo Infer(const ShapeArray &inputs, const HashSet<size_t> &unknown_inputs, 123 const ElemPosIdx &pos_idx) const = 0; 124 }; 125 using ShapeCalcBaseFunctorPtr = std::shared_ptr<ShapeCalcBaseFunctor>; 126 127 /// \brief ShapeCalcFunctor is the functor of operator ShapeCalc that encapsulate its Infer and Calc functions. The 128 /// shape-input of ShapeCalcFunctor should be a scalar or a tensor. 129 class MS_CORE_API ShapeCalcFunctor : public ShapeCalcBaseFunctor { 130 public: 131 /// \brief Constructor of ShapeCalcFunctor. ShapeCalcFunctor(const std::string & name)132 explicit ShapeCalcFunctor(const std::string &name) : ShapeCalcBaseFunctor(name) {} 133 134 /// \brief Destructor of ShapeCalcFunctor. 135 ~ShapeCalcFunctor() override = default; 136 MS_DECLARE_PARENT(ShapeCalcFunctor, ShapeCalcBaseFunctor) 137 138 /// \brief Calculate shapes. It's the real calculation of ShapeCalc kernel. 139 /// \param[in] inputs The inputs. 140 /// \return Result shapes. 141 virtual ShapeArray Calc(const ShapeArray &inputs) const = 0; 142 143 /// \brief The InferShape implementation of ShapeCalc primitive. 144 /// \param[in] inputs The inputs. 145 /// \param[in] unknown_inputs If i exists in 'unknown_inputs', the shape value of inputs[i] is unknown. 146 /// \return Length of each shape that returned by Calc. 147 virtual std::vector<int64_t> Infer(const ShapeArray &inputs, const HashSet<size_t> &unknown_inputs) const = 0; 148 Calc(const ShapeArray & inputs,const ElemPosIdx &)149 ShapeArray Calc(const ShapeArray &inputs, const ElemPosIdx &) const final { return Calc(inputs); } Infer(const ShapeArray & inputs,const HashSet<size_t> & unknown_inputs,const ElemPosIdx & pos_idx)150 InferOutputInfo Infer(const ShapeArray &inputs, const HashSet<size_t> &unknown_inputs, 151 const ElemPosIdx &pos_idx) const final { 152 auto lengths = Infer(inputs, unknown_inputs); 153 return std::make_pair(lengths, false); 154 } 155 }; 156 using ShapeCalcFunctorPtr = std::shared_ptr<ShapeCalcFunctor>; 157 158 // common code to declare ShapeCalcFunctor 159 #define DECLARE_SHAPE_CALC(reg_name, cls) \ 160 cls() : ShapeCalcFunctor(reg_name) {} \ 161 ~cls() override = default; \ 162 MS_DECLARE_PARENT(cls, ShapeCalcFunctor) 163 164 /// \brief FunctorRegistry is the registry of functors to support importing functor from mindir. 165 class MS_CORE_API FunctorRegistry { 166 public: 167 using Creator = std::function<FunctorPtr()>; Instance()168 static FunctorRegistry &Instance() { 169 static FunctorRegistry ins{}; 170 return ins; 171 } GetCreator(const std::string & name)172 Creator GetCreator(const std::string &name) const { 173 auto iter = reg.find(name); 174 return iter == reg.end() ? nullptr : iter->second; 175 } 176 class RegCls { 177 public: RegCls(const std::string & name,const Creator & creator)178 RegCls(const std::string &name, const Creator &creator) { FunctorRegistry::Instance().Register(name, creator); } 179 ~RegCls() = default; 180 }; 181 Register(const std::string & name,const Creator & creator)182 void Register(const std::string &name, const Creator &creator) { 183 auto ret = reg.insert({name, creator}); 184 if (!ret.second) { 185 MS_LOG(WARNING) << "Duplicated functor is registered. name: " << name; 186 } else { 187 MS_LOG(DEBUG) << "Register functor: " << name; 188 } 189 } 190 191 private: 192 FunctorRegistry() = default; 193 ~FunctorRegistry() = default; 194 HashMap<std::string, Creator> reg; 195 }; 196 197 #define REG_FUNCTOR(name, cls) \ 198 static const FunctorRegistry::RegCls g_functor_##cls((name), []() -> FunctorPtr { return std::make_shared<cls>(); }) 199 } // namespace mindspore 200 #endif // MINDSPORE_CORE_IR_FUNCTOR_H_ 201