• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_OPLIB_OPINFO_H_
18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_OPLIB_OPINFO_H_
19 #include <vector>
20 #include <string>
21 #include <memory>
22 #include <unordered_map>
23 #include "ir/dtype.h"
24 #include "backend/kernel_compiler/kernel.h"
25 
26 namespace mindspore {
27 namespace kernel {
28 enum OpImplyType { kAKG = 0, kTBE = 1, kAICPU = 2, kCPU };
29 enum OpIOType { kInput = 0, kOutput };
30 constexpr auto kIgnored = "ignored";
31 
32 class OpAttr {
33  public:
34   OpAttr() = default;
35   ~OpAttr() = default;
36 
name()37   std::string name() const { return name_; }
param_type()38   std::string param_type() const { return param_type_; }
type()39   std::string type() const { return type_; }
value()40   std::string value() const { return value_; }
default_value()41   std::string default_value() const { return default_value_; }
42 
set_name(const std::string & name)43   void set_name(const std::string &name) { name_ = name; }
set_param_type(const std::string & param_type)44   void set_param_type(const std::string &param_type) { param_type_ = param_type; }
set_type(const std::string & type)45   void set_type(const std::string &type) { type_ = type; }
set_value(const std::string & value)46   void set_value(const std::string &value) { value_ = value; }
set_default_value(const std::string & default_value)47   void set_default_value(const std::string &default_value) { default_value_ = default_value; }
48 
49  private:
50   std::string name_;
51   std::string param_type_;
52   std::string type_;
53   std::string value_;
54   std::string default_value_;
55 };
56 
57 class OpIOInfo {
58  public:
59   OpIOInfo() = default;
60   ~OpIOInfo() = default;
61 
index()62   int index() const { return index_; }
name()63   const std::string &name() const { return name_; }
need_compile()64   bool need_compile() const { return need_compile_; }
param_type()65   const std::string &param_type() const { return param_type_; }
reshape_type()66   const std::string &reshape_type() const { return reshape_type_; }
shape()67   const std::string &shape() const { return shape_; }
dtypes()68   const std::vector<std::string> &dtypes() const { return dtypes_; }
formats()69   const std::vector<std::string> &formats() const { return formats_; }
value_depend()70   const std::string &value_depend() const { return value_depend_; }
71 
set_index(const int index)72   void set_index(const int index) { index_ = index; }
set_name(const std::string & name)73   void set_name(const std::string &name) { name_ = name; }
set_need_compile(const bool need_compile)74   void set_need_compile(const bool need_compile) { need_compile_ = need_compile; }
set_param_type(const std::string & param_type)75   void set_param_type(const std::string &param_type) { param_type_ = param_type; }
set_reshape_type(const std::string & reshape_type)76   void set_reshape_type(const std::string &reshape_type) { reshape_type_ = reshape_type; }
set_shape(const std::string & shape)77   void set_shape(const std::string &shape) { shape_ = shape; }
set_dtypes(const std::vector<std::string> & dtype)78   void set_dtypes(const std::vector<std::string> &dtype) { dtypes_ = dtype; }
set_formats(const std::vector<std::string> & formats)79   void set_formats(const std::vector<std::string> &formats) { formats_ = formats; }
set_value_depend(const std::string & value_depend)80   void set_value_depend(const std::string &value_depend) { value_depend_ = value_depend; }
81 
82  private:
83   int index_ = 0;
84   std::string name_;
85   bool need_compile_ = false;
86   std::string param_type_;
87   std::string reshape_type_;
88   std::string shape_;
89   std::vector<std::string> dtypes_;
90   std::vector<std::string> formats_;
91   std::string value_depend_ = kIgnored;
92 };
93 
94 class OpInfo {
95  public:
96   OpInfo() = default;
OpInfo(const OpInfo & opinfo)97   OpInfo(const OpInfo &opinfo) {
98     op_name_ = opinfo.op_name();
99     imply_type_ = opinfo.imply_type();
100 
101     impl_path_ = opinfo.impl_path();
102     fusion_type_ = opinfo.fusion_type();
103     async_flag_ = opinfo.async_flag_;
104     binfile_name_ = opinfo.binfile_name_;
105     compute_cost_ = opinfo.compute_cost_;
106     kernel_name_ = opinfo.kernel_name();
107     partial_flag_ = opinfo.partial_flag_;
108     dynamic_shape_ = opinfo.dynamic_shape_;
109     dynamic_compile_static_ = opinfo.dynamic_compile_static_;
110     op_pattern_ = opinfo.op_pattern();
111     processor_ = opinfo.processor_;
112     need_check_supported_ = opinfo.need_check_supported();
113     is_dynamic_format_ = opinfo.is_dynamic_format();
114     for (const auto &attr : opinfo.attrs_ptr()) {
115       attrs_ptr_.push_back(std::make_shared<OpAttr>(*attr));
116     }
117     for (const auto &input : opinfo.inputs_ptr()) {
118       inputs_ptr_.push_back(std::make_shared<OpIOInfo>(*input));
119     }
120     for (const auto &output : opinfo.outputs_ptr()) {
121       outputs_ptr_.push_back(std::make_shared<OpIOInfo>(*output));
122     }
123     ref_infos_ = opinfo.ref_infos();
124   }
125   ~OpInfo() = default;
op_name()126   std::string op_name() const { return op_name_; }
imply_type()127   OpImplyType imply_type() const { return imply_type_; }
impl_path()128   std::string impl_path() const { return impl_path_; }
fusion_type()129   std::string fusion_type() const { return fusion_type_; }
kernel_name()130   std::string kernel_name() const { return kernel_name_; }
op_pattern()131   OpPattern op_pattern() const { return op_pattern_; }
dynamic_shape()132   bool dynamic_shape() const { return dynamic_shape_; }
dynamic_compile_static()133   bool dynamic_compile_static() const { return dynamic_compile_static_; }
processor()134   std::string processor() const { return processor_; }
need_check_supported()135   bool need_check_supported() const { return need_check_supported_; }
is_dynamic_format()136   bool is_dynamic_format() const { return is_dynamic_format_; }
attrs_ptr()137   std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; }
inputs_ptr()138   std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; }
outputs_ptr()139   std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; }
ref_infos()140   const std::unordered_map<size_t, size_t> &ref_infos() const { return ref_infos_; }
141 
set_dynamic_shape(bool dynamic_shape)142   void set_dynamic_shape(bool dynamic_shape) { dynamic_shape_ = dynamic_shape; }
set_dynamic_compile_static_(bool dynamic_compile_static)143   void set_dynamic_compile_static_(bool dynamic_compile_static) { dynamic_compile_static_ = dynamic_compile_static; }
set_op_name(const std::string & op_name)144   void set_op_name(const std::string &op_name) { op_name_ = op_name; }
set_imply_type(const OpImplyType imply_type)145   void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; }
set_impl_path(const std::string & impl_path)146   void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; }
set_fusion_type(const std::string & fusion_type)147   void set_fusion_type(const std::string &fusion_type) { fusion_type_ = fusion_type; }
set_async_flag(const bool async_flag)148   void set_async_flag(const bool async_flag) { async_flag_ = async_flag; }
set_binfile_name(const std::string & binfile_name)149   void set_binfile_name(const std::string &binfile_name) { binfile_name_ = binfile_name; }
set_compute_cost(const int compute_cost)150   void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; }
set_kernel_name(const std::string & kernel_name)151   void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; }
set_partial_flag(const bool partial_flag)152   void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; }
set_op_pattern(const OpPattern op_pattern)153   void set_op_pattern(const OpPattern op_pattern) { op_pattern_ = op_pattern; }
set_processor(const std::string & processor)154   void set_processor(const std::string &processor) { processor_ = processor; }
set_need_check_supported(bool need_check_supported)155   void set_need_check_supported(bool need_check_supported) { need_check_supported_ = need_check_supported; }
set_is_dynamic_format(bool is_dynamic_format)156   void set_is_dynamic_format(bool is_dynamic_format) { is_dynamic_format_ = is_dynamic_format; }
add_attrs_ptr(const std::shared_ptr<OpAttr> & attr)157   void add_attrs_ptr(const std::shared_ptr<OpAttr> &attr) { attrs_ptr_.push_back(attr); }
add_inputs_ptr(const std::shared_ptr<OpIOInfo> & input)158   void add_inputs_ptr(const std::shared_ptr<OpIOInfo> &input) { inputs_ptr_.push_back(input); }
add_outputs_ptr(const std::shared_ptr<OpIOInfo> & output)159   void add_outputs_ptr(const std::shared_ptr<OpIOInfo> &output) { outputs_ptr_.push_back(output); }
is_ref()160   bool is_ref() const { return !ref_infos_.empty(); }
has_ref_index(size_t out_index)161   bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); }
add_ref_pair(size_t out_index,size_t in_index)162   void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); }
ClearInputs()163   void ClearInputs() { (void)inputs_ptr_.clear(); }
ClearOutputs()164   void ClearOutputs() { (void)outputs_ptr_.clear(); }
equals_to(const std::shared_ptr<OpInfo> & other_info)165   bool equals_to(const std::shared_ptr<OpInfo> &other_info) const {
166     return this->op_name_ == other_info->op_name_ && this->imply_type_ == other_info->imply_type_ &&
167            this->processor_ == other_info->processor_ && this->op_pattern_ == other_info->op_pattern_ &&
168            this->dynamic_shape_ == other_info->dynamic_shape_ &&
169            this->dynamic_compile_static_ == other_info->dynamic_compile_static_;
170   }
171 
172  private:
173   std::string op_name_;
174   OpImplyType imply_type_ = kTBE;
175   std::string impl_path_;
176   std::string fusion_type_;
177   bool async_flag_ = false;
178   std::string binfile_name_;
179   int compute_cost_ = 0;
180   std::string kernel_name_;
181   bool partial_flag_ = false;
182   bool dynamic_shape_ = false;
183   bool dynamic_compile_static_ = false;
184   bool need_check_supported_ = false;
185   bool is_dynamic_format_ = false;
186   OpPattern op_pattern_ = kCommonPattern;
187   std::string processor_;
188   std::vector<std::shared_ptr<OpAttr>> attrs_ptr_;
189   std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr_;
190   std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr_;
191   std::unordered_map<size_t, size_t> ref_infos_;
192 };
193 
194 using OpAttrPtr = std::shared_ptr<OpAttr>;
195 using OpIOInfoPtr = std::shared_ptr<OpIOInfo>;
196 using OpInfoPtr = std::shared_ptr<OpInfo>;
197 }  // namespace kernel
198 }  // namespace mindspore
199 #endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_OPLIB_OPINFO_H_
200