1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_LABEL_MAP_ITEM_H_ 16 #define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_LABEL_MAP_ITEM_H_ 17 18 #include <string> 19 #include <vector> 20 21 #include "absl/container/flat_hash_map.h" 22 #include "absl/container/flat_hash_set.h" 23 #include "absl/status/status.h" 24 #include "absl/strings/string_view.h" 25 #include "tensorflow_lite_support/cc/port/statusor.h" 26 27 namespace tflite { 28 namespace task { 29 namespace vision { 30 31 // Structure mapping a numerical class index output to a Knowledge Graph entity 32 // ID or any other string label representing this class. Optionally it is 33 // possible to specify an additional display name (in a given language) which is 34 // typically used for display purposes. 35 struct LabelMapItem { 36 // E.g. name = "/m/02xwb" 37 std::string name; 38 // E.g. display_name = "Fruit" 39 std::string display_name; 40 // Optional list of children (e.g. subcategories) used to represent a 41 // hierarchy. 42 std::vector<std::string> child_name; 43 }; 44 45 // Builds a label map from labels and (optional) display names file contents, 46 // both expected to contain one label per line. Those are typically obtained 47 // from TFLite Model Metadata TENSOR_AXIS_LABELS or TENSOR_VALUE_LABELS 48 // associated files. 49 // Returns an error e.g. if there's a mismatch between the number of labels and 50 // display names. 51 tflite::support::StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles( 52 absl::string_view labels_file, absl::string_view display_names_file); 53 54 // A class that represents a hierarchy of labels as specified in a label map. 55 // 56 // For example, it is useful to determine if one label is a descendant of 57 // another label or not. This can be used to implement labels pruning based on 58 // hierarchy, e.g. if both "fruit" and "banana" have been inferred by a given 59 // classifier model prune "fruit" from the final results as "banana" is a more 60 // fine-grained descendant. 61 class LabelHierarchy { 62 public: 63 LabelHierarchy() = default; 64 65 // Initializes the hierarchy of labels from a given label map vector. Returns 66 // an error status in case of failure, typically if the input label map does 67 // not contain any hierarchical relations between labels. 68 absl::Status InitializeFromLabelMap( 69 std::vector<LabelMapItem> label_map_items); 70 71 // Returns true if `descendant_name` is a descendant of `ancestor_name` in the 72 // hierarchy of labels. Invalid names, i.e. names which do not exist in the 73 // label map used at initialization time, are ignored. 74 bool HaveAncestorDescendantRelationship( 75 const std::string& ancestor_name, 76 const std::string& descendant_name) const; 77 78 private: 79 // Retrieve and return all parent names, if any, for the input label name. 80 absl::flat_hash_set<std::string> GetParents(const std::string& name) const; 81 82 // Retrieve all ancestor names, if any, for the input label name. 83 void GetAncestors(const std::string& name, 84 absl::flat_hash_set<std::string>* ancestors) const; 85 86 // Label name (key) to parent names (value) direct mapping. 87 absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>> 88 parents_map_; 89 }; 90 91 } // namespace vision 92 } // namespace task 93 } // namespace tflite 94 95 #endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_LABEL_MAP_ITEM_H_ 96