• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_TRANSFORM_ACL_IR_GE_ADAPTER_INFO_H_
18 #define MINDSPORE_CCSRC_TRANSFORM_ACL_IR_GE_ADAPTER_INFO_H_
19 
20 #include <map>
21 #include <string>
22 #include <memory>
23 #include <utility>
24 #include <vector>
25 #include <climits>
26 #include <unordered_map>
27 #include "ir/anf.h"
28 #include "ir/tensor.h"
29 #include "utils/hash_map.h"
30 #include "include/transform/graph_ir/types.h"
31 #include "transform/graph_ir/op_adapter_base.h"
32 #include "mindapi/base/shape_vector.h"
33 
34 namespace mindspore::transform {
35 using TensorPtr = mindspore::tensor::TensorPtr;
36 
37 struct ValuePairHasher {
38   template <typename T>
operatorValuePairHasher39   size_t operator()(const std::pair<T, ValuePtr> &p) const {
40     auto hash_value = hash_combine(std::hash<T>()(p.first), PointerHash<ValuePtr>{}(p.second));
41     return hash_value;
42   }
43 };
44 
45 struct Ms2GeParamInfo {
46   enum ParamType : uint8_t { REQUIRED, OPTIONAL, DYNAMIC };
47 
48   uint32_t index;
49   std::string name;
50   enum ParamType type;
51   bool is_after_dynamic = false;
52 };
53 
54 struct GeTensorInfo {
55   std::string op_type;
56 
57   // Attr
58   mindspore::HashMap<std::string, std::string> attr_map;
59   // NOTE: input index starts with 0
60   std::map<uint32_t, std::string> input_attr_map;
61   mindspore::HashMap<size_t, std::string> attr_input_map;
62 
63   // Input/Output
64   enum ParamMappingFlag : unsigned int {
65     kDynamicParam = 1 << 0,  // has only one dynamic input/output
66     kEmptyParam = 1 << 1,    // empty input/output
67     kMultiDynParam = 1 << 2  // has more than one dynamic inputs/outputs
68   };
69 
70   // map input/output indices of operator from MindSpore frontend to GraphEngine backend
71   // K: MindSpore operator input index, V: GE operator input index and type info
72   mindspore::HashMap<int, Ms2GeParamInfo> input_idx_ms2ge;
73   mindspore::HashMap<int, Ms2GeParamInfo> output_idx_ms2ge;
74   std::unordered_map<size_t, size_t> ref_map_;
75   // fields for recording the mapping flags of input/output
76   unsigned int input_mapping_flags = 0;
77   unsigned int output_mapping_flags = 0;
78   // map input/output indices of operator from GraphEngine backend to MindSpore frontend
79   // K: GE operator input index, V: MindSpore operator input index
80   mindspore::HashMap<size_t, int> input_idx_ge2ms;
81   mindspore::HashMap<size_t, int> output_idx_ge2ms;
82 
83   // DataType
84   mindspore::HashMap<int, std::vector<enum ::ge::DataType>> input_supported_dtypes;
85   mindspore::HashMap<int, std::vector<enum ::ge::DataType>> output_supported_dtypes;
86 };
87 
88 class GeAdapterInfo {
89  public:
GeAdapterInfo(OpAdapterPtr adpt)90   explicit GeAdapterInfo(OpAdapterPtr adpt) : adapter_(std::move(adpt)) {}
91   ~GeAdapterInfo() = default;
92 
93   void InitInfo();
94 
op_type()95   const std::string &op_type() const { return info_.op_type; }
attr_map()96   const mindspore::HashMap<std::string, std::string> &attr_map() const { return info_.attr_map; }
input_attr_map()97   const std::map<uint32_t, std::string> &input_attr_map() const { return info_.input_attr_map; }
attr_input_map()98   const mindspore::HashMap<size_t, std::string> &attr_input_map() const { return info_.attr_input_map; }
99 
100   // Get number of inputs in mindspore operator prototype, not the real number of inputs
GetNumInputsOfMsOpProto()101   size_t GetNumInputsOfMsOpProto() const {
102     // Note: number of ms operator inputs(not real inputs) is equal to size of info_.input_idx_ms2ge
103     return info_.input_idx_ms2ge.size();
104   }
105 
GetMs2GeInputMap()106   const mindspore::HashMap<int, Ms2GeParamInfo> &GetMs2GeInputMap() const { return info_.input_idx_ms2ge; }
107 
GetGeInputByMsInputIndex(size_t ms_input_idx)108   const Ms2GeParamInfo &GetGeInputByMsInputIndex(size_t ms_input_idx) const {
109     auto iter = info_.input_idx_ms2ge.find(ms_input_idx);
110     if (iter == info_.input_idx_ms2ge.end()) {
111       MS_LOG(EXCEPTION) << "Find input info from GE operator " << info_.op_type << " for mindspore input index "
112                         << ms_input_idx << " fail.";
113     }
114     return iter->second;
115   }
116 
GetOptGeInputByMsInputIndex(size_t ms_input_idx)117   const std::optional<Ms2GeParamInfo> GetOptGeInputByMsInputIndex(size_t ms_input_idx) const {
118     auto iter = info_.input_idx_ms2ge.find(ms_input_idx);
119     if (iter != info_.input_idx_ms2ge.end()) {
120       return iter->second;
121     }
122     return std::nullopt;
123   }
124 
125   // Get number of outputs in mindspore operator prototype, not the real number of outputs
GetNumOutputsOfMsOpProto()126   size_t GetNumOutputsOfMsOpProto() const {
127     // Note: number of ms operator outputs(not real outputs) is equal to size of info_.output_idx_ms2ge
128     return info_.output_idx_ms2ge.size();
129   }
130 
GetNumStaticOutputsOfMsOpProto()131   size_t GetNumStaticOutputsOfMsOpProto() const {
132     // Note: number of ms operator static outputs(not real outputs)
133     return adapter_->getOutputMap().size();
134   }
135 
GetMaxMsProtoIndexOfInputMap()136   int GetMaxMsProtoIndexOfInputMap() { return max_input_ms_proto_idx_; }
137 
GetGeOutputByMsOutputIndex(size_t ms_output_idx)138   const Ms2GeParamInfo GetGeOutputByMsOutputIndex(size_t ms_output_idx) const {
139     auto iter = info_.output_idx_ms2ge.find(ms_output_idx);
140     if (iter == info_.output_idx_ms2ge.end()) {
141       MS_LOG(EXCEPTION) << "Find output info from GE operator " << info_.op_type << " for mindspore output index "
142                         << ms_output_idx << " fail.";
143     }
144     return iter->second;
145   }
146 
GetOptGeOutputByMsOutputIndex(size_t ms_output_idx)147   const std::optional<Ms2GeParamInfo> GetOptGeOutputByMsOutputIndex(size_t ms_output_idx) const {
148     auto iter = info_.output_idx_ms2ge.find(ms_output_idx);
149     if (iter != info_.output_idx_ms2ge.end()) {
150       return iter->second;
151     }
152     return std::nullopt;
153   }
154 
GetInputMappingFlags()155   unsigned int GetInputMappingFlags() const { return info_.input_mapping_flags; }
156 
GetOutputMappingFlags()157   unsigned int GetOutputMappingFlags() const { return info_.output_mapping_flags; }
158 
GetRefMappingInfo()159   const std::unordered_map<size_t, size_t> &GetRefMappingInfo() const { return info_.ref_map_; }
160 
input_supported_dtypes()161   mindspore::HashMap<int, std::vector<enum ::ge::DataType>> input_supported_dtypes() const {
162     return info_.input_supported_dtypes;
163   }
output_supported_dtypes()164   mindspore::HashMap<int, std::vector<enum ::ge::DataType>> output_supported_dtypes() const {
165     return info_.output_supported_dtypes;
166   }
167   void GetGeAttrValueByMsAttrValue(const std::string &attr_name, ValuePtr *ms_value);
168   void GetGeAttrValueByMsInputValue(const uint32_t &input_idx, ValuePtr *ms_value);
169 
170  private:
171   void InitOpType();
172 
173   void InitAclInputsAndOutputs();
174   void InitRefMap();
175   template <typename ParamMap, typename DynParamMap>
176   void InitParametersMap(const ParamMap &params, const DynParamMap &dyn_params, bool is_input);
177 
178   // attr
179   void InitAttrMap();
180   void InitInputToAttrMap();
181   void InitAttrToInputMap();
182 
183   void InitInputSupportedDataType();
184   void InitOutputSupportedDataType();
185 
186   OpAdapterPtr adapter_{nullptr};
187   GeTensorInfo info_;
188   // max MindSpore input index in op prototype of INPUT_MAP and DYN_INPUT_MAP
189   int max_input_ms_proto_idx_ = INT_MIN;
190   std::unordered_map<std::pair<std::string, ValuePtr>, ValuePtr, ValuePairHasher> get_attr_cache_;
191   std::unordered_map<std::pair<uint32_t, ValuePtr>, ValuePtr, ValuePairHasher> get_input_attr_cache_;
192 };
193 
194 using GeAdapterInfoPtr = std::shared_ptr<GeAdapterInfo>;
195 
196 class GeAdapterManager {
197  public:
198   static GeAdapterManager &GetInstance();
199   GeAdapterInfoPtr GetInfo(const std::string &prim_name, bool is_training);
200 
201  private:
202   GeAdapterManager() = default;
203   ~GeAdapterManager() = default;
204   mindspore::HashMap<std::string, GeAdapterInfoPtr> op_cache_;
205   std::mutex lock_;
206 };
207 }  // namespace mindspore::transform
208 
209 #endif  // MINDSPORE_CCSRC_TRANSFORM_ACL_IR_GE_ADAPTER_INFO_H_
210