1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "backend/kernel_compiler/kernel_build_info.h"
18 #include "utils/log_adapter.h"
19 #include "debug/anf_ir_dump.h"
20 namespace mindspore {
21 namespace kernel {
GetInputFormat(size_t input_index) const22 std::string KernelBuildInfo::GetInputFormat(size_t input_index) const {
23 if (input_index >= inputs_format_.size()) {
24 MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input node";
25 return kInvalidFormat;
26 }
27 return inputs_format_[input_index];
28 }
29
GetOutputFormat(size_t output_index) const30 std::string KernelBuildInfo::GetOutputFormat(size_t output_index) const {
31 if (output_index >= outputs_format_.size()) {
32 MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of output";
33 return kInvalidFormat;
34 }
35 return outputs_format_[output_index];
36 }
37
GetInputDeviceType(size_t input_index) const38 TypeId KernelBuildInfo::GetInputDeviceType(size_t input_index) const {
39 if (input_index >= inputs_device_type_.size()) {
40 MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input";
41 return TypeId::kNumberTypeEnd;
42 }
43 return inputs_device_type_[input_index];
44 }
45
GetOutputDeviceType(size_t output_index) const46 TypeId KernelBuildInfo::GetOutputDeviceType(size_t output_index) const {
47 if (output_index >= outputs_device_type_.size()) {
48 MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of output";
49 return TypeId::kNumberTypeEnd;
50 }
51 return outputs_device_type_[output_index];
52 }
53
GetOriginDataFormat() const54 const std::string &KernelBuildInfo::GetOriginDataFormat() const { return origin_data_format_; }
55
GetAllInputFormats() const56 const std::vector<std::string> &KernelBuildInfo::GetAllInputFormats() const { return inputs_format_; }
57
GetAllOutputFormats() const58 const std::vector<std::string> &KernelBuildInfo::GetAllOutputFormats() const { return outputs_format_; }
59
GetAllInputDeviceTypes() const60 const std::vector<TypeId> &KernelBuildInfo::GetAllInputDeviceTypes() const { return inputs_device_type_; }
61
GetAllOutputDeviceTypes() const62 const std::vector<TypeId> &KernelBuildInfo::GetAllOutputDeviceTypes() const { return outputs_device_type_; }
63
GetInputNum() const64 size_t KernelBuildInfo::GetInputNum() const { return inputs_format_.size(); }
65
GetOutputNum() const66 size_t KernelBuildInfo::GetOutputNum() const { return outputs_format_.size(); }
67
GetInputReshapeType(size_t input_index) const68 std::string KernelBuildInfo::GetInputReshapeType(size_t input_index) const {
69 if (input_reshape_type_.empty()) {
70 return "";
71 }
72 if (input_index >= input_reshape_type_.size()) {
73 MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node size "
74 << input_reshape_type_.size();
75 }
76 return input_reshape_type_[input_index];
77 }
78
GetInputValueDepend(size_t input_index) const79 std::string KernelBuildInfo::GetInputValueDepend(size_t input_index) const {
80 if (input_value_depend_.empty()) {
81 return "";
82 }
83 if (input_index >= input_value_depend_.size()) {
84 MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node size "
85 << input_value_depend_.size();
86 }
87 return input_value_depend_[input_index];
88 }
89
GetOutputReshapeType(size_t output_index) const90 std::string KernelBuildInfo::GetOutputReshapeType(size_t output_index) const {
91 if (output_reshape_type_.empty()) {
92 return "";
93 }
94 if (output_index >= output_reshape_type_.size()) {
95 MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of output node size "
96 << output_reshape_type_.size();
97 }
98 return output_reshape_type_[output_index];
99 }
100
ToString() const101 std::string KernelBuildInfo::ToString() const {
102 std::ostringstream output_buffer;
103 output_buffer << "(";
104 for (size_t index = 0; index < GetInputNum(); ++index) {
105 if (index != 0) {
106 output_buffer << ", ";
107 }
108 output_buffer << "<" << ToShortString(GetInputDeviceType(index)) << "x" << GetInputFormat(index) << ">";
109 }
110 output_buffer << ") -> (";
111 for (size_t index = 0; index < GetOutputNum(); ++index) {
112 if (index != 0) {
113 output_buffer << ", ";
114 }
115 output_buffer << "<" << ToShortString(GetOutputDeviceType(index)) << "x" << GetOutputFormat(index) << ">";
116 }
117 output_buffer << ")";
118 return output_buffer.str();
119 }
120
IsSimilarityKernelBuildInfo(const KernelBuildInfo & other) const121 bool KernelBuildInfo::IsSimilarityKernelBuildInfo(const KernelBuildInfo &other) const {
122 if (inputs_format_ != other.inputs_format_ || outputs_format_ != other.outputs_format_) {
123 if (op_pattern_ != kFormatAgnosticPattern) {
124 return false;
125 } else {
126 MS_LOG(INFO) << "This kernel build info:" << this->ToString()
127 << ", other kernel build info: " << other.ToString();
128 }
129 }
130 return !(inputs_device_type_ != other.inputs_device_type_ || outputs_device_type_ != other.outputs_device_type_);
131 }
132
operator ==(const KernelBuildInfo & other) const133 bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const {
134 if (kernel_type_ != other.kernel_type_ || fusion_type_ != other.fusion_type_ || processor_ != other.processor_) {
135 return false;
136 }
137 return IsSimilarityKernelBuildInfo(other);
138 }
139
IsInputDefaultPadding() const140 bool KernelBuildInfo::IsInputDefaultPadding() const { return input_reshape_type_.empty(); }
141
IsOutputDefaultPadding() const142 bool KernelBuildInfo::IsOutputDefaultPadding() const { return output_reshape_type_.empty(); }
143
operator !=(const KernelBuildInfo & other) const144 bool KernelBuildInfo::operator!=(const KernelBuildInfo &other) const { return !((*this) == other); }
145
SetKernelType(const KernelType & kernel_type)146 void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &kernel_type) {
147 MS_EXCEPTION_IF_NULL(kernel_build_info_);
148 kernel_build_info_->kernel_type_ = kernel_type;
149 }
150
SetOriginDataFormat(const std::string & origin_data_format)151 void KernelBuildInfo::KernelBuildInfoBuilder::SetOriginDataFormat(const std::string &origin_data_format) {
152 MS_EXCEPTION_IF_NULL(kernel_build_info_);
153 kernel_build_info_->origin_data_format_ = origin_data_format;
154 }
155
SetInputsFormat(const std::vector<std::string> & inputs_format)156 void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsFormat(const std::vector<std::string> &inputs_format) {
157 MS_EXCEPTION_IF_NULL(kernel_build_info_);
158 kernel_build_info_->inputs_format_ = inputs_format;
159 }
160
SetOutputsFormat(const std::vector<std::string> & outputs_format)161 void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsFormat(const std::vector<std::string> &outputs_format) {
162 MS_EXCEPTION_IF_NULL(kernel_build_info_);
163 kernel_build_info_->outputs_format_ = outputs_format;
164 }
165
SetInputsDeviceType(const std::vector<TypeId> & inputs_device_type)166 void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsDeviceType(const std::vector<TypeId> &inputs_device_type) {
167 MS_EXCEPTION_IF_NULL(kernel_build_info_);
168 kernel_build_info_->inputs_device_type_ = inputs_device_type;
169 }
170
SetOutputsDeviceType(const std::vector<TypeId> & outputs_device_type)171 void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type) {
172 MS_EXCEPTION_IF_NULL(kernel_build_info_);
173 kernel_build_info_->outputs_device_type_ = outputs_device_type;
174 }
175
SetFusionType(FusionType fusion_type)176 void KernelBuildInfo::KernelBuildInfoBuilder::SetFusionType(FusionType fusion_type) {
177 MS_EXCEPTION_IF_NULL(kernel_build_info_);
178 kernel_build_info_->fusion_type_ = fusion_type;
179 }
180
SetOutputDataDesc(const std::vector<nlohmann::json> & data_desc)181 void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputDataDesc(const std::vector<nlohmann::json> &data_desc) {
182 MS_EXCEPTION_IF_NULL(kernel_build_info_);
183 kernel_build_info_->output_data_desc_ = data_desc;
184 }
185
SetProcessor(Processor processor)186 void KernelBuildInfo::KernelBuildInfoBuilder::SetProcessor(Processor processor) {
187 MS_EXCEPTION_IF_NULL(kernel_build_info_);
188 kernel_build_info_->processor_ = processor;
189 }
190
Build()191 std::shared_ptr<KernelBuildInfo> KernelBuildInfo::KernelBuildInfoBuilder::Build() { return kernel_build_info_; }
192
SetInputsReshapeType(const std::vector<std::string> & input_reshape_type)193 void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsReshapeType(const std::vector<std::string> &input_reshape_type) {
194 MS_EXCEPTION_IF_NULL(kernel_build_info_);
195 kernel_build_info_->input_reshape_type_ = input_reshape_type;
196 }
197
SetInputsValueDepend(const std::vector<std::string> & input_value_depend)198 void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsValueDepend(const std::vector<std::string> &input_value_depend) {
199 MS_EXCEPTION_IF_NULL(kernel_build_info_);
200 kernel_build_info_->input_value_depend_ = input_value_depend;
201 }
202
SetOutputsReshapeType(const std::vector<std::string> & output_reshape_type)203 void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsReshapeType(
204 const std::vector<std::string> &output_reshape_type) {
205 MS_EXCEPTION_IF_NULL(kernel_build_info_);
206 kernel_build_info_->output_reshape_type_ = output_reshape_type;
207 }
208
SetOpPattern(OpPattern pattern)209 void KernelBuildInfo::KernelBuildInfoBuilder::SetOpPattern(OpPattern pattern) {
210 MS_EXCEPTION_IF_NULL(kernel_build_info_);
211 kernel_build_info_->op_pattern_ = pattern;
212 }
SetInputFormat(const std::string & format,size_t index)213 void KernelBuildInfo::KernelBuildInfoBuilder::SetInputFormat(const std::string &format, size_t index) {
214 MS_EXCEPTION_IF_NULL(kernel_build_info_);
215 auto index_limit = kernel_build_info_->inputs_format_.size();
216 if (index >= index_limit) {
217 MS_LOG(EXCEPTION) << "Index of input format out of range! The value should be less than: " << index_limit
218 << ", but got: " << index;
219 }
220 kernel_build_info_->inputs_format_[index] = format;
221 }
222
SetOutputFormat(const std::string & format,size_t index)223 void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string &format, size_t index) {
224 MS_EXCEPTION_IF_NULL(kernel_build_info_);
225 auto index_limit = kernel_build_info_->outputs_format_.size();
226 if (index >= index_limit) {
227 MS_LOG(EXCEPTION) << "Index of output format out of range! The value should be less than: " << index_limit
228 << ", but got: " << index;
229 }
230 kernel_build_info_->outputs_format_[index] = format;
231 }
SetInputReshapeType(const std::string & input_reshape_type,size_t index)232 void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(const std::string &input_reshape_type, size_t index) {
233 MS_EXCEPTION_IF_NULL(kernel_build_info_);
234 auto index_limit = kernel_build_info_->input_reshape_type_.size();
235 if (index >= index_limit) {
236 MS_LOG(EXCEPTION) << "Index of input_reshape_type out of range! The value should be less than: " << index_limit
237 << ", but got: " << index;
238 }
239 (void)std::copy(input_reshape_type.begin(), input_reshape_type.end(),
240 std::back_inserter(kernel_build_info_->input_reshape_type_[index]));
241 }
242
SetOutputReshapeType(const std::string & output_reshape_type,size_t index)243 void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(const std::string &output_reshape_type,
244 size_t index) {
245 MS_EXCEPTION_IF_NULL(kernel_build_info_);
246 auto index_limit = kernel_build_info_->output_reshape_type_.size();
247 if (index >= index_limit) {
248 MS_LOG(EXCEPTION) << "Index of output_reshape_type out of range! The value should be less than: " << index_limit
249 << ", but got: " << index;
250 }
251 (void)std::copy(output_reshape_type.begin(), output_reshape_type.end(),
252 std::back_inserter(kernel_build_info_->output_reshape_type_[index]));
253 }
254
SetOutputDeviceType(const TypeId & output_device_type,size_t index)255 void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputDeviceType(const TypeId &output_device_type, size_t index) {
256 MS_EXCEPTION_IF_NULL(kernel_build_info_);
257 auto index_limit = kernel_build_info_->outputs_device_type_.size();
258 if (index >= index_limit) {
259 MS_LOG(EXCEPTION) << "Index of output_device_type out of range! The value should be less than: " << index_limit
260 << ", but got: " << index;
261 }
262 kernel_build_info_->outputs_device_type_[index] = output_device_type;
263 }
264
SetInputDeviceType(const TypeId & input_device_type,size_t index)265 void KernelBuildInfo::KernelBuildInfoBuilder::SetInputDeviceType(const TypeId &input_device_type, size_t index) {
266 MS_EXCEPTION_IF_NULL(kernel_build_info_);
267 auto index_limit = kernel_build_info_->inputs_device_type_.size();
268 if (index >= index_limit) {
269 MS_LOG(EXCEPTION) << "Index of input_device_type out of range! The value should be less than: " << index_limit
270 << ", but got: " << index;
271 }
272 kernel_build_info_->inputs_device_type_[index] = input_device_type;
273 }
274 } // namespace kernel
275 } // namespace mindspore
276