1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_TRANSFORM_ACL_IR_GE_ADAPTER_INFO_H_ 18 #define MINDSPORE_CCSRC_TRANSFORM_ACL_IR_GE_ADAPTER_INFO_H_ 19 20 #include <map> 21 #include <string> 22 #include <memory> 23 #include <utility> 24 #include <vector> 25 #include <climits> 26 #include <unordered_map> 27 #include "ir/anf.h" 28 #include "ir/tensor.h" 29 #include "utils/hash_map.h" 30 #include "include/transform/graph_ir/types.h" 31 #include "transform/graph_ir/op_adapter_base.h" 32 #include "mindapi/base/shape_vector.h" 33 34 namespace mindspore::transform { 35 using TensorPtr = mindspore::tensor::TensorPtr; 36 37 struct ValuePairHasher { 38 template <typename T> operatorValuePairHasher39 size_t operator()(const std::pair<T, ValuePtr> &p) const { 40 auto hash_value = hash_combine(std::hash<T>()(p.first), PointerHash<ValuePtr>{}(p.second)); 41 return hash_value; 42 } 43 }; 44 45 struct Ms2GeParamInfo { 46 enum ParamType : uint8_t { REQUIRED, OPTIONAL, DYNAMIC }; 47 48 uint32_t index; 49 std::string name; 50 enum ParamType type; 51 bool is_after_dynamic = false; 52 }; 53 54 struct GeTensorInfo { 55 std::string op_type; 56 57 // Attr 58 mindspore::HashMap<std::string, std::string> attr_map; 59 // NOTE: input index starts with 0 60 std::map<uint32_t, std::string> input_attr_map; 61 mindspore::HashMap<size_t, std::string> attr_input_map; 62 63 // Input/Output 64 enum ParamMappingFlag : unsigned int { 65 kDynamicParam = 1 << 0, // has only one dynamic input/output 66 kEmptyParam = 1 << 1, // empty input/output 67 kMultiDynParam = 1 << 2 // has more than one dynamic inputs/outputs 68 }; 69 70 // map input/output indices of operator from MindSpore frontend to GraphEngine backend 71 // K: MindSpore operator input index, V: GE operator input index and type info 72 mindspore::HashMap<int, Ms2GeParamInfo> input_idx_ms2ge; 73 mindspore::HashMap<int, Ms2GeParamInfo> output_idx_ms2ge; 74 std::unordered_map<size_t, size_t> ref_map_; 75 // fields for recording the mapping flags of input/output 76 unsigned int input_mapping_flags = 0; 77 unsigned int output_mapping_flags = 0; 78 // map input/output indices of operator from GraphEngine backend to MindSpore frontend 79 // K: GE operator input index, V: MindSpore operator input index 80 mindspore::HashMap<size_t, int> input_idx_ge2ms; 81 mindspore::HashMap<size_t, int> output_idx_ge2ms; 82 83 // DataType 84 mindspore::HashMap<int, std::vector<enum ::ge::DataType>> input_supported_dtypes; 85 mindspore::HashMap<int, std::vector<enum ::ge::DataType>> output_supported_dtypes; 86 }; 87 88 class GeAdapterInfo { 89 public: GeAdapterInfo(OpAdapterPtr adpt)90 explicit GeAdapterInfo(OpAdapterPtr adpt) : adapter_(std::move(adpt)) {} 91 ~GeAdapterInfo() = default; 92 93 void InitInfo(); 94 op_type()95 const std::string &op_type() const { return info_.op_type; } attr_map()96 const mindspore::HashMap<std::string, std::string> &attr_map() const { return info_.attr_map; } input_attr_map()97 const std::map<uint32_t, std::string> &input_attr_map() const { return info_.input_attr_map; } attr_input_map()98 const mindspore::HashMap<size_t, std::string> &attr_input_map() const { return info_.attr_input_map; } 99 100 // Get number of inputs in mindspore operator prototype, not the real number of inputs GetNumInputsOfMsOpProto()101 size_t GetNumInputsOfMsOpProto() const { 102 // Note: number of ms operator inputs(not real inputs) is equal to size of info_.input_idx_ms2ge 103 return info_.input_idx_ms2ge.size(); 104 } 105 GetMs2GeInputMap()106 const mindspore::HashMap<int, Ms2GeParamInfo> &GetMs2GeInputMap() const { return info_.input_idx_ms2ge; } 107 GetGeInputByMsInputIndex(size_t ms_input_idx)108 const Ms2GeParamInfo &GetGeInputByMsInputIndex(size_t ms_input_idx) const { 109 auto iter = info_.input_idx_ms2ge.find(ms_input_idx); 110 if (iter == info_.input_idx_ms2ge.end()) { 111 MS_LOG(EXCEPTION) << "Find input info from GE operator " << info_.op_type << " for mindspore input index " 112 << ms_input_idx << " fail."; 113 } 114 return iter->second; 115 } 116 GetOptGeInputByMsInputIndex(size_t ms_input_idx)117 const std::optional<Ms2GeParamInfo> GetOptGeInputByMsInputIndex(size_t ms_input_idx) const { 118 auto iter = info_.input_idx_ms2ge.find(ms_input_idx); 119 if (iter != info_.input_idx_ms2ge.end()) { 120 return iter->second; 121 } 122 return std::nullopt; 123 } 124 125 // Get number of outputs in mindspore operator prototype, not the real number of outputs GetNumOutputsOfMsOpProto()126 size_t GetNumOutputsOfMsOpProto() const { 127 // Note: number of ms operator outputs(not real outputs) is equal to size of info_.output_idx_ms2ge 128 return info_.output_idx_ms2ge.size(); 129 } 130 GetNumStaticOutputsOfMsOpProto()131 size_t GetNumStaticOutputsOfMsOpProto() const { 132 // Note: number of ms operator static outputs(not real outputs) 133 return adapter_->getOutputMap().size(); 134 } 135 GetMaxMsProtoIndexOfInputMap()136 int GetMaxMsProtoIndexOfInputMap() { return max_input_ms_proto_idx_; } 137 GetGeOutputByMsOutputIndex(size_t ms_output_idx)138 const Ms2GeParamInfo GetGeOutputByMsOutputIndex(size_t ms_output_idx) const { 139 auto iter = info_.output_idx_ms2ge.find(ms_output_idx); 140 if (iter == info_.output_idx_ms2ge.end()) { 141 MS_LOG(EXCEPTION) << "Find output info from GE operator " << info_.op_type << " for mindspore output index " 142 << ms_output_idx << " fail."; 143 } 144 return iter->second; 145 } 146 GetOptGeOutputByMsOutputIndex(size_t ms_output_idx)147 const std::optional<Ms2GeParamInfo> GetOptGeOutputByMsOutputIndex(size_t ms_output_idx) const { 148 auto iter = info_.output_idx_ms2ge.find(ms_output_idx); 149 if (iter != info_.output_idx_ms2ge.end()) { 150 return iter->second; 151 } 152 return std::nullopt; 153 } 154 GetInputMappingFlags()155 unsigned int GetInputMappingFlags() const { return info_.input_mapping_flags; } 156 GetOutputMappingFlags()157 unsigned int GetOutputMappingFlags() const { return info_.output_mapping_flags; } 158 GetRefMappingInfo()159 const std::unordered_map<size_t, size_t> &GetRefMappingInfo() const { return info_.ref_map_; } 160 input_supported_dtypes()161 mindspore::HashMap<int, std::vector<enum ::ge::DataType>> input_supported_dtypes() const { 162 return info_.input_supported_dtypes; 163 } output_supported_dtypes()164 mindspore::HashMap<int, std::vector<enum ::ge::DataType>> output_supported_dtypes() const { 165 return info_.output_supported_dtypes; 166 } 167 void GetGeAttrValueByMsAttrValue(const std::string &attr_name, ValuePtr *ms_value); 168 void GetGeAttrValueByMsInputValue(const uint32_t &input_idx, ValuePtr *ms_value); 169 170 private: 171 void InitOpType(); 172 173 void InitAclInputsAndOutputs(); 174 void InitRefMap(); 175 template <typename ParamMap, typename DynParamMap> 176 void InitParametersMap(const ParamMap ¶ms, const DynParamMap &dyn_params, bool is_input); 177 178 // attr 179 void InitAttrMap(); 180 void InitInputToAttrMap(); 181 void InitAttrToInputMap(); 182 183 void InitInputSupportedDataType(); 184 void InitOutputSupportedDataType(); 185 186 OpAdapterPtr adapter_{nullptr}; 187 GeTensorInfo info_; 188 // max MindSpore input index in op prototype of INPUT_MAP and DYN_INPUT_MAP 189 int max_input_ms_proto_idx_ = INT_MIN; 190 std::unordered_map<std::pair<std::string, ValuePtr>, ValuePtr, ValuePairHasher> get_attr_cache_; 191 std::unordered_map<std::pair<uint32_t, ValuePtr>, ValuePtr, ValuePairHasher> get_input_attr_cache_; 192 }; 193 194 using GeAdapterInfoPtr = std::shared_ptr<GeAdapterInfo>; 195 196 class GeAdapterManager { 197 public: 198 static GeAdapterManager &GetInstance(); 199 GeAdapterInfoPtr GetInfo(const std::string &prim_name, bool is_training); 200 201 private: 202 GeAdapterManager() = default; 203 ~GeAdapterManager() = default; 204 mindspore::HashMap<std::string, GeAdapterInfoPtr> op_cache_; 205 std::mutex lock_; 206 }; 207 } // namespace mindspore::transform 208 209 #endif // MINDSPORE_CCSRC_TRANSFORM_ACL_IR_GE_ADAPTER_INFO_H_ 210