1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "wrapper/fp32/arithmetic_fp32_wrapper.h"
TileConstTensor(const float * in_data,float * out_data,size_t ndim,const int * in_shape,const int * in_strides,const int * out_strides,const int * multiple)17 void TileConstTensor(const float *in_data, float *out_data, size_t ndim, const int *in_shape, const int *in_strides,
18 const int *out_strides, const int *multiple) {
19 TileOneDimensionFp32(in_data, out_data, 0, ndim, in_shape, in_strides, out_strides, multiple);
20 }
21
ArithmeticExecute(const void * input0,const void * input1,void * output,int size,bool is_opt,ArithmeticFuncType arithmetic_func_type,const void * arithmetic_func,const ArithmeticParameter * param)22 void ArithmeticExecute(const void *input0, const void *input1, void *output, int size, bool is_opt,
23 ArithmeticFuncType arithmetic_func_type, const void *arithmetic_func,
24 const ArithmeticParameter *param) {
25 if (arithmetic_func_type == kArithmeticFuncFloat) {
26 if (is_opt) {
27 ArithmeticOptRun arithmetic_opt_run = (ArithmeticOptRun)(arithmetic_func);
28 arithmetic_opt_run((const float *)(input0), (const float *)(input1), (float *)(output), size, param);
29 } else {
30 ArithmeticRun arithmetic_run = (ArithmeticRun)(arithmetic_func);
31 arithmetic_run((const float *)(input0), (const float *)(input1), (float *)(output), size);
32 }
33 } else if (arithmetic_func_type == kArithmeticFuncBool) {
34 ArithmeticBoolRun arithmetic_run_bool = (ArithmeticBoolRun)(arithmetic_func);
35 arithmetic_run_bool((const bool *)(input0), (const bool *)(input1), (bool *)(output), size);
36 } else if (arithmetic_func_type == kArithmeticFuncInt) {
37 if (is_opt) {
38 ArithmeticOptIntRun arithmetic_opt_run_int = (ArithmeticOptIntRun)(arithmetic_func);
39 arithmetic_opt_run_int((const int *)(input0), (const int *)(input1), (int *)(output), size, param);
40 } else {
41 ArithmeticIntRun arithmetic_run_int = (ArithmeticIntRun)(arithmetic_func);
42 arithmetic_run_int((const int *)(input0), (const int *)(input1), (int *)(output), size);
43 }
44 }
45 }
46
BatchScalarCalc(const void * input0,const void * input1,void * output,int batch_size,int size,bool is_opt,const void * arithmetic_func,const ArithmeticWrapperInfo * wrapper_info,const ArithmeticParameter * param)47 void BatchScalarCalc(const void *input0, const void *input1, void *output, int batch_size, int size, bool is_opt,
48 const void *arithmetic_func, const ArithmeticWrapperInfo *wrapper_info,
49 const ArithmeticParameter *param) {
50 int offset0 = wrapper_info->offset0_;
51 int offset1 = wrapper_info->offset1_;
52 int out_offset = wrapper_info->out_offset_;
53 int stride0 = wrapper_info->stride0_;
54 int stride1 = wrapper_info->stride1_;
55 int out_stride = wrapper_info->out_stride_;
56 for (int i = 0; i < batch_size; i++) {
57 ArithmeticExecute((const uint8_t *)(input0) + offset0, (const uint8_t *)(input1) + offset1,
58 (uint8_t *)(output) + out_offset, size, is_opt, wrapper_info->arithmetic_func_type_,
59 arithmetic_func, param);
60 offset0 += stride0;
61 offset1 += stride1;
62 out_offset += out_stride;
63 }
64 }
65
BroadcastRun(const void * input0,const void * input1,void * output,int dim,int out_count,int out_thread_stride,int break_pos,int data_type_len,ArithmeticFuncType arithmetic_func_type,const void * arithmetic_func,const ArithmeticParameter * param)66 void BroadcastRun(const void *input0, const void *input1, void *output, int dim, int out_count, int out_thread_stride,
67 int break_pos, int data_type_len, ArithmeticFuncType arithmetic_func_type,
68 const void *arithmetic_func, const ArithmeticParameter *param) {
69 if (dim > break_pos) {
70 int offset = out_thread_stride * data_type_len;
71 ArithmeticExecute((const uint8_t *)(input0) + offset, (const uint8_t *)(input1) + offset,
72 (uint8_t *)(output) + offset, out_count, false, arithmetic_func_type, arithmetic_func, param);
73 }
74 int offset_size[] = {param->in_strides0_[dim] * data_type_len, param->in_strides1_[dim] * data_type_len,
75 param->out_strides_[dim] * data_type_len};
76 for (int i = 0; i < param->out_shape_[dim]; ++i) {
77 int pos0_ = param->in_shape0_[dim] == 1 ? 0 : i;
78 int pos1_ = param->in_shape1_[dim] == 1 ? 0 : i;
79 BroadcastRun((const uint8_t *)(input0) + pos0_ * offset_size[0], (const uint8_t *)(input1) + pos1_ * offset_size[1],
80 (uint8_t *)(output) + i * offset_size[2], dim + 1, out_count, out_thread_stride, break_pos,
81 data_type_len, arithmetic_func_type, arithmetic_func, param);
82 }
83 }
84