1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
17 #define TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
18
19 // clang-format off
20 // Required for IS_MOBILE_PLATFORM
21 #include "tensorflow/core/platform/platform.h"
22 // clang-format on
23
24 // We replace this implementation with a null implementation for mobile
25 // platforms.
26 #ifdef IS_MOBILE_PLATFORM
27 #define TENSORFLOW_INCLUDED_FROM_COUNTER_H // prevent accidental use of
28 // mobile_counter.h
29 #include "tensorflow/core/lib/monitoring/mobile_counter.h"
30 #undef TENSORFLOW_INCLUDED_FROM_COUNTER_H
31 #else
32
33 #include <array>
34 #include <atomic>
35 #include <map>
36
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/core/lib/monitoring/collection_registry.h"
39 #include "tensorflow/core/lib/monitoring/metric_def.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/mutex.h"
43 #include "tensorflow/core/platform/thread_annotations.h"
44
45 namespace tensorflow {
46 namespace monitoring {
47
48 // CounterCell stores each value of an Counter.
49 //
50 // A cell can be passed off to a module which may repeatedly update it without
51 // needing further map-indexing computations. This improves both encapsulation
52 // (separate modules can own a cell each, without needing to know about the map
53 // to which both cells belong) and performance (since map indexing and
54 // associated locking are both avoided).
55 //
56 // This class is thread-safe.
57 class CounterCell {
58 public:
CounterCell(int64 value)59 explicit CounterCell(int64 value) : value_(value) {}
~CounterCell()60 ~CounterCell() {}
61
62 // Atomically increments the value by step.
63 // REQUIRES: Step be non-negative.
64 void IncrementBy(int64 step);
65
66 // Retrieves the current value.
67 int64 value() const;
68
69 private:
70 std::atomic<int64> value_;
71
72 TF_DISALLOW_COPY_AND_ASSIGN(CounterCell);
73 };
74
75 // A stateful class for updating a cumulative integer metric.
76 //
77 // This class encapsulates a set of values (or a single value for a label-less
78 // metric). Each value is identified by a tuple of labels. The class allows the
79 // user to increment each value.
80 //
81 // Counter allocates storage and maintains a cell for each value. You can
82 // retrieve an individual cell using a label-tuple and update it separately.
83 // This improves performance since operations related to retrieval, like
84 // map-indexing and locking, are avoided.
85 //
86 // This class is thread-safe.
87 template <int NumLabels>
88 class Counter {
89 public:
~Counter()90 ~Counter() {
91 // Deleted here, before the metric_def is destroyed.
92 registration_handle_.reset();
93 }
94
95 // Creates the metric based on the metric-definition arguments.
96 //
97 // Example;
98 // auto* counter_with_label = Counter<1>::New("/tensorflow/counter",
99 // "Tensorflow counter", "MyLabelName");
100 template <typename... MetricDefArgs>
101 static Counter* New(MetricDefArgs&&... metric_def_args);
102
103 // Retrieves the cell for the specified labels, creating it on demand if
104 // not already present.
105 template <typename... Labels>
106 CounterCell* GetCell(const Labels&... labels) TF_LOCKS_EXCLUDED(mu_);
107
GetStatus()108 Status GetStatus() { return status_; }
109
110 private:
Counter(const MetricDef<MetricKind::kCumulative,int64,NumLabels> & metric_def)111 explicit Counter(
112 const MetricDef<MetricKind::kCumulative, int64, NumLabels>& metric_def)
113 : metric_def_(metric_def),
114 registration_handle_(CollectionRegistry::Default()->Register(
115 &metric_def_, [&](MetricCollectorGetter getter) {
116 auto metric_collector = getter.Get(&metric_def_);
117
118 mutex_lock l(mu_);
119 for (const auto& cell : cells_) {
120 metric_collector.CollectValue(cell.first, cell.second.value());
121 }
122 })) {
123 if (registration_handle_) {
124 status_ = Status::OK();
125 } else {
126 status_ = Status(tensorflow::error::Code::ALREADY_EXISTS,
127 "Another metric with the same name already exists.");
128 }
129 }
130
131 mutable mutex mu_;
132
133 Status status_;
134
135 // The metric definition. This will be used to identify the metric when we
136 // register it for collection.
137 const MetricDef<MetricKind::kCumulative, int64, NumLabels> metric_def_;
138
139 std::unique_ptr<CollectionRegistry::RegistrationHandle> registration_handle_;
140
141 using LabelArray = std::array<string, NumLabels>;
142 std::map<LabelArray, CounterCell> cells_ TF_GUARDED_BY(mu_);
143
144 TF_DISALLOW_COPY_AND_ASSIGN(Counter);
145 };
146
147 ////
148 // Implementation details follow. API readers may skip.
149 ////
150
IncrementBy(const int64 step)151 inline void CounterCell::IncrementBy(const int64 step) {
152 DCHECK_LE(0, step) << "Must not decrement cumulative metrics.";
153 value_ += step;
154 }
155
value()156 inline int64 CounterCell::value() const { return value_; }
157
158 template <int NumLabels>
159 template <typename... MetricDefArgs>
New(MetricDefArgs &&...metric_def_args)160 Counter<NumLabels>* Counter<NumLabels>::New(
161 MetricDefArgs&&... metric_def_args) {
162 return new Counter<NumLabels>(
163 MetricDef<MetricKind::kCumulative, int64, NumLabels>(
164 std::forward<MetricDefArgs>(metric_def_args)...));
165 }
166
167 template <int NumLabels>
168 template <typename... Labels>
GetCell(const Labels &...labels)169 CounterCell* Counter<NumLabels>::GetCell(const Labels&... labels)
170 TF_LOCKS_EXCLUDED(mu_) {
171 // Provides a more informative error message than the one during array
172 // construction below.
173 static_assert(sizeof...(Labels) == NumLabels,
174 "Mismatch between Counter<NumLabels> and number of labels "
175 "provided in GetCell(...).");
176
177 const LabelArray& label_array = {{labels...}};
178 mutex_lock l(mu_);
179 const auto found_it = cells_.find(label_array);
180 if (found_it != cells_.end()) {
181 return &(found_it->second);
182 }
183 return &(cells_
184 .emplace(std::piecewise_construct,
185 std::forward_as_tuple(label_array),
186 std::forward_as_tuple(0))
187 .first->second);
188 }
189
190 } // namespace monitoring
191 } // namespace tensorflow
192
193 #endif // IS_MOBILE_PLATFORM
194 #endif // TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
195