• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_
19 
20 #include <algorithm>
21 #include <memory>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 
26 #include "frontend/optimizer/irpass.h"
27 #include "frontend/optimizer/optimizer.h"
28 #include "frontend/optimizer/anf_visitor.h"
29 #include "frontend/operator/ops.h"
30 
31 namespace mindspore {
32 namespace opt {
33 namespace irpass {
34 // {{prim::kPrimPartial, X, Xs}, Ys} -> {X, Xs, Ys} or {X, Ys, Xs}
35 class PartialEliminater : public AnfVisitor {
36  public:
operator()37   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
38     if (!node->isa<CNode>() || node->func_graph() == nullptr) {
39       return nullptr;
40     }
41 
42     X_ = nullptr;
43     Xs_.clear();
44     auto &inputs = node->cast<CNodePtr>()->inputs();
45     Visit(inputs[0]);
46 
47     if (Xs_.size() == 0) {
48       return nullptr;
49     }
50 
51     // {X, Xs, Ys}
52     std::vector<AnfNodePtr> args{};
53     const auto &xs_size = Xs_.size();
54     // Xs_ don't have monad or Ys_ is 0.
55     if (!HasAbstractMonad(Xs_.back()) || inputs.empty()) {
56       args.push_back(X_);
57       (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args));
58       (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args));
59       TraceGuard guard(std::make_shared<TracePartialTransform>(node->debug_info()));
60       auto new_node = node->func_graph()->NewCNode(args);
61       return new_node;
62     }
63     // {X, Ys, Xs} if Xs has monad
64     if (!IsValueNode<FuncGraph>(X_)) {
65       MS_LOG(EXCEPTION) << "not support yet as X_ is not a funcgraph. node: " << node->DebugString(2);
66     }
67     auto fg = GetValueNode<FuncGraphPtr>(X_);
68     MS_EXCEPTION_IF_NULL(fg);
69     if (fg->func_graph_cnodes_index().size() != 1) {
70       // If a graph is used by 2 or more partial nodes at the same time, clone the graph.
71       auto new_fg = BasicClone(fg);
72       auto new_fg_node = NewValueNode(new_fg);
73       fg->manager()->Replace(X_, new_fg_node);
74       fg = new_fg;
75       X_ = new_fg_node;
76     }
77     args.push_back(X_);
78     // Ys first;
79     (void)std::copy(inputs.begin() + 1, inputs.end(), std::back_inserter(args));
80     (void)std::copy(Xs_.begin(), Xs_.end(), std::back_inserter(args));
81     TraceGuard guard(std::make_shared<TracePartialTransform>(node->debug_info()));
82     auto new_node = node->func_graph()->NewCNode(args);
83     new_node->set_abstract(node->abstract());
84 
85     // reorder the formal parameter of fg.
86     AnfNodePtrList new_params;
87     std::copy(fg->parameters().cbegin() + xs_size, fg->parameters().cend(), std::back_inserter(new_params));
88     std::copy(fg->parameters().cbegin(), fg->parameters().cbegin() + xs_size, std::back_inserter(new_params));
89     fg->manager()->SetParameters(fg, new_params);
90     return new_node;
91   }
92 
Visit(const AnfNodePtr & node)93   void Visit(const AnfNodePtr &node) override {
94     if (!IsPrimitiveCNode(node, prim::kPrimPartial)) {
95       return;
96     }
97 
98     auto &inputs = node->cast<CNodePtr>()->inputs();
99     // {prim::kPrimPartial, X, Xs}
100     if (inputs.size() < 2) {
101       return;
102     }
103 
104     X_ = inputs[1];
105     // fill Xs
106     (void)std::copy(inputs.begin() + 2, inputs.end(), std::back_inserter(Xs_));
107   }
108 
109  private:
110   AnfNodePtr X_{nullptr};
111   std::vector<AnfNodePtr> Xs_{};
112 };
113 
114 class ChoicePartialEliminater : public AnfVisitor {
115  public:
116   virtual ~ChoicePartialEliminater() = default;
117 
118  protected:
119   AnfNodePtrList fg_list_{};
120   std::vector<AnfNodePtrList> args_list_{};
121 
Visit(const AnfNodePtr & node)122   void Visit(const AnfNodePtr &node) override {
123     if (!IsPrimitiveCNode(node, prim::kPrimPartial)) {
124       if (IsValueNode<FuncGraph>(node)) {
125         fg_list_.push_back(node);
126         args_list_.push_back(AnfNodePtrList{});
127       }
128       return;
129     }
130 
131     auto &inputs = node->cast<CNodePtr>()->inputs();
132     // {prim::kPrimPartial, G, Xs}
133     if (inputs.size() < 3) {
134       MS_LOG(EXCEPTION) << "Node should be Partial CNode, but: " << node->DebugString();
135       return;
136     }
137     if (IsValueNode<FuncGraph>(inputs[1])) {
138       fg_list_.push_back(inputs[1]);
139       AnfNodePtrList args;
140       (void)std::copy(inputs.begin() + 2, inputs.end(), std::back_inserter(args));
141       args_list_.push_back(args);
142     }
143     return;
144   }
145 
146   // return value: true -- continue replace; false -- return nullptr;
CheckFuncGraphAndArgs()147   bool CheckFuncGraphAndArgs() {
148     // Either one should be {Partial, G, X}
149     auto has_partial_args =
150       std::any_of(args_list_.cbegin(), args_list_.cend(), [](auto &args) { return args.size() != 0; });
151     if (!has_partial_args) {
152       return false;
153     }
154 
155     // check funcgraph should be used once only.
156     for (size_t i = 0; i < fg_list_.size(); i++) {
157       auto fg_node = fg_list_[i];
158       auto fg = GetValueNode<FuncGraphPtr>(fg_node);
159       MS_EXCEPTION_IF_NULL(fg);
160       if (fg->func_graph_cnodes_index().size() != 1) {
161         // If a graph is used by 2 or more partial nodes at the same time, clone the graph.
162         auto new_fg = BasicClone(fg);
163         auto manager = fg->manager();
164         MS_EXCEPTION_IF_NULL(manager);
165         manager->AddFuncGraph(new_fg);
166         fg_node->cast<ValueNodePtr>()->set_value(new_fg);
167       }
168     }
169     return true;
170   }
171 
172   // f(x1, x2, x3, z1, z2)
173   // g(x4, x2, z1, z2)
174   // h(x5, x2, x7, x8, z1, z2)
175   // --> anchor_fg = h
176   // h(x5, x2, x7, x8, x1, x3, x4, z1, z2)
177   // f(x5, x2, x7, x8, x1, x3, x4, z1, z2)
178   // g(x5, x2, x7, x8, x1, x3, x4, z1, z2)
179   // as z1, z2 maybe U or IO monad.
UnifyParameters(const size_t & anchor_index,const AnfNodePtrList & fg_list,const std::vector<AnfNodePtrList> args_list)180   AnfNodePtrList UnifyParameters(const size_t &anchor_index, const AnfNodePtrList &fg_list,
181                                  const std::vector<AnfNodePtrList> args_list) {
182     std::vector<size_t> inputs_index_list[args_list.size()];
183     size_t extra_input_counter = 0;
184     AnfNodePtrList extra_inputs;
185     const auto &anchor_args = args_list[anchor_index];
186     size_t anchor_args_size = anchor_args.size();
187     auto anchor_fg = GetValueNode<FuncGraphPtr>(fg_list[anchor_index]);
188     MS_EXCEPTION_IF_NULL(anchor_fg);
189     // Find the new location of the old_inputs except Zs;
190     for (size_t i = 0; i < args_list.size(); ++i) {
191       if (i == anchor_index) {
192         continue;
193       }
194       const auto &another_args = args_list[i];
195       auto &curr_inputs_index = inputs_index_list[i];
196       for (size_t j = 0; j < another_args.size(); ++j) {
197         size_t k;
198         for (k = 0; k < anchor_args_size; ++k) {
199           if (another_args[j] == anchor_args[k]) {
200             curr_inputs_index.push_back(k);
201             break;
202           }
203         }
204         if (k == anchor_args_size) {
205           // check if used by another func_graph;
206           for (k = 0; k < extra_input_counter; ++k) {
207             if (another_args[j] == extra_inputs[k]) {
208               curr_inputs_index.push_back(anchor_args_size + k);
209               break;
210             }
211           }
212           if (k == extra_input_counter) {
213             extra_inputs.push_back(another_args[j]);
214             curr_inputs_index.push_back(anchor_args_size + extra_input_counter);
215             extra_input_counter++;
216           }
217         }
218       }
219     }
220 
221     auto manager = anchor_fg->manager();
222     MS_EXCEPTION_IF_NULL(manager);
223     auto txn = manager->Transact();
224 
225     size_t anchor_params_size = anchor_fg->parameters().size();
226     const auto &anchor_fg_params = anchor_fg->parameters();
227     for (size_t i = 0; i < args_list.size(); ++i) {
228       if (i == anchor_index) {
229         continue;
230       }
231       AnfNodePtrList new_params;
232       new_params.resize(anchor_params_size + extra_input_counter);
233 
234       const auto &curr_inputs_index = inputs_index_list[i];
235       auto another_fg = GetValueNode<FuncGraphPtr>(fg_list[i]);
236       MS_EXCEPTION_IF_NULL(another_fg);
237       const auto &old_params = another_fg->parameters();
238       const auto &old_args = args_list[i];
239       for (size_t j = 0; j < old_args.size(); j++) {
240         new_params[curr_inputs_index[j]] = old_params[j];
241       }
242       // Zs_
243       for (size_t j = old_args.size(), k = 0; j < old_params.size(); ++j, ++k) {
244         new_params[anchor_args_size + extra_input_counter + k] = old_params[j];
245       }
246       // unused inputs
247       for (size_t j = 0; j < anchor_args_size; ++j) {
248         if (new_params[j] == nullptr) {
249           TraceGuard guard(std::make_shared<TraceCopy>(anchor_fg_params[j]->debug_info()));
250           ParameterPtr param = std::make_shared<Parameter>(another_fg);
251           new_params[j] = param;
252         }
253       }
254       // extra inputs used by another func_graph;
255       for (size_t j = 0; j < extra_inputs.size(); ++j) {
256         if (new_params[anchor_args_size + j] == nullptr) {
257           TraceGuard guard(std::make_shared<TraceCopy>(extra_inputs[j]->debug_info()));
258           ParameterPtr param = std::make_shared<Parameter>(another_fg);
259           new_params[anchor_args_size + j] = param;
260         }
261       }
262       // set the parameter for another_fg and replace it's parameters;
263       txn.SetParameters(another_fg, new_params);
264     }
265     // Reorder Zs_ and add extra parameters for anchor_fg;
266     // add extra parameter for anchor_fg;
267     AnfNodePtrList new_params;
268     new_params.reserve(anchor_params_size + extra_input_counter);
269     // reuse parameters for anchor_args;
270     std::copy(anchor_fg_params.cbegin(), anchor_fg_params.cbegin() + anchor_args_size, std::back_inserter(new_params));
271     // Extra parameters;
272     for (size_t i = 0; i < extra_inputs.size(); ++i) {
273       TraceGuard guard(std::make_shared<TraceCopy>(extra_inputs[i]->debug_info()));
274       ParameterPtr param = std::make_shared<Parameter>(anchor_fg);
275       new_params.push_back(param);
276     }
277     // Reorder Zs_ to last;
278     for (size_t i = anchor_args_size; i < anchor_params_size; ++i) {
279       new_params.push_back(anchor_fg_params[i]);
280     }
281     txn.SetParameters(anchor_fg, new_params);
282     txn.Commit();
283 
284     return extra_inputs;
285   }
286 };
287 
288 // {{prim::kPrimSwitch, cond, {prim::kPrimPartial, G1, Xs}, {prim::kPrimPartial, G2, Ys}}, Zs} ->
289 // {{prim::kPrimSwitch, cond, G1, G2}, Xs Union Ys Union Zs}
290 // {{prim::kPrimSwitch, cond, {G1}, {prim::kPrimPartial, G2, Ys}}, Zs} -> {{prim::kPrimSwitch, cond, G1, G2}, Ys Union
291 // Zs}
292 // {{prim::kPrimSwitch, cond, {prim::kPrimPartial, G1, Xs}, {G2}}, Zs} -> {{prim::kPrimSwitch, cond, G1, G2}, Xs Union
293 // Zs}
294 class SwitchPartialEliminater : public ChoicePartialEliminater {
295  public:
operator()296   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
297     if (!node->isa<CNode>() || node->func_graph() == nullptr) {
298       return nullptr;
299     }
300     auto cnode = node->cast<CNodePtr>();
301     if (!IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitch)) {
302       return nullptr;
303     }
304     auto input0_cnode = cnode->input(0)->cast<CNodePtr>();
305     if (input0_cnode->size() != 4) {
306       return nullptr;
307     }
308 
309     fg_list_.clear();
310     args_list_.clear();
311     auto &maybe_partial_1 = input0_cnode->input(2);
312     Visit(maybe_partial_1);
313     auto &maybe_partial_2 = input0_cnode->input(3);
314     Visit(maybe_partial_2);
315 
316     // Either one should be {Partial, G, X}
317     if (fg_list_.size() != 2 && args_list_.size() != 2) {
318       return nullptr;
319     }
320     // Should not continue;
321     if (!CheckFuncGraphAndArgs()) {
322       return nullptr;
323     }
324 
325     if (args_list_[0] == args_list_[1]) {
326       auto new_node =
327         BuildNewSwitchNode(cnode, input0_cnode, fg_list_[0], fg_list_[1], args_list_[0], AnfNodePtrList{});
328       return new_node;
329     } else {
330       // find partial funcgraph with the longest args as anchor;
331       size_t max_args_pos = 0;
332       if (args_list_[0].size() > args_list_[1].size()) {
333         max_args_pos = 0;
334       } else {
335         max_args_pos = 1;
336       }
337 
338       auto extra_inputs = UnifyParameters(max_args_pos, fg_list_, args_list_);
339       auto new_node =
340         BuildNewSwitchNode(cnode, input0_cnode, fg_list_[0], fg_list_[1], args_list_[max_args_pos], extra_inputs);
341       return new_node;
342     }
343   }
344 
345  private:
BuildNewSwitchNode(const CNodePtr & old_cnode,const CNodePtr input0_cnode,const AnfNodePtr & G1,const AnfNodePtr & G2,const AnfNodePtrList & partial_args,const AnfNodePtrList & extra_args)346   AnfNodePtr BuildNewSwitchNode(const CNodePtr &old_cnode, const CNodePtr input0_cnode, const AnfNodePtr &G1,
347                                 const AnfNodePtr &G2, const AnfNodePtrList &partial_args,
348                                 const AnfNodePtrList &extra_args) {
349     TraceGuard guard1(std::make_shared<TraceCopy>(input0_cnode->debug_info()));
350     // {Switch, cond, G1, G2}
351     auto switch_cnode = old_cnode->func_graph()->NewCNode({input0_cnode->input(0), input0_cnode->input(1), G1, G2});
352     AnfNodePtrList args{switch_cnode};
353     (void)std::copy(partial_args.begin(), partial_args.end(), std::back_inserter(args));
354     (void)std::copy(extra_args.begin(), extra_args.end(), std::back_inserter(args));
355     // Zs
356     if (old_cnode->size() >= 2) {
357       (void)std::copy(old_cnode->inputs().begin() + 1, old_cnode->inputs().end(), std::back_inserter(args));
358     }
359     TraceGuard guard2(std::make_shared<TraceCopy>(old_cnode->debug_info()));
360     auto new_node = old_cnode->func_graph()->NewCNode(args);
361     new_node->set_abstract(old_cnode->abstract());
362     return new_node;
363   }
364 };
365 
366 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{{prim::kPrimPartial, G1, Xs}, {prim::kPrimPartial, G2, Ys}}}, Zs} ->
367 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{G1, G2}, Xs Union Ys Union Zs}
368 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{{G1}, {prim::kPrimPartial, G2, Ys}}}, Zs} ->
369 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{G1, G2}}, Ys Union Zs}
370 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{{prim::kPrimPartial, G1, Xs}, {G2}}{}, Zs} ->
371 // {{prim::kPrimSwitchLayer, cond, prim::MakeTuple{G1, G2}}, Xs Union Zs}
372 class SwitchLayerPartialEliminater : public ChoicePartialEliminater {
373  public:
operator()374   AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
375     if (!node->isa<CNode>() || node->func_graph() == nullptr) {
376       return nullptr;
377     }
378     auto cnode = node->cast<CNodePtr>();
379     // {SwitchLayer{}, Zs}
380     if (!IsPrimitiveCNode(cnode->input(0), prim::kPrimSwitchLayer)) {
381       return nullptr;
382     }
383     auto switch_layer_cnode = cnode->input(0)->cast<CNodePtr>();
384     // {SwitchLayer, cond, MakeTuple{}}
385     if (switch_layer_cnode->size() != 3) {
386       return nullptr;
387     }
388     if (!IsPrimitiveCNode(switch_layer_cnode->input(2), prim::kPrimMakeTuple)) {
389       return nullptr;
390     }
391     auto make_tuple_cnode = switch_layer_cnode->input(2)->cast<CNodePtr>();
392     if (make_tuple_cnode->size() < 2) {
393       return nullptr;
394     }
395 
396     fg_list_.clear();
397     args_list_.clear();
398     // Build funcgraph list and args list;
399     for (size_t i = 1; i < make_tuple_cnode->size(); ++i) {
400       Visit(make_tuple_cnode->input(i));
401     }
402 
403     if (!CheckFuncGraphAndArgs()) {
404       return nullptr;
405     }
406     // All have the same args;
407     auto args_equal =
408       std::all_of(args_list_.cbegin() + 1, args_list_.cend(), [this](auto &args) { return args == args_list_[0]; });
409     if (args_equal) {
410       auto new_node = BuildNewSwitchLayerNode(cnode, switch_layer_cnode, args_list_[0], AnfNodePtrList{});
411       return new_node;
412     } else {
413       // find partial funcgraph with the longest args as anchor;
414       size_t max_args_pos = 0, max_args_len = 0;
415       for (size_t i = 0; i < args_list_.size(); ++i) {
416         if (max_args_len < args_list_[i].size()) {
417           max_args_len = args_list_[i].size();
418           max_args_pos = i;
419         }
420       }
421       auto extra_inputs = UnifyParameters(max_args_pos, fg_list_, args_list_);
422       auto new_node = BuildNewSwitchLayerNode(cnode, switch_layer_cnode, args_list_[max_args_pos], extra_inputs);
423       return new_node;
424     }
425   }
426 
427  private:
BuildNewSwitchLayerNode(const CNodePtr & old_cnode,const CNodePtr switch_layer_cnode,const AnfNodePtrList & anchor_partial_args,const AnfNodePtrList & extra_args)428   AnfNodePtr BuildNewSwitchLayerNode(const CNodePtr &old_cnode, const CNodePtr switch_layer_cnode,
429                                      const AnfNodePtrList &anchor_partial_args, const AnfNodePtrList &extra_args) {
430     auto make_tuple_cnode = switch_layer_cnode->input(2)->cast<CNodePtr>();
431     AnfNodePtrList make_tuple_args{make_tuple_cnode->input(0)};
432     make_tuple_args.insert(make_tuple_args.end(), fg_list_.begin(), fg_list_.end());
433     TraceGuard guard1(std::make_shared<TraceCopy>(make_tuple_cnode->debug_info()));
434     // {MakeTuple, G1, G2, ...}
435     auto new_make_tuple_cnode = old_cnode->func_graph()->NewCNode(make_tuple_args);
436 
437     TraceGuard guard2(std::make_shared<TraceCopy>(switch_layer_cnode->debug_info()));
438     // {SwitchLayer, cond, MakeTuple{}}
439     auto new_switch_layer_cnode = old_cnode->func_graph()->NewCNode(
440       {switch_layer_cnode->input(0), switch_layer_cnode->input(1), new_make_tuple_cnode});
441     AnfNodePtrList args{new_switch_layer_cnode};
442     (void)std::copy(anchor_partial_args.begin(), anchor_partial_args.end(), std::back_inserter(args));
443     (void)std::copy(extra_args.begin(), extra_args.end(), std::back_inserter(args));
444     // Zs
445     if (old_cnode->size() >= 2) {
446       (void)std::copy(old_cnode->inputs().begin() + 1, old_cnode->inputs().end(), std::back_inserter(args));
447     }
448     TraceGuard guard3(std::make_shared<TraceCopy>(old_cnode->debug_info()));
449     auto new_node = old_cnode->func_graph()->NewCNode(args);
450     new_node->set_abstract(old_cnode->abstract());
451     return new_node;
452   }
453 };
454 }  // namespace irpass
455 }  // namespace opt
456 }  // namespace mindspore
457 #endif  // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_PARTIAL_ELIMINATE_H_
458