• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef NNACL_KERNEL_ARITHMETIC_H_
17 #define NNACL_KERNEL_ARITHMETIC_H_
18 
19 #include "nnacl/op_base.h"
20 #include "nnacl/tensor_c.h"
21 #include "nnacl/kernel.h"
22 #include "nnacl/arithmetic_parameter.h"
23 
24 typedef struct ArithmeticFuncions {
25   int primitive_type_;
26   int activation_type_;
27   int (*compute_f32_)(const float *in1, const float *in2, float *out, int ele);
28   int (*compute_int_)(const int *in1, const int *in2, int *out, int ele);
29   int (*compute_bool_)(const bool *in1, const bool *in2, bool *out, int ele);
30   int (*optimzie_f32_)(const float *in1, const float *in2, float *out, int ele, bool scalar);
31   int (*optimzie_int_)(const int *in1, const int *in2, int *out, int ele, bool scalar);
32   int (*optimzie_bool_)(const bool *in1, const bool *in2, bool *out, int ele, bool scalar);
33 } ArithmeticFuncions;
34 
35 typedef struct ArithmeticMatrixInfo {
36   bool is_const_;
37   bool is_valid_;
38   void *data_;
39   int64_t inner_size_;
40   int shape_[ARITHMETIC_SUPPORT_DIMS_NUM];
41   int shape_size_;
42   int *batch_post_sum_; /* shape size + 1 */
43 } ArithmeticMatrixInfo;
44 
45 typedef struct ArithmeticBlockBoundaryInfo {
46   int batch_begin_;
47   int batch_end_;
48   int size_begin_;  // start-offset under the begin batch
49   int size_end_;    // end-num under the ending batch
50   int *a_offset_;
51   int *b_offset_;
52   bool init_offset_;
53 } ArithmeticBlockBoundaryInfo;
54 
55 typedef struct ArithmeticStruct {
56   KernelBase base_;
57   bool scalar_opt_;
58   int primitive_type_;
59   int ndim_;
60   int in_data_size_;
61   int out_data_size_;
62   int batch_tail_dim_;
63 
64   ArithmeticMatrixInfo a_matrix_;
65   ArithmeticMatrixInfo b_matrix_;
66   ArithmeticMatrixInfo c_matrix_;
67   ArithmeticFuncions functions_;
68 
69   void *broadcast_buffer_[TWO_TENSOR];
70   int block_boundary_infos_size_;
71   ArithmeticBlockBoundaryInfo block_boundary_infos_[MAX_THREAD_NUM];
72 
73   int in_shape0_[ARITHMETIC_SUPPORT_DIMS_NUM];
74   int in_elements_num0_;
75   int in_shape1_[ARITHMETIC_SUPPORT_DIMS_NUM];
76   int in_elements_num1_;
77   int out_shape_[ARITHMETIC_SUPPORT_DIMS_NUM];
78   int out_elements_num_;
79   int in_strides0_[ARITHMETIC_SUPPORT_DIMS_NUM];
80   int in_strides1_[ARITHMETIC_SUPPORT_DIMS_NUM];
81   int out_strides_[ARITHMETIC_SUPPORT_DIMS_NUM];
82   int multiples0_[ARITHMETIC_SUPPORT_DIMS_NUM];
83   int multiples1_[ARITHMETIC_SUPPORT_DIMS_NUM];
84 
85   void (*tile_function_)(const void *inPtr, void *outPtr, int dim, size_t ndim, const int *inShape,
86                          const int *inStrides, const int *outStrides, const int *multiple);
87   int (*execute_)(KernelBase *base, const void *input0, const void *input1, void *output, int64_t size);
88   void (*init_function_)(KernelBase *base);
89 } ArithmeticStruct;
90 
91 KernelBase *CreateArithmetic(OpParameter *param, int data_type);
92 int ArithmeticPrepare(struct KernelBase *self);
93 int ArithmeticRelease(struct KernelBase *self);
94 int ArithmeticCompute(struct KernelBase *self);
95 int ArithmeticResize(struct KernelBase *self);
96 
97 #endif  // NNACL_KERNEL_ARITHMETIC_H_
98