• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/runtime/kernel/arm/string/lsh_projection.h"
17 
18 #include "include/errorcode.h"
19 #include "src/common/string_util.h"
20 #include "src/kernel_registry.h"
21 
22 using mindspore::kernel::KERNEL_ARCH;
23 using mindspore::lite::KernelRegistrar;
24 using mindspore::lite::RET_ERROR;
25 using mindspore::lite::RET_NULL_PTR;
26 using mindspore::lite::RET_OK;
27 using mindspore::schema::PrimitiveType_LshProjection;
28 
29 namespace mindspore::kernel {
Init()30 int LshProjectionCPUKernel::Init() {
31   CHECK_LESS_RETURN(in_tensors_.size(), C2NUM);
32   CHECK_LESS_RETURN(out_tensors_.size(), 1);
33   CHECK_NULL_RETURN(in_tensors_[0]);
34   CHECK_NULL_RETURN(in_tensors_[1]);
35   CHECK_NULL_RETURN(out_tensors_[0]);
36   if (!InferShapeDone()) {
37     return RET_OK;
38   }
39   return ReSize();
40 }
41 
ReSize()42 int LshProjectionCPUKernel::ReSize() { return RET_OK; }
43 
LshProjectionRun(void * cdata,int task_id,float lhs_scale,float rhs_scale)44 int LshProjectionRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
45   auto kernel = reinterpret_cast<LshProjectionCPUKernel *>(cdata);
46   return kernel->DoExecute(task_id);
47 }
48 
Run()49 int LshProjectionCPUKernel::Run() {
50   auto input0_tensor = in_tensors_.at(0);
51   auto input1_tensor = in_tensors_.at(1);
52   auto out_tensor = out_tensors_.at(0);
53 
54   hash_seed_ = reinterpret_cast<float *>(input0_tensor->MutableData());
55   CHECK_NULL_RETURN(hash_seed_);
56   feature_ = reinterpret_cast<int32_t *>(input1_tensor->MutableData());
57   CHECK_NULL_RETURN(feature_);
58   weight_ = in_tensors_.size() == 2 ? nullptr : reinterpret_cast<float *>(in_tensors_.at(2)->MutableData());
59   CHECK_NULL_RETURN(weight_);
60   output_ = reinterpret_cast<int32_t *>(out_tensor->MutableData());
61   CHECK_NULL_RETURN(output_);
62 
63   param_->hash_buff_size_ = sizeof(float) + sizeof(int32_t);
64   param_->feature_num_ = input1_tensor->ElementsNum();
65   param_->hash_shape_[0] = input0_tensor->DimensionSize(0);
66   param_->hash_shape_[1] = input0_tensor->DimensionSize(1);
67   param_->thread_stride_ = op_parameter_->thread_num_ > 1 ? UP_DIV(param_->hash_shape_[0], op_parameter_->thread_num_)
68                                                           : param_->hash_shape_[0];
69   auto ret = MallocKeys();
70   if (ret != RET_OK) {
71     return ret;
72   }
73   ret = ParallelLaunch(this->ms_context_, LshProjectionRun, this, op_parameter_->thread_num_);
74   if (ret != RET_OK) {
75     MS_LOG(ERROR) << "LshProjection kernel parallel launch failed";
76   }
77   FreeKeys();
78   return ret;
79 }
80 
MallocKeys()81 int LshProjectionCPUKernel::MallocKeys() {
82   param_->hash_buffs_ =
83     static_cast<char **>(ms_context_->allocator->Malloc(op_parameter_->thread_num_ * sizeof(char *)));
84   if (param_->hash_buffs_ == nullptr) {
85     MS_LOG(ERROR) << "Memory allocation failed";
86     return RET_ERROR;
87   }
88   for (int i = 0; i < op_parameter_->thread_num_; i++) {
89     param_->hash_buffs_[i] = static_cast<char *>(ms_context_->allocator->Malloc(param_->hash_buff_size_));
90     if (param_->hash_buffs_[i] == nullptr) {
91       FreeKeys();
92       MS_LOG(ERROR) << "Memory allocation failed";
93       return RET_ERROR;
94     }
95   }
96   return RET_OK;
97 }
98 
FreeKeys()99 void LshProjectionCPUKernel::FreeKeys() {
100   if (param_->hash_buffs_ != nullptr) {
101     for (int i = 0; i < op_parameter_->thread_num_; i++) {
102       ms_context_->allocator->Free(param_->hash_buffs_[i]);
103       param_->hash_buffs_[i] = nullptr;
104     }
105     ms_context_->allocator->Free(param_->hash_buffs_);
106     param_->hash_buffs_ = nullptr;
107   }
108 }
109 
DoExecute(int task_id)110 int LshProjectionCPUKernel::DoExecute(int task_id) {
111   int cur_group_num = MSMIN(param_->hash_shape_[0] - task_id * param_->thread_stride_, param_->thread_stride_);
112   int start = task_id * param_->thread_stride_;
113   int end = start + cur_group_num;
114   char *hash_buff = param_->hash_buffs_[task_id];
115 
116   switch (param_->lsh_type_) {
117     case schema::LshProjectionType_SPARSE:
118       LshProjectionSparse(hash_seed_, feature_, weight_, output_, param_, start, end, hash_buff);
119       break;
120     case schema::LshProjectionType_DENSE:
121       LshProjectionDense(hash_seed_, feature_, weight_, output_, param_, start, end, hash_buff);
122       break;
123     default:
124       return RET_ERROR;
125   }
126   return RET_OK;
127 }
128 
GetSignBit(int32_t * feature,float * weight,float seed,LshProjectionParameter * para,char * hash_buff)129 int LshProjectionCPUKernel::GetSignBit(int32_t *feature, float *weight, float seed, LshProjectionParameter *para,
130                                        char *hash_buff) {
131   MS_ASSERT(feature != nullptr);
132   MS_ASSERT(weight != nullptr);
133   MS_ASSERT(para != nullptr);
134   MS_ASSERT(hash_buff != nullptr);
135   double score = 0.0;
136   for (int i = 0; i < para->feature_num_; i++) {
137     memcpy(hash_buff, &seed, sizeof(float));
138     memcpy(hash_buff + sizeof(float), &(feature[i]), sizeof(int32_t));
139     int64_t hash_i = static_cast<int64_t>(lite::StringHash64(hash_buff, para->hash_buff_size_));
140     double hash_d = static_cast<double>(hash_i);
141     if (weight == nullptr) {
142       score += hash_d;
143     } else {
144       score += weight[i] * hash_d;
145     }
146   }
147   return (score > 0) ? 1 : 0;
148 }
149 
LshProjectionSparse(float * hashSeed,int32_t * feature,float * weight,int32_t * output,LshProjectionParameter * para,int32_t start,int32_t end,char * hash_buff)150 void LshProjectionCPUKernel::LshProjectionSparse(float *hashSeed, int32_t *feature, float *weight, int32_t *output,
151                                                  LshProjectionParameter *para, int32_t start, int32_t end,
152                                                  char *hash_buff) {
153   MS_ASSERT(hashSeed != nullptr);
154   MS_ASSERT(feature != nullptr);
155   MS_ASSERT(weight != nullptr);
156   MS_ASSERT(output != nullptr);
157   MS_ASSERT(para != nullptr);
158   MS_ASSERT(hash_buff != nullptr);
159   for (int i = start; i < end; i++) {
160     int32_t hash_sign = 0;
161     for (int j = 0; j < para->hash_shape_[1]; j++) {
162       int bit = GetSignBit(feature, weight, hashSeed[i * para->hash_shape_[1] + j], para, hash_buff);
163       hash_sign = (hash_sign << 1) | bit;
164     }
165     output[i] = hash_sign + i * (1 << para->hash_shape_[1]);
166   }
167 }
168 
LshProjectionDense(float * hashSeed,int32_t * feature,float * weight,int32_t * output,LshProjectionParameter * para,int32_t start,int32_t end,char * hash_buff)169 void LshProjectionCPUKernel::LshProjectionDense(float *hashSeed, int32_t *feature, float *weight, int32_t *output,
170                                                 LshProjectionParameter *para, int32_t start, int32_t end,
171                                                 char *hash_buff) {
172   MS_ASSERT(feature != nullptr);
173   MS_ASSERT(weight != nullptr);
174   MS_ASSERT(output != nullptr);
175   MS_ASSERT(para != nullptr);
176   MS_ASSERT(hash_buff != nullptr);
177   for (int i = start; i < end; i++) {
178     for (int j = 0; j < para->hash_shape_[1]; j++) {
179       output[i * para->hash_shape_[1] + j] =
180         GetSignBit(feature, weight, hashSeed[i * para->hash_shape_[1] + j], para, hash_buff);
181     }
182   }
183 }
184 
185 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LshProjection, LiteKernelCreator<LshProjectionCPUKernel>)
186 }  // namespace mindspore::kernel
187