• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_SYMBOL_ENGINE_SYMBOL_ENGINE_IMPL_H_
17 #define MINDSPORE_CCSRC_INCLUDE_COMMON_SYMBOL_ENGINE_SYMBOL_ENGINE_IMPL_H_
18 #include <vector>
19 #include <utility>
20 #include <unordered_map>
21 #include <map>
22 #include <string>
23 #include <memory>
24 #include <set>
25 #include <mutex>
26 
27 #include "ir/anf.h"
28 #include "ir/func_graph.h"
29 #include "mindspore/core/symbolic_shape/symbol_engine.h"
30 #include "mindspore/core/symbolic_shape/symbol.h"
31 #include "mindspore/core/symbolic_shape/operation_builder.h"
32 #include "mindspore/core/symbolic_shape/operation.h"
33 #include "include/common/visible.h"
34 
35 namespace mindspore {
36 namespace symshape {
37 struct COMMON_EXPORT DependStatus {
38   bool shape{false};
39   bool value{false};
40 };
41 
42 /// \brief When a CNode's input[0] is also a CNode, it's a SpecialCNode.
43 class COMMON_EXPORT SpecialCNodeHelper {
44  public:
SpecialCNodeHelper(const CNodePtr & cnode)45   explicit SpecialCNodeHelper(const CNodePtr &cnode) : cnode_(cnode) {}
46   virtual ~SpecialCNodeHelper() = default;
47   virtual void SetDependStatus(std::map<AnfNodePtr, DependStatus> *depend_status_map) = 0;
48   virtual std::pair<PrimitivePtr, AbstractBasePtrList> ExtractInputs() = 0;
49 
50  protected:
51   CNodePtr cnode_;
52 };
53 
54 class COMMON_EXPORT SymbolEngineImpl : public SymbolEngine {
55  public:
SymbolEngineImpl(const FuncGraphPtr & fg)56   explicit SymbolEngineImpl(const FuncGraphPtr &fg) : SymbolEngine(fg), name_(fg->ToString()) {}
57   ~SymbolEngineImpl() = default;
58   MS_DECLARE_PARENT(SymbolEngineImpl, SymbolEngine)
59 
60   /// \brief Build SymbolEngine, and set to the FuncGraph.
61   static std::shared_ptr<symshape::SymbolEngineImpl> Build(const FuncGraphPtr &func_graph);
62 
GetInferMutex()63   std::mutex *GetInferMutex() { return &infer_mutex_; }
64   bool Infer(const AbstractBasePtrList &inputs) override;
65   bool IsDependValue(const AnfNodePtr &node) override;
66   bool IsDependShape(const AnfNodePtr &node) override;
SupportInfer()67   bool SupportInfer() override { return support_infer_; }
68   void QuerySymbolExpr(const AnfNodePtr &node, std::unordered_map<std::string, std::string> *symbol_expr_map) override;
69 
ToString()70   std::string ToString() const override { return "SymbolEngine_" + name_; }
71   std::string DumpText() const override;
72 
73   virtual void BuildSubgraphImpl(const CNodePtr &cnode, const FuncGraphPtr &sub_fg, size_t begin_input_index);
74   virtual void PreBuildQuerySubgraphDependStatus(const CNodePtr &cnode, const FuncGraphPtr &sub_fg,
75                                                  size_t begin_input_index);
76 
77  protected:
78   // prebuild of symbol engine, it should be called before BuildImpl
79   void PreBuild();
80   void PreBuildQueryDependStatus(const AnfNodePtrList &cnodes);
81   void PreBuildSpecialNode(const CNodePtr &cnode);
82   void SetInputDependStatus(const CNodePtr &cnode, bool current_depend_value);
83 
84   // build symbol engine
85   void BuildImpl();
86   SymbolPtr BuildCNodeSymbolicShape(OperationBuilder *builder, const PrimitivePtr &prim,
87                                     const AbstractBasePtrList &inputs, const AbstractBasePtr &abstract,
88                                     const CNodePtr &cnode);
89   SymbolPtr BuildCNodeSymbolicValue(OperationBuilder *builder, const PrimitivePtr &prim,
90                                     const AbstractBasePtrList &inputs, const AbstractBasePtr &abstract,
91                                     const CNodePtr &cnode);
92   virtual AbstractBasePtrList ExtractInputsAbstract(const CNodePtr &cnode);
93 
94   std::string QuerySymbolExprHelper(const SymbolPtr &s,
95                                     const std::unordered_map<std::string, std::string> &symbol_expr_map);
96 
97   void BuildNodesSymbol(const FuncGraphPtr &fg, const AnfNodePtrList &cnodes);
98   void BuildCNodeSymbol(const CNodePtr &cnode);
99   bool SetParamSymbols(const CNodePtr &cnode, const FuncGraphPtr &sub_fg, size_t begin_input_index, size_t visit_cnt);
100   bool HasAbstractAny(const AbstractBasePtrList &inputs, const AbstractBasePtr &output);
101   bool GeneralizeParamShape(const AnfNodePtr &param, const AbstractBasePtr &input_abs);
102   bool GeneralizeParamValue(const AnfNodePtr &param, const AbstractBasePtr &input_abs);
103 
104   std::string name_;
105   AnfNodePtrList cnodes_;
106   OpPtrList ops_;
107   std::unique_ptr<OperationEmitter> emitter_;
108   bool support_infer_{true};
109   std::map<AnfNodePtr, DependStatus> depend_status_map_;
110   std::map<FuncGraph *, size_t> visited_graph_;
111   std::map<AnfNodePtr, std::shared_ptr<SpecialCNodeHelper>> special_cnodes_;
112   std::mutex infer_mutex_;
113   std::set<AnfNodePtr> generalized_shape_;
114   std::set<AnfNodePtr> generalized_value_;
115 };
116 
117 using SymbolEngineImplPtr = std::shared_ptr<symshape::SymbolEngineImpl>;
118 /// \brief nodes have same digital shape may use same abstract object, but their symbolic shape may not same, clone a
119 /// new abstract for symbolic info.
120 COMMON_EXPORT AbstractBasePtr CloneAbstractIfSymbolExists(const AbstractBasePtr &abs);
CloneAbstractIfSymbolExists(const AnfNodePtr & node)121 inline AbstractBasePtr CloneAbstractIfSymbolExists(const AnfNodePtr &node) {
122   node->set_abstract(CloneAbstractIfSymbolExists(node->abstract()));
123   return node->abstract();
124 }
125 
126 COMMON_EXPORT void CleanSymbols(const FuncGraphPtr &func_graph);
127 }  // namespace symshape
128 }  // namespace mindspore
129 #endif  // MINDSPORE_CCSRC_INCLUDE_COMMON_SYMBOL_ENGINE_SYMBOL_ENGINE_IMPL_H_
130