1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/runtime/kernel/arm/base/one_hot_base.h"
18 #include "nnacl/fp32/one_hot_fp32.h"
19 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
20 #include "nnacl/fp16/one_hot_fp16.h"
21 #endif
22 #include "schema/model_generated.h"
23 #include "src/kernel_registry.h"
24 #include "include/errorcode.h"
25
26 using mindspore::kernel::KERNEL_ARCH;
27 using mindspore::lite::KernelRegistrar;
28 using mindspore::lite::RET_ERROR;
29 using mindspore::lite::RET_NULL_PTR;
30 using mindspore::lite::RET_OK;
31 using mindspore::schema::PrimitiveType_OneHot;
32
33 namespace mindspore::kernel {
34 namespace {
35 constexpr size_t kInputNum = 4;
36 constexpr size_t kInputNumOpt = 3;
37 constexpr size_t kOutputNum = 1;
38 } // namespace
39
Init()40 int OneHotCPUKernel::Init() {
41 // indices depth on_value off_value
42 if ((in_tensors_.size() != kInputNum && in_tensors_.size() != kInputNumOpt) || out_tensors_.size() != kOutputNum) {
43 MS_LOG(ERROR) << "OneHot input size should be " << kInputNum << " or " << kInputNumOpt << ", got "
44 << in_tensors_.size() << ", output size should be" << kOutputNum << ", got " << out_tensors_.size();
45 return RET_ERROR;
46 }
47 if (ms_context_ == nullptr) {
48 MS_LOG(ERROR) << "OneHot context nullptr";
49 return RET_NULL_PTR;
50 }
51 thread_num_ = op_parameter_->thread_num_;
52
53 auto param = reinterpret_cast<OneHotParameter *>(op_parameter_);
54 if (param == nullptr) {
55 MS_LOG(ERROR) << "OneHot op_parameter_ nullptr";
56 return RET_NULL_PTR;
57 }
58 axis_ = param->axis_;
59
60 if (!InferShapeDone()) {
61 return RET_OK;
62 }
63 return ReSize();
64 }
65
ReSize()66 int OneHotCPUKernel::ReSize() {
67 auto indices = in_tensors_.at(0);
68 if (indices == nullptr) {
69 MS_LOG(ERROR) << "OneHot inputs[0] indices nullptr";
70 return RET_NULL_PTR;
71 }
72 auto indices_shape = indices->shape();
73 const int indices_rank = static_cast<int>(indices_shape.size());
74 if (axis_ < 0) {
75 axis_ += indices_rank + 1;
76 }
77
78 outer_size_ = 1;
79 for (size_t i = 0; i < static_cast<size_t>(axis_); i++) {
80 outer_size_ *= indices_shape[i];
81 }
82 if (outer_size_ == 0) {
83 return RET_ERROR;
84 }
85 inner_size_ = indices->ElementsNum() / outer_size_;
86
87 return RET_OK;
88 }
89
RunOneHot(void * cdata,int task_id,float lhs_scale,float rhs_scale)90 int RunOneHot(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
91 auto onehot_kernel = reinterpret_cast<OneHotCPUKernel *>(cdata);
92 if (onehot_kernel == nullptr) {
93 MS_LOG(ERROR) << "cast OneHotCPUKernel failed";
94 return RET_ERROR;
95 }
96 auto error_code = onehot_kernel->OneHotImpl(task_id);
97 if (error_code != RET_OK) {
98 MS_LOG(ERROR) << "RunOneHot error task_id[" << task_id << "] error_code[" << error_code << "]";
99 return RET_ERROR;
100 }
101 return RET_OK;
102 }
103
OneHotImpl(int task_id)104 int OneHotCPUKernel::OneHotImpl(int task_id) {
105 auto indices_data = static_cast<int *>(in_tensors_.at(0)->data());
106 if (indices_data == nullptr) {
107 return RET_NULL_PTR;
108 }
109 auto output = out_tensors_.at(0);
110 if (output == nullptr) {
111 MS_LOG(ERROR) << "OneHot output nullptr";
112 return RET_NULL_PTR;
113 }
114 auto output_data = output->data();
115 if (output_data == nullptr) {
116 return RET_NULL_PTR;
117 }
118 auto one_hot_param = reinterpret_cast<OneHotParameter *>(op_parameter_);
119
120 if (output->data_type() == kNumberTypeFloat32) {
121 auto ret = OneHotToFp32(indices_data, on_value_, off_value_, reinterpret_cast<float *>(output_data), one_hot_param,
122 task_id, thread_num_);
123 return ret;
124 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
125 } else if (output->data_type() == kNumberTypeFloat16) {
126 auto ret = OneHotToFp16(indices_data, on_value_, off_value_, reinterpret_cast<float16_t *>(output_data),
127 one_hot_param, task_id, thread_num_);
128 return ret;
129 #endif
130 } else {
131 MS_LOG(ERROR) << "OneHot output datatype is unsupported: " << output->data_type();
132 return RET_ERROR;
133 }
134 }
135
InitParamsAndOnOffValue()136 int OneHotCPUKernel::InitParamsAndOnOffValue() {
137 auto one_hot_param = reinterpret_cast<OneHotParameter *>(op_parameter_);
138 if (one_hot_param == nullptr) {
139 MS_LOG(ERROR) << "cast OneHotParameter nullptr";
140 return RET_NULL_PTR;
141 }
142
143 auto depth_tensor = in_tensors_.at(1);
144 if (depth_tensor == nullptr) {
145 MS_LOG(ERROR) << "OneHot inputs[1] depth nullptr";
146 return RET_NULL_PTR;
147 }
148 const int *depth = reinterpret_cast<int *>(depth_tensor->MutableData());
149 if (depth == nullptr) {
150 return RET_NULL_PTR;
151 }
152 one_hot_param->depth_ = *depth;
153
154 if (in_tensors_.size() == kInputNum) {
155 // 4 inputs: indices, depth, on_value, off_value
156 one_hot_param->support_neg_index_ = false;
157 auto ret = InitOnOffValueForFourInputs();
158 if (ret != RET_OK) {
159 MS_LOG(ERROR) << "Init on off value failed";
160 return RET_NULL_PTR;
161 }
162 } else {
163 // 3 inputs: indices, depth, off_on_value
164 one_hot_param->support_neg_index_ = true;
165 auto ret = InitOnOffValueForThreeInputs();
166 if (ret != RET_OK) {
167 MS_LOG(ERROR) << "Init on off value failed";
168 return RET_NULL_PTR;
169 }
170 }
171
172 one_hot_param->outer_size_ = outer_size_;
173 one_hot_param->inner_size_ = inner_size_;
174
175 return RET_OK;
176 }
177
InitOnOffValueForFourInputs()178 int OneHotCPUKernel::InitOnOffValueForFourInputs() {
179 auto on_value_tensor = in_tensors_.at(2);
180 if (on_value_tensor == nullptr) {
181 MS_LOG(ERROR) << "OneHot on_value tensor is nullptr";
182 return RET_NULL_PTR;
183 }
184 if (on_value_tensor->data_type() == kNumberTypeFloat32) {
185 const auto *on_value = reinterpret_cast<float *>(on_value_tensor->data());
186 if (on_value == nullptr) {
187 return RET_NULL_PTR;
188 }
189 this->on_value_ = *on_value;
190 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
191 } else if (on_value_tensor->data_type() == kNumberTypeFloat16) {
192 const auto *on_value = reinterpret_cast<float16_t *>(on_value_tensor->data());
193 if (on_value == nullptr) {
194 return RET_NULL_PTR;
195 }
196 this->on_value_ = *on_value;
197 #endif
198 } else {
199 MS_LOG(ERROR) << "OneHot on value datatype is unsupported: " << on_value_tensor->data_type();
200 return RET_NULL_PTR;
201 }
202
203 auto off_value_tensor = in_tensors_.at(3);
204 if (off_value_tensor == nullptr) {
205 MS_LOG(ERROR) << "OneHot off_value tensor is nullptr";
206 return RET_NULL_PTR;
207 }
208
209 if (off_value_tensor->data_type() == kNumberTypeFloat32) {
210 const auto *off_value = reinterpret_cast<float *>(off_value_tensor->data());
211 if (off_value == nullptr) {
212 return RET_NULL_PTR;
213 }
214 this->off_value_ = *off_value;
215 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
216 } else if (off_value_tensor->data_type() == kNumberTypeFloat16) {
217 const auto *off_value = reinterpret_cast<float16_t *>(off_value_tensor->data());
218 if (off_value == nullptr) {
219 return RET_NULL_PTR;
220 }
221 this->off_value_ = *off_value;
222 #endif
223 } else {
224 MS_LOG(ERROR) << "OneHot off value datatype is unsupported: " << off_value_tensor->data_type();
225 return RET_NULL_PTR;
226 }
227 return RET_OK;
228 }
229
InitOnOffValueForThreeInputs()230 int OneHotCPUKernel::InitOnOffValueForThreeInputs() {
231 auto off_on_tensor = in_tensors_.at(2);
232 if (off_on_tensor == nullptr) {
233 MS_LOG(ERROR) << "OneHot inputs[2] on_value nullptr";
234 return RET_NULL_PTR;
235 }
236
237 if (off_on_tensor->data_type() == kNumberTypeFloat32) {
238 const auto *off_on_values = reinterpret_cast<float *>(off_on_tensor->data());
239 if (off_on_values == nullptr) {
240 return RET_NULL_PTR;
241 }
242 this->off_value_ = off_on_values[0];
243 this->on_value_ = off_on_values[1];
244 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
245 } else if (off_on_tensor->data_type() == kNumberTypeFloat16) {
246 const auto *off_on_values = reinterpret_cast<float16_t *>(off_on_tensor->data());
247 if (off_on_values == nullptr) {
248 return RET_NULL_PTR;
249 }
250 this->off_value_ = off_on_values[0];
251 this->on_value_ = off_on_values[1];
252 #endif
253 } else {
254 MS_LOG(ERROR) << "OneHot off value datatype is unsupported: " << off_on_tensor->data_type();
255 return RET_NULL_PTR;
256 }
257 return RET_OK;
258 }
259
Run()260 int OneHotCPUKernel::Run() {
261 auto ret = InitParamsAndOnOffValue();
262 if (ret != RET_OK) {
263 MS_LOG(ERROR) << "OneHot init param failed:" << ret;
264 return ret;
265 }
266 int error_code = ParallelLaunch(this->ms_context_, RunOneHot, this, op_parameter_->thread_num_);
267 if (error_code != RET_OK) {
268 MS_LOG(ERROR) << "OneHot function error error_code[" << error_code << "]";
269 return RET_ERROR;
270 }
271 return RET_OK;
272 }
273
274 REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_OneHot, LiteKernelCreator<OneHotCPUKernel>)
275 } // namespace mindspore::kernel
276