• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_OPLIB_OPINFO_H_
18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_OPLIB_OPINFO_H_
19 #include <vector>
20 #include <string>
21 #include <memory>
22 #include <map>
23 #include <utility>
24 #include <unordered_map>
25 #include "ir/dtype.h"
26 #include "kernel/kernel.h"
27 #include "kernel/oplib/op_info_keys.h"
28 
29 namespace mindspore::kernel {
30 class OpAttr {
31  public:
32   OpAttr() = default;
33   ~OpAttr() = default;
34 
name()35   std::string name() const { return name_; }
param_type()36   std::string param_type() const { return param_type_; }
type()37   std::string type() const { return type_; }
value()38   std::string value() const { return value_; }
default_value()39   std::string default_value() const { return default_value_; }
40 
set_name(const std::string & name)41   void set_name(const std::string &name) { name_ = name; }
set_param_type(const std::string & param_type)42   void set_param_type(const std::string &param_type) { param_type_ = param_type; }
set_type(const std::string & type)43   void set_type(const std::string &type) { type_ = type; }
set_value(const std::string & value)44   void set_value(const std::string &value) { value_ = value; }
set_default_value(const std::string & default_value)45   void set_default_value(const std::string &default_value) { default_value_ = default_value; }
46 
47  private:
48   std::string name_;
49   std::string param_type_;
50   std::string type_;
51   std::string value_;
52   std::string default_value_;
53 };
54 
55 class OpIOInfo {
56  public:
57   OpIOInfo() = default;
58   ~OpIOInfo() = default;
59 
index()60   int index() const { return index_; }
name()61   const std::string &name() const { return name_; }
need_compile()62   bool need_compile() const { return need_compile_; }
param_type()63   const std::string &param_type() const { return param_type_; }
reshape_type()64   const std::string &reshape_type() const { return reshape_type_; }
shape()65   const std::string &shape() const { return shape_; }
dtypes()66   const std::vector<std::string> &dtypes() const { return dtypes_; }
formats()67   const std::vector<std::string> &formats() const { return formats_; }
unknown_shape_formats()68   const std::vector<std::string> &unknown_shape_formats() const { return unknown_shape_formats_; }
object_types()69   const std::vector<std::string> &object_types() const { return object_types_; }
value_depend()70   const std::string &value_depend() const { return value_depend_; }
shapes_type()71   const std::string &shapes_type() const { return shapes_type_; }
72 
set_index(const int index)73   void set_index(const int index) { index_ = index; }
set_name(const std::string & name)74   void set_name(const std::string &name) { name_ = name; }
set_need_compile(const bool need_compile)75   void set_need_compile(const bool need_compile) { need_compile_ = need_compile; }
set_param_type(const std::string & param_type)76   void set_param_type(const std::string &param_type) { param_type_ = param_type; }
set_reshape_type(const std::string & reshape_type)77   void set_reshape_type(const std::string &reshape_type) { reshape_type_ = reshape_type; }
set_shape(const std::string & shape)78   void set_shape(const std::string &shape) { shape_ = shape; }
set_dtypes(const std::vector<std::string> & dtype)79   void set_dtypes(const std::vector<std::string> &dtype) { dtypes_ = dtype; }
set_formats(const std::vector<std::string> & formats)80   void set_formats(const std::vector<std::string> &formats) { formats_ = formats; }
set_unknown_shape_formats(const std::vector<std::string> & unknown_shape_formats)81   void set_unknown_shape_formats(const std::vector<std::string> &unknown_shape_formats) {
82     unknown_shape_formats_ = unknown_shape_formats;
83   }
set_object_types(const std::vector<std::string> & object_types)84   void set_object_types(const std::vector<std::string> &object_types) { object_types_ = object_types; }
set_value_depend(const std::string & value_depend)85   void set_value_depend(const std::string &value_depend) { value_depend_ = value_depend; }
set_shapes_type(const std::string & shapes_type)86   void set_shapes_type(const std::string &shapes_type) { shapes_type_ = shapes_type; }
87 
88  private:
89   int index_ = 0;
90   std::string name_;
91   bool need_compile_ = false;
92   std::string param_type_;
93   std::string reshape_type_;
94   std::string shape_;
95   std::string shapes_type_;
96   std::string value_depend_ = kIgnored;
97   std::vector<std::string> dtypes_;
98   std::vector<std::string> formats_;
99   std::vector<std::string> unknown_shape_formats_;
100   std::vector<std::string> object_types_;
101 };
102 
103 class OpInfo {
104  public:
105   OpInfo() = default;
106   ~OpInfo() = default;
op_name()107   std::string op_name() const { return op_name_; }
imply_type()108   OpImplyType imply_type() const { return imply_type_; }
async()109   bool async() const { return async_; }
bin_file()110   std::string bin_file() const { return bin_file_; }
compute()111   int compute() const { return compute_; }
cube_op()112   bool cube_op() const { return cube_op_; }
dynamic_compile_static()113   bool dynamic_compile_static() const { return dynamic_compile_static_; }
dynamic_format()114   bool dynamic_format() const { return dynamic_format_; }
dynamic_rank_support()115   bool dynamic_rank_support() const { return dynamic_rank_support_; }
dynamic_shape_support()116   bool dynamic_shape_support() const { return dynamic_shape_support_; }
heavy_op()117   bool heavy_op() const { return heavy_op_; }
jit_compile()118   bool jit_compile() const { return jit_compile_; }
kernel()119   std::string kernel() const { return kernel_; }
need_check_support()120   bool need_check_support() const { return need_check_support_; }
op_pattern()121   OpPattern op_pattern() const { return op_pattern_; }
op_file()122   std::string op_file() const { return op_file_; }
op_interface()123   std::string op_interface() const { return op_interface_; }
partial()124   bool partial() const { return partial_; }
precision_reduce()125   bool precision_reduce() const { return precision_reduce_; }
range_limit()126   std::string range_limit() const { return range_limit_; }
sagt_key_attrs()127   const std::vector<std::string> &sagt_key_attrs() const { return sagt_key_attrs_; }
slice_pattern()128   std::string slice_pattern() const { return slice_pattern_; }
prebuild_pattern()129   std::string prebuild_pattern() const { return prebuild_pattern_; }
130   // Attr
impl_path()131   std::string impl_path() const { return impl_path_; }
processor()132   std::string processor() const { return processor_; }
input_to_attr_index()133   const std::vector<size_t> &input_to_attr_index() const { return input_to_attr_index_; }
real_input_index()134   const std::pair<std::map<size_t, size_t>, std::map<size_t, size_t>> &real_input_index() const {
135     return real_input_index_;
136   }
ref_infos()137   const std::unordered_map<size_t, size_t> &ref_infos() const { return ref_infos_; }
138 
attrs_ptr()139   std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; }
inputs_ptr()140   std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; }
outputs_ptr()141   std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; }
142 
set_op_name(const std::string & op_name)143   void set_op_name(const std::string &op_name) { op_name_ = op_name; }
set_imply_type(OpImplyType imply_type)144   void set_imply_type(OpImplyType imply_type) { imply_type_ = imply_type; }
set_async(const bool async)145   void set_async(const bool async) { async_ = async; }
set_bin_file(const std::string & bin_file)146   void set_bin_file(const std::string &bin_file) { bin_file_ = bin_file; }
set_compute(const int compute)147   void set_compute(const int compute) { compute_ = compute; }
set_cube_op(bool cube_op)148   void set_cube_op(bool cube_op) { cube_op_ = cube_op; }
set_dynamic_compile_static(bool dynamic_compile_static)149   void set_dynamic_compile_static(bool dynamic_compile_static) { dynamic_compile_static_ = dynamic_compile_static; }
set_dynamic_format(bool dynamic_format)150   void set_dynamic_format(bool dynamic_format) { dynamic_format_ = dynamic_format; }
set_dynamic_rank_support(bool dynamic_rank_support)151   void set_dynamic_rank_support(bool dynamic_rank_support) { dynamic_rank_support_ = dynamic_rank_support; }
set_dynamic_shape_support(bool flag)152   void set_dynamic_shape_support(bool flag) { dynamic_shape_support_ = flag; }
set_heavy_op(bool heavy_op)153   void set_heavy_op(bool heavy_op) { heavy_op_ = heavy_op; }
set_jit_compile(bool jit_compile)154   void set_jit_compile(bool jit_compile) { jit_compile_ = jit_compile; }
set_soft_sync(bool soft_sync)155   void set_soft_sync(bool soft_sync) { soft_sync_ = soft_sync; }
set_op_impl_switch(const std::string & op_impl_switch)156   void set_op_impl_switch(const std::string &op_impl_switch) { op_impl_switch_ = op_impl_switch; }
set_kernel(const std::string & kernel_name)157   void set_kernel(const std::string &kernel_name) { kernel_ = kernel_name; }
set_need_check_supported(bool need_check_supported)158   void set_need_check_supported(bool need_check_supported) { need_check_support_ = need_check_supported; }
set_op_pattern(const OpPattern op_pattern)159   void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; }
set_op_file(const std::string & op_file)160   void set_op_file(const std::string &op_file) { op_file_ = op_file; }
set_op_interface(const std::string & op_interface)161   void set_op_interface(const std::string &op_interface) { op_interface_ = op_interface; }
set_partial(const bool partial_flag)162   void set_partial(const bool partial_flag) { partial_ = partial_flag; }
set_precision_reduce(bool precision_reduce)163   void set_precision_reduce(bool precision_reduce) { precision_reduce_ = precision_reduce; }
set_range_limit(const std::string & range_limit)164   void set_range_limit(const std::string &range_limit) { range_limit_ = range_limit; }
set_sagt_key_attrs(const std::vector<std::string> & sagt_key_attrs)165   void set_sagt_key_attrs(const std::vector<std::string> &sagt_key_attrs) { sagt_key_attrs_ = sagt_key_attrs; }
set_slice_pattern(const std::string & slice_pattern)166   void set_slice_pattern(const std::string &slice_pattern) { slice_pattern_ = slice_pattern; }
set_prebuild_pattern(const std::string & prebuild_pattern)167   void set_prebuild_pattern(const std::string &prebuild_pattern) { prebuild_pattern_ = prebuild_pattern; }
168 
set_impl_path(const std::string & impl_path)169   void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; }
set_processor(const std::string & processor)170   void set_processor(const std::string &processor) { processor_ = processor; }
set_input_to_attr_index(const std::vector<size_t> & input_to_attr_index)171   void set_input_to_attr_index(const std::vector<size_t> &input_to_attr_index) {
172     input_to_attr_index_ = input_to_attr_index;
173   }
set_real_input_index(const std::pair<std::map<size_t,size_t>,std::map<size_t,size_t>> & real_input_index)174   void set_real_input_index(const std::pair<std::map<size_t, size_t>, std::map<size_t, size_t>> &real_input_index) {
175     real_input_index_ = real_input_index;
176   }
is_ref()177   bool is_ref() const { return !ref_infos_.empty(); }
has_ref_index(size_t out_index)178   bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); }
add_ref_pair(size_t out_index,size_t in_index)179   void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); }
add_attrs_ptr(const std::shared_ptr<OpAttr> & attr)180   void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); }
add_inputs_ptr(const std::shared_ptr<OpIOInfo> & input)181   void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); }
set_inputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>> & inputs)182   void set_inputs_ptr(const std::vector<std::shared_ptr<OpIOInfo>> &inputs) { inputs_ptr_ = inputs; }
add_outputs_ptr(const std::shared_ptr<OpIOInfo> & output)183   void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); }
184 
equals_to(const std::shared_ptr<OpInfo> & other_info)185   bool equals_to(const std::shared_ptr<OpInfo> &other_info) const {
186     return this->op_name_ == other_info->op_name_ && this->imply_type_ == other_info->imply_type_ &&
187            this->processor_ == other_info->processor_ && this->op_pattern_ == other_info->op_pattern_ &&
188            this->dynamic_shape_support_ == other_info->dynamic_shape_support_ &&
189            this->dynamic_compile_static_ == other_info->dynamic_compile_static_;
190   }
191 
192  private:
193   std::string op_name_;
194   OpImplyType imply_type_ = kImplyTBE;
195   bool async_ = false;
196   std::string bin_file_;
197   int compute_ = 0;
198   bool cube_op_ = false;
199   bool dynamic_compile_static_ = false;
200   bool dynamic_format_ = false;
201   bool dynamic_rank_support_ = false;
202   bool dynamic_shape_support_ = false;
203   bool heavy_op_ = false;
204   bool jit_compile_ = false;
205   bool soft_sync_ = false;
206   std::string op_impl_switch_ = "";
207   std::string kernel_;
208   bool need_check_support_ = false;
209   OpPattern op_pattern_ = kCommonPattern;
210   std::string op_file_;
211   std::string op_interface_;
212   bool partial_ = false;
213   bool precision_reduce_ = false;
214   std::string range_limit_;
215   std::vector<std::string> sagt_key_attrs_ = {};
216   std::string slice_pattern_;
217   std::string prebuild_pattern_;
218   // Attr info
219   std::vector<std::shared_ptr<OpAttr>> attrs_ptr_;
220   // Input/Output info
221   std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr_;
222   std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr_;
223 
224   // Attr not in the json
225   std::string impl_path_;
226   std::string processor_;
227   std::vector<size_t> input_to_attr_index_{};
228   std::pair<std::map<size_t, size_t>, std::map<size_t, size_t>> real_input_index_{{}, {}};
229   std::unordered_map<size_t, size_t> ref_infos_;
230 };
231 
232 using OpAttrPtr = std::shared_ptr<OpAttr>;
233 using OpIOInfoPtr = std::shared_ptr<OpIOInfo>;
234 using OpInfoPtr = std::shared_ptr<OpInfo>;
235 }  // namespace mindspore::kernel
236 #endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_OPLIB_OPINFO_H_
237