1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_OPLIB_OPINFO_H_ 18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_OPLIB_OPINFO_H_ 19 #include <vector> 20 #include <string> 21 #include <memory> 22 #include <unordered_map> 23 #include "ir/dtype.h" 24 #include "backend/kernel_compiler/kernel.h" 25 26 namespace mindspore { 27 namespace kernel { 28 enum OpImplyType { kAKG = 0, kTBE = 1, kAICPU = 2, kCPU }; 29 enum OpIOType { kInput = 0, kOutput }; 30 constexpr auto kIgnored = "ignored"; 31 32 class OpAttr { 33 public: 34 OpAttr() = default; 35 ~OpAttr() = default; 36 name()37 std::string name() const { return name_; } param_type()38 std::string param_type() const { return param_type_; } type()39 std::string type() const { return type_; } value()40 std::string value() const { return value_; } default_value()41 std::string default_value() const { return default_value_; } 42 set_name(const std::string & name)43 void set_name(const std::string &name) { name_ = name; } set_param_type(const std::string & param_type)44 void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } set_type(const std::string & type)45 void set_type(const std::string &type) { type_ = type; } set_value(const std::string & value)46 void set_value(const std::string &value) { value_ = value; } set_default_value(const std::string & default_value)47 void set_default_value(const std::string &default_value) { default_value_ = default_value; } 48 49 private: 50 std::string name_; 51 std::string param_type_; 52 std::string type_; 53 std::string value_; 54 std::string default_value_; 55 }; 56 57 class OpIOInfo { 58 public: 59 OpIOInfo() = default; 60 ~OpIOInfo() = default; 61 index()62 int index() const { return index_; } name()63 const std::string &name() const { return name_; } need_compile()64 bool need_compile() const { return need_compile_; } param_type()65 const std::string ¶m_type() const { return param_type_; } reshape_type()66 const std::string &reshape_type() const { return reshape_type_; } shape()67 const std::string &shape() const { return shape_; } dtypes()68 const std::vector<std::string> &dtypes() const { return dtypes_; } formats()69 const std::vector<std::string> &formats() const { return formats_; } value_depend()70 const std::string &value_depend() const { return value_depend_; } 71 set_index(const int index)72 void set_index(const int index) { index_ = index; } set_name(const std::string & name)73 void set_name(const std::string &name) { name_ = name; } set_need_compile(const bool need_compile)74 void set_need_compile(const bool need_compile) { need_compile_ = need_compile; } set_param_type(const std::string & param_type)75 void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } set_reshape_type(const std::string & reshape_type)76 void set_reshape_type(const std::string &reshape_type) { reshape_type_ = reshape_type; } set_shape(const std::string & shape)77 void set_shape(const std::string &shape) { shape_ = shape; } set_dtypes(const std::vector<std::string> & dtype)78 void set_dtypes(const std::vector<std::string> &dtype) { dtypes_ = dtype; } set_formats(const std::vector<std::string> & formats)79 void set_formats(const std::vector<std::string> &formats) { formats_ = formats; } set_value_depend(const std::string & value_depend)80 void set_value_depend(const std::string &value_depend) { value_depend_ = value_depend; } 81 82 private: 83 int index_ = 0; 84 std::string name_; 85 bool need_compile_ = false; 86 std::string param_type_; 87 std::string reshape_type_; 88 std::string shape_; 89 std::vector<std::string> dtypes_; 90 std::vector<std::string> formats_; 91 std::string value_depend_ = kIgnored; 92 }; 93 94 class OpInfo { 95 public: 96 OpInfo() = default; OpInfo(const OpInfo & opinfo)97 OpInfo(const OpInfo &opinfo) { 98 op_name_ = opinfo.op_name(); 99 imply_type_ = opinfo.imply_type(); 100 101 impl_path_ = opinfo.impl_path(); 102 fusion_type_ = opinfo.fusion_type(); 103 async_flag_ = opinfo.async_flag_; 104 binfile_name_ = opinfo.binfile_name_; 105 compute_cost_ = opinfo.compute_cost_; 106 kernel_name_ = opinfo.kernel_name(); 107 partial_flag_ = opinfo.partial_flag_; 108 dynamic_shape_ = opinfo.dynamic_shape_; 109 dynamic_compile_static_ = opinfo.dynamic_compile_static_; 110 op_pattern_ = opinfo.op_pattern(); 111 processor_ = opinfo.processor_; 112 need_check_supported_ = opinfo.need_check_supported(); 113 is_dynamic_format_ = opinfo.is_dynamic_format(); 114 for (const auto &attr : opinfo.attrs_ptr()) { 115 attrs_ptr_.push_back(std::make_shared<OpAttr>(*attr)); 116 } 117 for (const auto &input : opinfo.inputs_ptr()) { 118 inputs_ptr_.push_back(std::make_shared<OpIOInfo>(*input)); 119 } 120 for (const auto &output : opinfo.outputs_ptr()) { 121 outputs_ptr_.push_back(std::make_shared<OpIOInfo>(*output)); 122 } 123 ref_infos_ = opinfo.ref_infos(); 124 } 125 ~OpInfo() = default; op_name()126 std::string op_name() const { return op_name_; } imply_type()127 OpImplyType imply_type() const { return imply_type_; } impl_path()128 std::string impl_path() const { return impl_path_; } fusion_type()129 std::string fusion_type() const { return fusion_type_; } kernel_name()130 std::string kernel_name() const { return kernel_name_; } op_pattern()131 OpPattern op_pattern() const { return op_pattern_; } dynamic_shape()132 bool dynamic_shape() const { return dynamic_shape_; } dynamic_compile_static()133 bool dynamic_compile_static() const { return dynamic_compile_static_; } processor()134 std::string processor() const { return processor_; } need_check_supported()135 bool need_check_supported() const { return need_check_supported_; } is_dynamic_format()136 bool is_dynamic_format() const { return is_dynamic_format_; } attrs_ptr()137 std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; } inputs_ptr()138 std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; } outputs_ptr()139 std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; } ref_infos()140 const std::unordered_map<size_t, size_t> &ref_infos() const { return ref_infos_; } 141 set_dynamic_shape(bool dynamic_shape)142 void set_dynamic_shape(bool dynamic_shape) { dynamic_shape_ = dynamic_shape; } set_dynamic_compile_static_(bool dynamic_compile_static)143 void set_dynamic_compile_static_(bool dynamic_compile_static) { dynamic_compile_static_ = dynamic_compile_static; } set_op_name(const std::string & op_name)144 void set_op_name(const std::string &op_name) { op_name_ = op_name; } set_imply_type(const OpImplyType imply_type)145 void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; } set_impl_path(const std::string & impl_path)146 void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; } set_fusion_type(const std::string & fusion_type)147 void set_fusion_type(const std::string &fusion_type) { fusion_type_ = fusion_type; } set_async_flag(const bool async_flag)148 void set_async_flag(const bool async_flag) { async_flag_ = async_flag; } set_binfile_name(const std::string & binfile_name)149 void set_binfile_name(const std::string &binfile_name) { binfile_name_ = binfile_name; } set_compute_cost(const int compute_cost)150 void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; } set_kernel_name(const std::string & kernel_name)151 void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; } set_partial_flag(const bool partial_flag)152 void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } set_op_pattern(const OpPattern op_pattern)153 void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; } set_processor(const std::string & processor)154 void set_processor(const std::string &processor) { processor_ = processor; } set_need_check_supported(bool need_check_supported)155 void set_need_check_supported(bool need_check_supported) { need_check_supported_ = need_check_supported; } set_is_dynamic_format(bool is_dynamic_format)156 void set_is_dynamic_format(bool is_dynamic_format) { is_dynamic_format_ = is_dynamic_format; } add_attrs_ptr(const std::shared_ptr<OpAttr> & attr)157 void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); } add_inputs_ptr(const std::shared_ptr<OpIOInfo> & input)158 void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); } add_outputs_ptr(const std::shared_ptr<OpIOInfo> & output)159 void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); } is_ref()160 bool is_ref() const { return !ref_infos_.empty(); } has_ref_index(size_t out_index)161 bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); } add_ref_pair(size_t out_index,size_t in_index)162 void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); } ClearInputs()163 void ClearInputs() { (void)inputs_ptr_.clear(); } ClearOutputs()164 void ClearOutputs() { (void)outputs_ptr_.clear(); } equals_to(const std::shared_ptr<OpInfo> & other_info)165 bool equals_to(const std::shared_ptr<OpInfo> &other_info) const { 166 return this->op_name_ == other_info->op_name_ && this->imply_type_ == other_info->imply_type_ && 167 this->processor_ == other_info->processor_ && this->op_pattern_ == other_info->op_pattern_ && 168 this->dynamic_shape_ == other_info->dynamic_shape_ && 169 this->dynamic_compile_static_ == other_info->dynamic_compile_static_; 170 } 171 172 private: 173 std::string op_name_; 174 OpImplyType imply_type_ = kTBE; 175 std::string impl_path_; 176 std::string fusion_type_; 177 bool async_flag_ = false; 178 std::string binfile_name_; 179 int compute_cost_ = 0; 180 std::string kernel_name_; 181 bool partial_flag_ = false; 182 bool dynamic_shape_ = false; 183 bool dynamic_compile_static_ = false; 184 bool need_check_supported_ = false; 185 bool is_dynamic_format_ = false; 186 OpPattern op_pattern_ = kCommonPattern; 187 std::string processor_; 188 std::vector<std::shared_ptr<OpAttr>> attrs_ptr_; 189 std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr_; 190 std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr_; 191 std::unordered_map<size_t, size_t> ref_infos_; 192 }; 193 194 using OpAttrPtr = std::shared_ptr<OpAttr>; 195 using OpIOInfoPtr = std::shared_ptr<OpIOInfo>; 196 using OpInfoPtr = std::shared_ptr<OpInfo>; 197 } // namespace kernel 198 } // namespace mindspore 199 #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_OPLIB_OPINFO_H_ 200