• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "backend/kernel_compiler/kernel_build_info.h"
18 #include "utils/log_adapter.h"
19 #include "debug/anf_ir_dump.h"
20 namespace mindspore {
21 namespace kernel {
GetInputFormat(size_t input_index) const22 std::string KernelBuildInfo::GetInputFormat(size_t input_index) const {
23   if (input_index >= inputs_format_.size()) {
24     MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input node";
25     return kInvalidFormat;
26   }
27   return inputs_format_[input_index];
28 }
29 
GetOutputFormat(size_t output_index) const30 std::string KernelBuildInfo::GetOutputFormat(size_t output_index) const {
31   if (output_index >= outputs_format_.size()) {
32     MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of output";
33     return kInvalidFormat;
34   }
35   return outputs_format_[output_index];
36 }
37 
GetInputDeviceType(size_t input_index) const38 TypeId KernelBuildInfo::GetInputDeviceType(size_t input_index) const {
39   if (input_index >= inputs_device_type_.size()) {
40     MS_LOG(ERROR) << "The index [" << input_index << "] is exceed the number of input";
41     return TypeId::kNumberTypeEnd;
42   }
43   return inputs_device_type_[input_index];
44 }
45 
GetOutputDeviceType(size_t output_index) const46 TypeId KernelBuildInfo::GetOutputDeviceType(size_t output_index) const {
47   if (output_index >= outputs_device_type_.size()) {
48     MS_LOG(ERROR) << "The index [" << output_index << "] is exceed the number of output";
49     return TypeId::kNumberTypeEnd;
50   }
51   return outputs_device_type_[output_index];
52 }
53 
GetOriginDataFormat() const54 const std::string &KernelBuildInfo::GetOriginDataFormat() const { return origin_data_format_; }
55 
GetAllInputFormats() const56 const std::vector<std::string> &KernelBuildInfo::GetAllInputFormats() const { return inputs_format_; }
57 
GetAllOutputFormats() const58 const std::vector<std::string> &KernelBuildInfo::GetAllOutputFormats() const { return outputs_format_; }
59 
GetAllInputDeviceTypes() const60 const std::vector<TypeId> &KernelBuildInfo::GetAllInputDeviceTypes() const { return inputs_device_type_; }
61 
GetAllOutputDeviceTypes() const62 const std::vector<TypeId> &KernelBuildInfo::GetAllOutputDeviceTypes() const { return outputs_device_type_; }
63 
GetInputNum() const64 size_t KernelBuildInfo::GetInputNum() const { return inputs_format_.size(); }
65 
GetOutputNum() const66 size_t KernelBuildInfo::GetOutputNum() const { return outputs_format_.size(); }
67 
GetInputReshapeType(size_t input_index) const68 std::string KernelBuildInfo::GetInputReshapeType(size_t input_index) const {
69   if (input_reshape_type_.empty()) {
70     return "";
71   }
72   if (input_index >= input_reshape_type_.size()) {
73     MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node size "
74                       << input_reshape_type_.size();
75   }
76   return input_reshape_type_[input_index];
77 }
78 
GetInputValueDepend(size_t input_index) const79 std::string KernelBuildInfo::GetInputValueDepend(size_t input_index) const {
80   if (input_value_depend_.empty()) {
81     return "";
82   }
83   if (input_index >= input_value_depend_.size()) {
84     MS_LOG(EXCEPTION) << "The index [" << input_index << "] is exceed the number of input node size "
85                       << input_value_depend_.size();
86   }
87   return input_value_depend_[input_index];
88 }
89 
GetOutputReshapeType(size_t output_index) const90 std::string KernelBuildInfo::GetOutputReshapeType(size_t output_index) const {
91   if (output_reshape_type_.empty()) {
92     return "";
93   }
94   if (output_index >= output_reshape_type_.size()) {
95     MS_LOG(EXCEPTION) << "The index [" << output_index << "] is exceed the number of output node size "
96                       << output_reshape_type_.size();
97   }
98   return output_reshape_type_[output_index];
99 }
100 
ToString() const101 std::string KernelBuildInfo::ToString() const {
102   std::ostringstream output_buffer;
103   output_buffer << "(";
104   for (size_t index = 0; index < GetInputNum(); ++index) {
105     if (index != 0) {
106       output_buffer << ", ";
107     }
108     output_buffer << "<" << ToShortString(GetInputDeviceType(index)) << "x" << GetInputFormat(index) << ">";
109   }
110   output_buffer << ") -> (";
111   for (size_t index = 0; index < GetOutputNum(); ++index) {
112     if (index != 0) {
113       output_buffer << ", ";
114     }
115     output_buffer << "<" << ToShortString(GetOutputDeviceType(index)) << "x" << GetOutputFormat(index) << ">";
116   }
117   output_buffer << ")";
118   return output_buffer.str();
119 }
120 
IsSimilarityKernelBuildInfo(const KernelBuildInfo & other) const121 bool KernelBuildInfo::IsSimilarityKernelBuildInfo(const KernelBuildInfo &other) const {
122   if (inputs_format_ != other.inputs_format_ || outputs_format_ != other.outputs_format_) {
123     if (op_pattern_ != kFormatAgnosticPattern) {
124       return false;
125     } else {
126       MS_LOG(INFO) << "This kernel build info:" << this->ToString()
127                    << ", other kernel build info: " << other.ToString();
128     }
129   }
130   return !(inputs_device_type_ != other.inputs_device_type_ || outputs_device_type_ != other.outputs_device_type_);
131 }
132 
operator ==(const KernelBuildInfo & other) const133 bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const {
134   if (kernel_type_ != other.kernel_type_ || fusion_type_ != other.fusion_type_ || processor_ != other.processor_) {
135     return false;
136   }
137   return IsSimilarityKernelBuildInfo(other);
138 }
139 
IsInputDefaultPadding() const140 bool KernelBuildInfo::IsInputDefaultPadding() const { return input_reshape_type_.empty(); }
141 
IsOutputDefaultPadding() const142 bool KernelBuildInfo::IsOutputDefaultPadding() const { return output_reshape_type_.empty(); }
143 
operator !=(const KernelBuildInfo & other) const144 bool KernelBuildInfo::operator!=(const KernelBuildInfo &other) const { return !((*this) == other); }
145 
SetKernelType(const KernelType & kernel_type)146 void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &kernel_type) {
147   MS_EXCEPTION_IF_NULL(kernel_build_info_);
148   kernel_build_info_->kernel_type_ = kernel_type;
149 }
150 
SetOriginDataFormat(const std::string & origin_data_format)151 void KernelBuildInfo::KernelBuildInfoBuilder::SetOriginDataFormat(const std::string &origin_data_format) {
152   MS_EXCEPTION_IF_NULL(kernel_build_info_);
153   kernel_build_info_->origin_data_format_ = origin_data_format;
154 }
155 
SetInputsFormat(const std::vector<std::string> & inputs_format)156 void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsFormat(const std::vector<std::string> &inputs_format) {
157   MS_EXCEPTION_IF_NULL(kernel_build_info_);
158   kernel_build_info_->inputs_format_ = inputs_format;
159 }
160 
SetOutputsFormat(const std::vector<std::string> & outputs_format)161 void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsFormat(const std::vector<std::string> &outputs_format) {
162   MS_EXCEPTION_IF_NULL(kernel_build_info_);
163   kernel_build_info_->outputs_format_ = outputs_format;
164 }
165 
SetInputsDeviceType(const std::vector<TypeId> & inputs_device_type)166 void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsDeviceType(const std::vector<TypeId> &inputs_device_type) {
167   MS_EXCEPTION_IF_NULL(kernel_build_info_);
168   kernel_build_info_->inputs_device_type_ = inputs_device_type;
169 }
170 
SetOutputsDeviceType(const std::vector<TypeId> & outputs_device_type)171 void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type) {
172   MS_EXCEPTION_IF_NULL(kernel_build_info_);
173   kernel_build_info_->outputs_device_type_ = outputs_device_type;
174 }
175 
SetFusionType(FusionType fusion_type)176 void KernelBuildInfo::KernelBuildInfoBuilder::SetFusionType(FusionType fusion_type) {
177   MS_EXCEPTION_IF_NULL(kernel_build_info_);
178   kernel_build_info_->fusion_type_ = fusion_type;
179 }
180 
SetOutputDataDesc(const std::vector<nlohmann::json> & data_desc)181 void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputDataDesc(const std::vector<nlohmann::json> &data_desc) {
182   MS_EXCEPTION_IF_NULL(kernel_build_info_);
183   kernel_build_info_->output_data_desc_ = data_desc;
184 }
185 
SetProcessor(Processor processor)186 void KernelBuildInfo::KernelBuildInfoBuilder::SetProcessor(Processor processor) {
187   MS_EXCEPTION_IF_NULL(kernel_build_info_);
188   kernel_build_info_->processor_ = processor;
189 }
190 
Build()191 std::shared_ptr<KernelBuildInfo> KernelBuildInfo::KernelBuildInfoBuilder::Build() { return kernel_build_info_; }
192 
SetInputsReshapeType(const std::vector<std::string> & input_reshape_type)193 void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsReshapeType(const std::vector<std::string> &input_reshape_type) {
194   MS_EXCEPTION_IF_NULL(kernel_build_info_);
195   kernel_build_info_->input_reshape_type_ = input_reshape_type;
196 }
197 
SetInputsValueDepend(const std::vector<std::string> & input_value_depend)198 void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsValueDepend(const std::vector<std::string> &input_value_depend) {
199   MS_EXCEPTION_IF_NULL(kernel_build_info_);
200   kernel_build_info_->input_value_depend_ = input_value_depend;
201 }
202 
SetOutputsReshapeType(const std::vector<std::string> & output_reshape_type)203 void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsReshapeType(
204   const std::vector<std::string> &output_reshape_type) {
205   MS_EXCEPTION_IF_NULL(kernel_build_info_);
206   kernel_build_info_->output_reshape_type_ = output_reshape_type;
207 }
208 
SetOpPattern(OpPattern pattern)209 void KernelBuildInfo::KernelBuildInfoBuilder::SetOpPattern(OpPattern pattern) {
210   MS_EXCEPTION_IF_NULL(kernel_build_info_);
211   kernel_build_info_->op_pattern_ = pattern;
212 }
SetInputFormat(const std::string & format,size_t index)213 void KernelBuildInfo::KernelBuildInfoBuilder::SetInputFormat(const std::string &format, size_t index) {
214   MS_EXCEPTION_IF_NULL(kernel_build_info_);
215   auto index_limit = kernel_build_info_->inputs_format_.size();
216   if (index >= index_limit) {
217     MS_LOG(EXCEPTION) << "Index of input format out of range! The value should be less than: " << index_limit
218                       << ", but got: " << index;
219   }
220   kernel_build_info_->inputs_format_[index] = format;
221 }
222 
SetOutputFormat(const std::string & format,size_t index)223 void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string &format, size_t index) {
224   MS_EXCEPTION_IF_NULL(kernel_build_info_);
225   auto index_limit = kernel_build_info_->outputs_format_.size();
226   if (index >= index_limit) {
227     MS_LOG(EXCEPTION) << "Index of output format out of range! The value should be less than: " << index_limit
228                       << ", but got: " << index;
229   }
230   kernel_build_info_->outputs_format_[index] = format;
231 }
SetInputReshapeType(const std::string & input_reshape_type,size_t index)232 void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(const std::string &input_reshape_type, size_t index) {
233   MS_EXCEPTION_IF_NULL(kernel_build_info_);
234   auto index_limit = kernel_build_info_->input_reshape_type_.size();
235   if (index >= index_limit) {
236     MS_LOG(EXCEPTION) << "Index of input_reshape_type out of range! The value should be less than: " << index_limit
237                       << ", but got: " << index;
238   }
239   (void)std::copy(input_reshape_type.begin(), input_reshape_type.end(),
240                   std::back_inserter(kernel_build_info_->input_reshape_type_[index]));
241 }
242 
SetOutputReshapeType(const std::string & output_reshape_type,size_t index)243 void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(const std::string &output_reshape_type,
244                                                                    size_t index) {
245   MS_EXCEPTION_IF_NULL(kernel_build_info_);
246   auto index_limit = kernel_build_info_->output_reshape_type_.size();
247   if (index >= index_limit) {
248     MS_LOG(EXCEPTION) << "Index of output_reshape_type out of range! The value should be less than: " << index_limit
249                       << ", but got: " << index;
250   }
251   (void)std::copy(output_reshape_type.begin(), output_reshape_type.end(),
252                   std::back_inserter(kernel_build_info_->output_reshape_type_[index]));
253 }
254 
SetOutputDeviceType(const TypeId & output_device_type,size_t index)255 void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputDeviceType(const TypeId &output_device_type, size_t index) {
256   MS_EXCEPTION_IF_NULL(kernel_build_info_);
257   auto index_limit = kernel_build_info_->outputs_device_type_.size();
258   if (index >= index_limit) {
259     MS_LOG(EXCEPTION) << "Index of output_device_type out of range! The value should be less than: " << index_limit
260                       << ", but got: " << index;
261   }
262   kernel_build_info_->outputs_device_type_[index] = output_device_type;
263 }
264 
SetInputDeviceType(const TypeId & input_device_type,size_t index)265 void KernelBuildInfo::KernelBuildInfoBuilder::SetInputDeviceType(const TypeId &input_device_type, size_t index) {
266   MS_EXCEPTION_IF_NULL(kernel_build_info_);
267   auto index_limit = kernel_build_info_->inputs_device_type_.size();
268   if (index >= index_limit) {
269     MS_LOG(EXCEPTION) << "Index of input_device_type out of range! The value should be less than: " << index_limit
270                       << ", but got: " << index;
271   }
272   kernel_build_info_->inputs_device_type_[index] = input_device_type;
273 }
274 }  // namespace kernel
275 }  // namespace mindspore
276