• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_BUILD_INFO_H_
18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_BUILD_INFO_H_
19 #include <iostream>
20 #include <vector>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include "kernel/oplib/op_info_keys.h"
25 #include "ir/dtype.h"
26 #include "ir/kernel_info_dev.h"
27 #include "kernel/kernel.h"
28 
29 #ifdef OPAQUE
30 #undef OPAQUE
31 #endif
32 
33 namespace mindspore {
34 namespace kernel {
35 constexpr auto kPatternOpaque = "Opaque";
36 enum KernelObjectType : int {
37   UNKNOWN_TYPE = 0,
38   TENSOR,
39   SCALAR,
40   TUPLE,
41   TUPLE_UNFOLD,
42 };
43 
44 enum OpType : int {
45   UNKNOWN_OP_TYPE = 0,
46   DYNAMIC,
47   SKIP,
48 };
49 
50 std::string KernelObjectTypeLabel(const KernelObjectType &obj_type);
51 BACKEND_EXPORT std::string KernelTypeLabel(const KernelType &kernel_type);
52 std::string OpTypeLabel(const OpType &op_type);
53 
54 class BACKEND_EXPORT KernelBuildInfo {
55  public:
56   class KernelBuildInfoBuilder;
57 
58   KernelBuildInfo() = default;
59 
60   ~KernelBuildInfo() = default;
61 
kernel_type()62   KernelType kernel_type() const { return kernel_type_; }
63 
op_type()64   OpType op_type() const { return op_type_; }
65 
set_kernel_type(KernelType kernel_type)66   void set_kernel_type(KernelType kernel_type) { kernel_type_ = kernel_type; }
67 
set_op_type(OpType op_type)68   void set_op_type(OpType op_type) { op_type_ = op_type; }
69 
70   std::string GetInputFormat(size_t input_index) const;
71 
72   std::string GetOutputFormat(size_t output_index) const;
73 
74   TypeId GetInputDeviceType(size_t input_index) const;
75 
76   TypeId GetOutputDeviceType(size_t output_index) const;
77 
78   KernelObjectType GetInputKernelObjectType(size_t input_index) const;
79 
80   KernelObjectType GetOutputKernelObjectType(size_t output_index) const;
81 
82   std::string GetInputReshapeType(size_t input_index) const;
83 
84   bool IsInputDefaultPadding() const;
85 
86   bool IsOutputDefaultPadding() const;
87 
88   std::string GetOutputReshapeType(size_t output_index) const;
89 
90   const std::string &GetOriginDataFormat() const;
91 
92   const std::vector<std::string> &GetAllInputFormats() const;
93 
94   const std::vector<std::string> &GetAllOutputFormats() const;
95 
96   const std::vector<TypeId> &GetAllInputDeviceTypes() const;
97 
98   const std::vector<TypeId> &GetAllOutputDeviceTypes() const;
99 
100   const std::vector<KernelObjectType> &GetAllInputKernelObjectTypes() const;
101 
102   const std::vector<KernelObjectType> &GetAllOutputKernelObjectTypes() const;
103 
104   const std::vector<KernelObjectType> &GetAllOutputElementsKernelObjectTypes() const;
105 
106   void SetOpType(const OpType &op_type);
107 
108   void SetOutputsKernelObjectType(const std::vector<KernelObjectType> &outputs_kernel_object_type);
109 
110   void SetInputsKernelObjectType(const std::vector<KernelObjectType> &inputs_kernel_object_type);
111 
112   void SetOutputElementsKernelObjectType(const std::vector<KernelObjectType> &output_elements_kernel_object_type);
113 
114   const std::vector<std::string> &GetAllOutputReshapeType() const;
115 
116   const std::vector<std::string> &GetAllInputReshapeType() const;
117 
core_type()118   std::string core_type() const { return core_type_; }
119 
120   void SetOutputFormat(const std::string &format, size_t index);
121 
122   void SetInputFormat(const std::string &format, size_t index);
123 
124   void SetOutputDeviceType(const TypeId &output_device_type, size_t index);
125 
126   void SetInputsFormat(const std::vector<std::string> &inputs_format);
127 
128   void SetOutputsFormat(const std::vector<std::string> &outputs_format);
129 
130   void SetInputsDeviceType(const std::vector<TypeId> &inputs_device_type);
131 
132   void SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type);
133 
134   void SetInputsReshapeType(const std::vector<std::string> &input_reshape_type);
135 
op_pattern()136   OpPattern op_pattern() const { return op_pattern_; }
137 
output_data_desc()138   std::vector<nlohmann::json> output_data_desc() const { return output_data_desc_; }
139 
fusion_type()140   std::string fusion_type() const { return fusion_type_; }
141 
valid()142   bool valid() const { return valid_; }
set_valid(bool valid)143   void set_valid(bool valid) { valid_ = valid; }
144 
processor()145   Processor processor() const { return processor_; }
set_processor(Processor processor)146   void set_processor(Processor processor) { processor_ = processor; }
147 
148   size_t GetInputNum() const;
149 
150   size_t GetOutputNum() const;
151 
152   size_t GetOutputNumWithoutMonad() const;
153 
154   std::string ToString() const;
155 
156   bool IsSimilarityKernelBuildInfo(const KernelBuildInfo &other) const;
157 
158   bool operator==(const KernelBuildInfo &other) const;
159 
160   bool operator!=(const KernelBuildInfo &other) const;
161 
162   static auto constexpr kInvalidFormat = "InvalidFormat";
163 
164  private:
165   KernelType kernel_type_{UNKNOWN_KERNEL_TYPE};
166   OpType op_type_{UNKNOWN_OP_TYPE};
167   std::string origin_data_format_{kOpFormat_DEFAULT};
168   std::string core_type_;
169   std::vector<std::string> inputs_format_;
170   OpPattern op_pattern_{kCommonPattern};
171   std::vector<std::string> outputs_format_;
172   std::vector<std::string> input_reshape_type_;
173   std::vector<std::string> output_reshape_type_;
174   std::vector<TypeId> inputs_device_type_;
175   std::vector<TypeId> outputs_device_type_;
176   std::vector<KernelObjectType> inputs_kernel_object_type_;
177   std::vector<KernelObjectType> outputs_kernel_object_type_;
178   // Indicates kernel object types of elements in TupleUnfold.
179   // Only valid when output kernel object type is TupleUnfold.
180   std::vector<KernelObjectType> output_elements_kernel_object_type_;
181   std::vector<nlohmann::json> output_data_desc_;
182   std::string fusion_type_{kPatternOpaque};
183   Processor processor_{UNKNOWN};
184   // Indicates whether buildinfo is valid, the invalid buildinfo needs to select kernel again.
185   bool valid_{true};
186 };
187 using KernelBuildInfoPtr = std::shared_ptr<KernelBuildInfo>;
188 
189 class BACKEND_EXPORT KernelBuildInfo::KernelBuildInfoBuilder {
190  public:
KernelBuildInfoBuilder()191   KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared<KernelBuildInfo>(); }
192 
KernelBuildInfoBuilder(const KernelBuildInfoPtr & kernel_build_info)193   explicit KernelBuildInfoBuilder(const KernelBuildInfoPtr &kernel_build_info)
194       : kernel_build_info_(std::make_shared<KernelBuildInfo>()) {
195     SetKernelType(kernel_build_info->kernel_type());
196     SetOpType(kernel_build_info->op_type());
197     SetFusionType(kernel_build_info->fusion_type());
198     SetProcessor(kernel_build_info->processor());
199     SetOpPattern(kernel_build_info->op_pattern());
200     SetCoreType(kernel_build_info->core_type());
201     SetOutputDataDesc(kernel_build_info->output_data_desc());
202     for (size_t index = 0; index < kernel_build_info->GetInputNum(); ++index) {
203       (void)kernel_build_info_->inputs_device_type_.emplace_back(kernel_build_info->GetInputDeviceType(index));
204       (void)kernel_build_info_->inputs_format_.emplace_back(kernel_build_info->GetInputFormat(index));
205       (void)kernel_build_info_->input_reshape_type_.emplace_back(kernel_build_info->GetInputReshapeType(index));
206     }
207     kernel_build_info_->inputs_kernel_object_type_ = kernel_build_info->GetAllInputKernelObjectTypes();
208 
209     for (size_t index = 0; index < kernel_build_info->GetOutputNum(); ++index) {
210       (void)kernel_build_info_->outputs_device_type_.emplace_back(kernel_build_info->GetOutputDeviceType(index));
211       (void)kernel_build_info_->outputs_format_.emplace_back(kernel_build_info->GetOutputFormat(index));
212       (void)kernel_build_info_->output_reshape_type_.emplace_back(kernel_build_info->GetOutputReshapeType(index));
213     }
214     kernel_build_info_->outputs_kernel_object_type_ = kernel_build_info->GetAllOutputKernelObjectTypes();
215     SetValid(kernel_build_info->valid());
216   }
217 
218   ~KernelBuildInfoBuilder() = default;
219 
220   void SetKernelType(const KernelType &kernel_type);
221 
222   void SetOpType(const OpType &op_type);
223 
224   void SetOriginDataFormat(const std::string &origin_data_format);
225 
226   void SetInputsFormat(const std::vector<std::string> &inputs_format);
227 
228   void SetOutputsFormat(const std::vector<std::string> &outputs_format);
229 
230   void SetInputsDeviceType(const std::vector<TypeId> &inputs_device_type);
231 
232   void SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type);
233 
234   void SetInputsReshapeType(const std::vector<std::string> &input_reshape_type);
235 
236   void SetOutputsReshapeType(const std::vector<std::string> &output_reshape_type);
237 
238   void SetInputsKernelObjectType(const std::vector<KernelObjectType> &input_kernel_object_type);
239 
240   void SetOutputsKernelObjectType(const std::vector<KernelObjectType> &output_kernel_object_type);
241 
242   void SetOutputElementsKernelObjectType(const std::vector<KernelObjectType> &output_elements_kernel_object_type);
243 
244   void SetCoreType(const std::string &core_type);
245 
246   void SetFusionType(const std::string &fusion_type);
247   // save prebuild result
248   void SetOutputDataDesc(const std::vector<nlohmann::json> &data_desc);
249 
250   void SetProcessor(Processor processor);
251 
252   void SetOpPattern(OpPattern pattern);
253 
254   void SetInputFormat(const std::string &format, size_t index);
255 
256   void SetOutputFormat(const std::string &format, size_t index);
257 
258   void SetInputReshapeType(const std::string &input_reshape_type, size_t index);
259 
260   void SetOutputReshapeType(const std::string &output_reshape_type, size_t index);
261 
262   void SetInputDeviceType(const TypeId &input_device_type, size_t index);
263 
264   void SetOutputDeviceType(const TypeId &output_device_type, size_t index);
265 
266   void SetValid(bool valid);
267 
268   std::shared_ptr<KernelBuildInfo> Build();
269 
270  private:
271   std::shared_ptr<KernelBuildInfo> kernel_build_info_;
272 };
273 }  // namespace kernel
274 }  // namespace mindspore
275 #endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_BUILD_INFO_H_
276