• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_BUILD_INFO_H_
18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_BUILD_INFO_H_
19 #include <iostream>
20 #include <vector>
21 #include <memory>
22 #include <string>
23 #include <utility>
24 #include "ir/dtype.h"
25 #include "ir/kernel_info_dev.h"
26 #include "backend/kernel_compiler/kernel.h"
27 
28 namespace mindspore {
29 namespace kernel {
30 class KernelBuildInfo {
31  public:
32   class KernelBuildInfoBuilder;
33 
KernelBuildInfo()34   KernelBuildInfo() {
35     kernel_type_ = TBE_KERNEL;
36     fusion_type_ = OPAQUE;
37     processor_ = AICORE;
38     op_pattern_ = kCommonPattern;
39     input_reshape_type_ = {};
40     output_reshape_type_ = {};
41     origin_data_format_ = kOpFormat_DEFAULT;
42     inputs_format_ = {};
43     outputs_format_ = {};
44     inputs_device_type_ = {};
45     outputs_device_type_ = {};
46     output_data_desc_ = {};
47   }
48 
49   ~KernelBuildInfo() = default;
50 
kernel_type()51   KernelType kernel_type() const { return kernel_type_; }
52 
53   std::string GetInputFormat(size_t input_index) const;
54 
55   std::string GetOutputFormat(size_t output_index) const;
56 
57   TypeId GetInputDeviceType(size_t input_index) const;
58 
59   TypeId GetOutputDeviceType(size_t output_index) const;
60 
61   std::string GetInputReshapeType(size_t input_index) const;
62 
63   std::string GetInputValueDepend(size_t input_index) const;
64 
65   bool IsInputDefaultPadding() const;
66 
67   bool IsOutputDefaultPadding() const;
68 
69   std::string GetOutputReshapeType(size_t input_index) const;
70 
71   const std::string &GetOriginDataFormat() const;
72 
73   const std::vector<std::string> &GetAllInputFormats() const;
74 
75   const std::vector<std::string> &GetAllOutputFormats() const;
76 
77   const std::vector<TypeId> &GetAllInputDeviceTypes() const;
78 
79   const std::vector<TypeId> &GetAllOutputDeviceTypes() const;
80 
81   std::vector<std::string> GetAllOutputReshapeType() const;
82 
83   std::vector<std::string> GetAllInputReshapeType() const;
84 
op_pattern()85   OpPattern op_pattern() const { return op_pattern_; }
86 
output_data_desc()87   std::vector<nlohmann::json> output_data_desc() const { return output_data_desc_; }
88 
fusion_type()89   FusionType fusion_type() const { return fusion_type_; }
90 
processor()91   Processor processor() const { return processor_; }
92 
93   size_t GetInputNum() const;
94 
95   size_t GetOutputNum() const;
96 
97   std::string ToString() const;
98 
99   bool IsSimilarityKernelBuildInfo(const KernelBuildInfo &other) const;
100 
101   bool operator==(const KernelBuildInfo &other) const;
102 
103   bool operator!=(const KernelBuildInfo &other) const;
104 
105   static auto constexpr kInvalidFormat = "InvalidFormat";
106 
107  private:
108   KernelType kernel_type_;
109   std::string origin_data_format_;
110   std::vector<std::string> inputs_format_;
111   OpPattern op_pattern_;
112   std::vector<std::string> outputs_format_;
113   std::vector<std::string> input_reshape_type_;
114   std::vector<std::string> output_reshape_type_;
115   std::vector<TypeId> inputs_device_type_;
116   std::vector<TypeId> outputs_device_type_;
117   std::vector<nlohmann::json> output_data_desc_;
118   std::vector<std::string> input_value_depend_;
119   FusionType fusion_type_;
120   Processor processor_;
121 };
122 using KernelBuildInfoPtr = std::shared_ptr<KernelBuildInfo>;
123 
124 class KernelBuildInfo::KernelBuildInfoBuilder {
125  public:
KernelBuildInfoBuilder()126   KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared<KernelBuildInfo>(); }
127 
KernelBuildInfoBuilder(std::shared_ptr<KernelBuildInfo> kernel_build_info)128   explicit KernelBuildInfoBuilder(std::shared_ptr<KernelBuildInfo> kernel_build_info)
129       : kernel_build_info_(std::make_shared<KernelBuildInfo>()) {
130     SetKernelType(kernel_build_info->kernel_type());
131     SetFusionType(kernel_build_info->fusion_type());
132     SetProcessor(kernel_build_info->processor());
133     SetOpPattern(kernel_build_info->op_pattern());
134     for (size_t index = 0; index < kernel_build_info->GetInputNum(); ++index) {
135       kernel_build_info_->inputs_device_type_.emplace_back(kernel_build_info->GetInputDeviceType(index));
136       kernel_build_info_->inputs_format_.emplace_back(kernel_build_info->GetInputFormat(index));
137       kernel_build_info_->input_reshape_type_.emplace_back(kernel_build_info->GetInputReshapeType(index));
138       kernel_build_info_->input_value_depend_.emplace_back(kernel_build_info->GetInputValueDepend(index));
139     }
140     for (size_t index = 0; index < kernel_build_info->GetOutputNum(); ++index) {
141       kernel_build_info_->outputs_device_type_.emplace_back(kernel_build_info->GetOutputDeviceType(index));
142       kernel_build_info_->outputs_format_.emplace_back(kernel_build_info->GetOutputFormat(index));
143       kernel_build_info_->output_reshape_type_.emplace_back(kernel_build_info->GetOutputReshapeType(index));
144     }
145   }
146 
147   ~KernelBuildInfoBuilder() = default;
148 
149   void SetKernelType(const KernelType &kernel_type);
150 
151   void SetOriginDataFormat(const std::string &origin_data_format);
152 
153   void SetInputsFormat(const std::vector<std::string> &inputs_format);
154 
155   void SetOutputsFormat(const std::vector<std::string> &outputs_format);
156 
157   void SetInputsDeviceType(const std::vector<TypeId> &inputs_device_type);
158 
159   void SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type);
160 
161   void SetInputsReshapeType(const std::vector<std::string> &input_reshape_type);
162 
163   void SetInputsValueDepend(const std::vector<std::string> &input_value_depend);
164 
165   void SetOutputsReshapeType(const std::vector<std::string> &output_reshape_type);
166 
167   void SetFusionType(FusionType fusion_type);
168   // save prebuild result
169   void SetOutputDataDesc(const std::vector<nlohmann::json> &data_desc);
170 
171   void SetProcessor(Processor processor);
172 
173   void SetOpPattern(OpPattern pattern);
174 
175   void SetInputFormat(const std::string &format, size_t index);
176 
177   void SetOutputFormat(const std::string &format, size_t index);
178 
179   void SetInputReshapeType(const std::string &input_reshape_type, size_t index);
180 
181   void SetOutputReshapeType(const std::string &output_reshape_type, size_t index);
182 
183   void SetInputDeviceType(const TypeId &input_device_type, size_t index);
184 
185   void SetOutputDeviceType(const TypeId &output_device_type, size_t index);
186 
187   std::shared_ptr<KernelBuildInfo> Build();
188 
189  private:
190   std::shared_ptr<KernelBuildInfo> kernel_build_info_;
191 };
192 }  // namespace kernel
193 }  // namespace mindspore
194 #endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_BUILD_INFO_H_
195