• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_
18 #define PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 #include <set>
26 
27 #include "frontend/parallel/auto_parallel/rec_core/rec_graph.h"
28 #include "frontend/parallel/ops_info/operator_info.h"
29 
30 namespace mindspore {
31 namespace parallel {
32 static const std::set<OperatorType> ElementWiseOpType = {
33   OperatorType::kRecReLU,      OperatorType::kRecLog,      OperatorType::kRecExp,         OperatorType::kRecAdd,
34   OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd,  OperatorType::kRecSub,         OperatorType::kRecMul,
35   OperatorType::kRecDiv,       OperatorType::kRecSqueeze,  OperatorType::kRecReduce,      OperatorType::kRecCast,
36   OperatorType::kRecReshape,   OperatorType::kRecGatherV2, OperatorType::kRecArgWithValue};
37 
38 const std::map<std::string, OperatorType> DictOpType{
39   {MATMUL, OperatorType::kRecMatMul},
40   {CONV2D, OperatorType::kRecConvolution},
41   {MAXPOOL, OperatorType::kRecPooling},
42   {MAXPOOLV2, OperatorType::kRecPooling},
43   {POOLING, OperatorType::kRecPooling},
44   {MAX_POOL_WITH_ARGMAX, OperatorType::kRecPooling},
45   {SIMPLE_MEAN, OperatorType::kRecPooling},
46   {RESHAPE, OperatorType::kRecReshape},
47   {BIAS_ADD, OperatorType::kRecBiasAdd},
48   {BATCH_NORM, OperatorType::kRecBatchNorm},
49   {LAYER_NORM, OperatorType::kRecBatchNorm},
50   {SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits},
51   {ONEHOT, OperatorType::kRecOneHot},
52   {SQUEEZE, OperatorType::kRecSqueeze},
53   {CAST, OperatorType::kRecCast},
54   {REDUCE_SUM, OperatorType::kRecReduce},
55   {REDUCE_MAX, OperatorType::kRecReduce},
56   {REDUCE_MIN, OperatorType::kRecReduce},
57   {REDUCE_MEAN, OperatorType::kRecReduce},
58   {GATHERV2, OperatorType::kRecGatherV2},
59   {ARGMAXWITHVALUE, OperatorType::kRecArgWithValue},
60   {ARGMINWITHVALUE, OperatorType::kRecArgWithValue},
61   {UNSORTED_SEGMENT_SUM, OperatorType::kRecUnsortedSegmentOp},
62   {UNSORTED_SEGMENT_MAX, OperatorType::kRecUnsortedSegmentOp},
63   {UNSORTED_SEGMENT_MIN, OperatorType::kRecUnsortedSegmentOp},
64   // Activation OP
65   {ACTIVATION, OperatorType::kRecReLU},
66   {RELU, OperatorType::kRecReLU},
67   {"ReLU6", OperatorType::kRecReLU},
68   {"ReLUV2", OperatorType::kRecReLU},
69   {SIGMOID, OperatorType::kRecReLU},
70   {SIGMOID_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecReLU},
71   {"HSigmoid", OperatorType::kRecReLU},
72   {GELU, OperatorType::kRecReLU},
73   {TANH, OperatorType::kRecReLU},
74   {SOFTPLUS, OperatorType::kRecReLU},
75   {SOFTSIGN, OperatorType::kRecReLU},
76   {PRELU, OperatorType::kRecPReLU},
77   // Elm-wise OP
78   {TRANSPOSE, OperatorType::kRecElmWiseOp},
79   {L2_NORMALIZE, OperatorType::kRecElmWiseOp},
80   {ADD, OperatorType::kRecElmWiseOp},
81   {TENSOR_DOT, OperatorType::kRecElmWiseOp},
82   {SUB, OperatorType::kRecElmWiseOp},
83   {MUL, OperatorType::kRecElmWiseOp},
84   {DIV, OperatorType::kRecElmWiseOp},
85   {REAL_DIV, OperatorType::kRecElmWiseOp},
86   {SOFTMAX, OperatorType::kRecSoftmax},
87   {LOG_SOFTMAX, OperatorType::kRecSoftmax},
88   {SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, OperatorType::kRecSoftmaxCrossEntropyWithLogits},
89   {SQRT, OperatorType::kRecElmWiseOp},
90   {NEG, OperatorType::kRecElmWiseOp},
91   {POW, OperatorType::kRecElmWiseOp},
92   {EXP, OperatorType::kRecElmWiseOp},
93   {LOG, OperatorType::kRecElmWiseOp},
94   {COS, OperatorType::kRecElmWiseOp},
95   {ACOS, OperatorType::kRecElmWiseOp},
96   {ASIN, OperatorType::kRecElmWiseOp},
97   {ASINH, OperatorType::kRecElmWiseOp},
98   {ATAN, OperatorType::kRecElmWiseOp},
99   {ATANH, OperatorType::kRecElmWiseOp},
100   {EXPM1, OperatorType::kRecElmWiseOp},
101   {LOG1P, OperatorType::kRecElmWiseOp},
102   {LOGICALNOT, OperatorType::kRecElmWiseOp},
103   {"LogicalAnd", OperatorType::kRecElmWiseOp},
104   {"LogicalOr", OperatorType::kRecElmWiseOp},
105   {SQUARE, OperatorType::kRecElmWiseOp},
106   {"Abs", OperatorType::kRecElmWiseOp},
107   {"Acosh", OperatorType::kRecElmWiseOp},
108   {"AddN", OperatorType::kRecElmWiseOp},
109   {"AccumulateNV2", OperatorType::kRecElmWiseOp},
110   {"Atan2", OperatorType::kRecElmWiseOp},
111   {ELU, OperatorType::kRecElmWiseOp},
112   {ERF, OperatorType::kRecElmWiseOp},
113   {ERFC, OperatorType::kRecElmWiseOp},
114   {MOD, OperatorType::kRecElmWiseOp},
115   {FLOOR, OperatorType::kRecElmWiseOp},
116   {CEIL, OperatorType::kRecElmWiseOp},
117   {FLOORDIV, OperatorType::kRecElmWiseOp},
118   {"FloorMod", OperatorType::kRecElmWiseOp},
119   {GREATER, OperatorType::kRecElmWiseOp},
120   {"GreaterEqual", OperatorType::kRecElmWiseOp},
121   {"HSwish", OperatorType::kRecElmWiseOp},
122   {"Less", OperatorType::kRecElmWiseOp},
123   {"LessEqual", OperatorType::kRecElmWiseOp},
124   {MAXIMUM, OperatorType::kRecElmWiseOp},
125   {MINIMUM, OperatorType::kRecElmWiseOp},
126   {EQUAL, OperatorType::kRecElmWiseOp},
127   {NOT_EQUAL, OperatorType::kRecElmWiseOp},
128   {APPROXIMATEEQUAL, OperatorType::kRecElmWiseOp},
129   {INV, OperatorType::kRecElmWiseOp},
130   {BESSELI0E, OperatorType::kRecElmWiseOp},
131   {BESSELI1E, OperatorType::kRecElmWiseOp},
132   {ZEROSLIKE, OperatorType::kRecElmWiseOp},
133   {ONESLIKE, OperatorType::kRecElmWiseOp},
134   {DIVNONAN, OperatorType::kRecElmWiseOp},
135   {"Reciprocal", OperatorType::kRecElmWiseOp},
136   {"Round", OperatorType::kRecElmWiseOp},
137   {"Rsqrt", OperatorType::kRecElmWiseOp},
138   {"Sign", OperatorType::kRecElmWiseOp},
139   {SIN, OperatorType::kRecElmWiseOp},
140   {SINH, OperatorType::kRecElmWiseOp},
141   {TAN, OperatorType::kRecElmWiseOp},
142   {ASSIGN, OperatorType::kRecElmWiseOp},
143   {ASSIGN_ADD, OperatorType::kRecElmWiseOp},
144   {ASSIGN_SUB, OperatorType::kRecElmWiseOp},
145   {"AssignAdd", OperatorType::kRecElmWiseOp},
146   {DROPOUT_DO_MASK, OperatorType::kRecElmWiseOp},
147   {STACK, OperatorType::kRecElmWiseOp}};
148 
149 const TensorParam MakeTensor(int64_t n, int64_t c, int64_t h, int64_t w);
150 
151 Graph::NodeType MakeNewOperator(const std::vector<std::shared_ptr<OperatorInfo>> &ops, size_t iter_ops);
152 
153 OperatorRec CompleteOperatorInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
154                                    Graph::NodeType NewTensor);
155 
156 TensorParam Complete2DInputs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
157                              const size_t iter_input_tensor, Graph::NodeType NewTensor);
158 
159 std::shared_ptr<Graph> ParseGraph(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
160                                   const std::vector<std::vector<std::string>> &input_tensor_names);
161 
162 void MakeEdge(const std::vector<std::vector<std::string>> &input_tensor_names, const std::shared_ptr<Graph> &graph);
163 
164 size_t GetIndexInInputTensorNames(const std::vector<std::vector<std::string>> &input_tensor_names,
165                                   const std::string &input_name);
166 
167 void Eliminate_Aux(const size_t node_index, const std::shared_ptr<Graph> &graph,
168                    const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list);
169 
170 std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> &graph,
171                                       const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
172                                       const std::shared_ptr<std::vector<size_t>> &index_list);
173 }  // namespace parallel
174 }  // namespace mindspore
175 #endif  // PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_
176