1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/common/graph_kernel/reorder_ops.h"
18 #include <memory>
19 #include <string>
20 #include <vector>
21 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
22 #include "include/backend/anf_runtime_algorithm.h"
23 #include "include/common/debug/anf_ir_dump.h"
24 #include "include/common/utils/anfalgo.h"
25 #include "mindspore/core/ops/array_ops.h"
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "mindspore/core/ops/math_ops.h"
28 #include "mindspore/core/ops/nn_optimizer_ops.h"
29 #include "mindspore/core/ops/auto_generate/gen_ops_primitive.h"
30 #include "utils/hash_set.h"
31 #include "utils/log_adapter.h"
32
33 namespace mindspore::graphkernel {
34 namespace {
IsTypeInsensitive(const CNodePtr & node)35 bool IsTypeInsensitive(const CNodePtr &node) {
36 // Nodes that will change the input data type will not seen as type insensitive nodes.
37 static mindspore::HashSet<PrimitivePtr> type_insensitive_op_list{
38 prim::kPrimTransData, prim::kPrimTranspose, prim::kPrimTransposeD, prim::kPrimExpandDims, prim::kPrimReshape,
39 prim::kPrimSqueeze, prim::kPrimTile, prim::kPrimNeg, prim::kPrimReLU, prim::kPrimRelu,
40 prim::kPrimMaximum, prim::kPrimMinimum, prim::kPrimSelect};
41
42 return std::any_of(type_insensitive_op_list.begin(), type_insensitive_op_list.end(),
43 [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
44 }
45
46 enum CastType { CAST_UP, CAST_DOWN, CAST_OTHER };
GetCastType(const CNodePtr & node)47 CastType GetCastType(const CNodePtr &node) {
48 MS_EXCEPTION_IF_NULL(node);
49 if (!IsPrimitiveCNode(node, prim::kPrimCast)) {
50 MS_LOG(EXCEPTION) << "Expect Cast node, but got " << common::AnfAlgo::GetCNodeName(node);
51 }
52 TypeId input_type = AnfAlgo::GetInputDeviceDataType(node, 0);
53 TypeId output_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
54 if (input_type == kNumberTypeFloat16 && output_type == kNumberTypeFloat32) {
55 return CAST_UP;
56 }
57 if (input_type == kNumberTypeFloat32 && output_type == kNumberTypeFloat16) {
58 return CAST_DOWN;
59 }
60 return CAST_OTHER;
61 }
62
GetOpDataInputIndexes(const CNodePtr & node)63 std::vector<size_t> GetOpDataInputIndexes(const CNodePtr &node) {
64 std::vector<size_t> op_input_indexes;
65 if (node == nullptr || !IsTypeInsensitive(node)) {
66 return op_input_indexes;
67 }
68
69 // Data input index starts from 0.
70 if (IsPrimitiveCNode(node, prim::kPrimMaximum) || IsPrimitiveCNode(node, prim::kPrimMinimum)) {
71 op_input_indexes = {0, 1};
72 } else if (IsPrimitiveCNode(node, prim::kPrimSelect)) {
73 op_input_indexes = {1, 2};
74 } else {
75 op_input_indexes = {0};
76 }
77 return op_input_indexes;
78 }
79
CheckInputTypeConsistent(const CNodePtr & node,const std::vector<size_t> & check_indexes,const TypeId & base_type)80 bool CheckInputTypeConsistent(const CNodePtr &node, const std::vector<size_t> &check_indexes, const TypeId &base_type) {
81 MS_EXCEPTION_IF_NULL(node);
82
83 // node's inputs at check_indexes should be of type base_type
84 for (const auto &index : check_indexes) {
85 if (AnfAlgo::GetInputDeviceDataType(node, index) != base_type) {
86 return false;
87 }
88 }
89 return true;
90 }
91
SetNodeInfo(const CNodePtr & orig_node,const CNodePtr & new_node,const NodeIOInfo & node_io_info)92 void SetNodeInfo(const CNodePtr &orig_node, const CNodePtr &new_node, const NodeIOInfo &node_io_info) {
93 MS_EXCEPTION_IF_NULL(orig_node);
94 MS_EXCEPTION_IF_NULL(new_node);
95
96 auto node_name = common::AnfAlgo::GetCNodeName(new_node);
97 auto orig_node_name = common::AnfAlgo::GetCNodeName(orig_node);
98 if (orig_node_name != node_name) {
99 MS_LOG(EXCEPTION) << "Can not process on different nodes " << orig_node_name << " and " << node_name;
100 }
101
102 AbstractBasePtr new_abstract{nullptr};
103 if (node_io_info.outputs_type.empty()) {
104 MS_LOG(EXCEPTION) << "Can not set empty output type of new node from " << orig_node->fullname_with_scope();
105 }
106 if (node_name == "Cast") {
107 auto node_input = common::AnfAlgo::GetInputNode(new_node, 0);
108 MS_EXCEPTION_IF_NULL(node_input);
109 MS_EXCEPTION_IF_NULL(node_input->abstract());
110 new_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_io_info.outputs_type[0]),
111 node_input->abstract()->BuildShape());
112 } else {
113 MS_EXCEPTION_IF_NULL(orig_node->abstract());
114 new_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_io_info.outputs_type[0]),
115 orig_node->abstract()->BuildShape());
116 }
117
118 // Set abstract info
119 new_node->set_abstract(new_abstract);
120 // Set attrs
121 common::AnfAlgo::CopyNodeAttrs(orig_node, new_node);
122 // Set kernel build info
123 new_node->set_kernel_info(std::make_shared<device::KernelInfo>());
124 kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder;
125 info_builder.SetInputsFormat(node_io_info.inputs_format);
126 info_builder.SetInputsDeviceType(node_io_info.inputs_type);
127 info_builder.SetOutputsFormat(node_io_info.outputs_format);
128 info_builder.SetOutputsDeviceType(node_io_info.outputs_type);
129 info_builder.SetKernelType(AnfAlgo::GetKernelType(orig_node));
130 info_builder.SetOpPattern(AnfAlgo::GetOpPattern(orig_node));
131 info_builder.SetFusionType(AnfAlgo::GetFusionType(orig_node));
132 info_builder.SetProcessor(AnfAlgo::GetProcessor(orig_node));
133 AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), new_node.get());
134 }
135 } // namespace
136
SetTypeInsensitiveNodeInputs(const CNodePtr & node,const std::vector<size_t> & indexes,const std::vector<AnfNodePtr> & new_input_at_indexes,std::vector<AnfNodePtr> * new_inputs) const137 void ReorderOps::SetTypeInsensitiveNodeInputs(const CNodePtr &node, const std::vector<size_t> &indexes,
138 const std::vector<AnfNodePtr> &new_input_at_indexes,
139 std::vector<AnfNodePtr> *new_inputs) const {
140 MS_EXCEPTION_IF_NULL(node);
141 MS_EXCEPTION_IF_NULL(new_inputs);
142 if (indexes.size() != new_input_at_indexes.size()) {
143 MS_LOG(EXCEPTION) << "indexes size " << indexes.size() << " is not equal to new_input_at_indexes size "
144 << new_input_at_indexes.size();
145 }
146
147 auto node_inputs_num = node->size();
148 if (node_inputs_num == 0) {
149 MS_LOG(EXCEPTION) << "Inputs num is 0 in node " << node->fullname_with_scope();
150 }
151
152 // node's inputs at indexes change to new_input_at_indexes
153 if (!new_inputs->empty()) {
154 new_inputs->resize(0);
155 }
156 new_inputs->push_back(node->input(0));
157 mindspore::HashSet<size_t> indexes_set(indexes.begin(), indexes.end());
158 size_t idx = 0;
159 for (size_t i = 1; i < node_inputs_num; ++i) {
160 size_t data_idx = i - 1;
161 if (indexes_set.find(data_idx) == indexes_set.end()) {
162 new_inputs->push_back(node->input(i));
163 } else {
164 new_inputs->push_back(new_input_at_indexes[idx++]);
165 }
166 }
167 }
168
SetTypeInsensitiveNodeInputsInfo(const CNodePtr & node,const std::vector<size_t> & indexes,const std::vector<AnfNodePtr> & input_at_indexes,NodeIOInfo * new_inputs_info,bool from_input) const169 void ReorderOps::SetTypeInsensitiveNodeInputsInfo(const CNodePtr &node, const std::vector<size_t> &indexes,
170 const std::vector<AnfNodePtr> &input_at_indexes,
171 NodeIOInfo *new_inputs_info, bool from_input) const {
172 MS_EXCEPTION_IF_NULL(node);
173 MS_EXCEPTION_IF_NULL(new_inputs_info);
174 if (indexes.size() != input_at_indexes.size()) {
175 MS_LOG(EXCEPTION) << "indexes size " << indexes.size() << " is not equal to new_input_at_indexes size "
176 << input_at_indexes.size();
177 }
178
179 auto node_inputs_num = node->size();
180 if (node_inputs_num == 0) {
181 MS_LOG(EXCEPTION) << "Inputs num is 0 in node " << node->fullname_with_scope();
182 }
183
184 // node's inputs info at indexes change to input_at_indexes's input or output info
185 new_inputs_info->inputs_format.resize(0);
186 new_inputs_info->inputs_type.resize(0);
187 mindspore::HashSet<size_t> indexes_set(indexes.begin(), indexes.end());
188 size_t idx = 0;
189 for (size_t data_idx = 0; data_idx < node_inputs_num - 1; ++data_idx) {
190 if (indexes_set.find(data_idx) == indexes_set.end()) {
191 new_inputs_info->inputs_format.push_back(AnfAlgo::GetInputFormat(node, data_idx));
192 new_inputs_info->inputs_type.push_back(AnfAlgo::GetInputDeviceDataType(node, data_idx));
193 } else {
194 if (from_input) {
195 new_inputs_info->inputs_format.push_back(AnfAlgo::GetInputFormat(input_at_indexes[idx], 0));
196 new_inputs_info->inputs_type.push_back(AnfAlgo::GetInputDeviceDataType(input_at_indexes[idx], 0));
197 } else {
198 new_inputs_info->inputs_format.push_back(AnfAlgo::GetOutputFormat(input_at_indexes[idx], 0));
199 new_inputs_info->inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(input_at_indexes[idx], 0));
200 }
201 idx++;
202 }
203 }
204 }
205
ReorderTypeInsensitiveCastDown(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & mng,const CNodePtr & node) const206 bool ReorderOps::ReorderTypeInsensitiveCastDown(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
207 const CNodePtr &node) const {
208 // Limitation:
209 // Current cast node is CAST_DOWN.
210 // Cast node will not change the input format.
211 if (!IsPrimitiveCNode(node, prim::kPrimCast) || GetCastType(node) != CAST_DOWN ||
212 AnfAlgo::GetInputFormat(node, 0) != AnfAlgo::GetOutputFormat(node, 0)) {
213 return false;
214 }
215
216 auto large_type = AnfAlgo::GetInputDeviceDataType(node, 0);
217 auto small_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
218 auto pattern_output_format = AnfAlgo::GetOutputFormat(node, 0);
219
220 auto node_input = common::AnfAlgo::GetInputNode(node, 0);
221 auto type_insens_node = node_input->cast<CNodePtr>();
222 // Limitation:
223 // Find type insensitive node before cast node.
224 // Type insensitive node is only used by current cast node.
225 if (type_insens_node == nullptr || !IsTypeInsensitive(type_insens_node) ||
226 mng->node_users()[type_insens_node].size() > 1) {
227 return false;
228 }
229
230 auto op_input_indexes = GetOpDataInputIndexes(type_insens_node);
231 // Limitation: Type insensitive node's inputs are the large type.
232 if (op_input_indexes.empty() || !CheckInputTypeConsistent(type_insens_node, op_input_indexes, large_type)) {
233 return false;
234 }
235
236 std::vector<AnfNodePtr> new_cast_nodes;
237 for (const auto &index : op_input_indexes) {
238 auto new_cast_node = func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())),
239 common::AnfAlgo::GetInputNode(type_insens_node, index)});
240 NodeIOInfo cast_io_info;
241 cast_io_info.inputs_format.push_back(AnfAlgo::GetInputFormat(type_insens_node, index));
242 cast_io_info.outputs_format = cast_io_info.inputs_format;
243 cast_io_info.inputs_type.push_back(AnfAlgo::GetInputDeviceDataType(type_insens_node, index));
244 cast_io_info.outputs_type.push_back(small_type);
245 SetNodeInfo(node, new_cast_node, cast_io_info);
246 new_cast_nodes.push_back(new_cast_node);
247 }
248
249 std::vector<AnfNodePtr> type_insens_node_new_inputs;
250 SetTypeInsensitiveNodeInputs(type_insens_node, op_input_indexes, new_cast_nodes, &type_insens_node_new_inputs);
251 NodeIOInfo type_insens_io_info;
252 type_insens_io_info.outputs_format.push_back(pattern_output_format);
253 type_insens_io_info.outputs_type.push_back(small_type);
254 SetTypeInsensitiveNodeInputsInfo(type_insens_node, op_input_indexes, new_cast_nodes, &type_insens_io_info, false);
255 auto new_type_insens_node = func_graph->NewCNode(type_insens_node_new_inputs);
256 SetNodeInfo(type_insens_node, new_type_insens_node, type_insens_io_info);
257
258 (void)mng->Replace(node, new_type_insens_node);
259 return true;
260 }
261
ReorderCastUpTypeInsensitive(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & mng,const CNodePtr & node) const262 bool ReorderOps::ReorderCastUpTypeInsensitive(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
263 const CNodePtr &node) const {
264 if (!IsTypeInsensitive(node)) {
265 return false;
266 }
267
268 // Limitation:
269 // Certain inputs of type insensitive node are cast node.
270 // Cast nodes are CAST_UP.
271 // Cast nodes will not change the input format.
272 // All these cast nodes are only used by current type insensitive node.
273 std::vector<AnfNodePtr> cast_nodes;
274 std::vector<AnfNodePtr> cast_input_nodes;
275 auto op_input_indexes = GetOpDataInputIndexes(node);
276 for (const auto &index : op_input_indexes) {
277 auto node_input = common::AnfAlgo::GetInputNode(node, index);
278 auto cast_node = node_input->cast<CNodePtr>();
279 if (cast_node != nullptr && IsPrimitiveCNode(cast_node, prim::kPrimCast) && GetCastType(cast_node) == CAST_UP &&
280 AnfAlgo::GetInputFormat(node, 0) == AnfAlgo::GetOutputFormat(node, 0) &&
281 mng->node_users()[cast_node].size() == 1) {
282 cast_nodes.push_back(cast_node);
283 cast_input_nodes.push_back(common::AnfAlgo::GetInputNode(cast_node, 0));
284 }
285 }
286 if (cast_nodes.empty() || cast_nodes.size() != op_input_indexes.size()) {
287 return false;
288 }
289
290 auto small_type = AnfAlgo::GetInputDeviceDataType(cast_nodes[0], 0);
291 auto large_type = AnfAlgo::GetOutputDeviceDataType(cast_nodes[0], 0);
292 auto pattern_output_format = AnfAlgo::GetOutputFormat(node, 0);
293
294 // Limitation: All these cast nodes cast same type to another type.
295 if (!std::all_of(cast_nodes.begin(), cast_nodes.end(), [&small_type](const AnfNodePtr &cast_node) {
296 return AnfAlgo::GetInputDeviceDataType(cast_node, 0) == small_type;
297 })) {
298 return false;
299 }
300 // Limitation: Type insensitive node's inputs have same data type.
301 if (!CheckInputTypeConsistent(node, op_input_indexes, large_type)) {
302 return false;
303 }
304
305 std::vector<AnfNodePtr> type_insens_node_new_inputs;
306 SetTypeInsensitiveNodeInputs(node, op_input_indexes, cast_input_nodes, &type_insens_node_new_inputs);
307 auto new_type_insens_node = func_graph->NewCNode(type_insens_node_new_inputs);
308 NodeIOInfo type_insens_io_info;
309 type_insens_io_info.outputs_format.push_back(pattern_output_format);
310 type_insens_io_info.outputs_type.push_back(small_type);
311 SetTypeInsensitiveNodeInputsInfo(node, op_input_indexes, cast_nodes, &type_insens_io_info, true);
312 SetNodeInfo(node, new_type_insens_node, type_insens_io_info);
313
314 auto new_cast_node =
315 func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())), new_type_insens_node});
316 NodeIOInfo cast_io_info;
317 cast_io_info.inputs_format.push_back(pattern_output_format);
318 cast_io_info.outputs_format = cast_io_info.inputs_format;
319 cast_io_info.inputs_type.push_back(small_type);
320 cast_io_info.outputs_type.push_back(large_type);
321 SetNodeInfo(cast_nodes[0]->cast<CNodePtr>(), new_cast_node, cast_io_info);
322
323 (void)mng->Replace(node, new_cast_node);
324 return true;
325 }
326
ReorderCastTypeInsensitive(const FuncGraphPtr & func_graph) const327 bool ReorderOps::ReorderCastTypeInsensitive(const FuncGraphPtr &func_graph) const {
328 // Reorder cast node and type insensitive node in graph kernel sub-graph, this function has several limitations,
329 // see the comments that start will "Limitation:" in this file.
330 // Limitation: Assuming the type insensitive node will not change the type of input nodes, otherwise it can be seen
331 // as another cast node in some sense, such as LessEqual operator, which performs on two inputs and output a
332 // a boolean result.
333 auto mng = GkUtils::GetFuncGraphManager(func_graph);
334 bool changed = false;
335 auto todos = TopoSort(func_graph->get_return());
336 for (const auto &anf_node : todos) {
337 auto node = anf_node->cast<CNodePtr>();
338 if (node == nullptr) {
339 continue;
340 }
341
342 if (IsTypeInsensitive(node)) {
343 // Reorder pattern 1: CastUp-TypeInsensitive --> TypeInsensitive-CastUp
344 changed = ReorderCastUpTypeInsensitive(func_graph, mng, node) || changed;
345 } else if (IsPrimitiveCNode(node, prim::kPrimCast)) {
346 // Reorder pattern 2: TypeInsensitive-CastDown --> CastDown-TypeInsensitive
347 changed = ReorderTypeInsensitiveCastDown(func_graph, mng, node) || changed;
348 }
349 }
350
351 return changed;
352 }
353
Run(const FuncGraphPtr & func_graph)354 bool ReorderOps::Run(const FuncGraphPtr &func_graph) {
355 bool changed = false;
356 auto todos = TopoSort(func_graph->get_return());
357 for (const auto &anf_node : todos) {
358 auto node = anf_node->cast<CNodePtr>();
359 if (node == nullptr) {
360 continue;
361 }
362
363 if (common::AnfAlgo::IsGraphKernel(node)) {
364 auto sub_func_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
365 bool need_traverse = true;
366 while (need_traverse) {
367 need_traverse = ReorderCastTypeInsensitive(sub_func_graph);
368 if (need_traverse) {
369 changed = true;
370 }
371 }
372 }
373 }
374
375 return changed;
376 }
377 } // namespace mindspore::graphkernel
378