• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "BackendId.hpp"
9 #include <cassert>
10 
11 namespace armnn
12 {
13 
14 struct BackendOptions;
15 using NetworkOptions = std::vector<BackendOptions>;
16 
17 using ModelOptions = std::vector<BackendOptions>;
18 
19 /// Struct for the users to pass backend specific options
20 struct BackendOptions
21 {
22 private:
23     template<typename T>
24     struct CheckAllowed
25     {
26         static const bool value = std::is_same<T, int>::value ||
27                                   std::is_same<T, float>::value ||
28                                   std::is_same<T, bool>::value ||
29                                   std::is_same<T, std::string>::value ||
30                                   std::is_same<T, const char*>::value;
31     };
32 public:
33 
34     /// Very basic type safe variant
35     class Var
36     {
37 
38     public:
39         /// Constructors
Var(int i)40         explicit Var(int i) : m_Vals(i), m_Type(VarTypes::Integer) {};
Var(float f)41         explicit Var(float f) : m_Vals(f), m_Type(VarTypes::Float) {};
Var(bool b)42         explicit Var(bool b) : m_Vals(b), m_Type(VarTypes::Boolean) {};
Var(const char * s)43         explicit Var(const char* s) : m_Vals(s), m_Type(VarTypes::String) {};
Var(std::string s)44         explicit Var(std::string s) : m_Vals(s), m_Type(VarTypes::String) {};
45 
46         /// Disallow implicit conversions from types not explicitly allowed below.
47         template<typename DisallowedType>
Var(DisallowedType)48         Var(DisallowedType)
49         {
50             static_assert(CheckAllowed<DisallowedType>::value, "Type is not allowed for Var<DisallowedType>.");
51             assert(false && "Unreachable code");
52         }
53 
54         /// Copy Construct
Var(const Var & other)55         Var(const Var& other)
56             : m_Type(other.m_Type)
57         {
58             switch(m_Type)
59             {
60                 case VarTypes::String:
61                 {
62                     new (&m_Vals.s) std::string(other.m_Vals.s);
63                     break;
64                 }
65                 default:
66                 {
67                     DoOp(other, [](auto& a, auto& b)
68                         {
69                             a = b;
70                         });
71                     break;
72                 }
73             }
74         }
75 
76         /// Copy operator
operator =(const Var & other)77         Var& operator=(const Var& other)
78         {
79             // Destroy existing string
80             if (m_Type == VarTypes::String)
81             {
82                 Destruct(m_Vals.s);
83             }
84 
85             m_Type = other.m_Type;
86             switch(m_Type)
87             {
88                 case VarTypes::String:
89                 {
90 
91                     new (&m_Vals.s) std::string(other.m_Vals.s);
92                     break;
93                 }
94                 default:
95                 {
96                     DoOp(other, [](auto& a, auto& b)
97                         {
98                             a = b;
99                         });
100                     break;
101                 }
102             }
103 
104             return *this;
105         };
106 
107         /// Type getters
IsBool() const108         bool IsBool() const { return m_Type == VarTypes::Boolean; }
IsInt() const109         bool IsInt() const { return m_Type == VarTypes::Integer; }
IsFloat() const110         bool IsFloat() const { return m_Type == VarTypes::Float; }
IsString() const111         bool IsString() const { return m_Type == VarTypes::String; }
112 
113         /// Value getters
AsBool() const114         bool AsBool() const { assert(IsBool()); return m_Vals.b; }
AsInt() const115         int AsInt() const { assert(IsInt()); return m_Vals.i; }
AsFloat() const116         float AsFloat() const { assert(IsFloat()); return m_Vals.f; }
AsString() const117         std::string AsString() const { assert(IsString()); return m_Vals.s; }
118 
119         /// Destructor
~Var()120         ~Var()
121         {
122             DoOp(*this, [this](auto& a, auto&)
123                 {
124                     Destruct(a);
125                 });
126         }
127     private:
128         template<typename Func>
DoOp(const Var & other,Func func)129         void DoOp(const Var& other, Func func)
130         {
131             if (other.IsBool())
132             {
133                 func(m_Vals.b, other.m_Vals.b);
134             }
135             else if (other.IsInt())
136             {
137                 func(m_Vals.i, other.m_Vals.i);
138             }
139             else if (other.IsFloat())
140             {
141                 func(m_Vals.f, other.m_Vals.f);
142             }
143             else if (other.IsString())
144             {
145                 func(m_Vals.s, other.m_Vals.s);
146             }
147         }
148 
149         template<typename Destructable>
Destruct(Destructable & d)150         void Destruct(Destructable& d)
151         {
152             if (std::is_destructible<Destructable>::value)
153             {
154                 d.~Destructable();
155             }
156         }
157 
158     private:
159         /// Types which can be stored
160         enum class VarTypes
161         {
162             Boolean,
163             Integer,
164             Float,
165             String,
166         };
167 
168         /// Union of potential type values.
169         union Vals
170         {
171             int i;
172             float f;
173             bool b;
174             std::string s;
175 
Vals()176             Vals(){}
~Vals()177             ~Vals(){}
178 
Vals(int i)179             explicit Vals(int i) : i(i) {};
Vals(float f)180             explicit Vals(float f) : f(f) {};
Vals(bool b)181             explicit Vals(bool b) : b(b) {};
Vals(const char * s)182             explicit Vals(const char* s) : s(std::string(s)) {}
Vals(std::string s)183             explicit Vals(std::string s) : s(s) {}
184        };
185 
186         Vals m_Vals;
187         VarTypes m_Type;
188     };
189 
190     struct BackendOption
191     {
192     public:
BackendOptionarmnn::BackendOptions::BackendOption193         BackendOption(std::string name, bool value)
194             : m_Name(name), m_Value(value)
195         {}
BackendOptionarmnn::BackendOptions::BackendOption196         BackendOption(std::string name, int value)
197             : m_Name(name), m_Value(value)
198         {}
BackendOptionarmnn::BackendOptions::BackendOption199         BackendOption(std::string name, float value)
200             : m_Name(name), m_Value(value)
201         {}
BackendOptionarmnn::BackendOptions::BackendOption202         BackendOption(std::string name, std::string value)
203             : m_Name(name), m_Value(value)
204         {}
BackendOptionarmnn::BackendOptions::BackendOption205         BackendOption(std::string name, const char* value)
206             : m_Name(name), m_Value(value)
207         {}
208 
209         template<typename DisallowedType>
BackendOptionarmnn::BackendOptions::BackendOption210         BackendOption(std::string, DisallowedType)
211             : m_Value(0)
212         {
213             static_assert(CheckAllowed<DisallowedType>::value, "Type is not allowed for BackendOption.");
214             assert(false && "Unreachable code");
215         }
216 
217         BackendOption(const BackendOption& other) = default;
218         BackendOption(BackendOption&& other) = default;
219         BackendOption& operator=(const BackendOption& other) = default;
220         BackendOption& operator=(BackendOption&& other) = default;
221         ~BackendOption() = default;
222 
GetNamearmnn::BackendOptions::BackendOption223         std::string GetName() const   { return m_Name; }
GetValuearmnn::BackendOptions::BackendOption224         Var GetValue() const          { return m_Value; }
225 
226     private:
227         std::string m_Name;         ///< Name of the option
228         Var         m_Value;        ///< Value of the option. (Bool, int, Float, String)
229     };
230 
BackendOptionsarmnn::BackendOptions231     explicit BackendOptions(BackendId backend)
232         : m_TargetBackend(backend)
233     {}
234 
BackendOptionsarmnn::BackendOptions235     BackendOptions(BackendId backend, std::initializer_list<BackendOption> options)
236         : m_TargetBackend(backend)
237         , m_Options(options)
238     {}
239 
240     BackendOptions(const BackendOptions& other) = default;
241     BackendOptions(BackendOptions&& other) = default;
242     BackendOptions& operator=(const BackendOptions& other) = default;
243     BackendOptions& operator=(BackendOptions&& other) = default;
244 
AddOptionarmnn::BackendOptions245     void AddOption(BackendOption&& option)
246     {
247         m_Options.push_back(option);
248     }
249 
AddOptionarmnn::BackendOptions250     void AddOption(const BackendOption& option)
251     {
252         m_Options.push_back(option);
253     }
254 
GetBackendIdarmnn::BackendOptions255     const BackendId& GetBackendId() const noexcept { return m_TargetBackend; }
GetOptionCountarmnn::BackendOptions256     size_t GetOptionCount() const noexcept { return m_Options.size(); }
GetOptionarmnn::BackendOptions257     const BackendOption& GetOption(size_t idx) const { return m_Options[idx]; }
258 
259 private:
260     /// The id for the backend to which the options should be passed.
261     BackendId m_TargetBackend;
262 
263     /// The array of options to pass to the backend context
264     std::vector<BackendOption> m_Options;
265 };
266 
267 
268 template <typename F>
ParseOptions(const std::vector<BackendOptions> & options,BackendId backend,F f)269 void ParseOptions(const std::vector<BackendOptions>& options, BackendId backend, F f)
270 {
271     for (auto optionsGroup : options)
272     {
273         if (optionsGroup.GetBackendId() == backend)
274         {
275             for (size_t i=0; i < optionsGroup.GetOptionCount(); i++)
276             {
277                 const BackendOptions::BackendOption option = optionsGroup.GetOption(i);
278                 f(option.GetName(), option.GetValue());
279             }
280         }
281     }
282 }
283 
284 } //namespace armnn
285