• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/runtime/kernel/arm/fp16/gather_fp16.h"
18 #include <limits>
19 #include "schema/model_generated.h"
20 #include "src/kernel_registry.h"
21 #include "nnacl/fp16/cast_fp16.h"
22 #include "src/runtime/infer_manager.h"
23 
24 using mindspore::kernel::KERNEL_ARCH;
25 using mindspore::lite::KernelRegistrar;
26 using mindspore::lite::RET_ERROR;
27 using mindspore::lite::RET_MEMORY_FAILED;
28 using mindspore::lite::RET_OK;
29 using mindspore::schema::PrimitiveType_Gather;
30 
31 namespace mindspore::kernel {
32 namespace {
33 constexpr int kSecondInput = 2;
34 }
~GatherFp16CPUKernel()35 GatherFp16CPUKernel::~GatherFp16CPUKernel() {
36   if (input_data_) {
37     ms_context_->allocator->Free(input_data_);
38     input_data_ = nullptr;
39   }
40 }
41 
Init()42 int GatherFp16CPUKernel::Init() {
43   CHECK_LESS_RETURN(in_tensors_.size(), 3);
44   CHECK_LESS_RETURN(out_tensors_.size(), 1);
45   auto input_tensor = in_tensors_.at(0);
46   CHECK_NULL_RETURN(input_tensor);
47   if (input_tensor->data_type() == kNumberTypeFloat32 && input_tensor->data() != nullptr) {
48     const_input_ = true;
49     input_data_ =
50       reinterpret_cast<float16_t *>(ms_context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t)));
51     if (input_data_ == nullptr) {
52       MS_LOG(ERROR) << "Malloc failed";
53       return RET_ERROR;
54     }
55     Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data()), input_data_, input_tensor->ElementsNum());
56   }
57   CHECK_NULL_RETURN(in_tensors_.at(kSecondInput)->data());
58   (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_ =
59     *(reinterpret_cast<int *>(in_tensors_.at(kSecondInput)->data()));
60   if (!InferShapeDone()) {
61     return RET_OK;
62   }
63   return ReSize();
64 }
65 
ReSize()66 int GatherFp16CPUKernel::ReSize() { return RET_OK; }
67 
DoGather(int task_id)68 int GatherFp16CPUKernel::DoGather(int task_id) {
69   auto input_tensor = in_tensors_.at(0);
70   auto indices_tensor = in_tensors_.at(1);
71   auto out_tensor = out_tensors_.at(0);
72   auto in_shape = input_tensor->shape();
73   int in_rank = in_shape.size();
74   int indices_element_size = indices_tensor->ElementsNum();
75   auto axis = (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_;
76   MS_CHECK_LT(axis, in_shape.size(), RET_ERROR);
77   const int limit = in_shape.at(axis);
78   int outer_size = 1, inner_size = 1;
79   for (int i = 0; i < axis; ++i) {
80     outer_size *= in_shape.at(i);
81   }
82   for (int i = axis + 1; i < in_rank; ++i) {
83     inner_size *= in_shape.at(i);
84   }
85   int stride = UP_DIV(outer_size, op_parameter_->thread_num_);
86   int count = MSMIN(stride, outer_size - stride * task_id);
87   if (count <= 0) {
88     return RET_OK;
89   }
90   auto thread_stride = stride * task_id;
91   int8_t *int8_in = nullptr;
92   if (input_tensor->data_type() == kNumberTypeFloat32) {
93     int8_in = reinterpret_cast<int8_t *>(input_data_);
94   } else if (input_tensor->data_type() == kNumberTypeFloat16) {
95     int8_in = reinterpret_cast<int8_t *>(input_tensor->data());
96   } else {
97     MS_LOG(ERROR) << "input data type error";
98     return RET_ERROR;
99   }
100   int8_t *int8_out = reinterpret_cast<int8_t *>(out_tensor->data());
101   CHECK_NULL_RETURN(int8_in);
102   CHECK_NULL_RETURN(int8_out);
103   int data_size = lite::DataTypeSize(kNumberTypeFloat16);
104   int8_in += thread_stride * limit * inner_size * data_size;
105   int8_out += thread_stride * indices_element_size * inner_size * data_size;
106   int error_code = Gather(int8_in, count, inner_size, limit, indices_data_, indices_element_size, int8_out, data_size);
107   return error_code;
108 }
109 
GatherRunFp16(void * cdata,int task_id,float lhs_scale,float rhs_scale)110 int GatherRunFp16(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
111   auto gather_kernel = reinterpret_cast<GatherFp16CPUKernel *>(cdata);
112   auto error_code = gather_kernel->DoGather(task_id);
113   if (error_code != RET_OK) {
114     MS_LOG(ERROR) << "GatherRun error task_id[" << task_id << "] error_code[" << error_code << "]";
115   }
116   return error_code;
117 }
118 
FreeIndicesData()119 void GatherFp16CPUKernel::FreeIndicesData() {
120   if (!is_indices_int32_) {
121     ms_context_->allocator->Free(indices_data_);
122     indices_data_ = nullptr;
123   }
124   if (!const_input_ && input_data_) {
125     ms_context_->allocator->Free(input_data_);
126     input_data_ = nullptr;
127   }
128 }
129 
Run()130 int GatherFp16CPUKernel::Run() {
131   auto indices_tensor = in_tensors_.at(1);
132   int indices_num = indices_tensor->ElementsNum();
133   is_indices_int32_ = indices_tensor->data_type() == kNumberTypeInt32;
134   int ret = AssignIndicesData(is_indices_int32_, indices_num, indices_tensor);
135   if (ret != RET_OK) {
136     MS_LOG(ERROR) << "AssignIndicesData failed, error_code[" << ret << "]";
137     return ret;
138   }
139   if (!const_input_) {
140     auto input_tensor = in_tensors_.at(0);
141     CHECK_NULL_RETURN(input_tensor->data());
142     if (input_tensor->data_type() == kNumberTypeFloat32) {
143       input_data_ =
144         reinterpret_cast<float16_t *>(ms_context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t)));
145       if (input_data_ == nullptr) {
146         MS_LOG(ERROR) << "Malloc data failed";
147         FreeIndicesData();
148         return RET_ERROR;
149       }
150       Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data()), input_data_, input_tensor->ElementsNum());
151     }
152   }
153   ret = ParallelLaunch(this->ms_context_, GatherRunFp16, this, op_parameter_->thread_num_);
154   if (ret != RET_OK) {
155     MS_LOG(ERROR) << "Gather function error error_code[" << ret << "]";
156   }
157   FreeIndicesData();
158   return ret;
159 }
160 
AssignIndicesData(bool isIndicesInt32,int indices_num,const lite::Tensor * indices_tensor)161 int GatherFp16CPUKernel::AssignIndicesData(bool isIndicesInt32, int indices_num, const lite::Tensor *indices_tensor) {
162   CHECK_NULL_RETURN(indices_tensor->data());
163   if (!isIndicesInt32) {
164     if (indices_num >= std::numeric_limits<int>::max() / static_cast<int>(sizeof(int))) {
165       MS_LOG(ERROR) << "Input indices_num is invalid, indices_num: " << indices_num;
166       return RET_ERROR;
167     }
168     if (indices_tensor->data_type() != kNumberTypeInt64 && indices_tensor->data_type() != kNumberTypeFloat16) {
169       MS_LOG(ERROR) << "The data type of indices tensor is wrong";
170       indices_data_ = nullptr;
171       return RET_ERROR;
172     }
173     indices_data_ = reinterpret_cast<int32_t *>(ms_context_->allocator->Malloc(sizeof(int32_t) * indices_num));
174     if (indices_data_ == nullptr) {
175       MS_LOG(ERROR) << "Memory allocation failed";
176       return RET_ERROR;
177     }
178     if (indices_tensor->data_type() == kNumberTypeInt64) {
179       for (int i = 0; i < indices_num; i++) {
180         indices_data_[i] = reinterpret_cast<int64_t *>(indices_tensor->data())[i];
181       }
182     } else {
183       for (int i = 0; i < indices_num; i++) {
184         indices_data_[i] = reinterpret_cast<float16_t *>(indices_tensor->data())[i];
185       }
186     }
187   } else {
188     indices_data_ = reinterpret_cast<int32_t *>(indices_tensor->data());
189   }
190   return RET_OK;
191 }
192 
193 REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Gather, LiteKernelCreator<GatherFp16CPUKernel>)
194 }  // namespace mindspore::kernel
195