• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_INCLUDE_TRANSFORM_GRAPH_IR_TYPES_H_
18 #define MINDSPORE_CCSRC_INCLUDE_TRANSFORM_GRAPH_IR_TYPES_H_
19 
20 #include <string>
21 #include <vector>
22 #include <map>
23 #include <memory>
24 #include "utils/hash_map.h"
25 #include "ir/anf.h"
26 #include "ir/dtype.h"
27 #include "ir/tensor.h"
28 
29 #include "graph/tensor.h"
30 #include "ge/ge_api.h"
31 
32 using GeTensor = ::ge::Tensor;
33 
34 namespace mindspore {
35 namespace transform {
36 enum Status : int { SUCCESS = 0, FAILED, INVALID_ARGUMENT, ALREADY_EXISTS, NOT_FOUND };
37 typedef enum { DEFAULT_MODE, ALLOW_FP32_TO_FP16, FORCE_FP32, MUST_KEEP_ORIGIN_DTYPE } AclPrecisionMode;
38 
39 using MeTensor = mindspore::tensor::Tensor;
40 using MeTensorPtr = std::shared_ptr<MeTensor>;
41 using MeDataType = mindspore::TypeId;
42 using GeDataType = ::ge::DataType;
43 using GeFormat = ::ge::Format;
44 using GeShape = ::ge::Shape;
45 using GeTensorPtr = std::shared_ptr<GeTensor>;
46 using GeTensorDesc = ::ge::TensorDesc;
47 using AnfGraph = FuncGraph;
48 using AnfGraphPtr = FuncGraphPtr;
49 using Operator = ::ge::Operator;
50 using OperatorPtr = std::shared_ptr<::ge::Operator>;
51 using DfGraph = ::ge::Graph;
52 using DfGraphPtr = std::shared_ptr<DfGraph>;
53 using TensorMap = mindspore::HashMap<std::string, std::shared_ptr<MeTensor>>;
54 using OptionMap = std::map<std::string, std::string>;
55 using TensorOrderMap = std::map<std::string, std::shared_ptr<tensor::Tensor>>;
56 using GeAllocatorPtr = ::ge::AllocatorPtr;
57 
58 static std::map<std::string, GeDataType> ge_str_dtype_map = {{"float", GeDataType::DT_FLOAT},
59                                                              {"float32", GeDataType::DT_FLOAT},
60                                                              {"float16", GeDataType::DT_FLOAT16},
61                                                              {"int8", GeDataType::DT_INT8},
62                                                              {"int16", GeDataType::DT_INT16},
63                                                              {"int32", GeDataType::DT_INT32},
64                                                              {"int64", GeDataType::DT_INT64},
65                                                              {"uint1", GeDataType::DT_UINT1},
66                                                              {"uint8", GeDataType::DT_UINT8},
67                                                              {"uint16", GeDataType::DT_UINT16},
68                                                              {"uint32", GeDataType::DT_UINT32},
69                                                              {"uint64", GeDataType::DT_UINT64},
70                                                              {"bool", GeDataType::DT_BOOL},
71                                                              {"double", GeDataType::DT_DOUBLE},
72                                                              {"dual", GeDataType::DT_DUAL},
73                                                              {"dual_sub_int8", GeDataType::DT_DUAL_SUB_INT8},
74                                                              {"dual_sub_uint8", GeDataType::DT_DUAL_SUB_UINT8},
75                                                              {"int4", GeDataType::DT_INT4},
76                                                              {"bfloat16", GeDataType::DT_BF16}};
77 static HashMap<GeDataType, std::string> ge_dtype_str_map = {{GeDataType::DT_FLOAT, "float"},
78                                                             {GeDataType::DT_FLOAT16, "float16"},
79                                                             {GeDataType::DT_INT8, "int8"},
80                                                             {GeDataType::DT_INT16, "int16"},
81                                                             {GeDataType::DT_UINT16, "uint16"},
82                                                             {GeDataType::DT_UINT8, "uint8"},
83                                                             {GeDataType::DT_INT32, "int32"},
84                                                             {GeDataType::DT_INT64, "int64"},
85                                                             {GeDataType::DT_UINT32, "uint32"},
86                                                             {GeDataType::DT_UINT64, "uint64"},
87                                                             {GeDataType::DT_BOOL, "bool"},
88                                                             {GeDataType::DT_DOUBLE, "double"},
89                                                             {GeDataType::DT_STRING, "string"},
90                                                             {GeDataType::DT_DUAL_SUB_INT8, "dual_sub_int8"},
91                                                             {GeDataType::DT_DUAL_SUB_UINT8, "dual_sub_uint8"},
92                                                             {GeDataType::DT_COMPLEX64, "complex64"},
93                                                             {GeDataType::DT_COMPLEX128, "complex128"},
94                                                             {GeDataType::DT_DUAL, "dual"},
95                                                             {GeDataType::DT_QINT8, "qint8"},
96                                                             {GeDataType::DT_QINT16, "qint16"},
97                                                             {GeDataType::DT_QINT32, "qint32"},
98                                                             {GeDataType::DT_QUINT8, "quint8"},
99                                                             {GeDataType::DT_QUINT16, "quint16"},
100                                                             {GeDataType::DT_RESOURCE, "resource"},
101                                                             {GeDataType::DT_STRING_REF, "string ref"},
102                                                             {GeDataType::DT_VARIANT, "dt_variant"},
103                                                             {GeDataType::DT_UNDEFINED, "undefined"},
104                                                             {GeDataType::DT_INT4, "int4"},
105                                                             {GeDataType::DT_UINT1, "uint1"},
106                                                             {GeDataType::DT_INT2, "int2"},
107                                                             {GeDataType::DT_UINT2, "uint2"},
108                                                             {GeDataType::DT_COMPLEX32, "complex32"},
109                                                             {GeDataType::DT_BF16, "bf16"}};
110 
111 static std::map<AclPrecisionMode, std::string> acl_precision_map = {{ALLOW_FP32_TO_FP16, "allow_fp32_to_fp16"},
112                                                                     {FORCE_FP32, "force_fp32"},
113                                                                     {MUST_KEEP_ORIGIN_DTYPE, "must_keep_origin_dtype"}};
114 
115 struct DfGraphWrapper {
116  public:
117   DfGraphWrapper(const std::string &name, const int &id, const DfGraphPtr &graph_ptr, const OptionMap &options);
~DfGraphWrapperDfGraphWrapper118   ~DfGraphWrapper() {}
119 
120   std::string name_;
121   int id_;
122   int times_{};
123   DfGraphPtr graph_ptr_;
124   OptionMap options_ = {};
125   bool is_added_to_ge_session_ = false;
126   std::mutex mutex_;
127 };
128 
129 using DfGraphWrapperPtr = std::shared_ptr<DfGraphWrapper>;
130 
131 struct OutHandler {
132   OperatorPtr op;
133   std::string out;
134   AnfNodePtr node;
OutHandlerOutHandler135   OutHandler() : op(nullptr), out(""), node(nullptr) {}
136   OutHandler(const OperatorPtr &op, const std::string out, const AnfNodePtr &node = nullptr)
opOutHandler137       : op(op), out(out), node(node) {}
138 };
139 
140 struct ControlEdge {
141   OperatorPtr src_op;
142   OperatorPtr dest_op;
143 };
144 
145 using SessionOptions = std::map<std::string, std::string>;
146 
147 struct GraphRunnerOptions {
148   std::string target{"default_graph_runner"};
149   SessionOptions options;
150   // if sess_ptr is nullptr, GraphRunner will create a new ge session
151   std::shared_ptr<::ge::Session> sess_ptr{nullptr};
152 };
153 
154 struct RunOptions {
155   // graph's name
156   std::string name;
157 };
158 }  // namespace transform
159 }  // namespace mindspore
160 #endif  // MINDSPORE_CCSRC_INCLUDE_TRANSFORM_GRAPH_IR_TYPES_H_
161