1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "src/runtime/kernel/arm/base/select.h"
17 #include "src/kernel_registry.h"
18 #include "include/errorcode.h"
19 #ifndef CONTROLFLOW_TENSORLIST_CLIP
20 #include "src/tensorlist.h"
21 #endif
22
23 using mindspore::lite::KernelRegistrar;
24 using mindspore::lite::RET_ERROR;
25 using mindspore::lite::RET_NULL_PTR;
26 using mindspore::lite::RET_OK;
27 using mindspore::schema::PrimitiveType_Select;
28
29 namespace mindspore::kernel {
30 constexpr static int kFirstIdx = 1;
31 constexpr static int kSecondIdx = 2;
32
Init()33 int SelectCPUKernel::Init() { return RET_OK; }
34
ReSize()35 int SelectCPUKernel::ReSize() { return RET_OK; }
36
37 // inputs: bool*1 true-data*n false-data*n
38 // output: data*n
Run()39 int SelectCPUKernel::Run() {
40 MS_ASSERT(in_tensors_.size() >= 3);
41 MS_ASSERT(in_tensors_.size() == out_tensors_.size() * 2 + 1);
42 auto bool_tensor = in_tensors_.front();
43 MS_ASSERT(bool_tensor != nullptr);
44 MS_ASSERT(bool_tensor->data_type() == kNumberTypeBool);
45 if (bool_tensor->Size() == 1) {
46 auto condition = static_cast<bool *>(bool_tensor->data());
47 if (condition == nullptr) {
48 MS_LOG(ERROR) << "data of bool tensor is nullptr";
49 return lite::RET_NULL_PTR;
50 }
51 if (*condition) {
52 auto ret = MoveData(this->out_tensors_.begin(), this->out_tensors_.end(), this->in_tensors_.begin() + 1,
53 this->in_tensors_.begin() + 1 + this->out_tensors_.size());
54 if (ret != RET_OK) {
55 MS_LOG(ERROR) << "carry data error : " << ret;
56 return ret;
57 }
58 } else {
59 auto ret = MoveData(this->out_tensors_.begin(), this->out_tensors_.end(),
60 this->in_tensors_.begin() + 1 + this->out_tensors_.size(),
61 this->in_tensors_.begin() + 1 + 2 * this->out_tensors_.size());
62 if (ret != RET_OK) {
63 MS_LOG(ERROR) << "carry data error : " << ret;
64 return ret;
65 }
66 }
67 } else {
68 MS_ASSERT(bool_tensor->shape().size() == in_tensors_.at(1)->shape().size());
69 for (size_t i = 0; i < in_tensors_.at(1)->shape().size(); i++) {
70 if (bool_tensor->shape()[i] != in_tensors_.at(1)->shape()[i]) {
71 MS_LOG(ERROR) << "Tensor shapes differ in dim: " << i << " in_tensors_.at(0): " << bool_tensor->shape()[i]
72 << " in_tensors_.at(1): " << in_tensors_.at(1)->shape()[i];
73 return RET_ERROR;
74 }
75 }
76 MS_ASSERT(in_tensors_.at(1)->Size() == out_tensors_.at(0)->Size());
77 auto size = in_tensors_.at(1)->ElementsNum();
78 auto condition = static_cast<bool *>(bool_tensor->data());
79 auto input1 = static_cast<float *>(in_tensors_.at(kFirstIdx)->data());
80 auto input2 = static_cast<float *>(in_tensors_.at(kSecondIdx)->data());
81 auto output = static_cast<float *>(out_tensors_.at(0)->data());
82 if (condition == nullptr || input1 == nullptr || input2 == nullptr || output == nullptr) {
83 return RET_NULL_PTR;
84 }
85 for (int i = 0; i < size; i++) {
86 output[i] = condition[i] ? input1[i] : input2[i];
87 }
88 }
89 return RET_OK;
90 }
91
92 REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Select, LiteKernelCreator<SelectCPUKernel>)
93 } // namespace mindspore::kernel
94