• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_
19 
20 #include <algorithm>
21 #include <memory>
22 #include <utility>
23 #include <vector>
24 #include <set>
25 
26 #include "utils/hash_map.h"
27 #include "mindspore/core/ops/sequence_ops.h"
28 #include "mindspore/core/ops/framework_ops.h"
29 #include "ir/func_graph_cloner.h"
30 #include "frontend/optimizer/irpass.h"
31 #include "frontend/optimizer/optimizer.h"
32 #include "frontend/optimizer/anf_visitor.h"
33 #include "frontend/operator/ops.h"
34 
35 namespace mindspore {
36 namespace opt {
37 namespace irpass {
38 const auto kMinInputSizeOfCallWithArgs = 2;
39 // {{prim::kPrimPartial, X, Xs}, Ys} -> {X, Xs, Ys} or {X, Ys, Xs}
40 class PartialEliminater : public AnfVisitor {
41  public:
operator()42   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
43     if (!node->isa<CNode>() || node->func_graph() == nullptr) {
44       return nullptr;
45     }
46     X_ = nullptr;
47     Xs_.clear();
48     auto &inputs = node->cast<CNodePtr>()->inputs();
49     Visit(inputs[0]);
50 
51     if (Xs_.size() == 0) {
52       return nullptr;
53     }
54 
55     // {X, Xs, Ys}
56     std::vector<AnfNodePtr> args{};
57     const auto xs_size = Xs_.size();
58     // Xs_ don't have monad or Ys_ is 0.
59     if (!HasAbstractMonad(Xs_.back()) || inputs.empty()) {
60       args.push_back(X_);
61       (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args));
62       (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args));
63       TraceGuard guard(std::make_shared<TracePartialTransform>(node->debug_info()));
64       auto new_node = node->func_graph()->NewCNode(args);
65       new_node->set_abstract(node->abstract());
66       return new_node;
67     }
68     // {X, Ys, Xs} if Xs has monad
69     if (!IsValueNode<FuncGraph>(X_)) {
70       constexpr auto recursive_level = 2;
71       MS_LOG(INTERNAL_EXCEPTION) << "Not support yet as X_ is not a funcgraph. node: "
72                                  << node->DebugString(recursive_level);
73     }
74     auto fg = GetValueNode<FuncGraphPtr>(X_);
75     MS_EXCEPTION_IF_NULL(fg);
76     if (fg->func_graph_cnodes_index().size() != 1) {
77       // If a graph is used by 2 or more partial nodes at the same time, clone the graph.
78       auto new_fg = BasicClone(fg);
79       auto new_fg_node = NewValueNode(new_fg);
80       fg->manager()->Replace(X_, new_fg_node);
81       fg = new_fg;
82       X_ = new_fg_node;
83     }
84     args.push_back(X_);
85     // Ys first;
86     (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args));
87     (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args));
88     TraceGuard guard(std::make_shared<TracePartialTransform>(node->debug_info()));
89     auto new_node = node->func_graph()->NewCNode(args);
90     new_node->set_abstract(node->abstract());
91 
92     // reorder the formal parameter of fg.
93     AnfNodePtrList new_params;
94     (void)std::copy(fg->parameters().cbegin() + SizeToLong(xs_size), fg->parameters().cend(),
95                     std::back_inserter(new_params));
96     (void)std::copy(fg->parameters().cbegin(), fg->parameters().cbegin() + SizeToLong(xs_size),
97                     std::back_inserter(new_params));
98     fg->manager()->SetParameters(fg, new_params);
99     return new_node;
100   }
101 
Visit(const AnfNodePtr & node)102   void Visit(const AnfNodePtr &node) override {
103     if (!IsPrimitiveCNode(node, prim::kPrimPartial)) {
104       return;
105     }
106 
107     auto &inputs = node->cast<CNodePtr>()->inputs();
108     // {prim::kPrimPartial, X, Xs}
109     if (inputs.size() <= 1) {
110       return;
111     }
112 
113     X_ = inputs[1];
114     // fill Xs
115     // {Partial, Function, Args....}
116     constexpr auto args_index = 2;
117     (void)std::copy(inputs.begin() + args_index, inputs.end(), std::back_inserter(Xs_));
118   }
119 
120  private:
121   AnfNodePtr X_{nullptr};
122   std::vector<AnfNodePtr> Xs_{};
123 };
124 
125 class ChoicePartialEliminater : public AnfVisitor {
126  public:
127   virtual ~ChoicePartialEliminater() = default;
128 
Visit(const AnfNodePtr & node)129   void Visit(const AnfNodePtr &node) override {
130     if (!IsPrimitiveCNode(node, prim::kPrimPartial)) {
131       if (IsValueNode<FuncGraph>(node)) {
132         fg_list_.push_back(node);
133         (void)args_list_.emplace_back(AnfNodePtrList{});
134       }
135       return;
136     }
137 
138     auto &inputs = node->cast<CNodePtr>()->inputs();
139     // {prim::kPrimPartial, G}
140     if (inputs.size() < kPartialMinInputSize) {
141       MS_LOG(INTERNAL_EXCEPTION) << "Node should be Partial CNode, but: " << node->DebugString();
142     }
143     if (IsValueNode<FuncGraph>(inputs[1])) {
144       fg_list_.push_back(inputs[1]);
145       AnfNodePtrList args;
146       // {Partial, Function, Args....}
147       constexpr auto args_index = 2;
148       (void)std::copy(inputs.begin() + args_index, inputs.end(), std::back_inserter(args));
149       args_list_.push_back(args);
150     }
151     return;
152   }
153 
154  protected:
155   AnfNodePtrList fg_list_{};
156   std::vector<AnfNodePtrList> args_list_{};
157 
158   // return value: true -- continue replace; false -- return nullptr;
CheckFuncGraphAndArgs()159   bool CheckFuncGraphAndArgs() {
160     // Either one should be {Partial, G, X}
161     auto has_partial_args =
162       std::any_of(args_list_.cbegin(), args_list_.cend(), [](auto &args) { return args.size() != 0; });
163     if (!has_partial_args) {
164       return false;
165     }
166 
167     // check funcgraph should be used once only.
168     for (size_t i = 0; i < fg_list_.size(); i++) {
169       auto fg_node = fg_list_[i];
170       auto fg = GetValueNode<FuncGraphPtr>(fg_node);
171       MS_EXCEPTION_IF_NULL(fg);
172       if (fg->func_graph_cnodes_index().size() != 1) {
173         // If a graph is used by 2 or more partial nodes at the same time, clone the graph.
174         // BasicClone should be replaced by TransformableClone to avoid recursive.
175         auto new_fg = TransformableClone(fg);
176         auto manager = fg->manager();
177         MS_EXCEPTION_IF_NULL(manager);
178         manager->AddFuncGraph(new_fg);
179         fg_list_[i] = NewValueNode(new_fg);
180       }
181     }
182     return true;
183   }
184 
185   // Merge partial's args and call's args
186   // branch1: {{primPartial, Xs}, Zs} -> {{primPartial, Xs, Zs}}
187   // branch2: {{primPartial, Ys}, Zs} -> {{primPartial, Ys, Zs}}
MergeArgs(const CNodePtr & call_node)188   void MergeArgs(const CNodePtr &call_node) {
189     for (auto &args : args_list_) {
190       (void)args.insert(args.end(), call_node->inputs().begin() + 1, call_node->inputs().end());
191     }
192   }
193 
194   // f(x1, x2, x3, z1, z2 ,monad1)
195   // g(x4, x2, z1, z2, monad2)
196   // h(x5, x2, x7, x8, z1, z2, monad3)
197   // --> union_args = (x1, x2, x3, z1, z2, x4, x5, x7 ,x8, monad1, monad2, monad3)
198   // h(x1, x2, x3, z1, z2, x4, x5, x7 ,x8, monad1, monad2, monad3)
199   // f(x1, x2, x3, z1, z2, x4, x5, x7 ,x8, monad1, monad2, monad3)
200   // g(x1, x2, x3, z1, z2, x4, x5, x7 ,x8, monad1, monad2, monad3)
UnifyParameters(const AnfNodePtrList & fg_list,const std::vector<AnfNodePtrList> args_list)201   static AnfNodePtrList UnifyParameters(const AnfNodePtrList &fg_list, const std::vector<AnfNodePtrList> args_list) {
202     if (fg_list.empty()) {
203       return {};
204     }
205     auto first_func_graph = GetValueNode<FuncGraphPtr>(fg_list[0]);
206     MS_EXCEPTION_IF_NULL(first_func_graph);
207     const auto manager = first_func_graph->manager();
208     MS_EXCEPTION_IF_NULL(manager);
209     auto txn = manager->Transact();
210     // Get all new args, new args is the union set of old args.
211     auto new_args = ArgsUnion(args_list);
212     auto old_args_index_map = GenOldArgsIndexes(fg_list, args_list);
213     for (size_t branch_index = 0; branch_index < fg_list.size(); ++branch_index) {
214       auto func_graph = GetValueNode<FuncGraphPtr>(fg_list[branch_index]);
215       MS_EXCEPTION_IF_NULL(func_graph);
216       auto new_parameters = GetFuncGraphNewParameters(func_graph, new_args, old_args_index_map);
217       txn.SetParameters(func_graph, new_parameters);
218     }
219     txn.Commit();
220     return new_args;
221   }
222 
223  private:
ArgsUnion(const std::vector<AnfNodePtrList> args_list)224   static std::vector<AnfNodePtr> ArgsUnion(const std::vector<AnfNodePtrList> args_list) {
225     std::vector<AnfNodePtr> no_monad_args;
226     std::vector<AnfNodePtr> monad_args;
227     for (const auto &args : args_list) {
228       for (const auto &arg : args) {
229         if (HasAbstractMonad(arg)) {
230           if (count(monad_args.begin(), monad_args.end(), arg) == 0) {
231             monad_args.push_back(arg);
232           }
233           continue;
234         }
235         if (count(no_monad_args.begin(), no_monad_args.end(), arg) == 0) {
236           no_monad_args.push_back(arg);
237         }
238       }
239     }
240     // Keep monad args after no monad args.
241     (void)no_monad_args.insert(no_monad_args.end(), monad_args.begin(), monad_args.end());
242     return no_monad_args;
243   }
244 
GenOldArgsIndexes(const AnfNodePtrList & fg_list,const std::vector<AnfNodePtrList> & args_list)245   static HashMap<FuncGraphPtr, HashMap<AnfNodePtr, size_t>> GenOldArgsIndexes(
246     const AnfNodePtrList &fg_list, const std::vector<AnfNodePtrList> &args_list) {
247     HashMap<FuncGraphPtr, HashMap<AnfNodePtr, size_t>> old_args_indexes;
248     for (size_t i = 0; i < fg_list.size(); ++i) {
249       const auto func_graph = GetValueNode<FuncGraphPtr>(fg_list[i]);
250       MS_EXCEPTION_IF_NULL(func_graph);
251       const auto &args = args_list[i];
252       HashMap<AnfNodePtr, size_t> args_indexes;
253       size_t arg_index = 0;
254       for (const auto &arg : args) {
255         (void)args_indexes.emplace(arg, arg_index++);
256       }
257       old_args_indexes[func_graph] = args_indexes;
258     }
259     return old_args_indexes;
260   }
261 
GetParameterByArg(const HashMap<FuncGraphPtr,HashMap<AnfNodePtr,size_t>> & all_old_args_index_map,const AnfNodePtr & arg)262   static AnfNodePtr GetParameterByArg(const HashMap<FuncGraphPtr, HashMap<AnfNodePtr, size_t>> &all_old_args_index_map,
263                                       const AnfNodePtr &arg) {
264     MS_LOG(DEBUG) << "Get parameter by arg:" << arg->DebugString();
265     for (const auto &[fg, old_args_index] : all_old_args_index_map) {
266       auto it = old_args_index.find(arg);
267       if (it == old_args_index.end()) {
268         continue;
269       }
270       size_t arg_index = it->second;
271       if (arg_index >= fg->parameters().size()) {
272         MS_LOG(INTERNAL_EXCEPTION) << "Index:" << arg_index << " out of range:" << fg->parameters().size();
273       }
274       return fg->parameters()[arg_index];
275     }
276     MS_LOG(INTERNAL_EXCEPTION) << "Can't find parameter of arg:" << arg->DebugString();
277   }
278 
GetFuncGraphNewParameters(const FuncGraphPtr & func_graph,const std::vector<AnfNodePtr> & new_args,const HashMap<FuncGraphPtr,HashMap<AnfNodePtr,size_t>> & all_old_args_index_map)279   static std::vector<AnfNodePtr> GetFuncGraphNewParameters(
280     const FuncGraphPtr &func_graph, const std::vector<AnfNodePtr> &new_args,
281     const HashMap<FuncGraphPtr, HashMap<AnfNodePtr, size_t>> &all_old_args_index_map) {
282     MS_EXCEPTION_IF_NULL(func_graph);
283     const auto &old_parameters = func_graph->parameters();
284     std::vector<AnfNodePtr> new_parameters(new_args.size());
285     const auto &old_args_index_map = all_old_args_index_map.find(func_graph)->second;
286     for (size_t new_arg_index = 0; new_arg_index < new_args.size(); ++new_arg_index) {
287       const auto &new_arg = new_args[new_arg_index];
288       auto arg_old_index_it = old_args_index_map.find(new_arg);
289       // The new_arg is the arg of current func graph.
290       if (arg_old_index_it != old_args_index_map.end()) {
291         auto arg_old_index = arg_old_index_it->second;
292         new_parameters[new_arg_index] = old_parameters[arg_old_index];
293         MS_LOG(DEBUG) << "Find exist parameter:" << new_parameters[new_arg_index]->DebugString()
294                       << ", arg_old_index:" << arg_old_index;
295         continue;
296       }
297       // The new_arg is the arg of other func graph.
298       const auto other_fg_parameter = GetParameterByArg(all_old_args_index_map, new_arg);
299       MS_LOG(DEBUG) << "Get other fg's parameter:" << other_fg_parameter->DebugString();
300       TraceGuard guard(std::make_shared<TraceCopy>(other_fg_parameter->debug_info()));
301       ParameterPtr param = std::make_shared<Parameter>(func_graph);
302       param->set_abstract(other_fg_parameter->abstract());
303       new_parameters[new_arg_index] = param;
304     }
305     return new_parameters;
306   }
307 };
308 
309 // {{prim::kPrimSwitch, cond, {prim::kPrimPartial, G1, Xs}, {prim::kPrimPartial, G2, Ys}}, Zs} ->
310 // {{prim::kPrimSwitch, cond, G1, G2}, Xs Union Ys Union Zs}
311 // {{prim::kPrimSwitch, cond, {G1}, {prim::kPrimPartial, G2, Ys}}, Zs} -> {{prim::kPrimSwitch, cond, G1, G2}, Ys Union
312 // Zs}
313 // {{prim::kPrimSwitch, cond, {prim::kPrimPartial, G1, Xs}, {G2}}, Zs} -> {{prim::kPrimSwitch, cond, G1, G2}, Xs Union
314 // Zs}
315 class SwitchPartialEliminater : public ChoicePartialEliminater {
316  public:
operator()317   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
318     if (!node->isa<CNode>() || node->func_graph() == nullptr) {
319       return nullptr;
320     }
321     auto switch_call = node->cast<CNodePtr>();
322     if (!IsPrimitiveCNode(switch_call->input(0), prim::kPrimSwitch)) {
323       return nullptr;
324     }
325     auto switch_node = switch_call->input(0)->cast<CNodePtr>();
326     if (switch_node->size() != kSwitchInputSize) {
327       return nullptr;
328     }
329     fg_list_.clear();
330     args_list_.clear();
331     const auto maybe_partial_1 = switch_node->input(kSwitchTrueBranchIndex);
332     Visit(maybe_partial_1);
333     const auto maybe_partial_2 = switch_node->input(kSwitchFalseBranchIndex);
334     Visit(maybe_partial_2);
335 
336     // Either one should be {Partial, G, X}
337     if (fg_list_.size() != kSwitchBranchesNum && args_list_.size() != kSwitchBranchesNum) {
338       return nullptr;
339     }
340     if (!CheckFuncGraphAndArgs()) {
341       return nullptr;
342     }
343     MergeArgs(switch_call);
344     if (args_list_[0] == args_list_[1]) {
345       return BuildNewSwitchNode(switch_call, args_list_[0]);
346     } else {
347       const auto new_args = UnifyParameters(fg_list_, args_list_);
348       return BuildNewSwitchNode(switch_call, new_args);
349     }
350   }
351 
352  private:
BuildNewSwitchNode(const CNodePtr & switch_call,const std::vector<AnfNodePtr> & new_args)353   AnfNodePtr BuildNewSwitchNode(const CNodePtr &switch_call, const std::vector<AnfNodePtr> &new_args) {
354     auto fg = switch_call->func_graph();
355     MS_EXCEPTION_IF_NULL(fg);
356     const auto input0 = switch_call->input(0);
357     MS_EXCEPTION_IF_NULL(input0);
358     const auto switch_node = input0->cast<CNodePtr>();
359     TraceGuard guard1(std::make_shared<TraceCopy>(switch_node->debug_info()));
360     // {Switch, cond, G1, G2}
361     std::vector<AnfNodePtr> switch_inputs = {switch_node->input(0), switch_node->input(1)};
362     (void)switch_inputs.insert(switch_inputs.end(), fg_list_.begin(), fg_list_.end());
363     const auto new_switch_cnode = fg->NewCNode(std::move(switch_inputs));
364     new_switch_cnode->set_abstract(switch_node->abstract());
365     // Create switch call.
366     TraceGuard guard2(std::make_shared<TraceCopy>(switch_call->debug_info()));
367     AnfNodePtrList switch_call_inputs{new_switch_cnode};
368     (void)switch_call_inputs.insert(switch_call_inputs.end(), new_args.begin(), new_args.end());
369     const auto new_call_node = fg->NewCNode(std::move(switch_call_inputs));
370     new_call_node->set_abstract(switch_call->abstract());
371     return new_call_node;
372   }
373 };
374 
375 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{{prim::kPrimPartial, G1, Xs}, {prim::kPrimPartial, G2, Ys}}}, Zs} ->
376 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{G1, G2}, Xs Union Ys Union Zs}
377 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{{G1}, {prim::kPrimPartial, G2, Ys}}}, Zs} ->
378 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{G1, G2}}, Ys Union Zs}
379 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{{prim::kPrimPartial, G1, Xs}, {G2}}{}, Zs} ->
380 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{G1, G2}}, Xs Union Zs}
381 class SwitchLayerPartialEliminater : public ChoicePartialEliminater {
382  public:
operator()383   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
384     if (!node->isa<CNode>() || node->func_graph() == nullptr) {
385       return nullptr;
386     }
387     auto switch_layer_call = node->cast<CNodePtr>();
388     MS_EXCEPTION_IF_NULL(switch_layer_call);
389     // {SwitchLayer{}, Zs}
390     if (!IsPrimitiveCNode(switch_layer_call->input(0), prim::kPrimSwitchLayer)) {
391       return nullptr;
392     }
393     auto switch_layer_cnode = switch_layer_call->input(0)->cast<CNodePtr>();
394     // {SwitchLayer, cond, MakeTuple{}}
395     if (switch_layer_cnode->size() != kSwitchLayerInputSize) {
396       return nullptr;
397     }
398     if (!IsPrimitiveCNode(switch_layer_cnode->input(kSwitchLayerBranchesIndex), prim::kPrimMakeTuple)) {
399       return nullptr;
400     }
401     auto make_tuple_cnode = switch_layer_cnode->input(kSwitchLayerBranchesIndex)->cast<CNodePtr>();
402     if (make_tuple_cnode->size() <= 1) {
403       return nullptr;
404     }
405 
406     fg_list_.clear();
407     args_list_.clear();
408     // Build funcgraph list and args list;
409     for (size_t i = 1; i < make_tuple_cnode->size(); ++i) {
410       Visit(make_tuple_cnode->input(i));
411     }
412 
413     if (!CheckFuncGraphAndArgs()) {
414       return nullptr;
415     }
416     MergeArgs(switch_layer_call);
417     // All have the same args;
418     auto args_equal =
419       std::all_of(args_list_.cbegin() + 1, args_list_.cend(), [this](auto &args) { return args == args_list_[0]; });
420     if (args_equal) {
421       return BuildNewSwitchLayerNode(switch_layer_call, args_list_[0]);
422     } else {
423       const auto new_args = UnifyParameters(fg_list_, args_list_);
424       return BuildNewSwitchLayerNode(switch_layer_call, new_args);
425     }
426   }
427 
428  private:
BuildNewSwitchLayerNode(const CNodePtr & switch_layer_call_node,const AnfNodePtrList & new_args)429   AnfNodePtr BuildNewSwitchLayerNode(const CNodePtr &switch_layer_call_node, const AnfNodePtrList &new_args) {
430     const auto switch_layer = switch_layer_call_node->input(0)->cast<CNodePtr>();
431     MS_EXCEPTION_IF_NULL(switch_layer);
432     auto make_tuple_cnode = switch_layer->input(kSwitchLayerBranchesIndex)->cast<CNodePtr>();
433     MS_EXCEPTION_IF_NULL(make_tuple_cnode);
434     // {primMakeTuple, G1, G2, ...}
435     AnfNodePtrList make_tuple_args{make_tuple_cnode->input(0)};
436     (void)make_tuple_args.insert(make_tuple_args.end(), fg_list_.begin(), fg_list_.end());
437     TraceGuard guard1(std::make_shared<TraceCopy>(make_tuple_cnode->debug_info()));
438     auto new_make_tuple_cnode = make_tuple_cnode->func_graph()->NewCNode(std::move(make_tuple_args));
439     // {primSwitchLayer, cond, MakeTuple{}}
440     TraceGuard guard2(std::make_shared<TraceCopy>(switch_layer->debug_info()));
441     auto new_switch_layer =
442       switch_layer->func_graph()->NewCNode({switch_layer->input(0), switch_layer->input(1), new_make_tuple_cnode});
443     // Create new switch_layer call node.
444     TraceGuard guard3(std::make_shared<TraceCopy>(switch_layer_call_node->debug_info()));
445     AnfNodePtrList switch_layer_call_inputs{new_switch_layer};
446     (void)switch_layer_call_inputs.insert(switch_layer_call_inputs.cend(), new_args.cbegin(), new_args.cend());
447     auto new_node = switch_layer_call_node->func_graph()->NewCNode(std::move(switch_layer_call_inputs));
448     new_node->set_abstract(switch_layer_call_node->abstract());
449     return new_node;
450   }
451 };
452 }  // namespace irpass
453 }  // namespace opt
454 }  // namespace mindspore
455 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_
456