• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/runtime/kernel/arm/string/hashtable_lookup.h"
17 #include <string>
18 #include <algorithm>
19 #include "src/kernel_registry.h"
20 #include "src/common/string_util.h"
21 
22 using mindspore::lite::KernelRegistrar;
23 using mindspore::lite::RET_ERROR;
24 using mindspore::lite::RET_OK;
25 using mindspore::schema::PrimitiveType_HashtableLookup;
26 
27 namespace mindspore::kernel {
Init()28 int HashtableLookupCPUKernel::Init() {
29   if (!InferShapeDone()) {
30     return RET_OK;
31   }
32   return ReSize();
33 }
34 
ReSize()35 int HashtableLookupCPUKernel::ReSize() { return RET_OK; }
36 
CmpKeyFunc(const void * lhs,const void * rhs)37 static int CmpKeyFunc(const void *lhs, const void *rhs) {
38   return *static_cast<const int *>(lhs) - *static_cast<const int *>(rhs);
39 }
40 
Run()41 int HashtableLookupCPUKernel::Run() {
42   auto input_tensor = in_tensors_.at(0);
43   auto keys_tensor = in_tensors_.at(1);
44   auto values_tensor = in_tensors_.at(2);
45   auto output_tensor = out_tensors_.at(0);
46   auto hits_tensor = out_tensors_.at(1);
47 
48   int rows = GetStringCount(values_tensor);
49   if (rows < 0) {
50     MS_LOG(ERROR) << "get string cnt fail!";
51     return RET_ERROR;
52   }
53   int32_t *input_data = reinterpret_cast<int32_t *>(input_tensor->MutableData());
54   uint8_t *hits_data = reinterpret_cast<uint8_t *>(hits_tensor->MutableData());
55   std::vector<lite::StringPack> output_string_pack(input_tensor->ElementsNum());
56   std::vector<lite::StringPack> all_string_pack = ParseTensorBuffer(values_tensor);
57   lite::StringPack null_string_pack = {0, nullptr};
58 
59   for (int i = 0; i < input_tensor->ElementsNum(); i++) {
60     int index = -1;
61     void *p = bsearch(&(input_data[i]), keys_tensor->MutableData(), rows, sizeof(int32_t), CmpKeyFunc);
62     if (p != nullptr) {
63       index = reinterpret_cast<int32_t *>(p) - reinterpret_cast<int32_t *>(keys_tensor->MutableData());
64     }
65     if (index >= rows || index < 0) {
66       output_string_pack[i] = null_string_pack;
67       hits_data[i] = 0;
68     } else {
69       output_string_pack[i] = all_string_pack[index];
70       hits_data[i] = 1;
71     }
72   }
73   WriteStringsToTensor(output_tensor, output_string_pack);
74   return RET_OK;
75 }
76 
CpuHashtableLookupKernelCreator(const std::vector<lite::Tensor * > & inputs,const std::vector<lite::Tensor * > & outputs,OpParameter * parameter,const lite::Context * ctx,const kernel::KernelKey & desc)77 kernel::InnerKernel *CpuHashtableLookupKernelCreator(const std::vector<lite::Tensor *> &inputs,
78                                                      const std::vector<lite::Tensor *> &outputs, OpParameter *parameter,
79                                                      const lite::Context *ctx, const kernel::KernelKey &desc) {
80   auto *kernel = new (std::nothrow)
81     HashtableLookupCPUKernel(parameter, inputs, outputs, static_cast<const lite::InnerContext *>(ctx));
82   if (kernel == nullptr) {
83     MS_LOG(ERROR) << "new HashtableLookupCPUKernel fail!";
84     free(parameter);
85     return nullptr;
86   }
87   return kernel;
88 }
89 
90 REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_HashtableLookup, CpuHashtableLookupKernelCreator)
91 }  // namespace mindspore::kernel
92