• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CORE_SYMBOLIC_SHAPE_OPERATION_H_
17 #define MINDSPORE_CORE_SYMBOLIC_SHAPE_OPERATION_H_
18 #include <memory>
19 #include <vector>
20 #include <string>
21 #include <utility>
22 #include <map>
23 
24 #include "mindspore/core/symbolic_shape/symbol.h"
25 #include "mindspore/core/symbolic_shape/int_symbol.h"
26 
27 namespace mindspore {
28 namespace symshape {
29 /// \brief Operation is the basic class of operators for symbol.
30 class MS_CORE_API Operation : public Base {
31  public:
Operation(SymbolPtrList && inputs)32   explicit Operation(SymbolPtrList &&inputs) : inputs_(inputs) {}
33   virtual ~Operation() = default;
34   MS_DECLARE_PARENT(Operation, Base)
35 
36   bool Build();
Run()37   inline void Run() {
38     MS_LOG(DEBUG) << "Running operation " << ToString();
39     MS_EXCEPTION_IF_NULL(output_);
40     EvalOnRun();
41     MS_LOG(DEBUG) << "Run result of [" << name() << "] : " << output_->ToString();
42   }
43 
44   virtual bool EqualsTo(const OpPtr &other);
45 
inputs()46   const SymbolPtrList &inputs() const { return inputs_; }
input(size_t i)47   const SymbolPtr &input(size_t i) const {
48     if (MS_UNLIKELY(i >= inputs_.size())) {
49       MS_LOG(INTERNAL_EXCEPTION) << "The index " << i << " is out of range of the inputs size " << inputs_.size();
50     }
51     return inputs_[i];
52   }
input_num()53   size_t input_num() const { return inputs_.size(); }
output()54   const SymbolPtr &output() const { return output_; }
55   template <typename T>
input_as(size_t i)56   const T *input_as(size_t i) const {
57     const T *p = input(i)->as_noexcept<T>();
58     if (MS_UNLIKELY(p == nullptr)) {
59       MS_LOG(INTERNAL_EXCEPTION) << "Convert failed for input " << i << " of " << name();
60     }
61     return p;
62   }
63   template <typename T>
input_as_sptr(size_t i)64   std::shared_ptr<T> input_as_sptr(size_t i) const {
65     auto p = input(i)->as_sptr_noexcept<T>();
66     if (MS_UNLIKELY(p == nullptr)) {
67       MS_LOG(INTERNAL_EXCEPTION) << "Convert failed for input " << i << " of " << name();
68     }
69     return p;
70   }
71   template <typename T>
output_as()72   T *output_as() const {
73     if (MS_UNLIKELY(output_ == nullptr)) {
74       MS_LOG(INTERNAL_EXCEPTION) << "The output of " << name() << " is not initialized.";
75     }
76     T *p = output_->as_noexcept<T>();
77     if (MS_UNLIKELY(p == nullptr)) {
78       MS_LOG(INTERNAL_EXCEPTION) << "Convert failed for output of " << name();
79     }
80     return p;
81   }
82 
need_eval()83   bool need_eval() const { return need_eval_; }
84 
name()85   virtual std::string name() const { return type_name(); }
86 
87   // overwrite shared_from_this to get OpPtr directly
shared_from_this()88   OpPtr shared_from_this() { return shared_from_base<Operation>(); }
89   std::string ToString() const override;
90   std::string DumpText() const override;
91 
92   class Emitter {
93    public:
ops_(op_list)94     explicit Emitter(OpPtrList *op_list = nullptr) : ops_(op_list) {}
95     ~Emitter() = default;
96     SymbolPtr Emit(const OpPtr &op) const;
Clean()97     void Clean() { Emitter::cse_cache_.clear(); }
98 
99    private:
100     void Cse(const OpPtr &op);
101     static inline std::map<std::string, OpPtrList> cse_cache_;
102     OpPtrList *ops_;
103   };
104   friend class OperationBuilder;
105 
106  protected:
107   virtual SymbolPtr Eval() = 0;
EvalOnRun()108   virtual void EvalOnRun() { output_->Update(Eval()); }
UpdateMathInfo()109   virtual void UpdateMathInfo() {}
110 
is_building()111   bool is_building() const { return is_building_; }
DoNotEvalOnRun()112   void DoNotEvalOnRun() {
113     if (is_building_) {
114       need_eval_ = false;
115     }
116   }
117 
SetEmitter(const Emitter * e)118   void SetEmitter(const Emitter *e) { emitter_ = e; }
emitter()119   const Emitter &emitter() const {
120     static Emitter e(nullptr);
121     return emitter_ != nullptr ? *emitter_ : e;
122   }
Emit(const OpPtr & op)123   SymbolPtr Emit(const OpPtr &op) const { return emitter().Emit(op); }
124 
ResultIntList(SymbolPtrList && result)125   SymbolPtr ResultIntList(SymbolPtrList &&result) {
126     if (is_building()) {
127       return GenList(result);
128     }
129     output_as<ListSymbol>()->UpdateList(result);
130     return nullptr;
131   }
132 
GenInt(int64_t v)133   SymbolPtr GenInt(int64_t v) { return IntSymbol::Make(v, shared_from_this()); }
GenVInt()134   SymbolPtr GenVInt() { return IntSymbol::Make(shared_from_this()); }
GenList(const SymbolPtrList & list)135   SymbolPtr GenList(const SymbolPtrList &list) { return ListSymbol::Make(list, shared_from_this()); }
GenList(SymbolPtrList && list)136   SymbolPtr GenList(SymbolPtrList &&list) { return ListSymbol::Make(list, shared_from_this()); }
GenList(const std::initializer_list<SymbolPtr> & list)137   SymbolPtr GenList(const std::initializer_list<SymbolPtr> &list) { return ListSymbol::Make(list, shared_from_this()); }
GenVList()138   SymbolPtr GenVList() { return ListSymbol::Make(shared_from_this()); }
139 
GenVIntList(size_t n)140   SymbolPtr GenVIntList(size_t n) {
141     SymbolPtrList list(n);
142     std::generate(list.begin(), list.end(), [this]() { return this->GenVInt(); });
143     return GenList(std::move(list));
144   }
145 
146   // output abstract only can be used on building
SetOutAbstract(const abstract::AbstractBasePtr & abs)147   void SetOutAbstract(const abstract::AbstractBasePtr &abs) { out_abstract_ = abs; }
out_abstract()148   const abstract::AbstractBasePtr &out_abstract() const {
149     MS_EXCEPTION_IF_NULL(out_abstract_);
150     return out_abstract_;
151   }
152 
153   SymbolPtr output_{nullptr};
154   bool support_commutative_law_{false};
155 
156  private:
157   const Emitter *emitter_{nullptr};
158   SymbolPtrList inputs_;
159   bool is_building_{true};
160   bool need_eval_{true};
161   abstract::AbstractBasePtr out_abstract_{nullptr};
162 };
163 using OperationEmitter = Operation::Emitter;
164 }  // namespace symshape
165 }  // namespace mindspore
166 #endif  // MINDSPORE_CORE_SYMBOLIC_SHAPE_OPERATION_H_
167