1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/runtime/kernel/arm/fp32/gather_fp32.h"
18 #include <limits>
19 #include "schema/model_generated.h"
20 #include "src/kernel_registry.h"
21
22 using mindspore::kernel::KERNEL_ARCH;
23 using mindspore::lite::KernelRegistrar;
24 using mindspore::lite::RET_ERROR;
25 using mindspore::lite::RET_OK;
26 using mindspore::schema::PrimitiveType_Gather;
27
28 namespace mindspore::kernel {
29 namespace {
30 constexpr int kSecondInput = 2;
31 }
Init()32 int GatherCPUKernel::Init() {
33 CHECK_LESS_RETURN(in_tensors_.size(), kInputSize2);
34 CHECK_LESS_RETURN(out_tensors_.size(), 1);
35 CHECK_NULL_RETURN(in_tensors_.at(kSecondInput)->data());
36 axis_ = *(reinterpret_cast<int *>(in_tensors_.at(kSecondInput)->data()));
37 if (!InferShapeDone()) {
38 return RET_OK;
39 }
40 return ReSize();
41 }
42
ReSize()43 int GatherCPUKernel::ReSize() { return RET_OK; }
44
DoGather(int task_id)45 int GatherCPUKernel::DoGather(int task_id) {
46 auto input_tensor = in_tensors_.at(0);
47 auto indices_tensor = in_tensors_.at(1);
48 auto out_tensor = out_tensors_.at(0);
49
50 auto in_shape = input_tensor->shape();
51 int in_rank = in_shape.size();
52 int indices_element_size = indices_tensor->ElementsNum();
53 MS_CHECK_LT(axis_, in_rank, RET_ERROR);
54 const int limit = in_shape.at(axis_);
55
56 int outer_size = 1, inner_size = 1;
57 for (int i = 0; i < axis_; ++i) {
58 outer_size *= in_shape.at(i);
59 }
60 for (int i = axis_ + 1; i < in_rank; ++i) {
61 inner_size *= in_shape.at(i);
62 }
63 int stride = UP_DIV(outer_size, op_parameter_->thread_num_);
64 int count = MSMIN(stride, outer_size - stride * task_id);
65 if (count <= 0) {
66 return RET_OK;
67 }
68 auto thread_stride = stride * task_id;
69
70 int8_t *int8_in = reinterpret_cast<int8_t *>(input_tensor->data());
71 CHECK_NULL_RETURN(int8_in);
72 int8_t *int8_out = reinterpret_cast<int8_t *>(out_tensor->data());
73 CHECK_NULL_RETURN(int8_out);
74
75 int data_size = static_cast<int>(lite::DataTypeSize(input_tensor->data_type()));
76 int8_in += thread_stride * limit * inner_size * data_size;
77 int8_out += thread_stride * indices_element_size * inner_size * data_size;
78
79 int error_code = Gather(int8_in, count, inner_size, limit, indices_data_, indices_element_size, int8_out, data_size);
80
81 return error_code;
82 }
83
GatherRun(void * cdata,int task_id,float lhs_scale,float rhs_scale)84 int GatherRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
85 auto gather_kernel = reinterpret_cast<GatherCPUKernel *>(cdata);
86 auto error_code = gather_kernel->DoGather(task_id);
87 if (error_code != RET_OK) {
88 MS_LOG(ERROR) << "GatherRun error task_id[" << task_id << "] error_code[" << error_code << "]";
89 }
90 return error_code;
91 }
92
Run()93 int GatherCPUKernel::Run() {
94 auto indices_tensor = in_tensors_.at(1);
95 int indices_num = indices_tensor->ElementsNum();
96 bool isIndicesInt32 = indices_tensor->data_type() == kNumberTypeInt32;
97 int ret = AssignIndicesData(isIndicesInt32, indices_num, indices_tensor);
98 if (ret != RET_OK) {
99 MS_LOG(ERROR) << "AssignIndicesData failed, error_code[" << ret << "]";
100 return ret;
101 }
102
103 ret = ParallelLaunch(this->ms_context_, GatherRun, this, op_parameter_->thread_num_);
104 if (ret != RET_OK) {
105 MS_LOG(ERROR) << "Gather function error error_code[" << ret << "]";
106 }
107 if (!isIndicesInt32) {
108 ms_context_->allocator->Free(indices_data_);
109 indices_data_ = nullptr;
110 }
111 return ret;
112 }
113
AssignIndicesData(bool isIndicesInt32,int indices_num,lite::Tensor * indices_tensor)114 int GatherCPUKernel::AssignIndicesData(bool isIndicesInt32, int indices_num, lite::Tensor *indices_tensor) {
115 if (!isIndicesInt32) {
116 if (indices_num >= std::numeric_limits<int>::max() / static_cast<int>(sizeof(int))) {
117 MS_LOG(ERROR) << "Input indices_num is invalid, indices_num: " << indices_num;
118 return RET_ERROR;
119 }
120 indices_data_ = reinterpret_cast<int32_t *>(ms_context_->allocator->Malloc(sizeof(int32_t) * indices_num));
121 if (indices_data_ == nullptr) {
122 MS_LOG(ERROR) << "Memory allocation failed";
123 return RET_ERROR;
124 }
125 if (indices_tensor->data_type() == kNumberTypeInt64) {
126 for (int i = 0; i < indices_num; i++) {
127 indices_data_[i] = reinterpret_cast<int64_t *>(indices_tensor->MutableData())[i];
128 }
129 } else {
130 for (int i = 0; i < indices_num; i++) {
131 indices_data_[i] = static_cast<int>(reinterpret_cast<float *>(indices_tensor->MutableData())[i]);
132 }
133 }
134 } else {
135 indices_data_ = reinterpret_cast<int32_t *>(indices_tensor->MutableData());
136 }
137 return RET_OK;
138 }
139
140 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Gather, LiteKernelCreator<GatherCPUKernel>)
141 REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Gather, LiteKernelCreator<GatherCPUKernel>)
142 } // namespace mindspore::kernel
143