• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_H_
17 #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_H_
18 #include <cstddef>
19 #include <atomic>
20 #include <map>
21 #include <memory>
22 #include <optional>
23 #include <set>
24 #include <string>
25 #include <utility>
26 #include <variant>
27 #include <vector>
28 #include <algorithm>
29 #include "abstract/dshape.h"
30 #include "abstract/ops/primitive_infer_map.h"
31 #include "include/api/format.h"
32 #include "include/backend/visible.h"
33 #include "include/common/utils/utils.h"
34 #include "include/common/utils/convert_utils.h"
35 #include "include/backend/device_synchronizer.h"
36 #include "ir/anf.h"
37 #include "ir/dtype.h"
38 #include "ir/tensor.h"
39 #include "ir/kernel_tensor_value.h"
40 #include "mindspore/core/ops/base_operator.h"
41 #include "nlohmann/json.hpp"
42 #include "utils/log_adapter.h"
43 #include "ops/op_name.h"
44 #include "kernel/format_utils.h"
45 
46 #ifdef _MSC_VER
47 #undef OPAQUE
48 #endif
49 
50 #ifdef OPAQUE
51 #undef OPAQUE
52 #endif
53 
54 namespace mindspore {
55 enum KernelType : int {
56   UNKNOWN_KERNEL_TYPE = 0,
57   AKG_KERNEL,
58   AICPU_KERNEL,
59   RT_KERNEL,
60   HCCL_KERNEL,
61   TBE_KERNEL,
62   HOST_KERNEL,
63   CPU_KERNEL,
64   GPU_KERNEL,
65   BISHENG_KERNEL,
66   ACL_KERNEL,
67   OPAPI_KERNEL,
68   INTERNAL_KERNEL,
69 };
70 
71 // PointerRefCount encapsulates pointer and reference count-related operations, and supports custom deleter to free
72 // resources. In Ref scenarios, KernelTensor of different DeviceAddress may hold the same PointerRefCount object.
73 class PointerRefCount {
74  public:
75   // The arguments are pointer and a bool variable that identifies whether pointer is from the memory pool.
76   using Deleter = std::function<void(void *, bool)>;
77 
78   PointerRefCount() = default;
PointerRefCount(void * ptr)79   explicit PointerRefCount(void *ptr) : ptr_(ptr) {}
PointerRefCount(void * ptr,const Deleter & deleter)80   PointerRefCount(void *ptr, const Deleter &deleter) : ptr_(ptr), deleter_(deleter) {}
81 
PointerRefCount(const PointerRefCount & other)82   PointerRefCount(const PointerRefCount &other)
83       : ptr_(other.ptr_),
84         original_ref_count_(other.original_ref_count_),
85         ref_count_(other.ref_count_.load()),
86         dynamic_ref_count_(other.dynamic_ref_count_.load()),
87         deleter_(other.deleter_) {}
88 
~PointerRefCount()89   ~PointerRefCount() {
90     try {
91       if (ptr_ != nullptr && deleter_) {
92         deleter_(ptr_, from_mem_pool_);
93       }
94       ptr_ = nullptr;
95     } catch (const std::exception &e) {
96       MS_LOG(ERROR) << "PointerRefCount destructed failed: " << e.what();
97     } catch (...) {
98       MS_LOG(ERROR) << "PointerRefCount destructed failed.";
99     }
100   }
101 
102   // Get raw pointer.
ptr()103   void *ptr() const { return ptr_; }
104   // Set raw pointer.
set_ptr(void * ptr)105   void set_ptr(void *ptr) { ptr_ = ptr; }
106 
107   // Get whether pointer in PointerRefCount is allocated from the memory pool.
from_mem_pool()108   bool from_mem_pool() const { return from_mem_pool_; }
109   // Set whether pointer in PointerRefCount is allocated from the memory pool.
set_from_mem_pool(bool from_mem_pool)110   void set_from_mem_pool(bool from_mem_pool) { from_mem_pool_ = from_mem_pool; }
111 
112   // Increase ref count or dynamic ref count.
IncreaseCounter()113   size_t IncreaseCounter() {
114     if (ref_count_ != SIZE_MAX) {
115       return ++ref_count_;
116     } else if (dynamic_ref_count_ != INT32_MAX) {
117       return ++dynamic_ref_count_;
118     }
119     return SIZE_MAX;
120   }
121   // Decrease ref count or dynamic ref count.
DecreaseCounter()122   size_t DecreaseCounter() {
123     if (ref_count_ != SIZE_MAX) {
124       return --ref_count_;
125     } else if (dynamic_ref_count_ != INT32_MAX) {
126       return --dynamic_ref_count_;
127     }
128     return SIZE_MAX;
129   }
130 
131   // The related interface of static reference count operation.
set_original_ref_count(size_t original_ref_count)132   void set_original_ref_count(size_t original_ref_count) { original_ref_count_ = original_ref_count; }
original_ref_count()133   size_t original_ref_count() const { return original_ref_count_; }
set_ref_count(size_t ref_count)134   void set_ref_count(size_t ref_count) { ref_count_ = ref_count; }
ref_count()135   size_t ref_count() const { return ref_count_.load(); }
IncreaseOriginalRefCount()136   void IncreaseOriginalRefCount() {
137     if (original_ref_count_ < SIZE_MAX) {
138       original_ref_count_++;
139     }
140   }
DecreaseOriginalRefCount()141   void DecreaseOriginalRefCount() {
142     if ((original_ref_count_ < SIZE_MAX) && (original_ref_count_ > 0)) {
143       original_ref_count_--;
144     }
145   }
DecreaseRefCount()146   size_t DecreaseRefCount() { return --ref_count_; }
ResetRefCount()147   void ResetRefCount() { ref_count_ = original_ref_count_; }
148 
149   // The related interface of dynamic reference count operation.
set_dynamic_ref_count(int32_t dynamic_ref_count)150   void set_dynamic_ref_count(int32_t dynamic_ref_count) { dynamic_ref_count_ = dynamic_ref_count; }
dynamic_ref_count()151   int32_t dynamic_ref_count() const { return dynamic_ref_count_; }
IncreaseDynamicRefCount(const std::string & op_object)152   void IncreaseDynamicRefCount(const std::string &op_object) {
153     if (dynamic_ref_count_ < INT32_MAX) {
154       (void)++dynamic_ref_count_;
155       MS_LOG(DEBUG) << op_object << " increases dynamic ref count to:" << dynamic_ref_count_ << " for ptr:" << ptr();
156     }
157   }
DecreaseDynamicRefCount(const std::string & op_object)158   int32_t DecreaseDynamicRefCount(const std::string &op_object) {
159     if (dynamic_ref_count_ <= 0) {
160       MS_LOG(EXCEPTION) << "The dynamic reference count is invalid value:" << dynamic_ref_count_;
161     }
162     MS_LOG(DEBUG) << op_object << " The dynamic ref count decreases to:" << dynamic_ref_count_ << " for ptr:" << ptr();
163     return --dynamic_ref_count_;
164   }
165 
166   // Get pointer resource destructor.
deleter()167   Deleter deleter() const { return deleter_; }
168 
169   // Set pointer resource destructor.
set_deleter(const Deleter & deleter)170   void set_deleter(const Deleter &deleter) { deleter_ = deleter; }
171 
172  private:
173   void *ptr_{nullptr};
174 
175   // Whether ptr_  is allocated from the memory pool.
176   bool from_mem_pool_{false};
177 
178   // The static reference count, the value can be calculated at compile phase.
179   size_t original_ref_count_{1};
180   // The current reference count value, it will be decreased in the running, and reset by original_ref_count_ when it is
181   // zero.
182   std::atomic<size_t> ref_count_{1};
183 
184   // The dynamic reference count, the value can be calculated at compile phase.
185   std::atomic_int32_t dynamic_ref_count_{INT32_MAX};
186 
187   // The pointer resource destructor.
188   Deleter deleter_;
189 };
190 using PointerRefCountPtr = std::shared_ptr<PointerRefCount>;
191 
192 namespace kernel {
193 
194 // Backend processor
195 enum Processor {
196   UNKNOWN = -1,
197   AICORE = 0,
198   AICPU,
199   CUDA,
200   CPU,
201   BISHENG,
202 };
203 
204 struct AtomicInitInfo {
205   std::vector<std::string> dtype_list;
206   std::vector<int64_t> init_value_int64_list;
207   std::vector<float> init_value_float_list;
208 };
209 
210 /**
211  * @brief base class for autotensor kernel and cce kernel.
212  */
213 struct Address {
AddressAddress214   Address() : addr(nullptr), size(0) {}
AddressAddress215   Address(void *address_addr, size_t address_size) : addr(address_addr), size(address_size) {}
216   void *addr;
217   size_t size;
218 };
219 using AddressPtr = std::shared_ptr<Address>;
220 using AddressPtrList = std::vector<AddressPtr>;
221 using StreamType = void *;
222 using abstract::AbstractBase;
223 using device::DeviceSynchronizerPtr;
224 // The memory info of kernel launch.
225 struct KernelLaunchAddr {
226   AddressPtrList inputs_;
227   AddressPtrList outputs_;
228   AddressPtrList workspaces_;
229 };
230 struct TensorInfo {
231   mindspore::Format format;
232   abstract::AbstractTensorPtr base_;
233 };
234 struct ScalarInfo {
235   abstract::AbstractScalarPtr base_;
236 };
237 struct ListInfo {
238   abstract::AbstractListPtr base_;
239 };
240 struct TupleInfo {
241   abstract::AbstractTuplePtr base_;
242 };
243 using TensorInfoPtr = std::shared_ptr<TensorInfo>;
244 using BaseOperatorPtr = std::shared_ptr<ops::BaseOperator>;
245 
246 class KernelAttr;
247 
248 // Used to encapsulate host-side related data structures in KernelTensor.
249 struct KernelHostInfo {
250   KernelHostInfo() = default;
251 
252   KernelHostInfo(const KernelHostInfo &other);
253 
254   // The shape vector transformed according `shape_vector_` and `format_` is generally used on the operator side.
255   // Operators on different platforms may require different format and shape information.
256   ShapeVector shape_vector_after_format_trasform_{};
257 
258   // Make shape transform related interfaces thread-safe.
259   std::mutex shape_transform_mutex_;
260 
261   // The object enum type id of the KernelTensor.
262   TypeId type_id_{kTypeUnknown};
263 
264   // Saves the contents after the value is converted to continuous memory storage.
265   KernelTensorValuePtr kernel_tensor_value_{nullptr};
266 
267   // Make GetValue related interfaces thread-safe.
268   std::mutex value_mutex_;
269 };
270 
271 // A template class used to detect whether it is a valid container.
272 template <typename T>
273 struct ValidContainerChecker : std::false_type {};
274 
275 // A ValidContainerChecker's specialization to detect whether the type is std::vector whose element is scalar.
276 template <typename... Args>
277 struct ValidContainerChecker<std::vector<Args...>> : std::true_type {};
278 
279 // A ValidContainerChecker's specialization to detect whether the type is std::string.
280 template <>
281 struct ValidContainerChecker<std::string> : std::true_type {};
282 
283 // A wrapper used to check the types std::string and std::vector.
284 template <typename T>
285 struct IsValidContainer {
286   static constexpr bool value = ValidContainerChecker<std::decay_t<T>>::value;
287 };
288 
289 struct AddressCommon {
290   AddressCommon() { pointer_ref_count_ = std::make_shared<PointerRefCount>(); }
291   AddressCommon(void *device_ptr, size_t size)
292       : pointer_ref_count_(std::make_shared<PointerRefCount>(device_ptr)), size_(size) {}
293   AddressCommon(void *device_ptr, size_t size, const ShapeVector &shape_vector, const Format &format, TypeId dtype_id,
294                 const std::string &device_name, uint32_t device_id, uint32_t stream_id = 0)
295       : pointer_ref_count_(std::make_shared<PointerRefCount>(device_ptr)),
296         stream_id_(stream_id),
297         size_(size),
298         format_(format),
299         dtype_id_(dtype_id),
300         device_name_(device_name),
301         device_id_(device_id),
302         shape_vector_(shape_vector) {}
303   AddressCommon(const AddressCommon &other) {
304     pointer_ref_count_ =
305       other.pointer_ref_count_ != nullptr
306         ? std::make_shared<PointerRefCount>(other.pointer_ref_count_->ptr(), other.pointer_ref_count_->deleter())
307         : std::make_shared<PointerRefCount>();
308     tensor_storage_info_ = other.tensor_storage_info_;
309     stream_id_ = other.stream_id_;
310     size_ = other.size_;
311     format_ = other.format_;
312     dtype_id_ = other.dtype_id_;
313     device_id_ = other.device_id_;
314     device_name_ = other.device_name_;
315     dtype_id_ = other.dtype_id_;
316     shape_vector_ = other.shape_vector_;
317     managed_by_somas_ = other.managed_by_somas_;
318   }
319   PointerRefCountPtr pointer_ref_count_;
320   TensorStorageInfoPtr tensor_storage_info_{nullptr};
321   uint32_t stream_id_{0};
322   size_t size_{0};
323   Format format_{Format::DEFAULT_FORMAT};
324   // The data enum type id of the KernelTensor.
325   TypeId dtype_id_{kTypeUnknown};
326   // The device target name, such as "GPU","Ascend".
327   std::string device_name_;
328   // Represents the device card id associated with the KernelTensor.
329   uint32_t device_id_{0};
330   // The origin flatten shape vector for Tensor/Scalar/Tuple/List.
331   // 1. For Tensor type, means its shape. For example, a Tensor with shape (8, 16), shape_vector_ is {8, 16}.
332   // 2. For Scalar type, shape_vector_ is an empty ShapeVector, i.e. {}.
333   // 3. For Tuple/List (all elements must be Tensor with same shape or Scalar) type, the shape_vector_
334   // consists of the element number and the shape of element in Tuple/List. For example, if a Tuple of the structure
335   // ((8,16), (8,16)) contains two Tensors of shape (8, 16), then shape_vector_ is {2, 8, 16}, 2 means elements
336   // number in Tuple/List. A Tuple with a structure such as ((), ()) that contains two Scalar, the shape_vector_ of
337   // this Tuple is {2}.
338   ShapeVector shape_vector_{};
339   bool managed_by_somas_{false};
340 };
341 using AddressCommonPtr = std::shared_ptr<AddressCommon>;
342 
343 // KernelTensor is used to express input and output parameters of kernels.
344 // KernelTensor is a generalized Tensor semantics, which can represent not only Tensor, but also the meta-information
345 // of Scalar, Tuple, List and other data structures. It saves the shape, type, value and format information required by
346 // operators Infer and Launch, and provides related Get/Set interfaces.
347 class BACKEND_EXPORT KernelTensor : public AbstractBase {
348  public:
349   using Deleter = PointerRefCount::Deleter;
350 
351   KernelTensor();
352   ~KernelTensor() = default;
353   explicit KernelTensor(const AddressCommonPtr &address_common) : address_common_(address_common) {}
354 
355   // Constructor of KernelTensor by shape, type, value.
356   KernelTensor(const abstract::BaseShapePtr &shape, const TypePtr &type, const ValuePtr &value);
357 
358   // Constructor of KernelTensor by device info.
359   KernelTensor(void *device_ptr, size_t size, Format format, TypeId dtype_id, const ShapeVector &host_shape,
360                const string &device_name, uint32_t device_id, const UserDataPtr &user_data = nullptr);
361 
362   // Constructor of KernelTensor by shape, type, value and device info.
363   KernelTensor(const abstract::BaseShapePtr &shape, const TypePtr &type, const ValuePtr &value, void *device_ptr,
364                size_t size, const std::string &format, TypeId dtype_id, const ShapeVector &host_shape,
365                const string &device_name, uint32_t device_id, const UserDataPtr &user_data = nullptr);
366 
367   // Constructor of KernelTensor by shape, type, value and device info.
368   KernelTensor(const AddressCommonPtr &address_common, const abstract::BaseShapePtr &shape, const TypePtr &type,
369                const ValuePtr &value, const ShapeVector &host_shape, const UserDataPtr &user_data = nullptr);
370 
371   KernelTensor(const KernelTensor &other);
372 
373   MS_DECLARE_PARENT(KernelTensor, AbstractBase);
374 
375   // Get the base shape for Tensor/Sequence/Scalar.
376   abstract::BaseShapePtr GetShape() const override { return shape_; }
377 
378   // Set the base shape for Tensor/Sequence/Scalar.
379   // Note: for performance, the function `SetShape` uses type_id_, so need to SetType first.
380   void SetShape(const abstract::BaseShapePtr &shape);
381 
382   // Get the shape vector for Tensor/Sequence/Scalar.
383   const ShapeVector &GetShapeVector() const { return address_common_->shape_vector_; }
384 
385   // Set the shape vector for Tensor/Sequence/Scalar.
386   void SetShapeVector(const ShapeVector &shape_vector);
387 
388   // Set the shape vector for Tensor/Sequence/Scalar with rvalue.
389   void SetShapeVector(ShapeVector &&shape_vector);
390 
391   // Get the device shape vector for Tensor/Sequence/Scalar.
392   const ShapeVector &GetDeviceShapeVector() const;
393 
394   // Get host shape for KernelTensor.
395   const ShapeVector &host_shape() const { return host_shape_; }
396 
397   // Set host shape for KernelTensor.
398   void set_host_shape(const ShapeVector &host_shape) { host_shape_ = host_shape; }
399 
400   // Get the object type of the KernelTensor.
401   TypePtr GetType() const override { return type_; }
402 
403   // Set the type for the KernelTensor.
404   void SetType(const TypePtr &type);
405 
406   // Check whether the host info exists.
407   bool host_info_exist() const { return host_info_ != nullptr; }
408 
409   // Set host info after construct
410   void SetHostInfo(const abstract::BaseShapePtr &shape, const TypePtr &type, const ValuePtr &value);
411 
412   // Get the object enum type id of the KernelTensor.
413   TypeId type_id() const {
414     MS_EXCEPTION_IF_NULL(host_info_);
415     return host_info_->type_id_;
416   }
417 
418   // Get the data enum type id of the KernelTensor.
419   TypeId dtype_id() const { return address_common_->dtype_id_; }
420 
421   // Set the data enum type id of the KernelTensor.
422   void set_dtype_id(TypeId dtype_id) { address_common_->dtype_id_ = dtype_id; }
423 
424   // Set the value for the KernelTensor.
425   void SetValue(const ValuePtr &value) { value_ = value; }
426 
427   // Get the value of the KernelTensor.
428   ValuePtr GetValue() const override;
429 
430   // Get the address of the value converted to continuous memory storage.
431   const void *GetValuePtr();
432 
433   // Get the value in KernelTensor, return it if there is specific value, otherwise throw an exception.
434   template <typename T>
435   T GetValueWithCheck() {
436     auto value_opt = GetValue<T>();
437     if (!value_opt.has_value()) {
438       MS_LOG(EXCEPTION)
439         << "Get value failed, there is no any value in KernelTensor."
440            "Here are the possible reasons:"
441            "1. When the operator KernelMod is registered, the data type is not correct, such as Scalar or Tuple, "
442            "but is registered as Tensor."
443            "2. If the KernelMod is registered correctly, it may be an attempt to GetValue the output of the "
444            "previous operator. During compilation, the output of the operator has no value. You can check the ir "
445            "file to see if the input for the current operator value is from an operator.";
446     }
447     return value_opt.value();
448   }
449 
450   // Get the scalar value store in KernelTensor if exists.
451   // Return the optional contain value if the KernelTensor has value, otherwise nullopt.
452   template <typename T, typename std::enable_if<std::is_scalar<std::decay_t<T>>::value>::type * = nullptr>
453   std::optional<T> GetValue() {
454     MS_EXCEPTION_IF_NULL(host_info_);
455     std::lock_guard<std::mutex> lock(host_info_->value_mutex_);
456 
457     // There is a origin value in KernelTensor(maybe come from a ValueNode).
458     if (address_common_->dtype_id_ == kMetaTypeNone) {
459       MS_LOG(DEBUG) << "None type has no valid scalar value.";
460       return std::nullopt;
461     } else if (value_ && !value_->isa<ValueAny>()) {
462       if (host_info_->kernel_tensor_value_ == nullptr) {
463         host_info_->kernel_tensor_value_ = ConvertValueToKernelTensorValue(value_);
464       }
465     } else {
466       // Sync value data from device.
467       if (!SyncDataFromDeviceToHost()) {
468         MS_LOG(ERROR) << "Sync data from device to host side failed";
469         return std::nullopt;
470       }
471     }
472 
473     MS_EXCEPTION_IF_NULL(host_info_->kernel_tensor_value_);
474     MS_EXCEPTION_IF_CHECK_FAIL((host_info_->kernel_tensor_value_->GetDataSize() == sizeof(T)),
475                                "The data size in kernel tensor value which contains a scalar [" +
476                                  std::to_string(host_info_->kernel_tensor_value_->GetDataSize()) +
477                                  "] is not equal to the data type size [" + std::to_string(sizeof(T)) + "]");
478 
479     const T *data_ptr = reinterpret_cast<const T *>(host_info_->kernel_tensor_value_->GetDataPtr());
480     MS_EXCEPTION_IF_NULL(data_ptr);
481     return *data_ptr;
482   }
483 
484   // Get the std::vector/std::string value store in KernelTensor if exists.
485   // Return the optional contain value if the KernelTensor has value, otherwise nullopt.
486   template <typename T, typename std::enable_if<IsValidContainer<T>::value>::type * = nullptr>
487   std::optional<T> GetValue() {
488     if (!std::is_scalar_v<typename T::value_type>) {
489       MS_LOG(EXCEPTION) << "The element of std::vector to get kernel tensor's value should be scalar type.";
490     }
491     MS_EXCEPTION_IF_NULL(host_info_);
492     std::lock_guard<std::mutex> lock(host_info_->value_mutex_);
493 
494     // There is a origin value in KernelTensor(maybe come from a ValueNode).
495     if (address_common_->dtype_id_ == kMetaTypeNone) {
496       MS_LOG(DEBUG) << "None type has no valid value for vector or string.";
497       return std::nullopt;
498     } else if (value_ && !value_->isa<ValueAny>()) {
499       if (host_info_->kernel_tensor_value_ == nullptr) {
500         host_info_->kernel_tensor_value_ = ConvertValueToKernelTensorValue(value_);
501       }
502     } else {
503       // Sync value data from device.
504       if (!SyncDataFromDeviceToHost()) {
505         MS_LOG(ERROR) << "Sync data from device to host side failed";
506         return std::nullopt;
507       }
508     }
509 
510     MS_EXCEPTION_IF_NULL(host_info_->kernel_tensor_value_);
511     size_t element_num = host_info_->kernel_tensor_value_->GetDataSize() / sizeof(typename T::value_type);
512     if (element_num == 0) {
513       return T();
514     }
515     const typename T::value_type *data_ptr =
516       reinterpret_cast<const typename T::value_type *>(host_info_->kernel_tensor_value_->GetDataPtr());
517     MS_EXCEPTION_IF_NULL(data_ptr);
518 
519     return T(data_ptr, data_ptr + element_num);
520   }
521 
522   // Get the value stored in KernelTensor for type which is not scalar, std::vector or std::string if exists.
523   // Return the optional contain value if the KernelTensor has value, otherwise nullopt.
524   template <typename T, typename std::enable_if<!IsValidContainer<T>::value && !std::is_pointer_v<T> &&
525                                                 !std::is_scalar<std::decay_t<T>>::value>::type * = nullptr>
526   std::optional<T> GetValue() {
527     if (address_common_->dtype_id_ == kMetaTypeNone) {
528       MS_LOG(DEBUG) << "None type has no valid value.";
529       return std::nullopt;
530     }
531     if (value_ && !value_->isa<ValueAny>()) {
532       return mindspore::GetValue<T>(value_);
533     }
534     return std::nullopt;
535   }
536 
537   // Get the value in KernelTensor, return it if there is specific value, otherwise throw an exception.
538   template <typename T>
539   std::optional<T> GetOptionalValueWithCheck() {
540     if (value_ && value_->isa<None>()) {
541       return std::nullopt;
542     }
543     return GetValueWithCheck<T>();
544   }
545 
546   // Get the data format.
547   mindspore::Format format() const { return address_common_->format_; }
548 
549   // Set the data format.
550   void set_format(mindspore::Format format) { address_common_->format_ = format; }
551 
552   // Get the data format of string type.
553   std::string GetStringFormat() const;
554 
555   // Set the data format of string type.
556   void SetStringFormat(const std::string &format);
557 
558   // Get pointer and reference count.
559   const PointerRefCountPtr &pointer_ref_count() const { return address_common_->pointer_ref_count_; }
560 
561   // Set pointer and reference count.
562   void set_pointer_ref_count(const PointerRefCountPtr &ptr_ref_cnt) {
563     address_common_->pointer_ref_count_ = ptr_ref_cnt;
564   }
565 
566   //  Set the pointer and reference count to nullptr, resource reclaiming of the device pointer is automatically
567   //  released.
568   void ReleaseDeviceRes() { address_common_->pointer_ref_count_ = nullptr; }
569 
570   // Set pointer resource destructor.
571   void set_deleter(const Deleter &deleter) { address_common_->pointer_ref_count_->set_deleter(deleter); }
572 
573   // Get pointer to the device side that corresponds to KernelTensor, used in runtime.
574   void *device_ptr() const { return address_common_->pointer_ref_count_->ptr(); }
575 
576   // Set pointer to the device side that corresponds to KernelTensor, used in runtime.
577   void set_device_ptr(void *ptr) { address_common_->pointer_ref_count_->set_ptr(ptr); }
578 
579   // Get the memory size in byte of the KernelTensor.
580   size_t size() const { return address_common_->size_; }
581 
582   // Set the memory size in byte of the KernelTensor.
583   void set_size(size_t size) { address_common_->size_ = size; }
584 
585   // Get device target name, such "GPU","Ascend".
586   const std::string &device_name() const { return address_common_->device_name_; }
587 
588   // Set device target name, such "GPU","Ascend".
589   void set_device_name(const std::string &device_name) { address_common_->device_name_ = device_name; }
590 
591   // Get device id.
592   uint32_t device_id() const { return address_common_->device_id_; }
593 
594   // Set device id.
595   void set_device_id(uint32_t device_id) { address_common_->device_id_ = device_id; }
596 
597   // Get logical stream id.
598   uint32_t stream_id() const { return address_common_->stream_id_; }
599 
600   // Set logical stream id.
601   void set_stream_id(uint32_t stream_id) { address_common_->stream_id_ = stream_id; }
602 
603   // Get task id on stream.
604   std::shared_ptr<int64_t> task_id_on_stream() const { return task_id_on_stream_; }
605 
606   // Set task id on stream.
607   void set_task_id_on_stream(const std::shared_ptr<int64_t> &task_id_on_stream) {
608     task_id_on_stream_ = task_id_on_stream;
609   }
610 
611   bool managed_by_somas() const { return address_common_->managed_by_somas_; }
612 
613   void set_managed_by_somas(bool managed_by_somas) { address_common_->managed_by_somas_ = managed_by_somas; }
614 
615   // Get user data maintained by the KernelTensor.
616   const UserDataPtr &user_data() const { return user_data_; }
617 
618   // Set user data to the KernelTensor.
619   void set_user_data(const UserDataPtr &user_data) { user_data_ = user_data; }
620 
621   // Set device synchronizer to the KernelTensor.
622   void set_device_synchronizer(const DeviceSynchronizerPtr &device_synchronizer) {
623     device_synchronizer_ = device_synchronizer;
624   }
625 
626   // Clone a new KernelTensor from this.
627   std::shared_ptr<KernelTensor> CloneKernelTensor() { return std::make_shared<KernelTensor>(*this); }
628 
629   // Check whether the shape is dynamic shape(contains dim which is less than 0).
630   bool IsDynamicShape() const;
631 
632   // Check whether the KernelTensor is from a constant variable(such as ValueNode).
633   inline bool IsConstValue() const { return (value_ != nullptr) && !(value_->isa<ValueAny>()); }
634 
635   // The following four methods are only used in the Lite framework.
636   // Get the device data address(pointer and size).
637   AddressPtr GetData() const { return data_; }
638   // Set the device data address(pointer and size).
639   void SetData(const AddressPtr &data) { data_ = data; }
640   // Get the host data address(pointer and size).
641   AddressPtr GetHostData() const { return host_data_; }
642   // Set the host data address(pointer and size).
643   void SetHostData(const AddressPtr &data) { host_data_ = data; }
644 
645   // max shape is only used in compute-depended ops
646   ShapeVector GetMaxShape() const;
647 
648   const TensorStorageInfoPtr tensor_storage_info() const { return address_common_->tensor_storage_info_; }
649   void set_tensor_storage_info(const TensorStorageInfoPtr &storage_info) {
650     address_common_->tensor_storage_info_ = storage_info;
651   }
652 
653   const AddressCommonPtr address_common() const { return address_common_; }
654   void set_address_common(const AddressCommonPtr &address_common) { address_common_ = address_common; }
655 
656  private:
657   // This is a deprecated function in base class.
658   BaseShapePtr BuildShape() const override {
659     MS_LOG(EXCEPTION) << "Call deprecated function: BuildShape, Please use GetShape instead of BuildShape in "
660                          "operators' infer functions in the `core/ops` directory.";
661   }
662 
663   // This is a deprecated function in base class
664   TypePtr BuildType() const override {
665     MS_LOG(EXCEPTION) << "Call deprecated function: BuildType, Please use GetType instead of BuildType in "
666                          "operators' infer functions in the `core/ops` directory.";
667   }
668 
669   // Set the element data type to KernelTensor for Sequence type(Tuple or List).
670   void SetSequenceDType(const TypePtr &element_type);
671 
672   // Synchronize value data from device to host side.
673   bool SyncDataFromDeviceToHost() const;
674 
675   // Calculate memory size need by the KernelTensor.
676   void CalculateMemSize();
677 
678   // Check whether need to transpose host infer shape to device shape.
679   bool NeedTransposeToDeviceShape() const noexcept;
680 
681   // Transpose host infer shape to device shape according format.
682   const ShapeVector &TransposeToDeviceShape() const;
683 
684   // If host info is not initialized in the constructor, initialize it when you need it, making sure that host info is
685   // not empty when used.
686   void CheckHostInfoValid();
687 
688   // The host-side related data in KernelTensor.
689   // Note: To improve the performance of constructing KernelTensor, allow some constructors not to initialize host info.
690   // If host info is not initialized in the constructor, it can be initialized when it is needed.
691   std::unique_ptr<KernelHostInfo> host_info_{nullptr};
692 
693   // The launch index on stream managed by framework.
694   std::shared_ptr<int64_t> task_id_on_stream_{nullptr};
695 
696   // The flatten shape(maybe after padding) vector.
697   // Note: the 'host_shape_' will be repalced by 'shape_vector_' in the future.
698   ShapeVector host_shape_{};
699 
700   // User data is the extra data required by the kernel or framework.
701   UserDataPtr user_data_{nullptr};
702 
703   // For synchronizing data between device and host.
704   DeviceSynchronizerPtr device_synchronizer_{nullptr};
705 
706   // The following two variables are only used in the Lite framework.
707   // Device data address.
708   AddressPtr data_{nullptr};
709   // Host data address.
710   AddressPtr host_data_{nullptr};
711 
712   // address basic info
713   AddressCommonPtr address_common_{nullptr};
714 };
715 using KernelTensorPtr = std::shared_ptr<KernelTensor>;
716 
717 enum class KernelModType {
718   Invalid = 0,
719   KernelMod,
720   GpuKernelMod,
721   NativeGpuKernelMod,
722   CpuKernelMod,
723   NativeCpuKernelMod,
724   HostKernelMod,
725   DynamicAkgCpuKernelMod,
726 };
727 
728 // The info of kernel launch.
729 struct KernelLaunchInfo {
730   std::vector<KernelTensor *> inputs_;
731   std::vector<KernelTensor *> outputs_;
732   std::vector<KernelTensor *> workspaces_;
733 };
734 
735 enum KernelErrorCode : int { KRET_OK = 0, KRET_RESIZE_FAILED = 1, KRET_UNKNOWN_SHAPE = 2, KRET_UNKNOWN_OUT_SHAPE = 3 };
736 
737 class BACKEND_EXPORT KernelMod {
738  public:
739   KernelMod() = default;
740   virtual ~KernelMod() = default;
741 
742   virtual std::vector<KernelAttr> GetOpSupport() = 0;
743 
744   virtual bool Init(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) {
745     MS_LOG(EXCEPTION) << "The KernelMod[" << kernel_name_ << "] doesn't implement virtual function 'Init'";
746   }
747 
748   inline bool Init(const PrimitivePtr &primitive, const std::vector<KernelTensor *> &inputs,
749                    const std::vector<KernelTensor *> &outputs) {
750     primitive_ = primitive;
751     MS_EXCEPTION_IF_NULL(primitive_);
752     kernel_name_ = primitive_->name();
753 
754     return Init(inputs, outputs);
755   }
756 
757   // Resize() is for validating input/output shape and calculating the workspace size, framework will invoke this
758   // routine after infer shape.
759   virtual int Resize(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs);
760 
761   virtual bool Launch(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &workspace,
762                       const std::vector<KernelTensor *> &outputs, void *stream_ptr) {
763     return true;
764   }
765 
766   // Some kernels, e.g., Unique, can only get its output shape after its computing finished.
767   virtual bool IsNeedUpdateOutputShapeAndSize() { return false; }
768   virtual void UpdateOutputShapeAndSize(const std::vector<KernelTensor *> &inputs,
769                                         const std::vector<KernelTensor *> &outputs) {}
770 
771   // Some kernels, e.g., Shape/Reshape, don't use some input addresses in the kernel launch.
772   virtual std::vector<size_t> GetLaunchIgnoredInputAddressIdx() const { return {}; }
773 
774   void SetDevicedId(uint32_t device_id) { device_id_ = device_id; }
775   virtual enum KernelModType GetKernelModType() const { return KernelModType::KernelMod; }
776 
777   virtual void SetInputSizeList(const std::vector<size_t> &size_list) { input_size_list_ = size_list; }
778   virtual void SetOutputSizeList(const std::vector<size_t> &size_list) { output_size_list_ = size_list; }
779   virtual void SetWorkspaceSizeList(const std::vector<size_t> &size_list) { workspace_size_list_ = size_list; }
780   const std::vector<size_t> &GetInputSizeList() const { MS_LOG(EXCEPTION) << "Call deprecated interface."; }
781   virtual const std::vector<size_t> &GetOutputSizeList() const { return output_size_list_; }
782   virtual const std::vector<size_t> &GetWorkspaceSizeList() const { return workspace_size_list_; }
783 
784   const PrimitivePtr &primitive() const { return primitive_; }
785   const std::string &kernel_name() const { return kernel_name_; }
786 
787   virtual std::vector<size_t> GenParameters() { return {}; }
788   virtual void GenAtomicInitInfo(AtomicInitInfo *info) {}
789 
790   virtual void set_unique_name(const std::string &unique_name) {
791     MS_LOG(EXCEPTION) << "Call the method which doesn't implement";
792   }
793 
794   virtual void set_fullname(const std::string &fullname) {
795     MS_LOG(EXCEPTION) << "Call the method which doesn't implement";
796   }
797 
798   virtual void set_is_monad(bool is_monad) { MS_LOG(EXCEPTION) << "Call the method which doesn't implement"; }
799 
800   // If output of kernel has a user_data, it needs to return true, and the framework will create user_data for it.
801   virtual bool need_user_data() const { return false; }
802 
803   int32_t task_id() const { return task_id_; }
804   bool use_kernel_tensor() const { return use_kernel_tensor_; }
805   void set_use_kernel_tensor(bool use_kernel_tensor) { use_kernel_tensor_ = use_kernel_tensor; }
806 
807   uint32_t record_stream_id() const { return record_stream_id_; }
808   void set_record_stream_id(uint32_t record_stream_id) { record_stream_id_ = record_stream_id; }
809 
810   virtual bool Finalize() { return true; }
811 
812  protected:
813   bool IsValidShape(const ShapeVector &shape) const {
814     if (std::any_of(shape.begin(), shape.end(), [](int64_t dim) { return dim < 0; })) {
815       return false;
816     }
817     return true;
818   }
819 
820  protected:
821   std::string kernel_name_;
822   PrimitivePtr primitive_;
823   uint32_t device_id_ = 0;
824   std::vector<size_t> input_size_list_;
825   std::vector<size_t> output_size_list_;
826   std::vector<size_t> workspace_size_list_;
827 
828   int32_t task_id_ = -1;
829   bool use_kernel_tensor_{false};
830   uint32_t record_stream_id_{0};
831 };
832 using KernelModPtr = std::shared_ptr<KernelMod>;
833 
834 template <typename T>
835 inline T *GetDeviceAddress(const std::vector<KernelTensor *> &addr_list, size_t index) {
836   if (index >= addr_list.size()) {
837     MS_LOG(ERROR) << "Address index(" << index << ") out of range(" << addr_list.size() << ")";
838     return nullptr;
839   }
840 
841   if (addr_list[index] == nullptr) {
842     MS_LOG(ERROR) << "The device address is nullptr, address index: " << index << ", and the length of 'addr_list' is "
843                   << addr_list.size();
844     return nullptr;
845   }
846 
847   if (addr_list[index]->device_ptr() == nullptr) {
848     MS_LOG(WARNING) << "The memory of device address is nullptr, address index: " << index
849                     << ", and the length of 'addr_list' is " << addr_list.size();
850     return nullptr;
851   }
852 
853   // When the input is an empty tuple, the input size will be 0.
854   if (addr_list[index]->size() == 0) {
855     MS_LOG(INFO) << "The size of device address is zero, address index: " << index
856                  << ", and the length of 'addr_list' is " << addr_list.size();
857   }
858   return reinterpret_cast<T *>(addr_list[index]->device_ptr());
859 }
860 
861 BACKEND_EXPORT std::vector<std::vector<int64_t>> GetShapes(const std::vector<KernelTensor *> &tensors);
862 
863 BACKEND_EXPORT void ConvertLaunchInfoToAddr(const KernelLaunchInfo &launch_info, KernelLaunchAddr *mem_info);
864 
865 template <typename T>
866 inline bool CheckNullInput(const std::vector<T> &input_shape) {
867   // If input_shape.size() == 0, it means a scalar input; If input_shape.size() != 0 and input_shape contains 0,
868   // it means a null input. Just return a null output.
869   if (input_shape.size() != 0) {
870     if (std::any_of(input_shape.begin(), input_shape.end(), [](T i) { return i == 0; })) {
871       return true;
872     }
873   }
874   return false;
875 }
876 #define CHECK_NULL_INPUT(input_shape) mindspore::kernel::CheckNullInput(input_shape)
877 
878 template <typename T>
879 inline bool CheckShapeNull(const std::vector<T> &shape, std::string kernel_name, std::string param_name) {
880   if (CHECK_NULL_INPUT(shape)) {
881     MS_LOG(WARNING) << "For '" << kernel_name << "', the shape of " << param_name << " cannot contain zero, but got "
882                     << shape;
883     return true;
884   }
885   return false;
886 }
887 
888 #define CHECK_SHAPE_NULL(shape, kernel_name, param_name) \
889   mindspore::kernel::CheckShapeNull(shape, kernel_name, param_name)
890 }  // namespace kernel
891 }  // namespace mindspore
892 
893 #endif  // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_KERNEL_H_
894