• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_USER_DATA_H_
18 #define MINDSPORE_CORE_USER_DATA_H_
19 
20 #include <string>
21 #include <memory>
22 #include <utility>
23 #include "utils/hash_map.h"
24 
25 namespace mindspore {
26 class UserData {
27  public:
28   using DataMap = mindspore::HashMap<std::string, std::shared_ptr<void>>;
29 
30   UserData() = default;
UserData(const UserData & other)31   UserData(const UserData &other) : data_(other.data_ ? std::make_unique<DataMap>(*other.data_) : nullptr) {}
UserData(UserData && other)32   UserData(UserData &&other) : data_(std::move(other.data_)) {}
33   UserData &operator=(const UserData &other) {
34     if (this == &other) {
35       return *this;
36     }
37     data_ = (other.data_ ? std::make_unique<DataMap>(*other.data_) : nullptr);
38     return *this;
39   }
40   UserData &operator=(UserData &&other) {
41     if (this == &other) {
42       return *this;
43     }
44     data_ = std::move(other.data_);
45     return *this;
46   }
47   ~UserData() = default;
48 
49   template <typename T>
set(const std::string & key,const std::shared_ptr<T> & value)50   void set(const std::string &key, const std::shared_ptr<T> &value) {
51     InitData();
52     if (value == nullptr) {
53       (void)data_->erase(key);
54     } else {
55       (void)data_->insert_or_assign(key, value);
56     }
57   }
58 
59   template <typename T>
get(const std::string & key)60   std::shared_ptr<T> get(const std::string &key) const {
61     if (data_ == nullptr) {
62       return nullptr;
63     }
64     auto iter = data_->find(key);
65     if (iter == data_->end()) {
66       return nullptr;
67     }
68     return std::static_pointer_cast<T>(iter->second);
69   }
70 
has(const std::string & key)71   bool has(const std::string &key) const { return (data_ != nullptr) && (data_->find(key) != data_->end()); }
72 
size()73   size_t size() const {
74     if (data_ == nullptr) {
75       return 0;
76     }
77     return data_->size();
78   }
79 
80  private:
InitData()81   void InitData() {
82     if (data_ == nullptr) {
83       data_ = std::make_unique<DataMap>();
84     }
85   }
86   std::unique_ptr<DataMap> data_;
87 };
88 
89 using UserDataPtr = std::shared_ptr<UserData>;
90 
91 // User data key name.
92 constexpr auto kUserDataData = "user_data_data";
93 constexpr auto kUserDataType = "user_data_type";
94 // User data key for hash table.
95 constexpr auto kHashTableKeyType = "hash_table_key_type";
96 constexpr auto kHashTableValueType = "hash_table_value_type";
97 constexpr auto kHashTableShapeVector = "hash_table_shape_vector";
98 constexpr auto kHashTableDefaultValue = "hash_table_default_value";
99 constexpr auto kHashTablePermitFilter = "hash_table_permit_filter";
100 constexpr auto kHashTableEvictFilter = "hash_table_evict_filter";
101 
102 enum class UserDataType { kUserDataTypeUnknown = 0, kUserTypeHashTable };
103 }  // namespace mindspore
104 
105 #endif  // MINDSPORE_CORE_USER_DATA_H_
106