• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "cpu_kernel/ms_kernel/trace.h"
18 #include <algorithm>
19 #include "cstring"
20 #include "securec.h"
21 #include "context/inc/cpu_kernel_utils.h"
22 #include "utils/eigen_tensor.h"
23 #include "utils/kernel_util.h"
24 
25 namespace {
26 const uint32_t kInputNum = 1;
27 const uint32_t kOutputNum = 1;
28 const uint32_t InputShapeDim = 2;
29 const uint32_t OutputShapeDim = 1;
30 const uint64_t OutputShapeDimSize = 1;
31 const char *kTrace = "Trace";
32 
33 #define TRACE_COMPUTE_CASE(DTYPE, INPUT, OUTPUT, CTX, TYPE)       \
34   case (DTYPE): {                                                 \
35     uint32_t result = TraceCompute<TYPE>(INPUT, OUTPUT, CTX);     \
36     if (result != KERNEL_STATUS_OK) {                             \
37       CUST_KERNEL_LOG_ERROR(ctx, "Trace kernel compute failed."); \
38       return result;                                              \
39     }                                                             \
40     break;                                                        \
41   }
42 }  // namespace
43 
44 namespace aicpu {
Compute(CpuKernelContext & ctx)45 uint32_t TraceCpuKernel::Compute(CpuKernelContext &ctx) {
46   CUST_KERNEL_HANDLE_ERROR(ctx, NormalCheck(ctx, kInputNum, kOutputNum), "Trace check input and output number failed.");
47 
48   Tensor *input_tensor = ctx.Input(0);
49   CUST_KERNEL_CHECK_NULLPTR(ctx, input_tensor->GetData(), KERNEL_STATUS_PARAM_INVALID, "Trace get input data failed.")
50   CUST_KERNEL_CHECK_NULLPTR(ctx, input_tensor->GetTensorShape(), KERNEL_STATUS_PARAM_INVALID,
51                             "Trace get input shape failed")
52 
53   if (input_tensor->GetTensorShape()->GetDims() != InputShapeDim) {
54     CUST_KERNEL_LOG_ERROR(ctx, "Trace input dim must be 2!");
55     return KERNEL_STATUS_PARAM_INVALID;
56   }
57 
58   // check output tensor
59   Tensor *output_tensor = ctx.Output(0);
60   CUST_KERNEL_CHECK_NULLPTR(ctx, output_tensor, KERNEL_STATUS_PARAM_INVALID, "Trace get output failed.")
61   CUST_KERNEL_CHECK_NULLPTR(ctx, output_tensor->GetData(), KERNEL_STATUS_PARAM_INVALID, "Trace get output data failed.")
62   CUST_KERNEL_CHECK_NULLPTR(ctx, output_tensor->GetTensorShape(), KERNEL_STATUS_PARAM_INVALID,
63                             "Trace get output shape failed")
64 
65   auto input_dtype = input_tensor->GetDataType();
66   auto output_dtype = output_tensor->GetDataType();
67   switch (input_dtype) {
68     TRACE_COMPUTE_CASE(DT_INT8, input_tensor, output_tensor, ctx, int8_t)
69     TRACE_COMPUTE_CASE(DT_UINT8, input_tensor, output_tensor, ctx, uint8_t)
70     TRACE_COMPUTE_CASE(DT_INT16, input_tensor, output_tensor, ctx, int16_t)
71     TRACE_COMPUTE_CASE(DT_UINT16, input_tensor, output_tensor, ctx, uint16_t)
72     TRACE_COMPUTE_CASE(DT_INT32, input_tensor, output_tensor, ctx, int32_t)
73     TRACE_COMPUTE_CASE(DT_UINT32, input_tensor, output_tensor, ctx, uint32_t)
74     TRACE_COMPUTE_CASE(DT_INT64, input_tensor, output_tensor, ctx, int64_t)
75     TRACE_COMPUTE_CASE(DT_UINT64, input_tensor, output_tensor, ctx, uint64_t)
76     TRACE_COMPUTE_CASE(DT_FLOAT16, input_tensor, output_tensor, ctx, Eigen::half)
77     TRACE_COMPUTE_CASE(DT_FLOAT, input_tensor, output_tensor, ctx, float)
78     TRACE_COMPUTE_CASE(DT_DOUBLE, input_tensor, output_tensor, ctx, double)
79     default:
80       CUST_KERNEL_LOG_ERROR(ctx, "Trace kernel data type [%u] not support", output_dtype);
81       return KERNEL_STATUS_PARAM_INVALID;
82   }
83   return KERNEL_STATUS_OK;
84 }
85 
86 template <typename T>
TraceCompute(Tensor * input,Tensor * output,CpuKernelContext & ctx)87 uint32_t TraceCpuKernel::TraceCompute(Tensor *input, Tensor *output, CpuKernelContext &ctx) {
88   auto inputDataAddr = reinterpret_cast<T *>(input->GetData());
89   auto outputDataAddr = reinterpret_cast<T *>(output->GetData());
90   auto input_shape = ctx.Input(0)->GetTensorShape();
91   int64_t inputLine = input_shape->GetDimSize(0);
92   int64_t inputCol = input_shape->GetDimSize(1);
93   auto min_shape = std::min(inputLine, inputCol);
94   auto output_size = output->GetDataSize();
95   auto ret = memset_s(outputDataAddr, output_size, 0, sizeof(T));
96   if (ret != EOK) {
97     CUST_KERNEL_LOG_ERROR(ctx, "For 'Trace', memset_s failed, ret=%d.", ret);
98     return KERNEL_STATUS_INNER_ERROR;
99   }
100   for (int64_t i = 0; i < min_shape; i++) {
101     *(outputDataAddr) += *(inputDataAddr + i * inputCol + i);
102   }
103   return KERNEL_STATUS_OK;
104 }
105 
106 REGISTER_MS_CPU_KERNEL(kTrace, TraceCpuKernel);
107 }  // namespace aicpu
108