• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "nnacl/kernel/gather_nd.h"
18 #include "nnacl/fp32/gatherNd_fp32.h"
19 #include "nnacl/kernel/default_kernel_base.h"
20 #include "nnacl/nnacl_common.h"
21 
GatherNdInitOffset(GatherNdStruct * gather_nd)22 int GatherNdInitOffset(GatherNdStruct *gather_nd) {
23   TensorC *input_tensor = gather_nd->base_.in_[FIRST_INPUT];
24   NNACL_CHECK_NULL_RETURN_ERR(input_tensor);
25   TensorC *indices_tensor = gather_nd->base_.in_[SECOND_INPUT];
26   NNACL_CHECK_NULL_RETURN_ERR(indices_tensor);
27 
28   if (indices_tensor->shape_size_ < 1) {
29     return NNACL_GATHER_ND_INDICES_RANK_INVALID;
30   }
31 
32   int in_rank = input_tensor->shape_size_;
33   int idx_lastshape = indices_tensor->shape_[indices_tensor->shape_size_ - 1];
34   if (idx_lastshape > in_rank) {
35     return NNACL_GATHER_ND_INDICES_SHAPE_INVALID;
36   }
37 
38   gather_nd->area_ = 1;
39   for (int i = idx_lastshape; i < input_tensor->shape_size_; ++i) {
40     gather_nd->area_ *= input_tensor->shape_[i];
41   }
42 
43   int in_stride[MAX_SHAPE_SIZE] = {0};
44   in_stride[in_rank - 1] = 1;
45   for (int i = in_rank - 2; i >= 0; --i) {
46     in_stride[i] = input_tensor->shape_[i + 1] * in_stride[i + 1];
47   }
48 
49   int idx_stride = idx_lastshape;
50   (void)memset(gather_nd->in_offset_, 0, gather_nd->count_ * sizeof(int));
51 
52   if (indices_tensor->data_type_ == kNumberTypeInt || indices_tensor->data_type_ == kNumberTypeInt32) {
53     int32_t *indices_ptr = (int32_t *)indices_tensor->data_;
54     NNACL_CHECK_NULL_RETURN_ERR(indices_ptr);
55     for (int j = 0; j < gather_nd->count_; ++j) {
56       for (int k = 0; k < idx_lastshape; ++k) {
57         gather_nd->in_offset_[j] += indices_ptr[j * idx_stride + k] * in_stride[k];
58       }
59     }
60   } else if (indices_tensor->data_type_ == kNumberTypeInt64) {
61     int64_t *indices_ptr = (int64_t *)indices_tensor->data_;
62     for (int j = 0; j < gather_nd->count_; ++j) {
63       for (int k = 0; k < idx_lastshape; ++k) {
64         gather_nd->in_offset_[j] += indices_ptr[j * idx_stride + k] * in_stride[k];
65       }
66     }
67   } else {
68     return NNACL_GATHER_ND_INDICES_DATA_TYPE_INVALID;
69   }
70 
71   return NNACL_OK;
72 }
73 
GatherNdRun(void * cdata,int task_id,float l,float r)74 int GatherNdRun(void *cdata, int task_id, float l, float r) {
75   GatherNdStruct *gather_nd = (GatherNdStruct *)cdata;
76   NNACL_CHECK_NULL_RETURN_ERR(gather_nd);
77   TensorC *input = gather_nd->base_.in_[FIRST_INPUT];
78   NNACL_CHECK_NULL_RETURN_ERR(input);
79 
80   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(task_id, gather_nd->thread_stride_, NNACL_ERR);
81   int count = NNACL_MIN(gather_nd->thread_stride_, gather_nd->count_ - task_id * gather_nd->thread_stride_);
82   if (count <= 0) {
83     return NNACL_OK;
84   }
85 
86   int offset = task_id * gather_nd->thread_stride_;
87   int dtype_len = DataTypeCSize(input->data_type_);
88   NNACL_CHECK_INT_MUL_NOT_OVERFLOW(offset, gather_nd->area_, NNACL_ERR);
89   int8_t *out_ptr = (int8_t *)gather_nd->out_ptr_ + offset * gather_nd->area_ * dtype_len;
90   return GatherNd(gather_nd->in_ptr_, out_ptr, gather_nd->in_offset_ + offset, gather_nd->area_, count, dtype_len);
91 }
92 
GatherNdCompute(KernelBase * self)93 int GatherNdCompute(KernelBase *self) {
94   GatherNdStruct *gather_nd = (GatherNdStruct *)self;
95   NNACL_CHECK_NULL_RETURN_ERR(gather_nd);
96 
97   TensorC *input = self->in_[FIRST_INPUT];
98   NNACL_CHECK_NULL_RETURN_ERR(input);
99   gather_nd->in_ptr_ = input->data_;
100   NNACL_CHECK_NULL_RETURN_ERR(gather_nd->in_ptr_);
101 
102   TensorC *output = self->out_[FIRST_INPUT];
103   NNACL_CHECK_NULL_RETURN_ERR(output);
104   gather_nd->out_ptr_ = output->data_;
105   NNACL_CHECK_NULL_RETURN_ERR(gather_nd->out_ptr_);
106 
107   int ret = GatherNdInitOffset(gather_nd);
108   if (ret != NNACL_OK) {
109     return ret;
110   }
111 
112   return self->env_->ParallelLaunch(self->env_->thread_pool_, GatherNdRun, self, self->thread_nr_);
113 }
114 
GatherNdRelease(KernelBase * self)115 int GatherNdRelease(KernelBase *self) {
116   GatherNdStruct *gather_nd = (GatherNdStruct *)self;
117   NNACL_CHECK_NULL_RETURN_ERR(gather_nd);
118   if (gather_nd->in_offset_ != NULL) {
119     self->env_->Free(self->env_->allocator_, gather_nd->in_offset_);
120     gather_nd->in_offset_ = NULL;
121   }
122   return NNACL_OK;
123 }
124 
GatherNdResize(KernelBase * self)125 int GatherNdResize(KernelBase *self) {
126   (void)self->Release;
127   GatherNdStruct *gather_nd = (GatherNdStruct *)self;
128   NNACL_CHECK_NULL_RETURN_ERR(gather_nd);
129   TensorC *indices_tensor = self->in_[SECOND_INPUT];
130   NNACL_CHECK_NULL_RETURN_ERR(indices_tensor);
131 
132   gather_nd->count_ = 1;
133   for (int i = 0; i < indices_tensor->shape_size_ - 1; ++i) {
134     NNACL_CHECK_INT_MUL_NOT_OVERFLOW(gather_nd->count_, indices_tensor->shape_[i], NNACL_ERR);
135     gather_nd->count_ *= indices_tensor->shape_[i];
136   }
137 
138   int min_count = INT32_MAX / sizeof(int);
139   if (gather_nd->count_ >= min_count) {
140     return NNACL_GATHER_ND_COUNT_INVALID;
141   }
142 
143   gather_nd->in_offset_ = self->env_->Alloc(self->env_->allocator_, gather_nd->count_ * sizeof(int));
144   NNACL_MALLOC_CHECK_NULL_RETURN_ERR(gather_nd->in_offset_);
145 
146   gather_nd->base_.thread_nr_ = NNACL_MIN(gather_nd->base_.thread_nr_, gather_nd->count_);
147   if (gather_nd->base_.thread_nr_ != 0) {
148     gather_nd->thread_stride_ = UP_DIV(gather_nd->count_, gather_nd->base_.thread_nr_);
149   }
150   return NNACL_OK;
151 }
152 
CreateGatherNd(OpParameter * param,int data_type)153 KernelBase *CreateGatherNd(OpParameter *param, int data_type) {
154   GatherNdStruct *gather_nd = (GatherNdStruct *)malloc(sizeof(GatherNdStruct));
155   NNACL_MALLOC_CHECK_NULL_RETURN_NULL(gather_nd);
156   memset(gather_nd, 0, sizeof(GatherNdStruct));
157 
158   gather_nd->base_.Prepare = DefaultPrepare2In1Out;
159   gather_nd->base_.Resize = GatherNdResize;
160   gather_nd->base_.Compute = GatherNdCompute;
161   gather_nd->base_.Release = GatherNdRelease;
162   return (KernelBase *)gather_nd;
163 }
164 
165 REG_KERNEL_CREATOR(PrimType_GatherNd, kNumberTypeBool, CreateGatherNd);
166 REG_KERNEL_CREATOR(PrimType_GatherNd, kNumberTypeInt32, CreateGatherNd);
167 REG_KERNEL_CREATOR(PrimType_GatherNd, kNumberTypeFloat32, CreateGatherNd);
168 REG_KERNEL_CREATOR(PrimType_GatherNd, kNumberTypeFloat16, CreateGatherNd);
169