• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
17 #define TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
18 
19 #include <map>
20 #include <memory>
21 
22 #include "tensorflow/core/framework/summary.pb.h"
23 #include "tensorflow/core/lib/core/stringpiece.h"
24 #include "tensorflow/core/lib/monitoring/collected_metrics.h"
25 #include "tensorflow/core/lib/monitoring/metric_def.h"
26 #include "tensorflow/core/platform/env.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow/core/platform/mutex.h"
30 #include "tensorflow/core/platform/thread_annotations.h"
31 #include "tensorflow/core/platform/types.h"
32 
33 namespace tensorflow {
34 namespace monitoring {
35 
36 namespace test_util {
37 class CollectionRegistryTestAccess;
38 }  // namespace test_util
39 
40 namespace internal {
41 class Collector;
42 }  // namespace internal
43 
44 // Metric implementations would get an instance of this class using the
45 // MetricCollectorGetter in the collection-function lambda, so that their values
46 // can be collected.
47 //
48 // Read the documentation on CollectionRegistry::Register() for more details.
49 //
50 // For example:
51 //   auto metric_collector = metric_collector_getter->Get(&metric_def);
52 //   metric_collector.CollectValue(some_labels, some_value);
53 //   metric_collector.CollectValue(others_labels, other_value);
54 //
55 // This class is NOT thread-safe.
56 template <MetricKind metric_kind, typename Value, int NumLabels>
57 class MetricCollector {
58  public:
59   ~MetricCollector() = default;
60 
61   // Collects the value with these labels.
62   void CollectValue(const std::array<string, NumLabels>& labels,
63                     const Value& value);
64 
65  private:
66   friend class internal::Collector;
67 
MetricCollector(const MetricDef<metric_kind,Value,NumLabels> * const metric_def,const uint64 registration_time_millis,internal::Collector * const collector,PointSet * const point_set)68   MetricCollector(
69       const MetricDef<metric_kind, Value, NumLabels>* const metric_def,
70       const uint64 registration_time_millis,
71       internal::Collector* const collector, PointSet* const point_set)
72       : metric_def_(metric_def),
73         registration_time_millis_(registration_time_millis),
74         collector_(collector),
75         point_set_(point_set) {
76     point_set_->metric_name = string(metric_def->name());
77   }
78 
79   const MetricDef<metric_kind, Value, NumLabels>* const metric_def_;
80   const uint64 registration_time_millis_;
81   internal::Collector* const collector_;
82   PointSet* const point_set_;
83 
84   // This is made copyable because we can't hand out references of this class
85   // from MetricCollectorGetter because this class is templatized, and we need
86   // MetricCollectorGetter not to be templatized and hence MetricCollectorGetter
87   // can't own an instance of this class.
88 };
89 
90 // Returns a MetricCollector with the same template parameters as the
91 // metric-definition, so that the values of a metric can be collected.
92 //
93 // The collection-function defined by a metric takes this as a parameter.
94 //
95 // Read the documentation on CollectionRegistry::Register() for more details.
96 class MetricCollectorGetter {
97  public:
98   // Returns the MetricCollector with the same template parameters as the
99   // metric_def.
100   template <MetricKind metric_kind, typename Value, int NumLabels>
101   MetricCollector<metric_kind, Value, NumLabels> Get(
102       const MetricDef<metric_kind, Value, NumLabels>* const metric_def);
103 
104  private:
105   friend class internal::Collector;
106 
MetricCollectorGetter(internal::Collector * const collector,const AbstractMetricDef * const allowed_metric_def,const uint64 registration_time_millis)107   MetricCollectorGetter(internal::Collector* const collector,
108                         const AbstractMetricDef* const allowed_metric_def,
109                         const uint64 registration_time_millis)
110       : collector_(collector),
111         allowed_metric_def_(allowed_metric_def),
112         registration_time_millis_(registration_time_millis) {}
113 
114   internal::Collector* const collector_;
115   const AbstractMetricDef* const allowed_metric_def_;
116   const uint64 registration_time_millis_;
117 };
118 
119 // A collection registry for metrics.
120 //
121 // Metrics are registered here so that their state can be collected later and
122 // exported.
123 //
124 // This class is thread-safe.
125 class CollectionRegistry {
126  public:
127   ~CollectionRegistry() = default;
128 
129   // Returns the default registry for the process.
130   //
131   // This registry belongs to this library and should never be deleted.
132   static CollectionRegistry* Default();
133 
134   using CollectionFunction = std::function<void(MetricCollectorGetter getter)>;
135 
136   // Registers the metric and the collection-function which can be used to
137   // collect its values. Returns a Registration object, which when upon
138   // destruction would cause the metric to be unregistered from this registry.
139   //
140   // IMPORTANT: Delete the handle before the metric-def is deleted.
141   //
142   // Example usage;
143   // CollectionRegistry::Default()->Register(
144   //   &metric_def,
145   //   [&](MetricCollectorGetter getter) {
146   //     auto metric_collector = getter.Get(&metric_def);
147   //     for (const auto& cell : cells) {
148   //       metric_collector.CollectValue(cell.labels(), cell.value());
149   //     }
150   //   });
151   class RegistrationHandle;
152   std::unique_ptr<RegistrationHandle> Register(
153       const AbstractMetricDef* metric_def,
154       const CollectionFunction& collection_function)
155       LOCKS_EXCLUDED(mu_) TF_MUST_USE_RESULT;
156 
157   // Options for collecting metrics.
158   struct CollectMetricsOptions {
CollectMetricsOptionsCollectMetricsOptions159     CollectMetricsOptions() {}
160     bool collect_metric_descriptors = true;
161   };
162   // Goes through all the registered metrics, collects their definitions
163   // (optionally) and current values and returns them in a standard format.
164   std::unique_ptr<CollectedMetrics> CollectMetrics(
165       const CollectMetricsOptions& options) const;
166 
167  private:
168   friend class test_util::CollectionRegistryTestAccess;
169   friend class internal::Collector;
170 
171   CollectionRegistry(Env* env);
172 
173   // Unregisters the metric from this registry. This is private because the
174   // public interface provides a Registration handle which automatically calls
175   // this upon destruction.
176   void Unregister(const AbstractMetricDef* metric_def) LOCKS_EXCLUDED(mu_);
177 
178   // TF environment, mainly used for timestamping.
179   Env* const env_;
180 
181   mutable mutex mu_;
182 
183   // Information required for collection.
184   struct CollectionInfo {
185     const AbstractMetricDef* const metric_def;
186     CollectionFunction collection_function;
187     uint64 registration_time_millis;
188   };
189   std::map<StringPiece, CollectionInfo> registry_ GUARDED_BY(mu_);
190 
191   TF_DISALLOW_COPY_AND_ASSIGN(CollectionRegistry);
192 };
193 
194 ////
195 // Implementation details follow. API readers may skip.
196 ////
197 
198 class CollectionRegistry::RegistrationHandle {
199  public:
RegistrationHandle(CollectionRegistry * const export_registry,const AbstractMetricDef * const metric_def)200   RegistrationHandle(CollectionRegistry* const export_registry,
201                      const AbstractMetricDef* const metric_def)
202       : export_registry_(export_registry), metric_def_(metric_def) {}
203 
~RegistrationHandle()204   ~RegistrationHandle() { export_registry_->Unregister(metric_def_); }
205 
206  private:
207   CollectionRegistry* const export_registry_;
208   const AbstractMetricDef* const metric_def_;
209 };
210 
211 namespace internal {
212 
213 template <typename Value>
214 void CollectValue(const Value& value, Point* point);
215 
216 template <>
CollectValue(const int64 & value,Point * const point)217 inline void CollectValue(const int64& value, Point* const point) {
218   point->value_type = ValueType::kInt64;
219   point->int64_value = value;
220 }
221 
222 template <>
CollectValue(const string & value,Point * const point)223 inline void CollectValue(const string& value, Point* const point) {
224   point->value_type = ValueType::kString;
225   point->string_value = value;
226 }
227 
228 template <>
CollectValue(const bool & value,Point * const point)229 inline void CollectValue(const bool& value, Point* const point) {
230   point->value_type = ValueType::kBool;
231   point->bool_value = value;
232 }
233 
234 template <>
CollectValue(const HistogramProto & value,Point * const point)235 inline void CollectValue(const HistogramProto& value, Point* const point) {
236   point->value_type = ValueType::kHistogram;
237   // This is inefficient. If and when we hit snags, we can change the API to do
238   // this more efficiently.
239   point->histogram_value = value;
240 }
241 
242 // Used by the CollectionRegistry class to collect all the values of all the
243 // metrics in the registry. This is an implementation detail of the
244 // CollectionRegistry class, please do not depend on this.
245 //
246 // This cannot be a private nested class because we need to forward declare this
247 // so that the MetricCollector and MetricCollectorGetter classes can be friends
248 // with it.
249 //
250 // This class is thread-safe.
251 class Collector {
252  public:
Collector(const uint64 collection_time_millis)253   Collector(const uint64 collection_time_millis)
254       : collected_metrics_(new CollectedMetrics()),
255         collection_time_millis_(collection_time_millis) {}
256 
257   template <MetricKind metric_kind, typename Value, int NumLabels>
GetMetricCollector(const MetricDef<metric_kind,Value,NumLabels> * const metric_def,const uint64 registration_time_millis,internal::Collector * const collector)258   MetricCollector<metric_kind, Value, NumLabels> GetMetricCollector(
259       const MetricDef<metric_kind, Value, NumLabels>* const metric_def,
260       const uint64 registration_time_millis,
261       internal::Collector* const collector) LOCKS_EXCLUDED(mu_) {
262     auto* const point_set = [&]() {
263       mutex_lock l(mu_);
264       return collected_metrics_->point_set_map
265           .insert(std::make_pair(string(metric_def->name()),
266                                  std::unique_ptr<PointSet>(new PointSet())))
267           .first->second.get();
268     }();
269     return MetricCollector<metric_kind, Value, NumLabels>(
270         metric_def, registration_time_millis, collector, point_set);
271   }
272 
collection_time_millis()273   uint64 collection_time_millis() const { return collection_time_millis_; }
274 
275   void CollectMetricDescriptor(const AbstractMetricDef* const metric_def)
276       LOCKS_EXCLUDED(mu_);
277 
278   void CollectMetricValues(
279       const CollectionRegistry::CollectionInfo& collection_info);
280 
281   std::unique_ptr<CollectedMetrics> ConsumeCollectedMetrics()
282       LOCKS_EXCLUDED(mu_);
283 
284  private:
285   mutable mutex mu_;
286   std::unique_ptr<CollectedMetrics> collected_metrics_ GUARDED_BY(mu_);
287   const uint64 collection_time_millis_;
288 
289   TF_DISALLOW_COPY_AND_ASSIGN(Collector);
290 };
291 
292 // Write the timestamps for the point based on the MetricKind.
293 //
294 // Gauge metrics will have start and end timestamps set to the collection time.
295 //
296 // Cumulative metrics will have the start timestamp set to the time when the
297 // collection function was registered, while the end timestamp will be set to
298 // the collection time.
299 template <MetricKind kind>
300 void WriteTimestamps(const uint64 registration_time_millis,
301                      const uint64 collection_time_millis, Point* const point);
302 
303 template <>
304 inline void WriteTimestamps<MetricKind::kGauge>(
305     const uint64 registration_time_millis, const uint64 collection_time_millis,
306     Point* const point) {
307   point->start_timestamp_millis = collection_time_millis;
308   point->end_timestamp_millis = collection_time_millis;
309 }
310 
311 template <>
312 inline void WriteTimestamps<MetricKind::kCumulative>(
313     const uint64 registration_time_millis, const uint64 collection_time_millis,
314     Point* const point) {
315   point->start_timestamp_millis = registration_time_millis;
316   // There's a chance that the clock goes backwards on the same machine, so we
317   // protect ourselves against that.
318   point->end_timestamp_millis =
319       registration_time_millis < collection_time_millis
320           ? collection_time_millis
321           : registration_time_millis;
322 }
323 
324 }  // namespace internal
325 
326 template <MetricKind metric_kind, typename Value, int NumLabels>
CollectValue(const std::array<string,NumLabels> & labels,const Value & value)327 void MetricCollector<metric_kind, Value, NumLabels>::CollectValue(
328     const std::array<string, NumLabels>& labels, const Value& value) {
329   point_set_->points.emplace_back(new Point());
330   auto* const point = point_set_->points.back().get();
331   const std::vector<string> label_descriptions =
332       metric_def_->label_descriptions();
333   point->labels.reserve(NumLabels);
334   for (int i = 0; i < NumLabels; ++i) {
335     point->labels.push_back({});
336     auto* const label = &point->labels.back();
337     label->name = label_descriptions[i];
338     label->value = labels[i];
339   }
340   internal::CollectValue(value, point);
341   internal::WriteTimestamps<metric_kind>(
342       registration_time_millis_, collector_->collection_time_millis(), point);
343 }
344 
345 template <MetricKind metric_kind, typename Value, int NumLabels>
Get(const MetricDef<metric_kind,Value,NumLabels> * const metric_def)346 MetricCollector<metric_kind, Value, NumLabels> MetricCollectorGetter::Get(
347     const MetricDef<metric_kind, Value, NumLabels>* const metric_def) {
348   if (allowed_metric_def_ != metric_def) {
349     LOG(FATAL) << "Expected collection for: " << allowed_metric_def_->name()
350                << " but instead got: " << metric_def->name();
351   }
352 
353   return collector_->GetMetricCollector(metric_def, registration_time_millis_,
354                                         collector_);
355 }
356 
357 }  // namespace monitoring
358 }  // namespace tensorflow
359 
360 #endif  // TENSORFLOW_CORE_LIB_MONITORING_COLLECTION_REGISTRY_H_
361