1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_HAL_DEVICE_CPU_HASH_TABLE_UTIL_H_
17 #define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_HAL_DEVICE_CPU_HASH_TABLE_UTIL_H_
18
19 #include <map>
20 #include <tuple>
21 #include <utility>
22 #include <memory>
23 #include <functional>
24 #include <string>
25 #include "include/backend/device_address.h"
26 #include "plugin/device/cpu/hal/device/cpu_hash_table.h"
27
28 namespace mindspore {
29 namespace device {
30 namespace cpu {
31 using CreateHashTableFunc = std::function<void(const UserDataPtr &)>;
32 using ImportHashTableFunc = std::function<bool(const UserDataPtr &, const void *, size_t)>;
33 using ClearHashTableFunc = std::function<void(const UserDataPtr &)>;
34
35 constexpr size_t kCreateFuncIndex = 0;
36 constexpr size_t kImportFuncIndex = 1;
37 constexpr size_t kClearFuncIndex = 2;
38
39 /**
40 * @brief Create CPU hash table and set into `user_data`.
41 * @param[in] `user_data`: The input user data which contains meta information to create CPU hash table.
42 */
43 template <typename KeyType, typename ValueType>
CreateCPUHashTable(const UserDataPtr & user_data)44 void CreateCPUHashTable(const UserDataPtr &user_data) {
45 MS_EXCEPTION_IF_NULL(user_data);
46 auto shape_vector = user_data->get<ShapeVector>(kHashTableShapeVector);
47 auto default_value = user_data->get<Value>(kHashTableDefaultValue);
48 MS_EXCEPTION_IF_NULL(shape_vector);
49 MS_EXCEPTION_IF_NULL(default_value);
50
51 int32_t value_size = 1;
52 for (size_t i = 0; i < (*shape_vector).size(); ++i) {
53 value_size *= (*shape_vector)[i];
54 }
55 if (value_size <= 0) {
56 MS_LOG(WARNING) << "Invalid value size:" << value_size;
57 }
58 if (default_value->isa<StringImm>()) {
59 user_data->set<CPUHashTable<KeyType, ValueType>>(
60 kUserDataData,
61 std::make_shared<CPUHashTable<KeyType, ValueType>>(value_size, GetValue<std::string>(default_value)));
62 } else if (default_value->isa<FloatImm>()) {
63 user_data->set<CPUHashTable<KeyType, ValueType>>(
64 kUserDataData, std::make_shared<CPUHashTable<KeyType, float>>(value_size, GetValue<float>(default_value)));
65 } else {
66 MS_LOG(EXCEPTION) << "Invalid Default Value:" << default_value;
67 }
68 }
69
70 /**
71 * @brief Import key, value, status tensors to CPU hash table.
72 * @param[in] `user_data`: The input user data which contains CPU hash table need to import.
73 * @param[in] `tensor_data`: The host pointer of tensor which need to be imported into CPU hash table.
74 * @param[in] `size`: The data length in bytes of tensor data which need to be imported into CPU hash table.
75 * @return Whether the function was successfully executed.
76 */
77 template <typename KeyType, typename ValueType>
ImportCPUHashTable(const UserDataPtr & user_data,const void * tensor_data,size_t size)78 bool ImportCPUHashTable(const UserDataPtr &user_data, const void *tensor_data, size_t size) {
79 MS_EXCEPTION_IF_NULL(user_data);
80 MS_EXCEPTION_IF_NULL(tensor_data);
81 const auto &cpu_hash_table = user_data->get<CPUHashTable<KeyType, ValueType>>(kUserDataData);
82 MS_EXCEPTION_IF_NULL(cpu_hash_table);
83 if (!cpu_hash_table->Import({const_cast<void *>(tensor_data), size})) {
84 MS_LOG(ERROR) << "Import for hash table failed.";
85 return false;
86 }
87 return true;
88 }
89
90 /**
91 * @brief Clear all resource in CPU hash table and reset all statistics.
92 * @param[in] `user_data`: The input user data which contains CPU hash table need to clear.
93 */
94 template <typename KeyType, typename ValueType>
ClearCPUHashTable(const UserDataPtr & user_data)95 void ClearCPUHashTable(const UserDataPtr &user_data) {
96 MS_EXCEPTION_IF_NULL(user_data);
97 const auto &cpu_hash_table = user_data->get<CPUHashTable<KeyType, ValueType>>(kUserDataData);
98 MS_EXCEPTION_IF_NULL(cpu_hash_table);
99 if (!cpu_hash_table->Clear()) {
100 MS_LOG(EXCEPTION) << "Clear user data failed.";
101 }
102 }
103
104 static std::map<std::pair<TypeId, TypeId>, std::tuple<CreateHashTableFunc, ImportHashTableFunc, ClearHashTableFunc>>
105 cpu_hash_table_funcs = {
106 {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat32),
107 std::make_tuple(CreateCPUHashTable<int, float>, ImportCPUHashTable<int, float>, ClearCPUHashTable<int, float>)},
108 {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeFloat32),
109 std::make_tuple(CreateCPUHashTable<int64_t, float>, ImportCPUHashTable<int64_t, float>,
110 ClearCPUHashTable<int64_t, float>)}};
111 } // namespace cpu
112 } // namespace device
113 } // namespace mindspore
114 #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_HAL_DEVICE_CPU_HASH_TABLE_UTIL_H_
115