1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
17 #define TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
18
19 #include <map>
20 #include <memory>
21
22 #include "tensorflow/core/framework/summary.pb.h"
23 #include "tensorflow/core/lib/core/stringpiece.h"
24 #include "tensorflow/core/lib/monitoring/collected_metrics.h"
25 #include "tensorflow/core/lib/monitoring/metric_def.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow/core/platform/mutex.h"
30 #include "tensorflow/core/platform/thread_annotations.h"
31 #include "tensorflow/core/platform/types.h"
32
33 namespace tensorflow {
34 namespace monitoring {
35
36 namespace test_util {
37 class CollectionRegistryTestAccess;
38 } // namespace test_util
39
40 namespace internal {
41 class Collector;
42 } // namespace internal
43
44 // Metric implementations would get an instance of this class using the
45 // MetricCollectorGetter in the collection-function lambda, so that their values
46 // can be collected.
47 //
48 // Read the documentation on CollectionRegistry::Register() for more details.
49 //
50 // For example:
51 // auto metric_collector = metric_collector_getter->Get(&metric_def);
52 // metric_collector.CollectValue(some_labels, some_value);
53 // metric_collector.CollectValue(others_labels, other_value);
54 //
55 // This class is NOT thread-safe.
56 template <MetricKind metric_kind, typename Value, int NumLabels>
57 class MetricCollector {
58 public:
59 ~MetricCollector() = default;
60
61 // Collects the value with these labels.
62 void CollectValue(const std::array<string, NumLabels>& labels,
63 const Value& value);
64
65 private:
66 friend class internal::Collector;
67
MetricCollector(const MetricDef<metric_kind,Value,NumLabels> * const metric_def,const uint64 registration_time_millis,internal::Collector * const collector,PointSet * const point_set)68 MetricCollector(
69 const MetricDef<metric_kind, Value, NumLabels>* const metric_def,
70 const uint64 registration_time_millis,
71 internal::Collector* const collector, PointSet* const point_set)
72 : metric_def_(metric_def),
73 registration_time_millis_(registration_time_millis),
74 collector_(collector),
75 point_set_(point_set) {
76 point_set_->metric_name = string(metric_def->name());
77 }
78
79 const MetricDef<metric_kind, Value, NumLabels>* const metric_def_;
80 const uint64 registration_time_millis_;
81 internal::Collector* const collector_;
82 PointSet* const point_set_;
83
84 // This is made copyable because we can't hand out references of this class
85 // from MetricCollectorGetter because this class is templatized, and we need
86 // MetricCollectorGetter not to be templatized and hence MetricCollectorGetter
87 // can't own an instance of this class.
88 };
89
90 // Returns a MetricCollector with the same template parameters as the
91 // metric-definition, so that the values of a metric can be collected.
92 //
93 // The collection-function defined by a metric takes this as a parameter.
94 //
95 // Read the documentation on CollectionRegistry::Register() for more details.
96 class MetricCollectorGetter {
97 public:
98 // Returns the MetricCollector with the same template parameters as the
99 // metric_def.
100 template <MetricKind metric_kind, typename Value, int NumLabels>
101 MetricCollector<metric_kind, Value, NumLabels> Get(
102 const MetricDef<metric_kind, Value, NumLabels>* const metric_def);
103
104 private:
105 friend class internal::Collector;
106
MetricCollectorGetter(internal::Collector * const collector,const AbstractMetricDef * const allowed_metric_def,const uint64 registration_time_millis)107 MetricCollectorGetter(internal::Collector* const collector,
108 const AbstractMetricDef* const allowed_metric_def,
109 const uint64 registration_time_millis)
110 : collector_(collector),
111 allowed_metric_def_(allowed_metric_def),
112 registration_time_millis_(registration_time_millis) {}
113
114 internal::Collector* const collector_;
115 const AbstractMetricDef* const allowed_metric_def_;
116 const uint64 registration_time_millis_;
117 };
118
119 // A collection registry for metrics.
120 //
121 // Metrics are registered here so that their state can be collected later and
122 // exported.
123 //
124 // This class is thread-safe.
125 class CollectionRegistry {
126 public:
127 ~CollectionRegistry() = default;
128
129 // Returns the default registry for the process.
130 //
131 // This registry belongs to this library and should never be deleted.
132 static CollectionRegistry* Default();
133
134 using CollectionFunction = std::function<void(MetricCollectorGetter getter)>;
135
136 // Registers the metric and the collection-function which can be used to
137 // collect its values. Returns a Registration object, which when upon
138 // destruction would cause the metric to be unregistered from this registry.
139 //
140 // IMPORTANT: Delete the handle before the metric-def is deleted.
141 //
142 // Example usage;
143 // CollectionRegistry::Default()->Register(
144 // &metric_def,
145 // [&](MetricCollectorGetter getter) {
146 // auto metric_collector = getter.Get(&metric_def);
147 // for (const auto& cell : cells) {
148 // metric_collector.CollectValue(cell.labels(), cell.value());
149 // }
150 // });
151 class RegistrationHandle;
152 std::unique_ptr<RegistrationHandle> Register(
153 const AbstractMetricDef* metric_def,
154 const CollectionFunction& collection_function)
155 LOCKS_EXCLUDED(mu_) TF_MUST_USE_RESULT;
156
157 // Options for collecting metrics.
158 struct CollectMetricsOptions {
CollectMetricsOptionsCollectMetricsOptions159 CollectMetricsOptions() {}
160 bool collect_metric_descriptors = true;
161 };
162 // Goes through all the registered metrics, collects their definitions
163 // (optionally) and current values and returns them in a standard format.
164 std::unique_ptr<CollectedMetrics> CollectMetrics(
165 const CollectMetricsOptions& options) const;
166
167 private:
168 friend class test_util::CollectionRegistryTestAccess;
169 friend class internal::Collector;
170
171 CollectionRegistry(Env* env);
172
173 // Unregisters the metric from this registry. This is private because the
174 // public interface provides a Registration handle which automatically calls
175 // this upon destruction.
176 void Unregister(const AbstractMetricDef* metric_def) LOCKS_EXCLUDED(mu_);
177
178 // TF environment, mainly used for timestamping.
179 Env* const env_;
180
181 mutable mutex mu_;
182
183 // Information required for collection.
184 struct CollectionInfo {
185 const AbstractMetricDef* const metric_def;
186 CollectionFunction collection_function;
187 uint64 registration_time_millis;
188 };
189 std::map<StringPiece, CollectionInfo> registry_ GUARDED_BY(mu_);
190
191 TF_DISALLOW_COPY_AND_ASSIGN(CollectionRegistry);
192 };
193
194 ////
195 // Implementation details follow. API readers may skip.
196 ////
197
198 class CollectionRegistry::RegistrationHandle {
199 public:
RegistrationHandle(CollectionRegistry * const export_registry,const AbstractMetricDef * const metric_def)200 RegistrationHandle(CollectionRegistry* const export_registry,
201 const AbstractMetricDef* const metric_def)
202 : export_registry_(export_registry), metric_def_(metric_def) {}
203
~RegistrationHandle()204 ~RegistrationHandle() { export_registry_->Unregister(metric_def_); }
205
206 private:
207 CollectionRegistry* const export_registry_;
208 const AbstractMetricDef* const metric_def_;
209 };
210
211 namespace internal {
212
213 template <typename Value>
214 void CollectValue(const Value& value, Point* point);
215
216 template <>
CollectValue(const int64 & value,Point * const point)217 inline void CollectValue(const int64& value, Point* const point) {
218 point->value_type = ValueType::kInt64;
219 point->int64_value = value;
220 }
221
222 template <>
CollectValue(const string & value,Point * const point)223 inline void CollectValue(const string& value, Point* const point) {
224 point->value_type = ValueType::kString;
225 point->string_value = value;
226 }
227
228 template <>
CollectValue(const bool & value,Point * const point)229 inline void CollectValue(const bool& value, Point* const point) {
230 point->value_type = ValueType::kBool;
231 point->bool_value = value;
232 }
233
234 template <>
CollectValue(const HistogramProto & value,Point * const point)235 inline void CollectValue(const HistogramProto& value, Point* const point) {
236 point->value_type = ValueType::kHistogram;
237 // This is inefficient. If and when we hit snags, we can change the API to do
238 // this more efficiently.
239 point->histogram_value = value;
240 }
241
242 // Used by the CollectionRegistry class to collect all the values of all the
243 // metrics in the registry. This is an implementation detail of the
244 // CollectionRegistry class, please do not depend on this.
245 //
246 // This cannot be a private nested class because we need to forward declare this
247 // so that the MetricCollector and MetricCollectorGetter classes can be friends
248 // with it.
249 //
250 // This class is thread-safe.
251 class Collector {
252 public:
Collector(const uint64 collection_time_millis)253 Collector(const uint64 collection_time_millis)
254 : collected_metrics_(new CollectedMetrics()),
255 collection_time_millis_(collection_time_millis) {}
256
257 template <MetricKind metric_kind, typename Value, int NumLabels>
GetMetricCollector(const MetricDef<metric_kind,Value,NumLabels> * const metric_def,const uint64 registration_time_millis,internal::Collector * const collector)258 MetricCollector<metric_kind, Value, NumLabels> GetMetricCollector(
259 const MetricDef<metric_kind, Value, NumLabels>* const metric_def,
260 const uint64 registration_time_millis,
261 internal::Collector* const collector) LOCKS_EXCLUDED(mu_) {
262 auto* const point_set = [&]() {
263 mutex_lock l(mu_);
264 return collected_metrics_->point_set_map
265 .insert(std::make_pair(string(metric_def->name()),
266 std::unique_ptr<PointSet>(new PointSet())))
267 .first->second.get();
268 }();
269 return MetricCollector<metric_kind, Value, NumLabels>(
270 metric_def, registration_time_millis, collector, point_set);
271 }
272
collection_time_millis()273 uint64 collection_time_millis() const { return collection_time_millis_; }
274
275 void CollectMetricDescriptor(const AbstractMetricDef* const metric_def)
276 LOCKS_EXCLUDED(mu_);
277
278 void CollectMetricValues(
279 const CollectionRegistry::CollectionInfo& collection_info);
280
281 std::unique_ptr<CollectedMetrics> ConsumeCollectedMetrics()
282 LOCKS_EXCLUDED(mu_);
283
284 private:
285 mutable mutex mu_;
286 std::unique_ptr<CollectedMetrics> collected_metrics_ GUARDED_BY(mu_);
287 const uint64 collection_time_millis_;
288
289 TF_DISALLOW_COPY_AND_ASSIGN(Collector);
290 };
291
292 // Write the timestamps for the point based on the MetricKind.
293 //
294 // Gauge metrics will have start and end timestamps set to the collection time.
295 //
296 // Cumulative metrics will have the start timestamp set to the time when the
297 // collection function was registered, while the end timestamp will be set to
298 // the collection time.
299 template <MetricKind kind>
300 void WriteTimestamps(const uint64 registration_time_millis,
301 const uint64 collection_time_millis, Point* const point);
302
303 template <>
304 inline void WriteTimestamps<MetricKind::kGauge>(
305 const uint64 registration_time_millis, const uint64 collection_time_millis,
306 Point* const point) {
307 point->start_timestamp_millis = collection_time_millis;
308 point->end_timestamp_millis = collection_time_millis;
309 }
310
311 template <>
312 inline void WriteTimestamps<MetricKind::kCumulative>(
313 const uint64 registration_time_millis, const uint64 collection_time_millis,
314 Point* const point) {
315 point->start_timestamp_millis = registration_time_millis;
316 // There's a chance that the clock goes backwards on the same machine, so we
317 // protect ourselves against that.
318 point->end_timestamp_millis =
319 registration_time_millis < collection_time_millis
320 ? collection_time_millis
321 : registration_time_millis;
322 }
323
324 } // namespace internal
325
326 template <MetricKind metric_kind, typename Value, int NumLabels>
CollectValue(const std::array<string,NumLabels> & labels,const Value & value)327 void MetricCollector<metric_kind, Value, NumLabels>::CollectValue(
328 const std::array<string, NumLabels>& labels, const Value& value) {
329 point_set_->points.emplace_back(new Point());
330 auto* const point = point_set_->points.back().get();
331 const std::vector<string> label_descriptions =
332 metric_def_->label_descriptions();
333 point->labels.reserve(NumLabels);
334 for (int i = 0; i < NumLabels; ++i) {
335 point->labels.push_back({});
336 auto* const label = &point->labels.back();
337 label->name = label_descriptions[i];
338 label->value = labels[i];
339 }
340 internal::CollectValue(value, point);
341 internal::WriteTimestamps<metric_kind>(
342 registration_time_millis_, collector_->collection_time_millis(), point);
343 }
344
345 template <MetricKind metric_kind, typename Value, int NumLabels>
Get(const MetricDef<metric_kind,Value,NumLabels> * const metric_def)346 MetricCollector<metric_kind, Value, NumLabels> MetricCollectorGetter::Get(
347 const MetricDef<metric_kind, Value, NumLabels>* const metric_def) {
348 if (allowed_metric_def_ != metric_def) {
349 LOG(FATAL) << "Expected collection for: " << allowed_metric_def_->name()
350 << " but instead got: " << metric_def->name();
351 }
352
353 return collector_->GetMetricCollector(metric_def, registration_time_millis_,
354 collector_);
355 }
356
357 } // namespace monitoring
358 } // namespace tensorflow
359
360 #endif // TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
361