1 /**
2 * Copyright 2019-2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef MINDSPORE_CORE_UTILS_ANY_H_
17 #define MINDSPORE_CORE_UTILS_ANY_H_
18
19 #include <iostream>
20 #include <string>
21 #include <typeinfo>
22 #include <typeindex>
23 #include <memory>
24 #include <functional>
25 #include <sstream>
26 #include <vector>
27 #include <utility>
28 #include "utils/overload.h"
29 #include "utils/log_adapter.h"
30 #include "utils/misc.h"
31
32 namespace mindspore {
33 // usage:AnyPtr sp = std::make_shared<Any>(aname);
34 template <class T>
type(const T & t)35 std::string type(const T &t) {
36 return demangle(typeid(t).name());
37 }
38
39 class MS_CORE_API Any {
40 public:
41 // constructors
Any()42 Any() : m_ptr(nullptr), m_tpIndex(std::type_index(typeid(void))) {}
Any(const Any & other)43 Any(const Any &other) : m_ptr(other.clone()), m_tpIndex(other.m_tpIndex) {}
Any(Any && other)44 Any(Any &&other) : m_ptr(std::move(other.m_ptr)), m_tpIndex(std::move(other.m_tpIndex)) {}
45
46 Any &operator=(Any &&other);
47 // right reference constructor
48 template <class T, class = typename std::enable_if<!std::is_same<typename std::decay<T>::type, Any>::value, T>::type>
Any(T && t)49 Any(T &&t) : m_tpIndex(typeid(typename std::decay<T>::type)) { // NOLINT
50 BasePtr new_val = std::make_unique<Derived<typename std::decay<T>::type>>(std::forward<T>(t));
51 std::swap(m_ptr, new_val);
52 }
53
54 ~Any() = default;
55
56 // judge whether is empty
empty()57 bool empty() const { return m_ptr == nullptr; }
58
59 // judge the is relation
60 template <class T>
is()61 bool is() const {
62 return m_tpIndex == std::type_index(typeid(T));
63 }
64
type()65 const std::type_info &type() const { return m_ptr ? m_ptr->type() : typeid(void); }
66
Hash()67 std::size_t Hash() const {
68 std::stringstream buffer;
69 buffer << m_tpIndex.name();
70 if (m_ptr != nullptr) {
71 buffer << m_ptr->GetString();
72 }
73 return std::hash<std::string>()(buffer.str());
74 }
75
76 template <typename T>
Apply(const std::function<void (T &)> & fn)77 bool Apply(const std::function<void(T &)> &fn) {
78 if (type() == typeid(T)) {
79 T x = cast<T>();
80 fn(x);
81 return true;
82 }
83 return false;
84 }
85
GetString()86 std::string GetString() const {
87 if (m_ptr != nullptr) {
88 return m_ptr->GetString();
89 } else {
90 return std::string("");
91 }
92 }
93
94 friend std::ostream &operator<<(std::ostream &os, const Any &any) {
95 os << any.GetString();
96 return os;
97 }
98
99 // type cast
100 template <class T>
cast()101 T &cast() const {
102 if (!is<T>() || !m_ptr) {
103 // Use MS_LOGFATAL replace throw std::bad_cast()
104 MS_LOG(EXCEPTION) << "can not cast " << m_tpIndex.name() << " to " << typeid(T).name();
105 }
106 auto ptr = static_cast<Derived<T> *>(m_ptr.get());
107 return ptr->m_value;
108 }
109
110 bool operator==(const Any &other) const {
111 if (m_tpIndex != other.m_tpIndex) {
112 return false;
113 }
114 if (m_ptr == nullptr && other.m_ptr == nullptr) {
115 return true;
116 }
117 if (m_ptr == nullptr || other.m_ptr == nullptr) {
118 return false;
119 }
120 return *m_ptr == *other.m_ptr;
121 }
122
123 bool operator!=(const Any &other) const { return !(operator==(other)); }
124
125 Any &operator=(const Any &other);
126
127 bool operator<(const Any &other) const;
128
ToString()129 std::string ToString() const {
130 std::ostringstream buffer;
131 if (m_tpIndex == typeid(float)) {
132 buffer << "<float> " << cast<float>();
133 } else if (m_tpIndex == typeid(double)) {
134 buffer << "<double> " << cast<double>();
135 } else if (m_tpIndex == typeid(int)) {
136 buffer << "<int> " << cast<int>();
137 } else if (m_tpIndex == typeid(bool)) {
138 buffer << "<bool> " << cast<bool>();
139 } else if (m_ptr != nullptr) {
140 buffer << "<" << demangle(m_tpIndex.name()) << "> " << m_ptr->GetString();
141 }
142 return buffer.str();
143 }
144 #ifdef _MSC_VER
dump()145 void dump() const { std::cout << ToString() << std::endl; }
146 #else
dump()147 __attribute__((used)) void dump() const { std::cout << ToString() << std::endl; }
148 #endif
149
150 private:
151 struct Base;
152 using BasePtr = std::unique_ptr<Base>;
153
154 // type base definition
155 struct Base {
156 virtual const std::type_info &type() const = 0;
157 virtual BasePtr clone() const = 0;
158 virtual ~Base() = default;
159 virtual bool operator==(const Base &other) const = 0;
160 virtual std::string GetString() = 0;
161 };
162
163 template <typename T>
164 struct Derived : public Base {
165 template <typename... Args>
DerivedDerived166 explicit Derived(Args &&... args) : m_value(std::forward<Args>(args)...), serialize_cache_("") {}
167
168 bool operator==(const Base &other) const override {
169 if (typeid(*this) != typeid(other)) {
170 return false;
171 }
172 return m_value == static_cast<const Derived<T> &>(other).m_value;
173 }
174
typeDerived175 const std::type_info &type() const override { return typeid(T); }
176
cloneDerived177 BasePtr clone() const override { return std::make_unique<Derived<T>>(m_value); }
178
~DerivedDerived179 ~Derived() override {}
180
GetStringDerived181 std::string GetString() override {
182 std::stringstream buffer;
183 buffer << m_value;
184 return buffer.str();
185 }
186
187 T m_value;
188 std::string serialize_cache_;
189 };
190
191 // clone method
clone()192 BasePtr clone() const {
193 if (m_ptr != nullptr) {
194 return m_ptr->clone();
195 }
196 return nullptr;
197 }
198
199 BasePtr m_ptr; // point to real data
200 std::type_index m_tpIndex; // type info of data
201 };
202
203 using AnyPtr = std::shared_ptr<Any>;
204
205 struct AnyHash {
operatorAnyHash206 std::size_t operator()(const Any &c) const { return c.Hash(); }
207 };
208
209 struct AnyLess {
operatorAnyLess210 bool operator()(const Any &a, const Any &b) const { return a.Hash() < b.Hash(); }
211 };
212
213 bool AnyIsLiteral(const Any &any);
214 } // namespace mindspore
215
216 #endif // MINDSPORE_CORE_UTILS_ANY_H_
217