• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_COMMON_PASS_INSERT_TYPE_TRANSFORM_OP_H_
18 #define MINDSPORE_CCSRC_BACKEND_COMMON_PASS_INSERT_TYPE_TRANSFORM_OP_H_
19 
20 #include <map>
21 #include <vector>
22 #include <string>
23 #include "kernel/kernel_build_info.h"
24 #include "include/backend/optimizer/optimizer.h"
25 
26 namespace mindspore {
27 namespace opt {
28 using kernel::KernelBuildInfoPtr;
29 using kernel::KernelObjectType;
30 using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
31 
32 // This attribute represents this node's output is already expanded.
33 constexpr char kTupleUnfoldExpanded[] = "tuple_unfold_expanded";
34 
35 static std::map<KernelObjectType, std::string> kObjectTypeToString = {{KernelObjectType::UNKNOWN_TYPE, "unknown"},
36                                                                       {KernelObjectType::TENSOR, "tensor"},
37                                                                       {KernelObjectType::SCALAR, "scalar"},
38                                                                       {KernelObjectType::TUPLE, "tuple"},
39                                                                       {KernelObjectType::TUPLE_UNFOLD, "tuple_unfold"}};
40 
41 // Kernel object type pair of:
42 // 1. One node's input kernel object type.
43 // 2. The actual kernel object type this node's kernel info stores.
44 struct ObjectTypePair {
45   KernelObjectType current_input_type;
46   KernelObjectType needed_input_type;
47 
to_stringObjectTypePair48   std::string to_string() const {
49     if (kObjectTypeToString.count(current_input_type) == 0 || kObjectTypeToString.count(needed_input_type) == 0) {
50       MS_LOG(EXCEPTION) << "The current input object type " << current_input_type << " or needed input object type "
51                         << needed_input_type << " is not valid.";
52     }
53 
54     return kObjectTypeToString[current_input_type] + "->" + kObjectTypeToString[needed_input_type];
55   }
56 
57   bool operator<(const ObjectTypePair &t) const { return to_string() < t.to_string(); }
58 
59   bool operator==(const ObjectTypePair &t) const { return to_string() == t.to_string(); }
60 };
61 
62 // For each unmatched type pair, a processing method is required to correct the types by inserting type transforming
63 // ops or replace origin primitive.
64 // The method returns new input list so a new node could be created and replace the old node.
65 // If there's no need to change the input, this method returns the old input.
66 /**
67  * @param {FuncGraphPtr} &func_graph: This func_graph.
68  * @param {AnfNodePtr} &input: The input which needs to be processed because of its output kernel object type.
69  * @param {CNodePtr} &node: The node which uses input but the type is not satisfied.
70  * @param {bool} *new_prim: Whether the origin node's primitive should also be replaced. If true, new primitive node is
71  * returned as the first element in returned AnfNodePtrList.
72  * @return {AnfNodePtrList}: New input list which replaces 'input' and will be handled by caller.
73  */
74 using ProcessTypeTransformFunc = std::function<AnfNodePtrList(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
75                                                               const CNodePtr &node, bool *new_prim)>;
76 
77 // SplitTupleInputs methods refer to the pass ConvertTupleInputToDynamicInput. It unfolds tuple inputs and returns the
78 // unfolded inputs nodes.
79 int64_t SplitTupleInputsForInsertType(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input,
80                                       std::vector<AnfNodePtr> *plant_inputs);
81 
82 // Create the new cnode which will replace the original cnode.
83 // This method is called at the last step of this pass specifically.
84 AnfNodePtr CreateNewNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &input_list, const CNodePtr &origin_node);
85 
86 // Transforming MakeTuple to RealMakeTuple scenario.
87 AnfNodePtr CreateRealMakeTupleByMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &make_tuple_node);
88 
89 // Node with TupleUnfold output(not MakeTuple) connected to Tuple input scenario.
90 AnfNodePtr CreateRealMakeTupleByTupleUnfoldInput(const FuncGraphPtr &func_graph,
91                                                  const AnfNodePtr &node_with_tuple_unfold_output);
92 
93 // Backoff ops which is inserted in this pass.
94 bool IsBackOffOp(const CNodePtr &cnode);
95 
96 // Set kernel info validation flag according to white list.
97 void SetBackOffFlag(const KernelBuildInfoPtr &build_info, const CNodePtr &cnode);
98 
99 // Set kernel info for newly created cnodes. The kernel info will be generated from scratch.
100 // In some cases, there's no need to set input/output format and type for the node.
101 void SetKernelInfoForNewCNode(const CNodePtr &cnode, bool set_format_type = true);
102 
103 // Set kernel info for some value nodes manually.
104 void SetKernelInfoForValueNode(const ValueNodePtr &value_node);
105 
106 // Multiplex op infer methods defined under core/ops to generate abstract of new cnode.
107 abstract::AbstractBasePtr GenerateAbsByOpInfer(const PrimitivePtr &primitive);
108 
109 // Generate abstract, format and object type for newly created node.
110 // They can be generated in multiple ways because new node is not processed by kernel selecting method.
111 std::string GenerateOutputFormatForNewCNode(const CNodePtr &cnode);
112 void GenerateKernelObjectTypeForNewCNode(const CNodePtr &cnode, std::vector<KernelObjectType> *input_obj_type,
113                                          std::vector<KernelObjectType> *output_obj_type);
114 
115 // After kernel selection phase, one kernel's acquired input type may not be the same as the actual input type(the input
116 // node's output type). We need this pass to transform these types to valid types.
117 class BACKEND_EXPORT InsertTypeTransformOp : public PatternProcessPass {
118  public:
119   explicit InsertTypeTransformOp(bool multigraph = true);
120   ~InsertTypeTransformOp() override = default;
121   const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override;
122 
123  private:
124   // This method check whether new inputs are generated to replace the old one. If so, new input node list will be
125   // returned by method 'Process'.
126   bool IsInputUpdated(const AnfNodePtr &origin_input, const AnfNodePtrList &new_input_list) const;
127 
128   // This scenario is migrated from the pass ConvertTupleInputToDynamicInput. Please refer to
129   // convert_tuple_input_to_dynamic_input.h/cc
130   AnfNodePtrList ProcessTupleUnfoldToTupleUnfold(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
131                                                  const CNodePtr &node, bool *new_prim);
132 
133   // Convert TupleUnfold output to tuple, real tuple with continuous memory.
134   AnfNodePtrList ProcessTupleUnfoldToTuple(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
135                                            const CNodePtr &node, bool *new_prim);
136 
137   // Convert TupleUnfold output to Tensor. Firstly insert TupleToTensor op. Then transform TupleUnfold to Tuple.
138   AnfNodePtrList ProcessTupleUnfoldToTensor(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
139                                             const CNodePtr &node, bool *new_prim);
140 
141   // Convert Tuple output to TupleUnfold. User must be TupleGetItem op and change it to RealTupleGetItem.
142   AnfNodePtrList ProcessTupleToTupleUnfold(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
143                                            const CNodePtr &node, bool *new_prim);
144   AnfNodePtrList ProcessTupleToTupleUnfoldForTupleGetItem(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
145                                                           const CNodePtr &node, bool *new_prim);
146   AnfNodePtrList ProcessTupleToTupleUnfoldForSkipOp(const FuncGraphPtr &func_graph, const AnfNodePtr &input,
147                                                     const CNodePtr &node, bool *new_prim);
148   // Convert Tuple/Scalar output to Tensor. Simply insert TupleToTensor/ScalarToTensor op.
149   AnfNodePtrList ProcessTupleToTensor(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const CNodePtr &node,
150                                       bool *new_prim);
151   AnfNodePtrList ProcessScalarToTensor(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const CNodePtr &node,
152                                        bool *new_prim);
153 
154   // Transform Tensor to Tuple/Scalar. Simply insert TensorToTuple/TensorToScalar op.
155   AnfNodePtrList ProcessTensorToTuple(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const CNodePtr &node,
156                                       bool *new_prim);
157   AnfNodePtrList ProcessTensorToScalar(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const CNodePtr &node,
158                                        bool *new_prim);
159 };
160 }  // namespace opt
161 }  // namespace mindspore
162 #endif  // MINDSPORE_CCSRC_BACKEND_COMMON_PASS_INSERT_TYPE_TRANSFORM_OP_H_
163