1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_ 17 #define TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_ 18 #include <memory> 19 #include <string> 20 #include <unordered_map> 21 #include <utility> 22 #include <vector> 23 24 namespace tflite { 25 namespace tools { 26 27 template <typename T> 28 class TypedToolParam; 29 30 class ToolParam { 31 protected: 32 enum class ParamType { TYPE_INT32, TYPE_FLOAT, TYPE_BOOL, TYPE_STRING }; 33 template <typename T> 34 static ParamType GetValueType(); 35 36 public: 37 template <typename T> 38 static std::unique_ptr<ToolParam> Create(const T& default_value, 39 int position = 0) { 40 auto* param = new TypedToolParam<T>(default_value); 41 param->SetPosition(position); 42 return std::unique_ptr<ToolParam>(param); 43 } 44 45 template <typename T> AsTyped()46 TypedToolParam<T>* AsTyped() { 47 AssertHasSameType(GetValueType<T>(), type_); 48 return static_cast<TypedToolParam<T>*>(this); 49 } 50 51 template <typename T> AsConstTyped()52 const TypedToolParam<T>* AsConstTyped() const { 53 AssertHasSameType(GetValueType<T>(), type_); 54 return static_cast<const TypedToolParam<T>*>(this); 55 } 56 ~ToolParam()57 virtual ~ToolParam() {} ToolParam(ParamType type)58 explicit ToolParam(ParamType type) 59 : has_value_set_(false), position_(0), type_(type) {} 60 HasValueSet()61 bool HasValueSet() const { return has_value_set_; } 62 GetPosition()63 int GetPosition() const { return position_; } SetPosition(int position)64 void SetPosition(int position) { position_ = position; } 65 Set(const ToolParam &)66 virtual void Set(const ToolParam&) {} 67 68 virtual std::unique_ptr<ToolParam> Clone() const = 0; 69 70 protected: 71 bool has_value_set_; 72 73 // Represents the relative ordering among a set of params. 74 // Note: in our code, a ToolParam is generally used together with a 75 // tflite::Flag so that its value could be set when parsing commandline flags. 76 // In this case, the `position_` is simply the index of the particular flag 77 // into the list of commandline flags (i.e. named 'argv' in general). 78 int position_; 79 80 private: 81 static void AssertHasSameType(ParamType a, ParamType b); 82 83 const ParamType type_; 84 }; 85 86 template <typename T> 87 class TypedToolParam : public ToolParam { 88 public: TypedToolParam(const T & value)89 explicit TypedToolParam(const T& value) 90 : ToolParam(GetValueType<T>()), value_(value) {} 91 Set(const T & value)92 void Set(const T& value) { 93 value_ = value; 94 has_value_set_ = true; 95 } 96 Get()97 T Get() const { return value_; } 98 Set(const ToolParam & other)99 void Set(const ToolParam& other) override { 100 Set(other.AsConstTyped<T>()->Get()); 101 SetPosition(other.AsConstTyped<T>()->GetPosition()); 102 } 103 Clone()104 std::unique_ptr<ToolParam> Clone() const override { 105 return ToolParam::Create<T>(value_, position_); 106 } 107 108 private: 109 T value_; 110 }; 111 112 // A map-like container for holding values of different types. 113 class ToolParams { 114 public: 115 // Add a ToolParam instance `value` w/ `name` to this container. AddParam(const std::string & name,std::unique_ptr<ToolParam> value)116 void AddParam(const std::string& name, std::unique_ptr<ToolParam> value) { 117 params_[name] = std::move(value); 118 } 119 RemoveParam(const std::string & name)120 void RemoveParam(const std::string& name) { params_.erase(name); } 121 HasParam(const std::string & name)122 bool HasParam(const std::string& name) const { 123 return params_.find(name) != params_.end(); 124 } 125 Empty()126 bool Empty() const { return params_.empty(); } 127 GetParam(const std::string & name)128 const ToolParam* GetParam(const std::string& name) const { 129 const auto& entry = params_.find(name); 130 if (entry == params_.end()) return nullptr; 131 return entry->second.get(); 132 } 133 134 template <typename T> 135 void Set(const std::string& name, const T& value, int position = 0) { 136 AssertParamExists(name); 137 params_.at(name)->AsTyped<T>()->Set(value); 138 params_.at(name)->AsTyped<T>()->SetPosition(position); 139 } 140 141 template <typename T> HasValueSet(const std::string & name)142 bool HasValueSet(const std::string& name) const { 143 AssertParamExists(name); 144 return params_.at(name)->AsConstTyped<T>()->HasValueSet(); 145 } 146 147 template <typename T> GetPosition(const std::string & name)148 int GetPosition(const std::string& name) const { 149 AssertParamExists(name); 150 return params_.at(name)->AsConstTyped<T>()->GetPosition(); 151 } 152 153 template <typename T> Get(const std::string & name)154 T Get(const std::string& name) const { 155 AssertParamExists(name); 156 return params_.at(name)->AsConstTyped<T>()->Get(); 157 } 158 159 // Set the value of all same parameters from 'other'. 160 void Set(const ToolParams& other); 161 162 // Merge the value of all parameters from 'other'. 'overwrite' indicates 163 // whether the value of the same paratmeter is overwritten or not. 164 void Merge(const ToolParams& other, bool overwrite = false); 165 166 private: 167 void AssertParamExists(const std::string& name) const; 168 std::unordered_map<std::string, std::unique_ptr<ToolParam>> params_; 169 }; 170 171 #define LOG_TOOL_PARAM(params, type, name, description, verbose) \ 172 do { \ 173 TFLITE_MAY_LOG(INFO, (verbose) || params.HasValueSet<type>(name)) \ 174 << description << ": [" << params.Get<type>(name) << "]"; \ 175 } while (0) 176 177 } // namespace tools 178 } // namespace tflite 179 #endif // TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_ 180