• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
17 #define TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
18 
19 // clang-format off
20 // Required for IS_MOBILE_PLATFORM
21 #include "tensorflow/core/platform/platform.h"
22 // clang-format on
23 
24 // We use a null implementation for mobile platforms.
25 #ifdef IS_MOBILE_PLATFORM
26 
27 #include <functional>
28 #include <map>
29 #include <memory>
30 
31 #include "tensorflow/core/lib/monitoring/metric_def.h"
32 #include "tensorflow/core/platform/macros.h"
33 
34 namespace tensorflow {
35 namespace monitoring {
36 
37 // MetricCollector which has a null implementation.
38 template <MetricKind metric_kind, typename Value, int NumLabels>
39 class MetricCollector {
40  public:
41   ~MetricCollector() = default;
42 
CollectValue(const std::array<std::string,NumLabels> & labels,Value value)43   void CollectValue(const std::array<std::string, NumLabels>& labels,
44                     Value value) {}
45 
46  private:
47   friend class MetricCollectorGetter;
48 
MetricCollector()49   MetricCollector() {}
50 };
51 
52 // MetricCollectorGetter which has a null implementation.
53 class MetricCollectorGetter {
54  public:
55   template <MetricKind metric_kind, typename Value, int NumLabels>
Get(const MetricDef<metric_kind,Value,NumLabels> * const metric_def)56   MetricCollector<metric_kind, Value, NumLabels> Get(
57       const MetricDef<metric_kind, Value, NumLabels>* const metric_def) {
58     return MetricCollector<metric_kind, Value, NumLabels>();
59   }
60 
61  private:
MetricCollectorGetter()62   MetricCollectorGetter() {}
63 };
64 
65 // CollectionRegistry which has a null implementation.
66 class CollectionRegistry {
67  public:
68   ~CollectionRegistry() = default;
69 
Default()70   static CollectionRegistry* Default() { return new CollectionRegistry(); }
71 
72   using CollectionFunction = std::function<void(MetricCollectorGetter getter)>;
73 
74   // RegistrationHandle which has a null implementation.
75   class RegistrationHandle {
76    public:
RegistrationHandle()77     RegistrationHandle() {}
78 
~RegistrationHandle()79     ~RegistrationHandle() {}
80   };
81 
Register(const AbstractMetricDef * metric_def,const CollectionFunction & collection_function)82   std::unique_ptr<RegistrationHandle> Register(
83       const AbstractMetricDef* metric_def,
84       const CollectionFunction& collection_function) {
85     return std::unique_ptr<RegistrationHandle>(new RegistrationHandle());
86   }
87 
88  private:
CollectionRegistry()89   CollectionRegistry() {}
90 
91   TF_DISALLOW_COPY_AND_ASSIGN(CollectionRegistry);
92 };
93 
94 }  // namespace monitoring
95 }  // namespace tensorflow
96 #else  // !defined(IS_MOBILE_PLATFORM)
97 
98 #include <functional>
99 #include <map>
100 #include <memory>
101 #include <utility>
102 
103 #include "tensorflow/core/framework/summary.pb.h"
104 #include "tensorflow/core/lib/monitoring/collected_metrics.h"
105 #include "tensorflow/core/lib/monitoring/metric_def.h"
106 #include "tensorflow/core/lib/monitoring/types.h"
107 #include "tensorflow/core/platform/env.h"
108 #include "tensorflow/core/platform/logging.h"
109 #include "tensorflow/core/platform/macros.h"
110 #include "tensorflow/core/platform/mutex.h"
111 #include "tensorflow/core/platform/stringpiece.h"
112 #include "tensorflow/core/platform/thread_annotations.h"
113 #include "tensorflow/core/platform/types.h"
114 
115 namespace tensorflow {
116 namespace monitoring {
117 
118 namespace test_util {
119 class CollectionRegistryTestAccess;
120 }  // namespace test_util
121 
122 namespace internal {
123 class Collector;
124 }  // namespace internal
125 
126 // Metric implementations would get an instance of this class using the
127 // MetricCollectorGetter in the collection-function lambda, so that their values
128 // can be collected.
129 //
130 // Read the documentation on CollectionRegistry::Register() for more details.
131 //
132 // For example:
133 //   auto metric_collector = metric_collector_getter->Get(&metric_def);
134 //   metric_collector.CollectValue(some_labels, some_value);
135 //   metric_collector.CollectValue(others_labels, other_value);
136 //
137 // This class is NOT thread-safe.
138 template <MetricKind metric_kind, typename Value, int NumLabels>
139 class MetricCollector {
140  public:
141   ~MetricCollector() = default;
142 
143   // Collects the value with these labels.
144   void CollectValue(const std::array<std::string, NumLabels>& labels,
145                     Value value);
146 
147  private:
148   friend class internal::Collector;
149 
MetricCollector(const MetricDef<metric_kind,Value,NumLabels> * const metric_def,const uint64 registration_time_millis,internal::Collector * const collector,PointSet * const point_set)150   MetricCollector(
151       const MetricDef<metric_kind, Value, NumLabels>* const metric_def,
152       const uint64 registration_time_millis,
153       internal::Collector* const collector, PointSet* const point_set)
154       : metric_def_(metric_def),
155         registration_time_millis_(registration_time_millis),
156         collector_(collector),
157         point_set_(point_set) {
158     point_set_->metric_name = std::string(metric_def->name());
159   }
160 
161   const MetricDef<metric_kind, Value, NumLabels>* const metric_def_;
162   const uint64 registration_time_millis_;
163   internal::Collector* const collector_;
164   PointSet* const point_set_;
165 
166   // This is made copyable because we can't hand out references of this class
167   // from MetricCollectorGetter because this class is templatized, and we need
168   // MetricCollectorGetter not to be templatized and hence MetricCollectorGetter
169   // can't own an instance of this class.
170 };
171 
172 // Returns a MetricCollector with the same template parameters as the
173 // metric-definition, so that the values of a metric can be collected.
174 //
175 // The collection-function defined by a metric takes this as a parameter.
176 //
177 // Read the documentation on CollectionRegistry::Register() for more details.
178 class MetricCollectorGetter {
179  public:
180   // Returns the MetricCollector with the same template parameters as the
181   // metric_def.
182   template <MetricKind metric_kind, typename Value, int NumLabels>
183   MetricCollector<metric_kind, Value, NumLabels> Get(
184       const MetricDef<metric_kind, Value, NumLabels>* const metric_def);
185 
186  private:
187   friend class internal::Collector;
188 
MetricCollectorGetter(internal::Collector * const collector,const AbstractMetricDef * const allowed_metric_def,const uint64 registration_time_millis)189   MetricCollectorGetter(internal::Collector* const collector,
190                         const AbstractMetricDef* const allowed_metric_def,
191                         const uint64 registration_time_millis)
192       : collector_(collector),
193         allowed_metric_def_(allowed_metric_def),
194         registration_time_millis_(registration_time_millis) {}
195 
196   internal::Collector* const collector_;
197   const AbstractMetricDef* const allowed_metric_def_;
198   const uint64 registration_time_millis_;
199 };
200 
201 // A collection registry for metrics.
202 //
203 // Metrics are registered here so that their state can be collected later and
204 // exported.
205 //
206 // This class is thread-safe.
207 class CollectionRegistry {
208  public:
209   ~CollectionRegistry() = default;
210 
211   // Returns the default registry for the process.
212   //
213   // This registry belongs to this library and should never be deleted.
214   static CollectionRegistry* Default();
215 
216   using CollectionFunction = std::function<void(MetricCollectorGetter getter)>;
217 
218   // Registers the metric and the collection-function which can be used to
219   // collect its values. Returns a Registration object, which when upon
220   // destruction would cause the metric to be unregistered from this registry.
221   //
222   // IMPORTANT: Delete the handle before the metric-def is deleted.
223   //
224   // Example usage;
225   // CollectionRegistry::Default()->Register(
226   //   &metric_def,
227   //   [&](MetricCollectorGetter getter) {
228   //     auto metric_collector = getter.Get(&metric_def);
229   //     for (const auto& cell : cells) {
230   //       metric_collector.CollectValue(cell.labels(), cell.value());
231   //     }
232   //   });
233   class RegistrationHandle;
234   std::unique_ptr<RegistrationHandle> Register(
235       const AbstractMetricDef* metric_def,
236       const CollectionFunction& collection_function)
237       TF_LOCKS_EXCLUDED(mu_) TF_MUST_USE_RESULT;
238 
239   // Options for collecting metrics.
240   struct CollectMetricsOptions {
CollectMetricsOptionsCollectMetricsOptions241     CollectMetricsOptions() {}
242     bool collect_metric_descriptors = true;
243   };
244   // Goes through all the registered metrics, collects their definitions
245   // (optionally) and current values and returns them in a standard format.
246   std::unique_ptr<CollectedMetrics> CollectMetrics(
247       const CollectMetricsOptions& options) const;
248 
249  private:
250   friend class test_util::CollectionRegistryTestAccess;
251   friend class internal::Collector;
252 
253   explicit CollectionRegistry(Env* env);
254 
255   // Unregisters the metric from this registry. This is private because the
256   // public interface provides a Registration handle which automatically calls
257   // this upon destruction.
258   void Unregister(const AbstractMetricDef* metric_def) TF_LOCKS_EXCLUDED(mu_);
259 
260   // TF environment, mainly used for timestamping.
261   Env* const env_;
262 
263   mutable mutex mu_;
264 
265   // Information required for collection.
266   struct CollectionInfo {
267     const AbstractMetricDef* const metric_def;
268     CollectionFunction collection_function;
269     uint64 registration_time_millis;
270   };
271   std::map<StringPiece, CollectionInfo> registry_ TF_GUARDED_BY(mu_);
272 
273   TF_DISALLOW_COPY_AND_ASSIGN(CollectionRegistry);
274 };
275 
276 ////
277 // Implementation details follow. API readers may skip.
278 ////
279 
280 class CollectionRegistry::RegistrationHandle {
281  public:
RegistrationHandle(CollectionRegistry * const export_registry,const AbstractMetricDef * const metric_def)282   RegistrationHandle(CollectionRegistry* const export_registry,
283                      const AbstractMetricDef* const metric_def)
284       : export_registry_(export_registry), metric_def_(metric_def) {}
285 
~RegistrationHandle()286   ~RegistrationHandle() { export_registry_->Unregister(metric_def_); }
287 
288  private:
289   CollectionRegistry* const export_registry_;
290   const AbstractMetricDef* const metric_def_;
291 };
292 
293 namespace internal {
294 
295 template <typename Value>
296 void CollectValue(Value value, Point* point);
297 
298 template <>
CollectValue(int64_t value,Point * const point)299 inline void CollectValue(int64_t value, Point* const point) {
300   point->value_type = ValueType::kInt64;
301   point->int64_value = value;
302 }
303 
304 template <>
CollectValue(std::function<int64_t ()> value_fn,Point * const point)305 inline void CollectValue(std::function<int64_t()> value_fn,
306                          Point* const point) {
307   point->value_type = ValueType::kInt64;
308   point->int64_value = value_fn();
309 }
310 
311 template <>
CollectValue(std::string value,Point * const point)312 inline void CollectValue(std::string value, Point* const point) {
313   point->value_type = ValueType::kString;
314   point->string_value = std::move(value);
315 }
316 
317 template <>
CollectValue(std::function<std::string ()> value_fn,Point * const point)318 inline void CollectValue(std::function<std::string()> value_fn,
319                          Point* const point) {
320   point->value_type = ValueType::kString;
321   point->string_value = value_fn();
322 }
323 
324 template <>
CollectValue(bool value,Point * const point)325 inline void CollectValue(bool value, Point* const point) {
326   point->value_type = ValueType::kBool;
327   point->bool_value = value;
328 }
329 
330 template <>
CollectValue(std::function<bool ()> value_fn,Point * const point)331 inline void CollectValue(std::function<bool()> value_fn, Point* const point) {
332   point->value_type = ValueType::kBool;
333   point->bool_value = value_fn();
334 }
335 
336 template <>
CollectValue(HistogramProto value,Point * const point)337 inline void CollectValue(HistogramProto value, Point* const point) {
338   point->value_type = ValueType::kHistogram;
339   // This is inefficient. If and when we hit snags, we can change the API to do
340   // this more efficiently.
341   point->histogram_value = std::move(value);
342 }
343 
344 template <>
CollectValue(Percentiles value,Point * const point)345 inline void CollectValue(Percentiles value, Point* const point) {
346   point->value_type = ValueType::kPercentiles;
347   point->percentiles_value = std::move(value);
348 }
349 
350 // Used by the CollectionRegistry class to collect all the values of all the
351 // metrics in the registry. This is an implementation detail of the
352 // CollectionRegistry class, please do not depend on this.
353 //
354 // This cannot be a private nested class because we need to forward declare this
355 // so that the MetricCollector and MetricCollectorGetter classes can be friends
356 // with it.
357 //
358 // This class is thread-safe.
359 class Collector {
360  public:
Collector(const uint64 collection_time_millis)361   explicit Collector(const uint64 collection_time_millis)
362       : collected_metrics_(new CollectedMetrics()),
363         collection_time_millis_(collection_time_millis) {}
364 
365   template <MetricKind metric_kind, typename Value, int NumLabels>
GetMetricCollector(const MetricDef<metric_kind,Value,NumLabels> * const metric_def,const uint64 registration_time_millis,internal::Collector * const collector)366   MetricCollector<metric_kind, Value, NumLabels> GetMetricCollector(
367       const MetricDef<metric_kind, Value, NumLabels>* const metric_def,
368       const uint64 registration_time_millis,
369       internal::Collector* const collector) TF_LOCKS_EXCLUDED(mu_) {
370     auto* const point_set = [&]() {
371       mutex_lock l(mu_);
372       return collected_metrics_->point_set_map
373           .insert(std::make_pair(std::string(metric_def->name()),
374                                  std::unique_ptr<PointSet>(new PointSet())))
375           .first->second.get();
376     }();
377     return MetricCollector<metric_kind, Value, NumLabels>(
378         metric_def, registration_time_millis, collector, point_set);
379   }
380 
collection_time_millis()381   uint64 collection_time_millis() const { return collection_time_millis_; }
382 
383   void CollectMetricDescriptor(const AbstractMetricDef* const metric_def)
384       TF_LOCKS_EXCLUDED(mu_);
385 
386   void CollectMetricValues(
387       const CollectionRegistry::CollectionInfo& collection_info);
388 
389   std::unique_ptr<CollectedMetrics> ConsumeCollectedMetrics()
390       TF_LOCKS_EXCLUDED(mu_);
391 
392  private:
393   mutable mutex mu_;
394   std::unique_ptr<CollectedMetrics> collected_metrics_ TF_GUARDED_BY(mu_);
395   const uint64 collection_time_millis_;
396 
397   TF_DISALLOW_COPY_AND_ASSIGN(Collector);
398 };
399 
400 // Write the timestamps for the point based on the MetricKind.
401 //
402 // Gauge metrics will have start and end timestamps set to the collection time.
403 //
404 // Cumulative metrics will have the start timestamp set to the time when the
405 // collection function was registered, while the end timestamp will be set to
406 // the collection time.
407 template <MetricKind kind>
408 void WriteTimestamps(const uint64 registration_time_millis,
409                      const uint64 collection_time_millis, Point* const point);
410 
411 template <>
412 inline void WriteTimestamps<MetricKind::kGauge>(
413     const uint64 registration_time_millis, const uint64 collection_time_millis,
414     Point* const point) {
415   point->start_timestamp_millis = collection_time_millis;
416   point->end_timestamp_millis = collection_time_millis;
417 }
418 
419 template <>
420 inline void WriteTimestamps<MetricKind::kCumulative>(
421     const uint64 registration_time_millis, const uint64 collection_time_millis,
422     Point* const point) {
423   point->start_timestamp_millis = registration_time_millis;
424   // There's a chance that the clock goes backwards on the same machine, so we
425   // protect ourselves against that.
426   point->end_timestamp_millis =
427       registration_time_millis < collection_time_millis
428           ? collection_time_millis
429           : registration_time_millis;
430 }
431 
432 }  // namespace internal
433 
434 template <MetricKind metric_kind, typename Value, int NumLabels>
CollectValue(const std::array<std::string,NumLabels> & labels,Value value)435 void MetricCollector<metric_kind, Value, NumLabels>::CollectValue(
436     const std::array<std::string, NumLabels>& labels, Value value) {
437   point_set_->points.emplace_back(new Point());
438   auto* const point = point_set_->points.back().get();
439   const std::vector<std::string> label_descriptions =
440       metric_def_->label_descriptions();
441   point->labels.reserve(NumLabels);
442   for (int i = 0; i < NumLabels; ++i) {
443     point->labels.push_back({});
444     auto* const label = &point->labels.back();
445     label->name = label_descriptions[i];
446     label->value = labels[i];
447   }
448   internal::CollectValue(std::move(value), point);
449   internal::WriteTimestamps<metric_kind>(
450       registration_time_millis_, collector_->collection_time_millis(), point);
451 }
452 
453 template <MetricKind metric_kind, typename Value, int NumLabels>
Get(const MetricDef<metric_kind,Value,NumLabels> * const metric_def)454 MetricCollector<metric_kind, Value, NumLabels> MetricCollectorGetter::Get(
455     const MetricDef<metric_kind, Value, NumLabels>* const metric_def) {
456   if (allowed_metric_def_ != metric_def) {
457     LOG(FATAL) << "Expected collection for: " << allowed_metric_def_->name()
458                << " but instead got: " << metric_def->name();
459   }
460 
461   return collector_->GetMetricCollector(metric_def, registration_time_millis_,
462                                         collector_);
463 }
464 
465 class Exporter {
466  public:
~Exporter()467   virtual ~Exporter() {}
468   virtual void PeriodicallyExportMetrics() = 0;
469   virtual void ExportMetrics() = 0;
470 };
471 
472 namespace exporter_registration {
473 
474 class ExporterRegistration {
475  public:
ExporterRegistration(Exporter * exporter)476   explicit ExporterRegistration(Exporter* exporter) : exporter_(exporter) {
477     exporter_->PeriodicallyExportMetrics();
478   }
479 
480  private:
481   Exporter* exporter_;
482 };
483 
484 }  // namespace exporter_registration
485 
486 #define REGISTER_TF_METRICS_EXPORTER(exporter) \
487   REGISTER_TF_METRICS_EXPORTER_UNIQ_HELPER(__COUNTER__, exporter)
488 
489 #define REGISTER_TF_METRICS_EXPORTER_UNIQ_HELPER(ctr, exporter) \
490   REGISTER_TF_METRICS_EXPORTER_UNIQ(ctr, exporter)
491 
492 #define REGISTER_TF_METRICS_EXPORTER_UNIQ(ctr, exporter)                       \
493   static ::tensorflow::monitoring::exporter_registration::ExporterRegistration \
494       exporter_registration_##ctr(new exporter())
495 
496 }  // namespace monitoring
497 }  // namespace tensorflow
498 
499 #endif  // IS_MOBILE_PLATFORM
500 #endif  // TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
501