1 /**
2 * Copyright 2020-2022 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "plugin/device/cpu/kernel/embedding_look_up_cpu_kernel.h"
18 #include "mindspore/core/ops/embedding_lookup.h"
19 #include "utils/check_convert_utils.h"
20 #include "include/backend/distributed/embedding_cache/embedding_cache_utils.h"
21
22 namespace mindspore {
23 namespace kernel {
24 namespace {
25 constexpr size_t kEmbeddingLookupInputsNum = 3;
26 constexpr size_t kEmbeddingLookUpInputParamsMaxDim = 2;
27 constexpr size_t kOffsetIndex = 2;
28 using KernelRunFunc = EmbeddingLookUpCpuKernelMod::KernelRunFunc;
29
30 #define ADD_KERNEL(input_params_dtype, input_indices_dtype, output_dtype, input_params_type, input_indices_type) \
31 { \
32 KernelAttr() \
33 .AddInputAttr(kNumberType##input_params_dtype) \
34 .AddInputAttr(kNumberType##input_indices_dtype) \
35 .AddInputAttr(kNumberTypeInt64) \
36 .AddOutputAttr(kNumberType##output_dtype), \
37 &EmbeddingLookUpCpuKernelMod::LaunchKernel<input_params_type, input_indices_type, int64_t> \
38 }
39
40 #define ADD_KERNEL_INT32(input_params_dtype, input_indices_dtype, output_dtype, input_params_type, input_indices_type) \
41 { \
42 KernelAttr() \
43 .AddInputAttr(kNumberType##input_params_dtype) \
44 .AddInputAttr(kNumberType##input_indices_dtype) \
45 .AddInputAttr(kNumberTypeInt32) \
46 .AddOutputAttr(kNumberType##output_dtype), \
47 &EmbeddingLookUpCpuKernelMod::LaunchKernel<input_params_type, input_indices_type, int32_t> \
48 }
49
50 template <typename T, typename S>
LookUpTableTask(const T * input_addr,const S * indices_addr,T * output_addr,size_t indices_lens,size_t outer_dim_size,int64_t offset,size_t first_dim_size,std::string kernel_name_)51 void LookUpTableTask(const T *input_addr, const S *indices_addr, T *output_addr, size_t indices_lens,
52 size_t outer_dim_size, int64_t offset, size_t first_dim_size, std::string kernel_name_) {
53 auto type_size = sizeof(T);
54 size_t lens = outer_dim_size * type_size;
55 for (size_t i = 0; i < indices_lens; ++i) {
56 S index = indices_addr[i] - static_cast<S>(offset);
57 if (index >= 0 && index < SizeToInt(first_dim_size)) {
58 size_t pos = static_cast<size_t>(index) * outer_dim_size;
59 auto ret = memcpy_s(output_addr, (indices_lens - i) * lens, input_addr + pos, lens);
60 if (ret != EOK) {
61 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memcpy failed. Error no: " << ret;
62 }
63 } else {
64 auto ret = memset_s(output_addr, (indices_lens - i) * lens, 0, lens);
65 if (ret != EOK) {
66 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memset failed. Error no: " << ret;
67 }
68 }
69 output_addr += outer_dim_size;
70 }
71 }
72
73 // Indices should start from zero and should minus offset.
74 template <typename S>
RectifyIndex(S * indices_addr,size_t indices_lens,int64_t offset)75 void RectifyIndex(S *indices_addr, size_t indices_lens, int64_t offset) {
76 for (size_t i = 0; i < indices_lens; ++i) {
77 indices_addr[i] -= static_cast<S>(offset);
78 }
79 }
80 } // namespace
81
GetFuncList() const82 const std::vector<std::pair<KernelAttr, KernelRunFunc>> &EmbeddingLookUpCpuKernelMod::GetFuncList() const {
83 static const std::vector<std::pair<KernelAttr, KernelRunFunc>> func_list = {
84 ADD_KERNEL(Bool, Int32, Bool, bool, int32_t),
85 ADD_KERNEL(Int8, Int32, Int8, int8_t, int32_t),
86 ADD_KERNEL(Int16, Int32, Int16, int16_t, int32_t),
87 ADD_KERNEL(Int32, Int32, Int32, int32_t, int32_t),
88 ADD_KERNEL(Int64, Int32, Int64, int64_t, int32_t),
89 ADD_KERNEL(UInt8, Int32, UInt8, uint8_t, int32_t),
90 ADD_KERNEL(UInt16, Int32, UInt16, uint16_t, int32_t),
91 ADD_KERNEL(UInt32, Int32, UInt32, uint32_t, int32_t),
92 ADD_KERNEL(UInt64, Int32, UInt64, uint64_t, int32_t),
93 ADD_KERNEL(Float16, Int32, Float16, float16, int32_t),
94 ADD_KERNEL(Float32, Int32, Float32, float, int32_t),
95 ADD_KERNEL(Float64, Int32, Float64, double, int32_t),
96
97 ADD_KERNEL(Bool, Int64, Bool, bool, int64_t),
98 ADD_KERNEL(Int8, Int64, Int8, int8_t, int64_t),
99 ADD_KERNEL(Int16, Int64, Int16, int16_t, int64_t),
100 ADD_KERNEL(Int32, Int64, Int32, int32_t, int64_t),
101 ADD_KERNEL(Int64, Int64, Int64, int64_t, int64_t),
102 ADD_KERNEL(UInt8, Int64, UInt8, uint8_t, int64_t),
103 ADD_KERNEL(UInt16, Int64, UInt16, uint16_t, int64_t),
104 ADD_KERNEL(UInt32, Int64, UInt32, uint32_t, int64_t),
105 ADD_KERNEL(UInt64, Int64, UInt64, uint64_t, int64_t),
106 ADD_KERNEL(Float16, Int64, Float16, float16, int64_t),
107 ADD_KERNEL(Float32, Int64, Float32, float, int64_t),
108 ADD_KERNEL(Float64, Int64, Float64, double, int64_t),
109
110 ADD_KERNEL_INT32(Int32, Int32, Int32, int32_t, int32_t),
111 ADD_KERNEL_INT32(Float32, Int32, Float32, float, int32_t)};
112
113 return func_list;
114 }
115
Init(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)116 bool EmbeddingLookUpCpuKernelMod::Init(const std::vector<KernelTensor *> &inputs,
117 const std::vector<KernelTensor *> &outputs) {
118 if (primitive_->HasAttr(kAttrEnableEmbeddingStorage)) {
119 enable_embedding_storage_ = GetValue<bool>(primitive_->GetAttr(kAttrEnableEmbeddingStorage));
120 }
121 if (primitive_->HasAttr(kAttrParameterKey)) {
122 parameter_key_ = GetValue<int32_t>(primitive_->GetAttr(kAttrParameterKey));
123 }
124
125 return MatchKernelFunc(kernel_name_, inputs, outputs);
126 }
127
Resize(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & outputs)128 int EmbeddingLookUpCpuKernelMod::Resize(const std::vector<KernelTensor *> &inputs,
129 const std::vector<KernelTensor *> &outputs) {
130 if (int ret = KernelMod::Resize(inputs, outputs); ret != KRET_OK) {
131 return ret;
132 }
133
134 if (inputs.size() != kEmbeddingLookupInputsNum || outputs.size() != 1) {
135 MS_LOG(ERROR) << "For '" << kernel_name_ << "', input and output size must be " << kEmbeddingLookupInputsNum
136 << ", but got " << inputs.size() << " and " << outputs.size();
137 }
138
139 std::vector<int64_t> input_params_shape = inputs[kIndex0]->GetShapeVector();
140 if (input_params_shape.empty() || input_params_shape.size() > kEmbeddingLookUpInputParamsMaxDim) {
141 MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input must be 1-"
142 << kEmbeddingLookUpInputParamsMaxDim << "D, but got " << input_params_shape.size() << "D.";
143 }
144 first_dim_size_ = LongToSize(input_params_shape[0]);
145 outer_dim_size_ = 1;
146 for (size_t i = 1; i < input_params_shape.size(); ++i) {
147 outer_dim_size_ *= LongToSize(input_params_shape[i]);
148 }
149 input_params_dtype_ = inputs[kIndex0]->dtype_id();
150
151 std::vector<int64_t> input_indices_shape = inputs[kIndex1]->GetShapeVector();
152 input_indices_lens_ = SizeOf(input_indices_shape);
153 input_indices_dtype_ = inputs[kIndex1]->dtype_id();
154 return KRET_OK;
155 }
156
157 template <typename T, typename S, typename G>
LaunchKernel(const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > &,const std::vector<KernelTensor * > & outputs)158 bool EmbeddingLookUpCpuKernelMod::LaunchKernel(const std::vector<KernelTensor *> &inputs,
159 const std::vector<KernelTensor *> &,
160 const std::vector<KernelTensor *> &outputs) {
161 T *input_params_addr = GetDeviceAddress<T>(inputs, 0);
162 S *input_indices_addr = GetDeviceAddress<S>(inputs, 1);
163 T *output_addr = GetDeviceAddress<T>(outputs, 0);
164 G offset = static_cast<G *>(inputs[kOffsetIndex]->device_ptr())[0];
165 offset_ = static_cast<int64_t>(offset);
166
167 if (enable_embedding_storage_) {
168 if (offset_ != 0) {
169 // Indices should start from zero, so minus offset first.
170 auto rectify_index_task = [&](size_t start, size_t end) {
171 size_t task_proc_lens = end - start;
172 RectifyIndex<S>(input_indices_addr + start, task_proc_lens, offset_);
173 };
174 ParallelLaunchAutoSearch(rectify_index_task, input_indices_lens_, this, ¶llel_search_info_);
175 }
176
177 auto embedding_storage = embedding_storage_manager.Get(parameter_key_);
178 MS_ERROR_IF_NULL(embedding_storage);
179 if (!embedding_storage->Get({input_indices_addr, inputs[1]->size()}, {output_addr, outputs[0]->size()})) {
180 MS_LOG(ERROR) << "For '" << kernel_name_
181 << "', lookup embedding from embedding storage failed, parameter key: " << parameter_key_;
182 return false;
183 }
184 return true;
185 }
186
187 auto task = [&](size_t start, size_t end) {
188 size_t task_proc_lens = end - start;
189 LookUpTableTask<T, S>(input_params_addr, input_indices_addr + start, output_addr + start * outer_dim_size_,
190 task_proc_lens, outer_dim_size_, offset_, first_dim_size_, kernel_name_);
191 };
192
193 ParallelLaunchAutoSearch(task, input_indices_lens_, this, ¶llel_search_info_);
194 return true;
195 }
196
197 MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, EmbeddingLookup, EmbeddingLookUpCpuKernelMod);
198 } // namespace kernel
199 } // namespace mindspore
200