1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include "BackendId.hpp"
9 #include <cassert>
10
11 namespace armnn
12 {
13
14 struct BackendOptions;
15 using NetworkOptions = std::vector<BackendOptions>;
16
17 using ModelOptions = std::vector<BackendOptions>;
18
19 /// Struct for the users to pass backend specific options
20 struct BackendOptions
21 {
22 private:
23 template<typename T>
24 struct CheckAllowed
25 {
26 static const bool value = std::is_same<T, int>::value ||
27 std::is_same<T, float>::value ||
28 std::is_same<T, bool>::value ||
29 std::is_same<T, std::string>::value ||
30 std::is_same<T, const char*>::value;
31 };
32 public:
33
34 /// Very basic type safe variant
35 class Var
36 {
37
38 public:
39 /// Constructors
Var(int i)40 explicit Var(int i) : m_Vals(i), m_Type(VarTypes::Integer) {};
Var(float f)41 explicit Var(float f) : m_Vals(f), m_Type(VarTypes::Float) {};
Var(bool b)42 explicit Var(bool b) : m_Vals(b), m_Type(VarTypes::Boolean) {};
Var(const char * s)43 explicit Var(const char* s) : m_Vals(s), m_Type(VarTypes::String) {};
Var(std::string s)44 explicit Var(std::string s) : m_Vals(s), m_Type(VarTypes::String) {};
45
46 /// Disallow implicit conversions from types not explicitly allowed below.
47 template<typename DisallowedType>
Var(DisallowedType)48 Var(DisallowedType)
49 {
50 static_assert(CheckAllowed<DisallowedType>::value, "Type is not allowed for Var<DisallowedType>.");
51 assert(false && "Unreachable code");
52 }
53
54 /// Copy Construct
Var(const Var & other)55 Var(const Var& other)
56 : m_Type(other.m_Type)
57 {
58 switch(m_Type)
59 {
60 case VarTypes::String:
61 {
62 new (&m_Vals.s) std::string(other.m_Vals.s);
63 break;
64 }
65 default:
66 {
67 DoOp(other, [](auto& a, auto& b)
68 {
69 a = b;
70 });
71 break;
72 }
73 }
74 }
75
76 /// Copy operator
operator =(const Var & other)77 Var& operator=(const Var& other)
78 {
79 // Destroy existing string
80 if (m_Type == VarTypes::String)
81 {
82 Destruct(m_Vals.s);
83 }
84
85 m_Type = other.m_Type;
86 switch(m_Type)
87 {
88 case VarTypes::String:
89 {
90
91 new (&m_Vals.s) std::string(other.m_Vals.s);
92 break;
93 }
94 default:
95 {
96 DoOp(other, [](auto& a, auto& b)
97 {
98 a = b;
99 });
100 break;
101 }
102 }
103
104 return *this;
105 };
106
107 /// Type getters
IsBool() const108 bool IsBool() const { return m_Type == VarTypes::Boolean; }
IsInt() const109 bool IsInt() const { return m_Type == VarTypes::Integer; }
IsFloat() const110 bool IsFloat() const { return m_Type == VarTypes::Float; }
IsString() const111 bool IsString() const { return m_Type == VarTypes::String; }
112
113 /// Value getters
AsBool() const114 bool AsBool() const { assert(IsBool()); return m_Vals.b; }
AsInt() const115 int AsInt() const { assert(IsInt()); return m_Vals.i; }
AsFloat() const116 float AsFloat() const { assert(IsFloat()); return m_Vals.f; }
AsString() const117 std::string AsString() const { assert(IsString()); return m_Vals.s; }
118
119 /// Destructor
~Var()120 ~Var()
121 {
122 DoOp(*this, [this](auto& a, auto&)
123 {
124 Destruct(a);
125 });
126 }
127 private:
128 template<typename Func>
DoOp(const Var & other,Func func)129 void DoOp(const Var& other, Func func)
130 {
131 if (other.IsBool())
132 {
133 func(m_Vals.b, other.m_Vals.b);
134 }
135 else if (other.IsInt())
136 {
137 func(m_Vals.i, other.m_Vals.i);
138 }
139 else if (other.IsFloat())
140 {
141 func(m_Vals.f, other.m_Vals.f);
142 }
143 else if (other.IsString())
144 {
145 func(m_Vals.s, other.m_Vals.s);
146 }
147 }
148
149 template<typename Destructable>
Destruct(Destructable & d)150 void Destruct(Destructable& d)
151 {
152 if (std::is_destructible<Destructable>::value)
153 {
154 d.~Destructable();
155 }
156 }
157
158 private:
159 /// Types which can be stored
160 enum class VarTypes
161 {
162 Boolean,
163 Integer,
164 Float,
165 String,
166 };
167
168 /// Union of potential type values.
169 union Vals
170 {
171 int i;
172 float f;
173 bool b;
174 std::string s;
175
Vals()176 Vals(){}
~Vals()177 ~Vals(){}
178
Vals(int i)179 explicit Vals(int i) : i(i) {};
Vals(float f)180 explicit Vals(float f) : f(f) {};
Vals(bool b)181 explicit Vals(bool b) : b(b) {};
Vals(const char * s)182 explicit Vals(const char* s) : s(std::string(s)) {}
Vals(std::string s)183 explicit Vals(std::string s) : s(s) {}
184 };
185
186 Vals m_Vals;
187 VarTypes m_Type;
188 };
189
190 struct BackendOption
191 {
192 public:
BackendOptionarmnn::BackendOptions::BackendOption193 BackendOption(std::string name, bool value)
194 : m_Name(name), m_Value(value)
195 {}
BackendOptionarmnn::BackendOptions::BackendOption196 BackendOption(std::string name, int value)
197 : m_Name(name), m_Value(value)
198 {}
BackendOptionarmnn::BackendOptions::BackendOption199 BackendOption(std::string name, float value)
200 : m_Name(name), m_Value(value)
201 {}
BackendOptionarmnn::BackendOptions::BackendOption202 BackendOption(std::string name, std::string value)
203 : m_Name(name), m_Value(value)
204 {}
BackendOptionarmnn::BackendOptions::BackendOption205 BackendOption(std::string name, const char* value)
206 : m_Name(name), m_Value(value)
207 {}
208
209 template<typename DisallowedType>
BackendOptionarmnn::BackendOptions::BackendOption210 BackendOption(std::string, DisallowedType)
211 : m_Value(0)
212 {
213 static_assert(CheckAllowed<DisallowedType>::value, "Type is not allowed for BackendOption.");
214 assert(false && "Unreachable code");
215 }
216
217 BackendOption(const BackendOption& other) = default;
218 BackendOption(BackendOption&& other) = default;
219 BackendOption& operator=(const BackendOption& other) = default;
220 BackendOption& operator=(BackendOption&& other) = default;
221 ~BackendOption() = default;
222
GetNamearmnn::BackendOptions::BackendOption223 std::string GetName() const { return m_Name; }
GetValuearmnn::BackendOptions::BackendOption224 Var GetValue() const { return m_Value; }
225
226 private:
227 std::string m_Name; ///< Name of the option
228 Var m_Value; ///< Value of the option. (Bool, int, Float, String)
229 };
230
BackendOptionsarmnn::BackendOptions231 explicit BackendOptions(BackendId backend)
232 : m_TargetBackend(backend)
233 {}
234
BackendOptionsarmnn::BackendOptions235 BackendOptions(BackendId backend, std::initializer_list<BackendOption> options)
236 : m_TargetBackend(backend)
237 , m_Options(options)
238 {}
239
240 BackendOptions(const BackendOptions& other) = default;
241 BackendOptions(BackendOptions&& other) = default;
242 BackendOptions& operator=(const BackendOptions& other) = default;
243 BackendOptions& operator=(BackendOptions&& other) = default;
244
AddOptionarmnn::BackendOptions245 void AddOption(BackendOption&& option)
246 {
247 m_Options.push_back(option);
248 }
249
AddOptionarmnn::BackendOptions250 void AddOption(const BackendOption& option)
251 {
252 m_Options.push_back(option);
253 }
254
GetBackendIdarmnn::BackendOptions255 const BackendId& GetBackendId() const noexcept { return m_TargetBackend; }
GetOptionCountarmnn::BackendOptions256 size_t GetOptionCount() const noexcept { return m_Options.size(); }
GetOptionarmnn::BackendOptions257 const BackendOption& GetOption(size_t idx) const { return m_Options[idx]; }
258
259 private:
260 /// The id for the backend to which the options should be passed.
261 BackendId m_TargetBackend;
262
263 /// The array of options to pass to the backend context
264 std::vector<BackendOption> m_Options;
265 };
266
267
268 template <typename F>
ParseOptions(const std::vector<BackendOptions> & options,BackendId backend,F f)269 void ParseOptions(const std::vector<BackendOptions>& options, BackendId backend, F f)
270 {
271 for (auto optionsGroup : options)
272 {
273 if (optionsGroup.GetBackendId() == backend)
274 {
275 for (size_t i=0; i < optionsGroup.GetOptionCount(); i++)
276 {
277 const BackendOptions::BackendOption option = optionsGroup.GetOption(i);
278 f(option.GetName(), option.GetValue());
279 }
280 }
281 }
282 }
283
284 } //namespace armnn
285