1 /** 2 * Copyright 2019-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_H_ 17 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_H_ 18 #include <cstddef> 19 #include <atomic> 20 #include <map> 21 #include <memory> 22 #include <optional> 23 #include <set> 24 #include <string> 25 #include <utility> 26 #include <variant> 27 #include <vector> 28 #include <algorithm> 29 #include "abstract/dshape.h" 30 #include "abstract/ops/primitive_infer_map.h" 31 #include "include/api/format.h" 32 #include "include/backend/visible.h" 33 #include "include/common/utils/utils.h" 34 #include "include/common/utils/convert_utils.h" 35 #include "include/backend/device_synchronizer.h" 36 #include "ir/anf.h" 37 #include "ir/dtype.h" 38 #include "ir/tensor.h" 39 #include "ir/kernel_tensor_value.h" 40 #include "mindspore/core/ops/base_operator.h" 41 #include "nlohmann/json.hpp" 42 #include "utils/log_adapter.h" 43 #include "ops/op_name.h" 44 #include "kernel/format_utils.h" 45 46 #ifdef _MSC_VER 47 #undef OPAQUE 48 #endif 49 50 #ifdef OPAQUE 51 #undef OPAQUE 52 #endif 53 54 namespace mindspore { 55 enum KernelType : int { 56 UNKNOWN_KERNEL_TYPE = 0, 57 AKG_KERNEL, 58 AICPU_KERNEL, 59 RT_KERNEL, 60 HCCL_KERNEL, 61 TBE_KERNEL, 62 HOST_KERNEL, 63 CPU_KERNEL, 64 GPU_KERNEL, 65 BISHENG_KERNEL, 66 ACL_KERNEL, 67 OPAPI_KERNEL, 68 INTERNAL_KERNEL, 69 }; 70 71 // PointerRefCount encapsulates pointer and reference count-related operations, and supports custom deleter to free 72 // resources. In Ref scenarios, KernelTensor of different DeviceAddress may hold the same PointerRefCount object. 73 class PointerRefCount { 74 public: 75 // The arguments are pointer and a bool variable that identifies whether pointer is from the memory pool. 76 using Deleter = std::function<void(void *, bool)>; 77 78 PointerRefCount() = default; PointerRefCount(void * ptr)79 explicit PointerRefCount(void *ptr) : ptr_(ptr) {} PointerRefCount(void * ptr,const Deleter & deleter)80 PointerRefCount(void *ptr, const Deleter &deleter) : ptr_(ptr), deleter_(deleter) {} 81 PointerRefCount(const PointerRefCount & other)82 PointerRefCount(const PointerRefCount &other) 83 : ptr_(other.ptr_), 84 original_ref_count_(other.original_ref_count_), 85 ref_count_(other.ref_count_.load()), 86 dynamic_ref_count_(other.dynamic_ref_count_.load()), 87 deleter_(other.deleter_) {} 88 ~PointerRefCount()89 ~PointerRefCount() { 90 try { 91 if (ptr_ != nullptr && deleter_) { 92 deleter_(ptr_, from_mem_pool_); 93 } 94 ptr_ = nullptr; 95 } catch (const std::exception &e) { 96 MS_LOG(ERROR) << "PointerRefCount destructed failed: " << e.what(); 97 } catch (...) { 98 MS_LOG(ERROR) << "PointerRefCount destructed failed."; 99 } 100 } 101 102 // Get raw pointer. ptr()103 void *ptr() const { return ptr_; } 104 // Set raw pointer. set_ptr(void * ptr)105 void set_ptr(void *ptr) { ptr_ = ptr; } 106 107 // Get whether pointer in PointerRefCount is allocated from the memory pool. from_mem_pool()108 bool from_mem_pool() const { return from_mem_pool_; } 109 // Set whether pointer in PointerRefCount is allocated from the memory pool. set_from_mem_pool(bool from_mem_pool)110 void set_from_mem_pool(bool from_mem_pool) { from_mem_pool_ = from_mem_pool; } 111 112 // Increase ref count or dynamic ref count. IncreaseCounter()113 size_t IncreaseCounter() { 114 if (ref_count_ != SIZE_MAX) { 115 return ++ref_count_; 116 } else if (dynamic_ref_count_ != INT32_MAX) { 117 return ++dynamic_ref_count_; 118 } 119 return SIZE_MAX; 120 } 121 // Decrease ref count or dynamic ref count. DecreaseCounter()122 size_t DecreaseCounter() { 123 if (ref_count_ != SIZE_MAX) { 124 return --ref_count_; 125 } else if (dynamic_ref_count_ != INT32_MAX) { 126 return --dynamic_ref_count_; 127 } 128 return SIZE_MAX; 129 } 130 131 // The related interface of static reference count operation. set_original_ref_count(size_t original_ref_count)132 void set_original_ref_count(size_t original_ref_count) { original_ref_count_ = original_ref_count; } original_ref_count()133 size_t original_ref_count() const { return original_ref_count_; } set_ref_count(size_t ref_count)134 void set_ref_count(size_t ref_count) { ref_count_ = ref_count; } ref_count()135 size_t ref_count() const { return ref_count_.load(); } IncreaseOriginalRefCount()136 void IncreaseOriginalRefCount() { 137 if (original_ref_count_ < SIZE_MAX) { 138 original_ref_count_++; 139 } 140 } DecreaseOriginalRefCount()141 void DecreaseOriginalRefCount() { 142 if ((original_ref_count_ < SIZE_MAX) && (original_ref_count_ > 0)) { 143 original_ref_count_--; 144 } 145 } DecreaseRefCount()146 size_t DecreaseRefCount() { return --ref_count_; } ResetRefCount()147 void ResetRefCount() { ref_count_ = original_ref_count_; } 148 149 // The related interface of dynamic reference count operation. set_dynamic_ref_count(int32_t dynamic_ref_count)150 void set_dynamic_ref_count(int32_t dynamic_ref_count) { dynamic_ref_count_ = dynamic_ref_count; } dynamic_ref_count()151 int32_t dynamic_ref_count() const { return dynamic_ref_count_; } IncreaseDynamicRefCount(const std::string & op_object)152 void IncreaseDynamicRefCount(const std::string &op_object) { 153 if (dynamic_ref_count_ < INT32_MAX) { 154 (void)++dynamic_ref_count_; 155 MS_LOG(DEBUG) << op_object << " increases dynamic ref count to:" << dynamic_ref_count_ << " for ptr:" << ptr(); 156 } 157 } DecreaseDynamicRefCount(const std::string & op_object)158 int32_t DecreaseDynamicRefCount(const std::string &op_object) { 159 if (dynamic_ref_count_ <= 0) { 160 MS_LOG(EXCEPTION) << "The dynamic reference count is invalid value:" << dynamic_ref_count_; 161 } 162 MS_LOG(DEBUG) << op_object << " The dynamic ref count decreases to:" << dynamic_ref_count_ << " for ptr:" << ptr(); 163 return --dynamic_ref_count_; 164 } 165 166 // Get pointer resource destructor. deleter()167 Deleter deleter() const { return deleter_; } 168 169 // Set pointer resource destructor. set_deleter(const Deleter & deleter)170 void set_deleter(const Deleter &deleter) { deleter_ = deleter; } 171 172 private: 173 void *ptr_{nullptr}; 174 175 // Whether ptr_ is allocated from the memory pool. 176 bool from_mem_pool_{false}; 177 178 // The static reference count, the value can be calculated at compile phase. 179 size_t original_ref_count_{1}; 180 // The current reference count value, it will be decreased in the running, and reset by original_ref_count_ when it is 181 // zero. 182 std::atomic<size_t> ref_count_{1}; 183 184 // The dynamic reference count, the value can be calculated at compile phase. 185 std::atomic_int32_t dynamic_ref_count_{INT32_MAX}; 186 187 // The pointer resource destructor. 188 Deleter deleter_; 189 }; 190 using PointerRefCountPtr = std::shared_ptr<PointerRefCount>; 191 192 namespace kernel { 193 194 // Backend processor 195 enum Processor { 196 UNKNOWN = -1, 197 AICORE = 0, 198 AICPU, 199 CUDA, 200 CPU, 201 BISHENG, 202 }; 203 204 struct AtomicInitInfo { 205 std::vector<std::string> dtype_list; 206 std::vector<int64_t> init_value_int64_list; 207 std::vector<float> init_value_float_list; 208 }; 209 210 /** 211 * @brief base class for autotensor kernel and cce kernel. 212 */ 213 struct Address { AddressAddress214 Address() : addr(nullptr), size(0) {} AddressAddress215 Address(void *address_addr, size_t address_size) : addr(address_addr), size(address_size) {} 216 void *addr; 217 size_t size; 218 }; 219 using AddressPtr = std::shared_ptr<Address>; 220 using AddressPtrList = std::vector<AddressPtr>; 221 using StreamType = void *; 222 using abstract::AbstractBase; 223 using device::DeviceSynchronizerPtr; 224 // The memory info of kernel launch. 225 struct KernelLaunchAddr { 226 AddressPtrList inputs_; 227 AddressPtrList outputs_; 228 AddressPtrList workspaces_; 229 }; 230 struct TensorInfo { 231 mindspore::Format format; 232 abstract::AbstractTensorPtr base_; 233 }; 234 struct ScalarInfo { 235 abstract::AbstractScalarPtr base_; 236 }; 237 struct ListInfo { 238 abstract::AbstractListPtr base_; 239 }; 240 struct TupleInfo { 241 abstract::AbstractTuplePtr base_; 242 }; 243 using TensorInfoPtr = std::shared_ptr<TensorInfo>; 244 using BaseOperatorPtr = std::shared_ptr<ops::BaseOperator>; 245 246 class KernelAttr; 247 248 // Used to encapsulate host-side related data structures in KernelTensor. 249 struct KernelHostInfo { 250 KernelHostInfo() = default; 251 252 KernelHostInfo(const KernelHostInfo &other); 253 254 // The shape vector transformed according `shape_vector_` and `format_` is generally used on the operator side. 255 // Operators on different platforms may require different format and shape information. 256 ShapeVector shape_vector_after_format_trasform_{}; 257 258 // Make shape transform related interfaces thread-safe. 259 std::mutex shape_transform_mutex_; 260 261 // The object enum type id of the KernelTensor. 262 TypeId type_id_{kTypeUnknown}; 263 264 // Saves the contents after the value is converted to continuous memory storage. 265 KernelTensorValuePtr kernel_tensor_value_{nullptr}; 266 267 // Make GetValue related interfaces thread-safe. 268 std::mutex value_mutex_; 269 }; 270 271 // A template class used to detect whether it is a valid container. 272 template <typename T> 273 struct ValidContainerChecker : std::false_type {}; 274 275 // A ValidContainerChecker's specialization to detect whether the type is std::vector whose element is scalar. 276 template <typename... Args> 277 struct ValidContainerChecker<std::vector<Args...>> : std::true_type {}; 278 279 // A ValidContainerChecker's specialization to detect whether the type is std::string. 280 template <> 281 struct ValidContainerChecker<std::string> : std::true_type {}; 282 283 // A wrapper used to check the types std::string and std::vector. 284 template <typename T> 285 struct IsValidContainer { 286 static constexpr bool value = ValidContainerChecker<std::decay_t<T>>::value; 287 }; 288 289 struct AddressCommon { 290 AddressCommon() { pointer_ref_count_ = std::make_shared<PointerRefCount>(); } 291 AddressCommon(void *device_ptr, size_t size) 292 : pointer_ref_count_(std::make_shared<PointerRefCount>(device_ptr)), size_(size) {} 293 AddressCommon(void *device_ptr, size_t size, const ShapeVector &shape_vector, const Format &format, TypeId dtype_id, 294 const std::string &device_name, uint32_t device_id, uint32_t stream_id = 0) 295 : pointer_ref_count_(std::make_shared<PointerRefCount>(device_ptr)), 296 stream_id_(stream_id), 297 size_(size), 298 format_(format), 299 dtype_id_(dtype_id), 300 device_name_(device_name), 301 device_id_(device_id), 302 shape_vector_(shape_vector) {} 303 AddressCommon(const AddressCommon &other) { 304 pointer_ref_count_ = 305 other.pointer_ref_count_ != nullptr 306 ? std::make_shared<PointerRefCount>(other.pointer_ref_count_->ptr(), other.pointer_ref_count_->deleter()) 307 : std::make_shared<PointerRefCount>(); 308 tensor_storage_info_ = other.tensor_storage_info_; 309 stream_id_ = other.stream_id_; 310 size_ = other.size_; 311 format_ = other.format_; 312 dtype_id_ = other.dtype_id_; 313 device_id_ = other.device_id_; 314 device_name_ = other.device_name_; 315 dtype_id_ = other.dtype_id_; 316 shape_vector_ = other.shape_vector_; 317 managed_by_somas_ = other.managed_by_somas_; 318 } 319 PointerRefCountPtr pointer_ref_count_; 320 TensorStorageInfoPtr tensor_storage_info_{nullptr}; 321 uint32_t stream_id_{0}; 322 size_t size_{0}; 323 Format format_{Format::DEFAULT_FORMAT}; 324 // The data enum type id of the KernelTensor. 325 TypeId dtype_id_{kTypeUnknown}; 326 // The device target name, such as "GPU","Ascend". 327 std::string device_name_; 328 // Represents the device card id associated with the KernelTensor. 329 uint32_t device_id_{0}; 330 // The origin flatten shape vector for Tensor/Scalar/Tuple/List. 331 // 1. For Tensor type, means its shape. For example, a Tensor with shape (8, 16), shape_vector_ is {8, 16}. 332 // 2. For Scalar type, shape_vector_ is an empty ShapeVector, i.e. {}. 333 // 3. For Tuple/List (all elements must be Tensor with same shape or Scalar) type, the shape_vector_ 334 // consists of the element number and the shape of element in Tuple/List. For example, if a Tuple of the structure 335 // ((8,16), (8,16)) contains two Tensors of shape (8, 16), then shape_vector_ is {2, 8, 16}, 2 means elements 336 // number in Tuple/List. A Tuple with a structure such as ((), ()) that contains two Scalar, the shape_vector_ of 337 // this Tuple is {2}. 338 ShapeVector shape_vector_{}; 339 bool managed_by_somas_{false}; 340 }; 341 using AddressCommonPtr = std::shared_ptr<AddressCommon>; 342 343 // KernelTensor is used to express input and output parameters of kernels. 344 // KernelTensor is a generalized Tensor semantics, which can represent not only Tensor, but also the meta-information 345 // of Scalar, Tuple, List and other data structures. It saves the shape, type, value and format information required by 346 // operators Infer and Launch, and provides related Get/Set interfaces. 347 class BACKEND_EXPORT KernelTensor : public AbstractBase { 348 public: 349 using Deleter = PointerRefCount::Deleter; 350 351 KernelTensor(); 352 ~KernelTensor() = default; 353 explicit KernelTensor(const AddressCommonPtr &address_common) : address_common_(address_common) {} 354 355 // Constructor of KernelTensor by shape, type, value. 356 KernelTensor(const abstract::BaseShapePtr &shape, const TypePtr &type, const ValuePtr &value); 357 358 // Constructor of KernelTensor by device info. 359 KernelTensor(void *device_ptr, size_t size, Format format, TypeId dtype_id, const ShapeVector &host_shape, 360 const string &device_name, uint32_t device_id, const UserDataPtr &user_data = nullptr); 361 362 // Constructor of KernelTensor by shape, type, value and device info. 363 KernelTensor(const abstract::BaseShapePtr &shape, const TypePtr &type, const ValuePtr &value, void *device_ptr, 364 size_t size, const std::string &format, TypeId dtype_id, const ShapeVector &host_shape, 365 const string &device_name, uint32_t device_id, const UserDataPtr &user_data = nullptr); 366 367 // Constructor of KernelTensor by shape, type, value and device info. 368 KernelTensor(const AddressCommonPtr &address_common, const abstract::BaseShapePtr &shape, const TypePtr &type, 369 const ValuePtr &value, const ShapeVector &host_shape, const UserDataPtr &user_data = nullptr); 370 371 KernelTensor(const KernelTensor &other); 372 373 MS_DECLARE_PARENT(KernelTensor, AbstractBase); 374 375 // Get the base shape for Tensor/Sequence/Scalar. 376 abstract::BaseShapePtr GetShape() const override { return shape_; } 377 378 // Set the base shape for Tensor/Sequence/Scalar. 379 // Note: for performance, the function `SetShape` uses type_id_, so need to SetType first. 380 void SetShape(const abstract::BaseShapePtr &shape); 381 382 // Get the shape vector for Tensor/Sequence/Scalar. 383 const ShapeVector &GetShapeVector() const { return address_common_->shape_vector_; } 384 385 // Set the shape vector for Tensor/Sequence/Scalar. 386 void SetShapeVector(const ShapeVector &shape_vector); 387 388 // Set the shape vector for Tensor/Sequence/Scalar with rvalue. 389 void SetShapeVector(ShapeVector &&shape_vector); 390 391 // Get the device shape vector for Tensor/Sequence/Scalar. 392 const ShapeVector &GetDeviceShapeVector() const; 393 394 // Get host shape for KernelTensor. 395 const ShapeVector &host_shape() const { return host_shape_; } 396 397 // Set host shape for KernelTensor. 398 void set_host_shape(const ShapeVector &host_shape) { host_shape_ = host_shape; } 399 400 // Get the object type of the KernelTensor. 401 TypePtr GetType() const override { return type_; } 402 403 // Set the type for the KernelTensor. 404 void SetType(const TypePtr &type); 405 406 // Check whether the host info exists. 407 bool host_info_exist() const { return host_info_ != nullptr; } 408 409 // Set host info after construct 410 void SetHostInfo(const abstract::BaseShapePtr &shape, const TypePtr &type, const ValuePtr &value); 411 412 // Get the object enum type id of the KernelTensor. 413 TypeId type_id() const { 414 MS_EXCEPTION_IF_NULL(host_info_); 415 return host_info_->type_id_; 416 } 417 418 // Get the data enum type id of the KernelTensor. 419 TypeId dtype_id() const { return address_common_->dtype_id_; } 420 421 // Set the data enum type id of the KernelTensor. 422 void set_dtype_id(TypeId dtype_id) { address_common_->dtype_id_ = dtype_id; } 423 424 // Set the value for the KernelTensor. 425 void SetValue(const ValuePtr &value) { value_ = value; } 426 427 // Get the value of the KernelTensor. 428 ValuePtr GetValue() const override; 429 430 // Get the address of the value converted to continuous memory storage. 431 const void *GetValuePtr(); 432 433 // Get the value in KernelTensor, return it if there is specific value, otherwise throw an exception. 434 template <typename T> 435 T GetValueWithCheck() { 436 auto value_opt = GetValue<T>(); 437 if (!value_opt.has_value()) { 438 MS_LOG(EXCEPTION) 439 << "Get value failed, there is no any value in KernelTensor." 440 "Here are the possible reasons:" 441 "1. When the operator KernelMod is registered, the data type is not correct, such as Scalar or Tuple, " 442 "but is registered as Tensor." 443 "2. If the KernelMod is registered correctly, it may be an attempt to GetValue the output of the " 444 "previous operator. During compilation, the output of the operator has no value. You can check the ir " 445 "file to see if the input for the current operator value is from an operator."; 446 } 447 return value_opt.value(); 448 } 449 450 // Get the scalar value store in KernelTensor if exists. 451 // Return the optional contain value if the KernelTensor has value, otherwise nullopt. 452 template <typename T, typename std::enable_if<std::is_scalar<std::decay_t<T>>::value>::type * = nullptr> 453 std::optional<T> GetValue() { 454 MS_EXCEPTION_IF_NULL(host_info_); 455 std::lock_guard<std::mutex> lock(host_info_->value_mutex_); 456 457 // There is a origin value in KernelTensor(maybe come from a ValueNode). 458 if (address_common_->dtype_id_ == kMetaTypeNone) { 459 MS_LOG(DEBUG) << "None type has no valid scalar value."; 460 return std::nullopt; 461 } else if (value_ && !value_->isa<ValueAny>()) { 462 if (host_info_->kernel_tensor_value_ == nullptr) { 463 host_info_->kernel_tensor_value_ = ConvertValueToKernelTensorValue(value_); 464 } 465 } else { 466 // Sync value data from device. 467 if (!SyncDataFromDeviceToHost()) { 468 MS_LOG(ERROR) << "Sync data from device to host side failed"; 469 return std::nullopt; 470 } 471 } 472 473 MS_EXCEPTION_IF_NULL(host_info_->kernel_tensor_value_); 474 MS_EXCEPTION_IF_CHECK_FAIL((host_info_->kernel_tensor_value_->GetDataSize() == sizeof(T)), 475 "The data size in kernel tensor value which contains a scalar [" + 476 std::to_string(host_info_->kernel_tensor_value_->GetDataSize()) + 477 "] is not equal to the data type size [" + std::to_string(sizeof(T)) + "]"); 478 479 const T *data_ptr = reinterpret_cast<const T *>(host_info_->kernel_tensor_value_->GetDataPtr()); 480 MS_EXCEPTION_IF_NULL(data_ptr); 481 return *data_ptr; 482 } 483 484 // Get the std::vector/std::string value store in KernelTensor if exists. 485 // Return the optional contain value if the KernelTensor has value, otherwise nullopt. 486 template <typename T, typename std::enable_if<IsValidContainer<T>::value>::type * = nullptr> 487 std::optional<T> GetValue() { 488 if (!std::is_scalar_v<typename T::value_type>) { 489 MS_LOG(EXCEPTION) << "The element of std::vector to get kernel tensor's value should be scalar type."; 490 } 491 MS_EXCEPTION_IF_NULL(host_info_); 492 std::lock_guard<std::mutex> lock(host_info_->value_mutex_); 493 494 // There is a origin value in KernelTensor(maybe come from a ValueNode). 495 if (address_common_->dtype_id_ == kMetaTypeNone) { 496 MS_LOG(DEBUG) << "None type has no valid value for vector or string."; 497 return std::nullopt; 498 } else if (value_ && !value_->isa<ValueAny>()) { 499 if (host_info_->kernel_tensor_value_ == nullptr) { 500 host_info_->kernel_tensor_value_ = ConvertValueToKernelTensorValue(value_); 501 } 502 } else { 503 // Sync value data from device. 504 if (!SyncDataFromDeviceToHost()) { 505 MS_LOG(ERROR) << "Sync data from device to host side failed"; 506 return std::nullopt; 507 } 508 } 509 510 MS_EXCEPTION_IF_NULL(host_info_->kernel_tensor_value_); 511 size_t element_num = host_info_->kernel_tensor_value_->GetDataSize() / sizeof(typename T::value_type); 512 if (element_num == 0) { 513 return T(); 514 } 515 const typename T::value_type *data_ptr = 516 reinterpret_cast<const typename T::value_type *>(host_info_->kernel_tensor_value_->GetDataPtr()); 517 MS_EXCEPTION_IF_NULL(data_ptr); 518 519 return T(data_ptr, data_ptr + element_num); 520 } 521 522 // Get the value stored in KernelTensor for type which is not scalar, std::vector or std::string if exists. 523 // Return the optional contain value if the KernelTensor has value, otherwise nullopt. 524 template <typename T, typename std::enable_if<!IsValidContainer<T>::value && !std::is_pointer_v<T> && 525 !std::is_scalar<std::decay_t<T>>::value>::type * = nullptr> 526 std::optional<T> GetValue() { 527 if (address_common_->dtype_id_ == kMetaTypeNone) { 528 MS_LOG(DEBUG) << "None type has no valid value."; 529 return std::nullopt; 530 } 531 if (value_ && !value_->isa<ValueAny>()) { 532 return mindspore::GetValue<T>(value_); 533 } 534 return std::nullopt; 535 } 536 537 // Get the value in KernelTensor, return it if there is specific value, otherwise throw an exception. 538 template <typename T> 539 std::optional<T> GetOptionalValueWithCheck() { 540 if (value_ && value_->isa<None>()) { 541 return std::nullopt; 542 } 543 return GetValueWithCheck<T>(); 544 } 545 546 // Get the data format. 547 mindspore::Format format() const { return address_common_->format_; } 548 549 // Set the data format. 550 void set_format(mindspore::Format format) { address_common_->format_ = format; } 551 552 // Get the data format of string type. 553 std::string GetStringFormat() const; 554 555 // Set the data format of string type. 556 void SetStringFormat(const std::string &format); 557 558 // Get pointer and reference count. 559 const PointerRefCountPtr &pointer_ref_count() const { return address_common_->pointer_ref_count_; } 560 561 // Set pointer and reference count. 562 void set_pointer_ref_count(const PointerRefCountPtr &ptr_ref_cnt) { 563 address_common_->pointer_ref_count_ = ptr_ref_cnt; 564 } 565 566 // Set the pointer and reference count to nullptr, resource reclaiming of the device pointer is automatically 567 // released. 568 void ReleaseDeviceRes() { address_common_->pointer_ref_count_ = nullptr; } 569 570 // Set pointer resource destructor. 571 void set_deleter(const Deleter &deleter) { address_common_->pointer_ref_count_->set_deleter(deleter); } 572 573 // Get pointer to the device side that corresponds to KernelTensor, used in runtime. 574 void *device_ptr() const { return address_common_->pointer_ref_count_->ptr(); } 575 576 // Set pointer to the device side that corresponds to KernelTensor, used in runtime. 577 void set_device_ptr(void *ptr) { address_common_->pointer_ref_count_->set_ptr(ptr); } 578 579 // Get the memory size in byte of the KernelTensor. 580 size_t size() const { return address_common_->size_; } 581 582 // Set the memory size in byte of the KernelTensor. 583 void set_size(size_t size) { address_common_->size_ = size; } 584 585 // Get device target name, such "GPU","Ascend". 586 const std::string &device_name() const { return address_common_->device_name_; } 587 588 // Set device target name, such "GPU","Ascend". 589 void set_device_name(const std::string &device_name) { address_common_->device_name_ = device_name; } 590 591 // Get device id. 592 uint32_t device_id() const { return address_common_->device_id_; } 593 594 // Set device id. 595 void set_device_id(uint32_t device_id) { address_common_->device_id_ = device_id; } 596 597 // Get logical stream id. 598 uint32_t stream_id() const { return address_common_->stream_id_; } 599 600 // Set logical stream id. 601 void set_stream_id(uint32_t stream_id) { address_common_->stream_id_ = stream_id; } 602 603 // Get task id on stream. 604 std::shared_ptr<int64_t> task_id_on_stream() const { return task_id_on_stream_; } 605 606 // Set task id on stream. 607 void set_task_id_on_stream(const std::shared_ptr<int64_t> &task_id_on_stream) { 608 task_id_on_stream_ = task_id_on_stream; 609 } 610 611 bool managed_by_somas() const { return address_common_->managed_by_somas_; } 612 613 void set_managed_by_somas(bool managed_by_somas) { address_common_->managed_by_somas_ = managed_by_somas; } 614 615 // Get user data maintained by the KernelTensor. 616 const UserDataPtr &user_data() const { return user_data_; } 617 618 // Set user data to the KernelTensor. 619 void set_user_data(const UserDataPtr &user_data) { user_data_ = user_data; } 620 621 // Set device synchronizer to the KernelTensor. 622 void set_device_synchronizer(const DeviceSynchronizerPtr &device_synchronizer) { 623 device_synchronizer_ = device_synchronizer; 624 } 625 626 // Clone a new KernelTensor from this. 627 std::shared_ptr<KernelTensor> CloneKernelTensor() { return std::make_shared<KernelTensor>(*this); } 628 629 // Check whether the shape is dynamic shape(contains dim which is less than 0). 630 bool IsDynamicShape() const; 631 632 // Check whether the KernelTensor is from a constant variable(such as ValueNode). 633 inline bool IsConstValue() const { return (value_ != nullptr) && !(value_->isa<ValueAny>()); } 634 635 // The following four methods are only used in the Lite framework. 636 // Get the device data address(pointer and size). 637 AddressPtr GetData() const { return data_; } 638 // Set the device data address(pointer and size). 639 void SetData(const AddressPtr &data) { data_ = data; } 640 // Get the host data address(pointer and size). 641 AddressPtr GetHostData() const { return host_data_; } 642 // Set the host data address(pointer and size). 643 void SetHostData(const AddressPtr &data) { host_data_ = data; } 644 645 // max shape is only used in compute-depended ops 646 ShapeVector GetMaxShape() const; 647 648 const TensorStorageInfoPtr tensor_storage_info() const { return address_common_->tensor_storage_info_; } 649 void set_tensor_storage_info(const TensorStorageInfoPtr &storage_info) { 650 address_common_->tensor_storage_info_ = storage_info; 651 } 652 653 const AddressCommonPtr address_common() const { return address_common_; } 654 void set_address_common(const AddressCommonPtr &address_common) { address_common_ = address_common; } 655 656 private: 657 // This is a deprecated function in base class. 658 BaseShapePtr BuildShape() const override { 659 MS_LOG(EXCEPTION) << "Call deprecated function: BuildShape, Please use GetShape instead of BuildShape in " 660 "operators' infer functions in the `core/ops` directory."; 661 } 662 663 // This is a deprecated function in base class 664 TypePtr BuildType() const override { 665 MS_LOG(EXCEPTION) << "Call deprecated function: BuildType, Please use GetType instead of BuildType in " 666 "operators' infer functions in the `core/ops` directory."; 667 } 668 669 // Set the element data type to KernelTensor for Sequence type(Tuple or List). 670 void SetSequenceDType(const TypePtr &element_type); 671 672 // Synchronize value data from device to host side. 673 bool SyncDataFromDeviceToHost() const; 674 675 // Calculate memory size need by the KernelTensor. 676 void CalculateMemSize(); 677 678 // Check whether need to transpose host infer shape to device shape. 679 bool NeedTransposeToDeviceShape() const noexcept; 680 681 // Transpose host infer shape to device shape according format. 682 const ShapeVector &TransposeToDeviceShape() const; 683 684 // If host info is not initialized in the constructor, initialize it when you need it, making sure that host info is 685 // not empty when used. 686 void CheckHostInfoValid(); 687 688 // The host-side related data in KernelTensor. 689 // Note: To improve the performance of constructing KernelTensor, allow some constructors not to initialize host info. 690 // If host info is not initialized in the constructor, it can be initialized when it is needed. 691 std::unique_ptr<KernelHostInfo> host_info_{nullptr}; 692 693 // The launch index on stream managed by framework. 694 std::shared_ptr<int64_t> task_id_on_stream_{nullptr}; 695 696 // The flatten shape(maybe after padding) vector. 697 // Note: the 'host_shape_' will be repalced by 'shape_vector_' in the future. 698 ShapeVector host_shape_{}; 699 700 // User data is the extra data required by the kernel or framework. 701 UserDataPtr user_data_{nullptr}; 702 703 // For synchronizing data between device and host. 704 DeviceSynchronizerPtr device_synchronizer_{nullptr}; 705 706 // The following two variables are only used in the Lite framework. 707 // Device data address. 708 AddressPtr data_{nullptr}; 709 // Host data address. 710 AddressPtr host_data_{nullptr}; 711 712 // address basic info 713 AddressCommonPtr address_common_{nullptr}; 714 }; 715 using KernelTensorPtr = std::shared_ptr<KernelTensor>; 716 717 enum class KernelModType { 718 Invalid = 0, 719 KernelMod, 720 GpuKernelMod, 721 NativeGpuKernelMod, 722 CpuKernelMod, 723 NativeCpuKernelMod, 724 HostKernelMod, 725 DynamicAkgCpuKernelMod, 726 }; 727 728 // The info of kernel launch. 729 struct KernelLaunchInfo { 730 std::vector<KernelTensor *> inputs_; 731 std::vector<KernelTensor *> outputs_; 732 std::vector<KernelTensor *> workspaces_; 733 }; 734 735 enum KernelErrorCode : int { KRET_OK = 0, KRET_RESIZE_FAILED = 1, KRET_UNKNOWN_SHAPE = 2, KRET_UNKNOWN_OUT_SHAPE = 3 }; 736 737 class BACKEND_EXPORT KernelMod { 738 public: 739 KernelMod() = default; 740 virtual ~KernelMod() = default; 741 742 virtual std::vector<KernelAttr> GetOpSupport() = 0; 743 744 virtual bool Init(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) { 745 MS_LOG(EXCEPTION) << "The KernelMod[" << kernel_name_ << "] doesn't implement virtual function 'Init'"; 746 } 747 748 inline bool Init(const PrimitivePtr &primitive, const std::vector<KernelTensor *> &inputs, 749 const std::vector<KernelTensor *> &outputs) { 750 primitive_ = primitive; 751 MS_EXCEPTION_IF_NULL(primitive_); 752 kernel_name_ = primitive_->name(); 753 754 return Init(inputs, outputs); 755 } 756 757 // Resize() is for validating input/output shape and calculating the workspace size, framework will invoke this 758 // routine after infer shape. 759 virtual int Resize(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs); 760 761 virtual bool Launch(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &workspace, 762 const std::vector<KernelTensor *> &outputs, void *stream_ptr) { 763 return true; 764 } 765 766 // Some kernels, e.g., Unique, can only get its output shape after its computing finished. 767 virtual bool IsNeedUpdateOutputShapeAndSize() { return false; } 768 virtual void UpdateOutputShapeAndSize(const std::vector<KernelTensor *> &inputs, 769 const std::vector<KernelTensor *> &outputs) {} 770 771 // Some kernels, e.g., Shape/Reshape, don't use some input addresses in the kernel launch. 772 virtual std::vector<size_t> GetLaunchIgnoredInputAddressIdx() const { return {}; } 773 774 void SetDevicedId(uint32_t device_id) { device_id_ = device_id; } 775 virtual enum KernelModType GetKernelModType() const { return KernelModType::KernelMod; } 776 777 virtual void SetInputSizeList(const std::vector<size_t> &size_list) { input_size_list_ = size_list; } 778 virtual void SetOutputSizeList(const std::vector<size_t> &size_list) { output_size_list_ = size_list; } 779 virtual void SetWorkspaceSizeList(const std::vector<size_t> &size_list) { workspace_size_list_ = size_list; } 780 const std::vector<size_t> &GetInputSizeList() const { MS_LOG(EXCEPTION) << "Call deprecated interface."; } 781 virtual const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; } 782 virtual const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; } 783 784 const PrimitivePtr &primitive() const { return primitive_; } 785 const std::string &kernel_name() const { return kernel_name_; } 786 787 virtual std::vector<size_t> GenParameters() { return {}; } 788 virtual void GenAtomicInitInfo(AtomicInitInfo *info) {} 789 790 virtual void set_unique_name(const std::string &unique_name) { 791 MS_LOG(EXCEPTION) << "Call the method which doesn't implement"; 792 } 793 794 virtual void set_fullname(const std::string &fullname) { 795 MS_LOG(EXCEPTION) << "Call the method which doesn't implement"; 796 } 797 798 virtual void set_is_monad(bool is_monad) { MS_LOG(EXCEPTION) << "Call the method which doesn't implement"; } 799 800 // If output of kernel has a user_data, it needs to return true, and the framework will create user_data for it. 801 virtual bool need_user_data() const { return false; } 802 803 int32_t task_id() const { return task_id_; } 804 bool use_kernel_tensor() const { return use_kernel_tensor_; } 805 void set_use_kernel_tensor(bool use_kernel_tensor) { use_kernel_tensor_ = use_kernel_tensor; } 806 807 uint32_t record_stream_id() const { return record_stream_id_; } 808 void set_record_stream_id(uint32_t record_stream_id) { record_stream_id_ = record_stream_id; } 809 810 virtual bool Finalize() { return true; } 811 812 protected: 813 bool IsValidShape(const ShapeVector &shape) const { 814 if (std::any_of(shape.begin(), shape.end(), [](int64_t dim) { return dim < 0; })) { 815 return false; 816 } 817 return true; 818 } 819 820 protected: 821 std::string kernel_name_; 822 PrimitivePtr primitive_; 823 uint32_t device_id_ = 0; 824 std::vector<size_t> input_size_list_; 825 std::vector<size_t> output_size_list_; 826 std::vector<size_t> workspace_size_list_; 827 828 int32_t task_id_ = -1; 829 bool use_kernel_tensor_{false}; 830 uint32_t record_stream_id_{0}; 831 }; 832 using KernelModPtr = std::shared_ptr<KernelMod>; 833 834 template <typename T> 835 inline T *GetDeviceAddress(const std::vector<KernelTensor *> &addr_list, size_t index) { 836 if (index >= addr_list.size()) { 837 MS_LOG(ERROR) << "Address index(" << index << ") out of range(" << addr_list.size() << ")"; 838 return nullptr; 839 } 840 841 if (addr_list[index] == nullptr) { 842 MS_LOG(ERROR) << "The device address is nullptr, address index: " << index << ", and the length of 'addr_list' is " 843 << addr_list.size(); 844 return nullptr; 845 } 846 847 if (addr_list[index]->device_ptr() == nullptr) { 848 MS_LOG(WARNING) << "The memory of device address is nullptr, address index: " << index 849 << ", and the length of 'addr_list' is " << addr_list.size(); 850 return nullptr; 851 } 852 853 // When the input is an empty tuple, the input size will be 0. 854 if (addr_list[index]->size() == 0) { 855 MS_LOG(INFO) << "The size of device address is zero, address index: " << index 856 << ", and the length of 'addr_list' is " << addr_list.size(); 857 } 858 return reinterpret_cast<T *>(addr_list[index]->device_ptr()); 859 } 860 861 BACKEND_EXPORT std::vector<std::vector<int64_t>> GetShapes(const std::vector<KernelTensor *> &tensors); 862 863 BACKEND_EXPORT void ConvertLaunchInfoToAddr(const KernelLaunchInfo &launch_info, KernelLaunchAddr *mem_info); 864 865 template <typename T> 866 inline bool CheckNullInput(const std::vector<T> &input_shape) { 867 // If input_shape.size() == 0, it means a scalar input; If input_shape.size() != 0 and input_shape contains 0, 868 // it means a null input. Just return a null output. 869 if (input_shape.size() != 0) { 870 if (std::any_of(input_shape.begin(), input_shape.end(), [](T i) { return i == 0; })) { 871 return true; 872 } 873 } 874 return false; 875 } 876 #define CHECK_NULL_INPUT(input_shape) mindspore::kernel::CheckNullInput(input_shape) 877 878 template <typename T> 879 inline bool CheckShapeNull(const std::vector<T> &shape, std::string kernel_name, std::string param_name) { 880 if (CHECK_NULL_INPUT(shape)) { 881 MS_LOG(WARNING) << "For '" << kernel_name << "', the shape of " << param_name << " cannot contain zero, but got " 882 << shape; 883 return true; 884 } 885 return false; 886 } 887 888 #define CHECK_SHAPE_NULL(shape, kernel_name, param_name) \ 889 mindspore::kernel::CheckShapeNull(shape, kernel_name, param_name) 890 } // namespace kernel 891 } // namespace mindspore 892 893 #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_H_ 894