/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #ifndef TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_ #define TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_ #include #include #include #include #include namespace tflite { namespace tools { template class TypedToolParam; class ToolParam { protected: enum class ParamType { TYPE_INT32, TYPE_FLOAT, TYPE_BOOL, TYPE_STRING }; template static ParamType GetValueType(); public: template static std::unique_ptr Create(const T& default_value) { return std::unique_ptr(new TypedToolParam(default_value)); } template TypedToolParam* AsTyped() { AssertHasSameType(GetValueType(), type_); return static_cast*>(this); } template const TypedToolParam* AsConstTyped() const { AssertHasSameType(GetValueType(), type_); return static_cast*>(this); } virtual ~ToolParam() {} explicit ToolParam(ParamType type) : has_value_set_(false), type_(type) {} bool HasValueSet() const { return has_value_set_; } virtual void Set(const ToolParam&) {} virtual std::unique_ptr Clone() const = 0; protected: bool has_value_set_; private: static void AssertHasSameType(ParamType a, ParamType b); const ParamType type_; }; template class TypedToolParam : public ToolParam { public: explicit TypedToolParam(const T& value) : ToolParam(GetValueType()), value_(value) {} void Set(const T& value) { value_ = value; has_value_set_ = true; } T Get() const { return value_; } void Set(const ToolParam& other) override { Set(other.AsConstTyped()->Get()); } std::unique_ptr Clone() const override { return std::unique_ptr(new TypedToolParam(value_)); } private: T value_; }; // A map-like container for holding values of different types. class ToolParams { public: void AddParam(const std::string& name, std::unique_ptr value) { params_[name] = std::move(value); } bool HasParam(const std::string& name) const { return params_.find(name) != params_.end(); } bool Empty() const { return params_.empty(); } const ToolParam* GetParam(const std::string& name) const { const auto& entry = params_.find(name); if (entry == params_.end()) return nullptr; return entry->second.get(); } template void Set(const std::string& name, const T& value) { AssertParamExists(name); params_.at(name)->AsTyped()->Set(value); } template bool HasValueSet(const std::string& name) const { AssertParamExists(name); return params_.at(name)->AsConstTyped()->HasValueSet(); } template T Get(const std::string& name) const { AssertParamExists(name); return params_.at(name)->AsConstTyped()->Get(); } // Set the value of all same parameters from 'other'. void Set(const ToolParams& other); // Merge the value of all parameters from 'other'. 'overwrite' indicates // whether the value of the same paratmeter is overwritten or not. void Merge(const ToolParams& other, bool overwrite = false); private: void AssertParamExists(const std::string& name) const; std::unordered_map> params_; }; #define LOG_TOOL_PARAM(params, type, name, description, verbose) \ do { \ TFLITE_MAY_LOG(INFO, (verbose) || params.HasValueSet(name)) \ << description << ": [" << params.Get(name) << "]"; \ } while (0) } // namespace tools } // namespace tflite #endif // TENSORFLOW_LITE_TOOLS_TOOL_PARAMS_H_