1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/kernel/concat.h"
18 #include "nnacl/concat_parameter.h"
19 #include "nnacl/tensor_c.h"
20 #include "nnacl/op_base.h"
21 #include "nnacl/nnacl_common.h"
22 #include "nnacl/tensor_c_utils.h"
23
24 #define kConcatMinCostPerThread 16384
25
DoConcat(ConcatStruct * concat,int task_id)26 int DoConcat(ConcatStruct *concat, int task_id) {
27 NNACL_CHECK_FALSE(task_id < 0, NNACL_ERR);
28 NNACL_CHECK_FALSE(task_id > concat->block_size_, NNACL_ERR);
29
30 int all_bytes = GetSize(concat->base_.out_[FIRST_INPUT]);
31 int64_t start = concat->block_splits_[task_id];
32 int64_t end = task_id < (concat->block_size_ - 1) ? concat->block_splits_[task_id + 1] : all_bytes;
33 int64_t start_row = start / concat->inner_sizes_[concat->base_.in_size_];
34 int64_t end_row = end / concat->inner_sizes_[concat->base_.in_size_];
35
36 size_t src_buf_size = concat->base_.in_size_ * sizeof(uint8_t *);
37 NNACL_CHECK_MALLOC_SIZE(src_buf_size);
38 uint8_t **src = (uint8_t **)concat->base_.env_->Alloc(concat->base_.env_->allocator_, src_buf_size);
39 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(src);
40 for (size_t i = 0; i < concat->base_.in_size_; ++i) {
41 if (concat->is_with_data_[i]) {
42 src[i] = concat->inputs_ptr_[i] + start_row * concat->inner_sizes_[i];
43 }
44 }
45 uint8_t *out = concat->output_ + start;
46
47 int input_index = concat->block_boundary_infos_[task_id].begin_input_;
48 int end_index = concat->block_boundary_infos_[task_id].end_input_;
49 if (start_row == end_row) {
50 if (input_index == end_index) {
51 memcpy(out, src[input_index] + concat->block_boundary_infos_[task_id].begin_point_,
52 concat->block_boundary_infos_[task_id].end_point_ - concat->block_boundary_infos_[task_id].begin_point_);
53 concat->base_.env_->Free(concat->base_.env_->allocator_, src);
54 return NNACL_OK;
55 }
56 int64_t size = concat->inner_sizes_[input_index] - concat->block_boundary_infos_[task_id].begin_point_;
57 memcpy(out, src[input_index] + concat->block_boundary_infos_[task_id].begin_point_, size);
58 out += size;
59 ++input_index;
60 for (; input_index < end_index; ++input_index) {
61 memcpy(out, src[input_index], concat->inner_sizes_[input_index]);
62 out += concat->inner_sizes_[input_index];
63 }
64 memcpy(out, src[input_index], concat->block_boundary_infos_[task_id].end_point_);
65 concat->base_.env_->Free(concat->base_.env_->allocator_, src);
66 return NNACL_OK;
67 }
68 for (int i = 0; i < input_index; ++i) {
69 src[i] += concat->inner_sizes_[i];
70 }
71 int64_t size = concat->inner_sizes_[input_index] - concat->block_boundary_infos_[task_id].begin_point_;
72 memcpy(out, src[input_index] + concat->block_boundary_infos_[task_id].begin_point_, size);
73 src[input_index] += concat->inner_sizes_[input_index];
74 out += size;
75 ++input_index;
76 for (; input_index < concat->base_.in_size_; ++input_index) {
77 memcpy(out, src[input_index], concat->inner_sizes_[input_index]);
78 src[input_index] += concat->inner_sizes_[input_index];
79 out += concat->inner_sizes_[input_index];
80 }
81 ++start_row;
82 for (; start_row < end_row; ++start_row) {
83 for (input_index = 0; input_index < concat->base_.in_size_; ++input_index) {
84 memcpy(out, src[input_index], concat->inner_sizes_[input_index]);
85 src[input_index] += concat->inner_sizes_[input_index];
86 out += concat->inner_sizes_[input_index];
87 }
88 }
89 for (input_index = 0; input_index < end_index; ++input_index) {
90 memcpy(out, src[input_index], concat->inner_sizes_[input_index]);
91 out += concat->inner_sizes_[input_index];
92 }
93 memcpy(out, src[end_index], concat->block_boundary_infos_[task_id].end_point_);
94
95 concat->base_.env_->Free(concat->base_.env_->allocator_, src);
96 return NNACL_OK;
97 }
98
ConcatRun(void * cdata,int task_id,float l,float r)99 int ConcatRun(void *cdata, int task_id, float l, float r) {
100 ConcatStruct *concat = (ConcatStruct *)cdata;
101 NNACL_CHECK_NULL_RETURN_ERR(concat);
102 return DoConcat(concat, task_id);
103 }
104
InitConcatDynamicStatus(ConcatStruct * concat)105 int InitConcatDynamicStatus(ConcatStruct *concat) {
106 ConcatParameter *param = (ConcatParameter *)concat->base_.param_;
107 NNACL_CHECK_NULL_RETURN_ERR(param);
108
109 size_t i = 0;
110 int64_t output_inner_size = 0;
111 for (; i < concat->base_.in_size_; i++) {
112 TensorC *t = concat->base_.in_[i];
113 NNACL_CHECK_FALSE(param->axis_ >= t->shape_size_, NNACL_CONCAT_AXIS_INVALID);
114 int64_t outer_size = 1;
115 for (int j = 0; j < param->axis_; ++j) {
116 outer_size *= t->shape_[j];
117 }
118 int inner_size = DataTypeCSize(concat->data_type_);
119 NNACL_CHECK_TRUE_RET(inner_size > 0, NNACL_UNSUPPORTED_DATA_TYPE);
120
121 for (int j = param->axis_; j < t->shape_size_; ++j) {
122 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(inner_size, t->shape_[j], NNACL_CONCAT_SHAPE_INVALID);
123 inner_size *= t->shape_[j];
124 }
125 if (i == 0) {
126 concat->outer_size_ = outer_size;
127 } else {
128 NNACL_CHECK_TRUE_RET(concat->outer_size_ == outer_size, NNACL_CONCAT_SHAPE_INVALID);
129 }
130 if (inner_size == 0) {
131 concat->is_with_data_[i] = false;
132 concat->inner_sizes_[i] = inner_size;
133 continue;
134 }
135 concat->is_with_data_[i] = true;
136 concat->inner_sizes_[i] = inner_size;
137 output_inner_size += inner_size;
138 }
139 concat->inner_sizes_[i] = output_inner_size;
140 return NNACL_OK;
141 }
142
ComputeConcatUnitBoundary(ConcatStruct * concat,int64_t * pre_sum,int offset,int * input,int64_t * point)143 void ComputeConcatUnitBoundary(ConcatStruct *concat, int64_t *pre_sum, int offset, int *input, int64_t *point) {
144 size_t index = 0;
145 for (; index < concat->base_.in_size_; ++index) {
146 if (offset < pre_sum[index]) {
147 break;
148 }
149 }
150 *input = index;
151 *point = concat->inner_sizes_[index] - (pre_sum[index] - offset);
152 }
153
ChooseConcatThreadCuttingStrategy(ConcatStruct * concat)154 int ChooseConcatThreadCuttingStrategy(ConcatStruct *concat) {
155 NNACL_CHECK_TRUE_RET(concat->base_.thread_nr_ > 0, NNACL_ERR);
156
157 int all_bytes = GetSize(concat->base_.out_[FIRST_INPUT]);
158 int64_t thread_count = MSMAX(1, MSMIN(all_bytes / kConcatMinCostPerThread, concat->base_.thread_nr_));
159
160 NNACL_CHECK_ZERO_RETURN_ERR(thread_count);
161 int64_t block_size = all_bytes / thread_count;
162 int64_t remain_byte = all_bytes - block_size * thread_count;
163 int64_t *pre_sum =
164 (int64_t *)concat->base_.env_->Alloc(concat->base_.env_->allocator_, concat->base_.in_size_ * sizeof(int64_t));
165 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(pre_sum);
166 int64_t init_sum = 0;
167 for (size_t i = 0; i < concat->base_.in_size_; ++i) {
168 init_sum += concat->inner_sizes_[i];
169 pre_sum[i] = init_sum;
170 }
171
172 concat->block_size_ = 0;
173
174 int64_t block_spilt = 0;
175 while (block_spilt < all_bytes) {
176 concat->block_splits_[concat->block_size_] = block_spilt;
177 block_spilt += block_size;
178 if (remain_byte > 0) {
179 ++block_spilt;
180 --remain_byte;
181 }
182 int64_t start = concat->block_splits_[concat->block_size_];
183 int64_t end = block_spilt > all_bytes ? all_bytes : block_spilt;
184 int64_t start_offset = start - DOWN_ROUND(start, concat->inner_sizes_[concat->base_.in_size_]);
185 int64_t end_offset = end - DOWN_ROUND(end, concat->inner_sizes_[concat->base_.in_size_]);
186 ConcatBlockBoundaryInfo block_boundary_info;
187 ComputeConcatUnitBoundary(concat, pre_sum, start_offset, &block_boundary_info.begin_input_,
188 &block_boundary_info.begin_point_);
189 ComputeConcatUnitBoundary(concat, pre_sum, end_offset, &block_boundary_info.end_input_,
190 &block_boundary_info.end_point_);
191 concat->block_boundary_infos_[concat->block_size_] = block_boundary_info;
192 concat->block_size_++;
193 }
194
195 concat->base_.thread_nr_ = concat->block_size_;
196 concat->base_.env_->Free(concat->base_.env_->allocator_, pre_sum);
197 return NNACL_OK;
198 }
199
ConcatResize(KernelBase * self)200 int ConcatResize(KernelBase *self) {
201 ConcatStruct *concat = (ConcatStruct *)self;
202 NNACL_CHECK_NULL_RETURN_ERR(concat);
203 ConcatParameter *param = (ConcatParameter *)concat->base_.param_;
204 NNACL_CHECK_NULL_RETURN_ERR(param);
205
206 param->axis_ = param->axis_ >= 0 ? param->axis_ : self->in_[FIRST_INPUT]->shape_size_ + param->axis_;
207 NNACL_CHECK_FALSE(param->axis_ < 0, NNACL_CONCAT_AXIS_INVALID);
208 NNACL_CHECK_FALSE(param->axis_ >= self->in_[FIRST_INPUT]->shape_size_, NNACL_CONCAT_AXIS_INVALID);
209
210 int ret = InitConcatDynamicStatus(concat);
211 NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
212
213 if (concat->outer_size_ == 0 || concat->inner_sizes_[self->in_size_] == 0) {
214 return NNACL_OK;
215 }
216
217 return ChooseConcatThreadCuttingStrategy(concat);
218 }
219
ConcatPepare(KernelBase * self)220 int ConcatPepare(KernelBase *self) {
221 ConcatStruct *concat = (ConcatStruct *)self;
222 NNACL_CHECK_NULL_RETURN_ERR(concat);
223
224 concat->inputs_ptr_ = self->env_->Alloc(self->env_->allocator_, self->in_size_ * sizeof(uint8_t *));
225 NNACL_CHECK_NULL_RETURN_ERR(concat->inputs_ptr_);
226 concat->is_with_data_ = self->env_->Alloc(self->env_->allocator_, self->in_size_ * sizeof(bool));
227 NNACL_CHECK_NULL_RETURN_ERR(concat->is_with_data_);
228 concat->inner_sizes_ =
229 self->env_->Alloc(self->env_->allocator_, (self->in_size_ + self->out_size_) * sizeof(int64_t));
230 NNACL_CHECK_NULL_RETURN_ERR(concat->inner_sizes_);
231
232 return NNACL_OK;
233 }
234
ConcatRelease(KernelBase * self)235 int ConcatRelease(KernelBase *self) {
236 ConcatStruct *concat = (ConcatStruct *)self;
237 NNACL_CHECK_NULL_RETURN_ERR(concat);
238 if (concat->inputs_ptr_ != NULL) {
239 self->env_->Free(self->env_->allocator_, concat->inputs_ptr_);
240 }
241 if (concat->is_with_data_ != NULL) {
242 self->env_->Free(self->env_->allocator_, concat->is_with_data_);
243 }
244 if (concat->inner_sizes_ != NULL) {
245 self->env_->Free(self->env_->allocator_, concat->inner_sizes_);
246 }
247 return NNACL_OK;
248 }
249
ConcatCompute(KernelBase * self)250 int ConcatCompute(KernelBase *self) {
251 ConcatStruct *concat = (ConcatStruct *)self;
252 NNACL_CHECK_NULL_RETURN_ERR(concat);
253 if (concat->outer_size_ == 0 || concat->inner_sizes_[self->in_size_] == 0) {
254 return NNACL_OK;
255 }
256
257 for (size_t i = 0; i < self->in_size_; ++i) {
258 if (!concat->is_with_data_[i]) {
259 continue;
260 }
261 NNACL_CHECK_NULL_RETURN_ERR(self->in_[i]->data_);
262 concat->inputs_ptr_[i] = self->in_[i]->data_;
263 }
264
265 concat->output_ = self->out_[FIRST_INPUT]->data_;
266 NNACL_CHECK_NULL_RETURN_ERR(concat->output_);
267 return self->env_->ParallelLaunch(self->env_->thread_pool_, ConcatRun, self, self->thread_nr_);
268 }
269
CreateConcat(OpParameter * param,int data_type)270 KernelBase *CreateConcat(OpParameter *param, int data_type) {
271 ConcatStruct *concat = (ConcatStruct *)malloc(sizeof(ConcatStruct));
272 NNACL_CHECK_NULL_RETURN_NULL(concat);
273 memset(concat, 0, sizeof(ConcatStruct));
274 if (data_type == kNumberTypeBool) {
275 concat->data_type_ = data_type;
276 } else {
277 concat->data_type_ = kNumberTypeFloat32;
278 }
279 concat->inner_sizes_ = NULL;
280 concat->inputs_ptr_ = NULL;
281 concat->is_with_data_ = NULL;
282 concat->base_.Prepare = ConcatPepare;
283 concat->base_.Resize = ConcatResize;
284 concat->base_.Release = ConcatRelease;
285 concat->base_.Compute = ConcatCompute;
286 return (KernelBase *)concat;
287 }
288
289 REG_KERNEL_CREATOR(PrimType_Concat, kNumberTypeBool, CreateConcat)
290 REG_KERNEL_CREATOR(PrimType_Concat, kNumberTypeInt32, CreateConcat)
291 REG_KERNEL_CREATOR(PrimType_Concat, kNumberTypeFloat32, CreateConcat)
292