• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/concat.h"
18 #include "nnacl/concat_parameter.h"
19 #include "nnacl/tensor_c.h"
20 #include "nnacl/op_base.h"
21 #include "nnacl/nnacl_common.h"
22 #include "nnacl/tensor_c_utils.h"
23 
24 #define kConcatMinCostPerThread 16384
25 
DoConcat(ConcatStruct * concat,int task_id)26 int DoConcat(ConcatStruct *concat, int task_id) {
27   NNACL_CHECK_FALSE(task_id < 0, NNACL_ERR);
28   NNACL_CHECK_FALSE(task_id > concat->block_size_, NNACL_ERR);
29 
30   int all_bytes = GetSize(concat->base_.out_[FIRST_INPUT]);
31   int64_t start = concat->block_splits_[task_id];
32   int64_t end = task_id < (concat->block_size_ - 1) ? concat->block_splits_[task_id + 1] : all_bytes;
33   int64_t start_row = start / concat->inner_sizes_[concat->base_.in_size_];
34   int64_t end_row = end / concat->inner_sizes_[concat->base_.in_size_];
35 
36   size_t src_buf_size = concat->base_.in_size_ * sizeof(uint8_t *);
37   NNACL_CHECK_MALLOC_SIZE(src_buf_size);
38   uint8_t **src = (uint8_t **)concat->base_.env_->Alloc(concat->base_.env_->allocator_, src_buf_size);
39   NNACL_MALLOC_CHECK_NULL_RETURN_ERR(src);
40   for (size_t i = 0; i < concat->base_.in_size_; ++i) {
41     if (concat->is_with_data_[i]) {
42       src[i] = concat->inputs_ptr_[i] + start_row * concat->inner_sizes_[i];
43     }
44   }
45   uint8_t *out = concat->output_ + start;
46 
47   int input_index = concat->block_boundary_infos_[task_id].begin_input_;
48   int end_index = concat->block_boundary_infos_[task_id].end_input_;
49   if (start_row == end_row) {
50     if (input_index == end_index) {
51       memcpy(out, src[input_index] + concat->block_boundary_infos_[task_id].begin_point_,
52              concat->block_boundary_infos_[task_id].end_point_ - concat->block_boundary_infos_[task_id].begin_point_);
53       concat->base_.env_->Free(concat->base_.env_->allocator_, src);
54       return NNACL_OK;
55     }
56     int64_t size = concat->inner_sizes_[input_index] - concat->block_boundary_infos_[task_id].begin_point_;
57     memcpy(out, src[input_index] + concat->block_boundary_infos_[task_id].begin_point_, size);
58     out += size;
59     ++input_index;
60     for (; input_index < end_index; ++input_index) {
61       memcpy(out, src[input_index], concat->inner_sizes_[input_index]);
62       out += concat->inner_sizes_[input_index];
63     }
64     memcpy(out, src[input_index], concat->block_boundary_infos_[task_id].end_point_);
65     concat->base_.env_->Free(concat->base_.env_->allocator_, src);
66     return NNACL_OK;
67   }
68   for (int i = 0; i < input_index; ++i) {
69     src[i] += concat->inner_sizes_[i];
70   }
71   int64_t size = concat->inner_sizes_[input_index] - concat->block_boundary_infos_[task_id].begin_point_;
72   memcpy(out, src[input_index] + concat->block_boundary_infos_[task_id].begin_point_, size);
73   src[input_index] += concat->inner_sizes_[input_index];
74   out += size;
75   ++input_index;
76   for (; input_index < concat->base_.in_size_; ++input_index) {
77     memcpy(out, src[input_index], concat->inner_sizes_[input_index]);
78     src[input_index] += concat->inner_sizes_[input_index];
79     out += concat->inner_sizes_[input_index];
80   }
81   ++start_row;
82   for (; start_row < end_row; ++start_row) {
83     for (input_index = 0; input_index < concat->base_.in_size_; ++input_index) {
84       memcpy(out, src[input_index], concat->inner_sizes_[input_index]);
85       src[input_index] += concat->inner_sizes_[input_index];
86       out += concat->inner_sizes_[input_index];
87     }
88   }
89   for (input_index = 0; input_index < end_index; ++input_index) {
90     memcpy(out, src[input_index], concat->inner_sizes_[input_index]);
91     out += concat->inner_sizes_[input_index];
92   }
93   memcpy(out, src[end_index], concat->block_boundary_infos_[task_id].end_point_);
94 
95   concat->base_.env_->Free(concat->base_.env_->allocator_, src);
96   return NNACL_OK;
97 }
98 
ConcatRun(void * cdata,int task_id,float l,float r)99 int ConcatRun(void *cdata, int task_id, float l, float r) {
100   ConcatStruct *concat = (ConcatStruct *)cdata;
101   NNACL_CHECK_NULL_RETURN_ERR(concat);
102   return DoConcat(concat, task_id);
103 }
104 
InitConcatDynamicStatus(ConcatStruct * concat)105 int InitConcatDynamicStatus(ConcatStruct *concat) {
106   ConcatParameter *param = (ConcatParameter *)concat->base_.param_;
107   NNACL_CHECK_NULL_RETURN_ERR(param);
108 
109   size_t i = 0;
110   int64_t output_inner_size = 0;
111   for (; i < concat->base_.in_size_; i++) {
112     TensorC *t = concat->base_.in_[i];
113     NNACL_CHECK_FALSE(param->axis_ >= t->shape_size_, NNACL_CONCAT_AXIS_INVALID);
114     int64_t outer_size = 1;
115     for (int j = 0; j < param->axis_; ++j) {
116       outer_size *= t->shape_[j];
117     }
118     int inner_size = DataTypeCSize(concat->data_type_);
119     NNACL_CHECK_TRUE_RET(inner_size > 0, NNACL_UNSUPPORTED_DATA_TYPE);
120 
121     for (int j = param->axis_; j < t->shape_size_; ++j) {
122       NNACL_CHECK_INT_MUL_NOT_OVERFLOW(inner_size, t->shape_[j], NNACL_CONCAT_SHAPE_INVALID);
123       inner_size *= t->shape_[j];
124     }
125     if (i == 0) {
126       concat->outer_size_ = outer_size;
127     } else {
128       NNACL_CHECK_TRUE_RET(concat->outer_size_ == outer_size, NNACL_CONCAT_SHAPE_INVALID);
129     }
130     if (inner_size == 0) {
131       concat->is_with_data_[i] = false;
132       concat->inner_sizes_[i] = inner_size;
133       continue;
134     }
135     concat->is_with_data_[i] = true;
136     concat->inner_sizes_[i] = inner_size;
137     output_inner_size += inner_size;
138   }
139   concat->inner_sizes_[i] = output_inner_size;
140   return NNACL_OK;
141 }
142 
ComputeConcatUnitBoundary(ConcatStruct * concat,int64_t * pre_sum,int offset,int * input,int64_t * point)143 void ComputeConcatUnitBoundary(ConcatStruct *concat, int64_t *pre_sum, int offset, int *input, int64_t *point) {
144   size_t index = 0;
145   for (; index < concat->base_.in_size_; ++index) {
146     if (offset < pre_sum[index]) {
147       break;
148     }
149   }
150   *input = index;
151   *point = concat->inner_sizes_[index] - (pre_sum[index] - offset);
152 }
153 
ChooseConcatThreadCuttingStrategy(ConcatStruct * concat)154 int ChooseConcatThreadCuttingStrategy(ConcatStruct *concat) {
155   NNACL_CHECK_TRUE_RET(concat->base_.thread_nr_ > 0, NNACL_ERR);
156 
157   int all_bytes = GetSize(concat->base_.out_[FIRST_INPUT]);
158   int64_t thread_count = MSMAX(1, MSMIN(all_bytes / kConcatMinCostPerThread, concat->base_.thread_nr_));
159 
160   NNACL_CHECK_ZERO_RETURN_ERR(thread_count);
161   int64_t block_size = all_bytes / thread_count;
162   int64_t remain_byte = all_bytes - block_size * thread_count;
163   int64_t *pre_sum =
164     (int64_t *)concat->base_.env_->Alloc(concat->base_.env_->allocator_, concat->base_.in_size_ * sizeof(int64_t));
165   NNACL_MALLOC_CHECK_NULL_RETURN_ERR(pre_sum);
166   int64_t init_sum = 0;
167   for (size_t i = 0; i < concat->base_.in_size_; ++i) {
168     init_sum += concat->inner_sizes_[i];
169     pre_sum[i] = init_sum;
170   }
171 
172   concat->block_size_ = 0;
173 
174   int64_t block_spilt = 0;
175   while (block_spilt < all_bytes) {
176     concat->block_splits_[concat->block_size_] = block_spilt;
177     block_spilt += block_size;
178     if (remain_byte > 0) {
179       ++block_spilt;
180       --remain_byte;
181     }
182     int64_t start = concat->block_splits_[concat->block_size_];
183     int64_t end = block_spilt > all_bytes ? all_bytes : block_spilt;
184     int64_t start_offset = start - DOWN_ROUND(start, concat->inner_sizes_[concat->base_.in_size_]);
185     int64_t end_offset = end - DOWN_ROUND(end, concat->inner_sizes_[concat->base_.in_size_]);
186     ConcatBlockBoundaryInfo block_boundary_info;
187     ComputeConcatUnitBoundary(concat, pre_sum, start_offset, &block_boundary_info.begin_input_,
188                               &block_boundary_info.begin_point_);
189     ComputeConcatUnitBoundary(concat, pre_sum, end_offset, &block_boundary_info.end_input_,
190                               &block_boundary_info.end_point_);
191     concat->block_boundary_infos_[concat->block_size_] = block_boundary_info;
192     concat->block_size_++;
193   }
194 
195   concat->base_.thread_nr_ = concat->block_size_;
196   concat->base_.env_->Free(concat->base_.env_->allocator_, pre_sum);
197   return NNACL_OK;
198 }
199 
ConcatResize(KernelBase * self)200 int ConcatResize(KernelBase *self) {
201   ConcatStruct *concat = (ConcatStruct *)self;
202   NNACL_CHECK_NULL_RETURN_ERR(concat);
203   ConcatParameter *param = (ConcatParameter *)concat->base_.param_;
204   NNACL_CHECK_NULL_RETURN_ERR(param);
205 
206   param->axis_ = param->axis_ >= 0 ? param->axis_ : self->in_[FIRST_INPUT]->shape_size_ + param->axis_;
207   NNACL_CHECK_FALSE(param->axis_ < 0, NNACL_CONCAT_AXIS_INVALID);
208   NNACL_CHECK_FALSE(param->axis_ >= self->in_[FIRST_INPUT]->shape_size_, NNACL_CONCAT_AXIS_INVALID);
209 
210   int ret = InitConcatDynamicStatus(concat);
211   NNACL_CHECK_FALSE(ret != NNACL_OK, ret);
212 
213   if (concat->outer_size_ == 0 || concat->inner_sizes_[self->in_size_] == 0) {
214     return NNACL_OK;
215   }
216 
217   return ChooseConcatThreadCuttingStrategy(concat);
218 }
219 
ConcatPepare(KernelBase * self)220 int ConcatPepare(KernelBase *self) {
221   ConcatStruct *concat = (ConcatStruct *)self;
222   NNACL_CHECK_NULL_RETURN_ERR(concat);
223 
224   concat->inputs_ptr_ = self->env_->Alloc(self->env_->allocator_, self->in_size_ * sizeof(uint8_t *));
225   NNACL_CHECK_NULL_RETURN_ERR(concat->inputs_ptr_);
226   concat->is_with_data_ = self->env_->Alloc(self->env_->allocator_, self->in_size_ * sizeof(bool));
227   NNACL_CHECK_NULL_RETURN_ERR(concat->is_with_data_);
228   concat->inner_sizes_ =
229     self->env_->Alloc(self->env_->allocator_, (self->in_size_ + self->out_size_) * sizeof(int64_t));
230   NNACL_CHECK_NULL_RETURN_ERR(concat->inner_sizes_);
231 
232   return NNACL_OK;
233 }
234 
ConcatRelease(KernelBase * self)235 int ConcatRelease(KernelBase *self) {
236   ConcatStruct *concat = (ConcatStruct *)self;
237   NNACL_CHECK_NULL_RETURN_ERR(concat);
238   if (concat->inputs_ptr_ != NULL) {
239     self->env_->Free(self->env_->allocator_, concat->inputs_ptr_);
240   }
241   if (concat->is_with_data_ != NULL) {
242     self->env_->Free(self->env_->allocator_, concat->is_with_data_);
243   }
244   if (concat->inner_sizes_ != NULL) {
245     self->env_->Free(self->env_->allocator_, concat->inner_sizes_);
246   }
247   return NNACL_OK;
248 }
249 
ConcatCompute(KernelBase * self)250 int ConcatCompute(KernelBase *self) {
251   ConcatStruct *concat = (ConcatStruct *)self;
252   NNACL_CHECK_NULL_RETURN_ERR(concat);
253   if (concat->outer_size_ == 0 || concat->inner_sizes_[self->in_size_] == 0) {
254     return NNACL_OK;
255   }
256 
257   for (size_t i = 0; i < self->in_size_; ++i) {
258     if (!concat->is_with_data_[i]) {
259       continue;
260     }
261     NNACL_CHECK_NULL_RETURN_ERR(self->in_[i]->data_);
262     concat->inputs_ptr_[i] = self->in_[i]->data_;
263   }
264 
265   concat->output_ = self->out_[FIRST_INPUT]->data_;
266   NNACL_CHECK_NULL_RETURN_ERR(concat->output_);
267   return self->env_->ParallelLaunch(self->env_->thread_pool_, ConcatRun, self, self->thread_nr_);
268 }
269 
CreateConcat(OpParameter * param,int data_type)270 KernelBase *CreateConcat(OpParameter *param, int data_type) {
271   ConcatStruct *concat = (ConcatStruct *)malloc(sizeof(ConcatStruct));
272   NNACL_CHECK_NULL_RETURN_NULL(concat);
273   memset(concat, 0, sizeof(ConcatStruct));
274   if (data_type == kNumberTypeBool) {
275     concat->data_type_ = data_type;
276   } else {
277     concat->data_type_ = kNumberTypeFloat32;
278   }
279   concat->inner_sizes_ = NULL;
280   concat->inputs_ptr_ = NULL;
281   concat->is_with_data_ = NULL;
282   concat->base_.Prepare = ConcatPepare;
283   concat->base_.Resize = ConcatResize;
284   concat->base_.Release = ConcatRelease;
285   concat->base_.Compute = ConcatCompute;
286   return (KernelBase *)concat;
287 }
288 
289 REG_KERNEL_CREATOR(PrimType_Concat, kNumberTypeBool, CreateConcat)
290 REG_KERNEL_CREATOR(PrimType_Concat, kNumberTypeInt32, CreateConcat)
291 REG_KERNEL_CREATOR(PrimType_Concat, kNumberTypeFloat32, CreateConcat)
292