1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "nnacl/kernel/gather_nd.h"
18 #include "nnacl/fp32/gatherNd_fp32.h"
19 #include "nnacl/kernel/default_kernel_base.h"
20 #include "nnacl/nnacl_common.h"
21
GatherNdInitOffset(GatherNdStruct * gather_nd)22 int GatherNdInitOffset(GatherNdStruct *gather_nd) {
23 TensorC *input_tensor = gather_nd->base_.in_[FIRST_INPUT];
24 NNACL_CHECK_NULL_RETURN_ERR(input_tensor);
25 TensorC *indices_tensor = gather_nd->base_.in_[SECOND_INPUT];
26 NNACL_CHECK_NULL_RETURN_ERR(indices_tensor);
27
28 if (indices_tensor->shape_size_ < 1) {
29 return NNACL_GATHER_ND_INDICES_RANK_INVALID;
30 }
31
32 int in_rank = input_tensor->shape_size_;
33 int idx_lastshape = indices_tensor->shape_[indices_tensor->shape_size_ - 1];
34 if (idx_lastshape > in_rank) {
35 return NNACL_GATHER_ND_INDICES_SHAPE_INVALID;
36 }
37
38 gather_nd->area_ = 1;
39 for (int i = idx_lastshape; i < input_tensor->shape_size_; ++i) {
40 gather_nd->area_ *= input_tensor->shape_[i];
41 }
42
43 int in_stride[MAX_SHAPE_SIZE] = {0};
44 in_stride[in_rank - 1] = 1;
45 for (int i = in_rank - 2; i >= 0; --i) {
46 in_stride[i] = input_tensor->shape_[i + 1] * in_stride[i + 1];
47 }
48
49 int idx_stride = idx_lastshape;
50 (void)memset(gather_nd->in_offset_, 0, gather_nd->count_ * sizeof(int));
51
52 if (indices_tensor->data_type_ == kNumberTypeInt || indices_tensor->data_type_ == kNumberTypeInt32) {
53 int32_t *indices_ptr = (int32_t *)indices_tensor->data_;
54 NNACL_CHECK_NULL_RETURN_ERR(indices_ptr);
55 for (int j = 0; j < gather_nd->count_; ++j) {
56 for (int k = 0; k < idx_lastshape; ++k) {
57 gather_nd->in_offset_[j] += indices_ptr[j * idx_stride + k] * in_stride[k];
58 }
59 }
60 } else if (indices_tensor->data_type_ == kNumberTypeInt64) {
61 int64_t *indices_ptr = (int64_t *)indices_tensor->data_;
62 for (int j = 0; j < gather_nd->count_; ++j) {
63 for (int k = 0; k < idx_lastshape; ++k) {
64 gather_nd->in_offset_[j] += indices_ptr[j * idx_stride + k] * in_stride[k];
65 }
66 }
67 } else {
68 return NNACL_GATHER_ND_INDICES_DATA_TYPE_INVALID;
69 }
70
71 return NNACL_OK;
72 }
73
GatherNdRun(void * cdata,int task_id,float l,float r)74 int GatherNdRun(void *cdata, int task_id, float l, float r) {
75 GatherNdStruct *gather_nd = (GatherNdStruct *)cdata;
76 NNACL_CHECK_NULL_RETURN_ERR(gather_nd);
77 TensorC *input = gather_nd->base_.in_[FIRST_INPUT];
78 NNACL_CHECK_NULL_RETURN_ERR(input);
79
80 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, gather_nd->thread_stride_, NNACL_ERR);
81 int count = NNACL_MIN(gather_nd->thread_stride_, gather_nd->count_ - task_id * gather_nd->thread_stride_);
82 if (count <= 0) {
83 return NNACL_OK;
84 }
85
86 int offset = task_id * gather_nd->thread_stride_;
87 int dtype_len = DataTypeCSize(input->data_type_);
88 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(offset, gather_nd->area_, NNACL_ERR);
89 int8_t *out_ptr = (int8_t *)gather_nd->out_ptr_ + offset * gather_nd->area_ * dtype_len;
90 return GatherNd(gather_nd->in_ptr_, out_ptr, gather_nd->in_offset_ + offset, gather_nd->area_, count, dtype_len);
91 }
92
GatherNdCompute(KernelBase * self)93 int GatherNdCompute(KernelBase *self) {
94 GatherNdStruct *gather_nd = (GatherNdStruct *)self;
95 NNACL_CHECK_NULL_RETURN_ERR(gather_nd);
96
97 TensorC *input = self->in_[FIRST_INPUT];
98 NNACL_CHECK_NULL_RETURN_ERR(input);
99 gather_nd->in_ptr_ = input->data_;
100 NNACL_CHECK_NULL_RETURN_ERR(gather_nd->in_ptr_);
101
102 TensorC *output = self->out_[FIRST_INPUT];
103 NNACL_CHECK_NULL_RETURN_ERR(output);
104 gather_nd->out_ptr_ = output->data_;
105 NNACL_CHECK_NULL_RETURN_ERR(gather_nd->out_ptr_);
106
107 int ret = GatherNdInitOffset(gather_nd);
108 if (ret != NNACL_OK) {
109 return ret;
110 }
111
112 return self->env_->ParallelLaunch(self->env_->thread_pool_, GatherNdRun, self, self->thread_nr_);
113 }
114
GatherNdRelease(KernelBase * self)115 int GatherNdRelease(KernelBase *self) {
116 GatherNdStruct *gather_nd = (GatherNdStruct *)self;
117 NNACL_CHECK_NULL_RETURN_ERR(gather_nd);
118 if (gather_nd->in_offset_ != NULL) {
119 self->env_->Free(self->env_->allocator_, gather_nd->in_offset_);
120 gather_nd->in_offset_ = NULL;
121 }
122 return NNACL_OK;
123 }
124
GatherNdResize(KernelBase * self)125 int GatherNdResize(KernelBase *self) {
126 (void)self->Release;
127 GatherNdStruct *gather_nd = (GatherNdStruct *)self;
128 NNACL_CHECK_NULL_RETURN_ERR(gather_nd);
129 TensorC *indices_tensor = self->in_[SECOND_INPUT];
130 NNACL_CHECK_NULL_RETURN_ERR(indices_tensor);
131
132 gather_nd->count_ = 1;
133 for (int i = 0; i < indices_tensor->shape_size_ - 1; ++i) {
134 NNACL_CHECK_INT_MUL_NOT_OVERFLOW(gather_nd->count_, indices_tensor->shape_[i], NNACL_ERR);
135 gather_nd->count_ *= indices_tensor->shape_[i];
136 }
137
138 int min_count = INT32_MAX / sizeof(int);
139 if (gather_nd->count_ >= min_count) {
140 return NNACL_GATHER_ND_COUNT_INVALID;
141 }
142
143 gather_nd->in_offset_ = self->env_->Alloc(self->env_->allocator_, gather_nd->count_ * sizeof(int));
144 NNACL_MALLOC_CHECK_NULL_RETURN_ERR(gather_nd->in_offset_);
145
146 gather_nd->base_.thread_nr_ = NNACL_MIN(gather_nd->base_.thread_nr_, gather_nd->count_);
147 if (gather_nd->base_.thread_nr_ != 0) {
148 gather_nd->thread_stride_ = UP_DIV(gather_nd->count_, gather_nd->base_.thread_nr_);
149 }
150 return NNACL_OK;
151 }
152
CreateGatherNd(OpParameter * param,int data_type)153 KernelBase *CreateGatherNd(OpParameter *param, int data_type) {
154 GatherNdStruct *gather_nd = (GatherNdStruct *)malloc(sizeof(GatherNdStruct));
155 NNACL_MALLOC_CHECK_NULL_RETURN_NULL(gather_nd);
156 memset(gather_nd, 0, sizeof(GatherNdStruct));
157
158 gather_nd->base_.Prepare = DefaultPrepare2In1Out;
159 gather_nd->base_.Resize = GatherNdResize;
160 gather_nd->base_.Compute = GatherNdCompute;
161 gather_nd->base_.Release = GatherNdRelease;
162 return (KernelBase *)gather_nd;
163 }
164
165 REG_KERNEL_CREATOR(PrimType_GatherNd, kNumberTypeBool, CreateGatherNd);
166 REG_KERNEL_CREATOR(PrimType_GatherNd, kNumberTypeInt32, CreateGatherNd);
167 REG_KERNEL_CREATOR(PrimType_GatherNd, kNumberTypeFloat32, CreateGatherNd);
168 REG_KERNEL_CREATOR(PrimType_GatherNd, kNumberTypeFloat16, CreateGatherNd);
169