• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_
17 #define TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_
18 
19 // clang-format off
20 // Required for IS_MOBILE_PLATFORM
21 #include "tensorflow/core/platform/platform.h"
22 // clang-format on
23 
24 // We replace this implementation with a null implementation for mobile
25 // platforms.
26 #ifdef IS_MOBILE_PLATFORM
27 #define TENSORFLOW_INCLUDED_FROM_SAMPLER_H  // prevent accidental use of
28                                             // mobile_sampler.h
29 #include "tensorflow/core/lib/monitoring/mobile_sampler.h"
30 #undef TENSORFLOW_INCLUDED_FROM_SAMPLER_H
31 #else
32 
33 #include <float.h>
34 
35 #include <map>
36 
37 #include "tensorflow/core/framework/summary.pb.h"
38 #include "tensorflow/core/lib/core/status.h"
39 #include "tensorflow/core/lib/histogram/histogram.h"
40 #include "tensorflow/core/lib/monitoring/collection_registry.h"
41 #include "tensorflow/core/lib/monitoring/metric_def.h"
42 #include "tensorflow/core/platform/macros.h"
43 #include "tensorflow/core/platform/mutex.h"
44 #include "tensorflow/core/platform/thread_annotations.h"
45 
46 namespace tensorflow {
47 namespace monitoring {
48 
49 // SamplerCell stores each value of an Sampler.
50 //
51 // A cell can be passed off to a module which may repeatedly update it without
52 // needing further map-indexing computations. This improves both encapsulation
53 // (separate modules can own a cell each, without needing to know about the map
54 // to which both cells belong) and performance (since map indexing and
55 // associated locking are both avoided).
56 //
57 // This class is thread-safe.
58 class SamplerCell {
59  public:
SamplerCell(const std::vector<double> & bucket_limits)60   SamplerCell(const std::vector<double>& bucket_limits)
61       : histogram_(bucket_limits) {}
62 
~SamplerCell()63   ~SamplerCell() {}
64 
65   // Atomically adds a sample.
66   void Add(double sample);
67 
68   // Returns the current histogram value as a proto.
69   HistogramProto value() const;
70 
71  private:
72   histogram::ThreadSafeHistogram histogram_;
73 
74   TF_DISALLOW_COPY_AND_ASSIGN(SamplerCell);
75 };
76 
77 // Bucketing strategies for the samplers.
78 //
79 // We automatically add -DBL_MAX and DBL_MAX to the ranges, so that no sample
80 // goes out of bounds.
81 //
82 // WARNING: If you are changing the interface here, please do change the same in
83 // mobile_sampler.h.
84 class Buckets {
85  public:
86   virtual ~Buckets() = default;
87 
88   // Sets up buckets of the form:
89   // [-DBL_MAX, ..., scale * growth^i,
90   //   scale * growth_factor^(i + 1), ..., DBL_MAX].
91   //
92   // So for powers of 2 with a bucket count of 10, you would say (1, 2, 10)
93   static std::unique_ptr<Buckets> Exponential(double scale,
94                                               double growth_factor,
95                                               int bucket_count);
96 
97   // Sets up buckets of the form:
98   // [-DBL_MAX, ..., bucket_limits[i], bucket_limits[i + 1], ..., DBL_MAX].
99   static std::unique_ptr<Buckets> Explicit(
100       std::initializer_list<double> bucket_limits);
101 
102   // This alternative Explicit Buckets factory method is primarily meant to be
103   // used by the CLIF layer code paths that are incompatible with
104   // initialize_lists.
105   static std::unique_ptr<Buckets> Explicit(std::vector<double> bucket_limits);
106 
107   virtual const std::vector<double>& explicit_bounds() const = 0;
108 };
109 
110 // A stateful class for updating a cumulative histogram metric.
111 //
112 // This class encapsulates a set of histograms (or a single histogram for a
113 // label-less metric) configured with a list of increasing bucket boundaries.
114 // Each histogram is identified by a tuple of labels. The class allows the
115 // user to add a sample to each histogram value.
116 //
117 // Sampler allocates storage and maintains a cell for each value. You can
118 // retrieve an individual cell using a label-tuple and update it separately.
119 // This improves performance since operations related to retrieval, like
120 // map-indexing and locking, are avoided.
121 //
122 // This class is thread-safe.
123 template <int NumLabels>
124 class Sampler {
125  public:
~Sampler()126   ~Sampler() {
127     // Deleted here, before the metric_def is destroyed.
128     registration_handle_.reset();
129   }
130 
131   // Creates the metric based on the metric-definition arguments and buckets.
132   //
133   // Example;
134   // auto* sampler_with_label = Sampler<1>::New({"/tensorflow/sampler",
135   //   "Tensorflow sampler", "MyLabelName"}, {10.0, 20.0, 30.0});
136   static Sampler* New(const MetricDef<MetricKind::kCumulative, HistogramProto,
137                                       NumLabels>& metric_def,
138                       std::unique_ptr<Buckets> buckets);
139 
140   // Retrieves the cell for the specified labels, creating it on demand if
141   // not already present.
142   template <typename... Labels>
143   SamplerCell* GetCell(const Labels&... labels) TF_LOCKS_EXCLUDED(mu_);
144 
GetStatus()145   Status GetStatus() { return status_; }
146 
147  private:
148   friend class SamplerCell;
149 
Sampler(const MetricDef<MetricKind::kCumulative,HistogramProto,NumLabels> & metric_def,std::unique_ptr<Buckets> buckets)150   Sampler(const MetricDef<MetricKind::kCumulative, HistogramProto, NumLabels>&
151               metric_def,
152           std::unique_ptr<Buckets> buckets)
153       : metric_def_(metric_def),
154         buckets_(std::move(buckets)),
155         registration_handle_(CollectionRegistry::Default()->Register(
156             &metric_def_, [&](MetricCollectorGetter getter) {
157               auto metric_collector = getter.Get(&metric_def_);
158 
159               mutex_lock l(mu_);
160               for (const auto& cell : cells_) {
161                 metric_collector.CollectValue(cell.first, cell.second.value());
162               }
163             })) {
164     if (registration_handle_) {
165       status_ = Status::OK();
166     } else {
167       status_ = Status(tensorflow::error::Code::ALREADY_EXISTS,
168                        "Another metric with the same name already exists.");
169     }
170   }
171 
172   mutable mutex mu_;
173 
174   Status status_;
175 
176   // The metric definition. This will be used to identify the metric when we
177   // register it for collection.
178   const MetricDef<MetricKind::kCumulative, HistogramProto, NumLabels>
179       metric_def_;
180 
181   // Bucket limits for the histograms in the cells.
182   std::unique_ptr<Buckets> buckets_;
183 
184   // Registration handle with the CollectionRegistry.
185   std::unique_ptr<CollectionRegistry::RegistrationHandle> registration_handle_;
186 
187   using LabelArray = std::array<string, NumLabels>;
188   // we need a container here that guarantees pointer stability of the value,
189   // namely, the pointer of the value should remain valid even after more cells
190   // are inserted.
191   std::map<LabelArray, SamplerCell> cells_ TF_GUARDED_BY(mu_);
192 
193   TF_DISALLOW_COPY_AND_ASSIGN(Sampler);
194 };
195 
196 ////
197 //  Implementation details follow. API readers may skip.
198 ////
199 
Add(const double sample)200 inline void SamplerCell::Add(const double sample) { histogram_.Add(sample); }
201 
value()202 inline HistogramProto SamplerCell::value() const {
203   HistogramProto pb;
204   histogram_.EncodeToProto(&pb, true /* preserve_zero_buckets */);
205   return pb;
206 }
207 
208 template <int NumLabels>
New(const MetricDef<MetricKind::kCumulative,HistogramProto,NumLabels> & metric_def,std::unique_ptr<Buckets> buckets)209 Sampler<NumLabels>* Sampler<NumLabels>::New(
210     const MetricDef<MetricKind::kCumulative, HistogramProto, NumLabels>&
211         metric_def,
212     std::unique_ptr<Buckets> buckets) {
213   return new Sampler<NumLabels>(metric_def, std::move(buckets));
214 }
215 
216 template <int NumLabels>
217 template <typename... Labels>
GetCell(const Labels &...labels)218 SamplerCell* Sampler<NumLabels>::GetCell(const Labels&... labels)
219     TF_LOCKS_EXCLUDED(mu_) {
220   // Provides a more informative error message than the one during array
221   // construction below.
222   static_assert(sizeof...(Labels) == NumLabels,
223                 "Mismatch between Sampler<NumLabels> and number of labels "
224                 "provided in GetCell(...).");
225 
226   const LabelArray& label_array = {{labels...}};
227   mutex_lock l(mu_);
228   const auto found_it = cells_.find(label_array);
229   if (found_it != cells_.end()) {
230     return &(found_it->second);
231   }
232   return &(cells_
233                .emplace(std::piecewise_construct,
234                         std::forward_as_tuple(label_array),
235                         std::forward_as_tuple(buckets_->explicit_bounds()))
236                .first->second);
237 }
238 
239 }  // namespace monitoring
240 }  // namespace tensorflow
241 
242 #endif  // IS_MOBILE_PLATFORM
243 #endif  // TENSORFLOW_CORE_LIB_MONITORING_SAMPLER_H_
244