• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <vector>
17 #include "custom_aot_extra.h"
18 
19 class add_reduce_kernel_attr : public AotKernelData {
20  public:
21   int64_t axis;
22   bool keep_dim;
23 };
24 
CustomKernelInit(int * ndims,int64_t ** shapes,const char ** dtypes,AotExtra * extra)25 extern "C" int CustomKernelInit(int *ndims, int64_t **shapes, const char **dtypes, AotExtra *extra) {
26   size_t workspace_size = 1;
27   for (size_t i = 0; i < ndims[0]; i++) {
28     workspace_size *= shapes[0][i];
29   }
30 
31   std::vector<size_t> workspace = {workspace_size * sizeof(float)};
32   extra->SetWorkSpace(workspace);
33 
34   add_reduce_kernel_attr *kernel_data_ptr = new add_reduce_kernel_attr;
35   kernel_data_ptr->axis = extra->Attr<int64_t>("axis");
36   kernel_data_ptr->keep_dim = extra->Attr<bool>("keep_dim");
37   extra->SetKernelData(kernel_data_ptr);
38   return 0;
39 }
40 
CustomKernelInferShape(int * ndims,int64_t ** shapes,AotExtra * extra)41 extern "C" std::vector<int64_t> CustomKernelInferShape(int *ndims, int64_t **shapes, AotExtra *extra) {
42   const int64_t kDynRankSize = -2;
43 
44   if (shapes[0][0] == kDynRankSize) {
45     return std::vector<int64_t>{shapes[0][0]};
46   }
47   int64_t axis = extra->Attr<int64_t>("axis");
48   bool keep_dim = extra->Attr<bool>("keep_dim");
49   if (keep_dim) {
50     if (axis == 0) {
51       return std::vector<int64_t>{1, shapes[0][1]};
52     } else {
53       return std::vector<int64_t>{shapes[0][0], 1};
54     }
55   } else {
56     return std::vector<int64_t>{shapes[0][1 - axis]};
57   }
58 }
59 
CustomKernel(int nparam,void ** params,int * ndims,int64_t ** shapes,const char ** dtypes,void * stream,void * extra_void)60 extern "C" int CustomKernel(int nparam, void **params, int *ndims, int64_t **shapes, const char **dtypes, void *stream,
61                             void *extra_void) {
62   float *input_1 = static_cast<float *>(params[0]);
63   float *input_2 = static_cast<float *>(params[1]);
64   float *output = static_cast<float *>(params[2]);
65   float *tmp = static_cast<float *>(params[3]);
66 
67   // Add
68   int in_size = 1;
69   for (int i = 0; i < ndims[0]; i++) {
70     in_size *= shapes[0][i];
71   }
72 
73   for (int i = 0; i < in_size; i++) {
74     tmp[i] = input_1[i] + input_2[i];
75   }
76 
77   // ReduceSum
78   AotExtra *extra = static_cast<AotExtra *>(extra_void);
79   auto kernel_ptr = static_cast<add_reduce_kernel_attr *>(extra->KernelData());
80   bool keep_dim = kernel_ptr->keep_dim;
81   int64_t axis = kernel_ptr->axis;
82   int64_t input_dim_1 = shapes[0][1];
83   int size;
84   if (keep_dim) {
85     size = shapes[1][0] * shapes[1][1];
86   } else {
87     size = shapes[1][0];
88   }
89   int ext = shapes[0][axis];
90   for (int i = 0; i < size; i++) {
91     output[i] = 0;
92     for (int j = 0; j < ext; j++) {
93       int idx = input_dim_1 * (i * axis + j * (1 - axis)) + i * (1 - axis) + j * axis;
94       output[i] = output[i] + tmp[idx];
95     }
96   }
97   return 0;
98 }