1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "kernel/kernel_build_info.h"
18
19 #include <algorithm>
20 #include <unordered_map>
21 #include "utils/log_adapter.h"
22 #include "include/common/debug/anf_dump_utils.h"
23
24 namespace mindspore {
25 namespace kernel {
KernelObjectTypeLabel(const KernelObjectType & obj_type)26 std::string KernelObjectTypeLabel(const KernelObjectType &obj_type) {
27 std::unordered_map<KernelObjectType, std::string> trans_map{{KernelObjectType::TUPLE, "Tuple"},
28 {KernelObjectType::SCALAR, "Scalar"},
29 {KernelObjectType::TENSOR, "Tensor"},
30 {KernelObjectType::UNKNOWN_TYPE, "Unknown"},
31 {KernelObjectType::TUPLE_UNFOLD, "TupleUnfold"}};
32 if (trans_map.find(obj_type) == trans_map.end()) {
33 return "Unknown";
34 }
35 return trans_map[obj_type];
36 }
37
KernelTypeLabel(const KernelType & kernel_type)38 std::string KernelTypeLabel(const KernelType &kernel_type) {
39 std::unordered_map<KernelType, std::string> trans_map{{KernelType::UNKNOWN_KERNEL_TYPE, "UNKNOWN_KERNEL_TYPE"},
40 {KernelType::AKG_KERNEL, "AKG_KERNEL"},
41 {KernelType::AICPU_KERNEL, "AICPU_KERNEL"},
42 {KernelType::RT_KERNEL, "RT_KERNEL"},
43 {KernelType::HCCL_KERNEL, "HCCL_KERNEL"},
44 {KernelType::TBE_KERNEL, "TBE_KERNEL"},
45 {KernelType::HOST_KERNEL, "HOST_KERNEL"},
46 {KernelType::CPU_KERNEL, "CPU_KERNEL"},
47 {KernelType::GPU_KERNEL, "GPU_KERNEL"},
48 {KernelType::BISHENG_KERNEL, "BISHENG_KERNEL"},
49 {KernelType::ACL_KERNEL, "ACL_KERNEL"},
50 {KernelType::OPAPI_KERNEL, "OPAPI_KERNEL"}};
51 if (trans_map.find(kernel_type) == trans_map.end()) {
52 return "UNKNOWN_KERNEL_TYPE";
53 }
54 return trans_map[kernel_type];
55 }
56
OpTypeLabel(const OpType & op_type)57 std::string OpTypeLabel(const OpType &op_type) {
58 std::unordered_map<OpType, std::string> trans_map{
59 {OpType::UNKNOWN_OP_TYPE, "UNKNOWN_OP_TYPE"}, {OpType::DYNAMIC, "DYNAMIC"}, {OpType::SKIP, "SKIP"}};
60 if (trans_map.find(op_type) == trans_map.end()) {
61 return "UNKNOWN_OP_TYPE";
62 }
63 return trans_map[op_type];
64 }
65
GetInputFormat(size_t input_index) const66 std::string KernelBuildInfo::GetInputFormat(size_t input_index) const {
67 if (input_index >= inputs_format_.size()) {
68 MS_LOG(ERROR) << "The index [" << input_index
69 << "] is exceed the number of input node size:" << inputs_format_.size();
70 return kInvalidFormat;
71 }
72 return inputs_format_[input_index];
73 }
74
GetOutputFormat(size_t output_index) const75 std::string KernelBuildInfo::GetOutputFormat(size_t output_index) const {
76 if (output_index >= outputs_format_.size()) {
77 MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of output";
78 return kInvalidFormat;
79 }
80 return outputs_format_[output_index];
81 }
82
GetInputDeviceType(size_t input_index) const83 TypeId KernelBuildInfo::GetInputDeviceType(size_t input_index) const {
84 if (input_index >= inputs_device_type_.size()) {
85 MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input";
86 return TypeId::kNumberTypeEnd;
87 }
88 return inputs_device_type_[input_index];
89 }
90
GetOutputDeviceType(size_t output_index) const91 TypeId KernelBuildInfo::GetOutputDeviceType(size_t output_index) const {
92 if (output_index >= outputs_device_type_.size()) {
93 MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of output";
94 return TypeId::kNumberTypeEnd;
95 }
96 return outputs_device_type_[output_index];
97 }
98
GetInputKernelObjectType(size_t input_index) const99 KernelObjectType KernelBuildInfo::GetInputKernelObjectType(size_t input_index) const {
100 if (inputs_kernel_object_type_.empty()) {
101 return KernelObjectType::UNKNOWN_TYPE;
102 }
103 if (input_index >= inputs_kernel_object_type_.size()) {
104 bool has_tuple_unfold =
105 std::any_of(inputs_kernel_object_type_.begin(), inputs_kernel_object_type_.end(),
106 [](const KernelObjectType &obj_type) { return obj_type == KernelObjectType::TUPLE_UNFOLD; });
107 // tuple unfold may correspond to many formats or dtypes
108 if (!has_tuple_unfold) {
109 MS_LOG(ERROR) << "The input index [" << input_index
110 << "] is exceed the number of input:" << inputs_kernel_object_type_.size();
111 }
112 return KernelObjectType::UNKNOWN_TYPE;
113 }
114 return inputs_kernel_object_type_[input_index];
115 }
116
GetOutputKernelObjectType(size_t output_index) const117 KernelObjectType KernelBuildInfo::GetOutputKernelObjectType(size_t output_index) const {
118 if (outputs_kernel_object_type_.empty()) {
119 return KernelObjectType::UNKNOWN_TYPE;
120 }
121
122 // tuple unfold may correspond to many formats or dtypes
123 bool has_tuple_unfold =
124 std::any_of(outputs_kernel_object_type_.begin(), outputs_kernel_object_type_.end(),
125 [](const KernelObjectType &obj_type) { return obj_type == KernelObjectType::TUPLE_UNFOLD; });
126 if (has_tuple_unfold) {
127 return KernelObjectType::UNKNOWN_TYPE;
128 }
129
130 if (output_index >= outputs_kernel_object_type_.size()) {
131 MS_LOG(ERROR) << "The output index [" << output_index
132 << "] is exceed the number of output:" << outputs_kernel_object_type_.size();
133 return KernelObjectType::UNKNOWN_TYPE;
134 }
135 return outputs_kernel_object_type_[output_index];
136 }
137
GetAllOutputElementsKernelObjectTypes() const138 const std::vector<KernelObjectType> &KernelBuildInfo::GetAllOutputElementsKernelObjectTypes() const {
139 return output_elements_kernel_object_type_;
140 }
141
GetOriginDataFormat() const142 const std::string &KernelBuildInfo::GetOriginDataFormat() const { return origin_data_format_; }
143
GetAllInputFormats() const144 const std::vector<std::string> &KernelBuildInfo::GetAllInputFormats() const { return inputs_format_; }
145
GetAllOutputFormats() const146 const std::vector<std::string> &KernelBuildInfo::GetAllOutputFormats() const { return outputs_format_; }
147
GetAllInputReshapeType() const148 const std::vector<std::string> &KernelBuildInfo::GetAllInputReshapeType() const { return input_reshape_type_; }
149
GetAllOutputReshapeType() const150 const std::vector<std::string> &KernelBuildInfo::GetAllOutputReshapeType() const { return output_reshape_type_; }
151
GetAllInputDeviceTypes() const152 const std::vector<TypeId> &KernelBuildInfo::GetAllInputDeviceTypes() const { return inputs_device_type_; }
153
GetAllOutputDeviceTypes() const154 const std::vector<TypeId> &KernelBuildInfo::GetAllOutputDeviceTypes() const { return outputs_device_type_; }
155
GetAllOutputKernelObjectTypes() const156 const std::vector<KernelObjectType> &KernelBuildInfo::GetAllOutputKernelObjectTypes() const {
157 return outputs_kernel_object_type_;
158 }
159
GetAllInputKernelObjectTypes() const160 const std::vector<KernelObjectType> &KernelBuildInfo::GetAllInputKernelObjectTypes() const {
161 return inputs_kernel_object_type_;
162 }
163
SetOpType(const OpType & op_type)164 void KernelBuildInfo::SetOpType(const OpType &op_type) { op_type_ = op_type; }
165
SetOutputsKernelObjectType(const std::vector<KernelObjectType> & outputs_kernel_object_type)166 void KernelBuildInfo::SetOutputsKernelObjectType(const std::vector<KernelObjectType> &outputs_kernel_object_type) {
167 outputs_kernel_object_type_ = outputs_kernel_object_type;
168 }
169
SetInputsKernelObjectType(const std::vector<KernelObjectType> & inputs_kernel_object_type)170 void KernelBuildInfo::SetInputsKernelObjectType(const std::vector<KernelObjectType> &inputs_kernel_object_type) {
171 inputs_kernel_object_type_ = inputs_kernel_object_type;
172 }
173
SetOutputElementsKernelObjectType(const std::vector<KernelObjectType> & output_elements_kernel_object_type)174 void KernelBuildInfo::SetOutputElementsKernelObjectType(
175 const std::vector<KernelObjectType> &output_elements_kernel_object_type) {
176 output_elements_kernel_object_type_ = output_elements_kernel_object_type;
177 }
178
SetInputsFormat(const std::vector<std::string> & inputs_format)179 void KernelBuildInfo::SetInputsFormat(const std::vector<std::string> &inputs_format) { inputs_format_ = inputs_format; }
180
SetInputsReshapeType(const std::vector<std::string> & input_reshape_type)181 void KernelBuildInfo::SetInputsReshapeType(const std::vector<std::string> &input_reshape_type) {
182 input_reshape_type_ = input_reshape_type;
183 }
184
SetInputsDeviceType(const std::vector<TypeId> & inputs_device_type)185 void KernelBuildInfo::SetInputsDeviceType(const std::vector<TypeId> &inputs_device_type) {
186 inputs_device_type_ = inputs_device_type;
187 }
188
SetOutputFormat(const std::string & format,size_t index)189 void KernelBuildInfo::SetOutputFormat(const std::string &format, size_t index) {
190 if (index >= outputs_format_.size()) {
191 MS_LOG(EXCEPTION) << "The index [" << index
192 << "] is exceed the length of output formats list, total size:" << outputs_format_.size();
193 }
194 outputs_format_[index] = format;
195 }
196
SetInputFormat(const std::string & format,size_t index)197 void KernelBuildInfo::SetInputFormat(const std::string &format, size_t index) {
198 if (index >= inputs_format_.size()) {
199 MS_LOG(EXCEPTION) << "The index [" << index
200 << "] is exceed the length of input formats list, total size:" << inputs_format_.size();
201 }
202 inputs_format_[index] = format;
203 }
204
SetOutputsFormat(const std::vector<std::string> & outputs_format)205 void KernelBuildInfo::SetOutputsFormat(const std::vector<std::string> &outputs_format) {
206 outputs_format_ = outputs_format;
207 }
208
SetOutputDeviceType(const TypeId & output_device_type,size_t index)209 void KernelBuildInfo::SetOutputDeviceType(const TypeId &output_device_type, size_t index) {
210 if (index >= outputs_device_type_.size()) {
211 MS_LOG(EXCEPTION) << "The index [" << index << "] is exceed the number of output";
212 }
213 outputs_device_type_[index] = output_device_type;
214 }
215
SetOutputsDeviceType(const std::vector<TypeId> & outputs_device_type)216 void KernelBuildInfo::SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type) {
217 outputs_device_type_ = outputs_device_type;
218 }
219
GetInputNum() const220 size_t KernelBuildInfo::GetInputNum() const { return inputs_format_.size(); }
221
GetOutputNum() const222 size_t KernelBuildInfo::GetOutputNum() const { return outputs_format_.size(); }
223
GetOutputNumWithoutMonad() const224 size_t KernelBuildInfo::GetOutputNumWithoutMonad() const {
225 const auto count = std::count_if(outputs_device_type_.begin(), outputs_device_type_.end(),
226 [](TypeId type) { return type != TypeId::kObjectTypeUMonad; });
227 return static_cast<size_t>(count);
228 }
229
GetInputReshapeType(size_t input_index) const230 std::string KernelBuildInfo::GetInputReshapeType(size_t input_index) const {
231 if (input_reshape_type_.empty()) {
232 return "";
233 }
234 if (input_index >= input_reshape_type_.size()) {
235 MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node size "
236 << input_reshape_type_.size();
237 }
238 return input_reshape_type_[input_index];
239 }
240
GetOutputReshapeType(size_t output_index) const241 std::string KernelBuildInfo::GetOutputReshapeType(size_t output_index) const {
242 if (output_reshape_type_.empty()) {
243 return "";
244 }
245 if (output_index >= output_reshape_type_.size()) {
246 MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of output node size "
247 << output_reshape_type_.size();
248 }
249 return output_reshape_type_[output_index];
250 }
251
ToString() const252 std::string KernelBuildInfo::ToString() const {
253 std::ostringstream output_buffer;
254 output_buffer << "(";
255 for (size_t index = 0; index < GetInputNum(); ++index) {
256 if (index != 0) {
257 output_buffer << ", ";
258 }
259 output_buffer << "<" << TypeIdLabel(GetInputDeviceType(index)) << "x" << GetInputFormat(index) << ">";
260 }
261 output_buffer << ", object_type: [";
262 auto input_object_types = GetAllInputKernelObjectTypes();
263 for (size_t index = 0; index < input_object_types.size(); ++index) {
264 if (index != 0) {
265 output_buffer << ",";
266 }
267 output_buffer << KernelObjectTypeLabel(input_object_types[index]);
268 }
269
270 output_buffer << "]) -> (";
271 for (size_t index = 0; index < GetOutputNum(); ++index) {
272 if (index != 0) {
273 output_buffer << ",";
274 }
275 output_buffer << "<" << TypeIdLabel(GetOutputDeviceType(index)) << "x" << GetOutputFormat(index) << ">";
276 }
277 output_buffer << ", object_type: [";
278 auto output_object_types = GetAllOutputKernelObjectTypes();
279 for (size_t index = 0; index < output_object_types.size(); ++index) {
280 if (index != 0) {
281 output_buffer << ", ";
282 }
283 output_buffer << KernelObjectTypeLabel(output_object_types[index]);
284 }
285 output_buffer << "], kernel_type: " << KernelTypeLabel(kernel_type());
286 output_buffer << ", op_type: " << OpTypeLabel(op_type());
287 output_buffer << ")";
288 return output_buffer.str();
289 }
290
IsSimilarityKernelBuildInfo(const KernelBuildInfo & other) const291 bool KernelBuildInfo::IsSimilarityKernelBuildInfo(const KernelBuildInfo &other) const {
292 if (inputs_format_ != other.inputs_format_ || outputs_format_ != other.outputs_format_) {
293 if (op_pattern_ != kFormatAgnosticPattern) {
294 return false;
295 } else {
296 MS_LOG(INFO) << "This kernel build info:" << this->ToString()
297 << ", other kernel build info: " << other.ToString();
298 }
299 }
300 return !(inputs_device_type_ != other.inputs_device_type_ || outputs_device_type_ != other.outputs_device_type_);
301 }
302
operator ==(const KernelBuildInfo & other) const303 bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const {
304 if (kernel_type_ != other.kernel_type_ || processor_ != other.processor_) {
305 return false;
306 }
307 return IsSimilarityKernelBuildInfo(other);
308 }
309
IsInputDefaultPadding() const310 bool KernelBuildInfo::IsInputDefaultPadding() const { return input_reshape_type_.empty(); }
311
IsOutputDefaultPadding() const312 bool KernelBuildInfo::IsOutputDefaultPadding() const { return output_reshape_type_.empty(); }
313
operator !=(const KernelBuildInfo & other) const314 bool KernelBuildInfo::operator!=(const KernelBuildInfo &other) const { return !((*this) == other); }
315
SetKernelType(const KernelType & kernel_type)316 void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &kernel_type) {
317 MS_EXCEPTION_IF_NULL(kernel_build_info_);
318 kernel_build_info_->kernel_type_ = kernel_type;
319 }
320
SetOpType(const OpType & op_type)321 void KernelBuildInfo::KernelBuildInfoBuilder::SetOpType(const OpType &op_type) {
322 MS_EXCEPTION_IF_NULL(kernel_build_info_);
323 kernel_build_info_->op_type_ = op_type;
324 }
325
SetOriginDataFormat(const std::string & origin_data_format)326 void KernelBuildInfo::KernelBuildInfoBuilder::SetOriginDataFormat(const std::string &origin_data_format) {
327 MS_EXCEPTION_IF_NULL(kernel_build_info_);
328 kernel_build_info_->origin_data_format_ = origin_data_format;
329 }
330
SetInputsFormat(const std::vector<std::string> & inputs_format)331 void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsFormat(const std::vector<std::string> &inputs_format) {
332 MS_EXCEPTION_IF_NULL(kernel_build_info_);
333 kernel_build_info_->inputs_format_ = inputs_format;
334 }
335
SetOutputsFormat(const std::vector<std::string> & outputs_format)336 void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsFormat(const std::vector<std::string> &outputs_format) {
337 MS_EXCEPTION_IF_NULL(kernel_build_info_);
338 kernel_build_info_->outputs_format_ = outputs_format;
339 }
340
SetInputsDeviceType(const std::vector<TypeId> & inputs_device_type)341 void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsDeviceType(const std::vector<TypeId> &inputs_device_type) {
342 MS_EXCEPTION_IF_NULL(kernel_build_info_);
343 kernel_build_info_->inputs_device_type_ = inputs_device_type;
344 }
345
SetOutputsDeviceType(const std::vector<TypeId> & outputs_device_type)346 void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type) {
347 MS_EXCEPTION_IF_NULL(kernel_build_info_);
348 kernel_build_info_->outputs_device_type_ = outputs_device_type;
349 }
350
SetFusionType(const std::string & fusion_type)351 void KernelBuildInfo::KernelBuildInfoBuilder::SetFusionType(const std::string &fusion_type) {
352 MS_EXCEPTION_IF_NULL(kernel_build_info_);
353 kernel_build_info_->fusion_type_ = fusion_type;
354 }
355
SetCoreType(const std::string & core_type)356 void KernelBuildInfo::KernelBuildInfoBuilder::SetCoreType(const std::string &core_type) {
357 MS_EXCEPTION_IF_NULL(kernel_build_info_);
358 kernel_build_info_->core_type_ = core_type;
359 }
360
SetOutputDataDesc(const std::vector<nlohmann::json> & data_desc)361 void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputDataDesc(const std::vector<nlohmann::json> &data_desc) {
362 MS_EXCEPTION_IF_NULL(kernel_build_info_);
363 kernel_build_info_->output_data_desc_ = data_desc;
364 }
365
SetProcessor(Processor processor)366 void KernelBuildInfo::KernelBuildInfoBuilder::SetProcessor(Processor processor) {
367 MS_EXCEPTION_IF_NULL(kernel_build_info_);
368 kernel_build_info_->processor_ = processor;
369 }
370
SetInputsKernelObjectType(const std::vector<KernelObjectType> & inputs_kernel_object_type)371 void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsKernelObjectType(
372 const std::vector<KernelObjectType> &inputs_kernel_object_type) {
373 MS_EXCEPTION_IF_NULL(kernel_build_info_);
374 kernel_build_info_->inputs_kernel_object_type_ = inputs_kernel_object_type;
375 }
376
SetOutputsKernelObjectType(const std::vector<KernelObjectType> & outputs_kernel_object_type)377 void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsKernelObjectType(
378 const std::vector<KernelObjectType> &outputs_kernel_object_type) {
379 MS_EXCEPTION_IF_NULL(kernel_build_info_);
380 kernel_build_info_->outputs_kernel_object_type_ = outputs_kernel_object_type;
381 }
382
SetOutputElementsKernelObjectType(const std::vector<KernelObjectType> & output_elements_kernel_object_type)383 void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputElementsKernelObjectType(
384 const std::vector<KernelObjectType> &output_elements_kernel_object_type) {
385 MS_EXCEPTION_IF_NULL(kernel_build_info_);
386 kernel_build_info_->output_elements_kernel_object_type_ = output_elements_kernel_object_type;
387 }
388
Build()389 std::shared_ptr<KernelBuildInfo> KernelBuildInfo::KernelBuildInfoBuilder::Build() { return kernel_build_info_; }
390
SetInputsReshapeType(const std::vector<std::string> & input_reshape_type)391 void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsReshapeType(const std::vector<std::string> &input_reshape_type) {
392 MS_EXCEPTION_IF_NULL(kernel_build_info_);
393 kernel_build_info_->input_reshape_type_ = input_reshape_type;
394 }
395
SetOutputsReshapeType(const std::vector<std::string> & output_reshape_type)396 void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsReshapeType(
397 const std::vector<std::string> &output_reshape_type) {
398 MS_EXCEPTION_IF_NULL(kernel_build_info_);
399 kernel_build_info_->output_reshape_type_ = output_reshape_type;
400 }
401
SetOpPattern(OpPattern pattern)402 void KernelBuildInfo::KernelBuildInfoBuilder::SetOpPattern(OpPattern pattern) {
403 MS_EXCEPTION_IF_NULL(kernel_build_info_);
404 kernel_build_info_->op_pattern_ = pattern;
405 }
SetInputFormat(const std::string & format,size_t index)406 void KernelBuildInfo::KernelBuildInfoBuilder::SetInputFormat(const std::string &format, size_t index) {
407 MS_EXCEPTION_IF_NULL(kernel_build_info_);
408 auto index_limit = kernel_build_info_->inputs_format_.size();
409 if (index >= index_limit) {
410 MS_LOG(EXCEPTION) << "Index of input format out of range! The value should be less than: " << index_limit
411 << ", but got: " << index;
412 }
413 kernel_build_info_->inputs_format_[index] = format;
414 }
415
SetOutputFormat(const std::string & format,size_t index)416 void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string &format, size_t index) {
417 MS_EXCEPTION_IF_NULL(kernel_build_info_);
418 auto index_limit = kernel_build_info_->outputs_format_.size();
419 if (index >= index_limit) {
420 MS_LOG(EXCEPTION) << "Index of output format out of range! The value should be less than: " << index_limit
421 << ", but got: " << index;
422 }
423 kernel_build_info_->outputs_format_[index] = format;
424 }
SetInputReshapeType(const std::string & input_reshape_type,size_t index)425 void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(const std::string &input_reshape_type, size_t index) {
426 MS_EXCEPTION_IF_NULL(kernel_build_info_);
427 auto index_limit = kernel_build_info_->input_reshape_type_.size();
428 if (index >= index_limit) {
429 MS_LOG(EXCEPTION) << "Index of input_reshape_type out of range! The value should be less than: " << index_limit
430 << ", but got: " << index;
431 }
432 (void)std::copy(input_reshape_type.begin(), input_reshape_type.end(),
433 std::back_inserter(kernel_build_info_->input_reshape_type_[index]));
434 }
435
SetOutputReshapeType(const std::string & output_reshape_type,size_t index)436 void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(const std::string &output_reshape_type,
437 size_t index) {
438 MS_EXCEPTION_IF_NULL(kernel_build_info_);
439 auto index_limit = kernel_build_info_->output_reshape_type_.size();
440 if (index >= index_limit) {
441 MS_LOG(EXCEPTION) << "Index of output_reshape_type out of range! The value should be less than: " << index_limit
442 << ", but got: " << index;
443 }
444 (void)std::copy(output_reshape_type.begin(), output_reshape_type.end(),
445 std::back_inserter(kernel_build_info_->output_reshape_type_[index]));
446 }
447
SetOutputDeviceType(const TypeId & output_device_type,size_t index)448 void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputDeviceType(const TypeId &output_device_type, size_t index) {
449 MS_EXCEPTION_IF_NULL(kernel_build_info_);
450 auto index_limit = kernel_build_info_->outputs_device_type_.size();
451 if (index >= index_limit) {
452 MS_LOG(EXCEPTION) << "Index of output_device_type out of range! The value should be less than: " << index_limit
453 << ", but got: " << index;
454 }
455 kernel_build_info_->outputs_device_type_[index] = output_device_type;
456 }
457
SetInputDeviceType(const TypeId & input_device_type,size_t index)458 void KernelBuildInfo::KernelBuildInfoBuilder::SetInputDeviceType(const TypeId &input_device_type, size_t index) {
459 MS_EXCEPTION_IF_NULL(kernel_build_info_);
460 auto index_limit = kernel_build_info_->inputs_device_type_.size();
461 if (index >= index_limit) {
462 MS_LOG(EXCEPTION) << "Index of input_device_type out of range! The value should be less than: " << index_limit
463 << ", but got: " << index;
464 }
465 kernel_build_info_->inputs_device_type_[index] = input_device_type;
466 }
467
SetValid(bool valid)468 void KernelBuildInfo::KernelBuildInfoBuilder::SetValid(bool valid) {
469 MS_EXCEPTION_IF_NULL(kernel_build_info_);
470 kernel_build_info_->valid_ = valid;
471 }
472 } // namespace kernel
473 } // namespace mindspore
474