• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/optimizer/graph_kernel/reorder_ops.h"
18 #include <memory>
19 #include <vector>
20 #include <string>
21 #include <unordered_set>
22 #include "base/core_ops.h"
23 #include "utils/utils.h"
24 #include "utils/log_adapter.h"
25 #include "backend/session/anf_runtime_algorithm.h"
26 #include "debug/anf_ir_dump.h"
27 #include "backend/optimizer/graph_kernel/graph_kernel_helper.h"
28 
29 namespace mindspore {
30 namespace opt {
31 namespace {
IsTypeInsensitive(const CNodePtr & node)32 bool IsTypeInsensitive(const CNodePtr &node) {
33   // Nodes that will change the input data type will not seen as type insensitive nodes.
34   static std::unordered_set<PrimitivePtr> type_insensitive_op_list{
35     prim::kPrimTransData, prim::kPrimTranspose, prim::kPrimExpandDims, prim::kPrimReshape,
36     prim::kPrimSqueeze,   prim::kPrimTile,      prim::kPrimNeg,        prim::kPrimRelu,
37     prim::kPrimMaximum,   prim::kPrimMinimum,   prim::kPrimSelect};
38 
39   return std::any_of(type_insensitive_op_list.begin(), type_insensitive_op_list.end(),
40                      [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
41 }
42 
43 enum CastType { CAST_UP, CAST_DOWN, CAST_OTHER };
GetCastType(const CNodePtr & node)44 CastType GetCastType(const CNodePtr &node) {
45   MS_EXCEPTION_IF_NULL(node);
46   if (!IsPrimitiveCNode(node, prim::kPrimCast)) {
47     MS_LOG(EXCEPTION) << "Only process for Cast!";
48   }
49   TypeId input_type = AnfAlgo::GetInputDeviceDataType(node, 0);
50   TypeId output_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
51   if (input_type == kNumberTypeFloat16 && output_type == kNumberTypeFloat32) {
52     return CAST_UP;
53   }
54   if (input_type == kNumberTypeFloat32 && output_type == kNumberTypeFloat16) {
55     return CAST_DOWN;
56   }
57   return CAST_OTHER;
58 }
59 
GetOpDataInputIndexes(const CNodePtr & node)60 std::vector<size_t> GetOpDataInputIndexes(const CNodePtr &node) {
61   std::vector<size_t> op_input_indexes;
62   if (node == nullptr || !IsTypeInsensitive(node)) {
63     return op_input_indexes;
64   }
65 
66   // Data input index starts from 0.
67   if (IsPrimitiveCNode(node, prim::kPrimMaximum) || IsPrimitiveCNode(node, prim::kPrimMinimum)) {
68     op_input_indexes = {0, 1};
69   } else if (IsPrimitiveCNode(node, prim::kPrimSelect)) {
70     op_input_indexes = {1, 2};
71   } else {
72     op_input_indexes = {0};
73   }
74   return op_input_indexes;
75 }
76 
CheckInputTypeConsistent(const CNodePtr & node,const std::vector<size_t> & check_indexes,const TypeId & base_type)77 bool CheckInputTypeConsistent(const CNodePtr &node, const std::vector<size_t> &check_indexes, const TypeId &base_type) {
78   MS_EXCEPTION_IF_NULL(node);
79 
80   // node's inputs at check_indexes should be of type base_type
81   for (const auto &index : check_indexes) {
82     if (AnfAlgo::GetInputDeviceDataType(node, index) != base_type) {
83       return false;
84     }
85   }
86   return true;
87 }
88 
SetNodeInfo(const CNodePtr & orig_node,const CNodePtr & new_node,const NodeIOInfo & node_io_info)89 void SetNodeInfo(const CNodePtr &orig_node, const CNodePtr &new_node, const NodeIOInfo &node_io_info) {
90   MS_EXCEPTION_IF_NULL(orig_node);
91   MS_EXCEPTION_IF_NULL(new_node);
92 
93   auto node_name = AnfAlgo::GetCNodeName(new_node);
94   auto orig_node_name = AnfAlgo::GetCNodeName(orig_node);
95   if (orig_node_name != node_name) {
96     MS_LOG(EXCEPTION) << "Can not process on different nodes " << orig_node_name << " and " << node_name;
97   }
98 
99   AbstractBasePtr new_abstract{nullptr};
100   if (node_io_info.outputs_type.empty()) {
101     MS_LOG(EXCEPTION) << "Can not set empty output type of new node from " << orig_node->fullname_with_scope();
102   }
103   if (node_name == "Cast") {
104     auto node_input = AnfAlgo::GetInputNode(new_node, 0);
105     MS_EXCEPTION_IF_NULL(node_input);
106     MS_EXCEPTION_IF_NULL(node_input->abstract());
107     new_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_io_info.outputs_type[0]),
108                                                               node_input->abstract()->BuildShape());
109   } else {
110     MS_EXCEPTION_IF_NULL(orig_node->abstract());
111     new_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_io_info.outputs_type[0]),
112                                                               orig_node->abstract()->BuildShape());
113   }
114 
115   // Set abstract info
116   new_node->set_abstract(new_abstract);
117   // Set attrs
118   AnfAlgo::CopyNodeAttrs(orig_node, new_node);
119   // Set kernel build info
120   new_node->set_kernel_info(std::make_shared<device::KernelInfo>());
121   kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder;
122   info_builder.SetInputsFormat(node_io_info.inputs_format);
123   info_builder.SetInputsDeviceType(node_io_info.inputs_type);
124   info_builder.SetOutputsFormat(node_io_info.outputs_format);
125   info_builder.SetOutputsDeviceType(node_io_info.outputs_type);
126   info_builder.SetKernelType(AnfAlgo::GetKernelType(orig_node));
127   info_builder.SetOpPattern(AnfAlgo::GetOpPattern(orig_node));
128   info_builder.SetFusionType(AnfAlgo::GetFusionType(orig_node));
129   info_builder.SetProcessor(AnfAlgo::GetProcessor(orig_node));
130   AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), new_node.get());
131 }
132 }  // namespace
133 
SetTypeInsensitiveNodeInputs(const CNodePtr & node,const std::vector<size_t> & indexes,const std::vector<AnfNodePtr> & new_input_at_indexes,std::vector<AnfNodePtr> * new_inputs)134 void ReorderOps::SetTypeInsensitiveNodeInputs(const CNodePtr &node, const std::vector<size_t> &indexes,
135                                               const std::vector<AnfNodePtr> &new_input_at_indexes,
136                                               std::vector<AnfNodePtr> *new_inputs) {
137   MS_EXCEPTION_IF_NULL(node);
138   MS_EXCEPTION_IF_NULL(new_inputs);
139   if (indexes.size() != new_input_at_indexes.size()) {
140     MS_LOG(EXCEPTION) << "indexes size " << indexes.size() << " is not equal to new_input_at_indexes size "
141                       << new_input_at_indexes.size();
142   }
143 
144   auto node_inputs_num = node->size();
145   if (node_inputs_num == 0) {
146     MS_LOG(EXCEPTION) << "Inputs num is 0 in node " << node->fullname_with_scope();
147   }
148 
149   // node's inputs at indexes change to new_input_at_indexes
150   if (!new_inputs->empty()) {
151     new_inputs->resize(0);
152   }
153   new_inputs->push_back(node->input(0));
154   std::unordered_set<size_t> indexes_set(indexes.begin(), indexes.end());
155   size_t idx = 0;
156   for (size_t i = 1; i < node_inputs_num; ++i) {
157     size_t data_idx = i - 1;
158     if (indexes_set.find(data_idx) == indexes_set.end()) {
159       new_inputs->push_back(node->input(i));
160     } else {
161       new_inputs->push_back(new_input_at_indexes[idx++]);
162     }
163   }
164 }
165 
SetTypeInsensitiveNodeInputsInfo(const CNodePtr & node,const std::vector<size_t> & indexes,const std::vector<AnfNodePtr> & input_at_indexes,NodeIOInfo * new_inputs_info,bool from_input)166 void ReorderOps::SetTypeInsensitiveNodeInputsInfo(const CNodePtr &node, const std::vector<size_t> &indexes,
167                                                   const std::vector<AnfNodePtr> &input_at_indexes,
168                                                   NodeIOInfo *new_inputs_info, bool from_input) {
169   MS_EXCEPTION_IF_NULL(node);
170   MS_EXCEPTION_IF_NULL(new_inputs_info);
171   if (indexes.size() != input_at_indexes.size()) {
172     MS_LOG(EXCEPTION) << "indexes size " << indexes.size() << " is not equal to new_input_at_indexes size "
173                       << input_at_indexes.size();
174   }
175 
176   auto node_inputs_num = node->size();
177   if (node_inputs_num == 0) {
178     MS_LOG(EXCEPTION) << "Inputs num is 0 in node " << node->fullname_with_scope();
179   }
180 
181   // node's inputs info at indexes change to input_at_indexes's input or output info
182   new_inputs_info->inputs_format.resize(0);
183   new_inputs_info->inputs_type.resize(0);
184   std::unordered_set<size_t> indexes_set(indexes.begin(), indexes.end());
185   size_t idx = 0;
186   for (size_t data_idx = 0; data_idx < node_inputs_num - 1; ++data_idx) {
187     if (indexes_set.find(data_idx) == indexes_set.end()) {
188       new_inputs_info->inputs_format.push_back(AnfAlgo::GetInputFormat(node, data_idx));
189       new_inputs_info->inputs_type.push_back(AnfAlgo::GetInputDeviceDataType(node, data_idx));
190     } else {
191       if (from_input) {
192         new_inputs_info->inputs_format.push_back(AnfAlgo::GetInputFormat(input_at_indexes[idx], 0));
193         new_inputs_info->inputs_type.push_back(AnfAlgo::GetInputDeviceDataType(input_at_indexes[idx], 0));
194       } else {
195         new_inputs_info->inputs_format.push_back(AnfAlgo::GetOutputFormat(input_at_indexes[idx], 0));
196         new_inputs_info->inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(input_at_indexes[idx], 0));
197       }
198       idx++;
199     }
200   }
201 }
202 
ReorderTypeInsensitiveCastDown(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & mng,const CNodePtr & node)203 bool ReorderOps::ReorderTypeInsensitiveCastDown(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
204                                                 const CNodePtr &node) {
205   // Limitation:
206   //   Current cast node is CAST_DOWN.
207   //   Cast node will not change the input format.
208   if (!IsPrimitiveCNode(node, prim::kPrimCast) || GetCastType(node) != CAST_DOWN ||
209       AnfAlgo::GetInputFormat(node, 0) != AnfAlgo::GetOutputFormat(node, 0)) {
210     return false;
211   }
212 
213   auto large_type = AnfAlgo::GetInputDeviceDataType(node, 0);
214   auto small_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
215   auto pattern_output_format = AnfAlgo::GetOutputFormat(node, 0);
216 
217   auto node_input = AnfAlgo::GetInputNode(node, 0);
218   auto type_insens_node = node_input->cast<CNodePtr>();
219   // Limitation:
220   //   Find type insensitive node before cast node.
221   //   Type insensitive node is only used by current cast node.
222   if (type_insens_node == nullptr || !IsTypeInsensitive(type_insens_node) ||
223       mng->node_users()[type_insens_node].size() > 1) {
224     return false;
225   }
226 
227   auto op_input_indexes = GetOpDataInputIndexes(type_insens_node);
228   // Limitation: Type insensitive node's inputs are the large type.
229   if (op_input_indexes.empty() || !CheckInputTypeConsistent(type_insens_node, op_input_indexes, large_type)) {
230     return false;
231   }
232 
233   std::vector<AnfNodePtr> new_cast_nodes;
234   for (const auto &index : op_input_indexes) {
235     auto new_cast_node =
236       func_graph->NewCNode({NewValueNode(prim::kPrimCast), AnfAlgo::GetInputNode(type_insens_node, index)});
237     NodeIOInfo cast_io_info;
238     cast_io_info.inputs_format.push_back(AnfAlgo::GetInputFormat(type_insens_node, index));
239     cast_io_info.outputs_format = cast_io_info.inputs_format;
240     cast_io_info.inputs_type.push_back(AnfAlgo::GetInputDeviceDataType(type_insens_node, index));
241     cast_io_info.outputs_type.push_back(small_type);
242     SetNodeInfo(node, new_cast_node, cast_io_info);
243     new_cast_nodes.push_back(new_cast_node);
244   }
245 
246   std::vector<AnfNodePtr> type_insens_node_new_inputs;
247   SetTypeInsensitiveNodeInputs(type_insens_node, op_input_indexes, new_cast_nodes, &type_insens_node_new_inputs);
248   NodeIOInfo type_insens_io_info;
249   type_insens_io_info.outputs_format.push_back(pattern_output_format);
250   type_insens_io_info.outputs_type.push_back(small_type);
251   SetTypeInsensitiveNodeInputsInfo(type_insens_node, op_input_indexes, new_cast_nodes, &type_insens_io_info, false);
252   auto new_type_insens_node = func_graph->NewCNode(type_insens_node_new_inputs);
253   SetNodeInfo(type_insens_node, new_type_insens_node, type_insens_io_info);
254 
255   (void)mng->Replace(node, new_type_insens_node);
256   return true;
257 }
258 
ReorderCastUpTypeInsensitive(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & mng,const CNodePtr & node)259 bool ReorderOps::ReorderCastUpTypeInsensitive(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
260                                               const CNodePtr &node) {
261   if (!IsTypeInsensitive(node)) {
262     return false;
263   }
264 
265   // Limitation:
266   //   Certain inputs of type insensitive node are cast node.
267   //   Cast nodes are CAST_UP.
268   //   Cast nodes will not change the input format.
269   //   All these cast nodes are only used by current type insensitive node.
270   std::vector<AnfNodePtr> cast_nodes;
271   std::vector<AnfNodePtr> cast_input_nodes;
272   auto op_input_indexes = GetOpDataInputIndexes(node);
273   for (const auto &index : op_input_indexes) {
274     auto node_input = AnfAlgo::GetInputNode(node, index);
275     auto cast_node = node_input->cast<CNodePtr>();
276     if (cast_node != nullptr && IsPrimitiveCNode(cast_node, prim::kPrimCast) && GetCastType(cast_node) == CAST_UP &&
277         AnfAlgo::GetInputFormat(node, 0) == AnfAlgo::GetOutputFormat(node, 0) &&
278         mng->node_users()[cast_node].size() == 1) {
279       cast_nodes.push_back(cast_node);
280       cast_input_nodes.push_back(AnfAlgo::GetInputNode(cast_node, 0));
281     }
282   }
283   if (cast_nodes.empty() || cast_nodes.size() != op_input_indexes.size()) {
284     return false;
285   }
286 
287   auto small_type = AnfAlgo::GetInputDeviceDataType(cast_nodes[0], 0);
288   auto large_type = AnfAlgo::GetOutputDeviceDataType(cast_nodes[0], 0);
289   auto pattern_output_format = AnfAlgo::GetOutputFormat(node, 0);
290 
291   // Limitation: All these cast nodes cast same type to another type.
292   if (!std::all_of(cast_nodes.begin(), cast_nodes.end(), [&small_type](const AnfNodePtr &cast_node) {
293         return AnfAlgo::GetInputDeviceDataType(cast_node, 0) == small_type;
294       })) {
295     return false;
296   }
297   // Limitation: Type insensitive node's inputs have same data type.
298   if (!CheckInputTypeConsistent(node, op_input_indexes, large_type)) {
299     return false;
300   }
301 
302   std::vector<AnfNodePtr> type_insens_node_new_inputs;
303   SetTypeInsensitiveNodeInputs(node, op_input_indexes, cast_input_nodes, &type_insens_node_new_inputs);
304   auto new_type_insens_node = func_graph->NewCNode(type_insens_node_new_inputs);
305   NodeIOInfo type_insens_io_info;
306   type_insens_io_info.outputs_format.push_back(pattern_output_format);
307   type_insens_io_info.outputs_type.push_back(small_type);
308   SetTypeInsensitiveNodeInputsInfo(node, op_input_indexes, cast_nodes, &type_insens_io_info, true);
309   SetNodeInfo(node, new_type_insens_node, type_insens_io_info);
310 
311   auto new_cast_node = func_graph->NewCNode({NewValueNode(prim::kPrimCast), new_type_insens_node});
312   NodeIOInfo cast_io_info;
313   cast_io_info.inputs_format.push_back(pattern_output_format);
314   cast_io_info.outputs_format = cast_io_info.inputs_format;
315   cast_io_info.inputs_type.push_back(small_type);
316   cast_io_info.outputs_type.push_back(large_type);
317   SetNodeInfo(cast_nodes[0]->cast<CNodePtr>(), new_cast_node, cast_io_info);
318 
319   (void)mng->Replace(node, new_cast_node);
320   return true;
321 }
322 
ReorderCastTypeInsensitive(const FuncGraphPtr & func_graph)323 bool ReorderOps::ReorderCastTypeInsensitive(const FuncGraphPtr &func_graph) {
324   // Reorder cast node and type insensitive node in graph kernel sub-graph, this function has several limitations,
325   //   see the comments that start will "Limitation:" in this file.
326   // Limitation: Assuming the type insensitive node will not change the type of input nodes, otherwise it can be seen
327   //   as another cast node in some sense, such as LessEqual operator, which performs on two inputs and output a
328   //   a boolean result.
329   auto mng = GetFuncGraphManager(func_graph);
330   bool changed = false;
331   auto todos = TopoSort(func_graph->get_return());
332   for (const auto &anf_node : todos) {
333     auto node = anf_node->cast<CNodePtr>();
334     if (node == nullptr) {
335       continue;
336     }
337 
338     if (IsTypeInsensitive(node)) {
339       // Reorder pattern 1: CastUp-TypeInsensitive --> TypeInsensitive-CastUp
340       changed = ReorderCastUpTypeInsensitive(func_graph, mng, node) || changed;
341     } else if (IsPrimitiveCNode(node, prim::kPrimCast)) {
342       // Reorder pattern 2: TypeInsensitive-CastDown --> CastDown-TypeInsensitive
343       changed = ReorderTypeInsensitiveCastDown(func_graph, mng, node) || changed;
344     }
345   }
346 
347   return changed;
348 }
349 
Run(const FuncGraphPtr & func_graph)350 bool ReorderOps::Run(const FuncGraphPtr &func_graph) {
351   bool changed = false;
352   auto todos = TopoSort(func_graph->get_return());
353   for (const auto &anf_node : todos) {
354     auto node = anf_node->cast<CNodePtr>();
355     if (node == nullptr) {
356       continue;
357     }
358 
359     if (AnfAlgo::IsGraphKernel(node)) {
360       auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
361       bool need_traverse = true;
362       while (need_traverse) {
363         need_traverse = ReorderCastTypeInsensitive(sub_func_graph);
364         if (need_traverse) {
365           changed = true;
366         }
367       }
368     }
369   }
370 
371   return changed;
372 }
373 }  // namespace opt
374 }  // namespace mindspore
375