• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_
18 #define PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 #include <set>
26 
27 #include "frontend/parallel/auto_parallel/rec_core/rec_graph.h"
28 #include "frontend/parallel/ops_info/operator_info.h"
29 
30 namespace mindspore {
31 namespace parallel {
32 static const std::set<OperatorType> EliminateOpType = {
33   OperatorType::kRecReLU,         OperatorType::kRecLog,           OperatorType::kRecExp,
34   OperatorType::kRecAdd,          OperatorType::kRecElmWiseOp,     OperatorType::kRecBiasAdd,
35   OperatorType::kRecSub,          OperatorType::kRecMul,           OperatorType::kRecDiv,
36   OperatorType::kRecSqueeze,      OperatorType::kRecReduce,        OperatorType::kRecCast,
37   OperatorType::kRecReshape,      OperatorType::kRecGatherV2,      OperatorType::kRecArgWithValue,
38   OperatorType::kRecSoftmax,      OperatorType::kRecOneHot,        OperatorType::kRecExpandDims,
39   OperatorType::kRecStridedSlice, OperatorType::kRecCum,           OperatorType::kRecLayerNorm,
40   OperatorType::kRecFlatten,      OperatorType::kRecBatchParallel, OperatorType::kRecStandAlone,
41   OperatorType::kRecPadV3,        OperatorType::kRecBatchMatMul,   OperatorType::kFlashAttentionScore,
42   OperatorType::kRecRmsNorm};
43 
44 const std::map<std::string, OperatorType> DictOpType{
45   {MATMUL, OperatorType::kRecMatMul},
46   {BATCH_MATMUL, OperatorType::kRecBatchMatMul},
47   {CONV2D, OperatorType::kRecConvolution},
48   {CONV2D_TRANSPOSE, OperatorType::kRecConvolution},
49   {MAXPOOL, OperatorType::kRecPooling},
50   {MAXPOOLV2, OperatorType::kRecPooling},
51   {POOLING, OperatorType::kRecPooling},
52   {MAX_POOL_WITH_ARGMAX, OperatorType::kRecPooling},
53   {SIMPLE_MEAN, OperatorType::kRecPooling},
54   {RESHAPE, OperatorType::kRecReshape},
55   {FLATTEN, OperatorType::kRecFlatten},
56   {BIAS_ADD, OperatorType::kRecBiasAdd},
57   {BATCH_NORM, OperatorType::kRecBatchNorm},
58   {LAYER_NORM, OperatorType::kRecLayerNorm},
59   {RMS_NORM, OperatorType::kRecRmsNorm},
60   {SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits},
61   {ONEHOT, OperatorType::kRecOneHot},
62   {SQUEEZE, OperatorType::kRecSqueeze},
63   {CAST, OperatorType::kRecCast},
64   {REDUCE_SUM, OperatorType::kRecReduce},
65   {REDUCE_MAX, OperatorType::kRecReduce},
66   {REDUCE_MIN, OperatorType::kRecReduce},
67   {REDUCE_MEAN, OperatorType::kRecReduce},
68   {STAND_ALONE, OperatorType::kRecStandAlone},
69   {GET_NEXT, OperatorType::kRecUnknownType},
70   {VIRTUAL_DATA_SET, OperatorType::kRecVirtual},
71   {VIRTUAL_OUTPUT, OperatorType::kRecVirtual},
72   {BATCH_PARALLEL, OperatorType::kRecBatchParallel},
73   {GATHERV2, OperatorType::kRecGatherV2},
74   {EXPAND_DIMS, OperatorType::kRecExpandDims},
75   {STRIDEDSLICE, OperatorType::kRecStridedSlice},
76   {ARGMAXWITHVALUE, OperatorType::kRecArgWithValue},
77   {ARGMINWITHVALUE, OperatorType::kRecArgWithValue},
78   {UNSORTED_SEGMENT_SUM, OperatorType::kRecUnsortedSegmentOp},
79   {UNSORTED_SEGMENT_MAX, OperatorType::kRecUnsortedSegmentOp},
80   {UNSORTED_SEGMENT_MIN, OperatorType::kRecUnsortedSegmentOp},
81   // Activation OP
82   {ACTIVATION, OperatorType::kRecReLU},
83   {RELU, OperatorType::kRecReLU},
84   {SILU, OperatorType::kRecReLU},
85   {"ReLU6", OperatorType::kRecReLU},
86   {SIGMOID, OperatorType::kRecReLU},
87   {SIGMOID_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecReLU},
88   {"HSigmoid", OperatorType::kRecReLU},
89   {GELU, OperatorType::kRecReLU},
90   {FAST_GELU, OperatorType::kRecReLU},
91   {TANH, OperatorType::kRecReLU},
92   {SOFTPLUS, OperatorType::kRecReLU},
93   {SOFTSIGN, OperatorType::kRecReLU},
94   {PRELU, OperatorType::kRecPReLU},
95   // Elm-wise OP
96   {SPLIT, OperatorType::kRecElmWiseOp},
97   {TRANSPOSE, OperatorType::kRecElmWiseOp},
98   {L2_NORMALIZE, OperatorType::kRecElmWiseOp},
99   {ADD, OperatorType::kRecElmWiseOp},
100   {TENSOR_DOT, OperatorType::kRecElmWiseOp},
101   {SUB, OperatorType::kRecElmWiseOp},
102   {MUL, OperatorType::kRecElmWiseOp},
103   {DIV, OperatorType::kRecElmWiseOp},
104   {REAL_DIV, OperatorType::kRecElmWiseOp},
105   {HYPOT, OperatorType::kRecElmWiseOp},
106   {IGAMMA, OperatorType::kRecElmWiseOp},
107   {IGAMMAC, OperatorType::kRecElmWiseOp},
108   {LEFT_SHIFT, OperatorType::kRecElmWiseOp},
109   {RIGHT_SHIFT, OperatorType::kRecElmWiseOp},
110   {NEXT_AFTER, OperatorType::kRecElmWiseOp},
111   {ZETA, OperatorType::kRecElmWiseOp},
112   {GCD, OperatorType::kRecElmWiseOp},
113   {SOFTMAX, OperatorType::kRecSoftmax},
114   {REVERSEV2, OperatorType::kRecSoftmax},
115   {LOG_SOFTMAX, OperatorType::kRecSoftmax},
116   {CHOLESKY, OperatorType::kRecSoftmax},
117   {SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSoftmaxCrossEntropyWithLogits},
118   {FLATTEN, OperatorType::kRecFlatten},
119   {PAD_V3, OperatorType::kRecPadV3},
120   {CUM_SUM, OperatorType::kRecCum},
121   {SQRT, OperatorType::kRecElmWiseOp},
122   {NEG, OperatorType::kRecElmWiseOp},
123   {POW, OperatorType::kRecElmWiseOp},
124   {EXP, OperatorType::kRecElmWiseOp},
125   {LOG, OperatorType::kRecElmWiseOp},
126   {COS, OperatorType::kRecElmWiseOp},
127   {LGAMMA, OperatorType::kRecElmWiseOp},
128   {TRUNC, OperatorType::kRecElmWiseOp},
129   {ACOS, OperatorType::kRecElmWiseOp},
130   {ASIN, OperatorType::kRecElmWiseOp},
131   {ASINH, OperatorType::kRecElmWiseOp},
132   {ATAN, OperatorType::kRecElmWiseOp},
133   {ATANH, OperatorType::kRecElmWiseOp},
134   {EXPM1, OperatorType::kRecElmWiseOp},
135   {LOG1P, OperatorType::kRecElmWiseOp},
136   {LOGICALNOT, OperatorType::kRecElmWiseOp},
137   {"LogicalAnd", OperatorType::kRecElmWiseOp},
138   {"LogicalOr", OperatorType::kRecElmWiseOp},
139   {SQUARE, OperatorType::kRecElmWiseOp},
140   {"Abs", OperatorType::kRecElmWiseOp},
141   {"Acosh", OperatorType::kRecElmWiseOp},
142   {"AddN", OperatorType::kRecElmWiseOp},
143   {"AccumulateNV2", OperatorType::kRecElmWiseOp},
144   {"Atan2", OperatorType::kRecElmWiseOp},
145   {ELU, OperatorType::kRecElmWiseOp},
146   {ERF, OperatorType::kRecElmWiseOp},
147   {ERFC, OperatorType::kRecElmWiseOp},
148   {MOD, OperatorType::kRecElmWiseOp},
149   {FLOOR, OperatorType::kRecElmWiseOp},
150   {CEIL, OperatorType::kRecElmWiseOp},
151   {FLOORDIV, OperatorType::kRecElmWiseOp},
152   {"FloorMod", OperatorType::kRecElmWiseOp},
153   {GREATER, OperatorType::kRecElmWiseOp},
154   {"GreaterEqual", OperatorType::kRecElmWiseOp},
155   {"HSwish", OperatorType::kRecElmWiseOp},
156   {"Less", OperatorType::kRecElmWiseOp},
157   {"LessEqual", OperatorType::kRecElmWiseOp},
158   {MAXIMUM, OperatorType::kRecElmWiseOp},
159   {MINIMUM, OperatorType::kRecElmWiseOp},
160   {EQUAL, OperatorType::kRecElmWiseOp},
161   {NOT_EQUAL, OperatorType::kRecElmWiseOp},
162   {APPROXIMATEEQUAL, OperatorType::kRecElmWiseOp},
163   {INV, OperatorType::kRecElmWiseOp},
164   {BESSELI0E, OperatorType::kRecElmWiseOp},
165   {BESSELI1E, OperatorType::kRecElmWiseOp},
166   {BESSELI0, OperatorType::kRecElmWiseOp},
167   {BESSELI1, OperatorType::kRecElmWiseOp},
168   {BESSELJ0, OperatorType::kRecElmWiseOp},
169   {BESSELJ1, OperatorType::kRecElmWiseOp},
170   {ZEROSLIKE, OperatorType::kRecElmWiseOp},
171   {ONESLIKE, OperatorType::kRecElmWiseOp},
172   {DIVNONAN, OperatorType::kRecElmWiseOp},
173   {"Reciprocal", OperatorType::kRecElmWiseOp},
174   {"Round", OperatorType::kRecElmWiseOp},
175   {"Rsqrt", OperatorType::kRecElmWiseOp},
176   {"Sign", OperatorType::kRecElmWiseOp},
177   {SIN, OperatorType::kRecElmWiseOp},
178   {SINH, OperatorType::kRecElmWiseOp},
179   {TAN, OperatorType::kRecElmWiseOp},
180   {ASSIGN, OperatorType::kRecElmWiseOp},
181   {ASSIGN_ADD, OperatorType::kRecElmWiseOp},
182   {ASSIGN_SUB, OperatorType::kRecElmWiseOp},
183   {"AssignAdd", OperatorType::kRecElmWiseOp},
184   {DROPOUT_DO_MASK, OperatorType::kRecElmWiseOp},
185   {DROPOUT, OperatorType::kRecElmWiseOp},
186   {STACK, OperatorType::kRecElmWiseOp},
187   {"Select", OperatorType::kRecElmWiseOp},
188   {"Concat", OperatorType::kRecElmWiseOp},
189   {"Tile", OperatorType::kRecElmWiseOp},
190   {MASKED_FILL, OperatorType::kRecElmWiseOp},
191   {FILLV2, OperatorType::kRecElmWiseOp},
192   {SCATTER_UPDATE, OperatorType::kRecElmWiseOp},
193   {KV_CACHE_MGR, OperatorType::kRecElmWiseOp},
194   {GATHERD, OperatorType::kRecBatchParallel},
195   {FLASH_ATTENTION_SCORE, OperatorType::kFlashAttentionScore}};
196 
197 const TensorParam MakeTensor(int64_t n, int64_t c, int64_t h, int64_t w);
198 
199 Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t iter_ops);
200 
201 void CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
202                             Graph::NodeType *NewTensor);
203 
204 void Complete2DInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
205                       const size_t iter_input_tensors, Graph::NodeType *NewTensor);
206 
207 void Complete4DInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
208                       const size_t iter_input_tensors, Graph::NodeType *NewTensor);
209 
210 std::shared_ptr<Graph> ParseGraph(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
211                                   const std::vector<std::vector<std::string>> &input_tensor_names);
212 
213 void MakeEdge(const std::vector<std::vector<std::string>> &input_tensor_names, const std::shared_ptr<Graph> &graph);
214 
215 size_t GetIndexInInputTensorNames(const std::vector<std::vector<std::string>> &input_tensor_name,
216                                   const std::string &input_name);
217 
218 void Eliminate_Aux_Outgoing(size_t node_index, const std::shared_ptr<Graph> &graph);
219 void EliminateAuxOutgoingInput(size_t node_index, const std::shared_ptr<Graph> &graph, size_t i);
220 void EliminateAuxOutgoingAuxInput(size_t node_index, const std::shared_ptr<Graph> &graph, size_t i);
221 
222 void Eliminate_Aux(size_t node_index, const std::shared_ptr<Graph> &graph,
223                    const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list);
224 
225 std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> &graph,
226                                       const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
227                                       const std::shared_ptr<std::vector<size_t>> &index_list,
228                                       const bool dyn_shape_tmp_fix);
229 }  // namespace parallel
230 }  // namespace mindspore
231 #endif  // PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_
232