1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
17 #define TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
18
19 // clang-format off
20 // Required for IS_MOBILE_PLATFORM
21 #include "tensorflow/core/platform/platform.h"
22 // clang-format on
23
24 // We use a null implementation for mobile platforms.
25 #ifdef IS_MOBILE_PLATFORM
26
27 #include <functional>
28 #include <map>
29 #include <memory>
30
31 #include "tensorflow/core/lib/monitoring/metric_def.h"
32 #include "tensorflow/core/platform/macros.h"
33
34 namespace tensorflow {
35 namespace monitoring {
36
37 // MetricCollector which has a null implementation.
38 template <MetricKind metric_kind, typename Value, int NumLabels>
39 class MetricCollector {
40 public:
41 ~MetricCollector() = default;
42
CollectValue(const std::array<std::string,NumLabels> & labels,Value value)43 void CollectValue(const std::array<std::string, NumLabels>& labels,
44 Value value) {}
45
46 private:
47 friend class MetricCollectorGetter;
48
MetricCollector()49 MetricCollector() {}
50 };
51
52 // MetricCollectorGetter which has a null implementation.
53 class MetricCollectorGetter {
54 public:
55 template <MetricKind metric_kind, typename Value, int NumLabels>
Get(const MetricDef<metric_kind,Value,NumLabels> * const metric_def)56 MetricCollector<metric_kind, Value, NumLabels> Get(
57 const MetricDef<metric_kind, Value, NumLabels>* const metric_def) {
58 return MetricCollector<metric_kind, Value, NumLabels>();
59 }
60
61 private:
MetricCollectorGetter()62 MetricCollectorGetter() {}
63 };
64
65 // CollectionRegistry which has a null implementation.
66 class CollectionRegistry {
67 public:
68 ~CollectionRegistry() = default;
69
Default()70 static CollectionRegistry* Default() { return new CollectionRegistry(); }
71
72 using CollectionFunction = std::function<void(MetricCollectorGetter getter)>;
73
74 // RegistrationHandle which has a null implementation.
75 class RegistrationHandle {
76 public:
RegistrationHandle()77 RegistrationHandle() {}
78
~RegistrationHandle()79 ~RegistrationHandle() {}
80 };
81
Register(const AbstractMetricDef * metric_def,const CollectionFunction & collection_function)82 std::unique_ptr<RegistrationHandle> Register(
83 const AbstractMetricDef* metric_def,
84 const CollectionFunction& collection_function) {
85 return std::unique_ptr<RegistrationHandle>(new RegistrationHandle());
86 }
87
88 private:
CollectionRegistry()89 CollectionRegistry() {}
90
91 TF_DISALLOW_COPY_AND_ASSIGN(CollectionRegistry);
92 };
93
94 } // namespace monitoring
95 } // namespace tensorflow
96 #else // !defined(IS_MOBILE_PLATFORM)
97
98 #include <functional>
99 #include <map>
100 #include <memory>
101 #include <utility>
102
103 #include "tensorflow/core/framework/summary.pb.h"
104 #include "tensorflow/core/lib/monitoring/collected_metrics.h"
105 #include "tensorflow/core/lib/monitoring/metric_def.h"
106 #include "tensorflow/core/lib/monitoring/types.h"
107 #include "tensorflow/core/platform/env.h"
108 #include "tensorflow/core/platform/logging.h"
109 #include "tensorflow/core/platform/macros.h"
110 #include "tensorflow/core/platform/mutex.h"
111 #include "tensorflow/core/platform/stringpiece.h"
112 #include "tensorflow/core/platform/thread_annotations.h"
113 #include "tensorflow/core/platform/types.h"
114
115 namespace tensorflow {
116 namespace monitoring {
117
118 namespace test_util {
119 class CollectionRegistryTestAccess;
120 } // namespace test_util
121
122 namespace internal {
123 class Collector;
124 } // namespace internal
125
126 // Metric implementations would get an instance of this class using the
127 // MetricCollectorGetter in the collection-function lambda, so that their values
128 // can be collected.
129 //
130 // Read the documentation on CollectionRegistry::Register() for more details.
131 //
132 // For example:
133 // auto metric_collector = metric_collector_getter->Get(&metric_def);
134 // metric_collector.CollectValue(some_labels, some_value);
135 // metric_collector.CollectValue(others_labels, other_value);
136 //
137 // This class is NOT thread-safe.
138 template <MetricKind metric_kind, typename Value, int NumLabels>
139 class MetricCollector {
140 public:
141 ~MetricCollector() = default;
142
143 // Collects the value with these labels.
144 void CollectValue(const std::array<std::string, NumLabels>& labels,
145 Value value);
146
147 private:
148 friend class internal::Collector;
149
MetricCollector(const MetricDef<metric_kind,Value,NumLabels> * const metric_def,const uint64 registration_time_millis,internal::Collector * const collector,PointSet * const point_set)150 MetricCollector(
151 const MetricDef<metric_kind, Value, NumLabels>* const metric_def,
152 const uint64 registration_time_millis,
153 internal::Collector* const collector, PointSet* const point_set)
154 : metric_def_(metric_def),
155 registration_time_millis_(registration_time_millis),
156 collector_(collector),
157 point_set_(point_set) {
158 point_set_->metric_name = std::string(metric_def->name());
159 }
160
161 const MetricDef<metric_kind, Value, NumLabels>* const metric_def_;
162 const uint64 registration_time_millis_;
163 internal::Collector* const collector_;
164 PointSet* const point_set_;
165
166 // This is made copyable because we can't hand out references of this class
167 // from MetricCollectorGetter because this class is templatized, and we need
168 // MetricCollectorGetter not to be templatized and hence MetricCollectorGetter
169 // can't own an instance of this class.
170 };
171
172 // Returns a MetricCollector with the same template parameters as the
173 // metric-definition, so that the values of a metric can be collected.
174 //
175 // The collection-function defined by a metric takes this as a parameter.
176 //
177 // Read the documentation on CollectionRegistry::Register() for more details.
178 class MetricCollectorGetter {
179 public:
180 // Returns the MetricCollector with the same template parameters as the
181 // metric_def.
182 template <MetricKind metric_kind, typename Value, int NumLabels>
183 MetricCollector<metric_kind, Value, NumLabels> Get(
184 const MetricDef<metric_kind, Value, NumLabels>* const metric_def);
185
186 private:
187 friend class internal::Collector;
188
MetricCollectorGetter(internal::Collector * const collector,const AbstractMetricDef * const allowed_metric_def,const uint64 registration_time_millis)189 MetricCollectorGetter(internal::Collector* const collector,
190 const AbstractMetricDef* const allowed_metric_def,
191 const uint64 registration_time_millis)
192 : collector_(collector),
193 allowed_metric_def_(allowed_metric_def),
194 registration_time_millis_(registration_time_millis) {}
195
196 internal::Collector* const collector_;
197 const AbstractMetricDef* const allowed_metric_def_;
198 const uint64 registration_time_millis_;
199 };
200
201 // A collection registry for metrics.
202 //
203 // Metrics are registered here so that their state can be collected later and
204 // exported.
205 //
206 // This class is thread-safe.
207 class CollectionRegistry {
208 public:
209 ~CollectionRegistry() = default;
210
211 // Returns the default registry for the process.
212 //
213 // This registry belongs to this library and should never be deleted.
214 static CollectionRegistry* Default();
215
216 using CollectionFunction = std::function<void(MetricCollectorGetter getter)>;
217
218 // Registers the metric and the collection-function which can be used to
219 // collect its values. Returns a Registration object, which when upon
220 // destruction would cause the metric to be unregistered from this registry.
221 //
222 // IMPORTANT: Delete the handle before the metric-def is deleted.
223 //
224 // Example usage;
225 // CollectionRegistry::Default()->Register(
226 // &metric_def,
227 // [&](MetricCollectorGetter getter) {
228 // auto metric_collector = getter.Get(&metric_def);
229 // for (const auto& cell : cells) {
230 // metric_collector.CollectValue(cell.labels(), cell.value());
231 // }
232 // });
233 class RegistrationHandle;
234 std::unique_ptr<RegistrationHandle> Register(
235 const AbstractMetricDef* metric_def,
236 const CollectionFunction& collection_function)
237 TF_LOCKS_EXCLUDED(mu_) TF_MUST_USE_RESULT;
238
239 // Options for collecting metrics.
240 struct CollectMetricsOptions {
CollectMetricsOptionsCollectMetricsOptions241 CollectMetricsOptions() {}
242 bool collect_metric_descriptors = true;
243 };
244 // Goes through all the registered metrics, collects their definitions
245 // (optionally) and current values and returns them in a standard format.
246 std::unique_ptr<CollectedMetrics> CollectMetrics(
247 const CollectMetricsOptions& options) const;
248
249 private:
250 friend class test_util::CollectionRegistryTestAccess;
251 friend class internal::Collector;
252
253 explicit CollectionRegistry(Env* env);
254
255 // Unregisters the metric from this registry. This is private because the
256 // public interface provides a Registration handle which automatically calls
257 // this upon destruction.
258 void Unregister(const AbstractMetricDef* metric_def) TF_LOCKS_EXCLUDED(mu_);
259
260 // TF environment, mainly used for timestamping.
261 Env* const env_;
262
263 mutable mutex mu_;
264
265 // Information required for collection.
266 struct CollectionInfo {
267 const AbstractMetricDef* const metric_def;
268 CollectionFunction collection_function;
269 uint64 registration_time_millis;
270 };
271 std::map<StringPiece, CollectionInfo> registry_ TF_GUARDED_BY(mu_);
272
273 TF_DISALLOW_COPY_AND_ASSIGN(CollectionRegistry);
274 };
275
276 ////
277 // Implementation details follow. API readers may skip.
278 ////
279
280 class CollectionRegistry::RegistrationHandle {
281 public:
RegistrationHandle(CollectionRegistry * const export_registry,const AbstractMetricDef * const metric_def)282 RegistrationHandle(CollectionRegistry* const export_registry,
283 const AbstractMetricDef* const metric_def)
284 : export_registry_(export_registry), metric_def_(metric_def) {}
285
~RegistrationHandle()286 ~RegistrationHandle() { export_registry_->Unregister(metric_def_); }
287
288 private:
289 CollectionRegistry* const export_registry_;
290 const AbstractMetricDef* const metric_def_;
291 };
292
293 namespace internal {
294
295 template <typename Value>
296 void CollectValue(Value value, Point* point);
297
298 template <>
CollectValue(int64_t value,Point * const point)299 inline void CollectValue(int64_t value, Point* const point) {
300 point->value_type = ValueType::kInt64;
301 point->int64_value = value;
302 }
303
304 template <>
CollectValue(std::function<int64_t ()> value_fn,Point * const point)305 inline void CollectValue(std::function<int64_t()> value_fn,
306 Point* const point) {
307 point->value_type = ValueType::kInt64;
308 point->int64_value = value_fn();
309 }
310
311 template <>
CollectValue(std::string value,Point * const point)312 inline void CollectValue(std::string value, Point* const point) {
313 point->value_type = ValueType::kString;
314 point->string_value = std::move(value);
315 }
316
317 template <>
CollectValue(std::function<std::string ()> value_fn,Point * const point)318 inline void CollectValue(std::function<std::string()> value_fn,
319 Point* const point) {
320 point->value_type = ValueType::kString;
321 point->string_value = value_fn();
322 }
323
324 template <>
CollectValue(bool value,Point * const point)325 inline void CollectValue(bool value, Point* const point) {
326 point->value_type = ValueType::kBool;
327 point->bool_value = value;
328 }
329
330 template <>
CollectValue(std::function<bool ()> value_fn,Point * const point)331 inline void CollectValue(std::function<bool()> value_fn, Point* const point) {
332 point->value_type = ValueType::kBool;
333 point->bool_value = value_fn();
334 }
335
336 template <>
CollectValue(HistogramProto value,Point * const point)337 inline void CollectValue(HistogramProto value, Point* const point) {
338 point->value_type = ValueType::kHistogram;
339 // This is inefficient. If and when we hit snags, we can change the API to do
340 // this more efficiently.
341 point->histogram_value = std::move(value);
342 }
343
344 template <>
CollectValue(Percentiles value,Point * const point)345 inline void CollectValue(Percentiles value, Point* const point) {
346 point->value_type = ValueType::kPercentiles;
347 point->percentiles_value = std::move(value);
348 }
349
350 // Used by the CollectionRegistry class to collect all the values of all the
351 // metrics in the registry. This is an implementation detail of the
352 // CollectionRegistry class, please do not depend on this.
353 //
354 // This cannot be a private nested class because we need to forward declare this
355 // so that the MetricCollector and MetricCollectorGetter classes can be friends
356 // with it.
357 //
358 // This class is thread-safe.
359 class Collector {
360 public:
Collector(const uint64 collection_time_millis)361 explicit Collector(const uint64 collection_time_millis)
362 : collected_metrics_(new CollectedMetrics()),
363 collection_time_millis_(collection_time_millis) {}
364
365 template <MetricKind metric_kind, typename Value, int NumLabels>
GetMetricCollector(const MetricDef<metric_kind,Value,NumLabels> * const metric_def,const uint64 registration_time_millis,internal::Collector * const collector)366 MetricCollector<metric_kind, Value, NumLabels> GetMetricCollector(
367 const MetricDef<metric_kind, Value, NumLabels>* const metric_def,
368 const uint64 registration_time_millis,
369 internal::Collector* const collector) TF_LOCKS_EXCLUDED(mu_) {
370 auto* const point_set = [&]() {
371 mutex_lock l(mu_);
372 return collected_metrics_->point_set_map
373 .insert(std::make_pair(std::string(metric_def->name()),
374 std::unique_ptr<PointSet>(new PointSet())))
375 .first->second.get();
376 }();
377 return MetricCollector<metric_kind, Value, NumLabels>(
378 metric_def, registration_time_millis, collector, point_set);
379 }
380
collection_time_millis()381 uint64 collection_time_millis() const { return collection_time_millis_; }
382
383 void CollectMetricDescriptor(const AbstractMetricDef* const metric_def)
384 TF_LOCKS_EXCLUDED(mu_);
385
386 void CollectMetricValues(
387 const CollectionRegistry::CollectionInfo& collection_info);
388
389 std::unique_ptr<CollectedMetrics> ConsumeCollectedMetrics()
390 TF_LOCKS_EXCLUDED(mu_);
391
392 private:
393 mutable mutex mu_;
394 std::unique_ptr<CollectedMetrics> collected_metrics_ TF_GUARDED_BY(mu_);
395 const uint64 collection_time_millis_;
396
397 TF_DISALLOW_COPY_AND_ASSIGN(Collector);
398 };
399
400 // Write the timestamps for the point based on the MetricKind.
401 //
402 // Gauge metrics will have start and end timestamps set to the collection time.
403 //
404 // Cumulative metrics will have the start timestamp set to the time when the
405 // collection function was registered, while the end timestamp will be set to
406 // the collection time.
407 template <MetricKind kind>
408 void WriteTimestamps(const uint64 registration_time_millis,
409 const uint64 collection_time_millis, Point* const point);
410
411 template <>
412 inline void WriteTimestamps<MetricKind::kGauge>(
413 const uint64 registration_time_millis, const uint64 collection_time_millis,
414 Point* const point) {
415 point->start_timestamp_millis = collection_time_millis;
416 point->end_timestamp_millis = collection_time_millis;
417 }
418
419 template <>
420 inline void WriteTimestamps<MetricKind::kCumulative>(
421 const uint64 registration_time_millis, const uint64 collection_time_millis,
422 Point* const point) {
423 point->start_timestamp_millis = registration_time_millis;
424 // There's a chance that the clock goes backwards on the same machine, so we
425 // protect ourselves against that.
426 point->end_timestamp_millis =
427 registration_time_millis < collection_time_millis
428 ? collection_time_millis
429 : registration_time_millis;
430 }
431
432 } // namespace internal
433
434 template <MetricKind metric_kind, typename Value, int NumLabels>
CollectValue(const std::array<std::string,NumLabels> & labels,Value value)435 void MetricCollector<metric_kind, Value, NumLabels>::CollectValue(
436 const std::array<std::string, NumLabels>& labels, Value value) {
437 point_set_->points.emplace_back(new Point());
438 auto* const point = point_set_->points.back().get();
439 const std::vector<std::string> label_descriptions =
440 metric_def_->label_descriptions();
441 point->labels.reserve(NumLabels);
442 for (int i = 0; i < NumLabels; ++i) {
443 point->labels.push_back({});
444 auto* const label = &point->labels.back();
445 label->name = label_descriptions[i];
446 label->value = labels[i];
447 }
448 internal::CollectValue(std::move(value), point);
449 internal::WriteTimestamps<metric_kind>(
450 registration_time_millis_, collector_->collection_time_millis(), point);
451 }
452
453 template <MetricKind metric_kind, typename Value, int NumLabels>
Get(const MetricDef<metric_kind,Value,NumLabels> * const metric_def)454 MetricCollector<metric_kind, Value, NumLabels> MetricCollectorGetter::Get(
455 const MetricDef<metric_kind, Value, NumLabels>* const metric_def) {
456 if (allowed_metric_def_ != metric_def) {
457 LOG(FATAL) << "Expected collection for: " << allowed_metric_def_->name()
458 << " but instead got: " << metric_def->name();
459 }
460
461 return collector_->GetMetricCollector(metric_def, registration_time_millis_,
462 collector_);
463 }
464
465 class Exporter {
466 public:
~Exporter()467 virtual ~Exporter() {}
468 virtual void PeriodicallyExportMetrics() = 0;
469 virtual void ExportMetrics() = 0;
470 };
471
472 namespace exporter_registration {
473
474 class ExporterRegistration {
475 public:
ExporterRegistration(Exporter * exporter)476 explicit ExporterRegistration(Exporter* exporter) : exporter_(exporter) {
477 exporter_->PeriodicallyExportMetrics();
478 }
479
480 private:
481 Exporter* exporter_;
482 };
483
484 } // namespace exporter_registration
485
486 #define REGISTER_TF_METRICS_EXPORTER(exporter) \
487 REGISTER_TF_METRICS_EXPORTER_UNIQ_HELPER(__COUNTER__, exporter)
488
489 #define REGISTER_TF_METRICS_EXPORTER_UNIQ_HELPER(ctr, exporter) \
490 REGISTER_TF_METRICS_EXPORTER_UNIQ(ctr, exporter)
491
492 #define REGISTER_TF_METRICS_EXPORTER_UNIQ(ctr, exporter) \
493 static ::tensorflow::monitoring::exporter_registration::ExporterRegistration \
494 exporter_registration_##ctr(new exporter())
495
496 } // namespace monitoring
497 } // namespace tensorflow
498
499 #endif // IS_MOBILE_PLATFORM
500 #endif // TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
501