• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/gather.h"
18 #include "nnacl/nnacl_common.h"
19 #include "nnacl/op_base.h"
20 #include "nnacl/kernel/default_kernel_base.h"
21 #include "nnacl/tensor_c_utils.h"
22 
23 #define kGatherMinCostPerThread 16384
24 
GatherHandleCopy(GatherStruct * gather,int8_t ** int8_in,int8_t ** int8_out,int begin,int end,int byte_in_stride)25 void GatherHandleCopy(GatherStruct *gather, int8_t **int8_in, int8_t **int8_out, int begin, int end,
26                       int byte_in_stride) {
27   for (; begin < end; ++begin) {
28     int index = gather->indices_data_[begin];
29     index = (index < 0 ? index + gather->limit_ : index);
30     if (index < 0 || index >= gather->limit_) {
31       memset(*int8_out, 0, gather->byte_inner_size_);
32     } else {
33       memcpy(*int8_out, *int8_in + index * gather->byte_inner_size_, gather->byte_inner_size_);
34     }
35     *int8_out += gather->byte_inner_size_;
36   }
37   *int8_in += byte_in_stride;
38 }
39 
GatherRun(void * cdata,int task_id,float l,float r)40 int GatherRun(void *cdata, int task_id, float l, float r) {
41   GatherStruct *gather = (GatherStruct *)cdata;
42   NNACL_CHECK_NULL_RETURN_ERR(gather);
43   NNACL_CHECK_FALSE(task_id < 0, NNACL_ERR);
44   NNACL_CHECK_FALSE(task_id >= gather->block_infos_size_, NNACL_ERR);
45 
46   int8_t *int8_in = (int8_t *)(gather->base_.in_[FIRST_INPUT]->data_);
47   NNACL_CHECK_NULL_RETURN_ERR(int8_in);
48   int8_t *int8_out = (int8_t *)(gather->base_.out_[OUTPUT_INDEX]->data_);
49   NNACL_CHECK_NULL_RETURN_ERR(int8_out);
50   int begin_batch = gather->block_infos_[task_id].begin_batch_;
51   int begin_index = gather->block_infos_[task_id].begin_index_;
52   int end_batch = gather->block_infos_[task_id].end_batch_;
53   int end_index = gather->block_infos_[task_id].end_index_;
54   int64_t byte_in_stride = gather->limit_ * gather->byte_inner_size_;
55   int8_in += begin_batch * byte_in_stride;
56   int8_out += begin_batch * gather->indices_size_ * gather->byte_inner_size_ + begin_index * gather->byte_inner_size_;
57   if (begin_batch == end_batch) {
58     GatherHandleCopy(gather, &int8_in, &int8_out, begin_index, end_index, byte_in_stride);
59     return NNACL_OK;
60   }
61   GatherHandleCopy(gather, &int8_in, &int8_out, begin_index, gather->indices_size_, byte_in_stride);
62   ++begin_batch;
63   for (; begin_batch < end_batch; ++begin_batch) {
64     GatherHandleCopy(gather, &int8_in, &int8_out, 0, gather->indices_size_, byte_in_stride);
65   }
66   GatherHandleCopy(gather, &int8_in, &int8_out, 0, end_index, byte_in_stride);
67   return NNACL_OK;
68 }
69 
AssignGatherIndicesData(GatherStruct * gather,bool is_indices_int32)70 int AssignGatherIndicesData(GatherStruct *gather, bool is_indices_int32) {
71   TensorC *indices_tensor = gather->base_.in_[SECOND_INPUT];
72   NNACL_CHECK_NULL_RETURN_ERR(indices_tensor->data_);
73 
74   if (is_indices_int32) {
75     gather->indices_data_ = (int *)(indices_tensor->data_);
76     return NNACL_OK;
77   }
78 
79   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(gather->indices_size_, (int)(sizeof(int)), NNACL_ERR);
80   gather->indices_data_ =
81     (int *)(gather->base_.env_->Alloc(gather->base_.env_->allocator_, gather->indices_size_ * sizeof(int)));
82   NNACL_MALLOC_CHECK_NULL_RETURN_ERR(gather->indices_data_);
83 
84   switch (indices_tensor->data_type_) {
85     case kNumberTypeInt64:
86       for (int i = 0; i < gather->indices_size_; i++) {
87         gather->indices_data_[i] = (int)((int64_t *)indices_tensor->data_)[i];
88       }
89       break;
90     case kNumberTypeFloat:
91     case kNumberTypeFloat32:
92       for (int i = 0; i < gather->indices_size_; i++) {
93         gather->indices_data_[i] = (int)((float *)indices_tensor->data_)[i];
94       }
95       break;
96     case kNumberTypeBool:
97       for (int i = 0; i < gather->indices_size_; i++) {
98         gather->indices_data_[i] = (int)((bool *)indices_tensor->data_)[i];
99       }
100       break;
101     default:
102       return NNACL_UNSUPPORTED_DATA_TYPE;
103   }
104   return NNACL_OK;
105 }
106 
InitGatherDynamicStatus(GatherStruct * gather)107 int InitGatherDynamicStatus(GatherStruct *gather) {
108   int *in_shape = gather->base_.in_[FIRST_INPUT]->shape_;
109   int in_rank = (int)gather->base_.in_[FIRST_INPUT]->shape_size_;
110   NNACL_CHECK_TRUE_RET(gather->axis_ >= 0 && gather->axis_ < in_rank, NNACL_GATHER_AXIS_INVALID);
111   gather->limit_ = in_shape[gather->axis_];
112   gather->outer_size_ = 1;
113   for (int i = 0; i < gather->axis_; ++i) {
114     gather->outer_size_ *= in_shape[i];
115   }
116   gather->byte_inner_size_ = (int)DataTypeCSize(gather->base_.out_[OUTPUT_INDEX]->data_type_);
117   for (int i = gather->axis_ + 1; i < in_rank; ++i) {
118     gather->byte_inner_size_ *= in_shape[i];
119   }
120   gather->indices_size_ = GetElementNum(gather->base_.in_[SECOND_INPUT]);
121   return NNACL_OK;
122 }
123 
GatherUpdateThreadNumProcess(GatherStruct * gather)124 void GatherUpdateThreadNumProcess(GatherStruct *gather) {
125   int all_bytes = GetSize(gather->base_.out_[OUTPUT_INDEX]);
126   if (all_bytes <= kGatherMinCostPerThread) {
127     gather->base_.thread_nr_ = 1;
128     return;
129   }
130 
131   gather->base_.thread_nr_ =
132     gather->base_.UpdateThread(TC_PTYPE(PrimType_Gather), 0, gather->byte_inner_size_,
133                                GetSize(gather->base_.out_[OUTPUT_INDEX]), gather->base_.thread_nr_);
134   return;
135 }
136 
ChooseGatherThreadCuttingStrategy(GatherStruct * gather)137 int ChooseGatherThreadCuttingStrategy(GatherStruct *gather) {
138   gather->block_infos_size_ = 0;
139   if (gather->outer_size_ == 0 || gather->indices_size_ == 0 || gather->byte_inner_size_ == 0) {
140     return NNACL_OK;
141   }
142   GatherUpdateThreadNumProcess(gather);
143   if (gather->base_.thread_nr_ > GATHER_BLOCK_INFOS_SIZE) {
144     gather->base_.thread_nr_ = GATHER_BLOCK_INFOS_SIZE;
145   }
146 
147   if (gather->base_.thread_nr_ == 1) {
148     gather->block_infos_[gather->block_infos_size_].begin_batch_ = 0;
149     gather->block_infos_[gather->block_infos_size_].begin_index_ = 0;
150     gather->block_infos_[gather->block_infos_size_].end_batch_ = gather->outer_size_;
151     gather->block_infos_[gather->block_infos_size_].end_index_ = 0;
152     gather->block_infos_size_++;
153   } else {
154     NNACL_CHECK_INT_MUL_NOT_OVERFLOW(gather->outer_size_, gather->indices_size_, NNACL_ERR);
155     int total_block = gather->outer_size_ * gather->indices_size_;
156     int block_size = total_block / gather->base_.thread_nr_;
157     NNACL_CHECK_INT_MUL_NOT_OVERFLOW(block_size, gather->base_.thread_nr_, NNACL_ERR);
158     int remain_block = total_block - block_size * gather->base_.thread_nr_;
159     int start = 0;
160     while (start < total_block) {
161       GatherBlockBoundaryInfo block_boundary_info;
162       block_boundary_info.begin_batch_ = start / gather->indices_size_;
163       block_boundary_info.begin_index_ = start % gather->indices_size_;
164       start += block_size;
165       if (remain_block > 0) {
166         ++start;
167         --remain_block;
168       }
169       if (start >= total_block) {
170         start = total_block;
171       }
172       block_boundary_info.end_batch_ = start / gather->indices_size_;
173       block_boundary_info.end_index_ = start % gather->indices_size_;
174       gather->block_infos_[gather->block_infos_size_++] = block_boundary_info;
175     }
176     gather->base_.thread_nr_ = gather->block_infos_size_;
177   }
178 
179   return NNACL_OK;
180 }
181 
GatherResize(KernelBase * self)182 int GatherResize(KernelBase *self) {
183   GatherStruct *gather = (GatherStruct *)self;
184   NNACL_CHECK_NULL_RETURN_ERR(gather);
185 
186   int status = InitGatherDynamicStatus(gather);
187   NNACL_CHECK_FALSE(status != NNACL_OK, status);
188 
189   return ChooseGatherThreadCuttingStrategy(gather);
190 }
191 
GatherPrepare(struct KernelBase * self)192 int GatherPrepare(struct KernelBase *self) {
193   GatherStruct *gather = (GatherStruct *)self;
194   NNACL_CHECK_NULL_RETURN_ERR(gather);
195   NNACL_CHECK_FALSE(self->in_size_ < THREE_TENSOR, NNACL_GATHER_INPUT_TENSOR_INVALID);
196   NNACL_CHECK_FALSE(self->out_size_ < ONE_TENSOR, NNACL_GATHER_OUTPUT_TENSOR_INVALID);
197   NNACL_CHECK_NULL_RETURN_ERR(self->in_[THIRD_INPUT]);
198   NNACL_CHECK_NULL_RETURN_ERR(self->in_[THIRD_INPUT]->data_);
199   gather->axis_ = *((int *)self->in_[THIRD_INPUT]->data_);
200   return NNACL_OK;
201 }
202 
GatherCompute(struct KernelBase * self)203 int GatherCompute(struct KernelBase *self) {
204   GatherStruct *gather = (GatherStruct *)self;
205   NNACL_CHECK_NULL_RETURN_ERR(gather);
206 
207   if (gather->outer_size_ == 0 || gather->indices_size_ == 0 || gather->byte_inner_size_ == 0) {
208     return NNACL_OK;
209   }
210 
211   bool is_indices_int32 = self->in_[SECOND_INPUT]->data_type_ == kNumberTypeInt32;
212   int ret = AssignGatherIndicesData(gather, is_indices_int32);
213   if (ret != NNACL_OK) {
214     return ret;
215   }
216 
217   ret = self->env_->ParallelLaunch(self->env_->thread_pool_, GatherRun, gather, gather->base_.thread_nr_);
218 
219   if (!is_indices_int32) {
220     self->env_->Free(self->env_->allocator_, gather->indices_data_);
221     gather->indices_data_ = NULL;
222   }
223   return ret;
224 }
225 
CreateGather(OpParameter * param,int data_type)226 KernelBase *CreateGather(OpParameter *param, int data_type) {
227   GatherStruct *gather = (GatherStruct *)malloc(sizeof(GatherStruct));
228   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(gather);
229   gather->indices_data_ = NULL;
230   gather->block_infos_size_ = 0;
231   gather->base_.Prepare = GatherPrepare;
232   gather->base_.Resize = GatherResize;
233   gather->base_.Release = DefaultRelease;
234   gather->base_.Compute = GatherCompute;
235   return (KernelBase *)gather;
236 }
237 
238 REG_KERNEL_CREATOR(PrimType_Gather, kNumberTypeFloat16, CreateGather)
239 REG_KERNEL_CREATOR(PrimType_Gather, kNumberTypeFloat32, CreateGather)
240 REG_KERNEL_CREATOR(PrimType_Gather, kNumberTypeInt32, CreateGather)
241 REG_KERNEL_CREATOR(PrimType_Gather, kNumberTypeBool, CreateGather)
242