• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_
17 #define TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <utility>
22 #include <vector>
23 
24 namespace tflite {
25 namespace tools {
26 
27 template <typename T>
28 class TypedToolParam;
29 
30 class ToolParam {
31  protected:
32   enum class ParamType { TYPE_INT32, TYPE_FLOAT, TYPE_BOOL, TYPE_STRING };
33   template <typename T>
34   static ParamType GetValueType();
35 
36  public:
37   template <typename T>
38   static std::unique_ptr<ToolParam> Create(const T& default_value,
39                                            int position = 0) {
40     auto* param = new TypedToolParam<T>(default_value);
41     param->SetPosition(position);
42     return std::unique_ptr<ToolParam>(param);
43   }
44 
45   template <typename T>
AsTyped()46   TypedToolParam<T>* AsTyped() {
47     AssertHasSameType(GetValueType<T>(), type_);
48     return static_cast<TypedToolParam<T>*>(this);
49   }
50 
51   template <typename T>
AsConstTyped()52   const TypedToolParam<T>* AsConstTyped() const {
53     AssertHasSameType(GetValueType<T>(), type_);
54     return static_cast<const TypedToolParam<T>*>(this);
55   }
56 
~ToolParam()57   virtual ~ToolParam() {}
ToolParam(ParamType type)58   explicit ToolParam(ParamType type)
59       : has_value_set_(false), position_(0), type_(type) {}
60 
HasValueSet()61   bool HasValueSet() const { return has_value_set_; }
62 
GetPosition()63   int GetPosition() const { return position_; }
SetPosition(int position)64   void SetPosition(int position) { position_ = position; }
65 
Set(const ToolParam &)66   virtual void Set(const ToolParam&) {}
67 
68   virtual std::unique_ptr<ToolParam> Clone() const = 0;
69 
70  protected:
71   bool has_value_set_;
72 
73   // Represents the relative ordering among a set of params.
74   // Note: in our code, a ToolParam is generally used together with a
75   // tflite::Flag so that its value could be set when parsing commandline flags.
76   // In this case, the `position_` is simply the index of the particular flag
77   // into the list of commandline flags (i.e. named 'argv' in general).
78   int position_;
79 
80  private:
81   static void AssertHasSameType(ParamType a, ParamType b);
82 
83   const ParamType type_;
84 };
85 
86 template <typename T>
87 class TypedToolParam : public ToolParam {
88  public:
TypedToolParam(const T & value)89   explicit TypedToolParam(const T& value)
90       : ToolParam(GetValueType<T>()), value_(value) {}
91 
Set(const T & value)92   void Set(const T& value) {
93     value_ = value;
94     has_value_set_ = true;
95   }
96 
Get()97   T Get() const { return value_; }
98 
Set(const ToolParam & other)99   void Set(const ToolParam& other) override {
100     Set(other.AsConstTyped<T>()->Get());
101     SetPosition(other.AsConstTyped<T>()->GetPosition());
102   }
103 
Clone()104   std::unique_ptr<ToolParam> Clone() const override {
105     return ToolParam::Create<T>(value_, position_);
106   }
107 
108  private:
109   T value_;
110 };
111 
112 // A map-like container for holding values of different types.
113 class ToolParams {
114  public:
115   // Add a ToolParam instance `value` w/ `name` to this container.
AddParam(const std::string & name,std::unique_ptr<ToolParam> value)116   void AddParam(const std::string& name, std::unique_ptr<ToolParam> value) {
117     params_[name] = std::move(value);
118   }
119 
HasParam(const std::string & name)120   bool HasParam(const std::string& name) const {
121     return params_.find(name) != params_.end();
122   }
123 
Empty()124   bool Empty() const { return params_.empty(); }
125 
GetParam(const std::string & name)126   const ToolParam* GetParam(const std::string& name) const {
127     const auto& entry = params_.find(name);
128     if (entry == params_.end()) return nullptr;
129     return entry->second.get();
130   }
131 
132   template <typename T>
133   void Set(const std::string& name, const T& value, int position = 0) {
134     AssertParamExists(name);
135     params_.at(name)->AsTyped<T>()->Set(value);
136     params_.at(name)->AsTyped<T>()->SetPosition(position);
137   }
138 
139   template <typename T>
HasValueSet(const std::string & name)140   bool HasValueSet(const std::string& name) const {
141     AssertParamExists(name);
142     return params_.at(name)->AsConstTyped<T>()->HasValueSet();
143   }
144 
145   template <typename T>
GetPosition(const std::string & name)146   int GetPosition(const std::string& name) const {
147     AssertParamExists(name);
148     return params_.at(name)->AsConstTyped<T>()->GetPosition();
149   }
150 
151   template <typename T>
Get(const std::string & name)152   T Get(const std::string& name) const {
153     AssertParamExists(name);
154     return params_.at(name)->AsConstTyped<T>()->Get();
155   }
156 
157   // Set the value of all same parameters from 'other'.
158   void Set(const ToolParams& other);
159 
160   // Merge the value of all parameters from 'other'. 'overwrite' indicates
161   // whether the value of the same paratmeter is overwritten or not.
162   void Merge(const ToolParams& other, bool overwrite = false);
163 
164  private:
165   void AssertParamExists(const std::string& name) const;
166   std::unordered_map<std::string, std::unique_ptr<ToolParam>> params_;
167 };
168 
169 #define LOG_TOOL_PARAM(params, type, name, description, verbose)      \
170   do {                                                                \
171     TFLITE_MAY_LOG(INFO, (verbose) || params.HasValueSet<type>(name)) \
172         << description << ": [" << params.Get<type>(name) << "]";     \
173   } while (0)
174 
175 }  // namespace tools
176 }  // namespace tflite
177 #endif  // TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_
178