• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CORE_UTILS_ANY_H_
17 #define MINDSPORE_CORE_UTILS_ANY_H_
18 
19 #include <iostream>
20 #include <string>
21 #include <typeinfo>
22 #include <typeindex>
23 #include <memory>
24 #include <functional>
25 #include <sstream>
26 #include <vector>
27 #include <utility>
28 #include "utils/overload.h"
29 #include "utils/log_adapter.h"
30 #include "utils/misc.h"
31 
32 namespace mindspore {
33 // usage:AnyPtr sp = std::make_shared<Any>(aname);
34 template <class T>
type(const T & t)35 std::string type(const T &t) {
36   return demangle(typeid(t).name());
37 }
38 
39 class MS_CORE_API Any {
40  public:
41   // constructors
Any()42   Any() : m_ptr(nullptr), m_tpIndex(std::type_index(typeid(void))) {}
Any(const Any & other)43   Any(const Any &other) : m_ptr(other.clone()), m_tpIndex(other.m_tpIndex) {}
Any(Any && other)44   Any(Any &&other) : m_ptr(std::move(other.m_ptr)), m_tpIndex(std::move(other.m_tpIndex)) {}
45 
46   Any &operator=(Any &&other);
47   // right reference constructor
48   template <class T, class = typename std::enable_if<!std::is_same<typename std::decay<T>::type, Any>::value, T>::type>
Any(T && t)49   Any(T &&t) : m_tpIndex(typeid(typename std::decay<T>::type)) {  // NOLINT
50     BasePtr new_val = std::make_unique<Derived<typename std::decay<T>::type>>(std::forward<T>(t));
51     std::swap(m_ptr, new_val);
52   }
53 
54   ~Any() = default;
55 
56   // judge whether is empty
empty()57   bool empty() const { return m_ptr == nullptr; }
58 
59   // judge the is relation
60   template <class T>
is()61   bool is() const {
62     return m_tpIndex == std::type_index(typeid(T));
63   }
64 
type()65   const std::type_info &type() const { return m_ptr ? m_ptr->type() : typeid(void); }
66 
Hash()67   std::size_t Hash() const {
68     std::stringstream buffer;
69     buffer << m_tpIndex.name();
70     if (m_ptr != nullptr) {
71       buffer << m_ptr->GetString();
72     }
73     return std::hash<std::string>()(buffer.str());
74   }
75 
76   template <typename T>
Apply(const std::function<void (T &)> & fn)77   bool Apply(const std::function<void(T &)> &fn) {
78     if (type() == typeid(T)) {
79       T x = cast<T>();
80       fn(x);
81       return true;
82     }
83     return false;
84   }
85 
GetString()86   std::string GetString() const {
87     if (m_ptr != nullptr) {
88       return m_ptr->GetString();
89     } else {
90       return std::string("");
91     }
92   }
93 
94   friend std::ostream &operator<<(std::ostream &os, const Any &any) {
95     os << any.GetString();
96     return os;
97   }
98 
99   // type cast
100   template <class T>
cast()101   T &cast() const {
102     if (!is<T>() || !m_ptr) {
103       // Use MS_LOGFATAL replace throw std::bad_cast()
104       MS_LOG(EXCEPTION) << "can not cast " << m_tpIndex.name() << " to " << typeid(T).name();
105     }
106     auto ptr = static_cast<Derived<T> *>(m_ptr.get());
107     return ptr->m_value;
108   }
109 
110   bool operator==(const Any &other) const {
111     if (m_tpIndex != other.m_tpIndex) {
112       return false;
113     }
114     if (m_ptr == nullptr && other.m_ptr == nullptr) {
115       return true;
116     }
117     if (m_ptr == nullptr || other.m_ptr == nullptr) {
118       return false;
119     }
120     return *m_ptr == *other.m_ptr;
121   }
122 
123   bool operator!=(const Any &other) const { return !(operator==(other)); }
124 
125   Any &operator=(const Any &other);
126 
127   bool operator<(const Any &other) const;
128 
ToString()129   std::string ToString() const {
130     std::ostringstream buffer;
131     if (m_tpIndex == typeid(float)) {
132       buffer << "<float> " << cast<float>();
133     } else if (m_tpIndex == typeid(double)) {
134       buffer << "<double> " << cast<double>();
135     } else if (m_tpIndex == typeid(int)) {
136       buffer << "<int> " << cast<int>();
137     } else if (m_tpIndex == typeid(bool)) {
138       buffer << "<bool> " << cast<bool>();
139     } else if (m_ptr != nullptr) {
140       buffer << "<" << demangle(m_tpIndex.name()) << "> " << m_ptr->GetString();
141     }
142     return buffer.str();
143   }
144 #ifdef _MSC_VER
dump()145   void dump() const { std::cout << ToString() << std::endl; }
146 #else
dump()147   __attribute__((used)) void dump() const { std::cout << ToString() << std::endl; }
148 #endif
149 
150  private:
151   struct Base;
152   using BasePtr = std::unique_ptr<Base>;
153 
154   // type base definition
155   struct Base {
156     virtual const std::type_info &type() const = 0;
157     virtual BasePtr clone() const = 0;
158     virtual ~Base() = default;
159     virtual bool operator==(const Base &other) const = 0;
160     virtual std::string GetString() = 0;
161   };
162 
163   template <typename T>
164   struct Derived : public Base {
165     template <typename... Args>
DerivedDerived166     explicit Derived(Args &&... args) : m_value(std::forward<Args>(args)...), serialize_cache_("") {}
167 
168     bool operator==(const Base &other) const override {
169       if (typeid(*this) != typeid(other)) {
170         return false;
171       }
172       return m_value == static_cast<const Derived<T> &>(other).m_value;
173     }
174 
typeDerived175     const std::type_info &type() const override { return typeid(T); }
176 
cloneDerived177     BasePtr clone() const override { return std::make_unique<Derived<T>>(m_value); }
178 
~DerivedDerived179     ~Derived() override {}
180 
GetStringDerived181     std::string GetString() override {
182       std::stringstream buffer;
183       buffer << m_value;
184       return buffer.str();
185     }
186 
187     T m_value;
188     std::string serialize_cache_;
189   };
190 
191   // clone method
clone()192   BasePtr clone() const {
193     if (m_ptr != nullptr) {
194       return m_ptr->clone();
195     }
196     return nullptr;
197   }
198 
199   BasePtr m_ptr;              // point to real data
200   std::type_index m_tpIndex;  // type info of data
201 };
202 
203 using AnyPtr = std::shared_ptr<Any>;
204 
205 struct AnyHash {
operatorAnyHash206   std::size_t operator()(const Any &c) const { return c.Hash(); }
207 };
208 
209 struct AnyLess {
operatorAnyLess210   bool operator()(const Any &a, const Any &b) const { return a.Hash() < b.Hash(); }
211 };
212 
213 bool AnyIsLiteral(const Any &any);
214 }  // namespace mindspore
215 
216 #endif  // MINDSPORE_CORE_UTILS_ANY_H_
217