1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either arithmeticress or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/kernel/arithmetic.h"
18 #include "nnacl/op_base.h"
19 #include "nnacl/nnacl_common.h"
20 #include "nnacl/fp32/arithmetic_fp32.h"
21 #include "nnacl/fp32/mul_fp32.h"
22 #include "nnacl/tensor_c_utils.h"
23 #ifdef ENABLE_FP16
24 #include "nnacl/fp16/arithmetic_fp16.h"
25 #endif
26
InitArithmeticRunFunction(KernelBase * self)27 void InitArithmeticRunFunction(KernelBase *self) {
28 ArithmeticStruct *arithmetic = (ArithmeticStruct *)self;
29
30 ArithmeticFuncions fun_table[] = {
31 {PrimType_MulFusion, ActType_Relu, ElementMulRelu, ElementMulReluInt, NULL, ElementOptMulRelu, ElementOptMulReluInt,
32 NULL},
33 {PrimType_MulFusion, ActType_Relu6, ElementMulRelu6, ElementMulRelu6Int, NULL, ElementOptMulRelu6,
34 ElementOptMulRelu6Int, NULL},
35 {PrimType_MulFusion, ActType_No, ElementMul, ElementMulInt, NULL, ElementOptMul, ElementOptMulInt, NULL},
36 {PrimType_AddFusion, ActType_Relu, ElementAddRelu, NULL, NULL, ElementOptAddRelu, NULL, NULL},
37 {PrimType_AddFusion, ActType_Relu6, ElementAddRelu6, NULL, NULL, ElementOptAddRelu6, NULL, NULL},
38 {PrimType_AddFusion, ActType_No, ElementAdd, ElementAddInt, NULL, ElementOptAdd, ElementOptAddInt, NULL},
39 {PrimType_SubFusion, ActType_Relu, ElementSubRelu, NULL, NULL, ElementOptSubRelu, NULL, NULL},
40 {PrimType_SubFusion, ActType_Relu6, ElementSubRelu6, NULL, NULL, ElementOptSubRelu6, NULL, NULL},
41 {PrimType_SubFusion, ActType_No, ElementSub, ElementSubInt, NULL, ElementOptSub, ElementOptSubInt, NULL},
42 {PrimType_DivFusion, ActType_Relu, ElementDivRelu, NULL, NULL, ElementOptDivRelu, NULL, NULL},
43 {PrimType_DivFusion, ActType_Relu6, ElementDivRelu6, NULL, NULL, ElementOptDivRelu6, NULL, NULL},
44 {PrimType_DivFusion, ActType_No, ElementDiv, NULL, NULL, ElementOptDiv, ElementOptDivInt, NULL},
45 {PrimType_RealDiv, ActType_Relu, ElementDivRelu, NULL, NULL, ElementOptDivRelu, NULL, NULL},
46 {PrimType_RealDiv, ActType_Relu6, ElementDivRelu6, NULL, NULL, ElementOptDivRelu6, NULL, NULL},
47 {PrimType_RealDiv, ActType_No, ElementDiv, NULL, NULL, ElementOptDiv, ElementOptDivInt, NULL},
48 {PrimType_LogicalAnd, ActType_No, ElementLogicalAnd, ElementLogicalAndInt, ElementLogicalAndBool,
49 ElementOptLogicalAnd, ElementOptLogicalAndInt, ElementOptLogicalAndBool},
50 {PrimType_LogicalOr, ActType_No, ElementLogicalOr, NULL, ElementLogicalOrBool, NULL, NULL, ElementOptLogicalOrBool},
51 {PrimType_Maximum, ActType_No, ElementMaximum, ElementMaximumInt, NULL, ElementOptMaximum, ElementOptMaximumInt,
52 NULL},
53 {PrimType_Minimum, ActType_No, ElementMinimum, ElementMinimumInt, NULL, ElementOptMinimum, ElementOptMinimumInt,
54 NULL},
55 {PrimType_FloorMod, ActType_No, ElementFloorMod, ElementFloorModInt, NULL, ElementOptFloorMod,
56 ElementOptFloorModInt, NULL},
57 {PrimType_FloorDiv, ActType_No, ElementFloorDiv, ElementFloorDivInt, NULL, ElementOptFloorDiv,
58 ElementOptFloorDivInt, NULL},
59 {PrimType_Mod, ActType_No, ElementMod, ElementModInt, NULL, ElementOptMod, ElementOptModInt, NULL},
60 {PrimType_SquaredDifference, ActType_No, ElementSquaredDifference, NULL, NULL, ElementOptSquaredDifference, NULL,
61 NULL}};
62
63 size_t length = sizeof(fun_table) / sizeof(ArithmeticFuncions);
64 for (size_t i = 0; i < length; i++) {
65 if (fun_table[i].primitive_type_ == arithmetic->primitive_type_ &&
66 fun_table[i].activation_type_ == ((ArithmeticParameter *)(arithmetic->base_.param_))->activation_type_) {
67 arithmetic->functions_ = fun_table[i];
68 return;
69 }
70 }
71 }
72
ArithmeticRelease(struct KernelBase * self)73 int ArithmeticRelease(struct KernelBase *self) {
74 ArithmeticStruct *arithmetic = (ArithmeticStruct *)self;
75 NNACL_CHECK_NULL_RETURN_ERR(arithmetic);
76 for (int i = 0; i < TWO_TENSOR; i++) {
77 if (arithmetic->broadcast_buffer_[i] != NULL) {
78 self->env_->Free(self->env_->allocator_, arithmetic->broadcast_buffer_[i]);
79 arithmetic->broadcast_buffer_[i] = NULL;
80 }
81 }
82
83 for (int i = 0; i < arithmetic->block_boundary_infos_size_; i++) {
84 if (arithmetic->block_boundary_infos_[i].a_offset_ != NULL) {
85 self->env_->Free(self->env_->allocator_, arithmetic->block_boundary_infos_[i].a_offset_);
86 arithmetic->block_boundary_infos_[i].a_offset_ = NULL;
87 }
88 if (arithmetic->block_boundary_infos_[i].b_offset_ != NULL) {
89 self->env_->Free(self->env_->allocator_, arithmetic->block_boundary_infos_[i].b_offset_);
90 arithmetic->block_boundary_infos_[i].b_offset_ = NULL;
91 }
92 }
93 arithmetic->block_boundary_infos_size_ = 0;
94
95 if (arithmetic->a_matrix_.batch_post_sum_ != NULL) {
96 self->env_->Free(self->env_->allocator_, arithmetic->a_matrix_.batch_post_sum_);
97 arithmetic->a_matrix_.batch_post_sum_ = NULL;
98 }
99
100 if (arithmetic->b_matrix_.batch_post_sum_ != NULL) {
101 self->env_->Free(self->env_->allocator_, arithmetic->b_matrix_.batch_post_sum_);
102 arithmetic->b_matrix_.batch_post_sum_ = NULL;
103 }
104
105 if (arithmetic->c_matrix_.batch_post_sum_ != NULL) {
106 self->env_->Free(self->env_->allocator_, arithmetic->c_matrix_.batch_post_sum_);
107 arithmetic->c_matrix_.batch_post_sum_ = NULL;
108 }
109 return NNACL_OK;
110 }
111
ArithmeticComputeOffset(ArithmeticStruct * arithmetic,int task_id)112 void ArithmeticComputeOffset(ArithmeticStruct *arithmetic, int task_id) {
113 ArithmeticBlockBoundaryInfo *block_info = &arithmetic->block_boundary_infos_[task_id];
114 block_info->init_offset_ = true;
115
116 int64_t b_start = block_info->batch_begin_;
117 int64_t b_end = block_info->batch_end_;
118 int64_t s_end = block_info->size_end_;
119 if (s_end != 0) {
120 ++b_end;
121 }
122 int offset_index = 0;
123 for (; b_start < b_end; ++b_start) {
124 int64_t delta = b_start;
125 int64_t a_offset = 0;
126 int64_t b_offset = 0;
127 for (int j = 0; j <= arithmetic->batch_tail_dim_; ++j) {
128 if (j > 0) {
129 delta = delta % arithmetic->c_matrix_.batch_post_sum_[j];
130 }
131 if (j < arithmetic->batch_tail_dim_) {
132 a_offset += (delta / arithmetic->c_matrix_.batch_post_sum_[j + 1] * arithmetic->a_matrix_.shape_[j] /
133 arithmetic->c_matrix_.shape_[j]) *
134 arithmetic->a_matrix_.batch_post_sum_[j + 1];
135 b_offset += (delta / arithmetic->c_matrix_.batch_post_sum_[j + 1] * arithmetic->b_matrix_.shape_[j] /
136 arithmetic->c_matrix_.shape_[j]) *
137 arithmetic->b_matrix_.batch_post_sum_[j + 1];
138 } else {
139 a_offset += (delta * arithmetic->a_matrix_.shape_[j] / arithmetic->c_matrix_.shape_[j]);
140 b_offset += (delta * arithmetic->b_matrix_.shape_[j] / arithmetic->c_matrix_.shape_[j]);
141 }
142 }
143 block_info->a_offset_[offset_index] = a_offset * arithmetic->a_matrix_.inner_size_ * arithmetic->in_data_size_;
144 block_info->b_offset_[offset_index] = b_offset * arithmetic->b_matrix_.inner_size_ * arithmetic->in_data_size_;
145 offset_index++;
146 }
147 }
148
ArithmeticDoExecute(KernelBase * base,const void * input0,const void * input1,void * output,int64_t size)149 int ArithmeticDoExecute(KernelBase *base, const void *input0, const void *input1, void *output, int64_t size) {
150 ArithmeticStruct *arithmetic = (ArithmeticStruct *)base;
151 int data_type = arithmetic->base_.in_[FIRST_INPUT]->data_type_;
152 NNACL_CHECK_NULL_RETURN_ERR(input0);
153 NNACL_CHECK_NULL_RETURN_ERR(input1);
154
155 if (data_type == kNumberTypeFloat32) {
156 if (arithmetic->scalar_opt_) {
157 NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.optimzie_f32_);
158 return arithmetic->functions_.optimzie_f32_((const float *)input0, (const float *)input1, (float *)output, size,
159 arithmetic->in_elements_num0_ == 1);
160 } else {
161 NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.compute_f32_);
162 return arithmetic->functions_.compute_f32_((const float *)input0, (const float *)input1, (float *)output, size);
163 }
164 }
165
166 if (data_type == kNumberTypeBool) {
167 if (arithmetic->scalar_opt_) {
168 NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.optimzie_bool_);
169 return arithmetic->functions_.optimzie_bool_((const bool *)input0, (const bool *)input1, (bool *)output, size,
170 arithmetic->in_elements_num0_ == 1);
171 } else {
172 NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.compute_bool_);
173 return arithmetic->functions_.compute_bool_((const bool *)input0, (const bool *)input1, (bool *)output, size);
174 }
175 }
176
177 if (data_type == kNumberTypeInt32) {
178 if (arithmetic->scalar_opt_) {
179 NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.optimzie_int_);
180 return arithmetic->functions_.optimzie_int_((const int *)input0, (const int *)input1, (int *)output, size,
181 arithmetic->in_elements_num0_ == 1);
182 } else {
183 NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.compute_int_);
184 return arithmetic->functions_.compute_int_((const int *)input0, (const int *)input1, (int *)output, size);
185 }
186 }
187
188 return NNACL_UNSUPPORTED_DATA_TYPE;
189 }
190
ArithmeticRun(void * cdata,int task_id,float l,float r)191 int ArithmeticRun(void *cdata, int task_id, float l, float r) {
192 ArithmeticStruct *arithmetic = (ArithmeticStruct *)cdata;
193 NNACL_CHECK_FALSE(task_id < 0, NNACL_ERR);
194 NNACL_CHECK_FALSE(task_id >= arithmetic->block_boundary_infos_size_, NNACL_ERR);
195
196 if (arithmetic->block_boundary_infos_[task_id].init_offset_ == false) {
197 ArithmeticComputeOffset(arithmetic, task_id);
198 }
199
200 ArithmeticBlockBoundaryInfo *block_info = &arithmetic->block_boundary_infos_[task_id];
201 int64_t b_start = block_info->batch_begin_;
202 int64_t s_start = block_info->size_begin_;
203 int64_t s_end = block_info->size_end_;
204 int64_t index_start = 0;
205 int64_t index_end = block_info->batch_end_ - b_start;
206 uint8_t *a_ptr = (uint8_t *)(arithmetic->a_matrix_.data_) + block_info->a_offset_[index_start];
207 uint8_t *b_ptr = (uint8_t *)(arithmetic->b_matrix_.data_) + block_info->b_offset_[index_start];
208 uint8_t *c_ptr = (uint8_t *)(arithmetic->c_matrix_.data_) +
209 (b_start * arithmetic->c_matrix_.inner_size_ + s_start) * arithmetic->out_data_size_;
210 if (arithmetic->a_matrix_.inner_size_ > 1) {
211 a_ptr += s_start * arithmetic->in_data_size_;
212 }
213 if (arithmetic->b_matrix_.inner_size_ > 1) {
214 b_ptr += s_start * arithmetic->in_data_size_;
215 }
216
217 if (index_start == index_end) {
218 return arithmetic->execute_((KernelBase *)arithmetic, a_ptr, b_ptr, c_ptr, s_end - s_start);
219 }
220
221 int64_t size = arithmetic->c_matrix_.inner_size_ - s_start;
222 int ret = arithmetic->execute_((KernelBase *)arithmetic, a_ptr, b_ptr, c_ptr, size);
223 if (ret != NNACL_OK) {
224 return ret;
225 }
226
227 ++index_start;
228 c_ptr += size * arithmetic->out_data_size_;
229 int64_t c_stride = arithmetic->c_matrix_.inner_size_ * arithmetic->out_data_size_;
230 for (; index_start < index_end; ++index_start) {
231 a_ptr = (uint8_t *)(arithmetic->a_matrix_.data_) + block_info->a_offset_[index_start];
232 b_ptr = (uint8_t *)(arithmetic->b_matrix_.data_) + block_info->b_offset_[index_start];
233 ret = arithmetic->execute_((KernelBase *)arithmetic, a_ptr, b_ptr, c_ptr, arithmetic->c_matrix_.inner_size_);
234 if (ret != NNACL_OK) {
235 return ret;
236 }
237 c_ptr += c_stride;
238 }
239 if (s_end == 0) {
240 return NNACL_OK;
241 }
242 a_ptr = (uint8_t *)(arithmetic->a_matrix_.data_) + block_info->a_offset_[index_start];
243 b_ptr = (uint8_t *)(arithmetic->b_matrix_.data_) + block_info->b_offset_[index_start];
244 return arithmetic->execute_((KernelBase *)arithmetic, a_ptr, b_ptr, c_ptr, s_end);
245 }
246
ResetArithmeticMatric(KernelBase * base,ArithmeticMatrixInfo * matrix)247 void ResetArithmeticMatric(KernelBase *base, ArithmeticMatrixInfo *matrix) {
248 matrix->is_valid_ = false;
249 matrix->data_ = NULL;
250 matrix->inner_size_ = 1;
251 matrix->shape_size_ = 0;
252
253 if (matrix->batch_post_sum_ != NULL) {
254 base->env_->Free(base->env_->allocator_, matrix->batch_post_sum_);
255 matrix->batch_post_sum_ = NULL;
256 }
257 }
258
UpdateArithmeticParameter(ArithmeticStruct * arithmetic)259 int UpdateArithmeticParameter(ArithmeticStruct *arithmetic) {
260 NNACL_CHECK_TRUE_RET(arithmetic->a_matrix_.shape_size_ == arithmetic->b_matrix_.shape_size_,
261 NNACL_ARITHMETIC_SHAPE_INVALID);
262
263 arithmetic->ndim_ = arithmetic->a_matrix_.shape_size_;
264 ResetArithmeticMatric(&arithmetic->base_, &arithmetic->c_matrix_);
265
266 for (size_t i = 0; i < arithmetic->ndim_; ++i) {
267 NNACL_CHECK_TRUE_RET(arithmetic->a_matrix_.shape_[i] <= INT_MAX, NNACL_ARITHMETIC_SHAPE_INVALID);
268 NNACL_CHECK_TRUE_RET(arithmetic->b_matrix_.shape_[i] <= INT_MAX, NNACL_ARITHMETIC_SHAPE_INVALID);
269 arithmetic->in_shape0_[i] = arithmetic->a_matrix_.shape_[i];
270 arithmetic->in_shape1_[i] = arithmetic->b_matrix_.shape_[i];
271 arithmetic->out_shape_[i] = MSMAX(arithmetic->in_shape0_[i], arithmetic->in_shape1_[i]);
272 arithmetic->c_matrix_.shape_[arithmetic->c_matrix_.shape_size_++] =
273 MSMAX(arithmetic->a_matrix_.shape_[i], arithmetic->b_matrix_.shape_[i]);
274 }
275 return NNACL_OK;
276 }
277
OptimizeArithmeticShape(ArithmeticStruct * arithmetic)278 int OptimizeArithmeticShape(ArithmeticStruct *arithmetic) {
279 ArithmeticMatrixInfo *a = &arithmetic->a_matrix_;
280 ArithmeticMatrixInfo *b = &arithmetic->b_matrix_;
281 arithmetic->ndim_ = a->shape_size_ >= b->shape_size_ ? a->shape_size_ : b->shape_size_;
282
283 int shape0[MAX_LEN] = {0};
284 int shape1[MAX_LEN] = {0};
285 /* init a & b shape */
286 int i = 0;
287 for (; i < arithmetic->ndim_; ++i) {
288 shape0[i] = 1;
289 shape1[i] = 1;
290 }
291
292 /* init matrix shape dim */
293 int a_matrix_size = arithmetic->ndim_ - a->shape_size_;
294 for (i = a_matrix_size; i < arithmetic->ndim_; i++) {
295 shape0[i] = a->shape_[i - a_matrix_size];
296 }
297
298 int b_matrix_size = arithmetic->ndim_ - b->shape_size_;
299 for (i = b_matrix_size; i < arithmetic->ndim_; i++) {
300 shape1[i] = b->shape_[i - b_matrix_size];
301 }
302
303 /* horizontal shape dims */
304 int shape0_temp[MAX_LEN] = {0};
305 int shape1_temp[MAX_LEN] = {0};
306 int shape_temp_size = 0;
307 for (i = 0; i < arithmetic->ndim_;) { // horizontal comparison, merge the part of continuous 1.
308 shape0_temp[shape_temp_size] = shape0[i];
309 shape1_temp[shape_temp_size] = shape1[i];
310 shape_temp_size++;
311 if (shape0[i] != 1 && shape1[i] != 1) {
312 ++i;
313 continue;
314 }
315
316 size_t j0 = i;
317 while (j0 < arithmetic->ndim_ && shape0[j0] == 1) {
318 ++j0;
319 }
320 size_t j1 = i;
321 while (j1 < arithmetic->ndim_ && shape1[j1] == 1) {
322 ++j1;
323 }
324 size_t j = MSMAX(j0, j1);
325 while ((++i) < j) {
326 shape0_temp[shape_temp_size - 1] *= shape0[i];
327 shape1_temp[shape_temp_size - 1] *= shape1[i];
328 }
329 }
330
331 arithmetic->a_matrix_.shape_size_ = 0;
332 arithmetic->b_matrix_.shape_size_ = 0;
333
334 for (i = 0; i < shape_temp_size;) { // vertical comparison, merge the part of continuous equation.
335 if (shape0_temp[i] == 1 && shape1_temp[i] == 1) {
336 ++i;
337 continue;
338 }
339 shape0[arithmetic->a_matrix_.shape_size_++] = shape0_temp[i];
340 shape1[arithmetic->b_matrix_.shape_size_++] = shape1_temp[i];
341 if (shape0_temp[i] != shape1_temp[i]) {
342 ++i;
343 continue;
344 }
345 while ((++i) < shape_temp_size) {
346 if (shape0_temp[i] != shape1_temp[i]) {
347 break;
348 }
349 shape0[arithmetic->a_matrix_.shape_size_ - 1] *= shape0_temp[i];
350 shape1[arithmetic->b_matrix_.shape_size_ - 1] *= shape1_temp[i];
351 }
352 }
353
354 memcpy(arithmetic->a_matrix_.shape_, shape0, arithmetic->a_matrix_.shape_size_ * sizeof(int));
355 memcpy(arithmetic->b_matrix_.shape_, shape1, arithmetic->b_matrix_.shape_size_ * sizeof(int));
356
357 return UpdateArithmeticParameter(arithmetic);
358 }
359
ResetArithmeticStatus(ArithmeticStruct * arithmetic)360 int ResetArithmeticStatus(ArithmeticStruct *arithmetic) {
361 ResetArithmeticMatric(&arithmetic->base_, &arithmetic->a_matrix_);
362 ResetArithmeticMatric(&arithmetic->base_, &arithmetic->b_matrix_);
363 ResetArithmeticMatric(&arithmetic->base_, &arithmetic->c_matrix_);
364
365 arithmetic->a_matrix_.shape_size_ = arithmetic->base_.in_[FIRST_INPUT]->shape_size_;
366 memcpy(arithmetic->a_matrix_.shape_, arithmetic->base_.in_[FIRST_INPUT]->shape_,
367 arithmetic->a_matrix_.shape_size_ * sizeof(int));
368 arithmetic->b_matrix_.shape_size_ = arithmetic->base_.in_[SECOND_INPUT]->shape_size_;
369 memcpy(arithmetic->b_matrix_.shape_, arithmetic->base_.in_[SECOND_INPUT]->shape_,
370 arithmetic->b_matrix_.shape_size_ * sizeof(int));
371
372 return OptimizeArithmeticShape(arithmetic);
373 }
374
ArithmeticDoBroadcast(ArithmeticStruct * arithmetic,void * in_data,void * out_data,int input_index)375 void ArithmeticDoBroadcast(ArithmeticStruct *arithmetic, void *in_data, void *out_data, int input_index) {
376 int *in_shape = input_index == FIRST_INPUT ? arithmetic->in_shape0_ : arithmetic->in_shape1_;
377 int *in_stride = input_index == FIRST_INPUT ? arithmetic->in_strides0_ : arithmetic->in_strides1_;
378 int *multiples = input_index == FIRST_INPUT ? arithmetic->multiples0_ : arithmetic->multiples1_;
379 return arithmetic->tile_function_(in_data, out_data, 0, arithmetic->ndim_, in_shape, in_stride,
380 arithmetic->out_strides_, multiples);
381 }
382
ArithmeticBroadCastConstTensor(ArithmeticStruct * arithmetic)383 int ArithmeticBroadCastConstTensor(ArithmeticStruct *arithmetic) {
384 NNACL_CHECK_NULL_RETURN_ERR(arithmetic);
385
386 CalcStructMultiplesAndStrides(arithmetic);
387
388 #ifdef PARALLEL_INFERENCE
389 bool prefer_explicit_broadcast = false;
390 #else
391 bool prefer_explicit_broadcast = arithmetic->ndim_ != 1;
392 #endif
393 prefer_explicit_broadcast =
394 prefer_explicit_broadcast && (arithmetic->base_.in_[FIRST_INPUT]->data_type_ != kNumberTypeBool);
395
396 bool exist_broadcast_ = false;
397 int buffer_size = GetElementNum(arithmetic->base_.out_[OUTPUT_INDEX]) * arithmetic->in_data_size_;
398 if (arithmetic->a_matrix_.is_const_) {
399 NNACL_CHECK_NULL_RETURN_ERR(arithmetic->base_.in_[FIRST_INPUT]->data_);
400 if (arithmetic->in_elements_num0_ != arithmetic->out_elements_num_ && prefer_explicit_broadcast) {
401 exist_broadcast_ = true;
402
403 arithmetic->a_matrix_.data_ = arithmetic->base_.env_->Alloc(arithmetic->base_.env_->allocator_, buffer_size);
404 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->a_matrix_.data_);
405 arithmetic->broadcast_buffer_[Index0] = arithmetic->a_matrix_.data_;
406
407 ArithmeticDoBroadcast(arithmetic, arithmetic->base_.in_[FIRST_INPUT]->data_, arithmetic->a_matrix_.data_, Index0);
408 arithmetic->in_elements_num0_ = arithmetic->out_elements_num_;
409
410 // shape must be equal to out
411 for (size_t i = 0; i < arithmetic->ndim_; ++i) {
412 arithmetic->in_shape0_[i] = arithmetic->out_shape_[i];
413 arithmetic->in_strides0_[i] = arithmetic->out_strides_[i];
414 }
415 memcpy(arithmetic->a_matrix_.shape_, arithmetic->c_matrix_.shape_, arithmetic->ndim_ * sizeof(int));
416 arithmetic->a_matrix_.is_valid_ = true;
417 }
418 }
419
420 if (arithmetic->b_matrix_.is_const_) {
421 NNACL_CHECK_NULL_RETURN_ERR(arithmetic->base_.in_[SECOND_INPUT]->data_);
422 if (arithmetic->in_elements_num1_ != arithmetic->out_elements_num_ && prefer_explicit_broadcast) {
423 exist_broadcast_ = true;
424
425 arithmetic->b_matrix_.data_ = arithmetic->base_.env_->Alloc(arithmetic->base_.env_->allocator_, buffer_size);
426 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->b_matrix_.data_);
427 arithmetic->broadcast_buffer_[Index1] = arithmetic->b_matrix_.data_;
428
429 ArithmeticDoBroadcast(arithmetic, arithmetic->base_.in_[Index1]->data_, arithmetic->b_matrix_.data_, Index1);
430 arithmetic->in_elements_num1_ = arithmetic->out_elements_num_;
431 // shape must be equal to out
432 for (size_t i = 0; i < arithmetic->ndim_; ++i) {
433 arithmetic->in_shape1_[i] = arithmetic->out_shape_[i];
434 arithmetic->in_strides1_[i] = arithmetic->out_strides_[i];
435 }
436
437 memcpy(arithmetic->b_matrix_.shape_, arithmetic->c_matrix_.shape_, arithmetic->ndim_ * sizeof(int));
438 arithmetic->b_matrix_.is_valid_ = true;
439 }
440 }
441 if (!exist_broadcast_) {
442 return NNACL_OK;
443 }
444 return OptimizeArithmeticShape(arithmetic);
445 }
446
ArithmeticComputeOfflineInfo(ArithmeticStruct * arithmetic)447 int ArithmeticComputeOfflineInfo(ArithmeticStruct *arithmetic) {
448 int bread_pos = -1;
449 int last_dim = arithmetic->a_matrix_.shape_size_ - 1;
450 for (int i = last_dim; i >= 0; --i) {
451 if (arithmetic->a_matrix_.shape_[i] != arithmetic->b_matrix_.shape_[i]) {
452 bread_pos = i;
453 break;
454 }
455 }
456 arithmetic->batch_tail_dim_ = bread_pos;
457 if (bread_pos == last_dim && arithmetic->batch_tail_dim_ >= 0) {
458 --arithmetic->batch_tail_dim_;
459 }
460
461 for (int i = last_dim; i > arithmetic->batch_tail_dim_; --i) {
462 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(arithmetic->a_matrix_.inner_size_, arithmetic->a_matrix_.shape_[i], NNACL_ERR);
463 arithmetic->a_matrix_.inner_size_ *= arithmetic->a_matrix_.shape_[i];
464 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(arithmetic->b_matrix_.inner_size_, arithmetic->b_matrix_.shape_[i], NNACL_ERR);
465 arithmetic->b_matrix_.inner_size_ *= arithmetic->b_matrix_.shape_[i];
466 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(arithmetic->c_matrix_.inner_size_, arithmetic->c_matrix_.shape_[i], NNACL_ERR);
467 arithmetic->c_matrix_.inner_size_ *= arithmetic->c_matrix_.shape_[i];
468 }
469
470 arithmetic->a_matrix_.batch_post_sum_ = arithmetic->base_.env_->Alloc(
471 arithmetic->base_.env_->allocator_, (arithmetic->a_matrix_.shape_size_ + 1) * sizeof(int));
472 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->a_matrix_.batch_post_sum_);
473 for (int i = 0; i < arithmetic->a_matrix_.shape_size_ + 1; i++) {
474 arithmetic->a_matrix_.batch_post_sum_[i] = 1;
475 }
476
477 arithmetic->b_matrix_.batch_post_sum_ = arithmetic->base_.env_->Alloc(
478 arithmetic->base_.env_->allocator_, (arithmetic->b_matrix_.shape_size_ + 1) * sizeof(int));
479 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->b_matrix_.batch_post_sum_);
480 for (int i = 0; i < arithmetic->b_matrix_.shape_size_ + 1; i++) {
481 arithmetic->b_matrix_.batch_post_sum_[i] = 1;
482 }
483
484 arithmetic->c_matrix_.batch_post_sum_ = arithmetic->base_.env_->Alloc(
485 arithmetic->base_.env_->allocator_, (arithmetic->c_matrix_.shape_size_ + 1) * sizeof(int));
486 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->c_matrix_.batch_post_sum_);
487 for (int i = 0; i < arithmetic->c_matrix_.shape_size_ + 1; i++) {
488 arithmetic->c_matrix_.batch_post_sum_[i] = 1;
489 }
490
491 for (int i = arithmetic->batch_tail_dim_; i >= 0; --i) {
492 if (i == arithmetic->batch_tail_dim_) {
493 arithmetic->a_matrix_.batch_post_sum_[i] = arithmetic->a_matrix_.shape_[i];
494 arithmetic->b_matrix_.batch_post_sum_[i] = arithmetic->b_matrix_.shape_[i];
495 arithmetic->c_matrix_.batch_post_sum_[i] = arithmetic->c_matrix_.shape_[i];
496 } else {
497 arithmetic->a_matrix_.batch_post_sum_[i] =
498 arithmetic->a_matrix_.shape_[i] * arithmetic->a_matrix_.batch_post_sum_[i + 1];
499 arithmetic->b_matrix_.batch_post_sum_[i] =
500 arithmetic->b_matrix_.shape_[i] * arithmetic->b_matrix_.batch_post_sum_[i + 1];
501 arithmetic->c_matrix_.batch_post_sum_[i] =
502 arithmetic->c_matrix_.shape_[i] * arithmetic->c_matrix_.batch_post_sum_[i + 1];
503 }
504 }
505
506 arithmetic->scalar_opt_ = false;
507 if (arithmetic->a_matrix_.inner_size_ == 1) {
508 arithmetic->in_elements_num0_ = 1;
509 arithmetic->scalar_opt_ = true;
510 }
511 if (arithmetic->b_matrix_.inner_size_ == 1) {
512 arithmetic->in_elements_num1_ = 1;
513 arithmetic->scalar_opt_ = true;
514 }
515 return NNACL_OK;
516 }
517
ArithmeticChooseThreadCuttingStrategy(ArithmeticStruct * arithmetic)518 int ArithmeticChooseThreadCuttingStrategy(ArithmeticStruct *arithmetic) {
519 int total_num = GetElementNum(arithmetic->base_.out_[OUTPUT_INDEX]);
520 arithmetic->base_.thread_nr_ =
521 arithmetic->base_.UpdateThread(TC_TYPE(arithmetic->primitive_type_, arithmetic->functions_.activation_type_), 1, 1,
522 total_num, arithmetic->base_.thread_nr_);
523
524 int64_t block_size = UP_DIV(total_num, arithmetic->base_.thread_nr_);
525 int64_t split_point = 0;
526 while (split_point < total_num) {
527 int64_t start = split_point;
528 int64_t end = start + block_size;
529 if (end > total_num) {
530 end = total_num;
531 }
532 ArithmeticBlockBoundaryInfo block_boundary_info;
533 block_boundary_info.size_begin_ = start % arithmetic->c_matrix_.inner_size_;
534 block_boundary_info.size_end_ = end % arithmetic->c_matrix_.inner_size_;
535 block_boundary_info.batch_begin_ = start / arithmetic->c_matrix_.inner_size_;
536 block_boundary_info.batch_end_ = end / arithmetic->c_matrix_.inner_size_;
537 block_boundary_info.init_offset_ = false;
538
539 int max_offset_size = block_boundary_info.batch_end_ - block_boundary_info.batch_begin_ + TWO_TENSOR;
540 block_boundary_info.a_offset_ =
541 (int *)arithmetic->base_.env_->Alloc(arithmetic->base_.env_->allocator_, max_offset_size * sizeof(int));
542 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(block_boundary_info.a_offset_);
543 block_boundary_info.b_offset_ =
544 (int *)arithmetic->base_.env_->Alloc(arithmetic->base_.env_->allocator_, max_offset_size * sizeof(int));
545 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(block_boundary_info.b_offset_);
546
547 arithmetic->block_boundary_infos_[arithmetic->block_boundary_infos_size_++] = block_boundary_info;
548 split_point = end;
549 }
550
551 arithmetic->base_.thread_nr_ = arithmetic->block_boundary_infos_size_;
552 return NNACL_OK;
553 }
554
ArithmeticResize(struct KernelBase * self)555 int ArithmeticResize(struct KernelBase *self) {
556 ArithmeticStruct *arithmetic = (ArithmeticStruct *)self;
557 NNACL_CHECK_NULL_RETURN_ERR(arithmetic);
558
559 ArithmeticRelease(&arithmetic->base_);
560
561 NNACL_CHECK_TRUE_RET(arithmetic->in_data_size_ != 0, NNACL_UNSUPPORTED_DATA_TYPE);
562 NNACL_CHECK_TRUE_RET(arithmetic->out_data_size_ != 0, NNACL_UNSUPPORTED_DATA_TYPE);
563 arithmetic->in_elements_num0_ = GetElementNum(self->in_[FIRST_INPUT]);
564 arithmetic->in_elements_num1_ = GetElementNum(self->in_[SECOND_INPUT]);
565 arithmetic->out_elements_num_ = GetElementNum(self->in_[OUTPUT_INDEX]);
566
567 int ret = ResetArithmeticStatus(arithmetic);
568 if (ret != NNACL_OK) {
569 return ret;
570 }
571
572 ret = ArithmeticBroadCastConstTensor(arithmetic);
573 if (ret != NNACL_OK) {
574 return ret;
575 }
576
577 ret = ArithmeticComputeOfflineInfo(arithmetic);
578 if (ret != NNACL_OK) {
579 return ret;
580 }
581
582 return ArithmeticChooseThreadCuttingStrategy(arithmetic);
583 }
584
ArithmeticPrepare(struct KernelBase * self)585 int ArithmeticPrepare(struct KernelBase *self) {
586 ArithmeticStruct *arithmetic = (ArithmeticStruct *)self;
587 NNACL_CHECK_NULL_RETURN_ERR(arithmetic);
588
589 NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_ERR);
590 NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR);
591
592 NNACL_CHECK_FALSE(self->in_[FIRST_INPUT]->data_type_ < kNumberTypeBegin, NNACL_ERR);
593 NNACL_CHECK_FALSE(self->in_[FIRST_INPUT]->data_type_ > kNumberTypeEnd, NNACL_ERR);
594 NNACL_CHECK_FALSE(self->in_[SECOND_INPUT]->data_type_ > kNumberTypeEnd, NNACL_ERR);
595 NNACL_CHECK_FALSE(self->in_[SECOND_INPUT]->data_type_ > kNumberTypeEnd, NNACL_ERR);
596
597 if (self->param_->quant_type_ != Quant_None) {
598 return NNACL_ERR;
599 }
600
601 arithmetic->primitive_type_ = self->param_->type_;
602 if (self->param_->type_ == PrimType_Eltwise) {
603 switch (((ArithmeticParameter *)(self->param_))->eltwise_mode_) {
604 case Eltwise_PROD:
605 arithmetic->primitive_type_ = PrimType_MulFusion;
606 break;
607 case Eltwise_SUM:
608 arithmetic->primitive_type_ = PrimType_AddFusion;
609 break;
610 case Eltwise_MAXIMUM:
611 arithmetic->primitive_type_ = PrimType_Maximum;
612 break;
613 default:
614 return NNACL_ELTWISE_INVALID_MOD;
615 }
616 }
617 arithmetic->init_function_(self);
618
619 arithmetic->a_matrix_.is_const_ = IsConst(self->in_[FIRST_INPUT]);
620 arithmetic->b_matrix_.is_const_ = IsConst(self->in_[SECOND_INPUT]);
621 return NNACL_OK;
622 }
623
ArithmeticCompute(struct KernelBase * self)624 int ArithmeticCompute(struct KernelBase *self) {
625 ArithmeticStruct *arithmetic = (ArithmeticStruct *)self;
626 NNACL_CHECK_NULL_RETURN_ERR(arithmetic);
627 NNACL_CHECK_FALSE(self->in_[FIRST_INPUT]->data_type_ != self->in_[SECOND_INPUT]->data_type_,
628 NNACL_ARITHMETIC_DATA_TYPE_UNMATCH);
629
630 if (self->train_session_) {
631 arithmetic->in_data_size_ = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_);
632 }
633
634 if (false == arithmetic->a_matrix_.is_valid_) {
635 arithmetic->a_matrix_.data_ = self->in_[FIRST_INPUT]->data_;
636 }
637 NNACL_CHECK_NULL_RETURN_ERR(arithmetic->a_matrix_.data_);
638
639 if (false == arithmetic->b_matrix_.is_valid_) {
640 arithmetic->b_matrix_.data_ = self->in_[SECOND_INPUT]->data_;
641 }
642 NNACL_CHECK_NULL_RETURN_ERR(arithmetic->b_matrix_.data_);
643
644 arithmetic->c_matrix_.data_ = self->out_[OUTPUT_INDEX]->data_;
645 NNACL_CHECK_NULL_RETURN_ERR(arithmetic->c_matrix_.data_);
646
647 return self->env_->ParallelLaunch(self->env_->thread_pool_, ArithmeticRun, self, self->thread_nr_);
648 }
649
CreateArithmetic(OpParameter * param,int data_type)650 KernelBase *CreateArithmetic(OpParameter *param, int data_type) {
651 ArithmeticStruct *arithmetic = (ArithmeticStruct *)malloc(sizeof(ArithmeticStruct));
652 NNACL_MALLOC_CHECK_NULL_RETURN_NULL(arithmetic);
653 memset(arithmetic, 0, sizeof(ArithmeticStruct));
654 arithmetic->in_data_size_ = DataTypeCSize(data_type);
655 arithmetic->out_data_size_ = DataTypeCSize(data_type);
656 arithmetic->block_boundary_infos_size_ = 0;
657 arithmetic->a_matrix_.batch_post_sum_ = NULL;
658 arithmetic->b_matrix_.batch_post_sum_ = NULL;
659 arithmetic->c_matrix_.batch_post_sum_ = NULL;
660 arithmetic->broadcast_buffer_[FIRST_INPUT] = NULL;
661 arithmetic->broadcast_buffer_[SECOND_INPUT] = NULL;
662 arithmetic->tile_function_ = TileOneDimensionFp32;
663 arithmetic->init_function_ = InitArithmeticRunFunction;
664 arithmetic->execute_ = ArithmeticDoExecute;
665 arithmetic->base_.Prepare = ArithmeticPrepare;
666 arithmetic->base_.Resize = ArithmeticResize;
667 arithmetic->base_.Release = ArithmeticRelease;
668 arithmetic->base_.Compute = ArithmeticCompute;
669 return (KernelBase *)arithmetic;
670 }
671
672 REG_KERNEL_CREATOR(PrimType_MulFusion, kNumberTypeFloat32, CreateArithmetic)
673 REG_KERNEL_CREATOR(PrimType_MulFusion, kNumberTypeInt32, CreateArithmetic)
674 REG_KERNEL_CREATOR(PrimType_AddFusion, kNumberTypeBool, CreateArithmetic)
675 REG_KERNEL_CREATOR(PrimType_AddFusion, kNumberTypeFloat32, CreateArithmetic)
676 REG_KERNEL_CREATOR(PrimType_AddFusion, kNumberTypeInt32, CreateArithmetic)
677 REG_KERNEL_CREATOR(PrimType_SubFusion, kNumberTypeFloat32, CreateArithmetic)
678 REG_KERNEL_CREATOR(PrimType_SubFusion, kNumberTypeInt32, CreateArithmetic)
679 REG_KERNEL_CREATOR(PrimType_DivFusion, kNumberTypeFloat32, CreateArithmetic)
680 REG_KERNEL_CREATOR(PrimType_RealDiv, kNumberTypeFloat32, CreateArithmetic)
681 REG_KERNEL_CREATOR(PrimType_Mod, kNumberTypeFloat32, CreateArithmetic)
682 REG_KERNEL_CREATOR(PrimType_Mod, kNumberTypeInt32, CreateArithmetic)
683 REG_KERNEL_CREATOR(PrimType_LogicalAnd, kNumberTypeFloat32, CreateArithmetic)
684 REG_KERNEL_CREATOR(PrimType_LogicalAnd, kNumberTypeBool, CreateArithmetic)
685 REG_KERNEL_CREATOR(PrimType_LogicalAnd, kNumberTypeInt32, CreateArithmetic)
686 REG_KERNEL_CREATOR(PrimType_LogicalOr, kNumberTypeFloat32, CreateArithmetic)
687 REG_KERNEL_CREATOR(PrimType_LogicalOr, kNumberTypeBool, CreateArithmetic)
688 REG_KERNEL_CREATOR(PrimType_Maximum, kNumberTypeFloat32, CreateArithmetic)
689 REG_KERNEL_CREATOR(PrimType_Minimum, kNumberTypeFloat32, CreateArithmetic)
690 REG_KERNEL_CREATOR(PrimType_Maximum, kNumberTypeInt32, CreateArithmetic)
691 REG_KERNEL_CREATOR(PrimType_Minimum, kNumberTypeInt32, CreateArithmetic)
692 REG_KERNEL_CREATOR(PrimType_FloorDiv, kNumberTypeFloat32, CreateArithmetic)
693 REG_KERNEL_CREATOR(PrimType_FloorMod, kNumberTypeFloat32, CreateArithmetic)
694 REG_KERNEL_CREATOR(PrimType_FloorDiv, kNumberTypeInt32, CreateArithmetic)
695 REG_KERNEL_CREATOR(PrimType_FloorMod, kNumberTypeInt32, CreateArithmetic)
696 REG_KERNEL_CREATOR(PrimType_SquaredDifference, kNumberTypeFloat32, CreateArithmetic)
697 REG_KERNEL_CREATOR(PrimType_Eltwise, kNumberTypeFloat32, CreateArithmetic)
698 REG_KERNEL_CREATOR(PrimType_DivFusion, kNumberTypeInt32, CreateArithmetic)
699