• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "backend/common/graph_kernel/floatstatus_addn_fusion.h"
17 #include <algorithm>
18 #include <memory>
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22 #include "backend/common/graph_kernel/adapter/expander.h"
23 #include "backend/common/graph_kernel/core/graph_builder.h"
24 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
25 #include "backend/common/graph_kernel/graph_kernel_flags.h"
26 #include "backend/common/graph_kernel/graph_kernel_helper.h"
27 #include "include/backend/anf_runtime_algorithm.h"
28 #include "include/common/utils/anfalgo.h"
29 #include "mindspore/core/ops/array_ops.h"
30 #include "mindspore/core/ops/framework_ops.h"
31 #include "mindspore/core/ops/math_ops.h"
32 #include "mindspore/core/ops/sequence_ops.h"
33 
34 namespace mindspore::graphkernel {
35 namespace {
36 constexpr auto kNameAddN = "AddN";
37 constexpr auto kNameFloatStatus = "FloatStatus";
38 
CanConvert()39 bool CanConvert() {
40   const auto &flags = GraphKernelFlags::GetInstance();
41   if (!flags.enable_expand_ops_only.empty()) {
42     std::unordered_set<std::string> all_ops(flags.enable_expand_ops_only.begin(), flags.enable_expand_ops_only.end());
43     return all_ops.find(kNameAddN) != all_ops.end() && all_ops.find(kNameFloatStatus) != all_ops.end();
44   }
45   if (!flags.disable_expand_ops.empty()) {
46     auto find_target = std::find_if(flags.disable_expand_ops.begin(), flags.disable_expand_ops.end(),
47                                     [](const std::string &op) { return op == kNameAddN || op == kNameFloatStatus; });
48     return find_target == flags.disable_expand_ops.end();
49   }
50   return true;
51 }
52 
SubGraphSignleOutput(const AnfNodePtr & anf_node)53 InplaceAssignerInfo SubGraphSignleOutput(const AnfNodePtr &anf_node) {
54   InplaceAssignerInfo new_op_info;
55   auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(anf_node);
56   auto output = sub_graph->output();
57   new_op_info.op_node = output->cast<CNodePtr>();
58   return new_op_info;
59 }
60 }  // namespace
61 
ProcessFloatStatusAddN(const FuncGraphPtr & main_graph,const CNodePtr & addn,const FuncGraphManagerPtr & mng)62 void FloatStatusAddNFusion::ProcessFloatStatusAddN(const FuncGraphPtr &main_graph, const CNodePtr &addn,
63                                                    const FuncGraphManagerPtr &mng) {
64   mindspore::HashSet<AnfNodePtr> visited_nodes;
65   std::unordered_set<size_t> input_not_convert;
66 
67   for (size_t i = 1; i < addn->size(); i++) {
68     if (visited_nodes.find(addn->input(i)) != visited_nodes.end()) {
69       (void)input_not_convert.insert(i);
70       continue;
71     }
72     (void)visited_nodes.insert(addn->input(i));
73   }
74 
75   // Expand floatstatus to subgraph
76   for (size_t i = 1; i < addn->size(); i++) {
77     if (input_not_convert.count(i) > 0) {
78       continue;
79     }
80     auto floatstatus = addn->input(i)->cast<CNodePtr>();
81     auto expand_fg = GetCNodeFuncGraph(graphkernel::GetExpander(floatstatus, false)->Run(floatstatus));
82     MS_EXCEPTION_IF_NULL(expand_fg);
83     expand_fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfUtils::GetCNodeName(floatstatus)));
84     std::vector<AnfNodePtr> inputs(floatstatus->inputs().begin() + 1, floatstatus->inputs().end());
85     auto graph_kernel_node = CreateNewFuseCNode(main_graph, expand_fg, inputs);
86     (void)mng->Replace(floatstatus, graph_kernel_node);
87   }
88 
89   // Create broadcast node.
90   InplaceAssignerInfo op_info = SubGraphSignleOutput(addn->input(1));
91   auto out_type = GetType(op_info.op_node)->cast<TensorTypePtr>();
92   MS_EXCEPTION_IF_NULL(out_type);
93   auto broadcast_to_node = CreateCleanCompositeNode(op_info, main_graph, out_type->element()->type_id());
94 
95   // Insert extra input(broadcast node output) to composite node, and make elemany inplace-assign to it.
96   for (size_t i = 1; i < addn->size(); i++) {
97     if (input_not_convert.count(i) > 0) {
98       continue;
99     }
100     op_info = SubGraphSignleOutput(addn->input(i));
101     ProcessOriginCNode(addn->input(i), {{op_info, broadcast_to_node}});
102   }
103 
104   // Insert MakeTuple
105   AnfNodePtrList maketuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
106   (void)maketuple_inputs.insert(maketuple_inputs.end(), addn->inputs().begin() + 1, addn->inputs().end());
107   AbstractBasePtrList out_abs_list;
108   (void)std::transform(addn->inputs().begin() + 1, addn->inputs().end(), std::back_inserter(out_abs_list),
109                        [](const AnfNodePtr &node) { return node->abstract(); });
110   auto maketuple_node = main_graph->NewCNode(maketuple_inputs);
111   maketuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(out_abs_list));
112   main_graph->AddNode(maketuple_node);
113 
114   // Insert Depend
115   AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), broadcast_to_node, maketuple_node};
116   auto depend_node = main_graph->NewCNode(depend_inputs);
117   depend_node->set_abstract(broadcast_to_node->abstract());
118   main_graph->AddNode(depend_node);
119 
120   // Remove AddN
121   (void)mng->Replace(addn, depend_node);
122 }
123 
Run(const FuncGraphPtr & func_graph)124 bool FloatStatusAddNFusion::Run(const FuncGraphPtr &func_graph) {
125   auto mng = func_graph->manager();
126   MS_EXCEPTION_IF_NULL(mng);
127   auto changed = false;
128   if (!CanConvert()) {
129     return changed;
130   }
131   auto nodes = TopoSort(func_graph->get_return());
132   for (auto node : nodes) {
133     if (!IsPrimitiveCNode(node, prim::kPrimAddN) || common::AnfAlgo::IsDynamicShape(node)) {
134       continue;
135     }
136     auto cnode = node->cast<CNodePtr>();
137     MS_EXCEPTION_IF_NULL(cnode);
138     bool pattern_match =
139       std::all_of(cnode->inputs().begin() + 1, cnode->inputs().end(), [](const AnfNodePtr &anf_node) {
140         return IsPrimitiveCNode(anf_node, prim::kPrimFloatStatus) &&
141                (!common::AnfAlgo::IsDynamicShape(anf_node) ||
142                 GraphKernelFlags::GetInstance().kernel_generator == "DVM");
143       });
144     if (!pattern_match) {
145       continue;
146     }
147     ProcessFloatStatusAddN(func_graph, cnode, mng);
148     changed = true;
149   }
150 
151   if (changed) {
152     GkUtils::UpdateFuncGraphManager(mng, func_graph);
153   }
154 
155   return changed;
156 }
157 }  // namespace mindspore::graphkernel
158