• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_STAGE_H_
16 #define TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_STAGE_H_
17 
18 #include <functional>
19 #include <map>
20 #include <regex>  // NOLINT
21 #include <utility>
22 #include <vector>
23 
24 #include "absl/container/flat_hash_map.h"
25 #include "absl/memory/memory.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/lite/tools/evaluation/proto/evaluation_config.pb.h"
28 #include "tensorflow/lite/tools/evaluation/proto/evaluation_stages.pb.h"
29 
30 namespace tflite {
31 namespace evaluation {
32 
33 class EvaluationStage;
34 
35 typedef std::function<std::unique_ptr<EvaluationStage>(
36     const EvaluationStageConfig&)>
37     FactoryFunc;
38 
39 // Superclass for a single stage of an EvaluationPipeline.
40 // Provides basic functionality for construction and accessing
41 // initializers/inputs/outputs.
42 // Every subclass of EvaluationStage will define its own behavior by specifying
43 // appropriate accessor TAGs and implementing the Init, Run and Close methods.
44 class EvaluationStage {
45  public:
46   // Initializes an EvaluationStage. Returns false if initialization failed,
47   // true otherwise.
48   // Should be called only once, before any call to Run().
49   // object_map should contain {initializer name : object pointer} mappings
50   // required for initialization.
51   //
52   // NOTE: EvaluationStage will not take ownership of any elements of
53   // object_map.
54   bool Init(absl::flat_hash_map<std::string, void*>& object_map);
55 
56   // An individual run of the EvaluationStage. Returns false if there was a
57   // failure, true otherwise.
58   // Init() should be called before any calls to run().
59   // Inputs are acquired from and outputs are written to the incoming
60   // object_map, using appropriate TAGs.
61   //
62   // NOTE: The EvaluationStage should maintain ownership of outputs it
63   // populates into object_map. Ownership of inputs must be maintained
64   // elsewhere.
65   virtual bool Run(absl::flat_hash_map<std::string, void*>& object_map) = 0;
66 
67   // Returns the latest metrics based on all Run() calls made so far.
68   virtual EvaluationStageMetrics LatestMetrics() = 0;
69 
70   // The canonical way to instantiate EvaluationStages.
71   // Remember to call <classname>_ENABLE() first.
Create(const EvaluationStageConfig & config)72   static std::unique_ptr<EvaluationStage> Create(
73       const EvaluationStageConfig& config) {
74     if (!config.has_specification() ||
75         !config.specification().has_process_class()) {
76       LOG(ERROR) << "Process specification not present in config: "
77                  << config.name();
78       return nullptr;
79     }
80     auto& factory_ptr =
81         (*GetFactoryMapPtr())[config.specification().process_class()];
82     if (!factory_ptr) return nullptr;
83     return factory_ptr(config);
84   }
85 
86   // Used by DEFINE_REGISTRATION.
87   // This method takes ownership of factory.
88   // Should only be used via DEFINE_REGISTRATION macro.
RegisterStage(const ProcessClass & process_class,FactoryFunc class_factory)89   static void RegisterStage(const ProcessClass& process_class,
90                             FactoryFunc class_factory) {
91     (*GetFactoryMapPtr())[process_class] = std::move(class_factory);
92   }
93 
94   virtual ~EvaluationStage() = default;
95 
96  protected:
97   // Constructs an EvaluationStage.
98   // Each subclass constructor must invoke this constructor.
99   //
100   // NOTE: Do NOT use constructors to obtain new EvaluationStages. Use
101   // EvaluationStage::Create instead.
EvaluationStage(const EvaluationStageConfig & config)102   explicit EvaluationStage(const EvaluationStageConfig& config)
103       : config_(config) {}
104 
105   // Class-specific initialization, to be overridden by EvaluationStage
106   // sub-classes. Gets called in EvaluationStage::Init().
107   //
108   // NOTE: This object should not take ownership of any elements of object_map.
109   virtual bool DoInit(absl::flat_hash_map<std::string, void*>& object_map) = 0;
110 
111   // The three following functions return the initializer/input/output TAGs used
112   // by an EvaluationStage. These should be mapped to meaningful names in the
113   // EvaluationStageConfig, and to required objects during calls to Init/Run.
114   // Format for TAGs: [A-Z0-9_]+ (Uppercase letters, numbers, "_")
115   // Refer docs in tflite.evaluation.EvaluationStageConfig for more information.
116 
117   // Returns the expected initializer TAGs.
118   virtual std::vector<std::string> GetInitializerTags() = 0;
119 
120   // Returns the expected input TAGs.
121   virtual std::vector<std::string> GetInputTags() = 0;
122 
123   // Returns the expected output TAGs.
124   virtual std::vector<std::string> GetOutputTags() = 0;
125 
126   // Populates a pointer to the object corresponding to provided TAG.
127   // Returns true if success, false otherwise.
128   // object_map contain a {name : object pointer} mapping, with the
129   // name being mapped to the expected TAG in the EvaluationStageConfig.
130   // NOTE: object pointer must be non-NULL.
131   template <class T>
GetObjectFromTag(const std::string & tag,absl::flat_hash_map<std::string,void * > & object_map,T ** object_ptr)132   bool GetObjectFromTag(const std::string& tag,
133                         absl::flat_hash_map<std::string, void*>& object_map,
134                         T** object_ptr) {
135     *object_ptr = nullptr;
136     // Find name corresponding to TAG.
137     auto mapping_iter = tags_to_names_map_.find(tag);
138     if (mapping_iter == tags_to_names_map_.end()) {
139       LOG(ERROR) << "Unexpected TAG to GetObjectFromTag: " << tag;
140       return false;
141     }
142     const std::string& expected_name = mapping_iter->second;
143 
144     // Find object from name.
145     auto object_iter = object_map.find(expected_name);
146     if (object_iter == object_map.end()) {
147       LOG(ERROR) << "Could not find object for name: " << expected_name;
148       return false;
149     }
150     if (!object_iter->second) {
151       LOG(ERROR) << "Found null pointer for name: " << expected_name;
152       return false;
153     }
154     *object_ptr = static_cast<T*>(object_iter->second);
155     return true;
156   }
157 
158   // Maps the appropriate name to a given object in object_map. The name is
159   // derived from mappings provided in the EvaluationStageConfig.
160   // Returns false if tag is invalid, true otherwise.
161   //
162   // NOTE: The EvaluationStage must maintain ownership of object for the
163   // lifetime of object_map
AssignObjectToTag(const std::string & tag,void * object_ptr,absl::flat_hash_map<std::string,void * > & object_map)164   bool AssignObjectToTag(const std::string& tag, void* object_ptr,
165                          absl::flat_hash_map<std::string, void*>& object_map) {
166     // Find name corresponding to TAG.
167     auto mapping_iter = tags_to_names_map_.find(tag);
168     if (mapping_iter == tags_to_names_map_.end()) {
169       LOG(ERROR) << "Unexpected TAG to AssignObjectToTag: " << tag;
170       return false;
171     }
172     const std::string& expected_name = mapping_iter->second;
173 
174     object_map[expected_name] = object_ptr;
175     return true;
176   }
177 
178   EvaluationStageConfig config_;
179 
180  private:
181   // Verifies that all TAGs from expected_tags are present in
182   // tag_to_name_mappings, and then populates tags_to_names_map_ with the
183   // appropriate entries. Returns false in case any TAG/mapping is invalid, true
184   // otherwise.
185   // expected_tags should be a list of TAG-strings.
186   // tag_to_name_mappings should be RepeatedPtrField of strings mapping TAGs to
187   // names in the form "SOME_TAG:some_name".
188   bool ProcessExpectedTags(const std::vector<std::string>& expected_tags,
189                            std::vector<std::string>& tag_to_name_mappings);
190 
GetFactoryMapPtr()191   static std::map<ProcessClass, FactoryFunc>* GetFactoryMapPtr() {
192     return process_class_to_factory_map_;
193   }
194 
195   // Used by factories.
196   static std::map<ProcessClass, FactoryFunc>* process_class_to_factory_map_;
197 
198   // Maps expected TAGs to their names as defined by the EvaluationStageConfig.
199   absl::flat_hash_map<std::string, std::string> tags_to_names_map_;
200 
201   // To ensure correct formatting in the config.
202   const std::regex kTagNameMappingPattern{"^([A-Z0-9_]+):([a-z0-9_]+)$",
203                                           std::regex::optimize};
204 
205   // To ensure correct formatting in TAG names.
206   const std::regex kTagPattern{"^[A-Z0-9_]+$", std::regex::optimize};
207 };
208 
209 // Add this to headers of new EvaluationStages.
210 #define DECLARE_FACTORY(classname) void classname##_ENABLE();
211 
212 // Add this to implementation files of new EvaluationStages.
213 // Call <stage_name>_ENABLE() before using EvaluationStage::Create for the
214 // class.
215 #define DEFINE_FACTORY(classname, processclass)                                \
216   void classname##_ENABLE() {                                                  \
217     FactoryFunc classname##Factory = [](const EvaluationStageConfig& config) { \
218       return absl::make_unique<classname>(config);                             \
219     };                                                                         \
220     EvaluationStage::RegisterStage(processclass, classname##Factory);          \
221   }
222 
223 // Use this to assign a non-nullptr pointer to tag in object_map.
224 #define ASSIGN_OBJECT(tag, ptr, object_map)       \
225   if (!AssignObjectToTag(tag, ptr, object_map)) { \
226     return false;                                 \
227   }
228 
229 // Use this to obtain pointers to required object.
230 // Will return false if name corresponding to tag is not found, or if the
231 // pointer found is nullptr.
232 #define GET_OBJECT(tag, object_map, location)         \
233   if (!GetObjectFromTag(tag, object_map, location)) { \
234     return false;                                     \
235   }
236 
237 }  // namespace evaluation
238 }  // namespace tflite
239 
240 #endif  // TENSORFLOW_LITE_TOOLS_EVALUATION_EVALUATION_STAGE_H_
241