1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_BUILD_INFO_H_ 18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_BUILD_INFO_H_ 19 #include <iostream> 20 #include <vector> 21 #include <memory> 22 #include <string> 23 #include <utility> 24 #include "ir/dtype.h" 25 #include "ir/kernel_info_dev.h" 26 #include "backend/kernel_compiler/kernel.h" 27 28 namespace mindspore { 29 namespace kernel { 30 class KernelBuildInfo { 31 public: 32 class KernelBuildInfoBuilder; 33 KernelBuildInfo()34 KernelBuildInfo() { 35 kernel_type_ = TBE_KERNEL; 36 fusion_type_ = OPAQUE; 37 processor_ = AICORE; 38 op_pattern_ = kCommonPattern; 39 input_reshape_type_ = {}; 40 output_reshape_type_ = {}; 41 origin_data_format_ = kOpFormat_DEFAULT; 42 inputs_format_ = {}; 43 outputs_format_ = {}; 44 inputs_device_type_ = {}; 45 outputs_device_type_ = {}; 46 output_data_desc_ = {}; 47 } 48 49 ~KernelBuildInfo() = default; 50 kernel_type()51 KernelType kernel_type() const { return kernel_type_; } 52 53 std::string GetInputFormat(size_t input_index) const; 54 55 std::string GetOutputFormat(size_t output_index) const; 56 57 TypeId GetInputDeviceType(size_t input_index) const; 58 59 TypeId GetOutputDeviceType(size_t output_index) const; 60 61 std::string GetInputReshapeType(size_t input_index) const; 62 63 std::string GetInputValueDepend(size_t input_index) const; 64 65 bool IsInputDefaultPadding() const; 66 67 bool IsOutputDefaultPadding() const; 68 69 std::string GetOutputReshapeType(size_t input_index) const; 70 71 const std::string &GetOriginDataFormat() const; 72 73 const std::vector<std::string> &GetAllInputFormats() const; 74 75 const std::vector<std::string> &GetAllOutputFormats() const; 76 77 const std::vector<TypeId> &GetAllInputDeviceTypes() const; 78 79 const std::vector<TypeId> &GetAllOutputDeviceTypes() const; 80 81 std::vector<std::string> GetAllOutputReshapeType() const; 82 83 std::vector<std::string> GetAllInputReshapeType() const; 84 op_pattern()85 OpPattern op_pattern() const { return op_pattern_; } 86 output_data_desc()87 std::vector<nlohmann::json> output_data_desc() const { return output_data_desc_; } 88 fusion_type()89 FusionType fusion_type() const { return fusion_type_; } 90 processor()91 Processor processor() const { return processor_; } 92 93 size_t GetInputNum() const; 94 95 size_t GetOutputNum() const; 96 97 std::string ToString() const; 98 99 bool IsSimilarityKernelBuildInfo(const KernelBuildInfo &other) const; 100 101 bool operator==(const KernelBuildInfo &other) const; 102 103 bool operator!=(const KernelBuildInfo &other) const; 104 105 static auto constexpr kInvalidFormat = "InvalidFormat"; 106 107 private: 108 KernelType kernel_type_; 109 std::string origin_data_format_; 110 std::vector<std::string> inputs_format_; 111 OpPattern op_pattern_; 112 std::vector<std::string> outputs_format_; 113 std::vector<std::string> input_reshape_type_; 114 std::vector<std::string> output_reshape_type_; 115 std::vector<TypeId> inputs_device_type_; 116 std::vector<TypeId> outputs_device_type_; 117 std::vector<nlohmann::json> output_data_desc_; 118 std::vector<std::string> input_value_depend_; 119 FusionType fusion_type_; 120 Processor processor_; 121 }; 122 using KernelBuildInfoPtr = std::shared_ptr<KernelBuildInfo>; 123 124 class KernelBuildInfo::KernelBuildInfoBuilder { 125 public: KernelBuildInfoBuilder()126 KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared<KernelBuildInfo>(); } 127 KernelBuildInfoBuilder(std::shared_ptr<KernelBuildInfo> kernel_build_info)128 explicit KernelBuildInfoBuilder(std::shared_ptr<KernelBuildInfo> kernel_build_info) 129 : kernel_build_info_(std::make_shared<KernelBuildInfo>()) { 130 SetKernelType(kernel_build_info->kernel_type()); 131 SetFusionType(kernel_build_info->fusion_type()); 132 SetProcessor(kernel_build_info->processor()); 133 SetOpPattern(kernel_build_info->op_pattern()); 134 for (size_t index = 0; index < kernel_build_info->GetInputNum(); ++index) { 135 kernel_build_info_->inputs_device_type_.emplace_back(kernel_build_info->GetInputDeviceType(index)); 136 kernel_build_info_->inputs_format_.emplace_back(kernel_build_info->GetInputFormat(index)); 137 kernel_build_info_->input_reshape_type_.emplace_back(kernel_build_info->GetInputReshapeType(index)); 138 kernel_build_info_->input_value_depend_.emplace_back(kernel_build_info->GetInputValueDepend(index)); 139 } 140 for (size_t index = 0; index < kernel_build_info->GetOutputNum(); ++index) { 141 kernel_build_info_->outputs_device_type_.emplace_back(kernel_build_info->GetOutputDeviceType(index)); 142 kernel_build_info_->outputs_format_.emplace_back(kernel_build_info->GetOutputFormat(index)); 143 kernel_build_info_->output_reshape_type_.emplace_back(kernel_build_info->GetOutputReshapeType(index)); 144 } 145 } 146 147 ~KernelBuildInfoBuilder() = default; 148 149 void SetKernelType(const KernelType &kernel_type); 150 151 void SetOriginDataFormat(const std::string &origin_data_format); 152 153 void SetInputsFormat(const std::vector<std::string> &inputs_format); 154 155 void SetOutputsFormat(const std::vector<std::string> &outputs_format); 156 157 void SetInputsDeviceType(const std::vector<TypeId> &inputs_device_type); 158 159 void SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type); 160 161 void SetInputsReshapeType(const std::vector<std::string> &input_reshape_type); 162 163 void SetInputsValueDepend(const std::vector<std::string> &input_value_depend); 164 165 void SetOutputsReshapeType(const std::vector<std::string> &output_reshape_type); 166 167 void SetFusionType(FusionType fusion_type); 168 // save prebuild result 169 void SetOutputDataDesc(const std::vector<nlohmann::json> &data_desc); 170 171 void SetProcessor(Processor processor); 172 173 void SetOpPattern(OpPattern pattern); 174 175 void SetInputFormat(const std::string &format, size_t index); 176 177 void SetOutputFormat(const std::string &format, size_t index); 178 179 void SetInputReshapeType(const std::string &input_reshape_type, size_t index); 180 181 void SetOutputReshapeType(const std::string &output_reshape_type, size_t index); 182 183 void SetInputDeviceType(const TypeId &input_device_type, size_t index); 184 185 void SetOutputDeviceType(const TypeId &output_device_type, size_t index); 186 187 std::shared_ptr<KernelBuildInfo> Build(); 188 189 private: 190 std::shared_ptr<KernelBuildInfo> kernel_build_info_; 191 }; 192 } // namespace kernel 193 } // namespace mindspore 194 #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_BUILD_INFO_H_ 195