• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/stats_aggregator.h"
16 
17 #include <memory>
18 
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/resource_op_kernel.h"
21 #include "tensorflow/core/framework/summary.pb.h"
22 #include "tensorflow/core/lib/histogram/histogram.h"
23 #include "tensorflow/core/lib/monitoring/counter.h"
24 #include "tensorflow/core/lib/monitoring/gauge.h"
25 #include "tensorflow/core/lib/monitoring/sampler.h"
26 #include "tensorflow/core/platform/macros.h"
27 
28 namespace tensorflow {
29 namespace data {
30 namespace {
31 
get_counters_map_lock()32 static mutex* get_counters_map_lock() {
33   static mutex counters_map_lock(LINKER_INITIALIZED);
34   return &counters_map_lock;
35 }
36 
get_counters_map()37 static std::unordered_map<string, monitoring::Counter<1>*>* get_counters_map() {
38   static std::unordered_map<string, monitoring::Counter<1>*>* counters_map =
39       new std::unordered_map<string, monitoring::Counter<1>*>;
40   return counters_map;
41 }
42 
43 class StatsAggregatorImpl : public StatsAggregator {
44  public:
StatsAggregatorImpl()45   StatsAggregatorImpl() {}
46 
AddToHistogram(const string & name,gtl::ArraySlice<double> values)47   void AddToHistogram(const string& name,
48                       gtl::ArraySlice<double> values) override {
49     mutex_lock l(mu_);
50     histogram::Histogram& histogram = histograms_[name];
51     for (double value : values) {
52       histogram.Add(value);
53     }
54   }
55 
AddScalar(const string & name,float value)56   void AddScalar(const string& name, float value) override {
57     mutex_lock l(mu_);
58     scalars_[name] = value;
59   }
60 
EncodeToProto(Summary * out_summary)61   void EncodeToProto(Summary* out_summary) override {
62     mutex_lock l(mu_);
63     for (const auto& pair : histograms_) {
64       const string& name = pair.first;
65       const histogram::Histogram& histogram = pair.second;
66 
67       Summary::Value* value = out_summary->add_value();
68       value->set_tag(name);
69       histogram.EncodeToProto(value->mutable_histo(),
70                               false /* doesn't preserve zero buckets */);
71     }
72     for (const auto& pair : scalars_) {
73       Summary::Value* value = out_summary->add_value();
74       value->set_tag(pair.first);
75       value->set_simple_value(pair.second);
76     }
77   }
78 
IncrementCounter(const string & name,const string & label,int64 val)79   void IncrementCounter(const string& name, const string& label,
80                         int64 val) override {
81     mutex_lock l(*get_counters_map_lock());
82     auto counters_map = get_counters_map();
83     if (counters_map->find(name) == counters_map->end()) {
84       counters_map->emplace(
85           name,
86           monitoring::Counter<1>::New(
87               /*streamz name*/ name,
88               /*streamz description*/
89               strings::StrCat(name, " generated or consumed by the component."),
90               /*streamz label name*/ "component_descriptor"));
91     }
92     counters_map->at(name)->GetCell(label)->IncrementBy(val);
93   }
94 
95  private:
96   mutex mu_;
97   std::unordered_map<string, histogram::Histogram> histograms_ GUARDED_BY(mu_);
98   std::unordered_map<string, float> scalars_ GUARDED_BY(mu_);
99   TF_DISALLOW_COPY_AND_ASSIGN(StatsAggregatorImpl);
100 };
101 
102 class StatsAggregatorHandleOp
103     : public ResourceOpKernel<StatsAggregatorResource> {
104  public:
StatsAggregatorHandleOp(OpKernelConstruction * ctx)105   explicit StatsAggregatorHandleOp(OpKernelConstruction* ctx)
106       : ResourceOpKernel<StatsAggregatorResource>(ctx) {}
107 
108  private:
CreateResource(StatsAggregatorResource ** ret)109   Status CreateResource(StatsAggregatorResource** ret) override
110       EXCLUSIVE_LOCKS_REQUIRED(mu_) {
111     *ret =
112         new StatsAggregatorResource(absl::make_unique<StatsAggregatorImpl>());
113     return Status::OK();
114   }
115 
VerifyResource(StatsAggregatorResource * resource)116   Status VerifyResource(StatsAggregatorResource* resource) override {
117     return Status::OK();
118   }
119 };
120 
121 class StatsAggregatorSummaryOp : public OpKernel {
122  public:
StatsAggregatorSummaryOp(OpKernelConstruction * ctx)123   explicit StatsAggregatorSummaryOp(OpKernelConstruction* ctx)
124       : OpKernel(ctx) {}
125 
Compute(OpKernelContext * ctx)126   void Compute(OpKernelContext* ctx) override {
127     const Tensor& resource_handle_t = ctx->input(0);
128     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
129                 errors::InvalidArgument("resource_handle must be a scalar"));
130 
131     StatsAggregatorResource* resource;
132     OP_REQUIRES_OK(ctx,
133                    LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
134     core::ScopedUnref unref_iterator(resource);
135 
136     Tensor* summary_t;
137     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &summary_t));
138     Summary summary;
139     resource->stats_aggregator()->EncodeToProto(&summary);
140     summary_t->scalar<string>()() = summary.SerializeAsString();
141   }
142 };
143 
144 REGISTER_KERNEL_BUILDER(
145     Name("ExperimentalStatsAggregatorHandle").Device(DEVICE_CPU),
146     StatsAggregatorHandleOp);
147 REGISTER_KERNEL_BUILDER(
148     Name("ExperimentalStatsAggregatorSummary").Device(DEVICE_CPU),
149     StatsAggregatorSummaryOp);
150 
151 }  // namespace
152 }  // namespace data
153 }  // namespace tensorflow
154