1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
17 #define TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
18
19 // clang-format off
20 // Required for IS_MOBILE_PLATFORM
21 #include "tensorflow/core/platform/platform.h"
22 // clang-format on
23
24 // We replace this implementation with a null implementation for mobile
25 // platforms.
26 #ifdef IS_MOBILE_PLATFORM
27 #include "tensorflow/core/lib/monitoring/mobile_counter.h"
28 #else
29
30 #include <array>
31 #include <atomic>
32 #include <map>
33
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/lib/monitoring/collection_registry.h"
36 #include "tensorflow/core/lib/monitoring/metric_def.h"
37 #include "tensorflow/core/platform/logging.h"
38 #include "tensorflow/core/platform/macros.h"
39 #include "tensorflow/core/platform/mutex.h"
40 #include "tensorflow/core/platform/thread_annotations.h"
41
42 namespace tensorflow {
43 namespace monitoring {
44
45 // CounterCell stores each value of an Counter.
46 //
47 // A cell can be passed off to a module which may repeatedly update it without
48 // needing further map-indexing computations. This improves both encapsulation
49 // (separate modules can own a cell each, without needing to know about the map
50 // to which both cells belong) and performance (since map indexing and
51 // associated locking are both avoided).
52 //
53 // This class is thread-safe.
54 class CounterCell {
55 public:
CounterCell(int64 value)56 explicit CounterCell(int64 value) : value_(value) {}
~CounterCell()57 ~CounterCell() {}
58
59 // Atomically increments the value by step.
60 // REQUIRES: Step be non-negative.
61 void IncrementBy(int64 step);
62
63 // Retrieves the current value.
64 int64 value() const;
65
66 private:
67 std::atomic<int64> value_;
68
69 TF_DISALLOW_COPY_AND_ASSIGN(CounterCell);
70 };
71
72 // A stateful class for updating a cumulative integer metric.
73 //
74 // This class encapsulates a set of values (or a single value for a label-less
75 // metric). Each value is identified by a tuple of labels. The class allows the
76 // user to increment each value.
77 //
78 // Counter allocates storage and maintains a cell for each value. You can
79 // retrieve an individual cell using a label-tuple and update it separately.
80 // This improves performance since operations related to retrieval, like
81 // map-indexing and locking, are avoided.
82 //
83 // This class is thread-safe.
84 template <int NumLabels>
85 class Counter {
86 public:
~Counter()87 ~Counter() {
88 // Deleted here, before the metric_def is destroyed.
89 registration_handle_.reset();
90 }
91
92 // Creates the metric based on the metric-definition arguments.
93 //
94 // Example;
95 // auto* counter_with_label = Counter<1>::New("/tensorflow/counter",
96 // "Tensorflow counter", "MyLabelName");
97 template <typename... MetricDefArgs>
98 static Counter* New(MetricDefArgs&&... metric_def_args);
99
100 // Retrieves the cell for the specified labels, creating it on demand if
101 // not already present.
102 template <typename... Labels>
103 CounterCell* GetCell(const Labels&... labels) LOCKS_EXCLUDED(mu_);
104
GetStatus()105 Status GetStatus() { return status_; }
106
107 private:
Counter(const MetricDef<MetricKind::kCumulative,int64,NumLabels> & metric_def)108 explicit Counter(
109 const MetricDef<MetricKind::kCumulative, int64, NumLabels>& metric_def)
110 : metric_def_(metric_def),
111 registration_handle_(CollectionRegistry::Default()->Register(
112 &metric_def_, [&](MetricCollectorGetter getter) {
113 auto metric_collector = getter.Get(&metric_def_);
114
115 mutex_lock l(mu_);
116 for (const auto& cell : cells_) {
117 metric_collector.CollectValue(cell.first, cell.second.value());
118 }
119 })) {
120 if (registration_handle_) {
121 status_ = Status::OK();
122 } else {
123 status_ = Status(tensorflow::error::Code::ALREADY_EXISTS,
124 "Another metric with the same name already exists.");
125 }
126 }
127
128 mutable mutex mu_;
129
130 Status status_;
131
132 // The metric definition. This will be used to identify the metric when we
133 // register it for collection.
134 const MetricDef<MetricKind::kCumulative, int64, NumLabels> metric_def_;
135
136 std::unique_ptr<CollectionRegistry::RegistrationHandle> registration_handle_;
137
138 using LabelArray = std::array<string, NumLabels>;
139 std::map<LabelArray, CounterCell> cells_ GUARDED_BY(mu_);
140
141 TF_DISALLOW_COPY_AND_ASSIGN(Counter);
142 };
143
144 ////
145 // Implementation details follow. API readers may skip.
146 ////
147
IncrementBy(const int64 step)148 inline void CounterCell::IncrementBy(const int64 step) {
149 DCHECK_LE(0, step) << "Must not decrement cumulative metrics.";
150 value_ += step;
151 }
152
value()153 inline int64 CounterCell::value() const { return value_; }
154
155 template <int NumLabels>
156 template <typename... MetricDefArgs>
New(MetricDefArgs &&...metric_def_args)157 Counter<NumLabels>* Counter<NumLabels>::New(
158 MetricDefArgs&&... metric_def_args) {
159 return new Counter<NumLabels>(
160 MetricDef<MetricKind::kCumulative, int64, NumLabels>(
161 std::forward<MetricDefArgs>(metric_def_args)...));
162 }
163
164 template <int NumLabels>
165 template <typename... Labels>
GetCell(const Labels &...labels)166 CounterCell* Counter<NumLabels>::GetCell(const Labels&... labels)
167 LOCKS_EXCLUDED(mu_) {
168 // Provides a more informative error message than the one during array
169 // construction below.
170 static_assert(sizeof...(Labels) == NumLabels,
171 "Mismatch between Counter<NumLabels> and number of labels "
172 "provided in GetCell(...).");
173
174 const LabelArray& label_array = {{labels...}};
175 mutex_lock l(mu_);
176 const auto found_it = cells_.find(label_array);
177 if (found_it != cells_.end()) {
178 return &(found_it->second);
179 }
180 return &(cells_
181 .emplace(std::piecewise_construct,
182 std::forward_as_tuple(label_array),
183 std::forward_as_tuple(0))
184 .first->second);
185 }
186
187 } // namespace monitoring
188 } // namespace tensorflow
189
190 #endif // IS_MOBILE_PLATFORM
191 #endif // TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
192