1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include <vector>
17 #include "custom_aot_extra.h"
18
19 class add_reduce_kernel_attr : public AotKernelData {
20 public:
21 int64_t axis;
22 bool keep_dim;
23 };
24
CustomKernelInit(int * ndims,int64_t ** shapes,const char ** dtypes,AotExtra * extra)25 extern "C" int CustomKernelInit(int *ndims, int64_t **shapes, const char **dtypes, AotExtra *extra) {
26 size_t workspace_size = 1;
27 for (size_t i = 0; i < ndims[0]; i++) {
28 workspace_size *= shapes[0][i];
29 }
30
31 std::vector<size_t> workspace = {workspace_size * sizeof(float)};
32 extra->SetWorkSpace(workspace);
33
34 add_reduce_kernel_attr *kernel_data_ptr = new add_reduce_kernel_attr;
35 kernel_data_ptr->axis = extra->Attr<int64_t>("axis");
36 kernel_data_ptr->keep_dim = extra->Attr<bool>("keep_dim");
37 extra->SetKernelData(kernel_data_ptr);
38 return 0;
39 }
40
CustomKernelInferShape(int * ndims,int64_t ** shapes,AotExtra * extra)41 extern "C" std::vector<int64_t> CustomKernelInferShape(int *ndims, int64_t **shapes, AotExtra *extra) {
42 const int64_t kDynRankSize = -2;
43
44 if (shapes[0][0] == kDynRankSize) {
45 return std::vector<int64_t>{shapes[0][0]};
46 }
47 int64_t axis = extra->Attr<int64_t>("axis");
48 bool keep_dim = extra->Attr<bool>("keep_dim");
49 if (keep_dim) {
50 if (axis == 0) {
51 return std::vector<int64_t>{1, shapes[0][1]};
52 } else {
53 return std::vector<int64_t>{shapes[0][0], 1};
54 }
55 } else {
56 return std::vector<int64_t>{shapes[0][1 - axis]};
57 }
58 }
59
CustomKernel(int nparam,void ** params,int * ndims,int64_t ** shapes,const char ** dtypes,void * stream,void * extra_void)60 extern "C" int CustomKernel(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream,
61 void *extra_void) {
62 float *input_1 = static_cast<float *>(params[0]);
63 float *input_2 = static_cast<float *>(params[1]);
64 float *output = static_cast<float *>(params[2]);
65 float *tmp = static_cast<float *>(params[3]);
66
67 // Add
68 int in_size = 1;
69 for (int i = 0; i < ndims[0]; i++) {
70 in_size *= shapes[0][i];
71 }
72
73 for (int i = 0; i < in_size; i++) {
74 tmp[i] = input_1[i] + input_2[i];
75 }
76
77 // ReduceSum
78 AotExtra *extra = static_cast<AotExtra *>(extra_void);
79 auto kernel_ptr = static_cast<add_reduce_kernel_attr *>(extra->KernelData());
80 bool keep_dim = kernel_ptr->keep_dim;
81 int64_t axis = kernel_ptr->axis;
82 int64_t input_dim_1 = shapes[0][1];
83 int size;
84 if (keep_dim) {
85 size = shapes[1][0] * shapes[1][1];
86 } else {
87 size = shapes[1][0];
88 }
89 int ext = shapes[0][axis];
90 for (int i = 0; i < size; i++) {
91 output[i] = 0;
92 for (int j = 0; j < ext; j++) {
93 int idx = input_dim_1 * (i * axis + j * (1 - axis)) + i * (1 - axis) + j * axis;
94 output[i] = output[i] + tmp[idx];
95 }
96 }
97 return 0;
98 }