1 /** 2 * Copyright 2024 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_DEBUG_STATISTIC_KERNEL_H_ 18 #define MINDSPORE_CCSRC_DEBUG_STATISTIC_KERNEL_H_ 19 20 #include <map> 21 #include <memory> 22 #include <set> 23 #include <string> 24 #include <vector> 25 #include "include/common/debug/common.h" 26 #include "ir/dtype/tensor_type.h" 27 #include "mindrt/include/async/async.h" 28 #include "ops/auto_generate/gen_ops_primitive.h" 29 #include "runtime/device/device_address_utils.h" 30 #include "runtime/graph_scheduler/actor/actor_common.h" 31 #include "runtime/graph_scheduler/device_tensor_store.h" 32 #include "runtime/hardware/device_context.h" 33 #include "utils/log_adapter.h" 34 35 namespace mindspore { 36 37 namespace datadump { 38 using device::DeviceAddressPtr; 39 using kernel::KernelTensor; 40 using mindspore::device::DeviceContext; 41 using TensorPtr = tensor::TensorPtr; 42 43 class StatisticKernel { 44 public: StatisticKernel(const DeviceContext * device_context,string kernel_name,const std::set<TypeId> & dtype_id)45 StatisticKernel(const DeviceContext *device_context, string kernel_name, const std::set<TypeId> &dtype_id) 46 : device_context_(device_context), kernel_name_(kernel_name), supported_dtype_(dtype_id) { 47 MS_EXCEPTION_IF_NULL(device_context); 48 MS_EXCEPTION_IF_NULL(device_context_->device_res_manager_); 49 MS_LOG(DEBUG) << "Statistic kernel mod " << kernel_name_ << " construct."; 50 kernel_mod_ = device_context_->GetKernelExecutor(false)->CreateKernelMod(kernel_name); 51 MS_EXCEPTION_IF_NULL(kernel_mod_); 52 } 53 DeviceAddressPtr GenerateDeviceAddress(const uint32_t &stream_id, const size_t &mem_size, const TypeId &dtype_id, 54 const ShapeVector &shape, const ValuePtr &value = nullptr); 55 DeviceAddressPtr GetWorkSpaceDeviceAddress(const uint32_t stream_id, const vector<KernelTensor *> &inputs, 56 const vector<KernelTensor *> &outputs); 57 DeviceAddressPtr GetOutputDeviceAddress(const uint32_t stream_id, TypeId dtype_id); 58 TensorPtr LaunchKernel(KernelTensor *input); 59 TensorPtr SyncDeviceToHostTensor(DeviceAddressPtr device_addr); CheckDataType(const TypeId & dtype_id)60 bool CheckDataType(const TypeId &dtype_id) { return supported_dtype_.find(dtype_id) != supported_dtype_.end(); } 61 62 protected: 63 const DeviceContext *device_context_{nullptr}; 64 string kernel_name_; 65 kernel::KernelModPtr kernel_mod_; 66 std::set<TypeId> supported_dtype_; 67 }; 68 69 class DimStatisticKernel : public StatisticKernel { 70 public: DimStatisticKernel(const DeviceContext * device_context,string kernel_name,const std::set<TypeId> & dtype_id)71 explicit DimStatisticKernel(const DeviceContext *device_context, string kernel_name, const std::set<TypeId> &dtype_id) 72 : StatisticKernel(device_context, kernel_name, dtype_id) {} 73 TensorPtr LaunchKernel(KernelTensor *input); 74 TensorPtr Launch(vector<KernelTensor *> inputs, DeviceAddressPtr output_addr, uint32_t stream_id); 75 DeviceAddressPtr GetAxisDeviceAddress(const uint32_t stream_id, size_t dim); 76 DeviceAddressPtr GetKeepDimsDeviceAddress(const uint32_t stream_id); 77 DeviceAddressPtr GetDtypeDeviceAddress(const uint32_t stream_id, const TypeId &); 78 }; 79 80 class MeanStatisticKernel : public DimStatisticKernel { 81 public: MeanStatisticKernel(const DeviceContext * device_context,const std::set<TypeId> & dtype_id)82 explicit MeanStatisticKernel(const DeviceContext *device_context, const std::set<TypeId> &dtype_id) 83 : DimStatisticKernel(device_context, ops::kNameMeanExt, dtype_id) {} 84 }; 85 86 class NormStatisticKernel : public DimStatisticKernel { 87 public: NormStatisticKernel(const DeviceContext * device_context,const std::set<TypeId> & dtype_id)88 explicit NormStatisticKernel(const DeviceContext *device_context, const std::set<TypeId> &dtype_id) 89 : DimStatisticKernel(device_context, ops::kNameNorm, dtype_id) {} 90 TensorPtr LaunchKernel(KernelTensor *input); 91 DeviceAddressPtr GetScalar(const uint32_t stream_id, float scalar = 2.0); 92 }; 93 94 TensorPtr CalL2Norm(const DeviceContext *device_context, KernelTensor *input); 95 TensorPtr CalMax(const DeviceContext *device_context, KernelTensor *input); 96 TensorPtr CalMin(const DeviceContext *device_context, KernelTensor *input); 97 TensorPtr CalMean(const DeviceContext *device_context, KernelTensor *input); 98 TensorPtr CalStatistic(const std::string &stat_name, const DeviceContext *device_context, KernelTensor *input); 99 100 } // namespace datadump 101 102 } // namespace mindspore 103 104 #endif // MINDSPORE_CCSRC_DEBUG_STATISTIC_KERNEL_H_ 105