1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_
17 #define TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_
18
19 // clang-format off
20 // Required for IS_MOBILE_PLATFORM
21 #include "tensorflow/core/platform/platform.h"
22 // clang-format on
23
24 // We replace this implementation with a null implementation for mobile
25 // platforms.
26 #ifdef IS_MOBILE_PLATFORM
27 #define TENSORFLOW_INCLUDED_FROM_SAMPLER_H // prevent accidental use of
28 // mobile_sampler.h
29 #include "tensorflow/core/lib/monitoring/mobile_sampler.h"
30 #undef TENSORFLOW_INCLUDED_FROM_SAMPLER_H
31 #else
32
33 #include <float.h>
34
35 #include <map>
36
37 #include "tensorflow/core/framework/summary.pb.h"
38 #include "tensorflow/core/lib/core/status.h"
39 #include "tensorflow/core/lib/histogram/histogram.h"
40 #include "tensorflow/core/lib/monitoring/collection_registry.h"
41 #include "tensorflow/core/lib/monitoring/metric_def.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/mutex.h"
44 #include "tensorflow/core/platform/thread_annotations.h"
45
46 namespace tensorflow {
47 namespace monitoring {
48
49 // SamplerCell stores each value of an Sampler.
50 //
51 // A cell can be passed off to a module which may repeatedly update it without
52 // needing further map-indexing computations. This improves both encapsulation
53 // (separate modules can own a cell each, without needing to know about the map
54 // to which both cells belong) and performance (since map indexing and
55 // associated locking are both avoided).
56 //
57 // This class is thread-safe.
58 class SamplerCell {
59 public:
SamplerCell(const std::vector<double> & bucket_limits)60 SamplerCell(const std::vector<double>& bucket_limits)
61 : histogram_(bucket_limits) {}
62
~SamplerCell()63 ~SamplerCell() {}
64
65 // Atomically adds a sample.
66 void Add(double sample);
67
68 // Returns the current histogram value as a proto.
69 HistogramProto value() const;
70
71 private:
72 histogram::ThreadSafeHistogram histogram_;
73
74 TF_DISALLOW_COPY_AND_ASSIGN(SamplerCell);
75 };
76
77 // Bucketing strategies for the samplers.
78 //
79 // We automatically add -DBL_MAX and DBL_MAX to the ranges, so that no sample
80 // goes out of bounds.
81 //
82 // WARNING: If you are changing the interface here, please do change the same in
83 // mobile_sampler.h.
84 class Buckets {
85 public:
86 virtual ~Buckets() = default;
87
88 // Sets up buckets of the form:
89 // [-DBL_MAX, ..., scale * growth^i,
90 // scale * growth_factor^(i + 1), ..., DBL_MAX].
91 //
92 // So for powers of 2 with a bucket count of 10, you would say (1, 2, 10)
93 static std::unique_ptr<Buckets> Exponential(double scale,
94 double growth_factor,
95 int bucket_count);
96
97 // Sets up buckets of the form:
98 // [-DBL_MAX, ..., bucket_limits[i], bucket_limits[i + 1], ..., DBL_MAX].
99 static std::unique_ptr<Buckets> Explicit(
100 std::initializer_list<double> bucket_limits);
101
102 // This alternative Explicit Buckets factory method is primarily meant to be
103 // used by the CLIF layer code paths that are incompatible with
104 // initialize_lists.
105 static std::unique_ptr<Buckets> Explicit(std::vector<double> bucket_limits);
106
107 virtual const std::vector<double>& explicit_bounds() const = 0;
108 };
109
110 // A stateful class for updating a cumulative histogram metric.
111 //
112 // This class encapsulates a set of histograms (or a single histogram for a
113 // label-less metric) configured with a list of increasing bucket boundaries.
114 // Each histogram is identified by a tuple of labels. The class allows the
115 // user to add a sample to each histogram value.
116 //
117 // Sampler allocates storage and maintains a cell for each value. You can
118 // retrieve an individual cell using a label-tuple and update it separately.
119 // This improves performance since operations related to retrieval, like
120 // map-indexing and locking, are avoided.
121 //
122 // This class is thread-safe.
123 template <int NumLabels>
124 class Sampler {
125 public:
~Sampler()126 ~Sampler() {
127 // Deleted here, before the metric_def is destroyed.
128 registration_handle_.reset();
129 }
130
131 // Creates the metric based on the metric-definition arguments and buckets.
132 //
133 // Example;
134 // auto* sampler_with_label = Sampler<1>::New({"/tensorflow/sampler",
135 // "Tensorflow sampler", "MyLabelName"}, {10.0, 20.0, 30.0});
136 static Sampler* New(const MetricDef<MetricKind::kCumulative, HistogramProto,
137 NumLabels>& metric_def,
138 std::unique_ptr<Buckets> buckets);
139
140 // Retrieves the cell for the specified labels, creating it on demand if
141 // not already present.
142 template <typename... Labels>
143 SamplerCell* GetCell(const Labels&... labels) TF_LOCKS_EXCLUDED(mu_);
144
GetStatus()145 Status GetStatus() { return status_; }
146
147 private:
148 friend class SamplerCell;
149
Sampler(const MetricDef<MetricKind::kCumulative,HistogramProto,NumLabels> & metric_def,std::unique_ptr<Buckets> buckets)150 Sampler(const MetricDef<MetricKind::kCumulative, HistogramProto, NumLabels>&
151 metric_def,
152 std::unique_ptr<Buckets> buckets)
153 : metric_def_(metric_def),
154 buckets_(std::move(buckets)),
155 registration_handle_(CollectionRegistry::Default()->Register(
156 &metric_def_, [&](MetricCollectorGetter getter) {
157 auto metric_collector = getter.Get(&metric_def_);
158
159 mutex_lock l(mu_);
160 for (const auto& cell : cells_) {
161 metric_collector.CollectValue(cell.first, cell.second.value());
162 }
163 })) {
164 if (registration_handle_) {
165 status_ = Status::OK();
166 } else {
167 status_ = Status(tensorflow::error::Code::ALREADY_EXISTS,
168 "Another metric with the same name already exists.");
169 }
170 }
171
172 mutable mutex mu_;
173
174 Status status_;
175
176 // The metric definition. This will be used to identify the metric when we
177 // register it for collection.
178 const MetricDef<MetricKind::kCumulative, HistogramProto, NumLabels>
179 metric_def_;
180
181 // Bucket limits for the histograms in the cells.
182 std::unique_ptr<Buckets> buckets_;
183
184 // Registration handle with the CollectionRegistry.
185 std::unique_ptr<CollectionRegistry::RegistrationHandle> registration_handle_;
186
187 using LabelArray = std::array<string, NumLabels>;
188 // we need a container here that guarantees pointer stability of the value,
189 // namely, the pointer of the value should remain valid even after more cells
190 // are inserted.
191 std::map<LabelArray, SamplerCell> cells_ TF_GUARDED_BY(mu_);
192
193 TF_DISALLOW_COPY_AND_ASSIGN(Sampler);
194 };
195
196 ////
197 // Implementation details follow. API readers may skip.
198 ////
199
Add(const double sample)200 inline void SamplerCell::Add(const double sample) { histogram_.Add(sample); }
201
value()202 inline HistogramProto SamplerCell::value() const {
203 HistogramProto pb;
204 histogram_.EncodeToProto(&pb, true /* preserve_zero_buckets */);
205 return pb;
206 }
207
208 template <int NumLabels>
New(const MetricDef<MetricKind::kCumulative,HistogramProto,NumLabels> & metric_def,std::unique_ptr<Buckets> buckets)209 Sampler<NumLabels>* Sampler<NumLabels>::New(
210 const MetricDef<MetricKind::kCumulative, HistogramProto, NumLabels>&
211 metric_def,
212 std::unique_ptr<Buckets> buckets) {
213 return new Sampler<NumLabels>(metric_def, std::move(buckets));
214 }
215
216 template <int NumLabels>
217 template <typename... Labels>
GetCell(const Labels &...labels)218 SamplerCell* Sampler<NumLabels>::GetCell(const Labels&... labels)
219 TF_LOCKS_EXCLUDED(mu_) {
220 // Provides a more informative error message than the one during array
221 // construction below.
222 static_assert(sizeof...(Labels) == NumLabels,
223 "Mismatch between Sampler<NumLabels> and number of labels "
224 "provided in GetCell(...).");
225
226 const LabelArray& label_array = {{labels...}};
227 mutex_lock l(mu_);
228 const auto found_it = cells_.find(label_array);
229 if (found_it != cells_.end()) {
230 return &(found_it->second);
231 }
232 return &(cells_
233 .emplace(std::piecewise_construct,
234 std::forward_as_tuple(label_array),
235 std::forward_as_tuple(buckets_->explicit_bounds()))
236 .first->second);
237 }
238
239 } // namespace monitoring
240 } // namespace tensorflow
241
242 #endif // IS_MOBILE_PLATFORM
243 #endif // TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_
244