• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "wrapper/fp32/arithmetic_fp32_wrapper.h"
TileConstTensor(const float * in_data,float * out_data,size_t ndim,const int * in_shape,const int * in_strides,const int * out_strides,const int * multiple)17 void TileConstTensor(const float *in_data, float *out_data, size_t ndim, const int *in_shape, const int *in_strides,
18                      const int *out_strides, const int *multiple) {
19   TileOneDimensionFp32(in_data, out_data, 0, ndim, in_shape, in_strides, out_strides, multiple);
20 }
21 
ArithmeticExecute(const void * input0,const void * input1,void * output,int size,bool is_opt,ArithmeticFuncType arithmetic_func_type,const void * arithmetic_func,const ArithmeticParameter * param)22 void ArithmeticExecute(const void *input0, const void *input1, void *output, int size, bool is_opt,
23                        ArithmeticFuncType arithmetic_func_type, const void *arithmetic_func,
24                        const ArithmeticParameter *param) {
25   if (arithmetic_func_type == kArithmeticFuncFloat) {
26     if (is_opt) {
27       ArithmeticOptRun arithmetic_opt_run = (ArithmeticOptRun)(arithmetic_func);
28       arithmetic_opt_run((const float *)(input0), (const float *)(input1), (float *)(output), size, param);
29     } else {
30       ArithmeticRun arithmetic_run = (ArithmeticRun)(arithmetic_func);
31       arithmetic_run((const float *)(input0), (const float *)(input1), (float *)(output), size);
32     }
33   } else if (arithmetic_func_type == kArithmeticFuncBool) {
34     ArithmeticBoolRun arithmetic_run_bool = (ArithmeticBoolRun)(arithmetic_func);
35     arithmetic_run_bool((const bool *)(input0), (const bool *)(input1), (bool *)(output), size);
36   } else if (arithmetic_func_type == kArithmeticFuncInt) {
37     if (is_opt) {
38       ArithmeticOptIntRun arithmetic_opt_run_int = (ArithmeticOptIntRun)(arithmetic_func);
39       arithmetic_opt_run_int((const int *)(input0), (const int *)(input1), (int *)(output), size, param);
40     } else {
41       ArithmeticIntRun arithmetic_run_int = (ArithmeticIntRun)(arithmetic_func);
42       arithmetic_run_int((const int *)(input0), (const int *)(input1), (int *)(output), size);
43     }
44   }
45 }
46 
BatchScalarCalc(const void * input0,const void * input1,void * output,int batch_size,int size,bool is_opt,const void * arithmetic_func,const ArithmeticWrapperInfo * wrapper_info,const ArithmeticParameter * param)47 void BatchScalarCalc(const void *input0, const void *input1, void *output, int batch_size, int size, bool is_opt,
48                      const void *arithmetic_func, const ArithmeticWrapperInfo *wrapper_info,
49                      const ArithmeticParameter *param) {
50   int offset0 = wrapper_info->offset0_;
51   int offset1 = wrapper_info->offset1_;
52   int out_offset = wrapper_info->out_offset_;
53   int stride0 = wrapper_info->stride0_;
54   int stride1 = wrapper_info->stride1_;
55   int out_stride = wrapper_info->out_stride_;
56   for (int i = 0; i < batch_size; i++) {
57     ArithmeticExecute((const uint8_t *)(input0) + offset0, (const uint8_t *)(input1) + offset1,
58                       (uint8_t *)(output) + out_offset, size, is_opt, wrapper_info->arithmetic_func_type_,
59                       arithmetic_func, param);
60     offset0 += stride0;
61     offset1 += stride1;
62     out_offset += out_stride;
63   }
64 }
65 
BroadcastRun(const void * input0,const void * input1,void * output,int dim,int out_count,int out_thread_stride,int break_pos,int data_type_len,ArithmeticFuncType arithmetic_func_type,const void * arithmetic_func,const ArithmeticParameter * param)66 void BroadcastRun(const void *input0, const void *input1, void *output, int dim, int out_count, int out_thread_stride,
67                   int break_pos, int data_type_len, ArithmeticFuncType arithmetic_func_type,
68                   const void *arithmetic_func, const ArithmeticParameter *param) {
69   if (dim > break_pos) {
70     int offset = out_thread_stride * data_type_len;
71     ArithmeticExecute((const uint8_t *)(input0) + offset, (const uint8_t *)(input1) + offset,
72                       (uint8_t *)(output) + offset, out_count, false, arithmetic_func_type, arithmetic_func, param);
73   }
74   int offset_size[] = {param->in_strides0_[dim] * data_type_len, param->in_strides1_[dim] * data_type_len,
75                        param->out_strides_[dim] * data_type_len};
76   for (int i = 0; i < param->out_shape_[dim]; ++i) {
77     int pos0_ = param->in_shape0_[dim] == 1 ? 0 : i;
78     int pos1_ = param->in_shape1_[dim] == 1 ? 0 : i;
79     BroadcastRun((const uint8_t *)(input0) + pos0_ * offset_size[0], (const uint8_t *)(input1) + pos1_ * offset_size[1],
80                  (uint8_t *)(output) + i * offset_size[2], dim + 1, out_count, out_thread_stride, break_pos,
81                  data_type_len, arithmetic_func_type, arithmetic_func, param);
82   }
83 }
84