• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/common/graph_kernel/reorder_ops.h"
18 #include <memory>
19 #include <string>
20 #include <vector>
21 #include "backend/common/graph_kernel/core/graph_kernel_utils.h"
22 #include "include/backend/anf_runtime_algorithm.h"
23 #include "include/common/debug/anf_ir_dump.h"
24 #include "include/common/utils/anfalgo.h"
25 #include "mindspore/core/ops/array_ops.h"
26 #include "mindspore/core/ops/framework_ops.h"
27 #include "mindspore/core/ops/math_ops.h"
28 #include "mindspore/core/ops/nn_optimizer_ops.h"
29 #include "mindspore/core/ops/auto_generate/gen_ops_primitive.h"
30 #include "utils/hash_set.h"
31 #include "utils/log_adapter.h"
32 
33 namespace mindspore::graphkernel {
34 namespace {
IsTypeInsensitive(const CNodePtr & node)35 bool IsTypeInsensitive(const CNodePtr &node) {
36   // Nodes that will change the input data type will not seen as type insensitive nodes.
37   static mindspore::HashSet<PrimitivePtr> type_insensitive_op_list{
38     prim::kPrimTransData, prim::kPrimTranspose, prim::kPrimTransposeD, prim::kPrimExpandDims, prim::kPrimReshape,
39     prim::kPrimSqueeze,   prim::kPrimTile,      prim::kPrimNeg,        prim::kPrimReLU,       prim::kPrimRelu,
40     prim::kPrimMaximum,   prim::kPrimMinimum,   prim::kPrimSelect};
41 
42   return std::any_of(type_insensitive_op_list.begin(), type_insensitive_op_list.end(),
43                      [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
44 }
45 
46 enum CastType { CAST_UP, CAST_DOWN, CAST_OTHER };
GetCastType(const CNodePtr & node)47 CastType GetCastType(const CNodePtr &node) {
48   MS_EXCEPTION_IF_NULL(node);
49   if (!IsPrimitiveCNode(node, prim::kPrimCast)) {
50     MS_LOG(EXCEPTION) << "Expect Cast node, but got " << common::AnfAlgo::GetCNodeName(node);
51   }
52   TypeId input_type = AnfAlgo::GetInputDeviceDataType(node, 0);
53   TypeId output_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
54   if (input_type == kNumberTypeFloat16 && output_type == kNumberTypeFloat32) {
55     return CAST_UP;
56   }
57   if (input_type == kNumberTypeFloat32 && output_type == kNumberTypeFloat16) {
58     return CAST_DOWN;
59   }
60   return CAST_OTHER;
61 }
62 
GetOpDataInputIndexes(const CNodePtr & node)63 std::vector<size_t> GetOpDataInputIndexes(const CNodePtr &node) {
64   std::vector<size_t> op_input_indexes;
65   if (node == nullptr || !IsTypeInsensitive(node)) {
66     return op_input_indexes;
67   }
68 
69   // Data input index starts from 0.
70   if (IsPrimitiveCNode(node, prim::kPrimMaximum) || IsPrimitiveCNode(node, prim::kPrimMinimum)) {
71     op_input_indexes = {0, 1};
72   } else if (IsPrimitiveCNode(node, prim::kPrimSelect)) {
73     op_input_indexes = {1, 2};
74   } else {
75     op_input_indexes = {0};
76   }
77   return op_input_indexes;
78 }
79 
CheckInputTypeConsistent(const CNodePtr & node,const std::vector<size_t> & check_indexes,const TypeId & base_type)80 bool CheckInputTypeConsistent(const CNodePtr &node, const std::vector<size_t> &check_indexes, const TypeId &base_type) {
81   MS_EXCEPTION_IF_NULL(node);
82 
83   // node's inputs at check_indexes should be of type base_type
84   for (const auto &index : check_indexes) {
85     if (AnfAlgo::GetInputDeviceDataType(node, index) != base_type) {
86       return false;
87     }
88   }
89   return true;
90 }
91 
SetNodeInfo(const CNodePtr & orig_node,const CNodePtr & new_node,const NodeIOInfo & node_io_info)92 void SetNodeInfo(const CNodePtr &orig_node, const CNodePtr &new_node, const NodeIOInfo &node_io_info) {
93   MS_EXCEPTION_IF_NULL(orig_node);
94   MS_EXCEPTION_IF_NULL(new_node);
95 
96   auto node_name = common::AnfAlgo::GetCNodeName(new_node);
97   auto orig_node_name = common::AnfAlgo::GetCNodeName(orig_node);
98   if (orig_node_name != node_name) {
99     MS_LOG(EXCEPTION) << "Can not process on different nodes " << orig_node_name << " and " << node_name;
100   }
101 
102   AbstractBasePtr new_abstract{nullptr};
103   if (node_io_info.outputs_type.empty()) {
104     MS_LOG(EXCEPTION) << "Can not set empty output type of new node from " << orig_node->fullname_with_scope();
105   }
106   if (node_name == "Cast") {
107     auto node_input = common::AnfAlgo::GetInputNode(new_node, 0);
108     MS_EXCEPTION_IF_NULL(node_input);
109     MS_EXCEPTION_IF_NULL(node_input->abstract());
110     new_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_io_info.outputs_type[0]),
111                                                               node_input->abstract()->BuildShape());
112   } else {
113     MS_EXCEPTION_IF_NULL(orig_node->abstract());
114     new_abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_io_info.outputs_type[0]),
115                                                               orig_node->abstract()->BuildShape());
116   }
117 
118   // Set abstract info
119   new_node->set_abstract(new_abstract);
120   // Set attrs
121   common::AnfAlgo::CopyNodeAttrs(orig_node, new_node);
122   // Set kernel build info
123   new_node->set_kernel_info(std::make_shared<device::KernelInfo>());
124   kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder;
125   info_builder.SetInputsFormat(node_io_info.inputs_format);
126   info_builder.SetInputsDeviceType(node_io_info.inputs_type);
127   info_builder.SetOutputsFormat(node_io_info.outputs_format);
128   info_builder.SetOutputsDeviceType(node_io_info.outputs_type);
129   info_builder.SetKernelType(AnfAlgo::GetKernelType(orig_node));
130   info_builder.SetOpPattern(AnfAlgo::GetOpPattern(orig_node));
131   info_builder.SetFusionType(AnfAlgo::GetFusionType(orig_node));
132   info_builder.SetProcessor(AnfAlgo::GetProcessor(orig_node));
133   AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), new_node.get());
134 }
135 }  // namespace
136 
SetTypeInsensitiveNodeInputs(const CNodePtr & node,const std::vector<size_t> & indexes,const std::vector<AnfNodePtr> & new_input_at_indexes,std::vector<AnfNodePtr> * new_inputs) const137 void ReorderOps::SetTypeInsensitiveNodeInputs(const CNodePtr &node, const std::vector<size_t> &indexes,
138                                               const std::vector<AnfNodePtr> &new_input_at_indexes,
139                                               std::vector<AnfNodePtr> *new_inputs) const {
140   MS_EXCEPTION_IF_NULL(node);
141   MS_EXCEPTION_IF_NULL(new_inputs);
142   if (indexes.size() != new_input_at_indexes.size()) {
143     MS_LOG(EXCEPTION) << "indexes size " << indexes.size() << " is not equal to new_input_at_indexes size "
144                       << new_input_at_indexes.size();
145   }
146 
147   auto node_inputs_num = node->size();
148   if (node_inputs_num == 0) {
149     MS_LOG(EXCEPTION) << "Inputs num is 0 in node " << node->fullname_with_scope();
150   }
151 
152   // node's inputs at indexes change to new_input_at_indexes
153   if (!new_inputs->empty()) {
154     new_inputs->resize(0);
155   }
156   new_inputs->push_back(node->input(0));
157   mindspore::HashSet<size_t> indexes_set(indexes.begin(), indexes.end());
158   size_t idx = 0;
159   for (size_t i = 1; i < node_inputs_num; ++i) {
160     size_t data_idx = i - 1;
161     if (indexes_set.find(data_idx) == indexes_set.end()) {
162       new_inputs->push_back(node->input(i));
163     } else {
164       new_inputs->push_back(new_input_at_indexes[idx++]);
165     }
166   }
167 }
168 
SetTypeInsensitiveNodeInputsInfo(const CNodePtr & node,const std::vector<size_t> & indexes,const std::vector<AnfNodePtr> & input_at_indexes,NodeIOInfo * new_inputs_info,bool from_input) const169 void ReorderOps::SetTypeInsensitiveNodeInputsInfo(const CNodePtr &node, const std::vector<size_t> &indexes,
170                                                   const std::vector<AnfNodePtr> &input_at_indexes,
171                                                   NodeIOInfo *new_inputs_info, bool from_input) const {
172   MS_EXCEPTION_IF_NULL(node);
173   MS_EXCEPTION_IF_NULL(new_inputs_info);
174   if (indexes.size() != input_at_indexes.size()) {
175     MS_LOG(EXCEPTION) << "indexes size " << indexes.size() << " is not equal to new_input_at_indexes size "
176                       << input_at_indexes.size();
177   }
178 
179   auto node_inputs_num = node->size();
180   if (node_inputs_num == 0) {
181     MS_LOG(EXCEPTION) << "Inputs num is 0 in node " << node->fullname_with_scope();
182   }
183 
184   // node's inputs info at indexes change to input_at_indexes's input or output info
185   new_inputs_info->inputs_format.resize(0);
186   new_inputs_info->inputs_type.resize(0);
187   mindspore::HashSet<size_t> indexes_set(indexes.begin(), indexes.end());
188   size_t idx = 0;
189   for (size_t data_idx = 0; data_idx < node_inputs_num - 1; ++data_idx) {
190     if (indexes_set.find(data_idx) == indexes_set.end()) {
191       new_inputs_info->inputs_format.push_back(AnfAlgo::GetInputFormat(node, data_idx));
192       new_inputs_info->inputs_type.push_back(AnfAlgo::GetInputDeviceDataType(node, data_idx));
193     } else {
194       if (from_input) {
195         new_inputs_info->inputs_format.push_back(AnfAlgo::GetInputFormat(input_at_indexes[idx], 0));
196         new_inputs_info->inputs_type.push_back(AnfAlgo::GetInputDeviceDataType(input_at_indexes[idx], 0));
197       } else {
198         new_inputs_info->inputs_format.push_back(AnfAlgo::GetOutputFormat(input_at_indexes[idx], 0));
199         new_inputs_info->inputs_type.push_back(AnfAlgo::GetOutputDeviceDataType(input_at_indexes[idx], 0));
200       }
201       idx++;
202     }
203   }
204 }
205 
ReorderTypeInsensitiveCastDown(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & mng,const CNodePtr & node) const206 bool ReorderOps::ReorderTypeInsensitiveCastDown(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
207                                                 const CNodePtr &node) const {
208   // Limitation:
209   //   Current cast node is CAST_DOWN.
210   //   Cast node will not change the input format.
211   if (!IsPrimitiveCNode(node, prim::kPrimCast) || GetCastType(node) != CAST_DOWN ||
212       AnfAlgo::GetInputFormat(node, 0) != AnfAlgo::GetOutputFormat(node, 0)) {
213     return false;
214   }
215 
216   auto large_type = AnfAlgo::GetInputDeviceDataType(node, 0);
217   auto small_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
218   auto pattern_output_format = AnfAlgo::GetOutputFormat(node, 0);
219 
220   auto node_input = common::AnfAlgo::GetInputNode(node, 0);
221   auto type_insens_node = node_input->cast<CNodePtr>();
222   // Limitation:
223   //   Find type insensitive node before cast node.
224   //   Type insensitive node is only used by current cast node.
225   if (type_insens_node == nullptr || !IsTypeInsensitive(type_insens_node) ||
226       mng->node_users()[type_insens_node].size() > 1) {
227     return false;
228   }
229 
230   auto op_input_indexes = GetOpDataInputIndexes(type_insens_node);
231   // Limitation: Type insensitive node's inputs are the large type.
232   if (op_input_indexes.empty() || !CheckInputTypeConsistent(type_insens_node, op_input_indexes, large_type)) {
233     return false;
234   }
235 
236   std::vector<AnfNodePtr> new_cast_nodes;
237   for (const auto &index : op_input_indexes) {
238     auto new_cast_node = func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())),
239                                                common::AnfAlgo::GetInputNode(type_insens_node, index)});
240     NodeIOInfo cast_io_info;
241     cast_io_info.inputs_format.push_back(AnfAlgo::GetInputFormat(type_insens_node, index));
242     cast_io_info.outputs_format = cast_io_info.inputs_format;
243     cast_io_info.inputs_type.push_back(AnfAlgo::GetInputDeviceDataType(type_insens_node, index));
244     cast_io_info.outputs_type.push_back(small_type);
245     SetNodeInfo(node, new_cast_node, cast_io_info);
246     new_cast_nodes.push_back(new_cast_node);
247   }
248 
249   std::vector<AnfNodePtr> type_insens_node_new_inputs;
250   SetTypeInsensitiveNodeInputs(type_insens_node, op_input_indexes, new_cast_nodes, &type_insens_node_new_inputs);
251   NodeIOInfo type_insens_io_info;
252   type_insens_io_info.outputs_format.push_back(pattern_output_format);
253   type_insens_io_info.outputs_type.push_back(small_type);
254   SetTypeInsensitiveNodeInputsInfo(type_insens_node, op_input_indexes, new_cast_nodes, &type_insens_io_info, false);
255   auto new_type_insens_node = func_graph->NewCNode(type_insens_node_new_inputs);
256   SetNodeInfo(type_insens_node, new_type_insens_node, type_insens_io_info);
257 
258   (void)mng->Replace(node, new_type_insens_node);
259   return true;
260 }
261 
ReorderCastUpTypeInsensitive(const FuncGraphPtr & func_graph,const FuncGraphManagerPtr & mng,const CNodePtr & node) const262 bool ReorderOps::ReorderCastUpTypeInsensitive(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
263                                               const CNodePtr &node) const {
264   if (!IsTypeInsensitive(node)) {
265     return false;
266   }
267 
268   // Limitation:
269   //   Certain inputs of type insensitive node are cast node.
270   //   Cast nodes are CAST_UP.
271   //   Cast nodes will not change the input format.
272   //   All these cast nodes are only used by current type insensitive node.
273   std::vector<AnfNodePtr> cast_nodes;
274   std::vector<AnfNodePtr> cast_input_nodes;
275   auto op_input_indexes = GetOpDataInputIndexes(node);
276   for (const auto &index : op_input_indexes) {
277     auto node_input = common::AnfAlgo::GetInputNode(node, index);
278     auto cast_node = node_input->cast<CNodePtr>();
279     if (cast_node != nullptr && IsPrimitiveCNode(cast_node, prim::kPrimCast) && GetCastType(cast_node) == CAST_UP &&
280         AnfAlgo::GetInputFormat(node, 0) == AnfAlgo::GetOutputFormat(node, 0) &&
281         mng->node_users()[cast_node].size() == 1) {
282       cast_nodes.push_back(cast_node);
283       cast_input_nodes.push_back(common::AnfAlgo::GetInputNode(cast_node, 0));
284     }
285   }
286   if (cast_nodes.empty() || cast_nodes.size() != op_input_indexes.size()) {
287     return false;
288   }
289 
290   auto small_type = AnfAlgo::GetInputDeviceDataType(cast_nodes[0], 0);
291   auto large_type = AnfAlgo::GetOutputDeviceDataType(cast_nodes[0], 0);
292   auto pattern_output_format = AnfAlgo::GetOutputFormat(node, 0);
293 
294   // Limitation: All these cast nodes cast same type to another type.
295   if (!std::all_of(cast_nodes.begin(), cast_nodes.end(), [&small_type](const AnfNodePtr &cast_node) {
296         return AnfAlgo::GetInputDeviceDataType(cast_node, 0) == small_type;
297       })) {
298     return false;
299   }
300   // Limitation: Type insensitive node's inputs have same data type.
301   if (!CheckInputTypeConsistent(node, op_input_indexes, large_type)) {
302     return false;
303   }
304 
305   std::vector<AnfNodePtr> type_insens_node_new_inputs;
306   SetTypeInsensitiveNodeInputs(node, op_input_indexes, cast_input_nodes, &type_insens_node_new_inputs);
307   auto new_type_insens_node = func_graph->NewCNode(type_insens_node_new_inputs);
308   NodeIOInfo type_insens_io_info;
309   type_insens_io_info.outputs_format.push_back(pattern_output_format);
310   type_insens_io_info.outputs_type.push_back(small_type);
311   SetTypeInsensitiveNodeInputsInfo(node, op_input_indexes, cast_nodes, &type_insens_io_info, true);
312   SetNodeInfo(node, new_type_insens_node, type_insens_io_info);
313 
314   auto new_cast_node =
315     func_graph->NewCNode({NewValueNode(std::make_shared<Primitive>(prim::kPrimCast->name())), new_type_insens_node});
316   NodeIOInfo cast_io_info;
317   cast_io_info.inputs_format.push_back(pattern_output_format);
318   cast_io_info.outputs_format = cast_io_info.inputs_format;
319   cast_io_info.inputs_type.push_back(small_type);
320   cast_io_info.outputs_type.push_back(large_type);
321   SetNodeInfo(cast_nodes[0]->cast<CNodePtr>(), new_cast_node, cast_io_info);
322 
323   (void)mng->Replace(node, new_cast_node);
324   return true;
325 }
326 
ReorderCastTypeInsensitive(const FuncGraphPtr & func_graph) const327 bool ReorderOps::ReorderCastTypeInsensitive(const FuncGraphPtr &func_graph) const {
328   // Reorder cast node and type insensitive node in graph kernel sub-graph, this function has several limitations,
329   //   see the comments that start will "Limitation:" in this file.
330   // Limitation: Assuming the type insensitive node will not change the type of input nodes, otherwise it can be seen
331   //   as another cast node in some sense, such as LessEqual operator, which performs on two inputs and output a
332   //   a boolean result.
333   auto mng = GkUtils::GetFuncGraphManager(func_graph);
334   bool changed = false;
335   auto todos = TopoSort(func_graph->get_return());
336   for (const auto &anf_node : todos) {
337     auto node = anf_node->cast<CNodePtr>();
338     if (node == nullptr) {
339       continue;
340     }
341 
342     if (IsTypeInsensitive(node)) {
343       // Reorder pattern 1: CastUp-TypeInsensitive --> TypeInsensitive-CastUp
344       changed = ReorderCastUpTypeInsensitive(func_graph, mng, node) || changed;
345     } else if (IsPrimitiveCNode(node, prim::kPrimCast)) {
346       // Reorder pattern 2: TypeInsensitive-CastDown --> CastDown-TypeInsensitive
347       changed = ReorderTypeInsensitiveCastDown(func_graph, mng, node) || changed;
348     }
349   }
350 
351   return changed;
352 }
353 
Run(const FuncGraphPtr & func_graph)354 bool ReorderOps::Run(const FuncGraphPtr &func_graph) {
355   bool changed = false;
356   auto todos = TopoSort(func_graph->get_return());
357   for (const auto &anf_node : todos) {
358     auto node = anf_node->cast<CNodePtr>();
359     if (node == nullptr) {
360       continue;
361     }
362 
363     if (common::AnfAlgo::IsGraphKernel(node)) {
364       auto sub_func_graph = common::AnfAlgo::GetCNodeFuncGraphPtr(node);
365       bool need_traverse = true;
366       while (need_traverse) {
367         need_traverse = ReorderCastTypeInsensitive(sub_func_graph);
368         if (need_traverse) {
369           changed = true;
370         }
371       }
372     }
373   }
374 
375   return changed;
376 }
377 }  // namespace mindspore::graphkernel
378