1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
17 #define TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
18
19 // We replace this implementation with a null implementation for mobile
20 // platforms.
21 #include "tensorflow/core/platform/platform.h"
22 #ifdef IS_MOBILE_PLATFORM
23 #include "tensorflow/core/lib/monitoring/mobile_counter.h"
24 #else
25
26 #include <array>
27 #include <atomic>
28 #include <map>
29
30 #include "tensorflow/core/lib/monitoring/collection_registry.h"
31 #include "tensorflow/core/lib/monitoring/metric_def.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/macros.h"
34 #include "tensorflow/core/platform/mutex.h"
35 #include "tensorflow/core/platform/thread_annotations.h"
36
37 namespace tensorflow {
38 namespace monitoring {
39
40 // CounterCell stores each value of an Counter.
41 //
42 // A cell can be passed off to a module which may repeatedly update it without
43 // needing further map-indexing computations. This improves both encapsulation
44 // (separate modules can own a cell each, without needing to know about the map
45 // to which both cells belong) and performance (since map indexing and
46 // associated locking are both avoided).
47 //
48 // This class is thread-safe.
49 class CounterCell {
50 public:
CounterCell(int64 value)51 CounterCell(int64 value) : value_(value) {}
~CounterCell()52 ~CounterCell() {}
53
54 // Atomically increments the value by step.
55 // REQUIRES: Step be non-negative.
56 void IncrementBy(int64 step);
57
58 // Retrieves the current value.
59 int64 value() const;
60
61 private:
62 std::atomic<int64> value_;
63
64 TF_DISALLOW_COPY_AND_ASSIGN(CounterCell);
65 };
66
67 // A stateful class for updating a cumulative integer metric.
68 //
69 // This class encapsulates a set of values (or a single value for a label-less
70 // metric). Each value is identified by a tuple of labels. The class allows the
71 // user to increment each value.
72 //
73 // Counter allocates storage and maintains a cell for each value. You can
74 // retrieve an individual cell using a label-tuple and update it separately.
75 // This improves performance since operations related to retrieval, like
76 // map-indexing and locking, are avoided.
77 //
78 // This class is thread-safe.
79 template <int NumLabels>
80 class Counter {
81 public:
~Counter()82 ~Counter() {
83 // Deleted here, before the metric_def is destroyed.
84 registration_handle_.reset();
85 }
86
87 // Creates the metric based on the metric-definition arguments.
88 //
89 // Example;
90 // auto* counter_with_label = Counter<1>::New("/tensorflow/counter",
91 // "Tensorflow counter", "MyLabelName");
92 template <typename... MetricDefArgs>
93 static Counter* New(MetricDefArgs&&... metric_def_args);
94
95 // Retrieves the cell for the specified labels, creating it on demand if
96 // not already present.
97 template <typename... Labels>
98 CounterCell* GetCell(const Labels&... labels) LOCKS_EXCLUDED(mu_);
99
100 private:
Counter(const MetricDef<MetricKind::kCumulative,int64,NumLabels> & metric_def)101 explicit Counter(
102 const MetricDef<MetricKind::kCumulative, int64, NumLabels>& metric_def)
103 : metric_def_(metric_def),
104 registration_handle_(CollectionRegistry::Default()->Register(
105 &metric_def_, [&](MetricCollectorGetter getter) {
106 auto metric_collector = getter.Get(&metric_def_);
107
108 mutex_lock l(mu_);
109 for (const auto& cell : cells_) {
110 metric_collector.CollectValue(cell.first, cell.second.value());
111 }
112 })) {}
113
114 mutable mutex mu_;
115
116 // The metric definition. This will be used to identify the metric when we
117 // register it for collection.
118 const MetricDef<MetricKind::kCumulative, int64, NumLabels> metric_def_;
119
120 std::unique_ptr<CollectionRegistry::RegistrationHandle> registration_handle_;
121
122 using LabelArray = std::array<string, NumLabels>;
123 std::map<LabelArray, CounterCell> cells_ GUARDED_BY(mu_);
124
125 TF_DISALLOW_COPY_AND_ASSIGN(Counter);
126 };
127
128 ////
129 // Implementation details follow. API readers may skip.
130 ////
131
IncrementBy(const int64 step)132 inline void CounterCell::IncrementBy(const int64 step) {
133 DCHECK_LE(0, step) << "Must not decrement cumulative metrics.";
134 value_ += step;
135 }
136
value()137 inline int64 CounterCell::value() const { return value_; }
138
139 template <int NumLabels>
140 template <typename... MetricDefArgs>
New(MetricDefArgs &&...metric_def_args)141 Counter<NumLabels>* Counter<NumLabels>::New(
142 MetricDefArgs&&... metric_def_args) {
143 return new Counter<NumLabels>(
144 MetricDef<MetricKind::kCumulative, int64, NumLabels>(
145 std::forward<MetricDefArgs>(metric_def_args)...));
146 }
147
148 template <int NumLabels>
149 template <typename... Labels>
GetCell(const Labels &...labels)150 CounterCell* Counter<NumLabels>::GetCell(const Labels&... labels)
151 LOCKS_EXCLUDED(mu_) {
152 // Provides a more informative error message than the one during array
153 // construction below.
154 static_assert(sizeof...(Labels) == NumLabels,
155 "Mismatch between Counter<NumLabels> and number of labels "
156 "provided in GetCell(...).");
157
158 const LabelArray& label_array = {{labels...}};
159 mutex_lock l(mu_);
160 const auto found_it = cells_.find(label_array);
161 if (found_it != cells_.end()) {
162 return &(found_it->second);
163 }
164 return &(cells_
165 .emplace(std::piecewise_construct,
166 std::forward_as_tuple(label_array),
167 std::forward_as_tuple(0))
168 .first->second);
169 }
170
171 } // namespace monitoring
172 } // namespace tensorflow
173
174 #endif // IS_MOBILE_PLATFORM
175 #endif // TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
176