• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
17 #define TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
18 
19 // clang-format off
20 // Required for IS_MOBILE_PLATFORM
21 #include "tensorflow/core/platform/platform.h"
22 // clang-format on
23 
24 // We replace this implementation with a null implementation for mobile
25 // platforms.
26 #ifdef IS_MOBILE_PLATFORM
27 #define TENSORFLOW_INCLUDED_FROM_COUNTER_H  // prevent accidental use of
28                                             // mobile_counter.h
29 #include "tensorflow/core/lib/monitoring/mobile_counter.h"
30 #undef TENSORFLOW_INCLUDED_FROM_COUNTER_H
31 #else
32 
33 #include <array>
34 #include <atomic>
35 #include <map>
36 
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/core/lib/monitoring/collection_registry.h"
39 #include "tensorflow/core/lib/monitoring/metric_def.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/mutex.h"
43 #include "tensorflow/core/platform/thread_annotations.h"
44 
45 namespace tensorflow {
46 namespace monitoring {
47 
48 // CounterCell stores each value of an Counter.
49 //
50 // A cell can be passed off to a module which may repeatedly update it without
51 // needing further map-indexing computations. This improves both encapsulation
52 // (separate modules can own a cell each, without needing to know about the map
53 // to which both cells belong) and performance (since map indexing and
54 // associated locking are both avoided).
55 //
56 // This class is thread-safe.
57 class CounterCell {
58  public:
CounterCell(int64 value)59   explicit CounterCell(int64 value) : value_(value) {}
~CounterCell()60   ~CounterCell() {}
61 
62   // Atomically increments the value by step.
63   // REQUIRES: Step be non-negative.
64   void IncrementBy(int64 step);
65 
66   // Retrieves the current value.
67   int64 value() const;
68 
69  private:
70   std::atomic<int64> value_;
71 
72   TF_DISALLOW_COPY_AND_ASSIGN(CounterCell);
73 };
74 
75 // A stateful class for updating a cumulative integer metric.
76 //
77 // This class encapsulates a set of values (or a single value for a label-less
78 // metric). Each value is identified by a tuple of labels. The class allows the
79 // user to increment each value.
80 //
81 // Counter allocates storage and maintains a cell for each value. You can
82 // retrieve an individual cell using a label-tuple and update it separately.
83 // This improves performance since operations related to retrieval, like
84 // map-indexing and locking, are avoided.
85 //
86 // This class is thread-safe.
87 template <int NumLabels>
88 class Counter {
89  public:
~Counter()90   ~Counter() {
91     // Deleted here, before the metric_def is destroyed.
92     registration_handle_.reset();
93   }
94 
95   // Creates the metric based on the metric-definition arguments.
96   //
97   // Example;
98   // auto* counter_with_label = Counter<1>::New("/tensorflow/counter",
99   //   "Tensorflow counter", "MyLabelName");
100   template <typename... MetricDefArgs>
101   static Counter* New(MetricDefArgs&&... metric_def_args);
102 
103   // Retrieves the cell for the specified labels, creating it on demand if
104   // not already present.
105   template <typename... Labels>
106   CounterCell* GetCell(const Labels&... labels) TF_LOCKS_EXCLUDED(mu_);
107 
GetStatus()108   Status GetStatus() { return status_; }
109 
110  private:
Counter(const MetricDef<MetricKind::kCumulative,int64,NumLabels> & metric_def)111   explicit Counter(
112       const MetricDef<MetricKind::kCumulative, int64, NumLabels>& metric_def)
113       : metric_def_(metric_def),
114         registration_handle_(CollectionRegistry::Default()->Register(
115             &metric_def_, [&](MetricCollectorGetter getter) {
116               auto metric_collector = getter.Get(&metric_def_);
117 
118               mutex_lock l(mu_);
119               for (const auto& cell : cells_) {
120                 metric_collector.CollectValue(cell.first, cell.second.value());
121               }
122             })) {
123     if (registration_handle_) {
124       status_ = Status::OK();
125     } else {
126       status_ = Status(tensorflow::error::Code::ALREADY_EXISTS,
127                        "Another metric with the same name already exists.");
128     }
129   }
130 
131   mutable mutex mu_;
132 
133   Status status_;
134 
135   // The metric definition. This will be used to identify the metric when we
136   // register it for collection.
137   const MetricDef<MetricKind::kCumulative, int64, NumLabels> metric_def_;
138 
139   std::unique_ptr<CollectionRegistry::RegistrationHandle> registration_handle_;
140 
141   using LabelArray = std::array<string, NumLabels>;
142   std::map<LabelArray, CounterCell> cells_ TF_GUARDED_BY(mu_);
143 
144   TF_DISALLOW_COPY_AND_ASSIGN(Counter);
145 };
146 
147 ////
148 //  Implementation details follow. API readers may skip.
149 ////
150 
IncrementBy(const int64 step)151 inline void CounterCell::IncrementBy(const int64 step) {
152   DCHECK_LE(0, step) << "Must not decrement cumulative metrics.";
153   value_ += step;
154 }
155 
value()156 inline int64 CounterCell::value() const { return value_; }
157 
158 template <int NumLabels>
159 template <typename... MetricDefArgs>
New(MetricDefArgs &&...metric_def_args)160 Counter<NumLabels>* Counter<NumLabels>::New(
161     MetricDefArgs&&... metric_def_args) {
162   return new Counter<NumLabels>(
163       MetricDef<MetricKind::kCumulative, int64, NumLabels>(
164           std::forward<MetricDefArgs>(metric_def_args)...));
165 }
166 
167 template <int NumLabels>
168 template <typename... Labels>
GetCell(const Labels &...labels)169 CounterCell* Counter<NumLabels>::GetCell(const Labels&... labels)
170     TF_LOCKS_EXCLUDED(mu_) {
171   // Provides a more informative error message than the one during array
172   // construction below.
173   static_assert(sizeof...(Labels) == NumLabels,
174                 "Mismatch between Counter<NumLabels> and number of labels "
175                 "provided in GetCell(...).");
176 
177   const LabelArray& label_array = {{labels...}};
178   mutex_lock l(mu_);
179   const auto found_it = cells_.find(label_array);
180   if (found_it != cells_.end()) {
181     return &(found_it->second);
182   }
183   return &(cells_
184                .emplace(std::piecewise_construct,
185                         std::forward_as_tuple(label_array),
186                         std::forward_as_tuple(0))
187                .first->second);
188 }
189 
190 }  // namespace monitoring
191 }  // namespace tensorflow
192 
193 #endif  // IS_MOBILE_PLATFORM
194 #endif  // TENSORFLOW_CORE_LIB_MONITORING_COUNTER_H_
195