1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/kernel/f16/concat_f16.h"
18 #include "nnacl/kernel/concat.h"
19 #include "nnacl/fp16/cast_fp16.h"
20 #include "nnacl/fp16/utils_fp16.h"
21 #include "nnacl/tensor_c_utils.h"
22
23 typedef struct ConcatF16Struct {
24 ConcatStruct concat_;
25 void **tmp_buffer_; /* in_size + out_size */
26 } ConcatF16Struct;
27
ConcatEnsureFp16InputsAndOutput(ConcatF16Struct * concat_f16)28 int ConcatEnsureFp16InputsAndOutput(ConcatF16Struct *concat_f16) {
29 ConcatStruct *concat = &concat_f16->concat_;
30
31 int tmp_buffer_size = (concat->base_.in_size_ + concat->base_.out_size_) * sizeof(float16_t *);
32 concat_f16->tmp_buffer_ = concat->base_.env_->Alloc(concat->base_.env_->allocator_, tmp_buffer_size);
33 NNACL_CHECK_NULL_RETURN_ERR(concat_f16->tmp_buffer_);
34 memset(concat_f16->tmp_buffer_, 0, tmp_buffer_size);
35
36 for (size_t i = 0; i < concat->base_.in_size_; ++i) {
37 if (!concat->is_with_data_[i]) {
38 continue;
39 }
40
41 concat->inputs_ptr_[i] = GetOrAllocFp16Data(concat->base_.in_[i], concat->base_.env_, true);
42 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(concat->inputs_ptr_[i]);
43 if (concat->base_.in_[i]->data_type_ == kNumberTypeFloat32 ||
44 concat->base_.in_[i]->data_type_ == kNumberTypeFloat) {
45 concat_f16->tmp_buffer_[i] = concat->inputs_ptr_[i];
46 }
47 }
48
49 concat->output_ = GetOrAllocFp16Data(concat->base_.out_[OUTPUT_INDEX], concat->base_.env_, false);
50 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(concat->output_);
51 if (concat->base_.out_[OUTPUT_INDEX]->data_type_ == kNumberTypeFloat32 ||
52 concat->base_.out_[OUTPUT_INDEX]->data_type_ == kNumberTypeFloat) {
53 concat_f16->tmp_buffer_[concat->base_.in_size_] = concat->output_;
54 }
55 return NNACL_OK;
56 }
57
ConcatFp16Run(void * cdata,int task_id,float l,float r)58 int ConcatFp16Run(void *cdata, int task_id, float l, float r) {
59 ConcatF16Struct *concat_f16 = (ConcatF16Struct *)cdata;
60 NNACL_CHECK_NULL_RETURN_ERR(concat_f16);
61 ConcatStruct *concat = &concat_f16->concat_;
62 return DoConcat(concat, task_id);
63 }
64
ConcatF16FreeTmpBuffer(ConcatF16Struct * concat_f16)65 void ConcatF16FreeTmpBuffer(ConcatF16Struct *concat_f16) {
66 if (concat_f16->tmp_buffer_ != NULL) {
67 /* free tmp_buffer_[i] */
68 for (int i = 0; i < (concat_f16->concat_.base_.in_size_ + concat_f16->concat_.base_.out_size_); i++) {
69 if (concat_f16->tmp_buffer_[i] != NULL) {
70 concat_f16->concat_.base_.env_->Free(concat_f16->concat_.base_.env_->allocator_, concat_f16->tmp_buffer_[i]);
71 }
72 concat_f16->tmp_buffer_[i] = NULL;
73 }
74
75 /* free tmp_buffer_ */
76 concat_f16->concat_.base_.env_->Free(concat_f16->concat_.base_.env_->allocator_, concat_f16->tmp_buffer_);
77 concat_f16->tmp_buffer_ = NULL;
78 }
79 }
80
ConcatF16Compute(KernelBase * self)81 int ConcatF16Compute(KernelBase *self) {
82 ConcatF16Struct *concat_f16 = (ConcatF16Struct *)self;
83 NNACL_CHECK_NULL_RETURN_ERR(concat_f16);
84 ConcatStruct *concat = &concat_f16->concat_;
85
86 if (concat->outer_size_ == 0 || concat->inner_sizes_[self->in_size_] == 0) {
87 return NNACL_OK;
88 }
89
90 int ret = ConcatEnsureFp16InputsAndOutput(concat_f16);
91 if (ret != NNACL_OK) {
92 ConcatF16FreeTmpBuffer(concat_f16);
93 return ret;
94 }
95
96 NNACL_CHECK_NULL_RETURN_ERR(concat->output_);
97 ret = self->env_->ParallelLaunch(self->env_->thread_pool_, ConcatFp16Run, self, self->thread_nr_);
98 if (ret == NNACL_OK) {
99 TensorC *output_tensor = concat->base_.out_[FIRST_INPUT];
100 if (output_tensor->data_type_ == kNumberTypeFloat32 || output_tensor->data_type_ == kNumberTypeFloat) {
101 float *output = concat->base_.out_[FIRST_INPUT]->data_;
102 if (output == NULL) {
103 ret = NNACL_CONCAT_F16_OUTPUT_DATA_INVALID;
104 } else {
105 Float16ToFloat32((float16_t *)concat->output_, output, GetElementNum(output_tensor));
106 }
107 }
108 }
109
110 ConcatF16FreeTmpBuffer(concat_f16);
111 return ret;
112 }
113
CreateConcatF16(OpParameter * param,int data_type)114 KernelBase *CreateConcatF16(OpParameter *param, int data_type) {
115 ConcatF16Struct *concat_f16 = (ConcatF16Struct *)malloc(sizeof(ConcatF16Struct));
116 NNACL_CHECK_NULL_RETURN_NULL(concat_f16);
117 memset(concat_f16, 0, sizeof(ConcatF16Struct));
118
119 ConcatStruct *concat = &concat_f16->concat_;
120 concat->data_type_ = kNumberTypeFloat16;
121 concat->inner_sizes_ = NULL;
122 concat->inputs_ptr_ = NULL;
123 concat->is_with_data_ = NULL;
124 concat->base_.Prepare = ConcatPepare;
125 concat->base_.Resize = ConcatResize;
126 concat->base_.Release = ConcatRelease;
127 concat->base_.Compute = ConcatF16Compute;
128 concat_f16->tmp_buffer_ = NULL;
129 return (KernelBase *)concat;
130 }
131
132 REG_KERNEL_CREATOR(PrimType_Concat, kNumberTypeFloat16, CreateConcatF16)
133