1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/runtime/kernel/arm/fp16/gather_fp16.h"
18 #include <limits>
19 #include "schema/model_generated.h"
20 #include "src/kernel_registry.h"
21 #include "nnacl/fp16/cast_fp16.h"
22 #include "src/runtime/infer_manager.h"
23
24 using mindspore::kernel::KERNEL_ARCH;
25 using mindspore::lite::KernelRegistrar;
26 using mindspore::lite::RET_ERROR;
27 using mindspore::lite::RET_MEMORY_FAILED;
28 using mindspore::lite::RET_OK;
29 using mindspore::schema::PrimitiveType_Gather;
30
31 namespace mindspore::kernel {
32 namespace {
33 constexpr int kSecondInput = 2;
34 }
~GatherFp16CPUKernel()35 GatherFp16CPUKernel::~GatherFp16CPUKernel() {
36 if (input_data_) {
37 ms_context_->allocator->Free(input_data_);
38 input_data_ = nullptr;
39 }
40 }
41
Init()42 int GatherFp16CPUKernel::Init() {
43 CHECK_LESS_RETURN(in_tensors_.size(), 3);
44 CHECK_LESS_RETURN(out_tensors_.size(), 1);
45 auto input_tensor = in_tensors_.at(0);
46 CHECK_NULL_RETURN(input_tensor);
47 if (input_tensor->data_type() == kNumberTypeFloat32 && input_tensor->data() != nullptr) {
48 const_input_ = true;
49 input_data_ =
50 reinterpret_cast<float16_t *>(ms_context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t)));
51 if (input_data_ == nullptr) {
52 MS_LOG(ERROR) << "Malloc failed";
53 return RET_ERROR;
54 }
55 Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data()), input_data_, input_tensor->ElementsNum());
56 }
57 CHECK_NULL_RETURN(in_tensors_.at(kSecondInput)->data());
58 (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_ =
59 *(reinterpret_cast<int *>(in_tensors_.at(kSecondInput)->data()));
60 if (!InferShapeDone()) {
61 return RET_OK;
62 }
63 return ReSize();
64 }
65
ReSize()66 int GatherFp16CPUKernel::ReSize() { return RET_OK; }
67
DoGather(int task_id)68 int GatherFp16CPUKernel::DoGather(int task_id) {
69 auto input_tensor = in_tensors_.at(0);
70 auto indices_tensor = in_tensors_.at(1);
71 auto out_tensor = out_tensors_.at(0);
72 auto in_shape = input_tensor->shape();
73 int in_rank = in_shape.size();
74 int indices_element_size = indices_tensor->ElementsNum();
75 auto axis = (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_;
76 MS_CHECK_LT(axis, in_shape.size(), RET_ERROR);
77 const int limit = in_shape.at(axis);
78 int outer_size = 1, inner_size = 1;
79 for (int i = 0; i < axis; ++i) {
80 outer_size *= in_shape.at(i);
81 }
82 for (int i = axis + 1; i < in_rank; ++i) {
83 inner_size *= in_shape.at(i);
84 }
85 int stride = UP_DIV(outer_size, op_parameter_->thread_num_);
86 int count = MSMIN(stride, outer_size - stride * task_id);
87 if (count <= 0) {
88 return RET_OK;
89 }
90 auto thread_stride = stride * task_id;
91 int8_t *int8_in = nullptr;
92 if (input_tensor->data_type() == kNumberTypeFloat32) {
93 int8_in = reinterpret_cast<int8_t *>(input_data_);
94 } else if (input_tensor->data_type() == kNumberTypeFloat16) {
95 int8_in = reinterpret_cast<int8_t *>(input_tensor->data());
96 } else {
97 MS_LOG(ERROR) << "input data type error";
98 return RET_ERROR;
99 }
100 int8_t *int8_out = reinterpret_cast<int8_t *>(out_tensor->data());
101 CHECK_NULL_RETURN(int8_in);
102 CHECK_NULL_RETURN(int8_out);
103 int data_size = lite::DataTypeSize(kNumberTypeFloat16);
104 int8_in += thread_stride * limit * inner_size * data_size;
105 int8_out += thread_stride * indices_element_size * inner_size * data_size;
106 int error_code = Gather(int8_in, count, inner_size, limit, indices_data_, indices_element_size, int8_out, data_size);
107 return error_code;
108 }
109
GatherRunFp16(void * cdata,int task_id,float lhs_scale,float rhs_scale)110 int GatherRunFp16(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
111 auto gather_kernel = reinterpret_cast<GatherFp16CPUKernel *>(cdata);
112 auto error_code = gather_kernel->DoGather(task_id);
113 if (error_code != RET_OK) {
114 MS_LOG(ERROR) << "GatherRun error task_id[" << task_id << "] error_code[" << error_code << "]";
115 }
116 return error_code;
117 }
118
FreeIndicesData()119 void GatherFp16CPUKernel::FreeIndicesData() {
120 if (!is_indices_int32_) {
121 ms_context_->allocator->Free(indices_data_);
122 indices_data_ = nullptr;
123 }
124 if (!const_input_ && input_data_) {
125 ms_context_->allocator->Free(input_data_);
126 input_data_ = nullptr;
127 }
128 }
129
Run()130 int GatherFp16CPUKernel::Run() {
131 auto indices_tensor = in_tensors_.at(1);
132 int indices_num = indices_tensor->ElementsNum();
133 is_indices_int32_ = indices_tensor->data_type() == kNumberTypeInt32;
134 int ret = AssignIndicesData(is_indices_int32_, indices_num, indices_tensor);
135 if (ret != RET_OK) {
136 MS_LOG(ERROR) << "AssignIndicesData failed, error_code[" << ret << "]";
137 return ret;
138 }
139 if (!const_input_) {
140 auto input_tensor = in_tensors_.at(0);
141 CHECK_NULL_RETURN(input_tensor->data());
142 if (input_tensor->data_type() == kNumberTypeFloat32) {
143 input_data_ =
144 reinterpret_cast<float16_t *>(ms_context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t)));
145 if (input_data_ == nullptr) {
146 MS_LOG(ERROR) << "Malloc data failed";
147 FreeIndicesData();
148 return RET_ERROR;
149 }
150 Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data()), input_data_, input_tensor->ElementsNum());
151 }
152 }
153 ret = ParallelLaunch(this->ms_context_, GatherRunFp16, this, op_parameter_->thread_num_);
154 if (ret != RET_OK) {
155 MS_LOG(ERROR) << "Gather function error error_code[" << ret << "]";
156 }
157 FreeIndicesData();
158 return ret;
159 }
160
AssignIndicesData(bool isIndicesInt32,int indices_num,const lite::Tensor * indices_tensor)161 int GatherFp16CPUKernel::AssignIndicesData(bool isIndicesInt32, int indices_num, const lite::Tensor *indices_tensor) {
162 CHECK_NULL_RETURN(indices_tensor->data());
163 if (!isIndicesInt32) {
164 if (indices_num >= std::numeric_limits<int>::max() / static_cast<int>(sizeof(int))) {
165 MS_LOG(ERROR) << "Input indices_num is invalid, indices_num: " << indices_num;
166 return RET_ERROR;
167 }
168 if (indices_tensor->data_type() != kNumberTypeInt64 && indices_tensor->data_type() != kNumberTypeFloat16) {
169 MS_LOG(ERROR) << "The data type of indices tensor is wrong";
170 indices_data_ = nullptr;
171 return RET_ERROR;
172 }
173 indices_data_ = reinterpret_cast<int32_t *>(ms_context_->allocator->Malloc(sizeof(int32_t) * indices_num));
174 if (indices_data_ == nullptr) {
175 MS_LOG(ERROR) << "Memory allocation failed";
176 return RET_ERROR;
177 }
178 if (indices_tensor->data_type() == kNumberTypeInt64) {
179 for (int i = 0; i < indices_num; i++) {
180 indices_data_[i] = reinterpret_cast<int64_t *>(indices_tensor->data())[i];
181 }
182 } else {
183 for (int i = 0; i < indices_num; i++) {
184 indices_data_[i] = reinterpret_cast<float16_t *>(indices_tensor->data())[i];
185 }
186 }
187 } else {
188 indices_data_ = reinterpret_cast<int32_t *>(indices_tensor->data());
189 }
190 return RET_OK;
191 }
192
193 REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Gather, LiteKernelCreator<GatherFp16CPUKernel>)
194 } // namespace mindspore::kernel
195