1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef PARALLEL_AUTO_PARALLEL_REC_GRAPH_H_ 18 #define PARALLEL_AUTO_PARALLEL_REC_GRAPH_H_ 19 20 #include <iostream> 21 #include <string> 22 #include <vector> 23 24 #include "frontend/parallel/auto_parallel/rec_core/rec_strategy.h" 25 #include "frontend/parallel/auto_parallel/rec_core/rec_tensor.h" 26 #include "ir/anf.h" 27 28 namespace mindspore { 29 namespace parallel { 30 enum OperatorType { 31 kRecUnknownType, 32 kRecMatMul, 33 kRecConvolution, 34 kRecPooling, 35 kRecElmWiseOp, 36 kRecReLU, 37 kRecBatchNorm, 38 kRecLayerNorm, 39 kRecReshape, 40 kRecBiasAdd, 41 kRecSoftmax, 42 kRecSparseSoftmaxCrossEntropyWithLogits, 43 kRecSoftmaxCrossEntropyWithLogits, 44 kRecOneHot, 45 kRecLog, 46 kRecExp, 47 kRecAdd, 48 kRecSub, 49 kRecMul, 50 kRecDiv, 51 kRecSqueeze, 52 kRecCast, 53 kRecReduce, 54 kRecPReLU, 55 kRecGatherV2, 56 kRecExpandDims, 57 kRecStridedSlice, 58 kRecArgWithValue, 59 kRecUnsortedSegmentOp, 60 kRecBatchMatMul, 61 kRecFlatten, 62 kRecCum, 63 kRecStandAlone, 64 kRecBatchParallel, 65 kRecPadV3, 66 kRecVirtual, 67 kFlashAttentionScore, 68 kRecRmsNorm 69 }; 70 71 enum InfoType { kApplication, kConstant }; 72 73 struct OperatorRec { 74 OperatorType op_type; 75 TensorParam arguments[MAX_INPUT_NUM]; 76 StrategyRec str; 77 std::vector<StrategyRec> strs; 78 }; 79 80 // Define simplified dataflow Graph for partitioning 81 class Graph { 82 public: 83 struct NodeType { 84 std::string name; 85 // Nodes that point to this node 86 std::vector<size_t> node_in; 87 // Nodes that point from this node 88 std::vector<size_t> node_out; 89 // Nodes that point to this node via auxiliary edges 90 std::vector<size_t> node_in_aux; 91 // Input indices of the nodes that point to this node via auxliary edges 92 std::vector<size_t> node_in_aux_idx; 93 94 // Node Type Info: Application or Constant. Defined in enum <InfoType> . 95 InfoType info; 96 // Operator info. Defined in struct <OperatorRec> . 97 OperatorRec apply; 98 // Tensor info. Defined in tensor.h struct <TensorParam> . 99 TensorParam tensor_parm; 100 101 std::string param_name; 102 }; 103 104 bool dyn_shape_tmp_fix = false; 105 106 int64_t micro_batch_size = 1; 107 108 std::vector<Graph::NodeType> nodes; // Nodes of the graph. Public. 109 }; // Define simplified dataflow Graph for partitioning 110 } // namespace parallel 111 } // namespace mindspore 112 #endif // PARALLEL_AUTO_PARALLEL_REC_GRAPH_H_ 113