1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_BUILD_INFO_H_ 18 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_BUILD_INFO_H_ 19 #include <iostream> 20 #include <vector> 21 #include <memory> 22 #include <string> 23 #include <utility> 24 #include "kernel/oplib/op_info_keys.h" 25 #include "ir/dtype.h" 26 #include "ir/kernel_info_dev.h" 27 #include "kernel/kernel.h" 28 29 #ifdef OPAQUE 30 #undef OPAQUE 31 #endif 32 33 namespace mindspore { 34 namespace kernel { 35 constexpr auto kPatternOpaque = "Opaque"; 36 enum KernelObjectType : int { 37 UNKNOWN_TYPE = 0, 38 TENSOR, 39 SCALAR, 40 TUPLE, 41 TUPLE_UNFOLD, 42 }; 43 44 enum OpType : int { 45 UNKNOWN_OP_TYPE = 0, 46 DYNAMIC, 47 SKIP, 48 }; 49 50 std::string KernelObjectTypeLabel(const KernelObjectType &obj_type); 51 BACKEND_EXPORT std::string KernelTypeLabel(const KernelType &kernel_type); 52 std::string OpTypeLabel(const OpType &op_type); 53 54 class BACKEND_EXPORT KernelBuildInfo { 55 public: 56 class KernelBuildInfoBuilder; 57 58 KernelBuildInfo() = default; 59 60 ~KernelBuildInfo() = default; 61 kernel_type()62 KernelType kernel_type() const { return kernel_type_; } 63 op_type()64 OpType op_type() const { return op_type_; } 65 set_kernel_type(KernelType kernel_type)66 void set_kernel_type(KernelType kernel_type) { kernel_type_ = kernel_type; } 67 set_op_type(OpType op_type)68 void set_op_type(OpType op_type) { op_type_ = op_type; } 69 70 std::string GetInputFormat(size_t input_index) const; 71 72 std::string GetOutputFormat(size_t output_index) const; 73 74 TypeId GetInputDeviceType(size_t input_index) const; 75 76 TypeId GetOutputDeviceType(size_t output_index) const; 77 78 KernelObjectType GetInputKernelObjectType(size_t input_index) const; 79 80 KernelObjectType GetOutputKernelObjectType(size_t output_index) const; 81 82 std::string GetInputReshapeType(size_t input_index) const; 83 84 bool IsInputDefaultPadding() const; 85 86 bool IsOutputDefaultPadding() const; 87 88 std::string GetOutputReshapeType(size_t output_index) const; 89 90 const std::string &GetOriginDataFormat() const; 91 92 const std::vector<std::string> &GetAllInputFormats() const; 93 94 const std::vector<std::string> &GetAllOutputFormats() const; 95 96 const std::vector<TypeId> &GetAllInputDeviceTypes() const; 97 98 const std::vector<TypeId> &GetAllOutputDeviceTypes() const; 99 100 const std::vector<KernelObjectType> &GetAllInputKernelObjectTypes() const; 101 102 const std::vector<KernelObjectType> &GetAllOutputKernelObjectTypes() const; 103 104 const std::vector<KernelObjectType> &GetAllOutputElementsKernelObjectTypes() const; 105 106 void SetOpType(const OpType &op_type); 107 108 void SetOutputsKernelObjectType(const std::vector<KernelObjectType> &outputs_kernel_object_type); 109 110 void SetInputsKernelObjectType(const std::vector<KernelObjectType> &inputs_kernel_object_type); 111 112 void SetOutputElementsKernelObjectType(const std::vector<KernelObjectType> &output_elements_kernel_object_type); 113 114 const std::vector<std::string> &GetAllOutputReshapeType() const; 115 116 const std::vector<std::string> &GetAllInputReshapeType() const; 117 core_type()118 std::string core_type() const { return core_type_; } 119 120 void SetOutputFormat(const std::string &format, size_t index); 121 122 void SetInputFormat(const std::string &format, size_t index); 123 124 void SetOutputDeviceType(const TypeId &output_device_type, size_t index); 125 126 void SetInputsFormat(const std::vector<std::string> &inputs_format); 127 128 void SetOutputsFormat(const std::vector<std::string> &outputs_format); 129 130 void SetInputsDeviceType(const std::vector<TypeId> &inputs_device_type); 131 132 void SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type); 133 134 void SetInputsReshapeType(const std::vector<std::string> &input_reshape_type); 135 op_pattern()136 OpPattern op_pattern() const { return op_pattern_; } 137 output_data_desc()138 std::vector<nlohmann::json> output_data_desc() const { return output_data_desc_; } 139 fusion_type()140 std::string fusion_type() const { return fusion_type_; } 141 valid()142 bool valid() const { return valid_; } set_valid(bool valid)143 void set_valid(bool valid) { valid_ = valid; } 144 processor()145 Processor processor() const { return processor_; } set_processor(Processor processor)146 void set_processor(Processor processor) { processor_ = processor; } 147 148 size_t GetInputNum() const; 149 150 size_t GetOutputNum() const; 151 152 size_t GetOutputNumWithoutMonad() const; 153 154 std::string ToString() const; 155 156 bool IsSimilarityKernelBuildInfo(const KernelBuildInfo &other) const; 157 158 bool operator==(const KernelBuildInfo &other) const; 159 160 bool operator!=(const KernelBuildInfo &other) const; 161 162 static auto constexpr kInvalidFormat = "InvalidFormat"; 163 164 private: 165 KernelType kernel_type_{UNKNOWN_KERNEL_TYPE}; 166 OpType op_type_{UNKNOWN_OP_TYPE}; 167 std::string origin_data_format_{kOpFormat_DEFAULT}; 168 std::string core_type_; 169 std::vector<std::string> inputs_format_; 170 OpPattern op_pattern_{kCommonPattern}; 171 std::vector<std::string> outputs_format_; 172 std::vector<std::string> input_reshape_type_; 173 std::vector<std::string> output_reshape_type_; 174 std::vector<TypeId> inputs_device_type_; 175 std::vector<TypeId> outputs_device_type_; 176 std::vector<KernelObjectType> inputs_kernel_object_type_; 177 std::vector<KernelObjectType> outputs_kernel_object_type_; 178 // Indicates kernel object types of elements in TupleUnfold. 179 // Only valid when output kernel object type is TupleUnfold. 180 std::vector<KernelObjectType> output_elements_kernel_object_type_; 181 std::vector<nlohmann::json> output_data_desc_; 182 std::string fusion_type_{kPatternOpaque}; 183 Processor processor_{UNKNOWN}; 184 // Indicates whether buildinfo is valid, the invalid buildinfo needs to select kernel again. 185 bool valid_{true}; 186 }; 187 using KernelBuildInfoPtr = std::shared_ptr<KernelBuildInfo>; 188 189 class BACKEND_EXPORT KernelBuildInfo::KernelBuildInfoBuilder { 190 public: KernelBuildInfoBuilder()191 KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared<KernelBuildInfo>(); } 192 KernelBuildInfoBuilder(const KernelBuildInfoPtr & kernel_build_info)193 explicit KernelBuildInfoBuilder(const KernelBuildInfoPtr &kernel_build_info) 194 : kernel_build_info_(std::make_shared<KernelBuildInfo>()) { 195 SetKernelType(kernel_build_info->kernel_type()); 196 SetOpType(kernel_build_info->op_type()); 197 SetFusionType(kernel_build_info->fusion_type()); 198 SetProcessor(kernel_build_info->processor()); 199 SetOpPattern(kernel_build_info->op_pattern()); 200 SetCoreType(kernel_build_info->core_type()); 201 SetOutputDataDesc(kernel_build_info->output_data_desc()); 202 for (size_t index = 0; index < kernel_build_info->GetInputNum(); ++index) { 203 (void)kernel_build_info_->inputs_device_type_.emplace_back(kernel_build_info->GetInputDeviceType(index)); 204 (void)kernel_build_info_->inputs_format_.emplace_back(kernel_build_info->GetInputFormat(index)); 205 (void)kernel_build_info_->input_reshape_type_.emplace_back(kernel_build_info->GetInputReshapeType(index)); 206 } 207 kernel_build_info_->inputs_kernel_object_type_ = kernel_build_info->GetAllInputKernelObjectTypes(); 208 209 for (size_t index = 0; index < kernel_build_info->GetOutputNum(); ++index) { 210 (void)kernel_build_info_->outputs_device_type_.emplace_back(kernel_build_info->GetOutputDeviceType(index)); 211 (void)kernel_build_info_->outputs_format_.emplace_back(kernel_build_info->GetOutputFormat(index)); 212 (void)kernel_build_info_->output_reshape_type_.emplace_back(kernel_build_info->GetOutputReshapeType(index)); 213 } 214 kernel_build_info_->outputs_kernel_object_type_ = kernel_build_info->GetAllOutputKernelObjectTypes(); 215 SetValid(kernel_build_info->valid()); 216 } 217 218 ~KernelBuildInfoBuilder() = default; 219 220 void SetKernelType(const KernelType &kernel_type); 221 222 void SetOpType(const OpType &op_type); 223 224 void SetOriginDataFormat(const std::string &origin_data_format); 225 226 void SetInputsFormat(const std::vector<std::string> &inputs_format); 227 228 void SetOutputsFormat(const std::vector<std::string> &outputs_format); 229 230 void SetInputsDeviceType(const std::vector<TypeId> &inputs_device_type); 231 232 void SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type); 233 234 void SetInputsReshapeType(const std::vector<std::string> &input_reshape_type); 235 236 void SetOutputsReshapeType(const std::vector<std::string> &output_reshape_type); 237 238 void SetInputsKernelObjectType(const std::vector<KernelObjectType> &input_kernel_object_type); 239 240 void SetOutputsKernelObjectType(const std::vector<KernelObjectType> &output_kernel_object_type); 241 242 void SetOutputElementsKernelObjectType(const std::vector<KernelObjectType> &output_elements_kernel_object_type); 243 244 void SetCoreType(const std::string &core_type); 245 246 void SetFusionType(const std::string &fusion_type); 247 // save prebuild result 248 void SetOutputDataDesc(const std::vector<nlohmann::json> &data_desc); 249 250 void SetProcessor(Processor processor); 251 252 void SetOpPattern(OpPattern pattern); 253 254 void SetInputFormat(const std::string &format, size_t index); 255 256 void SetOutputFormat(const std::string &format, size_t index); 257 258 void SetInputReshapeType(const std::string &input_reshape_type, size_t index); 259 260 void SetOutputReshapeType(const std::string &output_reshape_type, size_t index); 261 262 void SetInputDeviceType(const TypeId &input_device_type, size_t index); 263 264 void SetOutputDeviceType(const TypeId &output_device_type, size_t index); 265 266 void SetValid(bool valid); 267 268 std::shared_ptr<KernelBuildInfo> Build(); 269 270 private: 271 std::shared_ptr<KernelBuildInfo> kernel_build_info_; 272 }; 273 } // namespace kernel 274 } // namespace mindspore 275 #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_BUILD_INFO_H_ 276