1 /** 2 * Copyright 2019-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_OPLIB_OPINFO_H_ 18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_OPLIB_OPINFO_H_ 19 #include <vector> 20 #include <string> 21 #include <memory> 22 #include <map> 23 #include <utility> 24 #include <unordered_map> 25 #include "ir/dtype.h" 26 #include "kernel/kernel.h" 27 #include "kernel/oplib/op_info_keys.h" 28 29 namespace mindspore::kernel { 30 class OpAttr { 31 public: 32 OpAttr() = default; 33 ~OpAttr() = default; 34 name()35 std::string name() const { return name_; } param_type()36 std::string param_type() const { return param_type_; } type()37 std::string type() const { return type_; } value()38 std::string value() const { return value_; } default_value()39 std::string default_value() const { return default_value_; } 40 set_name(const std::string & name)41 void set_name(const std::string &name) { name_ = name; } set_param_type(const std::string & param_type)42 void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } set_type(const std::string & type)43 void set_type(const std::string &type) { type_ = type; } set_value(const std::string & value)44 void set_value(const std::string &value) { value_ = value; } set_default_value(const std::string & default_value)45 void set_default_value(const std::string &default_value) { default_value_ = default_value; } 46 47 private: 48 std::string name_; 49 std::string param_type_; 50 std::string type_; 51 std::string value_; 52 std::string default_value_; 53 }; 54 55 class OpIOInfo { 56 public: 57 OpIOInfo() = default; 58 ~OpIOInfo() = default; 59 index()60 int index() const { return index_; } name()61 const std::string &name() const { return name_; } need_compile()62 bool need_compile() const { return need_compile_; } param_type()63 const std::string ¶m_type() const { return param_type_; } reshape_type()64 const std::string &reshape_type() const { return reshape_type_; } shape()65 const std::string &shape() const { return shape_; } dtypes()66 const std::vector<std::string> &dtypes() const { return dtypes_; } formats()67 const std::vector<std::string> &formats() const { return formats_; } unknown_shape_formats()68 const std::vector<std::string> &unknown_shape_formats() const { return unknown_shape_formats_; } object_types()69 const std::vector<std::string> &object_types() const { return object_types_; } value_depend()70 const std::string &value_depend() const { return value_depend_; } shapes_type()71 const std::string &shapes_type() const { return shapes_type_; } 72 set_index(const int index)73 void set_index(const int index) { index_ = index; } set_name(const std::string & name)74 void set_name(const std::string &name) { name_ = name; } set_need_compile(const bool need_compile)75 void set_need_compile(const bool need_compile) { need_compile_ = need_compile; } set_param_type(const std::string & param_type)76 void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } set_reshape_type(const std::string & reshape_type)77 void set_reshape_type(const std::string &reshape_type) { reshape_type_ = reshape_type; } set_shape(const std::string & shape)78 void set_shape(const std::string &shape) { shape_ = shape; } set_dtypes(const std::vector<std::string> & dtype)79 void set_dtypes(const std::vector<std::string> &dtype) { dtypes_ = dtype; } set_formats(const std::vector<std::string> & formats)80 void set_formats(const std::vector<std::string> &formats) { formats_ = formats; } set_unknown_shape_formats(const std::vector<std::string> & unknown_shape_formats)81 void set_unknown_shape_formats(const std::vector<std::string> &unknown_shape_formats) { 82 unknown_shape_formats_ = unknown_shape_formats; 83 } set_object_types(const std::vector<std::string> & object_types)84 void set_object_types(const std::vector<std::string> &object_types) { object_types_ = object_types; } set_value_depend(const std::string & value_depend)85 void set_value_depend(const std::string &value_depend) { value_depend_ = value_depend; } set_shapes_type(const std::string & shapes_type)86 void set_shapes_type(const std::string &shapes_type) { shapes_type_ = shapes_type; } 87 88 private: 89 int index_ = 0; 90 std::string name_; 91 bool need_compile_ = false; 92 std::string param_type_; 93 std::string reshape_type_; 94 std::string shape_; 95 std::string shapes_type_; 96 std::string value_depend_ = kIgnored; 97 std::vector<std::string> dtypes_; 98 std::vector<std::string> formats_; 99 std::vector<std::string> unknown_shape_formats_; 100 std::vector<std::string> object_types_; 101 }; 102 103 class OpInfo { 104 public: 105 OpInfo() = default; 106 ~OpInfo() = default; op_name()107 std::string op_name() const { return op_name_; } imply_type()108 OpImplyType imply_type() const { return imply_type_; } async()109 bool async() const { return async_; } bin_file()110 std::string bin_file() const { return bin_file_; } compute()111 int compute() const { return compute_; } cube_op()112 bool cube_op() const { return cube_op_; } dynamic_compile_static()113 bool dynamic_compile_static() const { return dynamic_compile_static_; } dynamic_format()114 bool dynamic_format() const { return dynamic_format_; } dynamic_rank_support()115 bool dynamic_rank_support() const { return dynamic_rank_support_; } dynamic_shape_support()116 bool dynamic_shape_support() const { return dynamic_shape_support_; } heavy_op()117 bool heavy_op() const { return heavy_op_; } jit_compile()118 bool jit_compile() const { return jit_compile_; } kernel()119 std::string kernel() const { return kernel_; } need_check_support()120 bool need_check_support() const { return need_check_support_; } op_pattern()121 OpPattern op_pattern() const { return op_pattern_; } op_file()122 std::string op_file() const { return op_file_; } op_interface()123 std::string op_interface() const { return op_interface_; } partial()124 bool partial() const { return partial_; } precision_reduce()125 bool precision_reduce() const { return precision_reduce_; } range_limit()126 std::string range_limit() const { return range_limit_; } sagt_key_attrs()127 const std::vector<std::string> &sagt_key_attrs() const { return sagt_key_attrs_; } slice_pattern()128 std::string slice_pattern() const { return slice_pattern_; } prebuild_pattern()129 std::string prebuild_pattern() const { return prebuild_pattern_; } 130 // Attr impl_path()131 std::string impl_path() const { return impl_path_; } processor()132 std::string processor() const { return processor_; } input_to_attr_index()133 const std::vector<size_t> &input_to_attr_index() const { return input_to_attr_index_; } real_input_index()134 const std::pair<std::map<size_t, size_t>, std::map<size_t, size_t>> &real_input_index() const { 135 return real_input_index_; 136 } ref_infos()137 const std::unordered_map<size_t, size_t> &ref_infos() const { return ref_infos_; } 138 attrs_ptr()139 std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; } inputs_ptr()140 std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; } outputs_ptr()141 std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; } 142 set_op_name(const std::string & op_name)143 void set_op_name(const std::string &op_name) { op_name_ = op_name; } set_imply_type(OpImplyType imply_type)144 void set_imply_type(OpImplyType imply_type) { imply_type_ = imply_type; } set_async(const bool async)145 void set_async(const bool async) { async_ = async; } set_bin_file(const std::string & bin_file)146 void set_bin_file(const std::string &bin_file) { bin_file_ = bin_file; } set_compute(const int compute)147 void set_compute(const int compute) { compute_ = compute; } set_cube_op(bool cube_op)148 void set_cube_op(bool cube_op) { cube_op_ = cube_op; } set_dynamic_compile_static(bool dynamic_compile_static)149 void set_dynamic_compile_static(bool dynamic_compile_static) { dynamic_compile_static_ = dynamic_compile_static; } set_dynamic_format(bool dynamic_format)150 void set_dynamic_format(bool dynamic_format) { dynamic_format_ = dynamic_format; } set_dynamic_rank_support(bool dynamic_rank_support)151 void set_dynamic_rank_support(bool dynamic_rank_support) { dynamic_rank_support_ = dynamic_rank_support; } set_dynamic_shape_support(bool flag)152 void set_dynamic_shape_support(bool flag) { dynamic_shape_support_ = flag; } set_heavy_op(bool heavy_op)153 void set_heavy_op(bool heavy_op) { heavy_op_ = heavy_op; } set_jit_compile(bool jit_compile)154 void set_jit_compile(bool jit_compile) { jit_compile_ = jit_compile; } set_soft_sync(bool soft_sync)155 void set_soft_sync(bool soft_sync) { soft_sync_ = soft_sync; } set_op_impl_switch(const std::string & op_impl_switch)156 void set_op_impl_switch(const std::string &op_impl_switch) { op_impl_switch_ = op_impl_switch; } set_kernel(const std::string & kernel_name)157 void set_kernel(const std::string &kernel_name) { kernel_ = kernel_name; } set_need_check_supported(bool need_check_supported)158 void set_need_check_supported(bool need_check_supported) { need_check_support_ = need_check_supported; } set_op_pattern(const OpPattern op_pattern)159 void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; } set_op_file(const std::string & op_file)160 void set_op_file(const std::string &op_file) { op_file_ = op_file; } set_op_interface(const std::string & op_interface)161 void set_op_interface(const std::string &op_interface) { op_interface_ = op_interface; } set_partial(const bool partial_flag)162 void set_partial(const bool partial_flag) { partial_ = partial_flag; } set_precision_reduce(bool precision_reduce)163 void set_precision_reduce(bool precision_reduce) { precision_reduce_ = precision_reduce; } set_range_limit(const std::string & range_limit)164 void set_range_limit(const std::string &range_limit) { range_limit_ = range_limit; } set_sagt_key_attrs(const std::vector<std::string> & sagt_key_attrs)165 void set_sagt_key_attrs(const std::vector<std::string> &sagt_key_attrs) { sagt_key_attrs_ = sagt_key_attrs; } set_slice_pattern(const std::string & slice_pattern)166 void set_slice_pattern(const std::string &slice_pattern) { slice_pattern_ = slice_pattern; } set_prebuild_pattern(const std::string & prebuild_pattern)167 void set_prebuild_pattern(const std::string &prebuild_pattern) { prebuild_pattern_ = prebuild_pattern; } 168 set_impl_path(const std::string & impl_path)169 void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; } set_processor(const std::string & processor)170 void set_processor(const std::string &processor) { processor_ = processor; } set_input_to_attr_index(const std::vector<size_t> & input_to_attr_index)171 void set_input_to_attr_index(const std::vector<size_t> &input_to_attr_index) { 172 input_to_attr_index_ = input_to_attr_index; 173 } set_real_input_index(const std::pair<std::map<size_t,size_t>,std::map<size_t,size_t>> & real_input_index)174 void set_real_input_index(const std::pair<std::map<size_t, size_t>, std::map<size_t, size_t>> &real_input_index) { 175 real_input_index_ = real_input_index; 176 } is_ref()177 bool is_ref() const { return !ref_infos_.empty(); } has_ref_index(size_t out_index)178 bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); } add_ref_pair(size_t out_index,size_t in_index)179 void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); } add_attrs_ptr(const std::shared_ptr<OpAttr> & attr)180 void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); } add_inputs_ptr(const std::shared_ptr<OpIOInfo> & input)181 void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); } set_inputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>> & inputs)182 void set_inputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>> &inputs) { inputs_ptr_ = inputs; } add_outputs_ptr(const std::shared_ptr<OpIOInfo> & output)183 void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); } 184 equals_to(const std::shared_ptr<OpInfo> & other_info)185 bool equals_to(const std::shared_ptr<OpInfo> &other_info) const { 186 return this->op_name_ == other_info->op_name_ && this->imply_type_ == other_info->imply_type_ && 187 this->processor_ == other_info->processor_ && this->op_pattern_ == other_info->op_pattern_ && 188 this->dynamic_shape_support_ == other_info->dynamic_shape_support_ && 189 this->dynamic_compile_static_ == other_info->dynamic_compile_static_; 190 } 191 192 private: 193 std::string op_name_; 194 OpImplyType imply_type_ = kImplyTBE; 195 bool async_ = false; 196 std::string bin_file_; 197 int compute_ = 0; 198 bool cube_op_ = false; 199 bool dynamic_compile_static_ = false; 200 bool dynamic_format_ = false; 201 bool dynamic_rank_support_ = false; 202 bool dynamic_shape_support_ = false; 203 bool heavy_op_ = false; 204 bool jit_compile_ = false; 205 bool soft_sync_ = false; 206 std::string op_impl_switch_ = ""; 207 std::string kernel_; 208 bool need_check_support_ = false; 209 OpPattern op_pattern_ = kCommonPattern; 210 std::string op_file_; 211 std::string op_interface_; 212 bool partial_ = false; 213 bool precision_reduce_ = false; 214 std::string range_limit_; 215 std::vector<std::string> sagt_key_attrs_ = {}; 216 std::string slice_pattern_; 217 std::string prebuild_pattern_; 218 // Attr info 219 std::vector<std::shared_ptr<OpAttr>> attrs_ptr_; 220 // Input/Output info 221 std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr_; 222 std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr_; 223 224 // Attr not in the json 225 std::string impl_path_; 226 std::string processor_; 227 std::vector<size_t> input_to_attr_index_{}; 228 std::pair<std::map<size_t, size_t>, std::map<size_t, size_t>> real_input_index_{{}, {}}; 229 std::unordered_map<size_t, size_t> ref_infos_; 230 }; 231 232 using OpAttrPtr = std::shared_ptr<OpAttr>; 233 using OpIOInfoPtr = std::shared_ptr<OpIOInfo>; 234 using OpInfoPtr = std::shared_ptr<OpInfo>; 235 } // namespace mindspore::kernel 236 #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_OPLIB_OPINFO_H_ 237