1 /**
2 * Copyright 2020-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/optimizer/graph_kernel/reorder_ops.h"
18 #include <memory>
19 #include <vector>
20 #include <string>
21 #include <unordered_set>
22 #include "base/core_ops.h"
23 #include "utils/utils.h"
24 #include "utils/log_adapter.h"
25 #include "backend/session/anf_runtime_algorithm.h"
26 #include "debug/anf_ir_dump.h"
27 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
28
29 namespace mindspore {
30 namespace opt {
31 namespace {
IsTypeInsensitive(const CNodePtr & node)32 bool IsTypeInsensitive(const CNodePtr &node) {
33 // Nodes that will change the input data type will not seen as type insensitive nodes.
34 static std::unordered_set<PrimitivePtr> type_insensitive_op_list{
35 prim::kPrimTransData, prim::kPrimTranspose, prim::kPrimExpandDims, prim::kPrimReshape,
36 prim::kPrimSqueeze, prim::kPrimTile, prim::kPrimNeg, prim::kPrimRelu,
37 prim::kPrimMaximum, prim::kPrimMinimum, prim::kPrimSelect};
38
39 return std::any_of(type_insensitive_op_list.begin(), type_insensitive_op_list.end(),
40 [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
41 }
42
43 enum CastType { CAST_UP, CAST_DOWN, CAST_OTHER };
GetCastType(const CNodePtr & node)44 CastType GetCastType(const CNodePtr &node) {
45 MS_EXCEPTION_IF_NULL(node);
46 if (!IsPrimitiveCNode(node, prim::kPrimCast)) {
47 MS_LOG(EXCEPTION) << "Only process for Cast!";
48 }
49 TypeId input_type = AnfAlgo::GetInputDeviceDataType(node, 0);
50 TypeId output_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
51 if (input_type == kNumberTypeFloat16 && output_type == kNumberTypeFloat32) {
52 return CAST_UP;
53 }
54 if (input_type == kNumberTypeFloat32 && output_type == kNumberTypeFloat16) {
55 return CAST_DOWN;
56 }
57 return CAST_OTHER;
58 }
59
GetOpDataInputIndexes(const CNodePtr & node)60 std::vector<size_t> GetOpDataInputIndexes(const CNodePtr &node) {
61 std::vector<size_t> op_input_indexes;
62 if (node == nullptr || !IsTypeInsensitive(node)) {
63 return op_input_indexes;
64 }
65
66 // Data input index starts from 0.
67 if (IsPrimitiveCNode(node, prim::kPrimMaximum) || IsPrimitiveCNode(node, prim::kPrimMinimum)) {
68 op_input_indexes = {0, 1};
69 } else if (IsPrimitiveCNode(node, prim::kPrimSelect)) {
70 op_input_indexes = {1, 2};
71 } else {
72 op_input_indexes = {0};
73 }
74 return op_input_indexes;
75 }
76
CheckInputTypeConsistent(const CNodePtr & node,const std::vector<size_t> & check_indexes,const TypeId & base_type)77 bool CheckInputTypeConsistent(const CNodePtr &node, const std::vector<size_t> &check_indexes, const TypeId &base_type) {
78 MS_EXCEPTION_IF_NULL(node);
79
80 // node's inputs at check_indexes should be of type base_type
81 for (const auto &index : check_indexes) {
82 if (AnfAlgo::GetInputDeviceDataType(node, index) != base_type) {
83 return false;
84 }
85 }
86 return true;
87 }
88
SetNodeInfo(const CNodePtr & orig_node,const CNodePtr & new_node,const NodeIOInfo & node_io_info)89 void SetNodeInfo(const CNodePtr &orig_node, const CNodePtr &new_node, const NodeIOInfo &node_io_info) {
90 MS_EXCEPTION_IF_NULL(orig_node);
91 MS_EXCEPTION_IF_NULL(new_node);
92
93 auto node_name = AnfAlgo::GetCNodeName(new_node);
94 auto orig_node_name = AnfAlgo::GetCNodeName(orig_node);
95 if (orig_node_name != node_name) {
96 MS_LOG(EXCEPTION) << "Can not process on different nodes " << orig_node_name << " and " << node_name;
97 }
98
99 AbstractBasePtr new_abstract{nullptr};
100 if (node_io_info.outputs_type.empty()) {
101 MS_LOG(EXCEPTION) << "Can not set empty output type of new node from " << orig_node->fullname_with_scope();
102 }
103 if (node_name == "Cast") {
104 auto node_input = AnfAlgo::GetInputNode(new_node, 0);
105 MS_EXCEPTION_IF_NULL(node_input);
106 MS_EXCEPTION_IF_NULL(node_input->abstract());
107 new_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_io_info.outputs_type[0]),
108 node_input->abstract()->BuildShape());
109 } else {
110 MS_EXCEPTION_IF_NULL(orig_node->abstract());
111 new_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_io_info.outputs_type[0]),
112 orig_node->abstract()->BuildShape());
113 }
114
115 // Set abstract info
116 new_node->set_abstract(new_abstract);
117 // Set attrs
118 AnfAlgo::CopyNodeAttrs(orig_node, new_node);
119 // Set kernel build info
120 new_node->set_kernel_info(std::make_shared<device::KernelInfo>());
121 kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder;
122 info_builder.SetInputsFormat(node_io_info.inputs_format);
123 info_builder.SetInputsDeviceType(node_io_info.inputs_type);
124 info_builder.SetOutputsFormat(node_io_info.outputs_format);
125 info_builder.SetOutputsDeviceType(node_io_info.outputs_type);
126 info_builder.SetKernelType(AnfAlgo::GetKernelType(orig_node));
127 info_builder.SetOpPattern(AnfAlgo::GetOpPattern(orig_node));
128 info_builder.SetFusionType(AnfAlgo::GetFusionType(orig_node));
129 info_builder.SetProcessor(AnfAlgo::GetProcessor(orig_node));
130 AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), new_node.get());
131 }
132 } // namespace
133
SetTypeInsensitiveNodeInputs(const CNodePtr & node,const std::vector<size_t> & indexes,const std::vector<AnfNodePtr> & new_input_at_indexes,std::vector<AnfNodePtr> * new_inputs)134 void ReorderOps::SetTypeInsensitiveNodeInputs(const CNodePtr &node, const std::vector<size_t> &indexes,
135 const std::vector<AnfNodePtr> &new_input_at_indexes,
136 std::vector<AnfNodePtr> *new_inputs) {
137 MS_EXCEPTION_IF_NULL(node);
138 MS_EXCEPTION_IF_NULL(new_inputs);
139 if (indexes.size() != new_input_at_indexes.size()) {
140 MS_LOG(EXCEPTION) << "indexes size " << indexes.size() << " is not equal to new_input_at_indexes size "
141 << new_input_at_indexes.size();
142 }
143
144 auto node_inputs_num = node->size();
145 if (node_inputs_num == 0) {
146 MS_LOG(EXCEPTION) << "Inputs num is 0 in node " << node->fullname_with_scope();
147 }
148
149 // node's inputs at indexes change to new_input_at_indexes
150 if (!new_inputs->empty()) {
151 new_inputs->resize(0);
152 }
153 new_inputs->push_back(node->input(0));
154 std::unordered_set<size_t> indexes_set(indexes.begin(), indexes.end());
155 size_t idx = 0;
156 for (size_t i = 1; i < node_inputs_num; ++i) {
157 size_t data_idx = i - 1;
158 if (indexes_set.find(data_idx) == indexes_set.end()) {
159 new_inputs->push_back(node->input(i));
160 } else {
161 new_inputs->push_back(new_input_at_indexes[idx++]);
162 }
163 }
164 }
165
SetTypeInsensitiveNodeInputsInfo(const CNodePtr & node,const std::vector<size_t> & indexes,const std::vector<AnfNodePtr> & input_at_indexes,NodeIOInfo * new_inputs_info,bool from_input)166 void ReorderOps::SetTypeInsensitiveNodeInputsInfo(const CNodePtr &node, const std::vector<size_t> &indexes,
167 const std::vector<AnfNodePtr> &input_at_indexes,
168 NodeIOInfo *new_inputs_info, bool from_input) {
169 MS_EXCEPTION_IF_NULL(node);
170 MS_EXCEPTION_IF_NULL(new_inputs_info);
171 if (indexes.size() != input_at_indexes.size()) {
172 MS_LOG(EXCEPTION) << "indexes size " << indexes.size() << " is not equal to new_input_at_indexes size "
173 << input_at_indexes.size();
174 }
175
176 auto node_inputs_num = node->size();
177 if (node_inputs_num == 0) {
178 MS_LOG(EXCEPTION) << "Inputs num is 0 in node " << node->fullname_with_scope();
179 }
180
181 // node's inputs info at indexes change to input_at_indexes's input or output info
182 new_inputs_info->inputs_format.resize(0);
183 new_inputs_info->inputs_type.resize(0);
184 std::unordered_set<size_t> indexes_set(indexes.begin(), indexes.end());
185 size_t idx = 0;
186 for (size_t data_idx = 0; data_idx < node_inputs_num - 1; ++data_idx) {
187 if (indexes_set.find(data_idx) == indexes_set.end()) {
188 new_inputs_info->inputs_format.push_back(AnfAlgo::GetInputFormat(node, data_idx));
189 new_inputs_info->inputs_type.push_back(AnfAlgo::GetInputDeviceDataType(node, data_idx));
190 } else {
191 if (from_input) {
192 new_inputs_info->inputs_format.push_back(AnfAlgo::GetInputFormat(input_at_indexes[idx], 0));
193 new_inputs_info->inputs_type.push_back(AnfAlgo::GetInputDeviceDataType(input_at_indexes[idx], 0));
194 } else {
195 new_inputs_info->inputs_format.push_back(AnfAlgo::GetOutputFormat(input_at_indexes[idx], 0));
196 new_inputs_info->inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(input_at_indexes[idx], 0));
197 }
198 idx++;
199 }
200 }
201 }
202
ReorderTypeInsensitiveCastDown(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & mng,const CNodePtr & node)203 bool ReorderOps::ReorderTypeInsensitiveCastDown(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
204 const CNodePtr &node) {
205 // Limitation:
206 // Current cast node is CAST_DOWN.
207 // Cast node will not change the input format.
208 if (!IsPrimitiveCNode(node, prim::kPrimCast) || GetCastType(node) != CAST_DOWN ||
209 AnfAlgo::GetInputFormat(node, 0) != AnfAlgo::GetOutputFormat(node, 0)) {
210 return false;
211 }
212
213 auto large_type = AnfAlgo::GetInputDeviceDataType(node, 0);
214 auto small_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
215 auto pattern_output_format = AnfAlgo::GetOutputFormat(node, 0);
216
217 auto node_input = AnfAlgo::GetInputNode(node, 0);
218 auto type_insens_node = node_input->cast<CNodePtr>();
219 // Limitation:
220 // Find type insensitive node before cast node.
221 // Type insensitive node is only used by current cast node.
222 if (type_insens_node == nullptr || !IsTypeInsensitive(type_insens_node) ||
223 mng->node_users()[type_insens_node].size() > 1) {
224 return false;
225 }
226
227 auto op_input_indexes = GetOpDataInputIndexes(type_insens_node);
228 // Limitation: Type insensitive node's inputs are the large type.
229 if (op_input_indexes.empty() || !CheckInputTypeConsistent(type_insens_node, op_input_indexes, large_type)) {
230 return false;
231 }
232
233 std::vector<AnfNodePtr> new_cast_nodes;
234 for (const auto &index : op_input_indexes) {
235 auto new_cast_node =
236 func_graph->NewCNode({NewValueNode(prim::kPrimCast), AnfAlgo::GetInputNode(type_insens_node, index)});
237 NodeIOInfo cast_io_info;
238 cast_io_info.inputs_format.push_back(AnfAlgo::GetInputFormat(type_insens_node, index));
239 cast_io_info.outputs_format = cast_io_info.inputs_format;
240 cast_io_info.inputs_type.push_back(AnfAlgo::GetInputDeviceDataType(type_insens_node, index));
241 cast_io_info.outputs_type.push_back(small_type);
242 SetNodeInfo(node, new_cast_node, cast_io_info);
243 new_cast_nodes.push_back(new_cast_node);
244 }
245
246 std::vector<AnfNodePtr> type_insens_node_new_inputs;
247 SetTypeInsensitiveNodeInputs(type_insens_node, op_input_indexes, new_cast_nodes, &type_insens_node_new_inputs);
248 NodeIOInfo type_insens_io_info;
249 type_insens_io_info.outputs_format.push_back(pattern_output_format);
250 type_insens_io_info.outputs_type.push_back(small_type);
251 SetTypeInsensitiveNodeInputsInfo(type_insens_node, op_input_indexes, new_cast_nodes, &type_insens_io_info, false);
252 auto new_type_insens_node = func_graph->NewCNode(type_insens_node_new_inputs);
253 SetNodeInfo(type_insens_node, new_type_insens_node, type_insens_io_info);
254
255 (void)mng->Replace(node, new_type_insens_node);
256 return true;
257 }
258
ReorderCastUpTypeInsensitive(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & mng,const CNodePtr & node)259 bool ReorderOps::ReorderCastUpTypeInsensitive(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
260 const CNodePtr &node) {
261 if (!IsTypeInsensitive(node)) {
262 return false;
263 }
264
265 // Limitation:
266 // Certain inputs of type insensitive node are cast node.
267 // Cast nodes are CAST_UP.
268 // Cast nodes will not change the input format.
269 // All these cast nodes are only used by current type insensitive node.
270 std::vector<AnfNodePtr> cast_nodes;
271 std::vector<AnfNodePtr> cast_input_nodes;
272 auto op_input_indexes = GetOpDataInputIndexes(node);
273 for (const auto &index : op_input_indexes) {
274 auto node_input = AnfAlgo::GetInputNode(node, index);
275 auto cast_node = node_input->cast<CNodePtr>();
276 if (cast_node != nullptr && IsPrimitiveCNode(cast_node, prim::kPrimCast) && GetCastType(cast_node) == CAST_UP &&
277 AnfAlgo::GetInputFormat(node, 0) == AnfAlgo::GetOutputFormat(node, 0) &&
278 mng->node_users()[cast_node].size() == 1) {
279 cast_nodes.push_back(cast_node);
280 cast_input_nodes.push_back(AnfAlgo::GetInputNode(cast_node, 0));
281 }
282 }
283 if (cast_nodes.empty() || cast_nodes.size() != op_input_indexes.size()) {
284 return false;
285 }
286
287 auto small_type = AnfAlgo::GetInputDeviceDataType(cast_nodes[0], 0);
288 auto large_type = AnfAlgo::GetOutputDeviceDataType(cast_nodes[0], 0);
289 auto pattern_output_format = AnfAlgo::GetOutputFormat(node, 0);
290
291 // Limitation: All these cast nodes cast same type to another type.
292 if (!std::all_of(cast_nodes.begin(), cast_nodes.end(), [&small_type](const AnfNodePtr &cast_node) {
293 return AnfAlgo::GetInputDeviceDataType(cast_node, 0) == small_type;
294 })) {
295 return false;
296 }
297 // Limitation: Type insensitive node's inputs have same data type.
298 if (!CheckInputTypeConsistent(node, op_input_indexes, large_type)) {
299 return false;
300 }
301
302 std::vector<AnfNodePtr> type_insens_node_new_inputs;
303 SetTypeInsensitiveNodeInputs(node, op_input_indexes, cast_input_nodes, &type_insens_node_new_inputs);
304 auto new_type_insens_node = func_graph->NewCNode(type_insens_node_new_inputs);
305 NodeIOInfo type_insens_io_info;
306 type_insens_io_info.outputs_format.push_back(pattern_output_format);
307 type_insens_io_info.outputs_type.push_back(small_type);
308 SetTypeInsensitiveNodeInputsInfo(node, op_input_indexes, cast_nodes, &type_insens_io_info, true);
309 SetNodeInfo(node, new_type_insens_node, type_insens_io_info);
310
311 auto new_cast_node = func_graph->NewCNode({NewValueNode(prim::kPrimCast), new_type_insens_node});
312 NodeIOInfo cast_io_info;
313 cast_io_info.inputs_format.push_back(pattern_output_format);
314 cast_io_info.outputs_format = cast_io_info.inputs_format;
315 cast_io_info.inputs_type.push_back(small_type);
316 cast_io_info.outputs_type.push_back(large_type);
317 SetNodeInfo(cast_nodes[0]->cast<CNodePtr>(), new_cast_node, cast_io_info);
318
319 (void)mng->Replace(node, new_cast_node);
320 return true;
321 }
322
ReorderCastTypeInsensitive(const FuncGraphPtr & func_graph)323 bool ReorderOps::ReorderCastTypeInsensitive(const FuncGraphPtr &func_graph) {
324 // Reorder cast node and type insensitive node in graph kernel sub-graph, this function has several limitations,
325 // see the comments that start will "Limitation:" in this file.
326 // Limitation: Assuming the type insensitive node will not change the type of input nodes, otherwise it can be seen
327 // as another cast node in some sense, such as LessEqual operator, which performs on two inputs and output a
328 // a boolean result.
329 auto mng = GetFuncGraphManager(func_graph);
330 bool changed = false;
331 auto todos = TopoSort(func_graph->get_return());
332 for (const auto &anf_node : todos) {
333 auto node = anf_node->cast<CNodePtr>();
334 if (node == nullptr) {
335 continue;
336 }
337
338 if (IsTypeInsensitive(node)) {
339 // Reorder pattern 1: CastUp-TypeInsensitive --> TypeInsensitive-CastUp
340 changed = ReorderCastUpTypeInsensitive(func_graph, mng, node) || changed;
341 } else if (IsPrimitiveCNode(node, prim::kPrimCast)) {
342 // Reorder pattern 2: TypeInsensitive-CastDown --> CastDown-TypeInsensitive
343 changed = ReorderTypeInsensitiveCastDown(func_graph, mng, node) || changed;
344 }
345 }
346
347 return changed;
348 }
349
Run(const FuncGraphPtr & func_graph)350 bool ReorderOps::Run(const FuncGraphPtr &func_graph) {
351 bool changed = false;
352 auto todos = TopoSort(func_graph->get_return());
353 for (const auto &anf_node : todos) {
354 auto node = anf_node->cast<CNodePtr>();
355 if (node == nullptr) {
356 continue;
357 }
358
359 if (AnfAlgo::IsGraphKernel(node)) {
360 auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
361 bool need_traverse = true;
362 while (need_traverse) {
363 need_traverse = ReorderCastTypeInsensitive(sub_func_graph);
364 if (need_traverse) {
365 changed = true;
366 }
367 }
368 }
369 }
370
371 return changed;
372 }
373 } // namespace opt
374 } // namespace mindspore
375