1 /** 2 * Copyright 2019-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_ 18 #define PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_ 19 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include <utility> 24 #include <vector> 25 #include <set> 26 27 #include "frontend/parallel/auto_parallel/rec_core/rec_graph.h" 28 #include "frontend/parallel/ops_info/operator_info.h" 29 30 namespace mindspore { 31 namespace parallel { 32 static const std::set<OperatorType> EliminateOpType = { 33 OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, 34 OperatorType::kRecAdd, OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, 35 OperatorType::kRecSub, OperatorType::kRecMul, OperatorType::kRecDiv, 36 OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecCast, 37 OperatorType::kRecReshape, OperatorType::kRecGatherV2, OperatorType::kRecArgWithValue, 38 OperatorType::kRecSoftmax, OperatorType::kRecOneHot, OperatorType::kRecExpandDims, 39 OperatorType::kRecStridedSlice, OperatorType::kRecCum, OperatorType::kRecLayerNorm, 40 OperatorType::kRecFlatten, OperatorType::kRecBatchParallel, OperatorType::kRecStandAlone, 41 OperatorType::kRecPadV3, OperatorType::kRecBatchMatMul, OperatorType::kFlashAttentionScore, 42 OperatorType::kRecRmsNorm}; 43 44 const std::map<std::string, OperatorType> DictOpType{ 45 {MATMUL, OperatorType::kRecMatMul}, 46 {BATCH_MATMUL, OperatorType::kRecBatchMatMul}, 47 {CONV2D, OperatorType::kRecConvolution}, 48 {CONV2D_TRANSPOSE, OperatorType::kRecConvolution}, 49 {MAXPOOL, OperatorType::kRecPooling}, 50 {MAXPOOLV2, OperatorType::kRecPooling}, 51 {POOLING, OperatorType::kRecPooling}, 52 {MAX_POOL_WITH_ARGMAX, OperatorType::kRecPooling}, 53 {SIMPLE_MEAN, OperatorType::kRecPooling}, 54 {RESHAPE, OperatorType::kRecReshape}, 55 {FLATTEN, OperatorType::kRecFlatten}, 56 {BIAS_ADD, OperatorType::kRecBiasAdd}, 57 {BATCH_NORM, OperatorType::kRecBatchNorm}, 58 {LAYER_NORM, OperatorType::kRecLayerNorm}, 59 {RMS_NORM, OperatorType::kRecRmsNorm}, 60 {SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits}, 61 {ONEHOT, OperatorType::kRecOneHot}, 62 {SQUEEZE, OperatorType::kRecSqueeze}, 63 {CAST, OperatorType::kRecCast}, 64 {REDUCE_SUM, OperatorType::kRecReduce}, 65 {REDUCE_MAX, OperatorType::kRecReduce}, 66 {REDUCE_MIN, OperatorType::kRecReduce}, 67 {REDUCE_MEAN, OperatorType::kRecReduce}, 68 {STAND_ALONE, OperatorType::kRecStandAlone}, 69 {GET_NEXT, OperatorType::kRecUnknownType}, 70 {VIRTUAL_DATA_SET, OperatorType::kRecVirtual}, 71 {VIRTUAL_OUTPUT, OperatorType::kRecVirtual}, 72 {BATCH_PARALLEL, OperatorType::kRecBatchParallel}, 73 {GATHERV2, OperatorType::kRecGatherV2}, 74 {EXPAND_DIMS, OperatorType::kRecExpandDims}, 75 {STRIDEDSLICE, OperatorType::kRecStridedSlice}, 76 {ARGMAXWITHVALUE, OperatorType::kRecArgWithValue}, 77 {ARGMINWITHVALUE, OperatorType::kRecArgWithValue}, 78 {UNSORTED_SEGMENT_SUM, OperatorType::kRecUnsortedSegmentOp}, 79 {UNSORTED_SEGMENT_MAX, OperatorType::kRecUnsortedSegmentOp}, 80 {UNSORTED_SEGMENT_MIN, OperatorType::kRecUnsortedSegmentOp}, 81 // Activation OP 82 {ACTIVATION, OperatorType::kRecReLU}, 83 {RELU, OperatorType::kRecReLU}, 84 {SILU, OperatorType::kRecReLU}, 85 {"ReLU6", OperatorType::kRecReLU}, 86 {SIGMOID, OperatorType::kRecReLU}, 87 {SIGMOID_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecReLU}, 88 {"HSigmoid", OperatorType::kRecReLU}, 89 {GELU, OperatorType::kRecReLU}, 90 {FAST_GELU, OperatorType::kRecReLU}, 91 {TANH, OperatorType::kRecReLU}, 92 {SOFTPLUS, OperatorType::kRecReLU}, 93 {SOFTSIGN, OperatorType::kRecReLU}, 94 {PRELU, OperatorType::kRecPReLU}, 95 // Elm-wise OP 96 {SPLIT, OperatorType::kRecElmWiseOp}, 97 {TRANSPOSE, OperatorType::kRecElmWiseOp}, 98 {L2_NORMALIZE, OperatorType::kRecElmWiseOp}, 99 {ADD, OperatorType::kRecElmWiseOp}, 100 {TENSOR_DOT, OperatorType::kRecElmWiseOp}, 101 {SUB, OperatorType::kRecElmWiseOp}, 102 {MUL, OperatorType::kRecElmWiseOp}, 103 {DIV, OperatorType::kRecElmWiseOp}, 104 {REAL_DIV, OperatorType::kRecElmWiseOp}, 105 {HYPOT, OperatorType::kRecElmWiseOp}, 106 {IGAMMA, OperatorType::kRecElmWiseOp}, 107 {IGAMMAC, OperatorType::kRecElmWiseOp}, 108 {LEFT_SHIFT, OperatorType::kRecElmWiseOp}, 109 {RIGHT_SHIFT, OperatorType::kRecElmWiseOp}, 110 {NEXT_AFTER, OperatorType::kRecElmWiseOp}, 111 {ZETA, OperatorType::kRecElmWiseOp}, 112 {GCD, OperatorType::kRecElmWiseOp}, 113 {SOFTMAX, OperatorType::kRecSoftmax}, 114 {REVERSEV2, OperatorType::kRecSoftmax}, 115 {LOG_SOFTMAX, OperatorType::kRecSoftmax}, 116 {CHOLESKY, OperatorType::kRecSoftmax}, 117 {SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSoftmaxCrossEntropyWithLogits}, 118 {FLATTEN, OperatorType::kRecFlatten}, 119 {PAD_V3, OperatorType::kRecPadV3}, 120 {CUM_SUM, OperatorType::kRecCum}, 121 {SQRT, OperatorType::kRecElmWiseOp}, 122 {NEG, OperatorType::kRecElmWiseOp}, 123 {POW, OperatorType::kRecElmWiseOp}, 124 {EXP, OperatorType::kRecElmWiseOp}, 125 {LOG, OperatorType::kRecElmWiseOp}, 126 {COS, OperatorType::kRecElmWiseOp}, 127 {LGAMMA, OperatorType::kRecElmWiseOp}, 128 {TRUNC, OperatorType::kRecElmWiseOp}, 129 {ACOS, OperatorType::kRecElmWiseOp}, 130 {ASIN, OperatorType::kRecElmWiseOp}, 131 {ASINH, OperatorType::kRecElmWiseOp}, 132 {ATAN, OperatorType::kRecElmWiseOp}, 133 {ATANH, OperatorType::kRecElmWiseOp}, 134 {EXPM1, OperatorType::kRecElmWiseOp}, 135 {LOG1P, OperatorType::kRecElmWiseOp}, 136 {LOGICALNOT, OperatorType::kRecElmWiseOp}, 137 {"LogicalAnd", OperatorType::kRecElmWiseOp}, 138 {"LogicalOr", OperatorType::kRecElmWiseOp}, 139 {SQUARE, OperatorType::kRecElmWiseOp}, 140 {"Abs", OperatorType::kRecElmWiseOp}, 141 {"Acosh", OperatorType::kRecElmWiseOp}, 142 {"AddN", OperatorType::kRecElmWiseOp}, 143 {"AccumulateNV2", OperatorType::kRecElmWiseOp}, 144 {"Atan2", OperatorType::kRecElmWiseOp}, 145 {ELU, OperatorType::kRecElmWiseOp}, 146 {ERF, OperatorType::kRecElmWiseOp}, 147 {ERFC, OperatorType::kRecElmWiseOp}, 148 {MOD, OperatorType::kRecElmWiseOp}, 149 {FLOOR, OperatorType::kRecElmWiseOp}, 150 {CEIL, OperatorType::kRecElmWiseOp}, 151 {FLOORDIV, OperatorType::kRecElmWiseOp}, 152 {"FloorMod", OperatorType::kRecElmWiseOp}, 153 {GREATER, OperatorType::kRecElmWiseOp}, 154 {"GreaterEqual", OperatorType::kRecElmWiseOp}, 155 {"HSwish", OperatorType::kRecElmWiseOp}, 156 {"Less", OperatorType::kRecElmWiseOp}, 157 {"LessEqual", OperatorType::kRecElmWiseOp}, 158 {MAXIMUM, OperatorType::kRecElmWiseOp}, 159 {MINIMUM, OperatorType::kRecElmWiseOp}, 160 {EQUAL, OperatorType::kRecElmWiseOp}, 161 {NOT_EQUAL, OperatorType::kRecElmWiseOp}, 162 {APPROXIMATEEQUAL, OperatorType::kRecElmWiseOp}, 163 {INV, OperatorType::kRecElmWiseOp}, 164 {BESSELI0E, OperatorType::kRecElmWiseOp}, 165 {BESSELI1E, OperatorType::kRecElmWiseOp}, 166 {BESSELI0, OperatorType::kRecElmWiseOp}, 167 {BESSELI1, OperatorType::kRecElmWiseOp}, 168 {BESSELJ0, OperatorType::kRecElmWiseOp}, 169 {BESSELJ1, OperatorType::kRecElmWiseOp}, 170 {ZEROSLIKE, OperatorType::kRecElmWiseOp}, 171 {ONESLIKE, OperatorType::kRecElmWiseOp}, 172 {DIVNONAN, OperatorType::kRecElmWiseOp}, 173 {"Reciprocal", OperatorType::kRecElmWiseOp}, 174 {"Round", OperatorType::kRecElmWiseOp}, 175 {"Rsqrt", OperatorType::kRecElmWiseOp}, 176 {"Sign", OperatorType::kRecElmWiseOp}, 177 {SIN, OperatorType::kRecElmWiseOp}, 178 {SINH, OperatorType::kRecElmWiseOp}, 179 {TAN, OperatorType::kRecElmWiseOp}, 180 {ASSIGN, OperatorType::kRecElmWiseOp}, 181 {ASSIGN_ADD, OperatorType::kRecElmWiseOp}, 182 {ASSIGN_SUB, OperatorType::kRecElmWiseOp}, 183 {"AssignAdd", OperatorType::kRecElmWiseOp}, 184 {DROPOUT_DO_MASK, OperatorType::kRecElmWiseOp}, 185 {DROPOUT, OperatorType::kRecElmWiseOp}, 186 {STACK, OperatorType::kRecElmWiseOp}, 187 {"Select", OperatorType::kRecElmWiseOp}, 188 {"Concat", OperatorType::kRecElmWiseOp}, 189 {"Tile", OperatorType::kRecElmWiseOp}, 190 {MASKED_FILL, OperatorType::kRecElmWiseOp}, 191 {FILLV2, OperatorType::kRecElmWiseOp}, 192 {SCATTER_UPDATE, OperatorType::kRecElmWiseOp}, 193 {KV_CACHE_MGR, OperatorType::kRecElmWiseOp}, 194 {GATHERD, OperatorType::kRecBatchParallel}, 195 {FLASH_ATTENTION_SCORE, OperatorType::kFlashAttentionScore}}; 196 197 const TensorParam MakeTensor(int64_t n, int64_t c, int64_t h, int64_t w); 198 199 Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t iter_ops); 200 201 void CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, 202 Graph::NodeType *NewTensor); 203 204 void Complete2DInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, 205 const size_t iter_input_tensors, Graph::NodeType *NewTensor); 206 207 void Complete4DInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, 208 const size_t iter_input_tensors, Graph::NodeType *NewTensor); 209 210 std::shared_ptr<Graph> ParseGraph(const std::vector<std::shared_ptr<OperatorInfo>> &ops, 211 const std::vector<std::vector<std::string>> &input_tensor_names); 212 213 void MakeEdge(const std::vector<std::vector<std::string>> &input_tensor_names, const std::shared_ptr<Graph> &graph); 214 215 size_t GetIndexInInputTensorNames(const std::vector<std::vector<std::string>> &input_tensor_name, 216 const std::string &input_name); 217 218 void Eliminate_Aux_Outgoing(size_t node_index, const std::shared_ptr<Graph> &graph); 219 void EliminateAuxOutgoingInput(size_t node_index, const std::shared_ptr<Graph> &graph, size_t i); 220 void EliminateAuxOutgoingAuxInput(size_t node_index, const std::shared_ptr<Graph> &graph, size_t i); 221 222 void Eliminate_Aux(size_t node_index, const std::shared_ptr<Graph> &graph, 223 const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list); 224 225 std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> &graph, 226 const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list, 227 const std::shared_ptr<std::vector<size_t>> &index_list, 228 const bool dyn_shape_tmp_fix); 229 } // namespace parallel 230 } // namespace mindspore 231 #endif // PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_ 232