• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "plugin/device/cpu/kernel/embedding_look_up_cpu_kernel.h"
18 #include "mindspore/core/ops/embedding_lookup.h"
19 #include "utils/check_convert_utils.h"
20 #include "include/backend/distributed/embedding_cache/embedding_cache_utils.h"
21 
22 namespace mindspore {
23 namespace kernel {
24 namespace {
25 constexpr size_t kEmbeddingLookupInputsNum = 3;
26 constexpr size_t kEmbeddingLookUpInputParamsMaxDim = 2;
27 constexpr size_t kOffsetIndex = 2;
28 using KernelRunFunc = EmbeddingLookUpCpuKernelMod::KernelRunFunc;
29 
30 #define ADD_KERNEL(input_params_dtype, input_indices_dtype, output_dtype, input_params_type, input_indices_type) \
31   {                                                                                                              \
32     KernelAttr()                                                                                                 \
33       .AddInputAttr(kNumberType##input_params_dtype)                                                             \
34       .AddInputAttr(kNumberType##input_indices_dtype)                                                            \
35       .AddInputAttr(kNumberTypeInt64)                                                                            \
36       .AddOutputAttr(kNumberType##output_dtype),                                                                 \
37       &EmbeddingLookUpCpuKernelMod::LaunchKernel<input_params_type, input_indices_type, int64_t>                 \
38   }
39 
40 #define ADD_KERNEL_INT32(input_params_dtype, input_indices_dtype, output_dtype, input_params_type, input_indices_type) \
41   {                                                                                                                    \
42     KernelAttr()                                                                                                       \
43       .AddInputAttr(kNumberType##input_params_dtype)                                                                   \
44       .AddInputAttr(kNumberType##input_indices_dtype)                                                                  \
45       .AddInputAttr(kNumberTypeInt32)                                                                                  \
46       .AddOutputAttr(kNumberType##output_dtype),                                                                       \
47       &EmbeddingLookUpCpuKernelMod::LaunchKernel<input_params_type, input_indices_type, int32_t>                       \
48   }
49 
50 template <typename T, typename S>
LookUpTableTask(const T * input_addr,const S * indices_addr,T * output_addr,size_t indices_lens,size_t outer_dim_size,int64_t offset,size_t first_dim_size,std::string kernel_name_)51 void LookUpTableTask(const T *input_addr, const S *indices_addr, T *output_addr, size_t indices_lens,
52                      size_t outer_dim_size, int64_t offset, size_t first_dim_size, std::string kernel_name_) {
53   auto type_size = sizeof(T);
54   size_t lens = outer_dim_size * type_size;
55   for (size_t i = 0; i < indices_lens; ++i) {
56     S index = indices_addr[i] - static_cast<S>(offset);
57     if (index >= 0 && index < SizeToInt(first_dim_size)) {
58       size_t pos = static_cast<size_t>(index) * outer_dim_size;
59       auto ret = memcpy_s(output_addr, (indices_lens - i) * lens, input_addr + pos, lens);
60       if (ret != EOK) {
61         MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memcpy failed. Error no: " << ret;
62       }
63     } else {
64       auto ret = memset_s(output_addr, (indices_lens - i) * lens, 0, lens);
65       if (ret != EOK) {
66         MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memset failed. Error no: " << ret;
67       }
68     }
69     output_addr += outer_dim_size;
70   }
71 }
72 
73 // Indices should start from zero and should minus offset.
74 template <typename S>
RectifyIndex(S * indices_addr,size_t indices_lens,int64_t offset)75 void RectifyIndex(S *indices_addr, size_t indices_lens, int64_t offset) {
76   for (size_t i = 0; i < indices_lens; ++i) {
77     indices_addr[i] -= static_cast<S>(offset);
78   }
79 }
80 }  // namespace
81 
GetFuncList() const82 const std::vector<std::pair<KernelAttr, KernelRunFunc>> &EmbeddingLookUpCpuKernelMod::GetFuncList() const {
83   static const std::vector<std::pair<KernelAttr, KernelRunFunc>> func_list = {
84     ADD_KERNEL(Bool, Int32, Bool, bool, int32_t),
85     ADD_KERNEL(Int8, Int32, Int8, int8_t, int32_t),
86     ADD_KERNEL(Int16, Int32, Int16, int16_t, int32_t),
87     ADD_KERNEL(Int32, Int32, Int32, int32_t, int32_t),
88     ADD_KERNEL(Int64, Int32, Int64, int64_t, int32_t),
89     ADD_KERNEL(UInt8, Int32, UInt8, uint8_t, int32_t),
90     ADD_KERNEL(UInt16, Int32, UInt16, uint16_t, int32_t),
91     ADD_KERNEL(UInt32, Int32, UInt32, uint32_t, int32_t),
92     ADD_KERNEL(UInt64, Int32, UInt64, uint64_t, int32_t),
93     ADD_KERNEL(Float16, Int32, Float16, float16, int32_t),
94     ADD_KERNEL(Float32, Int32, Float32, float, int32_t),
95     ADD_KERNEL(Float64, Int32, Float64, double, int32_t),
96 
97     ADD_KERNEL(Bool, Int64, Bool, bool, int64_t),
98     ADD_KERNEL(Int8, Int64, Int8, int8_t, int64_t),
99     ADD_KERNEL(Int16, Int64, Int16, int16_t, int64_t),
100     ADD_KERNEL(Int32, Int64, Int32, int32_t, int64_t),
101     ADD_KERNEL(Int64, Int64, Int64, int64_t, int64_t),
102     ADD_KERNEL(UInt8, Int64, UInt8, uint8_t, int64_t),
103     ADD_KERNEL(UInt16, Int64, UInt16, uint16_t, int64_t),
104     ADD_KERNEL(UInt32, Int64, UInt32, uint32_t, int64_t),
105     ADD_KERNEL(UInt64, Int64, UInt64, uint64_t, int64_t),
106     ADD_KERNEL(Float16, Int64, Float16, float16, int64_t),
107     ADD_KERNEL(Float32, Int64, Float32, float, int64_t),
108     ADD_KERNEL(Float64, Int64, Float64, double, int64_t),
109 
110     ADD_KERNEL_INT32(Int32, Int32, Int32, int32_t, int32_t),
111     ADD_KERNEL_INT32(Float32, Int32, Float32, float, int32_t)};
112 
113   return func_list;
114 }
115 
Init(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)116 bool EmbeddingLookUpCpuKernelMod::Init(const std::vector<KernelTensor *> &inputs,
117                                        const std::vector<KernelTensor *> &outputs) {
118   if (primitive_->HasAttr(kAttrEnableEmbeddingStorage)) {
119     enable_embedding_storage_ = GetValue<bool>(primitive_->GetAttr(kAttrEnableEmbeddingStorage));
120   }
121   if (primitive_->HasAttr(kAttrParameterKey)) {
122     parameter_key_ = GetValue<int32_t>(primitive_->GetAttr(kAttrParameterKey));
123   }
124 
125   return MatchKernelFunc(kernel_name_, inputs, outputs);
126 }
127 
Resize(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)128 int EmbeddingLookUpCpuKernelMod::Resize(const std::vector<KernelTensor *> &inputs,
129                                         const std::vector<KernelTensor *> &outputs) {
130   if (int ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) {
131     return ret;
132   }
133 
134   if (inputs.size() != kEmbeddingLookupInputsNum || outputs.size() != 1) {
135     MS_LOG(ERROR) << "For '" << kernel_name_ << "', input and output size must be " << kEmbeddingLookupInputsNum
136                   << ", but got " << inputs.size() << " and " << outputs.size();
137   }
138 
139   std::vector<int64_t> input_params_shape = inputs[kIndex0]->GetShapeVector();
140   if (input_params_shape.empty() || input_params_shape.size() > kEmbeddingLookUpInputParamsMaxDim) {
141     MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input must be 1-"
142                       << kEmbeddingLookUpInputParamsMaxDim << "D, but got " << input_params_shape.size() << "D.";
143   }
144   first_dim_size_ = LongToSize(input_params_shape[0]);
145   outer_dim_size_ = 1;
146   for (size_t i = 1; i < input_params_shape.size(); ++i) {
147     outer_dim_size_ *= LongToSize(input_params_shape[i]);
148   }
149   input_params_dtype_ = inputs[kIndex0]->dtype_id();
150 
151   std::vector<int64_t> input_indices_shape = inputs[kIndex1]->GetShapeVector();
152   input_indices_lens_ = SizeOf(input_indices_shape);
153   input_indices_dtype_ = inputs[kIndex1]->dtype_id();
154   return KRET_OK;
155 }
156 
157 template <typename T, typename S, typename G>
LaunchKernel(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > &,const std::vector<KernelTensor * > & outputs)158 bool EmbeddingLookUpCpuKernelMod::LaunchKernel(const std::vector<KernelTensor *> &inputs,
159                                                const std::vector<KernelTensor *> &,
160                                                const std::vector<KernelTensor *> &outputs) {
161   T *input_params_addr = GetDeviceAddress<T>(inputs, 0);
162   S *input_indices_addr = GetDeviceAddress<S>(inputs, 1);
163   T *output_addr = GetDeviceAddress<T>(outputs, 0);
164   G offset = static_cast<G *>(inputs[kOffsetIndex]->device_ptr())[0];
165   offset_ = static_cast<int64_t>(offset);
166 
167   if (enable_embedding_storage_) {
168     if (offset_ != 0) {
169       // Indices should start from zero, so minus offset first.
170       auto rectify_index_task = [&](size_t start, size_t end) {
171         size_t task_proc_lens = end - start;
172         RectifyIndex<S>(input_indices_addr + start, task_proc_lens, offset_);
173       };
174       ParallelLaunchAutoSearch(rectify_index_task, input_indices_lens_, this, &parallel_search_info_);
175     }
176 
177     auto embedding_storage = embedding_storage_manager.Get(parameter_key_);
178     MS_ERROR_IF_NULL(embedding_storage);
179     if (!embedding_storage->Get({input_indices_addr, inputs[1]->size()}, {output_addr, outputs[0]->size()})) {
180       MS_LOG(ERROR) << "For '" << kernel_name_
181                     << "', lookup embedding from embedding storage failed, parameter key: " << parameter_key_;
182       return false;
183     }
184     return true;
185   }
186 
187   auto task = [&](size_t start, size_t end) {
188     size_t task_proc_lens = end - start;
189     LookUpTableTask<T, S>(input_params_addr, input_indices_addr + start, output_addr + start * outer_dim_size_,
190                           task_proc_lens, outer_dim_size_, offset_, first_dim_size_, kernel_name_);
191   };
192 
193   ParallelLaunchAutoSearch(task, input_indices_lens_, this, &parallel_search_info_);
194   return true;
195 }
196 
197 MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, EmbeddingLookup, EmbeddingLookUpCpuKernelMod);
198 }  // namespace kernel
199 }  // namespace mindspore
200