1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_ 17 #define TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_ 18 #include <memory> 19 #include <string> 20 #include <unordered_map> 21 #include <utility> 22 #include <vector> 23 24 namespace tflite { 25 namespace tools { 26 27 template <typename T> 28 class TypedToolParam; 29 30 class ToolParam { 31 protected: 32 enum class ParamType { TYPE_INT32, TYPE_FLOAT, TYPE_BOOL, TYPE_STRING }; 33 template <typename T> 34 static ParamType GetValueType(); 35 36 public: 37 template <typename T> 38 static std::unique_ptr<ToolParam> Create(const T& default_value, 39 int position = 0) { 40 auto* param = new TypedToolParam<T>(default_value); 41 param->SetPosition(position); 42 return std::unique_ptr<ToolParam>(param); 43 } 44 45 template <typename T> AsTyped()46 TypedToolParam<T>* AsTyped() { 47 AssertHasSameType(GetValueType<T>(), type_); 48 return static_cast<TypedToolParam<T>*>(this); 49 } 50 51 template <typename T> AsConstTyped()52 const TypedToolParam<T>* AsConstTyped() const { 53 AssertHasSameType(GetValueType<T>(), type_); 54 return static_cast<const TypedToolParam<T>*>(this); 55 } 56 ~ToolParam()57 virtual ~ToolParam() {} ToolParam(ParamType type)58 explicit ToolParam(ParamType type) 59 : has_value_set_(false), position_(0), type_(type) {} 60 HasValueSet()61 bool HasValueSet() const { return has_value_set_; } 62 GetPosition()63 int GetPosition() const { return position_; } SetPosition(int position)64 void SetPosition(int position) { position_ = position; } 65 Set(const ToolParam &)66 virtual void Set(const ToolParam&) {} 67 68 virtual std::unique_ptr<ToolParam> Clone() const = 0; 69 70 protected: 71 bool has_value_set_; 72 73 // Represents the relative ordering among a set of params. 74 // Note: in our code, a ToolParam is generally used together with a 75 // tflite::Flag so that its value could be set when parsing commandline flags. 76 // In this case, the `position_` is simply the index of the particular flag 77 // into the list of commandline flags (i.e. named 'argv' in general). 78 int position_; 79 80 private: 81 static void AssertHasSameType(ParamType a, ParamType b); 82 83 const ParamType type_; 84 }; 85 86 template <typename T> 87 class TypedToolParam : public ToolParam { 88 public: TypedToolParam(const T & value)89 explicit TypedToolParam(const T& value) 90 : ToolParam(GetValueType<T>()), value_(value) {} 91 Set(const T & value)92 void Set(const T& value) { 93 value_ = value; 94 has_value_set_ = true; 95 } 96 Get()97 T Get() const { return value_; } 98 Set(const ToolParam & other)99 void Set(const ToolParam& other) override { 100 Set(other.AsConstTyped<T>()->Get()); 101 SetPosition(other.AsConstTyped<T>()->GetPosition()); 102 } 103 Clone()104 std::unique_ptr<ToolParam> Clone() const override { 105 return ToolParam::Create<T>(value_, position_); 106 } 107 108 private: 109 T value_; 110 }; 111 112 // A map-like container for holding values of different types. 113 class ToolParams { 114 public: 115 // Add a ToolParam instance `value` w/ `name` to this container. AddParam(const std::string & name,std::unique_ptr<ToolParam> value)116 void AddParam(const std::string& name, std::unique_ptr<ToolParam> value) { 117 params_[name] = std::move(value); 118 } 119 HasParam(const std::string & name)120 bool HasParam(const std::string& name) const { 121 return params_.find(name) != params_.end(); 122 } 123 Empty()124 bool Empty() const { return params_.empty(); } 125 GetParam(const std::string & name)126 const ToolParam* GetParam(const std::string& name) const { 127 const auto& entry = params_.find(name); 128 if (entry == params_.end()) return nullptr; 129 return entry->second.get(); 130 } 131 132 template <typename T> 133 void Set(const std::string& name, const T& value, int position = 0) { 134 AssertParamExists(name); 135 params_.at(name)->AsTyped<T>()->Set(value); 136 params_.at(name)->AsTyped<T>()->SetPosition(position); 137 } 138 139 template <typename T> HasValueSet(const std::string & name)140 bool HasValueSet(const std::string& name) const { 141 AssertParamExists(name); 142 return params_.at(name)->AsConstTyped<T>()->HasValueSet(); 143 } 144 145 template <typename T> GetPosition(const std::string & name)146 int GetPosition(const std::string& name) const { 147 AssertParamExists(name); 148 return params_.at(name)->AsConstTyped<T>()->GetPosition(); 149 } 150 151 template <typename T> Get(const std::string & name)152 T Get(const std::string& name) const { 153 AssertParamExists(name); 154 return params_.at(name)->AsConstTyped<T>()->Get(); 155 } 156 157 // Set the value of all same parameters from 'other'. 158 void Set(const ToolParams& other); 159 160 // Merge the value of all parameters from 'other'. 'overwrite' indicates 161 // whether the value of the same paratmeter is overwritten or not. 162 void Merge(const ToolParams& other, bool overwrite = false); 163 164 private: 165 void AssertParamExists(const std::string& name) const; 166 std::unordered_map<std::string, std::unique_ptr<ToolParam>> params_; 167 }; 168 169 #define LOG_TOOL_PARAM(params, type, name, description, verbose) \ 170 do { \ 171 TFLITE_MAY_LOG(INFO, (verbose) || params.HasValueSet<type>(name)) \ 172 << description << ": [" << params.Get<type>(name) << "]"; \ 173 } while (0) 174 175 } // namespace tools 176 } // namespace tflite 177 #endif // TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_ 178