1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef MINDSPORE_CCSRC_INCLUDE_COMMON_SYMBOL_ENGINE_SYMBOL_ENGINE_IMPL_H_
17 #define MINDSPORE_CCSRC_INCLUDE_COMMON_SYMBOL_ENGINE_SYMBOL_ENGINE_IMPL_H_
18 #include <vector>
19 #include <utility>
20 #include <unordered_map>
21 #include <map>
22 #include <string>
23 #include <memory>
24 #include <set>
25 #include <mutex>
26
27 #include "ir/anf.h"
28 #include "ir/func_graph.h"
29 #include "mindspore/core/symbolic_shape/symbol_engine.h"
30 #include "mindspore/core/symbolic_shape/symbol.h"
31 #include "mindspore/core/symbolic_shape/operation_builder.h"
32 #include "mindspore/core/symbolic_shape/operation.h"
33 #include "include/common/visible.h"
34
35 namespace mindspore {
36 namespace symshape {
37 struct COMMON_EXPORT DependStatus {
38 bool shape{false};
39 bool value{false};
40 };
41
42 /// \brief When a CNode's input[0] is also a CNode, it's a SpecialCNode.
43 class COMMON_EXPORT SpecialCNodeHelper {
44 public:
SpecialCNodeHelper(const CNodePtr & cnode)45 explicit SpecialCNodeHelper(const CNodePtr &cnode) : cnode_(cnode) {}
46 virtual ~SpecialCNodeHelper() = default;
47 virtual void SetDependStatus(std::map<AnfNodePtr, DependStatus> *depend_status_map) = 0;
48 virtual std::pair<PrimitivePtr, AbstractBasePtrList> ExtractInputs() = 0;
49
50 protected:
51 CNodePtr cnode_;
52 };
53
54 class COMMON_EXPORT SymbolEngineImpl : public SymbolEngine {
55 public:
SymbolEngineImpl(const FuncGraphPtr & fg)56 explicit SymbolEngineImpl(const FuncGraphPtr &fg) : SymbolEngine(fg), name_(fg->ToString()) {}
57 ~SymbolEngineImpl() = default;
58 MS_DECLARE_PARENT(SymbolEngineImpl, SymbolEngine)
59
60 /// \brief Build SymbolEngine, and set to the FuncGraph.
61 static std::shared_ptr<symshape::SymbolEngineImpl> Build(const FuncGraphPtr &func_graph);
62
GetInferMutex()63 std::mutex *GetInferMutex() { return &infer_mutex_; }
64 bool Infer(const AbstractBasePtrList &inputs) override;
65 bool IsDependValue(const AnfNodePtr &node) override;
66 bool IsDependShape(const AnfNodePtr &node) override;
SupportInfer()67 bool SupportInfer() override { return support_infer_; }
68 void QuerySymbolExpr(const AnfNodePtr &node, std::unordered_map<std::string, std::string> *symbol_expr_map) override;
69
ToString()70 std::string ToString() const override { return "SymbolEngine_" + name_; }
71 std::string DumpText() const override;
72
73 virtual void BuildSubgraphImpl(const CNodePtr &cnode, const FuncGraphPtr &sub_fg, size_t begin_input_index);
74 virtual void PreBuildQuerySubgraphDependStatus(const CNodePtr &cnode, const FuncGraphPtr &sub_fg,
75 size_t begin_input_index);
76
77 protected:
78 // prebuild of symbol engine, it should be called before BuildImpl
79 void PreBuild();
80 void PreBuildQueryDependStatus(const AnfNodePtrList &cnodes);
81 void PreBuildSpecialNode(const CNodePtr &cnode);
82 void SetInputDependStatus(const CNodePtr &cnode, bool current_depend_value);
83
84 // build symbol engine
85 void BuildImpl();
86 SymbolPtr BuildCNodeSymbolicShape(OperationBuilder *builder, const PrimitivePtr &prim,
87 const AbstractBasePtrList &inputs, const AbstractBasePtr &abstract,
88 const CNodePtr &cnode);
89 SymbolPtr BuildCNodeSymbolicValue(OperationBuilder *builder, const PrimitivePtr &prim,
90 const AbstractBasePtrList &inputs, const AbstractBasePtr &abstract,
91 const CNodePtr &cnode);
92 virtual AbstractBasePtrList ExtractInputsAbstract(const CNodePtr &cnode);
93
94 std::string QuerySymbolExprHelper(const SymbolPtr &s,
95 const std::unordered_map<std::string, std::string> &symbol_expr_map);
96
97 void BuildNodesSymbol(const FuncGraphPtr &fg, const AnfNodePtrList &cnodes);
98 void BuildCNodeSymbol(const CNodePtr &cnode);
99 bool SetParamSymbols(const CNodePtr &cnode, const FuncGraphPtr &sub_fg, size_t begin_input_index, size_t visit_cnt);
100 bool HasAbstractAny(const AbstractBasePtrList &inputs, const AbstractBasePtr &output);
101 bool GeneralizeParamShape(const AnfNodePtr ¶m, const AbstractBasePtr &input_abs);
102 bool GeneralizeParamValue(const AnfNodePtr ¶m, const AbstractBasePtr &input_abs);
103
104 std::string name_;
105 AnfNodePtrList cnodes_;
106 OpPtrList ops_;
107 std::unique_ptr<OperationEmitter> emitter_;
108 bool support_infer_{true};
109 std::map<AnfNodePtr, DependStatus> depend_status_map_;
110 std::map<FuncGraph *, size_t> visited_graph_;
111 std::map<AnfNodePtr, std::shared_ptr<SpecialCNodeHelper>> special_cnodes_;
112 std::mutex infer_mutex_;
113 std::set<AnfNodePtr> generalized_shape_;
114 std::set<AnfNodePtr> generalized_value_;
115 };
116
117 using SymbolEngineImplPtr = std::shared_ptr<symshape::SymbolEngineImpl>;
118 /// \brief nodes have same digital shape may use same abstract object, but their symbolic shape may not same, clone a
119 /// new abstract for symbolic info.
120 COMMON_EXPORT AbstractBasePtr CloneAbstractIfSymbolExists(const AbstractBasePtr &abs);
CloneAbstractIfSymbolExists(const AnfNodePtr & node)121 inline AbstractBasePtr CloneAbstractIfSymbolExists(const AnfNodePtr &node) {
122 node->set_abstract(CloneAbstractIfSymbolExists(node->abstract()));
123 return node->abstract();
124 }
125
126 COMMON_EXPORT void CleanSymbols(const FuncGraphPtr &func_graph);
127 } // namespace symshape
128 } // namespace mindspore
129 #endif // MINDSPORE_CCSRC_INCLUDE_COMMON_SYMBOL_ENGINE_SYMBOL_ENGINE_IMPL_H_
130