• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_HAL_DEVICE_CPU_HASH_TABLE_UTIL_H_
17 #define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_HAL_DEVICE_CPU_HASH_TABLE_UTIL_H_
18 
19 #include <map>
20 #include <tuple>
21 #include <utility>
22 #include <memory>
23 #include <functional>
24 #include <string>
25 #include "include/backend/device_address.h"
26 #include "plugin/device/cpu/hal/device/cpu_hash_table.h"
27 
28 namespace mindspore {
29 namespace device {
30 namespace cpu {
31 using CreateHashTableFunc = std::function<void(const UserDataPtr &)>;
32 using ImportHashTableFunc = std::function<bool(const UserDataPtr &, const void *, size_t)>;
33 using ClearHashTableFunc = std::function<void(const UserDataPtr &)>;
34 
35 constexpr size_t kCreateFuncIndex = 0;
36 constexpr size_t kImportFuncIndex = 1;
37 constexpr size_t kClearFuncIndex = 2;
38 
39 /**
40  * @brief Create CPU hash table and set into `user_data`.
41  * @param[in] `user_data`: The input user data which contains meta information to create CPU hash table.
42  */
43 template <typename KeyType, typename ValueType>
CreateCPUHashTable(const UserDataPtr & user_data)44 void CreateCPUHashTable(const UserDataPtr &user_data) {
45   MS_EXCEPTION_IF_NULL(user_data);
46   auto shape_vector = user_data->get<ShapeVector>(kHashTableShapeVector);
47   auto default_value = user_data->get<Value>(kHashTableDefaultValue);
48   MS_EXCEPTION_IF_NULL(shape_vector);
49   MS_EXCEPTION_IF_NULL(default_value);
50 
51   int32_t value_size = 1;
52   for (size_t i = 0; i < (*shape_vector).size(); ++i) {
53     value_size *= (*shape_vector)[i];
54   }
55   if (value_size <= 0) {
56     MS_LOG(WARNING) << "Invalid value size:" << value_size;
57   }
58   if (default_value->isa<StringImm>()) {
59     user_data->set<CPUHashTable<KeyType, ValueType>>(
60       kUserDataData,
61       std::make_shared<CPUHashTable<KeyType, ValueType>>(value_size, GetValue<std::string>(default_value)));
62   } else if (default_value->isa<FloatImm>()) {
63     user_data->set<CPUHashTable<KeyType, ValueType>>(
64       kUserDataData, std::make_shared<CPUHashTable<KeyType, float>>(value_size, GetValue<float>(default_value)));
65   } else {
66     MS_LOG(EXCEPTION) << "Invalid Default Value:" << default_value;
67   }
68 }
69 
70 /**
71  * @brief Import key, value, status tensors to CPU hash table.
72  * @param[in] `user_data`: The input user data which contains CPU hash table need to import.
73  * @param[in] `tensor_data`: The host pointer of tensor which need to be imported into CPU hash table.
74  * @param[in] `size`: The data length in bytes of tensor data which need to be imported into CPU hash table.
75  * @return Whether the function was successfully executed.
76  */
77 template <typename KeyType, typename ValueType>
ImportCPUHashTable(const UserDataPtr & user_data,const void * tensor_data,size_t size)78 bool ImportCPUHashTable(const UserDataPtr &user_data, const void *tensor_data, size_t size) {
79   MS_EXCEPTION_IF_NULL(user_data);
80   MS_EXCEPTION_IF_NULL(tensor_data);
81   const auto &cpu_hash_table = user_data->get<CPUHashTable<KeyType, ValueType>>(kUserDataData);
82   MS_EXCEPTION_IF_NULL(cpu_hash_table);
83   if (!cpu_hash_table->Import({const_cast<void *>(tensor_data), size})) {
84     MS_LOG(ERROR) << "Import for hash table failed.";
85     return false;
86   }
87   return true;
88 }
89 
90 /**
91  * @brief Clear all resource in CPU hash table and reset all statistics.
92  * @param[in] `user_data`: The input user data which contains CPU hash table need to clear.
93  */
94 template <typename KeyType, typename ValueType>
ClearCPUHashTable(const UserDataPtr & user_data)95 void ClearCPUHashTable(const UserDataPtr &user_data) {
96   MS_EXCEPTION_IF_NULL(user_data);
97   const auto &cpu_hash_table = user_data->get<CPUHashTable<KeyType, ValueType>>(kUserDataData);
98   MS_EXCEPTION_IF_NULL(cpu_hash_table);
99   if (!cpu_hash_table->Clear()) {
100     MS_LOG(EXCEPTION) << "Clear user data failed.";
101   }
102 }
103 
104 static std::map<std::pair<TypeId, TypeId>, std::tuple<CreateHashTableFunc, ImportHashTableFunc, ClearHashTableFunc>>
105   cpu_hash_table_funcs = {
106     {std::make_pair(TypeId::kNumberTypeInt32, TypeId::kNumberTypeFloat32),
107      std::make_tuple(CreateCPUHashTable<int, float>, ImportCPUHashTable<int, float>, ClearCPUHashTable<int, float>)},
108     {std::make_pair(TypeId::kNumberTypeInt64, TypeId::kNumberTypeFloat32),
109      std::make_tuple(CreateCPUHashTable<int64_t, float>, ImportCPUHashTable<int64_t, float>,
110                      ClearCPUHashTable<int64_t, float>)}};
111 }  // namespace cpu
112 }  // namespace device
113 }  // namespace mindspore
114 #endif  // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_HAL_DEVICE_CPU_HASH_TABLE_UTIL_H_
115