1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CORE_SYMBOLIC_SHAPE_OPERATION_H_ 17 #define MINDSPORE_CORE_SYMBOLIC_SHAPE_OPERATION_H_ 18 #include <memory> 19 #include <vector> 20 #include <string> 21 #include <utility> 22 #include <map> 23 24 #include "mindspore/core/symbolic_shape/symbol.h" 25 #include "mindspore/core/symbolic_shape/int_symbol.h" 26 27 namespace mindspore { 28 namespace symshape { 29 /// \brief Operation is the basic class of operators for symbol. 30 class MS_CORE_API Operation : public Base { 31 public: Operation(SymbolPtrList && inputs)32 explicit Operation(SymbolPtrList &&inputs) : inputs_(inputs) {} 33 virtual ~Operation() = default; 34 MS_DECLARE_PARENT(Operation, Base) 35 36 bool Build(); Run()37 inline void Run() { 38 MS_LOG(DEBUG) << "Running operation " << ToString(); 39 MS_EXCEPTION_IF_NULL(output_); 40 EvalOnRun(); 41 MS_LOG(DEBUG) << "Run result of [" << name() << "] : " << output_->ToString(); 42 } 43 44 virtual bool EqualsTo(const OpPtr &other); 45 inputs()46 const SymbolPtrList &inputs() const { return inputs_; } input(size_t i)47 const SymbolPtr &input(size_t i) const { 48 if (MS_UNLIKELY(i >= inputs_.size())) { 49 MS_LOG(INTERNAL_EXCEPTION) << "The index " << i << " is out of range of the inputs size " << inputs_.size(); 50 } 51 return inputs_[i]; 52 } input_num()53 size_t input_num() const { return inputs_.size(); } output()54 const SymbolPtr &output() const { return output_; } 55 template <typename T> input_as(size_t i)56 const T *input_as(size_t i) const { 57 const T *p = input(i)->as_noexcept<T>(); 58 if (MS_UNLIKELY(p == nullptr)) { 59 MS_LOG(INTERNAL_EXCEPTION) << "Convert failed for input " << i << " of " << name(); 60 } 61 return p; 62 } 63 template <typename T> input_as_sptr(size_t i)64 std::shared_ptr<T> input_as_sptr(size_t i) const { 65 auto p = input(i)->as_sptr_noexcept<T>(); 66 if (MS_UNLIKELY(p == nullptr)) { 67 MS_LOG(INTERNAL_EXCEPTION) << "Convert failed for input " << i << " of " << name(); 68 } 69 return p; 70 } 71 template <typename T> output_as()72 T *output_as() const { 73 if (MS_UNLIKELY(output_ == nullptr)) { 74 MS_LOG(INTERNAL_EXCEPTION) << "The output of " << name() << " is not initialized."; 75 } 76 T *p = output_->as_noexcept<T>(); 77 if (MS_UNLIKELY(p == nullptr)) { 78 MS_LOG(INTERNAL_EXCEPTION) << "Convert failed for output of " << name(); 79 } 80 return p; 81 } 82 need_eval()83 bool need_eval() const { return need_eval_; } 84 name()85 virtual std::string name() const { return type_name(); } 86 87 // overwrite shared_from_this to get OpPtr directly shared_from_this()88 OpPtr shared_from_this() { return shared_from_base<Operation>(); } 89 std::string ToString() const override; 90 std::string DumpText() const override; 91 92 class Emitter { 93 public: ops_(op_list)94 explicit Emitter(OpPtrList *op_list = nullptr) : ops_(op_list) {} 95 ~Emitter() = default; 96 SymbolPtr Emit(const OpPtr &op) const; Clean()97 void Clean() { Emitter::cse_cache_.clear(); } 98 99 private: 100 void Cse(const OpPtr &op); 101 static inline std::map<std::string, OpPtrList> cse_cache_; 102 OpPtrList *ops_; 103 }; 104 friend class OperationBuilder; 105 106 protected: 107 virtual SymbolPtr Eval() = 0; EvalOnRun()108 virtual void EvalOnRun() { output_->Update(Eval()); } UpdateMathInfo()109 virtual void UpdateMathInfo() {} 110 is_building()111 bool is_building() const { return is_building_; } DoNotEvalOnRun()112 void DoNotEvalOnRun() { 113 if (is_building_) { 114 need_eval_ = false; 115 } 116 } 117 SetEmitter(const Emitter * e)118 void SetEmitter(const Emitter *e) { emitter_ = e; } emitter()119 const Emitter &emitter() const { 120 static Emitter e(nullptr); 121 return emitter_ != nullptr ? *emitter_ : e; 122 } Emit(const OpPtr & op)123 SymbolPtr Emit(const OpPtr &op) const { return emitter().Emit(op); } 124 ResultIntList(SymbolPtrList && result)125 SymbolPtr ResultIntList(SymbolPtrList &&result) { 126 if (is_building()) { 127 return GenList(result); 128 } 129 output_as<ListSymbol>()->UpdateList(result); 130 return nullptr; 131 } 132 GenInt(int64_t v)133 SymbolPtr GenInt(int64_t v) { return IntSymbol::Make(v, shared_from_this()); } GenVInt()134 SymbolPtr GenVInt() { return IntSymbol::Make(shared_from_this()); } GenList(const SymbolPtrList & list)135 SymbolPtr GenList(const SymbolPtrList &list) { return ListSymbol::Make(list, shared_from_this()); } GenList(SymbolPtrList && list)136 SymbolPtr GenList(SymbolPtrList &&list) { return ListSymbol::Make(list, shared_from_this()); } GenList(const std::initializer_list<SymbolPtr> & list)137 SymbolPtr GenList(const std::initializer_list<SymbolPtr> &list) { return ListSymbol::Make(list, shared_from_this()); } GenVList()138 SymbolPtr GenVList() { return ListSymbol::Make(shared_from_this()); } 139 GenVIntList(size_t n)140 SymbolPtr GenVIntList(size_t n) { 141 SymbolPtrList list(n); 142 std::generate(list.begin(), list.end(), [this]() { return this->GenVInt(); }); 143 return GenList(std::move(list)); 144 } 145 146 // output abstract only can be used on building SetOutAbstract(const abstract::AbstractBasePtr & abs)147 void SetOutAbstract(const abstract::AbstractBasePtr &abs) { out_abstract_ = abs; } out_abstract()148 const abstract::AbstractBasePtr &out_abstract() const { 149 MS_EXCEPTION_IF_NULL(out_abstract_); 150 return out_abstract_; 151 } 152 153 SymbolPtr output_{nullptr}; 154 bool support_commutative_law_{false}; 155 156 private: 157 const Emitter *emitter_{nullptr}; 158 SymbolPtrList inputs_; 159 bool is_building_{true}; 160 bool need_eval_{true}; 161 abstract::AbstractBasePtr out_abstract_{nullptr}; 162 }; 163 using OperationEmitter = Operation::Emitter; 164 } // namespace symshape 165 } // namespace mindspore 166 #endif // MINDSPORE_CORE_SYMBOLIC_SHAPE_OPERATION_H_ 167