1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "src/runtime/kernel/arm/string/hashtable_lookup.h"
17 #include <string>
18 #include <algorithm>
19 #include "src/kernel_registry.h"
20 #include "src/common/string_util.h"
21
22 using mindspore::lite::KernelRegistrar;
23 using mindspore::lite::RET_ERROR;
24 using mindspore::lite::RET_OK;
25 using mindspore::schema::PrimitiveType_HashtableLookup;
26
27 namespace mindspore::kernel {
Init()28 int HashtableLookupCPUKernel::Init() {
29 if (!InferShapeDone()) {
30 return RET_OK;
31 }
32 return ReSize();
33 }
34
ReSize()35 int HashtableLookupCPUKernel::ReSize() { return RET_OK; }
36
CmpKeyFunc(const void * lhs,const void * rhs)37 static int CmpKeyFunc(const void *lhs, const void *rhs) {
38 return *static_cast<const int *>(lhs) - *static_cast<const int *>(rhs);
39 }
40
Run()41 int HashtableLookupCPUKernel::Run() {
42 auto input_tensor = in_tensors_.at(0);
43 auto keys_tensor = in_tensors_.at(1);
44 auto values_tensor = in_tensors_.at(2);
45 auto output_tensor = out_tensors_.at(0);
46 auto hits_tensor = out_tensors_.at(1);
47
48 int rows = GetStringCount(values_tensor);
49 if (rows < 0) {
50 MS_LOG(ERROR) << "get string cnt fail!";
51 return RET_ERROR;
52 }
53 int32_t *input_data = reinterpret_cast<int32_t *>(input_tensor->MutableData());
54 uint8_t *hits_data = reinterpret_cast<uint8_t *>(hits_tensor->MutableData());
55 std::vector<lite::StringPack> output_string_pack(input_tensor->ElementsNum());
56 std::vector<lite::StringPack> all_string_pack = ParseTensorBuffer(values_tensor);
57 lite::StringPack null_string_pack = {0, nullptr};
58
59 for (int i = 0; i < input_tensor->ElementsNum(); i++) {
60 int index = -1;
61 void *p = bsearch(&(input_data[i]), keys_tensor->MutableData(), rows, sizeof(int32_t), CmpKeyFunc);
62 if (p != nullptr) {
63 index = reinterpret_cast<int32_t *>(p) - reinterpret_cast<int32_t *>(keys_tensor->MutableData());
64 }
65 if (index >= rows || index < 0) {
66 output_string_pack[i] = null_string_pack;
67 hits_data[i] = 0;
68 } else {
69 output_string_pack[i] = all_string_pack[index];
70 hits_data[i] = 1;
71 }
72 }
73 WriteStringsToTensor(output_tensor, output_string_pack);
74 return RET_OK;
75 }
76
CpuHashtableLookupKernelCreator(const std::vector<lite::Tensor * > & inputs,const std::vector<lite::Tensor * > & outputs,OpParameter * parameter,const lite::Context * ctx,const kernel::KernelKey & desc)77 kernel::InnerKernel *CpuHashtableLookupKernelCreator(const std::vector<lite::Tensor *> &inputs,
78 const std::vector<lite::Tensor *> &outputs, OpParameter *parameter,
79 const lite::Context *ctx, const kernel::KernelKey &desc) {
80 auto *kernel = new (std::nothrow)
81 HashtableLookupCPUKernel(parameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx));
82 if (kernel == nullptr) {
83 MS_LOG(ERROR) << "new HashtableLookupCPUKernel fail!";
84 free(parameter);
85 return nullptr;
86 }
87 return kernel;
88 }
89
90 REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_HashtableLookup, CpuHashtableLookupKernelCreator)
91 } // namespace mindspore::kernel
92