• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef PARALLEL_AUTO_PARALLEL_REC_GRAPH_H_
18 #define PARALLEL_AUTO_PARALLEL_REC_GRAPH_H_
19 
20 #include <iostream>
21 #include <string>
22 #include <vector>
23 
24 #include "frontend/parallel/auto_parallel/rec_core/rec_strategy.h"
25 #include "frontend/parallel/auto_parallel/rec_core/rec_tensor.h"
26 #include "ir/anf.h"
27 
28 namespace mindspore {
29 namespace parallel {
30 enum OperatorType {
31   kRecUnknownType,
32   kRecMatMul,
33   kRecConvolution,
34   kRecPooling,
35   kRecElmWiseOp,
36   kRecReLU,
37   kRecBatchNorm,
38   kRecLayerNorm,
39   kRecReshape,
40   kRecBiasAdd,
41   kRecSoftmax,
42   kRecSparseSoftmaxCrossEntropyWithLogits,
43   kRecSoftmaxCrossEntropyWithLogits,
44   kRecOneHot,
45   kRecLog,
46   kRecExp,
47   kRecAdd,
48   kRecSub,
49   kRecMul,
50   kRecDiv,
51   kRecSqueeze,
52   kRecCast,
53   kRecReduce,
54   kRecPReLU,
55   kRecGatherV2,
56   kRecExpandDims,
57   kRecStridedSlice,
58   kRecArgWithValue,
59   kRecUnsortedSegmentOp,
60   kRecBatchMatMul,
61   kRecFlatten,
62   kRecCum,
63   kRecStandAlone,
64   kRecBatchParallel,
65   kRecPadV3,
66   kRecVirtual,
67   kFlashAttentionScore,
68   kRecRmsNorm
69 };
70 
71 enum InfoType { kApplication, kConstant };
72 
73 struct OperatorRec {
74   OperatorType op_type;
75   TensorParam arguments[MAX_INPUT_NUM];
76   StrategyRec str;
77   std::vector<StrategyRec> strs;
78 };
79 
80 // Define simplified dataflow Graph for partitioning
81 class Graph {
82  public:
83   struct NodeType {
84     std::string name;
85     // Nodes that point to this node
86     std::vector<size_t> node_in;
87     // Nodes that point from this node
88     std::vector<size_t> node_out;
89     // Nodes that point to this node via auxiliary edges
90     std::vector<size_t> node_in_aux;
91     // Input indices of the nodes that point to this node via auxliary edges
92     std::vector<size_t> node_in_aux_idx;
93 
94     // Node Type Info: Application or Constant. Defined in enum <InfoType> .
95     InfoType info;
96     // Operator info. Defined in struct <OperatorRec> .
97     OperatorRec apply;
98     // Tensor info. Defined in tensor.h struct <TensorParam> .
99     TensorParam tensor_parm;
100 
101     std::string param_name;
102   };
103 
104   bool dyn_shape_tmp_fix = false;
105 
106   int64_t micro_batch_size = 1;
107 
108   std::vector<Graph::NodeType> nodes;  // Nodes of the graph. Public.
109 };                                     // Define simplified dataflow Graph for partitioning
110 }  // namespace parallel
111 }  // namespace mindspore
112 #endif  // PARALLEL_AUTO_PARALLEL_REC_GRAPH_H_
113