1 /**
2 * Copyright 2019-2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef MINDSPORE_CORE_UTILS_ANY_H_
17 #define MINDSPORE_CORE_UTILS_ANY_H_
18
19 #include <iostream>
20 #include <string>
21 #include <typeinfo>
22 #include <typeindex>
23 #include <memory>
24 #include <functional>
25 #include <sstream>
26 #include <vector>
27 #include <utility>
28
29 #include "utils/overload.h"
30 #include "utils/log_adapter.h"
31 #include "utils/misc.h"
32
33 namespace mindspore {
34 // usage:AnyPtr sp = std::make_shared<Any>(aname);
35 template <class T>
type(const T & t)36 std::string type(const T &t) {
37 return demangle(typeid(t).name());
38 }
39
40 class Any {
41 public:
42 // constructors
Any()43 Any() : m_ptr(nullptr), m_tpIndex(std::type_index(typeid(void))) {}
Any(const Any & other)44 Any(const Any &other) : m_ptr(other.clone()), m_tpIndex(other.m_tpIndex) {}
Any(Any && other)45 Any(Any &&other) : m_ptr(std::move(other.m_ptr)), m_tpIndex(std::move(other.m_tpIndex)) {}
46
47 Any &operator=(Any &&other);
48 // right reference constructor
49 template <class T, class = typename std::enable_if<!std::is_same<typename std::decay<T>::type, Any>::value, T>::type>
Any(T && t)50 Any(T &&t) : m_tpIndex(typeid(typename std::decay<T>::type)) { // NOLINT
51 BasePtr new_val(new Derived<typename std::decay<T>::type>(std::forward<T>(t)));
52 std::swap(m_ptr, new_val);
53 }
54
55 ~Any() = default;
56
57 // judge whether is empty
empty()58 bool empty() const { return m_ptr == nullptr; }
59
60 // judge the is relation
61 template <class T>
is()62 bool is() const {
63 return m_tpIndex == std::type_index(typeid(T));
64 }
65
type()66 const std::type_info &type() const { return m_ptr ? m_ptr->type() : typeid(void); }
67
Hash()68 std::size_t Hash() const {
69 std::stringstream buffer;
70 buffer << m_tpIndex.name();
71 if (m_ptr != nullptr) {
72 buffer << m_ptr->GetString();
73 }
74 return std::hash<std::string>()(buffer.str());
75 }
76
77 template <typename T>
Apply(const std::function<void (T &)> & fn)78 bool Apply(const std::function<void(T &)> &fn) {
79 if (type() == typeid(T)) {
80 T x = cast<T>();
81 fn(x);
82 return true;
83 }
84 return false;
85 }
86
GetString()87 std::string GetString() const {
88 if (m_ptr != nullptr) {
89 return m_ptr->GetString();
90 } else {
91 return std::string("");
92 }
93 }
94
95 friend std::ostream &operator<<(std::ostream &os, const Any &any) {
96 os << any.GetString();
97 return os;
98 }
99
100 // type cast
101 template <class T>
cast()102 T &cast() const {
103 if (!is<T>() || !m_ptr) {
104 // Use MS_LOGFATAL replace throw std::bad_cast()
105 MS_LOG(EXCEPTION) << "can not cast " << m_tpIndex.name() << " to " << typeid(T).name();
106 }
107 auto ptr = static_cast<Derived<T> *>(m_ptr.get());
108 return ptr->m_value;
109 }
110
111 bool operator==(const Any &other) const {
112 if (m_tpIndex != other.m_tpIndex) {
113 return false;
114 }
115 if (m_ptr == nullptr && other.m_ptr == nullptr) {
116 return true;
117 }
118 if (m_ptr == nullptr || other.m_ptr == nullptr) {
119 return false;
120 }
121 return *m_ptr == *other.m_ptr;
122 }
123
124 bool operator!=(const Any &other) const { return !(operator==(other)); }
125
126 Any &operator=(const Any &other);
127
128 bool operator<(const Any &other) const;
129
ToString()130 std::string ToString() const {
131 std::ostringstream buffer;
132 if (m_tpIndex == typeid(float)) {
133 buffer << "<float> " << cast<float>();
134 } else if (m_tpIndex == typeid(double)) {
135 buffer << "<double> " << cast<double>();
136 } else if (m_tpIndex == typeid(int)) {
137 buffer << "<int> " << cast<int>();
138 } else if (m_tpIndex == typeid(bool)) {
139 buffer << "<bool> " << cast<bool>();
140 } else if (m_ptr != nullptr) {
141 buffer << "<" << demangle(m_tpIndex.name()) << "> " << m_ptr->GetString();
142 }
143 return buffer.str();
144 }
145 #ifdef _MSC_VER
dump()146 void dump() const { std::cout << ToString() << std::endl; }
147 #else
dump()148 __attribute__((used)) void dump() const { std::cout << ToString() << std::endl; }
149 #endif
150
151 private:
152 struct Base;
153 using BasePtr = std::unique_ptr<Base>;
154
155 // type base definition
156 struct Base {
157 virtual const std::type_info &type() const = 0;
158 virtual BasePtr clone() const = 0;
159 virtual ~Base() = default;
160 virtual bool operator==(const Base &other) const = 0;
161 virtual std::string GetString() = 0;
162 };
163
164 template <typename T>
165 struct Derived : public Base {
166 template <typename... Args>
DerivedDerived167 explicit Derived(Args &&... args) : m_value(std::forward<Args>(args)...), serialize_cache_("") {}
168
169 bool operator==(const Base &other) const override {
170 if (typeid(*this) != typeid(other)) {
171 return false;
172 }
173 return m_value == static_cast<const Derived<T> &>(other).m_value;
174 }
175
typeDerived176 const std::type_info &type() const override { return typeid(T); }
177
cloneDerived178 BasePtr clone() const override { return BasePtr(new Derived<T>(m_value)); }
179
~DerivedDerived180 ~Derived() override {}
181
GetStringDerived182 std::string GetString() override {
183 std::stringstream buffer;
184 buffer << m_value;
185 return buffer.str();
186 }
187
188 T m_value;
189 std::string serialize_cache_;
190 };
191
192 // clone method
clone()193 BasePtr clone() const {
194 if (m_ptr != nullptr) {
195 return m_ptr->clone();
196 }
197 return nullptr;
198 }
199
200 BasePtr m_ptr; // point to real data
201 std::type_index m_tpIndex; // type info of data
202 };
203
204 using AnyPtr = std::shared_ptr<Any>;
205
206 struct AnyHash {
operatorAnyHash207 std::size_t operator()(const Any &c) const { return c.Hash(); }
208 };
209
210 struct AnyLess {
operatorAnyLess211 bool operator()(const Any &a, const Any &b) const { return a.Hash() < b.Hash(); }
212 };
213
214 bool AnyIsLiteral(const Any &any);
215 } // namespace mindspore
216
217 #endif // MINDSPORE_CORE_UTILS_ANY_H_
218