• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2024 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_DEBUG_STATISTIC_KERNEL_H_
18 #define MINDSPORE_CCSRC_DEBUG_STATISTIC_KERNEL_H_
19 
20 #include <map>
21 #include <memory>
22 #include <set>
23 #include <string>
24 #include <vector>
25 #include "include/common/debug/common.h"
26 #include "ir/dtype/tensor_type.h"
27 #include "mindrt/include/async/async.h"
28 #include "ops/auto_generate/gen_ops_primitive.h"
29 #include "runtime/device/device_address_utils.h"
30 #include "runtime/graph_scheduler/actor/actor_common.h"
31 #include "runtime/graph_scheduler/device_tensor_store.h"
32 #include "runtime/hardware/device_context.h"
33 #include "utils/log_adapter.h"
34 
35 namespace mindspore {
36 
37 namespace datadump {
38 using device::DeviceAddressPtr;
39 using kernel::KernelTensor;
40 using mindspore::device::DeviceContext;
41 using TensorPtr = tensor::TensorPtr;
42 
43 class StatisticKernel {
44  public:
StatisticKernel(const DeviceContext * device_context,string kernel_name,const std::set<TypeId> & dtype_id)45   StatisticKernel(const DeviceContext *device_context, string kernel_name, const std::set<TypeId> &dtype_id)
46       : device_context_(device_context), kernel_name_(kernel_name), supported_dtype_(dtype_id) {
47     MS_EXCEPTION_IF_NULL(device_context);
48     MS_EXCEPTION_IF_NULL(device_context_->device_res_manager_);
49     MS_LOG(DEBUG) << "Statistic kernel mod " << kernel_name_ << " construct.";
50     kernel_mod_ = device_context_->GetKernelExecutor(false)->CreateKernelMod(kernel_name);
51     MS_EXCEPTION_IF_NULL(kernel_mod_);
52   }
53   DeviceAddressPtr GenerateDeviceAddress(const uint32_t &stream_id, const size_t &mem_size, const TypeId &dtype_id,
54                                          const ShapeVector &shape, const ValuePtr &value = nullptr);
55   DeviceAddressPtr GetWorkSpaceDeviceAddress(const uint32_t stream_id, const vector<KernelTensor *> &inputs,
56                                              const vector<KernelTensor *> &outputs);
57   DeviceAddressPtr GetOutputDeviceAddress(const uint32_t stream_id, TypeId dtype_id);
58   TensorPtr LaunchKernel(KernelTensor *input);
59   TensorPtr SyncDeviceToHostTensor(DeviceAddressPtr device_addr);
CheckDataType(const TypeId & dtype_id)60   bool CheckDataType(const TypeId &dtype_id) { return supported_dtype_.find(dtype_id) != supported_dtype_.end(); }
61 
62  protected:
63   const DeviceContext *device_context_{nullptr};
64   string kernel_name_;
65   kernel::KernelModPtr kernel_mod_;
66   std::set<TypeId> supported_dtype_;
67 };
68 
69 class DimStatisticKernel : public StatisticKernel {
70  public:
DimStatisticKernel(const DeviceContext * device_context,string kernel_name,const std::set<TypeId> & dtype_id)71   explicit DimStatisticKernel(const DeviceContext *device_context, string kernel_name, const std::set<TypeId> &dtype_id)
72       : StatisticKernel(device_context, kernel_name, dtype_id) {}
73   TensorPtr LaunchKernel(KernelTensor *input);
74   TensorPtr Launch(vector<KernelTensor *> inputs, DeviceAddressPtr output_addr, uint32_t stream_id);
75   DeviceAddressPtr GetAxisDeviceAddress(const uint32_t stream_id, size_t dim);
76   DeviceAddressPtr GetKeepDimsDeviceAddress(const uint32_t stream_id);
77   DeviceAddressPtr GetDtypeDeviceAddress(const uint32_t stream_id, const TypeId &);
78 };
79 
80 class MeanStatisticKernel : public DimStatisticKernel {
81  public:
MeanStatisticKernel(const DeviceContext * device_context,const std::set<TypeId> & dtype_id)82   explicit MeanStatisticKernel(const DeviceContext *device_context, const std::set<TypeId> &dtype_id)
83       : DimStatisticKernel(device_context, ops::kNameMeanExt, dtype_id) {}
84 };
85 
86 class NormStatisticKernel : public DimStatisticKernel {
87  public:
NormStatisticKernel(const DeviceContext * device_context,const std::set<TypeId> & dtype_id)88   explicit NormStatisticKernel(const DeviceContext *device_context, const std::set<TypeId> &dtype_id)
89       : DimStatisticKernel(device_context, ops::kNameNorm, dtype_id) {}
90   TensorPtr LaunchKernel(KernelTensor *input);
91   DeviceAddressPtr GetScalar(const uint32_t stream_id, float scalar = 2.0);
92 };
93 
94 TensorPtr CalL2Norm(const DeviceContext *device_context, KernelTensor *input);
95 TensorPtr CalMax(const DeviceContext *device_context, KernelTensor *input);
96 TensorPtr CalMin(const DeviceContext *device_context, KernelTensor *input);
97 TensorPtr CalMean(const DeviceContext *device_context, KernelTensor *input);
98 TensorPtr CalStatistic(const std::string &stat_name, const DeviceContext *device_context, KernelTensor *input);
99 
100 }  // namespace datadump
101 
102 }  // namespace mindspore
103 
104 #endif  // MINDSPORE_CCSRC_DEBUG_STATISTIC_KERNEL_H_
105