1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_COMMON_PASS_INSERT_TYPE_TRANSFORM_OP_H_ 18 #define MINDSPORE_CCSRC_BACKEND_COMMON_PASS_INSERT_TYPE_TRANSFORM_OP_H_ 19 20 #include <map> 21 #include <vector> 22 #include <string> 23 #include "kernel/kernel_build_info.h" 24 #include "include/backend/optimizer/optimizer.h" 25 26 namespace mindspore { 27 namespace opt { 28 using kernel::KernelBuildInfoPtr; 29 using kernel::KernelObjectType; 30 using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; 31 32 // This attribute represents this node's output is already expanded. 33 constexpr char kTupleUnfoldExpanded[] = "tuple_unfold_expanded"; 34 35 static std::map<KernelObjectType, std::string> kObjectTypeToString = {{KernelObjectType::UNKNOWN_TYPE, "unknown"}, 36 {KernelObjectType::TENSOR, "tensor"}, 37 {KernelObjectType::SCALAR, "scalar"}, 38 {KernelObjectType::TUPLE, "tuple"}, 39 {KernelObjectType::TUPLE_UNFOLD, "tuple_unfold"}}; 40 41 // Kernel object type pair of: 42 // 1. One node's input kernel object type. 43 // 2. The actual kernel object type this node's kernel info stores. 44 struct ObjectTypePair { 45 KernelObjectType current_input_type; 46 KernelObjectType needed_input_type; 47 to_stringObjectTypePair48 std::string to_string() const { 49 if (kObjectTypeToString.count(current_input_type) == 0 || kObjectTypeToString.count(needed_input_type) == 0) { 50 MS_LOG(EXCEPTION) << "The current input object type " << current_input_type << " or needed input object type " 51 << needed_input_type << " is not valid."; 52 } 53 54 return kObjectTypeToString[current_input_type] + "->" + kObjectTypeToString[needed_input_type]; 55 } 56 57 bool operator<(const ObjectTypePair &t) const { return to_string() < t.to_string(); } 58 59 bool operator==(const ObjectTypePair &t) const { return to_string() == t.to_string(); } 60 }; 61 62 // For each unmatched type pair, a processing method is required to correct the types by inserting type transforming 63 // ops or replace origin primitive. 64 // The method returns new input list so a new node could be created and replace the old node. 65 // If there's no need to change the input, this method returns the old input. 66 /** 67 * @param {FuncGraphPtr} &func_graph: This func_graph. 68 * @param {AnfNodePtr} &input: The input which needs to be processed because of its output kernel object type. 69 * @param {CNodePtr} &node: The node which uses input but the type is not satisfied. 70 * @param {bool} *new_prim: Whether the origin node's primitive should also be replaced. If true, new primitive node is 71 * returned as the first element in returned AnfNodePtrList. 72 * @return {AnfNodePtrList}: New input list which replaces 'input' and will be handled by caller. 73 */ 74 using ProcessTypeTransformFunc = std::function<AnfNodePtrList(const FuncGraphPtr &func_graph, const AnfNodePtr &input, 75 const CNodePtr &node, bool *new_prim)>; 76 77 // SplitTupleInputs methods refer to the pass ConvertTupleInputToDynamicInput. It unfolds tuple inputs and returns the 78 // unfolded inputs nodes. 79 int64_t SplitTupleInputsForInsertType(const FuncGraphPtr &graph, const AnfNodePtr &tuple_input, 80 std::vector<AnfNodePtr> *plant_inputs); 81 82 // Create the new cnode which will replace the original cnode. 83 // This method is called at the last step of this pass specifically. 84 AnfNodePtr CreateNewNode(const FuncGraphPtr &func_graph, const AnfNodePtrList &input_list, const CNodePtr &origin_node); 85 86 // Transforming MakeTuple to RealMakeTuple scenario. 87 AnfNodePtr CreateRealMakeTupleByMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &make_tuple_node); 88 89 // Node with TupleUnfold output(not MakeTuple) connected to Tuple input scenario. 90 AnfNodePtr CreateRealMakeTupleByTupleUnfoldInput(const FuncGraphPtr &func_graph, 91 const AnfNodePtr &node_with_tuple_unfold_output); 92 93 // Backoff ops which is inserted in this pass. 94 bool IsBackOffOp(const CNodePtr &cnode); 95 96 // Set kernel info validation flag according to white list. 97 void SetBackOffFlag(const KernelBuildInfoPtr &build_info, const CNodePtr &cnode); 98 99 // Set kernel info for newly created cnodes. The kernel info will be generated from scratch. 100 // In some cases, there's no need to set input/output format and type for the node. 101 void SetKernelInfoForNewCNode(const CNodePtr &cnode, bool set_format_type = true); 102 103 // Set kernel info for some value nodes manually. 104 void SetKernelInfoForValueNode(const ValueNodePtr &value_node); 105 106 // Multiplex op infer methods defined under core/ops to generate abstract of new cnode. 107 abstract::AbstractBasePtr GenerateAbsByOpInfer(const PrimitivePtr &primitive); 108 109 // Generate abstract, format and object type for newly created node. 110 // They can be generated in multiple ways because new node is not processed by kernel selecting method. 111 std::string GenerateOutputFormatForNewCNode(const CNodePtr &cnode); 112 void GenerateKernelObjectTypeForNewCNode(const CNodePtr &cnode, std::vector<KernelObjectType> *input_obj_type, 113 std::vector<KernelObjectType> *output_obj_type); 114 115 // After kernel selection phase, one kernel's acquired input type may not be the same as the actual input type(the input 116 // node's output type). We need this pass to transform these types to valid types. 117 class BACKEND_EXPORT InsertTypeTransformOp : public PatternProcessPass { 118 public: 119 explicit InsertTypeTransformOp(bool multigraph = true); 120 ~InsertTypeTransformOp() override = default; 121 const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override; 122 123 private: 124 // This method check whether new inputs are generated to replace the old one. If so, new input node list will be 125 // returned by method 'Process'. 126 bool IsInputUpdated(const AnfNodePtr &origin_input, const AnfNodePtrList &new_input_list) const; 127 128 // This scenario is migrated from the pass ConvertTupleInputToDynamicInput. Please refer to 129 // convert_tuple_input_to_dynamic_input.h/cc 130 AnfNodePtrList ProcessTupleUnfoldToTupleUnfold(const FuncGraphPtr &func_graph, const AnfNodePtr &input, 131 const CNodePtr &node, bool *new_prim); 132 133 // Convert TupleUnfold output to tuple, real tuple with continuous memory. 134 AnfNodePtrList ProcessTupleUnfoldToTuple(const FuncGraphPtr &func_graph, const AnfNodePtr &input, 135 const CNodePtr &node, bool *new_prim); 136 137 // Convert TupleUnfold output to Tensor. Firstly insert TupleToTensor op. Then transform TupleUnfold to Tuple. 138 AnfNodePtrList ProcessTupleUnfoldToTensor(const FuncGraphPtr &func_graph, const AnfNodePtr &input, 139 const CNodePtr &node, bool *new_prim); 140 141 // Convert Tuple output to TupleUnfold. User must be TupleGetItem op and change it to RealTupleGetItem. 142 AnfNodePtrList ProcessTupleToTupleUnfold(const FuncGraphPtr &func_graph, const AnfNodePtr &input, 143 const CNodePtr &node, bool *new_prim); 144 AnfNodePtrList ProcessTupleToTupleUnfoldForTupleGetItem(const FuncGraphPtr &func_graph, const AnfNodePtr &input, 145 const CNodePtr &node, bool *new_prim); 146 AnfNodePtrList ProcessTupleToTupleUnfoldForSkipOp(const FuncGraphPtr &func_graph, const AnfNodePtr &input, 147 const CNodePtr &node, bool *new_prim); 148 // Convert Tuple/Scalar output to Tensor. Simply insert TupleToTensor/ScalarToTensor op. 149 AnfNodePtrList ProcessTupleToTensor(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const CNodePtr &node, 150 bool *new_prim); 151 AnfNodePtrList ProcessScalarToTensor(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const CNodePtr &node, 152 bool *new_prim); 153 154 // Transform Tensor to Tuple/Scalar. Simply insert TensorToTuple/TensorToScalar op. 155 AnfNodePtrList ProcessTensorToTuple(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const CNodePtr &node, 156 bool *new_prim); 157 AnfNodePtrList ProcessTensorToScalar(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const CNodePtr &node, 158 bool *new_prim); 159 }; 160 } // namespace opt 161 } // namespace mindspore 162 #endif // MINDSPORE_CCSRC_BACKEND_COMMON_PASS_INSERT_TYPE_TRANSFORM_OP_H_ 163