• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
17 #define TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
18 
19 // clang-format off
20 // Required for IS_MOBILE_PLATFORM
21 #include "tensorflow/core/platform/platform.h"
22 // clang-format on
23 
24 // We use a null implementation for mobile platforms.
25 #ifdef IS_MOBILE_PLATFORM
26 
27 #include <functional>
28 #include <map>
29 #include <memory>
30 
31 #include "tensorflow/core/lib/monitoring/metric_def.h"
32 #include "tensorflow/core/platform/macros.h"
33 
34 namespace tensorflow {
35 namespace monitoring {
36 
37 // MetricCollector which has a null implementation.
38 template <MetricKind metric_kind, typename Value, int NumLabels>
39 class MetricCollector {
40  public:
41   ~MetricCollector() = default;
42 
CollectValue(const std::array<std::string,NumLabels> & labels,Value value)43   void CollectValue(const std::array<std::string, NumLabels>& labels,
44                     Value value) {}
45 
46  private:
47   friend class MetricCollectorGetter;
48 
MetricCollector()49   MetricCollector() {}
50 };
51 
52 // MetricCollectorGetter which has a null implementation.
53 class MetricCollectorGetter {
54  public:
55   template <MetricKind metric_kind, typename Value, int NumLabels>
Get(const MetricDef<metric_kind,Value,NumLabels> * const metric_def)56   MetricCollector<metric_kind, Value, NumLabels> Get(
57       const MetricDef<metric_kind, Value, NumLabels>* const metric_def) {
58     return MetricCollector<metric_kind, Value, NumLabels>();
59   }
60 
61  private:
MetricCollectorGetter()62   MetricCollectorGetter() {}
63 };
64 
65 // CollectionRegistry which has a null implementation.
66 class CollectionRegistry {
67  public:
68   ~CollectionRegistry() = default;
69 
Default()70   static CollectionRegistry* Default() { return new CollectionRegistry(); }
71 
72   using CollectionFunction = std::function<void(MetricCollectorGetter getter)>;
73 
74   // RegistrationHandle which has a null implementation.
75   class RegistrationHandle {
76    public:
RegistrationHandle()77     RegistrationHandle() {}
78 
~RegistrationHandle()79     ~RegistrationHandle() {}
80   };
81 
Register(const AbstractMetricDef * metric_def,const CollectionFunction & collection_function)82   std::unique_ptr<RegistrationHandle> Register(
83       const AbstractMetricDef* metric_def,
84       const CollectionFunction& collection_function) {
85     return std::unique_ptr<RegistrationHandle>(new RegistrationHandle());
86   }
87 
88  private:
CollectionRegistry()89   CollectionRegistry() {}
90 
91   TF_DISALLOW_COPY_AND_ASSIGN(CollectionRegistry);
92 };
93 
94 }  // namespace monitoring
95 }  // namespace tensorflow
96 #else  // !defined(IS_MOBILE_PLATFORM)
97 
98 #include <functional>
99 #include <map>
100 #include <memory>
101 #include <utility>
102 
103 #include "tensorflow/core/framework/summary.pb.h"
104 #include "tensorflow/core/lib/monitoring/collected_metrics.h"
105 #include "tensorflow/core/lib/monitoring/metric_def.h"
106 #include "tensorflow/core/lib/monitoring/types.h"
107 #include "tensorflow/core/platform/env.h"
108 #include "tensorflow/core/platform/logging.h"
109 #include "tensorflow/core/platform/macros.h"
110 #include "tensorflow/core/platform/mutex.h"
111 #include "tensorflow/core/platform/stringpiece.h"
112 #include "tensorflow/core/platform/thread_annotations.h"
113 #include "tensorflow/core/platform/types.h"
114 
115 namespace tensorflow {
116 namespace monitoring {
117 
118 namespace test_util {
119 class CollectionRegistryTestAccess;
120 }  // namespace test_util
121 
122 namespace internal {
123 class Collector;
124 }  // namespace internal
125 
126 // Metric implementations would get an instance of this class using the
127 // MetricCollectorGetter in the collection-function lambda, so that their values
128 // can be collected.
129 //
130 // Read the documentation on CollectionRegistry::Register() for more details.
131 //
132 // For example:
133 //   auto metric_collector = metric_collector_getter->Get(&metric_def);
134 //   metric_collector.CollectValue(some_labels, some_value);
135 //   metric_collector.CollectValue(others_labels, other_value);
136 //
137 // This class is NOT thread-safe.
138 template <MetricKind metric_kind, typename Value, int NumLabels>
139 class MetricCollector {
140  public:
141   ~MetricCollector() = default;
142 
143   // Collects the value with these labels.
144   void CollectValue(const std::array<std::string, NumLabels>& labels,
145                     Value value);
146 
147  private:
148   friend class internal::Collector;
149 
MetricCollector(const MetricDef<metric_kind,Value,NumLabels> * const metric_def,const uint64 registration_time_millis,internal::Collector * const collector,PointSet * const point_set)150   MetricCollector(
151       const MetricDef<metric_kind, Value, NumLabels>* const metric_def,
152       const uint64 registration_time_millis,
153       internal::Collector* const collector, PointSet* const point_set)
154       : metric_def_(metric_def),
155         registration_time_millis_(registration_time_millis),
156         collector_(collector),
157         point_set_(point_set) {
158     point_set_->metric_name = std::string(metric_def->name());
159   }
160 
161   const MetricDef<metric_kind, Value, NumLabels>* const metric_def_;
162   const uint64 registration_time_millis_;
163   internal::Collector* const collector_;
164   PointSet* const point_set_;
165 
166   // This is made copyable because we can't hand out references of this class
167   // from MetricCollectorGetter because this class is templatized, and we need
168   // MetricCollectorGetter not to be templatized and hence MetricCollectorGetter
169   // can't own an instance of this class.
170 };
171 
172 // Returns a MetricCollector with the same template parameters as the
173 // metric-definition, so that the values of a metric can be collected.
174 //
175 // The collection-function defined by a metric takes this as a parameter.
176 //
177 // Read the documentation on CollectionRegistry::Register() for more details.
178 class MetricCollectorGetter {
179  public:
180   // Returns the MetricCollector with the same template parameters as the
181   // metric_def.
182   template <MetricKind metric_kind, typename Value, int NumLabels>
183   MetricCollector<metric_kind, Value, NumLabels> Get(
184       const MetricDef<metric_kind, Value, NumLabels>* const metric_def);
185 
186  private:
187   friend class internal::Collector;
188 
MetricCollectorGetter(internal::Collector * const collector,const AbstractMetricDef * const allowed_metric_def,const uint64 registration_time_millis)189   MetricCollectorGetter(internal::Collector* const collector,
190                         const AbstractMetricDef* const allowed_metric_def,
191                         const uint64 registration_time_millis)
192       : collector_(collector),
193         allowed_metric_def_(allowed_metric_def),
194         registration_time_millis_(registration_time_millis) {}
195 
196   internal::Collector* const collector_;
197   const AbstractMetricDef* const allowed_metric_def_;
198   const uint64 registration_time_millis_;
199 };
200 
201 // A collection registry for metrics.
202 //
203 // Metrics are registered here so that their state can be collected later and
204 // exported.
205 //
206 // This class is thread-safe.
207 class CollectionRegistry {
208  public:
209   ~CollectionRegistry() = default;
210 
211   // Returns the default registry for the process.
212   //
213   // This registry belongs to this library and should never be deleted.
214   static CollectionRegistry* Default();
215 
216   using CollectionFunction = std::function<void(MetricCollectorGetter getter)>;
217 
218   // Registers the metric and the collection-function which can be used to
219   // collect its values. Returns a Registration object, which when upon
220   // destruction would cause the metric to be unregistered from this registry.
221   //
222   // IMPORTANT: Delete the handle before the metric-def is deleted.
223   //
224   // Example usage;
225   // CollectionRegistry::Default()->Register(
226   //   &metric_def,
227   //   [&](MetricCollectorGetter getter) {
228   //     auto metric_collector = getter.Get(&metric_def);
229   //     for (const auto& cell : cells) {
230   //       metric_collector.CollectValue(cell.labels(), cell.value());
231   //     }
232   //   });
233   class RegistrationHandle;
234   std::unique_ptr<RegistrationHandle> Register(
235       const AbstractMetricDef* metric_def,
236       const CollectionFunction& collection_function)
237       TF_LOCKS_EXCLUDED(mu_) TF_MUST_USE_RESULT;
238 
239   // Options for collecting metrics.
240   struct CollectMetricsOptions {
CollectMetricsOptionsCollectMetricsOptions241     CollectMetricsOptions() {}
242     bool collect_metric_descriptors = true;
243   };
244   // Goes through all the registered metrics, collects their definitions
245   // (optionally) and current values and returns them in a standard format.
246   std::unique_ptr<CollectedMetrics> CollectMetrics(
247       const CollectMetricsOptions& options) const;
248 
249  private:
250   friend class test_util::CollectionRegistryTestAccess;
251   friend class internal::Collector;
252 
253   explicit CollectionRegistry(Env* env);
254 
255   // Unregisters the metric from this registry. This is private because the
256   // public interface provides a Registration handle which automatically calls
257   // this upon destruction.
258   void Unregister(const AbstractMetricDef* metric_def) TF_LOCKS_EXCLUDED(mu_);
259 
260   // TF environment, mainly used for timestamping.
261   Env* const env_;
262 
263   mutable mutex mu_;
264 
265   // Information required for collection.
266   struct CollectionInfo {
267     const AbstractMetricDef* const metric_def;
268     CollectionFunction collection_function;
269     uint64 registration_time_millis;
270   };
271   std::map<StringPiece, CollectionInfo> registry_ TF_GUARDED_BY(mu_);
272 
273   TF_DISALLOW_COPY_AND_ASSIGN(CollectionRegistry);
274 };
275 
276 ////
277 // Implementation details follow. API readers may skip.
278 ////
279 
280 class CollectionRegistry::RegistrationHandle {
281  public:
RegistrationHandle(CollectionRegistry * const export_registry,const AbstractMetricDef * const metric_def)282   RegistrationHandle(CollectionRegistry* const export_registry,
283                      const AbstractMetricDef* const metric_def)
284       : export_registry_(export_registry), metric_def_(metric_def) {}
285 
~RegistrationHandle()286   ~RegistrationHandle() { export_registry_->Unregister(metric_def_); }
287 
288  private:
289   CollectionRegistry* const export_registry_;
290   const AbstractMetricDef* const metric_def_;
291 };
292 
293 namespace internal {
294 
295 template <typename Value>
296 void CollectValue(Value value, Point* point);
297 
298 template <>
CollectValue(int64_t value,Point * const point)299 inline void CollectValue(int64_t value, Point* const point) {
300   point->value_type = ValueType::kInt64;
301   point->int64_value = value;
302 }
303 
304 template <>
CollectValue(std::function<int64 ()> value_fn,Point * const point)305 inline void CollectValue(std::function<int64()> value_fn, Point* const point) {
306   point->value_type = ValueType::kInt64;
307   point->int64_value = value_fn();
308 }
309 
310 template <>
CollectValue(std::string value,Point * const point)311 inline void CollectValue(std::string value, Point* const point) {
312   point->value_type = ValueType::kString;
313   point->string_value = std::move(value);
314 }
315 
316 template <>
CollectValue(std::function<std::string ()> value_fn,Point * const point)317 inline void CollectValue(std::function<std::string()> value_fn,
318                          Point* const point) {
319   point->value_type = ValueType::kString;
320   point->string_value = value_fn();
321 }
322 
323 template <>
CollectValue(bool value,Point * const point)324 inline void CollectValue(bool value, Point* const point) {
325   point->value_type = ValueType::kBool;
326   point->bool_value = value;
327 }
328 
329 template <>
CollectValue(std::function<bool ()> value_fn,Point * const point)330 inline void CollectValue(std::function<bool()> value_fn, Point* const point) {
331   point->value_type = ValueType::kBool;
332   point->bool_value = value_fn();
333 }
334 
335 template <>
CollectValue(HistogramProto value,Point * const point)336 inline void CollectValue(HistogramProto value, Point* const point) {
337   point->value_type = ValueType::kHistogram;
338   // This is inefficient. If and when we hit snags, we can change the API to do
339   // this more efficiently.
340   point->histogram_value = std::move(value);
341 }
342 
343 template <>
CollectValue(Percentiles value,Point * const point)344 inline void CollectValue(Percentiles value, Point* const point) {
345   point->value_type = ValueType::kPercentiles;
346   point->percentiles_value = std::move(value);
347 }
348 
349 // Used by the CollectionRegistry class to collect all the values of all the
350 // metrics in the registry. This is an implementation detail of the
351 // CollectionRegistry class, please do not depend on this.
352 //
353 // This cannot be a private nested class because we need to forward declare this
354 // so that the MetricCollector and MetricCollectorGetter classes can be friends
355 // with it.
356 //
357 // This class is thread-safe.
358 class Collector {
359  public:
Collector(const uint64 collection_time_millis)360   explicit Collector(const uint64 collection_time_millis)
361       : collected_metrics_(new CollectedMetrics()),
362         collection_time_millis_(collection_time_millis) {}
363 
364   template <MetricKind metric_kind, typename Value, int NumLabels>
GetMetricCollector(const MetricDef<metric_kind,Value,NumLabels> * const metric_def,const uint64 registration_time_millis,internal::Collector * const collector)365   MetricCollector<metric_kind, Value, NumLabels> GetMetricCollector(
366       const MetricDef<metric_kind, Value, NumLabels>* const metric_def,
367       const uint64 registration_time_millis,
368       internal::Collector* const collector) TF_LOCKS_EXCLUDED(mu_) {
369     auto* const point_set = [&]() {
370       mutex_lock l(mu_);
371       return collected_metrics_->point_set_map
372           .insert(std::make_pair(std::string(metric_def->name()),
373                                  std::unique_ptr<PointSet>(new PointSet())))
374           .first->second.get();
375     }();
376     return MetricCollector<metric_kind, Value, NumLabels>(
377         metric_def, registration_time_millis, collector, point_set);
378   }
379 
collection_time_millis()380   uint64 collection_time_millis() const { return collection_time_millis_; }
381 
382   void CollectMetricDescriptor(const AbstractMetricDef* const metric_def)
383       TF_LOCKS_EXCLUDED(mu_);
384 
385   void CollectMetricValues(
386       const CollectionRegistry::CollectionInfo& collection_info);
387 
388   std::unique_ptr<CollectedMetrics> ConsumeCollectedMetrics()
389       TF_LOCKS_EXCLUDED(mu_);
390 
391  private:
392   mutable mutex mu_;
393   std::unique_ptr<CollectedMetrics> collected_metrics_ TF_GUARDED_BY(mu_);
394   const uint64 collection_time_millis_;
395 
396   TF_DISALLOW_COPY_AND_ASSIGN(Collector);
397 };
398 
399 // Write the timestamps for the point based on the MetricKind.
400 //
401 // Gauge metrics will have start and end timestamps set to the collection time.
402 //
403 // Cumulative metrics will have the start timestamp set to the time when the
404 // collection function was registered, while the end timestamp will be set to
405 // the collection time.
406 template <MetricKind kind>
407 void WriteTimestamps(const uint64 registration_time_millis,
408                      const uint64 collection_time_millis, Point* const point);
409 
410 template <>
411 inline void WriteTimestamps<MetricKind::kGauge>(
412     const uint64 registration_time_millis, const uint64 collection_time_millis,
413     Point* const point) {
414   point->start_timestamp_millis = collection_time_millis;
415   point->end_timestamp_millis = collection_time_millis;
416 }
417 
418 template <>
419 inline void WriteTimestamps<MetricKind::kCumulative>(
420     const uint64 registration_time_millis, const uint64 collection_time_millis,
421     Point* const point) {
422   point->start_timestamp_millis = registration_time_millis;
423   // There's a chance that the clock goes backwards on the same machine, so we
424   // protect ourselves against that.
425   point->end_timestamp_millis =
426       registration_time_millis < collection_time_millis
427           ? collection_time_millis
428           : registration_time_millis;
429 }
430 
431 }  // namespace internal
432 
433 template <MetricKind metric_kind, typename Value, int NumLabels>
CollectValue(const std::array<std::string,NumLabels> & labels,Value value)434 void MetricCollector<metric_kind, Value, NumLabels>::CollectValue(
435     const std::array<std::string, NumLabels>& labels, Value value) {
436   point_set_->points.emplace_back(new Point());
437   auto* const point = point_set_->points.back().get();
438   const std::vector<std::string> label_descriptions =
439       metric_def_->label_descriptions();
440   point->labels.reserve(NumLabels);
441   for (int i = 0; i < NumLabels; ++i) {
442     point->labels.push_back({});
443     auto* const label = &point->labels.back();
444     label->name = label_descriptions[i];
445     label->value = labels[i];
446   }
447   internal::CollectValue(std::move(value), point);
448   internal::WriteTimestamps<metric_kind>(
449       registration_time_millis_, collector_->collection_time_millis(), point);
450 }
451 
452 template <MetricKind metric_kind, typename Value, int NumLabels>
Get(const MetricDef<metric_kind,Value,NumLabels> * const metric_def)453 MetricCollector<metric_kind, Value, NumLabels> MetricCollectorGetter::Get(
454     const MetricDef<metric_kind, Value, NumLabels>* const metric_def) {
455   if (allowed_metric_def_ != metric_def) {
456     LOG(FATAL) << "Expected collection for: " << allowed_metric_def_->name()
457                << " but instead got: " << metric_def->name();
458   }
459 
460   return collector_->GetMetricCollector(metric_def, registration_time_millis_,
461                                         collector_);
462 }
463 
464 class Exporter {
465  public:
~Exporter()466   virtual ~Exporter() {}
467   virtual void PeriodicallyExportMetrics() = 0;
468   virtual void ExportMetrics() = 0;
469 };
470 
471 namespace exporter_registration {
472 
473 class ExporterRegistration {
474  public:
ExporterRegistration(Exporter * exporter)475   explicit ExporterRegistration(Exporter* exporter) : exporter_(exporter) {
476     exporter_->PeriodicallyExportMetrics();
477   }
478 
479  private:
480   Exporter* exporter_;
481 };
482 
483 }  // namespace exporter_registration
484 
485 #define REGISTER_TF_METRICS_EXPORTER(exporter) \
486   REGISTER_TF_METRICS_EXPORTER_UNIQ_HELPER(__COUNTER__, exporter)
487 
488 #define REGISTER_TF_METRICS_EXPORTER_UNIQ_HELPER(ctr, exporter) \
489   REGISTER_TF_METRICS_EXPORTER_UNIQ(ctr, exporter)
490 
491 #define REGISTER_TF_METRICS_EXPORTER_UNIQ(ctr, exporter)                       \
492   static ::tensorflow::monitoring::exporter_registration::ExporterRegistration \
493       exporter_registration_##ctr(new exporter())
494 
495 }  // namespace monitoring
496 }  // namespace tensorflow
497 
498 #endif  // IS_MOBILE_PLATFORM
499 #endif  // TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
500