1 /**
2 * Copyright 2021-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "cpu_kernel/ms_kernel/trace.h"
18 #include <algorithm>
19 #include "cstring"
20 #include "securec.h"
21 #include "context/inc/cpu_kernel_utils.h"
22 #include "utils/eigen_tensor.h"
23 #include "utils/kernel_util.h"
24
25 namespace {
26 const uint32_t kInputNum = 1;
27 const uint32_t kOutputNum = 1;
28 const uint32_t InputShapeDim = 2;
29 const uint32_t OutputShapeDim = 1;
30 const uint64_t OutputShapeDimSize = 1;
31 const char *kTrace = "Trace";
32
33 #define TRACE_COMPUTE_CASE(DTYPE, INPUT, OUTPUT, CTX, TYPE) \
34 case (DTYPE): { \
35 uint32_t result = TraceCompute<TYPE>(INPUT, OUTPUT, CTX); \
36 if (result != KERNEL_STATUS_OK) { \
37 CUST_KERNEL_LOG_ERROR(ctx, "Trace kernel compute failed."); \
38 return result; \
39 } \
40 break; \
41 }
42 } // namespace
43
44 namespace aicpu {
Compute(CpuKernelContext & ctx)45 uint32_t TraceCpuKernel::Compute(CpuKernelContext &ctx) {
46 CUST_KERNEL_HANDLE_ERROR(ctx, NormalCheck(ctx, kInputNum, kOutputNum), "Trace check input and output number failed.");
47
48 Tensor *input_tensor = ctx.Input(0);
49 CUST_KERNEL_CHECK_NULLPTR(ctx, input_tensor->GetData(), KERNEL_STATUS_PARAM_INVALID, "Trace get input data failed.")
50 CUST_KERNEL_CHECK_NULLPTR(ctx, input_tensor->GetTensorShape(), KERNEL_STATUS_PARAM_INVALID,
51 "Trace get input shape failed")
52
53 if (input_tensor->GetTensorShape()->GetDims() != InputShapeDim) {
54 CUST_KERNEL_LOG_ERROR(ctx, "Trace input dim must be 2!");
55 return KERNEL_STATUS_PARAM_INVALID;
56 }
57
58 // check output tensor
59 Tensor *output_tensor = ctx.Output(0);
60 CUST_KERNEL_CHECK_NULLPTR(ctx, output_tensor, KERNEL_STATUS_PARAM_INVALID, "Trace get output failed.")
61 CUST_KERNEL_CHECK_NULLPTR(ctx, output_tensor->GetData(), KERNEL_STATUS_PARAM_INVALID, "Trace get output data failed.")
62 CUST_KERNEL_CHECK_NULLPTR(ctx, output_tensor->GetTensorShape(), KERNEL_STATUS_PARAM_INVALID,
63 "Trace get output shape failed")
64
65 auto input_dtype = input_tensor->GetDataType();
66 auto output_dtype = output_tensor->GetDataType();
67 switch (input_dtype) {
68 TRACE_COMPUTE_CASE(DT_INT8, input_tensor, output_tensor, ctx, int8_t)
69 TRACE_COMPUTE_CASE(DT_UINT8, input_tensor, output_tensor, ctx, uint8_t)
70 TRACE_COMPUTE_CASE(DT_INT16, input_tensor, output_tensor, ctx, int16_t)
71 TRACE_COMPUTE_CASE(DT_UINT16, input_tensor, output_tensor, ctx, uint16_t)
72 TRACE_COMPUTE_CASE(DT_INT32, input_tensor, output_tensor, ctx, int32_t)
73 TRACE_COMPUTE_CASE(DT_UINT32, input_tensor, output_tensor, ctx, uint32_t)
74 TRACE_COMPUTE_CASE(DT_INT64, input_tensor, output_tensor, ctx, int64_t)
75 TRACE_COMPUTE_CASE(DT_UINT64, input_tensor, output_tensor, ctx, uint64_t)
76 TRACE_COMPUTE_CASE(DT_FLOAT16, input_tensor, output_tensor, ctx, Eigen::half)
77 TRACE_COMPUTE_CASE(DT_FLOAT, input_tensor, output_tensor, ctx, float)
78 TRACE_COMPUTE_CASE(DT_DOUBLE, input_tensor, output_tensor, ctx, double)
79 default:
80 CUST_KERNEL_LOG_ERROR(ctx, "Trace kernel data type [%u] not support", output_dtype);
81 return KERNEL_STATUS_PARAM_INVALID;
82 }
83 return KERNEL_STATUS_OK;
84 }
85
86 template <typename T>
TraceCompute(Tensor * input,Tensor * output,CpuKernelContext & ctx)87 uint32_t TraceCpuKernel::TraceCompute(Tensor *input, Tensor *output, CpuKernelContext &ctx) {
88 auto inputDataAddr = reinterpret_cast<T *>(input->GetData());
89 auto outputDataAddr = reinterpret_cast<T *>(output->GetData());
90 auto input_shape = ctx.Input(0)->GetTensorShape();
91 int64_t inputLine = input_shape->GetDimSize(0);
92 int64_t inputCol = input_shape->GetDimSize(1);
93 auto min_shape = std::min(inputLine, inputCol);
94 auto output_size = output->GetDataSize();
95 auto ret = memset_s(outputDataAddr, output_size, 0, sizeof(T));
96 if (ret != EOK) {
97 CUST_KERNEL_LOG_ERROR(ctx, "For 'Trace', memset_s failed, ret=%d.", ret);
98 return KERNEL_STATUS_INNER_ERROR;
99 }
100 for (int64_t i = 0; i < min_shape; i++) {
101 *(outputDataAddr) += *(inputDataAddr + i * inputCol + i);
102 }
103 return KERNEL_STATUS_OK;
104 }
105
106 REGISTER_MS_CPU_KERNEL(kTrace, TraceCpuKernel);
107 } // namespace aicpu
108