• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either arithmeticress or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/arithmetic.h"
18 #include "nnacl/op_base.h"
19 #include "nnacl/nnacl_common.h"
20 #include "nnacl/fp32/arithmetic_fp32.h"
21 #include "nnacl/fp32/mul_fp32.h"
22 #include "nnacl/tensor_c_utils.h"
23 #ifdef ENABLE_FP16
24 #include "nnacl/fp16/arithmetic_fp16.h"
25 #endif
26 
InitArithmeticRunFunction(KernelBase * self)27 void InitArithmeticRunFunction(KernelBase *self) {
28   ArithmeticStruct *arithmetic = (ArithmeticStruct *)self;
29 
30   ArithmeticFuncions fun_table[] = {
31     {PrimType_MulFusion, ActType_Relu, ElementMulRelu, ElementMulReluInt, NULL, ElementOptMulRelu, ElementOptMulReluInt,
32      NULL},
33     {PrimType_MulFusion, ActType_Relu6, ElementMulRelu6, ElementMulRelu6Int, NULL, ElementOptMulRelu6,
34      ElementOptMulRelu6Int, NULL},
35     {PrimType_MulFusion, ActType_No, ElementMul, ElementMulInt, NULL, ElementOptMul, ElementOptMulInt, NULL},
36     {PrimType_AddFusion, ActType_Relu, ElementAddRelu, NULL, NULL, ElementOptAddRelu, NULL, NULL},
37     {PrimType_AddFusion, ActType_Relu6, ElementAddRelu6, NULL, NULL, ElementOptAddRelu6, NULL, NULL},
38     {PrimType_AddFusion, ActType_No, ElementAdd, ElementAddInt, NULL, ElementOptAdd, ElementOptAddInt, NULL},
39     {PrimType_SubFusion, ActType_Relu, ElementSubRelu, NULL, NULL, ElementOptSubRelu, NULL, NULL},
40     {PrimType_SubFusion, ActType_Relu6, ElementSubRelu6, NULL, NULL, ElementOptSubRelu6, NULL, NULL},
41     {PrimType_SubFusion, ActType_No, ElementSub, ElementSubInt, NULL, ElementOptSub, ElementOptSubInt, NULL},
42     {PrimType_DivFusion, ActType_Relu, ElementDivRelu, NULL, NULL, ElementOptDivRelu, NULL, NULL},
43     {PrimType_DivFusion, ActType_Relu6, ElementDivRelu6, NULL, NULL, ElementOptDivRelu6, NULL, NULL},
44     {PrimType_DivFusion, ActType_No, ElementDiv, NULL, NULL, ElementOptDiv, ElementOptDivInt, NULL},
45     {PrimType_RealDiv, ActType_Relu, ElementDivRelu, NULL, NULL, ElementOptDivRelu, NULL, NULL},
46     {PrimType_RealDiv, ActType_Relu6, ElementDivRelu6, NULL, NULL, ElementOptDivRelu6, NULL, NULL},
47     {PrimType_RealDiv, ActType_No, ElementDiv, NULL, NULL, ElementOptDiv, ElementOptDivInt, NULL},
48     {PrimType_LogicalAnd, ActType_No, ElementLogicalAnd, ElementLogicalAndInt, ElementLogicalAndBool,
49      ElementOptLogicalAnd, ElementOptLogicalAndInt, ElementOptLogicalAndBool},
50     {PrimType_LogicalOr, ActType_No, ElementLogicalOr, NULL, ElementLogicalOrBool, NULL, NULL, ElementOptLogicalOrBool},
51     {PrimType_Maximum, ActType_No, ElementMaximum, ElementMaximumInt, NULL, ElementOptMaximum, ElementOptMaximumInt,
52      NULL},
53     {PrimType_Minimum, ActType_No, ElementMinimum, ElementMinimumInt, NULL, ElementOptMinimum, ElementOptMinimumInt,
54      NULL},
55     {PrimType_FloorMod, ActType_No, ElementFloorMod, ElementFloorModInt, NULL, ElementOptFloorMod,
56      ElementOptFloorModInt, NULL},
57     {PrimType_FloorDiv, ActType_No, ElementFloorDiv, ElementFloorDivInt, NULL, ElementOptFloorDiv,
58      ElementOptFloorDivInt, NULL},
59     {PrimType_Mod, ActType_No, ElementMod, ElementModInt, NULL, ElementOptMod, ElementOptModInt, NULL},
60     {PrimType_SquaredDifference, ActType_No, ElementSquaredDifference, NULL, NULL, ElementOptSquaredDifference, NULL,
61      NULL}};
62 
63   size_t length = sizeof(fun_table) / sizeof(ArithmeticFuncions);
64   for (size_t i = 0; i < length; i++) {
65     if (fun_table[i].primitive_type_ == arithmetic->primitive_type_ &&
66         fun_table[i].activation_type_ == ((ArithmeticParameter *)(arithmetic->base_.param_))->activation_type_) {
67       arithmetic->functions_ = fun_table[i];
68       return;
69     }
70   }
71 }
72 
ArithmeticRelease(struct KernelBase * self)73 int ArithmeticRelease(struct KernelBase *self) {
74   ArithmeticStruct *arithmetic = (ArithmeticStruct *)self;
75   NNACL_CHECK_NULL_RETURN_ERR(arithmetic);
76   for (int i = 0; i < TWO_TENSOR; i++) {
77     if (arithmetic->broadcast_buffer_[i] != NULL) {
78       self->env_->Free(self->env_->allocator_, arithmetic->broadcast_buffer_[i]);
79       arithmetic->broadcast_buffer_[i] = NULL;
80     }
81   }
82 
83   for (int i = 0; i < arithmetic->block_boundary_infos_size_; i++) {
84     if (arithmetic->block_boundary_infos_[i].a_offset_ != NULL) {
85       self->env_->Free(self->env_->allocator_, arithmetic->block_boundary_infos_[i].a_offset_);
86       arithmetic->block_boundary_infos_[i].a_offset_ = NULL;
87     }
88     if (arithmetic->block_boundary_infos_[i].b_offset_ != NULL) {
89       self->env_->Free(self->env_->allocator_, arithmetic->block_boundary_infos_[i].b_offset_);
90       arithmetic->block_boundary_infos_[i].b_offset_ = NULL;
91     }
92   }
93   arithmetic->block_boundary_infos_size_ = 0;
94 
95   if (arithmetic->a_matrix_.batch_post_sum_ != NULL) {
96     self->env_->Free(self->env_->allocator_, arithmetic->a_matrix_.batch_post_sum_);
97     arithmetic->a_matrix_.batch_post_sum_ = NULL;
98   }
99 
100   if (arithmetic->b_matrix_.batch_post_sum_ != NULL) {
101     self->env_->Free(self->env_->allocator_, arithmetic->b_matrix_.batch_post_sum_);
102     arithmetic->b_matrix_.batch_post_sum_ = NULL;
103   }
104 
105   if (arithmetic->c_matrix_.batch_post_sum_ != NULL) {
106     self->env_->Free(self->env_->allocator_, arithmetic->c_matrix_.batch_post_sum_);
107     arithmetic->c_matrix_.batch_post_sum_ = NULL;
108   }
109   return NNACL_OK;
110 }
111 
ArithmeticComputeOffset(ArithmeticStruct * arithmetic,int task_id)112 void ArithmeticComputeOffset(ArithmeticStruct *arithmetic, int task_id) {
113   ArithmeticBlockBoundaryInfo *block_info = &arithmetic->block_boundary_infos_[task_id];
114   block_info->init_offset_ = true;
115 
116   int64_t b_start = block_info->batch_begin_;
117   int64_t b_end = block_info->batch_end_;
118   int64_t s_end = block_info->size_end_;
119   if (s_end != 0) {
120     ++b_end;
121   }
122   int offset_index = 0;
123   for (; b_start < b_end; ++b_start) {
124     int64_t delta = b_start;
125     int64_t a_offset = 0;
126     int64_t b_offset = 0;
127     for (int j = 0; j <= arithmetic->batch_tail_dim_; ++j) {
128       if (j > 0) {
129         delta = delta % arithmetic->c_matrix_.batch_post_sum_[j];
130       }
131       if (j < arithmetic->batch_tail_dim_) {
132         a_offset += (delta / arithmetic->c_matrix_.batch_post_sum_[j + 1] * arithmetic->a_matrix_.shape_[j] /
133                      arithmetic->c_matrix_.shape_[j]) *
134                     arithmetic->a_matrix_.batch_post_sum_[j + 1];
135         b_offset += (delta / arithmetic->c_matrix_.batch_post_sum_[j + 1] * arithmetic->b_matrix_.shape_[j] /
136                      arithmetic->c_matrix_.shape_[j]) *
137                     arithmetic->b_matrix_.batch_post_sum_[j + 1];
138       } else {
139         a_offset += (delta * arithmetic->a_matrix_.shape_[j] / arithmetic->c_matrix_.shape_[j]);
140         b_offset += (delta * arithmetic->b_matrix_.shape_[j] / arithmetic->c_matrix_.shape_[j]);
141       }
142     }
143     block_info->a_offset_[offset_index] = a_offset * arithmetic->a_matrix_.inner_size_ * arithmetic->in_data_size_;
144     block_info->b_offset_[offset_index] = b_offset * arithmetic->b_matrix_.inner_size_ * arithmetic->in_data_size_;
145     offset_index++;
146   }
147 }
148 
ArithmeticDoExecute(KernelBase * base,const void * input0,const void * input1,void * output,int64_t size)149 int ArithmeticDoExecute(KernelBase *base, const void *input0, const void *input1, void *output, int64_t size) {
150   ArithmeticStruct *arithmetic = (ArithmeticStruct *)base;
151   int data_type = arithmetic->base_.in_[FIRST_INPUT]->data_type_;
152   NNACL_CHECK_NULL_RETURN_ERR(input0);
153   NNACL_CHECK_NULL_RETURN_ERR(input1);
154 
155   if (data_type == kNumberTypeFloat32) {
156     if (arithmetic->scalar_opt_) {
157       NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.optimzie_f32_);
158       return arithmetic->functions_.optimzie_f32_((const float *)input0, (const float *)input1, (float *)output, size,
159                                                   arithmetic->in_elements_num0_ == 1);
160     } else {
161       NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.compute_f32_);
162       return arithmetic->functions_.compute_f32_((const float *)input0, (const float *)input1, (float *)output, size);
163     }
164   }
165 
166   if (data_type == kNumberTypeBool) {
167     if (arithmetic->scalar_opt_) {
168       NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.optimzie_bool_);
169       return arithmetic->functions_.optimzie_bool_((const bool *)input0, (const bool *)input1, (bool *)output, size,
170                                                    arithmetic->in_elements_num0_ == 1);
171     } else {
172       NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.compute_bool_);
173       return arithmetic->functions_.compute_bool_((const bool *)input0, (const bool *)input1, (bool *)output, size);
174     }
175   }
176 
177   if (data_type == kNumberTypeInt32) {
178     if (arithmetic->scalar_opt_) {
179       NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.optimzie_int_);
180       return arithmetic->functions_.optimzie_int_((const int *)input0, (const int *)input1, (int *)output, size,
181                                                   arithmetic->in_elements_num0_ == 1);
182     } else {
183       NNACL_CHECK_NULL_RETURN_ERR(arithmetic->functions_.compute_int_);
184       return arithmetic->functions_.compute_int_((const int *)input0, (const int *)input1, (int *)output, size);
185     }
186   }
187 
188   return NNACL_UNSUPPORTED_DATA_TYPE;
189 }
190 
ArithmeticRun(void * cdata,int task_id,float l,float r)191 int ArithmeticRun(void *cdata, int task_id, float l, float r) {
192   ArithmeticStruct *arithmetic = (ArithmeticStruct *)cdata;
193   NNACL_CHECK_FALSE(task_id < 0, NNACL_ERR);
194   NNACL_CHECK_FALSE(task_id >= arithmetic->block_boundary_infos_size_, NNACL_ERR);
195 
196   if (arithmetic->block_boundary_infos_[task_id].init_offset_ == false) {
197     ArithmeticComputeOffset(arithmetic, task_id);
198   }
199 
200   ArithmeticBlockBoundaryInfo *block_info = &arithmetic->block_boundary_infos_[task_id];
201   int64_t b_start = block_info->batch_begin_;
202   int64_t s_start = block_info->size_begin_;
203   int64_t s_end = block_info->size_end_;
204   int64_t index_start = 0;
205   int64_t index_end = block_info->batch_end_ - b_start;
206   uint8_t *a_ptr = (uint8_t *)(arithmetic->a_matrix_.data_) + block_info->a_offset_[index_start];
207   uint8_t *b_ptr = (uint8_t *)(arithmetic->b_matrix_.data_) + block_info->b_offset_[index_start];
208   uint8_t *c_ptr = (uint8_t *)(arithmetic->c_matrix_.data_) +
209                    (b_start * arithmetic->c_matrix_.inner_size_ + s_start) * arithmetic->out_data_size_;
210   if (arithmetic->a_matrix_.inner_size_ > 1) {
211     a_ptr += s_start * arithmetic->in_data_size_;
212   }
213   if (arithmetic->b_matrix_.inner_size_ > 1) {
214     b_ptr += s_start * arithmetic->in_data_size_;
215   }
216 
217   if (index_start == index_end) {
218     return arithmetic->execute_((KernelBase *)arithmetic, a_ptr, b_ptr, c_ptr, s_end - s_start);
219   }
220 
221   int64_t size = arithmetic->c_matrix_.inner_size_ - s_start;
222   int ret = arithmetic->execute_((KernelBase *)arithmetic, a_ptr, b_ptr, c_ptr, size);
223   if (ret != NNACL_OK) {
224     return ret;
225   }
226 
227   ++index_start;
228   c_ptr += size * arithmetic->out_data_size_;
229   int64_t c_stride = arithmetic->c_matrix_.inner_size_ * arithmetic->out_data_size_;
230   for (; index_start < index_end; ++index_start) {
231     a_ptr = (uint8_t *)(arithmetic->a_matrix_.data_) + block_info->a_offset_[index_start];
232     b_ptr = (uint8_t *)(arithmetic->b_matrix_.data_) + block_info->b_offset_[index_start];
233     ret = arithmetic->execute_((KernelBase *)arithmetic, a_ptr, b_ptr, c_ptr, arithmetic->c_matrix_.inner_size_);
234     if (ret != NNACL_OK) {
235       return ret;
236     }
237     c_ptr += c_stride;
238   }
239   if (s_end == 0) {
240     return NNACL_OK;
241   }
242   a_ptr = (uint8_t *)(arithmetic->a_matrix_.data_) + block_info->a_offset_[index_start];
243   b_ptr = (uint8_t *)(arithmetic->b_matrix_.data_) + block_info->b_offset_[index_start];
244   return arithmetic->execute_((KernelBase *)arithmetic, a_ptr, b_ptr, c_ptr, s_end);
245 }
246 
ResetArithmeticMatric(KernelBase * base,ArithmeticMatrixInfo * matrix)247 void ResetArithmeticMatric(KernelBase *base, ArithmeticMatrixInfo *matrix) {
248   matrix->is_valid_ = false;
249   matrix->data_ = NULL;
250   matrix->inner_size_ = 1;
251   matrix->shape_size_ = 0;
252 
253   if (matrix->batch_post_sum_ != NULL) {
254     base->env_->Free(base->env_->allocator_, matrix->batch_post_sum_);
255     matrix->batch_post_sum_ = NULL;
256   }
257 }
258 
UpdateArithmeticParameter(ArithmeticStruct * arithmetic)259 int UpdateArithmeticParameter(ArithmeticStruct *arithmetic) {
260   NNACL_CHECK_TRUE_RET(arithmetic->a_matrix_.shape_size_ == arithmetic->b_matrix_.shape_size_,
261                        NNACL_ARITHMETIC_SHAPE_INVALID);
262 
263   arithmetic->ndim_ = arithmetic->a_matrix_.shape_size_;
264   ResetArithmeticMatric(&arithmetic->base_, &arithmetic->c_matrix_);
265 
266   for (size_t i = 0; i < arithmetic->ndim_; ++i) {
267     NNACL_CHECK_TRUE_RET(arithmetic->a_matrix_.shape_[i] <= INT_MAX, NNACL_ARITHMETIC_SHAPE_INVALID);
268     NNACL_CHECK_TRUE_RET(arithmetic->b_matrix_.shape_[i] <= INT_MAX, NNACL_ARITHMETIC_SHAPE_INVALID);
269     arithmetic->in_shape0_[i] = arithmetic->a_matrix_.shape_[i];
270     arithmetic->in_shape1_[i] = arithmetic->b_matrix_.shape_[i];
271     arithmetic->out_shape_[i] = MSMAX(arithmetic->in_shape0_[i], arithmetic->in_shape1_[i]);
272     arithmetic->c_matrix_.shape_[arithmetic->c_matrix_.shape_size_++] =
273       MSMAX(arithmetic->a_matrix_.shape_[i], arithmetic->b_matrix_.shape_[i]);
274   }
275   return NNACL_OK;
276 }
277 
OptimizeArithmeticShape(ArithmeticStruct * arithmetic)278 int OptimizeArithmeticShape(ArithmeticStruct *arithmetic) {
279   ArithmeticMatrixInfo *a = &arithmetic->a_matrix_;
280   ArithmeticMatrixInfo *b = &arithmetic->b_matrix_;
281   arithmetic->ndim_ = a->shape_size_ >= b->shape_size_ ? a->shape_size_ : b->shape_size_;
282 
283   int shape0[MAX_LEN] = {0};
284   int shape1[MAX_LEN] = {0};
285   /* init a & b shape */
286   int i = 0;
287   for (; i < arithmetic->ndim_; ++i) {
288     shape0[i] = 1;
289     shape1[i] = 1;
290   }
291 
292   /* init matrix shape dim */
293   int a_matrix_size = arithmetic->ndim_ - a->shape_size_;
294   for (i = a_matrix_size; i < arithmetic->ndim_; i++) {
295     shape0[i] = a->shape_[i - a_matrix_size];
296   }
297 
298   int b_matrix_size = arithmetic->ndim_ - b->shape_size_;
299   for (i = b_matrix_size; i < arithmetic->ndim_; i++) {
300     shape1[i] = b->shape_[i - b_matrix_size];
301   }
302 
303   /* horizontal shape dims */
304   int shape0_temp[MAX_LEN] = {0};
305   int shape1_temp[MAX_LEN] = {0};
306   int shape_temp_size = 0;
307   for (i = 0; i < arithmetic->ndim_;) {  // horizontal comparison, merge the part of continuous 1.
308     shape0_temp[shape_temp_size] = shape0[i];
309     shape1_temp[shape_temp_size] = shape1[i];
310     shape_temp_size++;
311     if (shape0[i] != 1 && shape1[i] != 1) {
312       ++i;
313       continue;
314     }
315 
316     size_t j0 = i;
317     while (j0 < arithmetic->ndim_ && shape0[j0] == 1) {
318       ++j0;
319     }
320     size_t j1 = i;
321     while (j1 < arithmetic->ndim_ && shape1[j1] == 1) {
322       ++j1;
323     }
324     size_t j = MSMAX(j0, j1);
325     while ((++i) < j) {
326       shape0_temp[shape_temp_size - 1] *= shape0[i];
327       shape1_temp[shape_temp_size - 1] *= shape1[i];
328     }
329   }
330 
331   arithmetic->a_matrix_.shape_size_ = 0;
332   arithmetic->b_matrix_.shape_size_ = 0;
333 
334   for (i = 0; i < shape_temp_size;) {  // vertical comparison, merge the part of continuous equation.
335     if (shape0_temp[i] == 1 && shape1_temp[i] == 1) {
336       ++i;
337       continue;
338     }
339     shape0[arithmetic->a_matrix_.shape_size_++] = shape0_temp[i];
340     shape1[arithmetic->b_matrix_.shape_size_++] = shape1_temp[i];
341     if (shape0_temp[i] != shape1_temp[i]) {
342       ++i;
343       continue;
344     }
345     while ((++i) < shape_temp_size) {
346       if (shape0_temp[i] != shape1_temp[i]) {
347         break;
348       }
349       shape0[arithmetic->a_matrix_.shape_size_ - 1] *= shape0_temp[i];
350       shape1[arithmetic->b_matrix_.shape_size_ - 1] *= shape1_temp[i];
351     }
352   }
353 
354   memcpy(arithmetic->a_matrix_.shape_, shape0, arithmetic->a_matrix_.shape_size_ * sizeof(int));
355   memcpy(arithmetic->b_matrix_.shape_, shape1, arithmetic->b_matrix_.shape_size_ * sizeof(int));
356 
357   return UpdateArithmeticParameter(arithmetic);
358 }
359 
ResetArithmeticStatus(ArithmeticStruct * arithmetic)360 int ResetArithmeticStatus(ArithmeticStruct *arithmetic) {
361   ResetArithmeticMatric(&arithmetic->base_, &arithmetic->a_matrix_);
362   ResetArithmeticMatric(&arithmetic->base_, &arithmetic->b_matrix_);
363   ResetArithmeticMatric(&arithmetic->base_, &arithmetic->c_matrix_);
364 
365   arithmetic->a_matrix_.shape_size_ = arithmetic->base_.in_[FIRST_INPUT]->shape_size_;
366   memcpy(arithmetic->a_matrix_.shape_, arithmetic->base_.in_[FIRST_INPUT]->shape_,
367          arithmetic->a_matrix_.shape_size_ * sizeof(int));
368   arithmetic->b_matrix_.shape_size_ = arithmetic->base_.in_[SECOND_INPUT]->shape_size_;
369   memcpy(arithmetic->b_matrix_.shape_, arithmetic->base_.in_[SECOND_INPUT]->shape_,
370          arithmetic->b_matrix_.shape_size_ * sizeof(int));
371 
372   return OptimizeArithmeticShape(arithmetic);
373 }
374 
ArithmeticDoBroadcast(ArithmeticStruct * arithmetic,void * in_data,void * out_data,int input_index)375 void ArithmeticDoBroadcast(ArithmeticStruct *arithmetic, void *in_data, void *out_data, int input_index) {
376   int *in_shape = input_index == FIRST_INPUT ? arithmetic->in_shape0_ : arithmetic->in_shape1_;
377   int *in_stride = input_index == FIRST_INPUT ? arithmetic->in_strides0_ : arithmetic->in_strides1_;
378   int *multiples = input_index == FIRST_INPUT ? arithmetic->multiples0_ : arithmetic->multiples1_;
379   return arithmetic->tile_function_(in_data, out_data, 0, arithmetic->ndim_, in_shape, in_stride,
380                                     arithmetic->out_strides_, multiples);
381 }
382 
ArithmeticBroadCastConstTensor(ArithmeticStruct * arithmetic)383 int ArithmeticBroadCastConstTensor(ArithmeticStruct *arithmetic) {
384   NNACL_CHECK_NULL_RETURN_ERR(arithmetic);
385 
386   CalcStructMultiplesAndStrides(arithmetic);
387 
388 #ifdef PARALLEL_INFERENCE
389   bool prefer_explicit_broadcast = false;
390 #else
391   bool prefer_explicit_broadcast = arithmetic->ndim_ != 1;
392 #endif
393   prefer_explicit_broadcast =
394     prefer_explicit_broadcast && (arithmetic->base_.in_[FIRST_INPUT]->data_type_ != kNumberTypeBool);
395 
396   bool exist_broadcast_ = false;
397   int buffer_size = GetElementNum(arithmetic->base_.out_[OUTPUT_INDEX]) * arithmetic->in_data_size_;
398   if (arithmetic->a_matrix_.is_const_) {
399     NNACL_CHECK_NULL_RETURN_ERR(arithmetic->base_.in_[FIRST_INPUT]->data_);
400     if (arithmetic->in_elements_num0_ != arithmetic->out_elements_num_ && prefer_explicit_broadcast) {
401       exist_broadcast_ = true;
402 
403       arithmetic->a_matrix_.data_ = arithmetic->base_.env_->Alloc(arithmetic->base_.env_->allocator_, buffer_size);
404       NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->a_matrix_.data_);
405       arithmetic->broadcast_buffer_[Index0] = arithmetic->a_matrix_.data_;
406 
407       ArithmeticDoBroadcast(arithmetic, arithmetic->base_.in_[FIRST_INPUT]->data_, arithmetic->a_matrix_.data_, Index0);
408       arithmetic->in_elements_num0_ = arithmetic->out_elements_num_;
409 
410       // shape must be equal to out
411       for (size_t i = 0; i < arithmetic->ndim_; ++i) {
412         arithmetic->in_shape0_[i] = arithmetic->out_shape_[i];
413         arithmetic->in_strides0_[i] = arithmetic->out_strides_[i];
414       }
415       memcpy(arithmetic->a_matrix_.shape_, arithmetic->c_matrix_.shape_, arithmetic->ndim_ * sizeof(int));
416       arithmetic->a_matrix_.is_valid_ = true;
417     }
418   }
419 
420   if (arithmetic->b_matrix_.is_const_) {
421     NNACL_CHECK_NULL_RETURN_ERR(arithmetic->base_.in_[SECOND_INPUT]->data_);
422     if (arithmetic->in_elements_num1_ != arithmetic->out_elements_num_ && prefer_explicit_broadcast) {
423       exist_broadcast_ = true;
424 
425       arithmetic->b_matrix_.data_ = arithmetic->base_.env_->Alloc(arithmetic->base_.env_->allocator_, buffer_size);
426       NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->b_matrix_.data_);
427       arithmetic->broadcast_buffer_[Index1] = arithmetic->b_matrix_.data_;
428 
429       ArithmeticDoBroadcast(arithmetic, arithmetic->base_.in_[Index1]->data_, arithmetic->b_matrix_.data_, Index1);
430       arithmetic->in_elements_num1_ = arithmetic->out_elements_num_;
431       // shape must be equal to out
432       for (size_t i = 0; i < arithmetic->ndim_; ++i) {
433         arithmetic->in_shape1_[i] = arithmetic->out_shape_[i];
434         arithmetic->in_strides1_[i] = arithmetic->out_strides_[i];
435       }
436 
437       memcpy(arithmetic->b_matrix_.shape_, arithmetic->c_matrix_.shape_, arithmetic->ndim_ * sizeof(int));
438       arithmetic->b_matrix_.is_valid_ = true;
439     }
440   }
441   if (!exist_broadcast_) {
442     return NNACL_OK;
443   }
444   return OptimizeArithmeticShape(arithmetic);
445 }
446 
ArithmeticComputeOfflineInfo(ArithmeticStruct * arithmetic)447 int ArithmeticComputeOfflineInfo(ArithmeticStruct *arithmetic) {
448   int bread_pos = -1;
449   int last_dim = arithmetic->a_matrix_.shape_size_ - 1;
450   for (int i = last_dim; i >= 0; --i) {
451     if (arithmetic->a_matrix_.shape_[i] != arithmetic->b_matrix_.shape_[i]) {
452       bread_pos = i;
453       break;
454     }
455   }
456   arithmetic->batch_tail_dim_ = bread_pos;
457   if (bread_pos == last_dim && arithmetic->batch_tail_dim_ >= 0) {
458     --arithmetic->batch_tail_dim_;
459   }
460 
461   for (int i = last_dim; i > arithmetic->batch_tail_dim_; --i) {
462     NNACL_CHECK_INT_MUL_NOT_OVERFLOW(arithmetic->a_matrix_.inner_size_, arithmetic->a_matrix_.shape_[i], NNACL_ERR);
463     arithmetic->a_matrix_.inner_size_ *= arithmetic->a_matrix_.shape_[i];
464     NNACL_CHECK_INT_MUL_NOT_OVERFLOW(arithmetic->b_matrix_.inner_size_, arithmetic->b_matrix_.shape_[i], NNACL_ERR);
465     arithmetic->b_matrix_.inner_size_ *= arithmetic->b_matrix_.shape_[i];
466     NNACL_CHECK_INT_MUL_NOT_OVERFLOW(arithmetic->c_matrix_.inner_size_, arithmetic->c_matrix_.shape_[i], NNACL_ERR);
467     arithmetic->c_matrix_.inner_size_ *= arithmetic->c_matrix_.shape_[i];
468   }
469 
470   arithmetic->a_matrix_.batch_post_sum_ = arithmetic->base_.env_->Alloc(
471     arithmetic->base_.env_->allocator_, (arithmetic->a_matrix_.shape_size_ + 1) * sizeof(int));
472   NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->a_matrix_.batch_post_sum_);
473   for (int i = 0; i < arithmetic->a_matrix_.shape_size_ + 1; i++) {
474     arithmetic->a_matrix_.batch_post_sum_[i] = 1;
475   }
476 
477   arithmetic->b_matrix_.batch_post_sum_ = arithmetic->base_.env_->Alloc(
478     arithmetic->base_.env_->allocator_, (arithmetic->b_matrix_.shape_size_ + 1) * sizeof(int));
479   NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->b_matrix_.batch_post_sum_);
480   for (int i = 0; i < arithmetic->b_matrix_.shape_size_ + 1; i++) {
481     arithmetic->b_matrix_.batch_post_sum_[i] = 1;
482   }
483 
484   arithmetic->c_matrix_.batch_post_sum_ = arithmetic->base_.env_->Alloc(
485     arithmetic->base_.env_->allocator_, (arithmetic->c_matrix_.shape_size_ + 1) * sizeof(int));
486   NNACL_MALLOC_CHECK_NULL_RETURN_ERR(arithmetic->c_matrix_.batch_post_sum_);
487   for (int i = 0; i < arithmetic->c_matrix_.shape_size_ + 1; i++) {
488     arithmetic->c_matrix_.batch_post_sum_[i] = 1;
489   }
490 
491   for (int i = arithmetic->batch_tail_dim_; i >= 0; --i) {
492     if (i == arithmetic->batch_tail_dim_) {
493       arithmetic->a_matrix_.batch_post_sum_[i] = arithmetic->a_matrix_.shape_[i];
494       arithmetic->b_matrix_.batch_post_sum_[i] = arithmetic->b_matrix_.shape_[i];
495       arithmetic->c_matrix_.batch_post_sum_[i] = arithmetic->c_matrix_.shape_[i];
496     } else {
497       arithmetic->a_matrix_.batch_post_sum_[i] =
498         arithmetic->a_matrix_.shape_[i] * arithmetic->a_matrix_.batch_post_sum_[i + 1];
499       arithmetic->b_matrix_.batch_post_sum_[i] =
500         arithmetic->b_matrix_.shape_[i] * arithmetic->b_matrix_.batch_post_sum_[i + 1];
501       arithmetic->c_matrix_.batch_post_sum_[i] =
502         arithmetic->c_matrix_.shape_[i] * arithmetic->c_matrix_.batch_post_sum_[i + 1];
503     }
504   }
505 
506   arithmetic->scalar_opt_ = false;
507   if (arithmetic->a_matrix_.inner_size_ == 1) {
508     arithmetic->in_elements_num0_ = 1;
509     arithmetic->scalar_opt_ = true;
510   }
511   if (arithmetic->b_matrix_.inner_size_ == 1) {
512     arithmetic->in_elements_num1_ = 1;
513     arithmetic->scalar_opt_ = true;
514   }
515   return NNACL_OK;
516 }
517 
ArithmeticChooseThreadCuttingStrategy(ArithmeticStruct * arithmetic)518 int ArithmeticChooseThreadCuttingStrategy(ArithmeticStruct *arithmetic) {
519   int total_num = GetElementNum(arithmetic->base_.out_[OUTPUT_INDEX]);
520   arithmetic->base_.thread_nr_ =
521     arithmetic->base_.UpdateThread(TC_TYPE(arithmetic->primitive_type_, arithmetic->functions_.activation_type_), 1, 1,
522                                    total_num, arithmetic->base_.thread_nr_);
523 
524   int64_t block_size = UP_DIV(total_num, arithmetic->base_.thread_nr_);
525   int64_t split_point = 0;
526   while (split_point < total_num) {
527     int64_t start = split_point;
528     int64_t end = start + block_size;
529     if (end > total_num) {
530       end = total_num;
531     }
532     ArithmeticBlockBoundaryInfo block_boundary_info;
533     block_boundary_info.size_begin_ = start % arithmetic->c_matrix_.inner_size_;
534     block_boundary_info.size_end_ = end % arithmetic->c_matrix_.inner_size_;
535     block_boundary_info.batch_begin_ = start / arithmetic->c_matrix_.inner_size_;
536     block_boundary_info.batch_end_ = end / arithmetic->c_matrix_.inner_size_;
537     block_boundary_info.init_offset_ = false;
538 
539     int max_offset_size = block_boundary_info.batch_end_ - block_boundary_info.batch_begin_ + TWO_TENSOR;
540     block_boundary_info.a_offset_ =
541       (int *)arithmetic->base_.env_->Alloc(arithmetic->base_.env_->allocator_, max_offset_size * sizeof(int));
542     NNACL_MALLOC_CHECK_NULL_RETURN_ERR(block_boundary_info.a_offset_);
543     block_boundary_info.b_offset_ =
544       (int *)arithmetic->base_.env_->Alloc(arithmetic->base_.env_->allocator_, max_offset_size * sizeof(int));
545     NNACL_MALLOC_CHECK_NULL_RETURN_ERR(block_boundary_info.b_offset_);
546 
547     arithmetic->block_boundary_infos_[arithmetic->block_boundary_infos_size_++] = block_boundary_info;
548     split_point = end;
549   }
550 
551   arithmetic->base_.thread_nr_ = arithmetic->block_boundary_infos_size_;
552   return NNACL_OK;
553 }
554 
ArithmeticResize(struct KernelBase * self)555 int ArithmeticResize(struct KernelBase *self) {
556   ArithmeticStruct *arithmetic = (ArithmeticStruct *)self;
557   NNACL_CHECK_NULL_RETURN_ERR(arithmetic);
558 
559   ArithmeticRelease(&arithmetic->base_);
560 
561   NNACL_CHECK_TRUE_RET(arithmetic->in_data_size_ != 0, NNACL_UNSUPPORTED_DATA_TYPE);
562   NNACL_CHECK_TRUE_RET(arithmetic->out_data_size_ != 0, NNACL_UNSUPPORTED_DATA_TYPE);
563   arithmetic->in_elements_num0_ = GetElementNum(self->in_[FIRST_INPUT]);
564   arithmetic->in_elements_num1_ = GetElementNum(self->in_[SECOND_INPUT]);
565   arithmetic->out_elements_num_ = GetElementNum(self->in_[OUTPUT_INDEX]);
566 
567   int ret = ResetArithmeticStatus(arithmetic);
568   if (ret != NNACL_OK) {
569     return ret;
570   }
571 
572   ret = ArithmeticBroadCastConstTensor(arithmetic);
573   if (ret != NNACL_OK) {
574     return ret;
575   }
576 
577   ret = ArithmeticComputeOfflineInfo(arithmetic);
578   if (ret != NNACL_OK) {
579     return ret;
580   }
581 
582   return ArithmeticChooseThreadCuttingStrategy(arithmetic);
583 }
584 
ArithmeticPrepare(struct KernelBase * self)585 int ArithmeticPrepare(struct KernelBase *self) {
586   ArithmeticStruct *arithmetic = (ArithmeticStruct *)self;
587   NNACL_CHECK_NULL_RETURN_ERR(arithmetic);
588 
589   NNACL_CHECK_FALSE(self->in_size_ < TWO_TENSOR, NNACL_ERR);
590   NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_ERR);
591 
592   NNACL_CHECK_FALSE(self->in_[FIRST_INPUT]->data_type_ < kNumberTypeBegin, NNACL_ERR);
593   NNACL_CHECK_FALSE(self->in_[FIRST_INPUT]->data_type_ > kNumberTypeEnd, NNACL_ERR);
594   NNACL_CHECK_FALSE(self->in_[SECOND_INPUT]->data_type_ > kNumberTypeEnd, NNACL_ERR);
595   NNACL_CHECK_FALSE(self->in_[SECOND_INPUT]->data_type_ > kNumberTypeEnd, NNACL_ERR);
596 
597   if (self->param_->quant_type_ != Quant_None) {
598     return NNACL_ERR;
599   }
600 
601   arithmetic->primitive_type_ = self->param_->type_;
602   if (self->param_->type_ == PrimType_Eltwise) {
603     switch (((ArithmeticParameter *)(self->param_))->eltwise_mode_) {
604       case Eltwise_PROD:
605         arithmetic->primitive_type_ = PrimType_MulFusion;
606         break;
607       case Eltwise_SUM:
608         arithmetic->primitive_type_ = PrimType_AddFusion;
609         break;
610       case Eltwise_MAXIMUM:
611         arithmetic->primitive_type_ = PrimType_Maximum;
612         break;
613       default:
614         return NNACL_ELTWISE_INVALID_MOD;
615     }
616   }
617   arithmetic->init_function_(self);
618 
619   arithmetic->a_matrix_.is_const_ = IsConst(self->in_[FIRST_INPUT]);
620   arithmetic->b_matrix_.is_const_ = IsConst(self->in_[SECOND_INPUT]);
621   return NNACL_OK;
622 }
623 
ArithmeticCompute(struct KernelBase * self)624 int ArithmeticCompute(struct KernelBase *self) {
625   ArithmeticStruct *arithmetic = (ArithmeticStruct *)self;
626   NNACL_CHECK_NULL_RETURN_ERR(arithmetic);
627   NNACL_CHECK_FALSE(self->in_[FIRST_INPUT]->data_type_ != self->in_[SECOND_INPUT]->data_type_,
628                     NNACL_ARITHMETIC_DATA_TYPE_UNMATCH);
629 
630   if (self->train_session_) {
631     arithmetic->in_data_size_ = DataTypeCSize(self->in_[FIRST_INPUT]->data_type_);
632   }
633 
634   if (false == arithmetic->a_matrix_.is_valid_) {
635     arithmetic->a_matrix_.data_ = self->in_[FIRST_INPUT]->data_;
636   }
637   NNACL_CHECK_NULL_RETURN_ERR(arithmetic->a_matrix_.data_);
638 
639   if (false == arithmetic->b_matrix_.is_valid_) {
640     arithmetic->b_matrix_.data_ = self->in_[SECOND_INPUT]->data_;
641   }
642   NNACL_CHECK_NULL_RETURN_ERR(arithmetic->b_matrix_.data_);
643 
644   arithmetic->c_matrix_.data_ = self->out_[OUTPUT_INDEX]->data_;
645   NNACL_CHECK_NULL_RETURN_ERR(arithmetic->c_matrix_.data_);
646 
647   return self->env_->ParallelLaunch(self->env_->thread_pool_, ArithmeticRun, self, self->thread_nr_);
648 }
649 
CreateArithmetic(OpParameter * param,int data_type)650 KernelBase *CreateArithmetic(OpParameter *param, int data_type) {
651   ArithmeticStruct *arithmetic = (ArithmeticStruct *)malloc(sizeof(ArithmeticStruct));
652   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(arithmetic);
653   memset(arithmetic, 0, sizeof(ArithmeticStruct));
654   arithmetic->in_data_size_ = DataTypeCSize(data_type);
655   arithmetic->out_data_size_ = DataTypeCSize(data_type);
656   arithmetic->block_boundary_infos_size_ = 0;
657   arithmetic->a_matrix_.batch_post_sum_ = NULL;
658   arithmetic->b_matrix_.batch_post_sum_ = NULL;
659   arithmetic->c_matrix_.batch_post_sum_ = NULL;
660   arithmetic->broadcast_buffer_[FIRST_INPUT] = NULL;
661   arithmetic->broadcast_buffer_[SECOND_INPUT] = NULL;
662   arithmetic->tile_function_ = TileOneDimensionFp32;
663   arithmetic->init_function_ = InitArithmeticRunFunction;
664   arithmetic->execute_ = ArithmeticDoExecute;
665   arithmetic->base_.Prepare = ArithmeticPrepare;
666   arithmetic->base_.Resize = ArithmeticResize;
667   arithmetic->base_.Release = ArithmeticRelease;
668   arithmetic->base_.Compute = ArithmeticCompute;
669   return (KernelBase *)arithmetic;
670 }
671 
672 REG_KERNEL_CREATOR(PrimType_MulFusion, kNumberTypeFloat32, CreateArithmetic)
673 REG_KERNEL_CREATOR(PrimType_MulFusion, kNumberTypeInt32, CreateArithmetic)
674 REG_KERNEL_CREATOR(PrimType_AddFusion, kNumberTypeBool, CreateArithmetic)
675 REG_KERNEL_CREATOR(PrimType_AddFusion, kNumberTypeFloat32, CreateArithmetic)
676 REG_KERNEL_CREATOR(PrimType_AddFusion, kNumberTypeInt32, CreateArithmetic)
677 REG_KERNEL_CREATOR(PrimType_SubFusion, kNumberTypeFloat32, CreateArithmetic)
678 REG_KERNEL_CREATOR(PrimType_SubFusion, kNumberTypeInt32, CreateArithmetic)
679 REG_KERNEL_CREATOR(PrimType_DivFusion, kNumberTypeFloat32, CreateArithmetic)
680 REG_KERNEL_CREATOR(PrimType_RealDiv, kNumberTypeFloat32, CreateArithmetic)
681 REG_KERNEL_CREATOR(PrimType_Mod, kNumberTypeFloat32, CreateArithmetic)
682 REG_KERNEL_CREATOR(PrimType_Mod, kNumberTypeInt32, CreateArithmetic)
683 REG_KERNEL_CREATOR(PrimType_LogicalAnd, kNumberTypeFloat32, CreateArithmetic)
684 REG_KERNEL_CREATOR(PrimType_LogicalAnd, kNumberTypeBool, CreateArithmetic)
685 REG_KERNEL_CREATOR(PrimType_LogicalAnd, kNumberTypeInt32, CreateArithmetic)
686 REG_KERNEL_CREATOR(PrimType_LogicalOr, kNumberTypeFloat32, CreateArithmetic)
687 REG_KERNEL_CREATOR(PrimType_LogicalOr, kNumberTypeBool, CreateArithmetic)
688 REG_KERNEL_CREATOR(PrimType_Maximum, kNumberTypeFloat32, CreateArithmetic)
689 REG_KERNEL_CREATOR(PrimType_Minimum, kNumberTypeFloat32, CreateArithmetic)
690 REG_KERNEL_CREATOR(PrimType_Maximum, kNumberTypeInt32, CreateArithmetic)
691 REG_KERNEL_CREATOR(PrimType_Minimum, kNumberTypeInt32, CreateArithmetic)
692 REG_KERNEL_CREATOR(PrimType_FloorDiv, kNumberTypeFloat32, CreateArithmetic)
693 REG_KERNEL_CREATOR(PrimType_FloorMod, kNumberTypeFloat32, CreateArithmetic)
694 REG_KERNEL_CREATOR(PrimType_FloorDiv, kNumberTypeInt32, CreateArithmetic)
695 REG_KERNEL_CREATOR(PrimType_FloorMod, kNumberTypeInt32, CreateArithmetic)
696 REG_KERNEL_CREATOR(PrimType_SquaredDifference, kNumberTypeFloat32, CreateArithmetic)
697 REG_KERNEL_CREATOR(PrimType_Eltwise, kNumberTypeFloat32, CreateArithmetic)
698 REG_KERNEL_CREATOR(PrimType_DivFusion, kNumberTypeInt32, CreateArithmetic)
699