1 /** 2 * Copyright 2019-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_INCLUDE_TRANSFORM_GRAPH_IR_TYPES_H_ 18 #define MINDSPORE_CCSRC_INCLUDE_TRANSFORM_GRAPH_IR_TYPES_H_ 19 20 #include <string> 21 #include <vector> 22 #include <map> 23 #include <memory> 24 #include "utils/hash_map.h" 25 #include "ir/anf.h" 26 #include "ir/dtype.h" 27 #include "ir/tensor.h" 28 29 #include "graph/tensor.h" 30 #include "ge/ge_api.h" 31 32 using GeTensor = ::ge::Tensor; 33 34 namespace mindspore { 35 namespace transform { 36 enum Status : int { SUCCESS = 0, FAILED, INVALID_ARGUMENT, ALREADY_EXISTS, NOT_FOUND }; 37 typedef enum { DEFAULT_MODE, ALLOW_FP32_TO_FP16, FORCE_FP32, MUST_KEEP_ORIGIN_DTYPE } AclPrecisionMode; 38 39 using MeTensor = mindspore::tensor::Tensor; 40 using MeTensorPtr = std::shared_ptr<MeTensor>; 41 using MeDataType = mindspore::TypeId; 42 using GeDataType = ::ge::DataType; 43 using GeFormat = ::ge::Format; 44 using GeShape = ::ge::Shape; 45 using GeTensorPtr = std::shared_ptr<GeTensor>; 46 using GeTensorDesc = ::ge::TensorDesc; 47 using AnfGraph = FuncGraph; 48 using AnfGraphPtr = FuncGraphPtr; 49 using Operator = ::ge::Operator; 50 using OperatorPtr = std::shared_ptr<::ge::Operator>; 51 using DfGraph = ::ge::Graph; 52 using DfGraphPtr = std::shared_ptr<DfGraph>; 53 using TensorMap = mindspore::HashMap<std::string, std::shared_ptr<MeTensor>>; 54 using OptionMap = std::map<std::string, std::string>; 55 using TensorOrderMap = std::map<std::string, std::shared_ptr<tensor::Tensor>>; 56 using GeAllocatorPtr = ::ge::AllocatorPtr; 57 58 static std::map<std::string, GeDataType> ge_str_dtype_map = {{"float", GeDataType::DT_FLOAT}, 59 {"float32", GeDataType::DT_FLOAT}, 60 {"float16", GeDataType::DT_FLOAT16}, 61 {"int8", GeDataType::DT_INT8}, 62 {"int16", GeDataType::DT_INT16}, 63 {"int32", GeDataType::DT_INT32}, 64 {"int64", GeDataType::DT_INT64}, 65 {"uint1", GeDataType::DT_UINT1}, 66 {"uint8", GeDataType::DT_UINT8}, 67 {"uint16", GeDataType::DT_UINT16}, 68 {"uint32", GeDataType::DT_UINT32}, 69 {"uint64", GeDataType::DT_UINT64}, 70 {"bool", GeDataType::DT_BOOL}, 71 {"double", GeDataType::DT_DOUBLE}, 72 {"dual", GeDataType::DT_DUAL}, 73 {"dual_sub_int8", GeDataType::DT_DUAL_SUB_INT8}, 74 {"dual_sub_uint8", GeDataType::DT_DUAL_SUB_UINT8}, 75 {"int4", GeDataType::DT_INT4}, 76 {"bfloat16", GeDataType::DT_BF16}}; 77 static HashMap<GeDataType, std::string> ge_dtype_str_map = {{GeDataType::DT_FLOAT, "float"}, 78 {GeDataType::DT_FLOAT16, "float16"}, 79 {GeDataType::DT_INT8, "int8"}, 80 {GeDataType::DT_INT16, "int16"}, 81 {GeDataType::DT_UINT16, "uint16"}, 82 {GeDataType::DT_UINT8, "uint8"}, 83 {GeDataType::DT_INT32, "int32"}, 84 {GeDataType::DT_INT64, "int64"}, 85 {GeDataType::DT_UINT32, "uint32"}, 86 {GeDataType::DT_UINT64, "uint64"}, 87 {GeDataType::DT_BOOL, "bool"}, 88 {GeDataType::DT_DOUBLE, "double"}, 89 {GeDataType::DT_STRING, "string"}, 90 {GeDataType::DT_DUAL_SUB_INT8, "dual_sub_int8"}, 91 {GeDataType::DT_DUAL_SUB_UINT8, "dual_sub_uint8"}, 92 {GeDataType::DT_COMPLEX64, "complex64"}, 93 {GeDataType::DT_COMPLEX128, "complex128"}, 94 {GeDataType::DT_DUAL, "dual"}, 95 {GeDataType::DT_QINT8, "qint8"}, 96 {GeDataType::DT_QINT16, "qint16"}, 97 {GeDataType::DT_QINT32, "qint32"}, 98 {GeDataType::DT_QUINT8, "quint8"}, 99 {GeDataType::DT_QUINT16, "quint16"}, 100 {GeDataType::DT_RESOURCE, "resource"}, 101 {GeDataType::DT_STRING_REF, "string ref"}, 102 {GeDataType::DT_VARIANT, "dt_variant"}, 103 {GeDataType::DT_UNDEFINED, "undefined"}, 104 {GeDataType::DT_INT4, "int4"}, 105 {GeDataType::DT_UINT1, "uint1"}, 106 {GeDataType::DT_INT2, "int2"}, 107 {GeDataType::DT_UINT2, "uint2"}, 108 {GeDataType::DT_COMPLEX32, "complex32"}, 109 {GeDataType::DT_BF16, "bf16"}}; 110 111 static std::map<AclPrecisionMode, std::string> acl_precision_map = {{ALLOW_FP32_TO_FP16, "allow_fp32_to_fp16"}, 112 {FORCE_FP32, "force_fp32"}, 113 {MUST_KEEP_ORIGIN_DTYPE, "must_keep_origin_dtype"}}; 114 115 struct DfGraphWrapper { 116 public: 117 DfGraphWrapper(const std::string &name, const int &id, const DfGraphPtr &graph_ptr, const OptionMap &options); ~DfGraphWrapperDfGraphWrapper118 ~DfGraphWrapper() {} 119 120 std::string name_; 121 int id_; 122 int times_{}; 123 DfGraphPtr graph_ptr_; 124 OptionMap options_ = {}; 125 bool is_added_to_ge_session_ = false; 126 std::mutex mutex_; 127 }; 128 129 using DfGraphWrapperPtr = std::shared_ptr<DfGraphWrapper>; 130 131 struct OutHandler { 132 OperatorPtr op; 133 std::string out; 134 AnfNodePtr node; OutHandlerOutHandler135 OutHandler() : op(nullptr), out(""), node(nullptr) {} 136 OutHandler(const OperatorPtr &op, const std::string out, const AnfNodePtr &node = nullptr) opOutHandler137 : op(op), out(out), node(node) {} 138 }; 139 140 struct ControlEdge { 141 OperatorPtr src_op; 142 OperatorPtr dest_op; 143 }; 144 145 using SessionOptions = std::map<std::string, std::string>; 146 147 struct GraphRunnerOptions { 148 std::string target{"default_graph_runner"}; 149 SessionOptions options; 150 // if sess_ptr is nullptr, GraphRunner will create a new ge session 151 std::shared_ptr<::ge::Session> sess_ptr{nullptr}; 152 }; 153 154 struct RunOptions { 155 // graph's name 156 std::string name; 157 }; 158 } // namespace transform 159 } // namespace mindspore 160 #endif // MINDSPORE_CCSRC_INCLUDE_TRANSFORM_GRAPH_IR_TYPES_H_ 161