• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CORE_UTILS_ANY_H_
17 #define MINDSPORE_CORE_UTILS_ANY_H_
18 
19 #include <iostream>
20 #include <string>
21 #include <typeinfo>
22 #include <typeindex>
23 #include <memory>
24 #include <functional>
25 #include <sstream>
26 #include <vector>
27 #include <utility>
28 
29 #include "utils/overload.h"
30 #include "utils/log_adapter.h"
31 #include "utils/misc.h"
32 
33 namespace mindspore {
34 // usage:AnyPtr sp = std::make_shared<Any>(aname);
35 template <class T>
type(const T & t)36 std::string type(const T &t) {
37   return demangle(typeid(t).name());
38 }
39 
40 class Any {
41  public:
42   // constructors
Any()43   Any() : m_ptr(nullptr), m_tpIndex(std::type_index(typeid(void))) {}
Any(const Any & other)44   Any(const Any &other) : m_ptr(other.clone()), m_tpIndex(other.m_tpIndex) {}
Any(Any && other)45   Any(Any &&other) : m_ptr(std::move(other.m_ptr)), m_tpIndex(std::move(other.m_tpIndex)) {}
46 
47   Any &operator=(Any &&other);
48   // right reference constructor
49   template <class T, class = typename std::enable_if<!std::is_same<typename std::decay<T>::type, Any>::value, T>::type>
Any(T && t)50   Any(T &&t) : m_tpIndex(typeid(typename std::decay<T>::type)) {  // NOLINT
51     BasePtr new_val(new Derived<typename std::decay<T>::type>(std::forward<T>(t)));
52     std::swap(m_ptr, new_val);
53   }
54 
55   ~Any() = default;
56 
57   // judge whether is empty
empty()58   bool empty() const { return m_ptr == nullptr; }
59 
60   // judge the is relation
61   template <class T>
is()62   bool is() const {
63     return m_tpIndex == std::type_index(typeid(T));
64   }
65 
type()66   const std::type_info &type() const { return m_ptr ? m_ptr->type() : typeid(void); }
67 
Hash()68   std::size_t Hash() const {
69     std::stringstream buffer;
70     buffer << m_tpIndex.name();
71     if (m_ptr != nullptr) {
72       buffer << m_ptr->GetString();
73     }
74     return std::hash<std::string>()(buffer.str());
75   }
76 
77   template <typename T>
Apply(const std::function<void (T &)> & fn)78   bool Apply(const std::function<void(T &)> &fn) {
79     if (type() == typeid(T)) {
80       T x = cast<T>();
81       fn(x);
82       return true;
83     }
84     return false;
85   }
86 
GetString()87   std::string GetString() const {
88     if (m_ptr != nullptr) {
89       return m_ptr->GetString();
90     } else {
91       return std::string("");
92     }
93   }
94 
95   friend std::ostream &operator<<(std::ostream &os, const Any &any) {
96     os << any.GetString();
97     return os;
98   }
99 
100   // type cast
101   template <class T>
cast()102   T &cast() const {
103     if (!is<T>() || !m_ptr) {
104       // Use MS_LOGFATAL replace throw std::bad_cast()
105       MS_LOG(EXCEPTION) << "can not cast " << m_tpIndex.name() << " to " << typeid(T).name();
106     }
107     auto ptr = static_cast<Derived<T> *>(m_ptr.get());
108     return ptr->m_value;
109   }
110 
111   bool operator==(const Any &other) const {
112     if (m_tpIndex != other.m_tpIndex) {
113       return false;
114     }
115     if (m_ptr == nullptr && other.m_ptr == nullptr) {
116       return true;
117     }
118     if (m_ptr == nullptr || other.m_ptr == nullptr) {
119       return false;
120     }
121     return *m_ptr == *other.m_ptr;
122   }
123 
124   bool operator!=(const Any &other) const { return !(operator==(other)); }
125 
126   Any &operator=(const Any &other);
127 
128   bool operator<(const Any &other) const;
129 
ToString()130   std::string ToString() const {
131     std::ostringstream buffer;
132     if (m_tpIndex == typeid(float)) {
133       buffer << "<float> " << cast<float>();
134     } else if (m_tpIndex == typeid(double)) {
135       buffer << "<double> " << cast<double>();
136     } else if (m_tpIndex == typeid(int)) {
137       buffer << "<int> " << cast<int>();
138     } else if (m_tpIndex == typeid(bool)) {
139       buffer << "<bool> " << cast<bool>();
140     } else if (m_ptr != nullptr) {
141       buffer << "<" << demangle(m_tpIndex.name()) << "> " << m_ptr->GetString();
142     }
143     return buffer.str();
144   }
145 #ifdef _MSC_VER
dump()146   void dump() const { std::cout << ToString() << std::endl; }
147 #else
dump()148   __attribute__((used)) void dump() const { std::cout << ToString() << std::endl; }
149 #endif
150 
151  private:
152   struct Base;
153   using BasePtr = std::unique_ptr<Base>;
154 
155   // type base definition
156   struct Base {
157     virtual const std::type_info &type() const = 0;
158     virtual BasePtr clone() const = 0;
159     virtual ~Base() = default;
160     virtual bool operator==(const Base &other) const = 0;
161     virtual std::string GetString() = 0;
162   };
163 
164   template <typename T>
165   struct Derived : public Base {
166     template <typename... Args>
DerivedDerived167     explicit Derived(Args &&... args) : m_value(std::forward<Args>(args)...), serialize_cache_("") {}
168 
169     bool operator==(const Base &other) const override {
170       if (typeid(*this) != typeid(other)) {
171         return false;
172       }
173       return m_value == static_cast<const Derived<T> &>(other).m_value;
174     }
175 
typeDerived176     const std::type_info &type() const override { return typeid(T); }
177 
cloneDerived178     BasePtr clone() const override { return BasePtr(new Derived<T>(m_value)); }
179 
~DerivedDerived180     ~Derived() override {}
181 
GetStringDerived182     std::string GetString() override {
183       std::stringstream buffer;
184       buffer << m_value;
185       return buffer.str();
186     }
187 
188     T m_value;
189     std::string serialize_cache_;
190   };
191 
192   // clone method
clone()193   BasePtr clone() const {
194     if (m_ptr != nullptr) {
195       return m_ptr->clone();
196     }
197     return nullptr;
198   }
199 
200   BasePtr m_ptr;              // point to real data
201   std::type_index m_tpIndex;  // type info of data
202 };
203 
204 using AnyPtr = std::shared_ptr<Any>;
205 
206 struct AnyHash {
operatorAnyHash207   std::size_t operator()(const Any &c) const { return c.Hash(); }
208 };
209 
210 struct AnyLess {
operatorAnyLess211   bool operator()(const Any &a, const Any &b) const { return a.Hash() < b.Hash(); }
212 };
213 
214 bool AnyIsLiteral(const Any &any);
215 }  // namespace mindspore
216 
217 #endif  // MINDSPORE_CORE_UTILS_ANY_H_
218