• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_LABEL_MAP_ITEM_H_
16 #define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_LABEL_MAP_ITEM_H_
17 
18 #include <string>
19 #include <vector>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/status/status.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow_lite_support/cc/port/statusor.h"
26 
27 namespace tflite {
28 namespace task {
29 namespace vision {
30 
31 // Structure mapping a numerical class index output to a Knowledge Graph entity
32 // ID or any other string label representing this class. Optionally it is
33 // possible to specify an additional display name (in a given language) which is
34 // typically used for display purposes.
35 struct LabelMapItem {
36   // E.g. name = "/m/02xwb"
37   std::string name;
38   // E.g. display_name = "Fruit"
39   std::string display_name;
40   // Optional list of children (e.g. subcategories) used to represent a
41   // hierarchy.
42   std::vector<std::string> child_name;
43 };
44 
45 // Builds a label map from labels and (optional) display names file contents,
46 // both expected to contain one label per line. Those are typically obtained
47 // from TFLite Model Metadata TENSOR_AXIS_LABELS or TENSOR_VALUE_LABELS
48 // associated files.
49 // Returns an error e.g. if there's a mismatch between the number of labels and
50 // display names.
51 tflite::support::StatusOr<std::vector<LabelMapItem>> BuildLabelMapFromFiles(
52     absl::string_view labels_file, absl::string_view display_names_file);
53 
54 // A class that represents a hierarchy of labels as specified in a label map.
55 //
56 // For example, it is useful to determine if one label is a descendant of
57 // another label or not. This can be used to implement labels pruning based on
58 // hierarchy, e.g. if both "fruit" and "banana" have been inferred by a given
59 // classifier model prune "fruit" from the final results as "banana" is a more
60 // fine-grained descendant.
61 class LabelHierarchy {
62  public:
63   LabelHierarchy() = default;
64 
65   // Initializes the hierarchy of labels from a given label map vector. Returns
66   // an error status in case of failure, typically if the input label map does
67   // not contain any hierarchical relations between labels.
68   absl::Status InitializeFromLabelMap(
69       std::vector<LabelMapItem> label_map_items);
70 
71   // Returns true if `descendant_name` is a descendant of `ancestor_name` in the
72   // hierarchy of labels. Invalid names, i.e. names which do not exist in the
73   // label map used at initialization time, are ignored.
74   bool HaveAncestorDescendantRelationship(
75       const std::string& ancestor_name,
76       const std::string& descendant_name) const;
77 
78  private:
79   // Retrieve and return all parent names, if any, for the input label name.
80   absl::flat_hash_set<std::string> GetParents(const std::string& name) const;
81 
82   // Retrieve all ancestor names, if any, for the input label name.
83   void GetAncestors(const std::string& name,
84                     absl::flat_hash_set<std::string>* ancestors) const;
85 
86   // Label name (key) to parent names (value) direct mapping.
87   absl::flat_hash_map<std::string, absl::flat_hash_set<std::string>>
88       parents_map_;
89 };
90 
91 }  // namespace vision
92 }  // namespace task
93 }  // namespace tflite
94 
95 #endif  // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_CORE_LABEL_MAP_ITEM_H_
96