• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_
16 #define TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_
17 
18 #include <memory>
19 #include <string>
20 
21 #include "tensorflow/core/framework/resource_mgr.h"
22 #include "tensorflow/core/lib/gtl/array_slice.h"
23 
24 namespace tensorflow {
25 
26 class Summary;
27 class SummaryWriterInterface;
28 namespace data {
29 
30 // A `StatsAggregator` accumulates statistics incrementally. A
31 // `StatsAggregator` can accumulate multiple different statistics, distinguished
32 // by a string name.
33 //
34 // The class currently supports accumulating `Histogram`, `scalar` objects and
35 // tfstreamz metrics, and we expect to add other methods in future.
36 //
37 // NOTE(mrry): `StatsAggregator` is a virtual interface because we anticipate
38 // that many different implementations will have the same interface. For
39 // example, we have different implementations in "stats_aggregator_ops.cc" for
40 // simple in-memory implementation that integrates with the pull-based summary
41 // API, and for the push-based `SummaryWriterInterface`, and we may add
42 // implementations that work well with other custom monitoring services.
43 class StatsAggregator {
44  public:
~StatsAggregator()45   virtual ~StatsAggregator() {}
46 
47   // Add the given `values` to the histogram with the given `name`. Each
48   // element of `values` will be treated as a separate sample in the histogram.
49   virtual void AddToHistogram(const string& name,
50                               gtl::ArraySlice<double> values,
51                               int64 global_step) = 0;
52 
53   // TODO(shivaniagrawal): consistency in double and float usage.
54   // Add the given `value` as Scalar with the given `name`.
55   virtual void AddScalar(const string& name, float value,
56                          int64 global_step) = 0;
57 
58   // Stores a protocol buffer representation of the aggregator state in the
59   // given `out_summary`.
60   virtual void EncodeToProto(Summary* out_summary) = 0;
61 
62   // Sets a `summary_writer` with this stats_aggregator.
63   virtual Status SetSummaryWriter(SummaryWriterInterface* summary_writer) = 0;
64 
65   // Increment the `label` cell of metrics mapped with `name` by given `value`.
66   virtual void IncrementCounter(const string& name, const string& label,
67                                 int64 val) = 0;
68 };
69 
70 // A `StatsAggregatorResource` wraps a sharable `StatsAggregator` as a resource
71 // in the TensorFlow resource manager.
72 //
73 // NOTE(mrry): This class is separate from `StatsAggregator` in order to
74 // simplify the memory management of the shared object. Most users of
75 // `StatsAggregator` interact with a `std::shared_ptr<StatsAggregator>` whereas
76 // the `ResourceBase` API requires explicit reference counting.
77 class StatsAggregatorResource : public ResourceBase {
78  public:
79   // Creates a new resource from the given `stats_aggregator`.
StatsAggregatorResource(std::unique_ptr<StatsAggregator> stats_aggregator)80   StatsAggregatorResource(std::unique_ptr<StatsAggregator> stats_aggregator)
81       : stats_aggregator_(stats_aggregator.release()) {}
82 
83   // Returns the wrapped `StatsAggregator`.
stats_aggregator()84   std::shared_ptr<StatsAggregator> stats_aggregator() const {
85     return stats_aggregator_;
86   }
87 
DebugString()88   string DebugString() const override { return "StatsAggregatorResource"; }
89 
90  private:
91   const std::shared_ptr<StatsAggregator> stats_aggregator_;
92 };
93 
94 }  // namespace data
95 }  // namespace tensorflow
96 
97 #endif  // TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_
98