• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/runtime/kernel/arm/base/one_hot_base.h"
18 #include "nnacl/fp32/one_hot_fp32.h"
19 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
20 #include "nnacl/fp16/one_hot_fp16.h"
21 #endif
22 #include "schema/model_generated.h"
23 #include "src/kernel_registry.h"
24 #include "include/errorcode.h"
25 
26 using mindspore::kernel::KERNEL_ARCH;
27 using mindspore::lite::KernelRegistrar;
28 using mindspore::lite::RET_ERROR;
29 using mindspore::lite::RET_NULL_PTR;
30 using mindspore::lite::RET_OK;
31 using mindspore::schema::PrimitiveType_OneHot;
32 
33 namespace mindspore::kernel {
34 namespace {
35 constexpr size_t kInputNum = 4;
36 constexpr size_t kInputNumOpt = 3;
37 constexpr size_t kOutputNum = 1;
38 }  // namespace
39 
Init()40 int OneHotCPUKernel::Init() {
41   // indices depth on_value off_value
42   if ((in_tensors_.size() != kInputNum && in_tensors_.size() != kInputNumOpt) || out_tensors_.size() != kOutputNum) {
43     MS_LOG(ERROR) << "OneHot input size should be " << kInputNum << " or " << kInputNumOpt << ", got "
44                   << in_tensors_.size() << ", output size should be" << kOutputNum << ", got " << out_tensors_.size();
45     return RET_ERROR;
46   }
47   if (ms_context_ == nullptr) {
48     MS_LOG(ERROR) << "OneHot context nullptr";
49     return RET_NULL_PTR;
50   }
51   thread_num_ = op_parameter_->thread_num_;
52 
53   auto param = reinterpret_cast<OneHotParameter *>(op_parameter_);
54   if (param == nullptr) {
55     MS_LOG(ERROR) << "OneHot op_parameter_ nullptr";
56     return RET_NULL_PTR;
57   }
58   axis_ = param->axis_;
59 
60   if (!InferShapeDone()) {
61     return RET_OK;
62   }
63   return ReSize();
64 }
65 
ReSize()66 int OneHotCPUKernel::ReSize() {
67   auto indices = in_tensors_.at(0);
68   if (indices == nullptr) {
69     MS_LOG(ERROR) << "OneHot inputs[0] indices nullptr";
70     return RET_NULL_PTR;
71   }
72   auto indices_shape = indices->shape();
73   const int indices_rank = static_cast<int>(indices_shape.size());
74   if (axis_ < 0) {
75     axis_ += indices_rank + 1;
76   }
77 
78   outer_size_ = 1;
79   for (size_t i = 0; i < static_cast<size_t>(axis_); i++) {
80     outer_size_ *= indices_shape[i];
81   }
82   if (outer_size_ == 0) {
83     return RET_ERROR;
84   }
85   inner_size_ = indices->ElementsNum() / outer_size_;
86 
87   return RET_OK;
88 }
89 
RunOneHot(void * cdata,int task_id,float lhs_scale,float rhs_scale)90 int RunOneHot(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
91   auto onehot_kernel = reinterpret_cast<OneHotCPUKernel *>(cdata);
92   if (onehot_kernel == nullptr) {
93     MS_LOG(ERROR) << "cast OneHotCPUKernel failed";
94     return RET_ERROR;
95   }
96   auto error_code = onehot_kernel->OneHotImpl(task_id);
97   if (error_code != RET_OK) {
98     MS_LOG(ERROR) << "RunOneHot error task_id[" << task_id << "] error_code[" << error_code << "]";
99     return RET_ERROR;
100   }
101   return RET_OK;
102 }
103 
OneHotImpl(int task_id)104 int OneHotCPUKernel::OneHotImpl(int task_id) {
105   auto indices_data = static_cast<int *>(in_tensors_.at(0)->data());
106   if (indices_data == nullptr) {
107     return RET_NULL_PTR;
108   }
109   auto output = out_tensors_.at(0);
110   if (output == nullptr) {
111     MS_LOG(ERROR) << "OneHot output nullptr";
112     return RET_NULL_PTR;
113   }
114   auto output_data = output->data();
115   if (output_data == nullptr) {
116     return RET_NULL_PTR;
117   }
118   auto one_hot_param = reinterpret_cast<OneHotParameter *>(op_parameter_);
119 
120   if (output->data_type() == kNumberTypeFloat32) {
121     auto ret = OneHotToFp32(indices_data, on_value_, off_value_, reinterpret_cast<float *>(output_data), one_hot_param,
122                             task_id, thread_num_);
123     return ret;
124 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
125   } else if (output->data_type() == kNumberTypeFloat16) {
126     auto ret = OneHotToFp16(indices_data, on_value_, off_value_, reinterpret_cast<float16_t *>(output_data),
127                             one_hot_param, task_id, thread_num_);
128     return ret;
129 #endif
130   } else {
131     MS_LOG(ERROR) << "OneHot output datatype is unsupported: " << output->data_type();
132     return RET_ERROR;
133   }
134 }
135 
InitParamsAndOnOffValue()136 int OneHotCPUKernel::InitParamsAndOnOffValue() {
137   auto one_hot_param = reinterpret_cast<OneHotParameter *>(op_parameter_);
138   if (one_hot_param == nullptr) {
139     MS_LOG(ERROR) << "cast OneHotParameter nullptr";
140     return RET_NULL_PTR;
141   }
142 
143   auto depth_tensor = in_tensors_.at(1);
144   if (depth_tensor == nullptr) {
145     MS_LOG(ERROR) << "OneHot inputs[1] depth nullptr";
146     return RET_NULL_PTR;
147   }
148   const int *depth = reinterpret_cast<int *>(depth_tensor->MutableData());
149   if (depth == nullptr) {
150     return RET_NULL_PTR;
151   }
152   one_hot_param->depth_ = *depth;
153 
154   if (in_tensors_.size() == kInputNum) {
155     // 4 inputs: indices, depth, on_value, off_value
156     one_hot_param->support_neg_index_ = false;
157     auto ret = InitOnOffValueForFourInputs();
158     if (ret != RET_OK) {
159       MS_LOG(ERROR) << "Init on off value failed";
160       return RET_NULL_PTR;
161     }
162   } else {
163     // 3 inputs: indices, depth, off_on_value
164     one_hot_param->support_neg_index_ = true;
165     auto ret = InitOnOffValueForThreeInputs();
166     if (ret != RET_OK) {
167       MS_LOG(ERROR) << "Init on off value failed";
168       return RET_NULL_PTR;
169     }
170   }
171 
172   one_hot_param->outer_size_ = outer_size_;
173   one_hot_param->inner_size_ = inner_size_;
174 
175   return RET_OK;
176 }
177 
InitOnOffValueForFourInputs()178 int OneHotCPUKernel::InitOnOffValueForFourInputs() {
179   auto on_value_tensor = in_tensors_.at(2);
180   if (on_value_tensor == nullptr) {
181     MS_LOG(ERROR) << "OneHot on_value tensor is nullptr";
182     return RET_NULL_PTR;
183   }
184   if (on_value_tensor->data_type() == kNumberTypeFloat32) {
185     const auto *on_value = reinterpret_cast<float *>(on_value_tensor->data());
186     if (on_value == nullptr) {
187       return RET_NULL_PTR;
188     }
189     this->on_value_ = *on_value;
190 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
191   } else if (on_value_tensor->data_type() == kNumberTypeFloat16) {
192     const auto *on_value = reinterpret_cast<float16_t *>(on_value_tensor->data());
193     if (on_value == nullptr) {
194       return RET_NULL_PTR;
195     }
196     this->on_value_ = *on_value;
197 #endif
198   } else {
199     MS_LOG(ERROR) << "OneHot on value datatype is unsupported: " << on_value_tensor->data_type();
200     return RET_NULL_PTR;
201   }
202 
203   auto off_value_tensor = in_tensors_.at(3);
204   if (off_value_tensor == nullptr) {
205     MS_LOG(ERROR) << "OneHot off_value tensor is nullptr";
206     return RET_NULL_PTR;
207   }
208 
209   if (off_value_tensor->data_type() == kNumberTypeFloat32) {
210     const auto *off_value = reinterpret_cast<float *>(off_value_tensor->data());
211     if (off_value == nullptr) {
212       return RET_NULL_PTR;
213     }
214     this->off_value_ = *off_value;
215 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
216   } else if (off_value_tensor->data_type() == kNumberTypeFloat16) {
217     const auto *off_value = reinterpret_cast<float16_t *>(off_value_tensor->data());
218     if (off_value == nullptr) {
219       return RET_NULL_PTR;
220     }
221     this->off_value_ = *off_value;
222 #endif
223   } else {
224     MS_LOG(ERROR) << "OneHot off value datatype is unsupported: " << off_value_tensor->data_type();
225     return RET_NULL_PTR;
226   }
227   return RET_OK;
228 }
229 
InitOnOffValueForThreeInputs()230 int OneHotCPUKernel::InitOnOffValueForThreeInputs() {
231   auto off_on_tensor = in_tensors_.at(2);
232   if (off_on_tensor == nullptr) {
233     MS_LOG(ERROR) << "OneHot inputs[2] on_value nullptr";
234     return RET_NULL_PTR;
235   }
236 
237   if (off_on_tensor->data_type() == kNumberTypeFloat32) {
238     const auto *off_on_values = reinterpret_cast<float *>(off_on_tensor->data());
239     if (off_on_values == nullptr) {
240       return RET_NULL_PTR;
241     }
242     this->off_value_ = off_on_values[0];
243     this->on_value_ = off_on_values[1];
244 #if defined(ENABLE_ARM) && defined(ENABLE_FP16)
245   } else if (off_on_tensor->data_type() == kNumberTypeFloat16) {
246     const auto *off_on_values = reinterpret_cast<float16_t *>(off_on_tensor->data());
247     if (off_on_values == nullptr) {
248       return RET_NULL_PTR;
249     }
250     this->off_value_ = off_on_values[0];
251     this->on_value_ = off_on_values[1];
252 #endif
253   } else {
254     MS_LOG(ERROR) << "OneHot off value datatype is unsupported: " << off_on_tensor->data_type();
255     return RET_NULL_PTR;
256   }
257   return RET_OK;
258 }
259 
Run()260 int OneHotCPUKernel::Run() {
261   auto ret = InitParamsAndOnOffValue();
262   if (ret != RET_OK) {
263     MS_LOG(ERROR) << "OneHot init param failed:" << ret;
264     return ret;
265   }
266   int error_code = ParallelLaunch(this->ms_context_, RunOneHot, this, op_parameter_->thread_num_);
267   if (error_code != RET_OK) {
268     MS_LOG(ERROR) << "OneHot function error error_code[" << error_code << "]";
269     return RET_ERROR;
270   }
271   return RET_OK;
272 }
273 
274 REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_OneHot, LiteKernelCreator<OneHotCPUKernel>)
275 }  // namespace mindspore::kernel
276