• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_
17 #define TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_
18 #include <memory>
19 #include <string>
20 #include <unordered_map>
21 #include <utility>
22 #include <vector>
23 
24 namespace tflite {
25 namespace tools {
26 
27 template <typename T>
28 class TypedToolParam;
29 
30 class ToolParam {
31  protected:
32   enum class ParamType { TYPE_INT32, TYPE_FLOAT, TYPE_BOOL, TYPE_STRING };
33   template <typename T>
34   static ParamType GetValueType();
35 
36  public:
37   template <typename T>
38   static std::unique_ptr<ToolParam> Create(const T& default_value,
39                                            int position = 0) {
40     auto* param = new TypedToolParam<T>(default_value);
41     param->SetPosition(position);
42     return std::unique_ptr<ToolParam>(param);
43   }
44 
45   template <typename T>
AsTyped()46   TypedToolParam<T>* AsTyped() {
47     AssertHasSameType(GetValueType<T>(), type_);
48     return static_cast<TypedToolParam<T>*>(this);
49   }
50 
51   template <typename T>
AsConstTyped()52   const TypedToolParam<T>* AsConstTyped() const {
53     AssertHasSameType(GetValueType<T>(), type_);
54     return static_cast<const TypedToolParam<T>*>(this);
55   }
56 
~ToolParam()57   virtual ~ToolParam() {}
ToolParam(ParamType type)58   explicit ToolParam(ParamType type)
59       : has_value_set_(false), position_(0), type_(type) {}
60 
HasValueSet()61   bool HasValueSet() const { return has_value_set_; }
62 
GetPosition()63   int GetPosition() const { return position_; }
SetPosition(int position)64   void SetPosition(int position) { position_ = position; }
65 
Set(const ToolParam &)66   virtual void Set(const ToolParam&) {}
67 
68   virtual std::unique_ptr<ToolParam> Clone() const = 0;
69 
70  protected:
71   bool has_value_set_;
72 
73   // Represents the relative ordering among a set of params.
74   // Note: in our code, a ToolParam is generally used together with a
75   // tflite::Flag so that its value could be set when parsing commandline flags.
76   // In this case, the `position_` is simply the index of the particular flag
77   // into the list of commandline flags (i.e. named 'argv' in general).
78   int position_;
79 
80  private:
81   static void AssertHasSameType(ParamType a, ParamType b);
82 
83   const ParamType type_;
84 };
85 
86 template <typename T>
87 class TypedToolParam : public ToolParam {
88  public:
TypedToolParam(const T & value)89   explicit TypedToolParam(const T& value)
90       : ToolParam(GetValueType<T>()), value_(value) {}
91 
Set(const T & value)92   void Set(const T& value) {
93     value_ = value;
94     has_value_set_ = true;
95   }
96 
Get()97   T Get() const { return value_; }
98 
Set(const ToolParam & other)99   void Set(const ToolParam& other) override {
100     Set(other.AsConstTyped<T>()->Get());
101     SetPosition(other.AsConstTyped<T>()->GetPosition());
102   }
103 
Clone()104   std::unique_ptr<ToolParam> Clone() const override {
105     return ToolParam::Create<T>(value_, position_);
106   }
107 
108  private:
109   T value_;
110 };
111 
112 // A map-like container for holding values of different types.
113 class ToolParams {
114  public:
115   // Add a ToolParam instance `value` w/ `name` to this container.
AddParam(const std::string & name,std::unique_ptr<ToolParam> value)116   void AddParam(const std::string& name, std::unique_ptr<ToolParam> value) {
117     params_[name] = std::move(value);
118   }
119 
RemoveParam(const std::string & name)120   void RemoveParam(const std::string& name) { params_.erase(name); }
121 
HasParam(const std::string & name)122   bool HasParam(const std::string& name) const {
123     return params_.find(name) != params_.end();
124   }
125 
Empty()126   bool Empty() const { return params_.empty(); }
127 
GetParam(const std::string & name)128   const ToolParam* GetParam(const std::string& name) const {
129     const auto& entry = params_.find(name);
130     if (entry == params_.end()) return nullptr;
131     return entry->second.get();
132   }
133 
134   template <typename T>
135   void Set(const std::string& name, const T& value, int position = 0) {
136     AssertParamExists(name);
137     params_.at(name)->AsTyped<T>()->Set(value);
138     params_.at(name)->AsTyped<T>()->SetPosition(position);
139   }
140 
141   template <typename T>
HasValueSet(const std::string & name)142   bool HasValueSet(const std::string& name) const {
143     AssertParamExists(name);
144     return params_.at(name)->AsConstTyped<T>()->HasValueSet();
145   }
146 
147   template <typename T>
GetPosition(const std::string & name)148   int GetPosition(const std::string& name) const {
149     AssertParamExists(name);
150     return params_.at(name)->AsConstTyped<T>()->GetPosition();
151   }
152 
153   template <typename T>
Get(const std::string & name)154   T Get(const std::string& name) const {
155     AssertParamExists(name);
156     return params_.at(name)->AsConstTyped<T>()->Get();
157   }
158 
159   // Set the value of all same parameters from 'other'.
160   void Set(const ToolParams& other);
161 
162   // Merge the value of all parameters from 'other'. 'overwrite' indicates
163   // whether the value of the same paratmeter is overwritten or not.
164   void Merge(const ToolParams& other, bool overwrite = false);
165 
166  private:
167   void AssertParamExists(const std::string& name) const;
168   std::unordered_map<std::string, std::unique_ptr<ToolParam>> params_;
169 };
170 
171 #define LOG_TOOL_PARAM(params, type, name, description, verbose)      \
172   do {                                                                \
173     TFLITE_MAY_LOG(INFO, (verbose) || params.HasValueSet<type>(name)) \
174         << description << ": [" << params.Get<type>(name) << "]";     \
175   } while (0)
176 
177 }  // namespace tools
178 }  // namespace tflite
179 #endif  // TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_
180