1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
17 #define TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
18
19 // clang-format off
20 // Required for IS_MOBILE_PLATFORM
21 #include "tensorflow/core/platform/platform.h"
22 // clang-format on
23
24 // We use a null implementation for mobile platforms.
25 #ifdef IS_MOBILE_PLATFORM
26
27 #include <functional>
28 #include <map>
29 #include <memory>
30
31 #include "tensorflow/core/lib/monitoring/metric_def.h"
32 #include "tensorflow/core/platform/macros.h"
33
34 namespace tensorflow {
35 namespace monitoring {
36
37 // MetricCollector which has a null implementation.
38 template <MetricKind metric_kind, typename Value, int NumLabels>
39 class MetricCollector {
40 public:
41 ~MetricCollector() = default;
42
CollectValue(const std::array<std::string,NumLabels> & labels,Value value)43 void CollectValue(const std::array<std::string, NumLabels>& labels,
44 Value value) {}
45
46 private:
47 friend class MetricCollectorGetter;
48
MetricCollector()49 MetricCollector() {}
50 };
51
52 // MetricCollectorGetter which has a null implementation.
53 class MetricCollectorGetter {
54 public:
55 template <MetricKind metric_kind, typename Value, int NumLabels>
Get(const MetricDef<metric_kind,Value,NumLabels> * const metric_def)56 MetricCollector<metric_kind, Value, NumLabels> Get(
57 const MetricDef<metric_kind, Value, NumLabels>* const metric_def) {
58 return MetricCollector<metric_kind, Value, NumLabels>();
59 }
60
61 private:
MetricCollectorGetter()62 MetricCollectorGetter() {}
63 };
64
65 // CollectionRegistry which has a null implementation.
66 class CollectionRegistry {
67 public:
68 ~CollectionRegistry() = default;
69
Default()70 static CollectionRegistry* Default() { return new CollectionRegistry(); }
71
72 using CollectionFunction = std::function<void(MetricCollectorGetter getter)>;
73
74 // RegistrationHandle which has a null implementation.
75 class RegistrationHandle {
76 public:
RegistrationHandle()77 RegistrationHandle() {}
78
~RegistrationHandle()79 ~RegistrationHandle() {}
80 };
81
Register(const AbstractMetricDef * metric_def,const CollectionFunction & collection_function)82 std::unique_ptr<RegistrationHandle> Register(
83 const AbstractMetricDef* metric_def,
84 const CollectionFunction& collection_function) {
85 return std::unique_ptr<RegistrationHandle>(new RegistrationHandle());
86 }
87
88 private:
CollectionRegistry()89 CollectionRegistry() {}
90
91 TF_DISALLOW_COPY_AND_ASSIGN(CollectionRegistry);
92 };
93
94 } // namespace monitoring
95 } // namespace tensorflow
96 #else // !defined(IS_MOBILE_PLATFORM)
97
98 #include <functional>
99 #include <map>
100 #include <memory>
101 #include <utility>
102
103 #include "tensorflow/core/framework/summary.pb.h"
104 #include "tensorflow/core/lib/monitoring/collected_metrics.h"
105 #include "tensorflow/core/lib/monitoring/metric_def.h"
106 #include "tensorflow/core/lib/monitoring/types.h"
107 #include "tensorflow/core/platform/env.h"
108 #include "tensorflow/core/platform/logging.h"
109 #include "tensorflow/core/platform/macros.h"
110 #include "tensorflow/core/platform/mutex.h"
111 #include "tensorflow/core/platform/stringpiece.h"
112 #include "tensorflow/core/platform/thread_annotations.h"
113 #include "tensorflow/core/platform/types.h"
114
115 namespace tensorflow {
116 namespace monitoring {
117
118 namespace test_util {
119 class CollectionRegistryTestAccess;
120 } // namespace test_util
121
122 namespace internal {
123 class Collector;
124 } // namespace internal
125
126 // Metric implementations would get an instance of this class using the
127 // MetricCollectorGetter in the collection-function lambda, so that their values
128 // can be collected.
129 //
130 // Read the documentation on CollectionRegistry::Register() for more details.
131 //
132 // For example:
133 // auto metric_collector = metric_collector_getter->Get(&metric_def);
134 // metric_collector.CollectValue(some_labels, some_value);
135 // metric_collector.CollectValue(others_labels, other_value);
136 //
137 // This class is NOT thread-safe.
138 template <MetricKind metric_kind, typename Value, int NumLabels>
139 class MetricCollector {
140 public:
141 ~MetricCollector() = default;
142
143 // Collects the value with these labels.
144 void CollectValue(const std::array<std::string, NumLabels>& labels,
145 Value value);
146
147 private:
148 friend class internal::Collector;
149
MetricCollector(const MetricDef<metric_kind,Value,NumLabels> * const metric_def,const uint64 registration_time_millis,internal::Collector * const collector,PointSet * const point_set)150 MetricCollector(
151 const MetricDef<metric_kind, Value, NumLabels>* const metric_def,
152 const uint64 registration_time_millis,
153 internal::Collector* const collector, PointSet* const point_set)
154 : metric_def_(metric_def),
155 registration_time_millis_(registration_time_millis),
156 collector_(collector),
157 point_set_(point_set) {
158 point_set_->metric_name = std::string(metric_def->name());
159 }
160
161 const MetricDef<metric_kind, Value, NumLabels>* const metric_def_;
162 const uint64 registration_time_millis_;
163 internal::Collector* const collector_;
164 PointSet* const point_set_;
165
166 // This is made copyable because we can't hand out references of this class
167 // from MetricCollectorGetter because this class is templatized, and we need
168 // MetricCollectorGetter not to be templatized and hence MetricCollectorGetter
169 // can't own an instance of this class.
170 };
171
172 // Returns a MetricCollector with the same template parameters as the
173 // metric-definition, so that the values of a metric can be collected.
174 //
175 // The collection-function defined by a metric takes this as a parameter.
176 //
177 // Read the documentation on CollectionRegistry::Register() for more details.
178 class MetricCollectorGetter {
179 public:
180 // Returns the MetricCollector with the same template parameters as the
181 // metric_def.
182 template <MetricKind metric_kind, typename Value, int NumLabels>
183 MetricCollector<metric_kind, Value, NumLabels> Get(
184 const MetricDef<metric_kind, Value, NumLabels>* const metric_def);
185
186 private:
187 friend class internal::Collector;
188
MetricCollectorGetter(internal::Collector * const collector,const AbstractMetricDef * const allowed_metric_def,const uint64 registration_time_millis)189 MetricCollectorGetter(internal::Collector* const collector,
190 const AbstractMetricDef* const allowed_metric_def,
191 const uint64 registration_time_millis)
192 : collector_(collector),
193 allowed_metric_def_(allowed_metric_def),
194 registration_time_millis_(registration_time_millis) {}
195
196 internal::Collector* const collector_;
197 const AbstractMetricDef* const allowed_metric_def_;
198 const uint64 registration_time_millis_;
199 };
200
201 // A collection registry for metrics.
202 //
203 // Metrics are registered here so that their state can be collected later and
204 // exported.
205 //
206 // This class is thread-safe.
207 class CollectionRegistry {
208 public:
209 ~CollectionRegistry() = default;
210
211 // Returns the default registry for the process.
212 //
213 // This registry belongs to this library and should never be deleted.
214 static CollectionRegistry* Default();
215
216 using CollectionFunction = std::function<void(MetricCollectorGetter getter)>;
217
218 // Registers the metric and the collection-function which can be used to
219 // collect its values. Returns a Registration object, which when upon
220 // destruction would cause the metric to be unregistered from this registry.
221 //
222 // IMPORTANT: Delete the handle before the metric-def is deleted.
223 //
224 // Example usage;
225 // CollectionRegistry::Default()->Register(
226 // &metric_def,
227 // [&](MetricCollectorGetter getter) {
228 // auto metric_collector = getter.Get(&metric_def);
229 // for (const auto& cell : cells) {
230 // metric_collector.CollectValue(cell.labels(), cell.value());
231 // }
232 // });
233 class RegistrationHandle;
234 std::unique_ptr<RegistrationHandle> Register(
235 const AbstractMetricDef* metric_def,
236 const CollectionFunction& collection_function)
237 TF_LOCKS_EXCLUDED(mu_) TF_MUST_USE_RESULT;
238
239 // Options for collecting metrics.
240 struct CollectMetricsOptions {
CollectMetricsOptionsCollectMetricsOptions241 CollectMetricsOptions() {}
242 bool collect_metric_descriptors = true;
243 };
244 // Goes through all the registered metrics, collects their definitions
245 // (optionally) and current values and returns them in a standard format.
246 std::unique_ptr<CollectedMetrics> CollectMetrics(
247 const CollectMetricsOptions& options) const;
248
249 private:
250 friend class test_util::CollectionRegistryTestAccess;
251 friend class internal::Collector;
252
253 explicit CollectionRegistry(Env* env);
254
255 // Unregisters the metric from this registry. This is private because the
256 // public interface provides a Registration handle which automatically calls
257 // this upon destruction.
258 void Unregister(const AbstractMetricDef* metric_def) TF_LOCKS_EXCLUDED(mu_);
259
260 // TF environment, mainly used for timestamping.
261 Env* const env_;
262
263 mutable mutex mu_;
264
265 // Information required for collection.
266 struct CollectionInfo {
267 const AbstractMetricDef* const metric_def;
268 CollectionFunction collection_function;
269 uint64 registration_time_millis;
270 };
271 std::map<StringPiece, CollectionInfo> registry_ TF_GUARDED_BY(mu_);
272
273 TF_DISALLOW_COPY_AND_ASSIGN(CollectionRegistry);
274 };
275
276 ////
277 // Implementation details follow. API readers may skip.
278 ////
279
280 class CollectionRegistry::RegistrationHandle {
281 public:
RegistrationHandle(CollectionRegistry * const export_registry,const AbstractMetricDef * const metric_def)282 RegistrationHandle(CollectionRegistry* const export_registry,
283 const AbstractMetricDef* const metric_def)
284 : export_registry_(export_registry), metric_def_(metric_def) {}
285
~RegistrationHandle()286 ~RegistrationHandle() { export_registry_->Unregister(metric_def_); }
287
288 private:
289 CollectionRegistry* const export_registry_;
290 const AbstractMetricDef* const metric_def_;
291 };
292
293 namespace internal {
294
295 template <typename Value>
296 void CollectValue(Value value, Point* point);
297
298 template <>
CollectValue(int64_t value,Point * const point)299 inline void CollectValue(int64_t value, Point* const point) {
300 point->value_type = ValueType::kInt64;
301 point->int64_value = value;
302 }
303
304 template <>
CollectValue(std::function<int64 ()> value_fn,Point * const point)305 inline void CollectValue(std::function<int64()> value_fn, Point* const point) {
306 point->value_type = ValueType::kInt64;
307 point->int64_value = value_fn();
308 }
309
310 template <>
CollectValue(std::string value,Point * const point)311 inline void CollectValue(std::string value, Point* const point) {
312 point->value_type = ValueType::kString;
313 point->string_value = std::move(value);
314 }
315
316 template <>
CollectValue(std::function<std::string ()> value_fn,Point * const point)317 inline void CollectValue(std::function<std::string()> value_fn,
318 Point* const point) {
319 point->value_type = ValueType::kString;
320 point->string_value = value_fn();
321 }
322
323 template <>
CollectValue(bool value,Point * const point)324 inline void CollectValue(bool value, Point* const point) {
325 point->value_type = ValueType::kBool;
326 point->bool_value = value;
327 }
328
329 template <>
CollectValue(std::function<bool ()> value_fn,Point * const point)330 inline void CollectValue(std::function<bool()> value_fn, Point* const point) {
331 point->value_type = ValueType::kBool;
332 point->bool_value = value_fn();
333 }
334
335 template <>
CollectValue(HistogramProto value,Point * const point)336 inline void CollectValue(HistogramProto value, Point* const point) {
337 point->value_type = ValueType::kHistogram;
338 // This is inefficient. If and when we hit snags, we can change the API to do
339 // this more efficiently.
340 point->histogram_value = std::move(value);
341 }
342
343 template <>
CollectValue(Percentiles value,Point * const point)344 inline void CollectValue(Percentiles value, Point* const point) {
345 point->value_type = ValueType::kPercentiles;
346 point->percentiles_value = std::move(value);
347 }
348
349 // Used by the CollectionRegistry class to collect all the values of all the
350 // metrics in the registry. This is an implementation detail of the
351 // CollectionRegistry class, please do not depend on this.
352 //
353 // This cannot be a private nested class because we need to forward declare this
354 // so that the MetricCollector and MetricCollectorGetter classes can be friends
355 // with it.
356 //
357 // This class is thread-safe.
358 class Collector {
359 public:
Collector(const uint64 collection_time_millis)360 explicit Collector(const uint64 collection_time_millis)
361 : collected_metrics_(new CollectedMetrics()),
362 collection_time_millis_(collection_time_millis) {}
363
364 template <MetricKind metric_kind, typename Value, int NumLabels>
GetMetricCollector(const MetricDef<metric_kind,Value,NumLabels> * const metric_def,const uint64 registration_time_millis,internal::Collector * const collector)365 MetricCollector<metric_kind, Value, NumLabels> GetMetricCollector(
366 const MetricDef<metric_kind, Value, NumLabels>* const metric_def,
367 const uint64 registration_time_millis,
368 internal::Collector* const collector) TF_LOCKS_EXCLUDED(mu_) {
369 auto* const point_set = [&]() {
370 mutex_lock l(mu_);
371 return collected_metrics_->point_set_map
372 .insert(std::make_pair(std::string(metric_def->name()),
373 std::unique_ptr<PointSet>(new PointSet())))
374 .first->second.get();
375 }();
376 return MetricCollector<metric_kind, Value, NumLabels>(
377 metric_def, registration_time_millis, collector, point_set);
378 }
379
collection_time_millis()380 uint64 collection_time_millis() const { return collection_time_millis_; }
381
382 void CollectMetricDescriptor(const AbstractMetricDef* const metric_def)
383 TF_LOCKS_EXCLUDED(mu_);
384
385 void CollectMetricValues(
386 const CollectionRegistry::CollectionInfo& collection_info);
387
388 std::unique_ptr<CollectedMetrics> ConsumeCollectedMetrics()
389 TF_LOCKS_EXCLUDED(mu_);
390
391 private:
392 mutable mutex mu_;
393 std::unique_ptr<CollectedMetrics> collected_metrics_ TF_GUARDED_BY(mu_);
394 const uint64 collection_time_millis_;
395
396 TF_DISALLOW_COPY_AND_ASSIGN(Collector);
397 };
398
399 // Write the timestamps for the point based on the MetricKind.
400 //
401 // Gauge metrics will have start and end timestamps set to the collection time.
402 //
403 // Cumulative metrics will have the start timestamp set to the time when the
404 // collection function was registered, while the end timestamp will be set to
405 // the collection time.
406 template <MetricKind kind>
407 void WriteTimestamps(const uint64 registration_time_millis,
408 const uint64 collection_time_millis, Point* const point);
409
410 template <>
411 inline void WriteTimestamps<MetricKind::kGauge>(
412 const uint64 registration_time_millis, const uint64 collection_time_millis,
413 Point* const point) {
414 point->start_timestamp_millis = collection_time_millis;
415 point->end_timestamp_millis = collection_time_millis;
416 }
417
418 template <>
419 inline void WriteTimestamps<MetricKind::kCumulative>(
420 const uint64 registration_time_millis, const uint64 collection_time_millis,
421 Point* const point) {
422 point->start_timestamp_millis = registration_time_millis;
423 // There's a chance that the clock goes backwards on the same machine, so we
424 // protect ourselves against that.
425 point->end_timestamp_millis =
426 registration_time_millis < collection_time_millis
427 ? collection_time_millis
428 : registration_time_millis;
429 }
430
431 } // namespace internal
432
433 template <MetricKind metric_kind, typename Value, int NumLabels>
CollectValue(const std::array<std::string,NumLabels> & labels,Value value)434 void MetricCollector<metric_kind, Value, NumLabels>::CollectValue(
435 const std::array<std::string, NumLabels>& labels, Value value) {
436 point_set_->points.emplace_back(new Point());
437 auto* const point = point_set_->points.back().get();
438 const std::vector<std::string> label_descriptions =
439 metric_def_->label_descriptions();
440 point->labels.reserve(NumLabels);
441 for (int i = 0; i < NumLabels; ++i) {
442 point->labels.push_back({});
443 auto* const label = &point->labels.back();
444 label->name = label_descriptions[i];
445 label->value = labels[i];
446 }
447 internal::CollectValue(std::move(value), point);
448 internal::WriteTimestamps<metric_kind>(
449 registration_time_millis_, collector_->collection_time_millis(), point);
450 }
451
452 template <MetricKind metric_kind, typename Value, int NumLabels>
Get(const MetricDef<metric_kind,Value,NumLabels> * const metric_def)453 MetricCollector<metric_kind, Value, NumLabels> MetricCollectorGetter::Get(
454 const MetricDef<metric_kind, Value, NumLabels>* const metric_def) {
455 if (allowed_metric_def_ != metric_def) {
456 LOG(FATAL) << "Expected collection for: " << allowed_metric_def_->name()
457 << " but instead got: " << metric_def->name();
458 }
459
460 return collector_->GetMetricCollector(metric_def, registration_time_millis_,
461 collector_);
462 }
463
464 class Exporter {
465 public:
~Exporter()466 virtual ~Exporter() {}
467 virtual void PeriodicallyExportMetrics() = 0;
468 virtual void ExportMetrics() = 0;
469 };
470
471 namespace exporter_registration {
472
473 class ExporterRegistration {
474 public:
ExporterRegistration(Exporter * exporter)475 explicit ExporterRegistration(Exporter* exporter) : exporter_(exporter) {
476 exporter_->PeriodicallyExportMetrics();
477 }
478
479 private:
480 Exporter* exporter_;
481 };
482
483 } // namespace exporter_registration
484
485 #define REGISTER_TF_METRICS_EXPORTER(exporter) \
486 REGISTER_TF_METRICS_EXPORTER_UNIQ_HELPER(__COUNTER__, exporter)
487
488 #define REGISTER_TF_METRICS_EXPORTER_UNIQ_HELPER(ctr, exporter) \
489 REGISTER_TF_METRICS_EXPORTER_UNIQ(ctr, exporter)
490
491 #define REGISTER_TF_METRICS_EXPORTER_UNIQ(ctr, exporter) \
492 static ::tensorflow::monitoring::exporter_registration::ExporterRegistration \
493 exporter_registration_##ctr(new exporter())
494
495 } // namespace monitoring
496 } // namespace tensorflow
497
498 #endif // IS_MOBILE_PLATFORM
499 #endif // TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
500