1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/kernel/gather.h"
18 #include "nnacl/nnacl_common.h"
19 #include "nnacl/op_base.h"
20 #include "nnacl/kernel/default_kernel_base.h"
21 #include "nnacl/tensor_c_utils.h"
22
23 #define kGatherMinCostPerThread 16384
24
GatherHandleCopy(GatherStruct * gather,int8_t ** int8_in,int8_t ** int8_out,int begin,int end,int byte_in_stride)25 void GatherHandleCopy(GatherStruct *gather, int8_t **int8_in, int8_t **int8_out, int begin, int end,
26 int byte_in_stride) {
27 for (; begin < end; ++begin) {
28 int index = gather->indices_data_[begin];
29 index = (index < 0 ? index + gather->limit_ : index);
30 if (index < 0 || index >= gather->limit_) {
31 memset(*int8_out, 0, gather->byte_inner_size_);
32 } else {
33 memcpy(*int8_out, *int8_in + index * gather->byte_inner_size_, gather->byte_inner_size_);
34 }
35 *int8_out += gather->byte_inner_size_;
36 }
37 *int8_in += byte_in_stride;
38 }
39
GatherRun(void * cdata,int task_id,float l,float r)40 int GatherRun(void *cdata, int task_id, float l, float r) {
41 GatherStruct *gather = (GatherStruct *)cdata;
42 NNACL_CHECK_NULL_RETURN_ERR(gather);
43 NNACL_CHECK_FALSE(task_id < 0, NNACL_ERR);
44 NNACL_CHECK_FALSE(task_id >= gather->block_infos_size_, NNACL_ERR);
45
46 int8_t *int8_in = (int8_t *)(gather->base_.in_[FIRST_INPUT]->data_);
47 NNACL_CHECK_NULL_RETURN_ERR(int8_in);
48 int8_t *int8_out = (int8_t *)(gather->base_.out_[OUTPUT_INDEX]->data_);
49 NNACL_CHECK_NULL_RETURN_ERR(int8_out);
50 int begin_batch = gather->block_infos_[task_id].begin_batch_;
51 int begin_index = gather->block_infos_[task_id].begin_index_;
52 int end_batch = gather->block_infos_[task_id].end_batch_;
53 int end_index = gather->block_infos_[task_id].end_index_;
54 int64_t byte_in_stride = gather->limit_ * gather->byte_inner_size_;
55 int8_in += begin_batch * byte_in_stride;
56 int8_out += begin_batch * gather->indices_size_ * gather->byte_inner_size_ + begin_index * gather->byte_inner_size_;
57 if (begin_batch == end_batch) {
58 GatherHandleCopy(gather, &int8_in, &int8_out, begin_index, end_index, byte_in_stride);
59 return NNACL_OK;
60 }
61 GatherHandleCopy(gather, &int8_in, &int8_out, begin_index, gather->indices_size_, byte_in_stride);
62 ++begin_batch;
63 for (; begin_batch < end_batch; ++begin_batch) {
64 GatherHandleCopy(gather, &int8_in, &int8_out, 0, gather->indices_size_, byte_in_stride);
65 }
66 GatherHandleCopy(gather, &int8_in, &int8_out, 0, end_index, byte_in_stride);
67 return NNACL_OK;
68 }
69
AssignGatherIndicesData(GatherStruct * gather,bool is_indices_int32)70 int AssignGatherIndicesData(GatherStruct *gather, bool is_indices_int32) {
71 TensorC *indices_tensor = gather->base_.in_[SECOND_INPUT];
72 NNACL_CHECK_NULL_RETURN_ERR(indices_tensor->data_);
73
74 if (is_indices_int32) {
75 gather->indices_data_ = (int *)(indices_tensor->data_);
76 return NNACL_OK;
77 }
78
79 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(gather->indices_size_, (int)(sizeof(int)), NNACL_ERR);
80 gather->indices_data_ =
81 (int *)(gather->base_.env_->Alloc(gather->base_.env_->allocator_, gather->indices_size_ * sizeof(int)));
82 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(gather->indices_data_);
83
84 switch (indices_tensor->data_type_) {
85 case kNumberTypeInt64:
86 for (int i = 0; i < gather->indices_size_; i++) {
87 gather->indices_data_[i] = (int)((int64_t *)indices_tensor->data_)[i];
88 }
89 break;
90 case kNumberTypeFloat:
91 case kNumberTypeFloat32:
92 for (int i = 0; i < gather->indices_size_; i++) {
93 gather->indices_data_[i] = (int)((float *)indices_tensor->data_)[i];
94 }
95 break;
96 case kNumberTypeBool:
97 for (int i = 0; i < gather->indices_size_; i++) {
98 gather->indices_data_[i] = (int)((bool *)indices_tensor->data_)[i];
99 }
100 break;
101 default:
102 return NNACL_UNSUPPORTED_DATA_TYPE;
103 }
104 return NNACL_OK;
105 }
106
InitGatherDynamicStatus(GatherStruct * gather)107 int InitGatherDynamicStatus(GatherStruct *gather) {
108 int *in_shape = gather->base_.in_[FIRST_INPUT]->shape_;
109 int in_rank = (int)gather->base_.in_[FIRST_INPUT]->shape_size_;
110 NNACL_CHECK_TRUE_RET(gather->axis_ >= 0 && gather->axis_ < in_rank, NNACL_GATHER_AXIS_INVALID);
111 gather->limit_ = in_shape[gather->axis_];
112 gather->outer_size_ = 1;
113 for (int i = 0; i < gather->axis_; ++i) {
114 gather->outer_size_ *= in_shape[i];
115 }
116 gather->byte_inner_size_ = (int)DataTypeCSize(gather->base_.out_[OUTPUT_INDEX]->data_type_);
117 for (int i = gather->axis_ + 1; i < in_rank; ++i) {
118 gather->byte_inner_size_ *= in_shape[i];
119 }
120 gather->indices_size_ = GetElementNum(gather->base_.in_[SECOND_INPUT]);
121 return NNACL_OK;
122 }
123
GatherUpdateThreadNumProcess(GatherStruct * gather)124 void GatherUpdateThreadNumProcess(GatherStruct *gather) {
125 int all_bytes = GetSize(gather->base_.out_[OUTPUT_INDEX]);
126 if (all_bytes <= kGatherMinCostPerThread) {
127 gather->base_.thread_nr_ = 1;
128 return;
129 }
130
131 gather->base_.thread_nr_ =
132 gather->base_.UpdateThread(TC_PTYPE(PrimType_Gather), 0, gather->byte_inner_size_,
133 GetSize(gather->base_.out_[OUTPUT_INDEX]), gather->base_.thread_nr_);
134 return;
135 }
136
ChooseGatherThreadCuttingStrategy(GatherStruct * gather)137 int ChooseGatherThreadCuttingStrategy(GatherStruct *gather) {
138 gather->block_infos_size_ = 0;
139 if (gather->outer_size_ == 0 || gather->indices_size_ == 0 || gather->byte_inner_size_ == 0) {
140 return NNACL_OK;
141 }
142 GatherUpdateThreadNumProcess(gather);
143 if (gather->base_.thread_nr_ > GATHER_BLOCK_INFOS_SIZE) {
144 gather->base_.thread_nr_ = GATHER_BLOCK_INFOS_SIZE;
145 }
146
147 if (gather->base_.thread_nr_ == 1) {
148 gather->block_infos_[gather->block_infos_size_].begin_batch_ = 0;
149 gather->block_infos_[gather->block_infos_size_].begin_index_ = 0;
150 gather->block_infos_[gather->block_infos_size_].end_batch_ = gather->outer_size_;
151 gather->block_infos_[gather->block_infos_size_].end_index_ = 0;
152 gather->block_infos_size_++;
153 } else {
154 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(gather->outer_size_, gather->indices_size_, NNACL_ERR);
155 int total_block = gather->outer_size_ * gather->indices_size_;
156 int block_size = total_block / gather->base_.thread_nr_;
157 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(block_size, gather->base_.thread_nr_, NNACL_ERR);
158 int remain_block = total_block - block_size * gather->base_.thread_nr_;
159 int start = 0;
160 while (start < total_block) {
161 GatherBlockBoundaryInfo block_boundary_info;
162 block_boundary_info.begin_batch_ = start / gather->indices_size_;
163 block_boundary_info.begin_index_ = start % gather->indices_size_;
164 start += block_size;
165 if (remain_block > 0) {
166 ++start;
167 --remain_block;
168 }
169 if (start >= total_block) {
170 start = total_block;
171 }
172 block_boundary_info.end_batch_ = start / gather->indices_size_;
173 block_boundary_info.end_index_ = start % gather->indices_size_;
174 gather->block_infos_[gather->block_infos_size_++] = block_boundary_info;
175 }
176 gather->base_.thread_nr_ = gather->block_infos_size_;
177 }
178
179 return NNACL_OK;
180 }
181
GatherResize(KernelBase * self)182 int GatherResize(KernelBase *self) {
183 GatherStruct *gather = (GatherStruct *)self;
184 NNACL_CHECK_NULL_RETURN_ERR(gather);
185
186 int status = InitGatherDynamicStatus(gather);
187 NNACL_CHECK_FALSE(status != NNACL_OK, status);
188
189 return ChooseGatherThreadCuttingStrategy(gather);
190 }
191
GatherPrepare(struct KernelBase * self)192 int GatherPrepare(struct KernelBase *self) {
193 GatherStruct *gather = (GatherStruct *)self;
194 NNACL_CHECK_NULL_RETURN_ERR(gather);
195 NNACL_CHECK_FALSE(self->in_size_ < THREE_TENSOR, NNACL_GATHER_INPUT_TENSOR_INVALID);
196 NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_GATHER_OUTPUT_TENSOR_INVALID);
197 NNACL_CHECK_NULL_RETURN_ERR(self->in_[THIRD_INPUT]);
198 NNACL_CHECK_NULL_RETURN_ERR(self->in_[THIRD_INPUT]->data_);
199 gather->axis_ = *((int *)self->in_[THIRD_INPUT]->data_);
200 return NNACL_OK;
201 }
202
GatherCompute(struct KernelBase * self)203 int GatherCompute(struct KernelBase *self) {
204 GatherStruct *gather = (GatherStruct *)self;
205 NNACL_CHECK_NULL_RETURN_ERR(gather);
206
207 if (gather->outer_size_ == 0 || gather->indices_size_ == 0 || gather->byte_inner_size_ == 0) {
208 return NNACL_OK;
209 }
210
211 bool is_indices_int32 = self->in_[SECOND_INPUT]->data_type_ == kNumberTypeInt32;
212 int ret = AssignGatherIndicesData(gather, is_indices_int32);
213 if (ret != NNACL_OK) {
214 return ret;
215 }
216
217 ret = self->env_->ParallelLaunch(self->env_->thread_pool_, GatherRun, gather, gather->base_.thread_nr_);
218
219 if (!is_indices_int32) {
220 self->env_->Free(self->env_->allocator_, gather->indices_data_);
221 gather->indices_data_ = NULL;
222 }
223 return ret;
224 }
225
CreateGather(OpParameter * param,int data_type)226 KernelBase *CreateGather(OpParameter *param, int data_type) {
227 GatherStruct *gather = (GatherStruct *)malloc(sizeof(GatherStruct));
228 NNACL_MALLOC_CHECK_NULL_RETURN_NULL(gather);
229 gather->indices_data_ = NULL;
230 gather->block_infos_size_ = 0;
231 gather->base_.Prepare = GatherPrepare;
232 gather->base_.Resize = GatherResize;
233 gather->base_.Release = DefaultRelease;
234 gather->base_.Compute = GatherCompute;
235 return (KernelBase *)gather;
236 }
237
238 REG_KERNEL_CREATOR(PrimType_Gather, kNumberTypeFloat16, CreateGather)
239 REG_KERNEL_CREATOR(PrimType_Gather, kNumberTypeFloat32, CreateGather)
240 REG_KERNEL_CREATOR(PrimType_Gather, kNumberTypeInt32, CreateGather)
241 REG_KERNEL_CREATOR(PrimType_Gather, kNumberTypeBool, CreateGather)
242