• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_
19 
20 #include <algorithm>
21 #include <memory>
22 #include <vector>
23 
24 #include "utils/hash_map.h"
25 #include "mindspore/core/ops/framework_ops.h"
26 #include "utils/hash_set.h"
27 #include "ir/func_graph.h"
28 #include "ir/func_graph_cloner.h"
29 #include "frontend/optimizer/optimizer_caller.h"
30 #include "frontend/optimizer/anf_visitor.h"
31 #include "frontend/operator/ops.h"
32 #include "frontend/optimizer/irpass.h"
33 #include "frontend/optimizer/optimizer.h"
34 #include "frontend/optimizer/graph_transform.h"
35 
36 namespace mindspore {
37 namespace opt {
38 namespace irpass {
IsFuncGraphCallNode(const AnfNodePtr & node)39 bool IsFuncGraphCallNode(const AnfNodePtr &node) {
40   if (!node->isa<CNode>()) {
41     return false;
42   }
43   auto cnode = node->cast<CNodePtr>();
44   return !IsValueNode<Primitive>(cnode->input(kAnfPrimitiveIndex));
45 }
46 
FlattenArgs(const FuncGraphPtr & fg,const AnfNodePtrList & args,size_t start_idx,AnfNodePtrList * new_args)47 bool FlattenArgs(const FuncGraphPtr &fg, const AnfNodePtrList &args, size_t start_idx, AnfNodePtrList *new_args) {
48   bool change = false;
49   for (size_t i = start_idx; i < args.size(); i++) {
50     const auto &arg = args[i];
51     auto abs = arg->abstract();
52     if (abs == nullptr) {
53       MS_LOG(INTERNAL_EXCEPTION) << "Null abs of arg:" << arg->DebugString();
54     }
55     // Dynamic length sequence input can not be flattened.
56     if (!IsSequenceExpandable(arg->abstract())) {
57       new_args->push_back(arg);
58       continue;
59     }
60     auto new_arg = TransformSequenceArgument(fg, arg, abs->cast<abstract::AbstractSequencePtr>());
61     (void)new_args->insert(new_args->cend(), new_arg.cbegin(), new_arg.cend());
62     change = true;
63   }
64   return change;
65 }
66 
67 // fg(param1_sequence, param2)
68 // =>
69 // fg(param1_1, param1_2, ..., param1_n, param2)
70 // Transform graph call sequence inputs to flat inputs.
71 class GraphSequenceTransform : public AnfVisitor {
72  public:
73   GraphSequenceTransform() = default;
74   ~GraphSequenceTransform() override = default;
operator()75   AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
76     if (!IsValueNode<FuncGraph>(node)) {
77       return nullptr;
78     }
79     auto fg = GetValueNode<FuncGraphPtr>(node);
80     if (!FuncGraphHasConstantSequenceInput(fg)) {
81       return nullptr;
82     }
83     fg = graph_transform_(fg, optimizer->manager());
84     // Can't set abstract of the value node, otherwise the renormalize process won't be executed.
85     return NewValueNode(fg);
86   }
87 
88  private:
89   GraphSequenceParamTransform graph_transform_;
90 };
91 
92 // {PrimPartial, G, Sequence_Xs}
93 // =>
94 // {kPrimPartial, G, TupleGetItem{Sequence_Xs,0}, SequenceGetItem{Sequence_Xs,1}, ..., TupleGetItem{Sequence_Xs,n}}
95 // transform partial's sequence binding args to flat inputs.
96 class PartialSequenceArgTransform : public AnfVisitor {
97  public:
98   PartialSequenceArgTransform() = default;
99   ~PartialSequenceArgTransform() override = default;
operator()100   AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
101     if (!IsPrimitiveCNode(node, prim::kPrimPartial)) {
102       return nullptr;
103     }
104     auto partial = node->cast<CNodePtr>();
105     const auto &partial_inputs = partial->inputs();
106     const auto &fg = partial->func_graph();
107     constexpr auto kPartialFirstArgIndex = 2;
108     // Put ValueNode<kPrimPartial> and ValueNode<FuncGraph> into new_inputs.
109     auto new_inputs = AnfNodePtrList(partial_inputs.begin(), partial_inputs.begin() + kPartialFirstArgIndex);
110     auto change = FlattenArgs(fg, partial_inputs, kPartialFirstArgIndex, &new_inputs);
111     if (change) {
112       auto new_partial = fg->NewCNode(new_inputs);
113       new_partial->set_abstract(partial->abstract());
114       return new_partial;
115     }
116     return nullptr;
117   }
118 };
119 
120 // {G,Sequence_Xs}
121 // =>
122 // {G, TupleGetItem{Sequence_Xs,0}, TupleGetItem{Sequence_Xs,1}, ..., TupleGetItem{Sequence_Xs,n}}
123 // Transform call's sequence args to flat inputs.
124 class CallSequenceArgTransform : public AnfVisitor {
125  public:
126   CallSequenceArgTransform() = default;
127   ~CallSequenceArgTransform() override = default;
operator()128   AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
129     if (!IsFuncGraphCallNode(node)) {
130       return nullptr;
131     }
132 
133     auto call_node = node->cast<CNodePtr>();
134     const auto &call_inputs = call_node->inputs();
135     const auto &fg = call_node->func_graph();
136     MS_EXCEPTION_IF_NULL(fg);
137     // Put ValueNode<FuncGraph> into inputs.
138     auto new_inputs = AnfNodePtrList(call_inputs.begin(), call_inputs.begin() + 1);
139     auto change = FlattenArgs(fg, call_inputs, 1, &new_inputs);
140     if (change) {
141       auto new_call = fg->NewCNode(new_inputs);
142       new_call->set_abstract(call_node->abstract());
143       return new_call;
144     }
145     return nullptr;
146   }
147 };
148 
149 class CallGraphSequenceTransform : public OptimizerCaller {
150  public:
CallGraphSequenceTransform()151   CallGraphSequenceTransform() {
152     (void)transformers_.emplace_back(std::make_shared<GraphSequenceTransform>());
153     (void)transformers_.emplace_back(std::make_shared<PartialSequenceArgTransform>());
154     (void)transformers_.emplace_back(std::make_shared<CallSequenceArgTransform>());
155   }
156   ~CallGraphSequenceTransform() override = default;
157 
operator()158   AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
159     if (AlreadyHasSparseComponent(node)) {
160       return nullptr;
161     }
162     for (auto &transform : transformers_) {
163       auto new_node = (*transform)(optimizer, node);
164       if (new_node != nullptr) {
165         return new_node;
166       }
167     }
168     return nullptr;
169   }
170 
171  private:
172   bool has_sparse_tensor_ = false;
AlreadyHasSparseComponent(const AnfNodePtr & node)173   bool AlreadyHasSparseComponent(const AnfNodePtr &node) {
174     if (has_sparse_tensor_) {
175       return true;
176     }
177     if (IsFuncGraphCallNode(node) || IsPrimitiveCNode(node, prim::kPrimPartial)) {
178       auto call_node = node->cast<CNodePtr>();
179       const auto &call_inputs = call_node->inputs();
180       for (auto input_node : call_inputs) {
181         auto abs = input_node->abstract();
182         // If SparseTensor, Tuple(SparseTensor,...) or Tuple(...,(..., SparseTensor)), return false and skip this pass.
183         if (abs != nullptr && ContainSparseTensor(abs)) {
184           has_sparse_tensor_ = true;
185           return true;
186         }
187       }
188     } else if (IsValueNode<FuncGraph>(node)) {
189       auto fg = GetValueNode<FuncGraphPtr>(node);
190       if (std::any_of(fg->parameters().cbegin(), fg->parameters().cend(), ParamContainSparseTensor)) {
191         has_sparse_tensor_ = true;
192         return true;
193       }
194     }
195     return false;
196   }
197   std::vector<OptimizerCallerPtr> transformers_{};
198 };
199 }  // namespace irpass
200 }  // namespace opt
201 }  // namespace mindspore
202 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_
203