1 /**
2 * Copyright 2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "backend/common/graph_kernel/floatstatus_addn_fusion.h"
17 #include <algorithm>
18 #include <memory>
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22 #include "backend/common/graph_kernel/adapter/expander.h"
23 #include "backend/common/graph_kernel/core/graph_builder.h"
24 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
25 #include "backend/common/graph_kernel/graph_kernel_flags.h"
26 #include "backend/common/graph_kernel/graph_kernel_helper.h"
27 #include "include/backend/anf_runtime_algorithm.h"
28 #include "include/common/utils/anfalgo.h"
29 #include "mindspore/core/ops/array_ops.h"
30 #include "mindspore/core/ops/framework_ops.h"
31 #include "mindspore/core/ops/math_ops.h"
32 #include "mindspore/core/ops/sequence_ops.h"
33
34 namespace mindspore::graphkernel {
35 namespace {
36 constexpr auto kNameAddN = "AddN";
37 constexpr auto kNameFloatStatus = "FloatStatus";
38
CanConvert()39 bool CanConvert() {
40 const auto &flags = GraphKernelFlags::GetInstance();
41 if (!flags.enable_expand_ops_only.empty()) {
42 std::unordered_set<std::string> all_ops(flags.enable_expand_ops_only.begin(), flags.enable_expand_ops_only.end());
43 return all_ops.find(kNameAddN) != all_ops.end() && all_ops.find(kNameFloatStatus) != all_ops.end();
44 }
45 if (!flags.disable_expand_ops.empty()) {
46 auto find_target = std::find_if(flags.disable_expand_ops.begin(), flags.disable_expand_ops.end(),
47 [](const std::string &op) { return op == kNameAddN || op == kNameFloatStatus; });
48 return find_target == flags.disable_expand_ops.end();
49 }
50 return true;
51 }
52
SubGraphSignleOutput(const AnfNodePtr & anf_node)53 InplaceAssignerInfo SubGraphSignleOutput(const AnfNodePtr &anf_node) {
54 InplaceAssignerInfo new_op_info;
55 auto sub_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(anf_node);
56 auto output = sub_graph->output();
57 new_op_info.op_node = output->cast<CNodePtr>();
58 return new_op_info;
59 }
60 } // namespace
61
ProcessFloatStatusAddN(const FuncGraphPtr & main_graph,const CNodePtr & addn,const FuncGraphManagerPtr & mng)62 void FloatStatusAddNFusion::ProcessFloatStatusAddN(const FuncGraphPtr &main_graph, const CNodePtr &addn,
63 const FuncGraphManagerPtr &mng) {
64 mindspore::HashSet<AnfNodePtr> visited_nodes;
65 std::unordered_set<size_t> input_not_convert;
66
67 for (size_t i = 1; i < addn->size(); i++) {
68 if (visited_nodes.find(addn->input(i)) != visited_nodes.end()) {
69 (void)input_not_convert.insert(i);
70 continue;
71 }
72 (void)visited_nodes.insert(addn->input(i));
73 }
74
75 // Expand floatstatus to subgraph
76 for (size_t i = 1; i < addn->size(); i++) {
77 if (input_not_convert.count(i) > 0) {
78 continue;
79 }
80 auto floatstatus = addn->input(i)->cast<CNodePtr>();
81 auto expand_fg = GetCNodeFuncGraph(graphkernel::GetExpander(floatstatus, false)->Run(floatstatus));
82 MS_EXCEPTION_IF_NULL(expand_fg);
83 expand_fg->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(AnfUtils::GetCNodeName(floatstatus)));
84 std::vector<AnfNodePtr> inputs(floatstatus->inputs().begin() + 1, floatstatus->inputs().end());
85 auto graph_kernel_node = CreateNewFuseCNode(main_graph, expand_fg, inputs);
86 (void)mng->Replace(floatstatus, graph_kernel_node);
87 }
88
89 // Create broadcast node.
90 InplaceAssignerInfo op_info = SubGraphSignleOutput(addn->input(1));
91 auto out_type = GetType(op_info.op_node)->cast<TensorTypePtr>();
92 MS_EXCEPTION_IF_NULL(out_type);
93 auto broadcast_to_node = CreateCleanCompositeNode(op_info, main_graph, out_type->element()->type_id());
94
95 // Insert extra input(broadcast node output) to composite node, and make elemany inplace-assign to it.
96 for (size_t i = 1; i < addn->size(); i++) {
97 if (input_not_convert.count(i) > 0) {
98 continue;
99 }
100 op_info = SubGraphSignleOutput(addn->input(i));
101 ProcessOriginCNode(addn->input(i), {{op_info, broadcast_to_node}});
102 }
103
104 // Insert MakeTuple
105 AnfNodePtrList maketuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
106 (void)maketuple_inputs.insert(maketuple_inputs.end(), addn->inputs().begin() + 1, addn->inputs().end());
107 AbstractBasePtrList out_abs_list;
108 (void)std::transform(addn->inputs().begin() + 1, addn->inputs().end(), std::back_inserter(out_abs_list),
109 [](const AnfNodePtr &node) { return node->abstract(); });
110 auto maketuple_node = main_graph->NewCNode(maketuple_inputs);
111 maketuple_node->set_abstract(std::make_shared<abstract::AbstractTuple>(out_abs_list));
112 main_graph->AddNode(maketuple_node);
113
114 // Insert Depend
115 AnfNodePtrList depend_inputs = {NewValueNode(prim::kPrimDepend), broadcast_to_node, maketuple_node};
116 auto depend_node = main_graph->NewCNode(depend_inputs);
117 depend_node->set_abstract(broadcast_to_node->abstract());
118 main_graph->AddNode(depend_node);
119
120 // Remove AddN
121 (void)mng->Replace(addn, depend_node);
122 }
123
Run(const FuncGraphPtr & func_graph)124 bool FloatStatusAddNFusion::Run(const FuncGraphPtr &func_graph) {
125 auto mng = func_graph->manager();
126 MS_EXCEPTION_IF_NULL(mng);
127 auto changed = false;
128 if (!CanConvert()) {
129 return changed;
130 }
131 auto nodes = TopoSort(func_graph->get_return());
132 for (auto node : nodes) {
133 if (!IsPrimitiveCNode(node, prim::kPrimAddN) || common::AnfAlgo::IsDynamicShape(node)) {
134 continue;
135 }
136 auto cnode = node->cast<CNodePtr>();
137 MS_EXCEPTION_IF_NULL(cnode);
138 bool pattern_match =
139 std::all_of(cnode->inputs().begin() + 1, cnode->inputs().end(), [](const AnfNodePtr &anf_node) {
140 return IsPrimitiveCNode(anf_node, prim::kPrimFloatStatus) &&
141 (!common::AnfAlgo::IsDynamicShape(anf_node) ||
142 GraphKernelFlags::GetInstance().kernel_generator == "DVM");
143 });
144 if (!pattern_match) {
145 continue;
146 }
147 ProcessFloatStatusAddN(func_graph, cnode, mng);
148 changed = true;
149 }
150
151 if (changed) {
152 GkUtils::UpdateFuncGraphManager(mng, func_graph);
153 }
154
155 return changed;
156 }
157 } // namespace mindspore::graphkernel
158