1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_ 17 #define TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_ 18 #include <memory> 19 #include <string> 20 #include <unordered_map> 21 #include <utility> 22 #include <vector> 23 24 namespace tflite { 25 namespace tools { 26 27 template <typename T> 28 class TypedToolParam; 29 30 class ToolParam { 31 protected: 32 enum class ParamType { TYPE_INT32, TYPE_FLOAT, TYPE_BOOL, TYPE_STRING }; 33 template <typename T> 34 static ParamType GetValueType(); 35 36 public: 37 template <typename T> Create(const T & default_value)38 static std::unique_ptr<ToolParam> Create(const T& default_value) { 39 return std::unique_ptr<ToolParam>(new TypedToolParam<T>(default_value)); 40 } 41 42 template <typename T> AsTyped()43 TypedToolParam<T>* AsTyped() { 44 AssertHasSameType(GetValueType<T>(), type_); 45 return static_cast<TypedToolParam<T>*>(this); 46 } 47 48 template <typename T> AsConstTyped()49 const TypedToolParam<T>* AsConstTyped() const { 50 AssertHasSameType(GetValueType<T>(), type_); 51 return static_cast<const TypedToolParam<T>*>(this); 52 } 53 ~ToolParam()54 virtual ~ToolParam() {} ToolParam(ParamType type)55 explicit ToolParam(ParamType type) : has_value_set_(false), type_(type) {} 56 HasValueSet()57 bool HasValueSet() const { return has_value_set_; } 58 Set(const ToolParam &)59 virtual void Set(const ToolParam&) {} 60 61 virtual std::unique_ptr<ToolParam> Clone() const = 0; 62 63 protected: 64 bool has_value_set_; 65 66 private: 67 static void AssertHasSameType(ParamType a, ParamType b); 68 69 const ParamType type_; 70 }; 71 72 template <typename T> 73 class TypedToolParam : public ToolParam { 74 public: TypedToolParam(const T & value)75 explicit TypedToolParam(const T& value) 76 : ToolParam(GetValueType<T>()), value_(value) {} 77 Set(const T & value)78 void Set(const T& value) { 79 value_ = value; 80 has_value_set_ = true; 81 } 82 Get()83 T Get() const { return value_; } 84 Set(const ToolParam & other)85 void Set(const ToolParam& other) override { 86 Set(other.AsConstTyped<T>()->Get()); 87 } 88 Clone()89 std::unique_ptr<ToolParam> Clone() const override { 90 return std::unique_ptr<ToolParam>(new TypedToolParam<T>(value_)); 91 } 92 93 private: 94 T value_; 95 }; 96 97 // A map-like container for holding values of different types. 98 class ToolParams { 99 public: AddParam(const std::string & name,std::unique_ptr<ToolParam> value)100 void AddParam(const std::string& name, std::unique_ptr<ToolParam> value) { 101 params_[name] = std::move(value); 102 } 103 HasParam(const std::string & name)104 bool HasParam(const std::string& name) const { 105 return params_.find(name) != params_.end(); 106 } 107 Empty()108 bool Empty() const { return params_.empty(); } 109 GetParam(const std::string & name)110 const ToolParam* GetParam(const std::string& name) const { 111 const auto& entry = params_.find(name); 112 if (entry == params_.end()) return nullptr; 113 return entry->second.get(); 114 } 115 116 template <typename T> Set(const std::string & name,const T & value)117 void Set(const std::string& name, const T& value) { 118 AssertParamExists(name); 119 params_.at(name)->AsTyped<T>()->Set(value); 120 } 121 122 template <typename T> HasValueSet(const std::string & name)123 bool HasValueSet(const std::string& name) const { 124 AssertParamExists(name); 125 return params_.at(name)->AsConstTyped<T>()->HasValueSet(); 126 } 127 128 template <typename T> Get(const std::string & name)129 T Get(const std::string& name) const { 130 AssertParamExists(name); 131 return params_.at(name)->AsConstTyped<T>()->Get(); 132 } 133 134 // Set the value of all same parameters from 'other'. 135 void Set(const ToolParams& other); 136 137 // Merge the value of all parameters from 'other'. 'overwrite' indicates 138 // whether the value of the same paratmeter is overwritten or not. 139 void Merge(const ToolParams& other, bool overwrite = false); 140 141 private: 142 void AssertParamExists(const std::string& name) const; 143 std::unordered_map<std::string, std::unique_ptr<ToolParam>> params_; 144 }; 145 146 #define LOG_TOOL_PARAM(params, type, name, description, verbose) \ 147 do { \ 148 TFLITE_MAY_LOG(INFO, (verbose) || params.HasValueSet<type>(name)) \ 149 << description << ": [" << params.Get<type>(name) << "]"; \ 150 } while (0) 151 152 } // namespace tools 153 } // namespace tflite 154 #endif // TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_ 155