• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/runtime/kernel/arm/base/select.h"
17 #include "src/kernel_registry.h"
18 #include "include/errorcode.h"
19 #ifndef CONTROLFLOW_TENSORLIST_CLIP
20 #include "src/tensorlist.h"
21 #endif
22 
23 using mindspore::lite::KernelRegistrar;
24 using mindspore::lite::RET_ERROR;
25 using mindspore::lite::RET_NULL_PTR;
26 using mindspore::lite::RET_OK;
27 using mindspore::schema::PrimitiveType_Select;
28 
29 namespace mindspore::kernel {
30 constexpr static int kFirstIdx = 1;
31 constexpr static int kSecondIdx = 2;
32 
Init()33 int SelectCPUKernel::Init() { return RET_OK; }
34 
ReSize()35 int SelectCPUKernel::ReSize() { return RET_OK; }
36 
37 // inputs: bool*1 true-data*n false-data*n
38 // output: data*n
Run()39 int SelectCPUKernel::Run() {
40   MS_ASSERT(in_tensors_.size() >= 3);
41   MS_ASSERT(in_tensors_.size() == out_tensors_.size() * 2 + 1);
42   auto bool_tensor = in_tensors_.front();
43   MS_ASSERT(bool_tensor != nullptr);
44   MS_ASSERT(bool_tensor->data_type() == kNumberTypeBool);
45   if (bool_tensor->Size() == 1) {
46     auto condition = static_cast<bool *>(bool_tensor->data());
47     if (condition == nullptr) {
48       MS_LOG(ERROR) << "data of bool tensor is nullptr";
49       return lite::RET_NULL_PTR;
50     }
51     if (*condition) {
52       auto ret = MoveData(this->out_tensors_.begin(), this->out_tensors_.end(), this->in_tensors_.begin() + 1,
53                           this->in_tensors_.begin() + 1 + this->out_tensors_.size());
54       if (ret != RET_OK) {
55         MS_LOG(ERROR) << "carry data error : " << ret;
56         return ret;
57       }
58     } else {
59       auto ret = MoveData(this->out_tensors_.begin(), this->out_tensors_.end(),
60                           this->in_tensors_.begin() + 1 + this->out_tensors_.size(),
61                           this->in_tensors_.begin() + 1 + 2 * this->out_tensors_.size());
62       if (ret != RET_OK) {
63         MS_LOG(ERROR) << "carry data error : " << ret;
64         return ret;
65       }
66     }
67   } else {
68     MS_ASSERT(bool_tensor->shape().size() == in_tensors_.at(1)->shape().size());
69     for (size_t i = 0; i < in_tensors_.at(1)->shape().size(); i++) {
70       if (bool_tensor->shape()[i] != in_tensors_.at(1)->shape()[i]) {
71         MS_LOG(ERROR) << "Tensor shapes differ in dim: " << i << " in_tensors_.at(0): " << bool_tensor->shape()[i]
72                       << " in_tensors_.at(1): " << in_tensors_.at(1)->shape()[i];
73         return RET_ERROR;
74       }
75     }
76     MS_ASSERT(in_tensors_.at(1)->Size() == out_tensors_.at(0)->Size());
77     auto size = in_tensors_.at(1)->ElementsNum();
78     auto condition = static_cast<bool *>(bool_tensor->data());
79     auto input1 = static_cast<float *>(in_tensors_.at(kFirstIdx)->data());
80     auto input2 = static_cast<float *>(in_tensors_.at(kSecondIdx)->data());
81     auto output = static_cast<float *>(out_tensors_.at(0)->data());
82     if (condition == nullptr || input1 == nullptr || input2 == nullptr || output == nullptr) {
83       return RET_NULL_PTR;
84     }
85     for (int i = 0; i < size; i++) {
86       output[i] = condition[i] ? input1[i] : input2[i];
87     }
88   }
89   return RET_OK;
90 }
91 
92 REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_Select, LiteKernelCreator<SelectCPUKernel>)
93 }  // namespace mindspore::kernel
94