1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/stats_aggregator.h"
16
17 #include <memory>
18
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/resource_op_kernel.h"
21 #include "tensorflow/core/framework/summary.pb.h"
22 #include "tensorflow/core/lib/histogram/histogram.h"
23 #include "tensorflow/core/lib/monitoring/counter.h"
24 #include "tensorflow/core/lib/monitoring/gauge.h"
25 #include "tensorflow/core/lib/monitoring/sampler.h"
26 #include "tensorflow/core/platform/macros.h"
27
28 namespace tensorflow {
29 namespace data {
30 namespace {
31
get_counters_map_lock()32 static mutex* get_counters_map_lock() {
33 static mutex counters_map_lock(LINKER_INITIALIZED);
34 return &counters_map_lock;
35 }
36
get_counters_map()37 static std::unordered_map<string, monitoring::Counter<1>*>* get_counters_map() {
38 static std::unordered_map<string, monitoring::Counter<1>*>* counters_map =
39 new std::unordered_map<string, monitoring::Counter<1>*>;
40 return counters_map;
41 }
42
43 class StatsAggregatorImpl : public StatsAggregator {
44 public:
StatsAggregatorImpl()45 StatsAggregatorImpl() {}
46
AddToHistogram(const string & name,gtl::ArraySlice<double> values)47 void AddToHistogram(const string& name,
48 gtl::ArraySlice<double> values) override {
49 mutex_lock l(mu_);
50 histogram::Histogram& histogram = histograms_[name];
51 for (double value : values) {
52 histogram.Add(value);
53 }
54 }
55
AddScalar(const string & name,float value)56 void AddScalar(const string& name, float value) override {
57 mutex_lock l(mu_);
58 scalars_[name] = value;
59 }
60
EncodeToProto(Summary * out_summary)61 void EncodeToProto(Summary* out_summary) override {
62 mutex_lock l(mu_);
63 for (const auto& pair : histograms_) {
64 const string& name = pair.first;
65 const histogram::Histogram& histogram = pair.second;
66
67 Summary::Value* value = out_summary->add_value();
68 value->set_tag(name);
69 histogram.EncodeToProto(value->mutable_histo(),
70 false /* doesn't preserve zero buckets */);
71 }
72 for (const auto& pair : scalars_) {
73 Summary::Value* value = out_summary->add_value();
74 value->set_tag(pair.first);
75 value->set_simple_value(pair.second);
76 }
77 }
78
IncrementCounter(const string & name,const string & label,int64 val)79 void IncrementCounter(const string& name, const string& label,
80 int64 val) override {
81 mutex_lock l(*get_counters_map_lock());
82 auto counters_map = get_counters_map();
83 if (counters_map->find(name) == counters_map->end()) {
84 counters_map->emplace(
85 name,
86 monitoring::Counter<1>::New(
87 /*streamz name*/ name,
88 /*streamz description*/
89 strings::StrCat(name, " generated or consumed by the component."),
90 /*streamz label name*/ "component_descriptor"));
91 }
92 counters_map->at(name)->GetCell(label)->IncrementBy(val);
93 }
94
95 private:
96 mutex mu_;
97 std::unordered_map<string, histogram::Histogram> histograms_ GUARDED_BY(mu_);
98 std::unordered_map<string, float> scalars_ GUARDED_BY(mu_);
99 TF_DISALLOW_COPY_AND_ASSIGN(StatsAggregatorImpl);
100 };
101
102 class StatsAggregatorHandleOp
103 : public ResourceOpKernel<StatsAggregatorResource> {
104 public:
StatsAggregatorHandleOp(OpKernelConstruction * ctx)105 explicit StatsAggregatorHandleOp(OpKernelConstruction* ctx)
106 : ResourceOpKernel<StatsAggregatorResource>(ctx) {}
107
108 private:
CreateResource(StatsAggregatorResource ** ret)109 Status CreateResource(StatsAggregatorResource** ret) override
110 EXCLUSIVE_LOCKS_REQUIRED(mu_) {
111 *ret =
112 new StatsAggregatorResource(absl::make_unique<StatsAggregatorImpl>());
113 return Status::OK();
114 }
115
VerifyResource(StatsAggregatorResource * resource)116 Status VerifyResource(StatsAggregatorResource* resource) override {
117 return Status::OK();
118 }
119 };
120
121 class StatsAggregatorSummaryOp : public OpKernel {
122 public:
StatsAggregatorSummaryOp(OpKernelConstruction * ctx)123 explicit StatsAggregatorSummaryOp(OpKernelConstruction* ctx)
124 : OpKernel(ctx) {}
125
Compute(OpKernelContext * ctx)126 void Compute(OpKernelContext* ctx) override {
127 const Tensor& resource_handle_t = ctx->input(0);
128 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
129 errors::InvalidArgument("resource_handle must be a scalar"));
130
131 StatsAggregatorResource* resource;
132 OP_REQUIRES_OK(ctx,
133 LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
134 core::ScopedUnref unref_iterator(resource);
135
136 Tensor* summary_t;
137 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &summary_t));
138 Summary summary;
139 resource->stats_aggregator()->EncodeToProto(&summary);
140 summary_t->scalar<string>()() = summary.SerializeAsString();
141 }
142 };
143
144 REGISTER_KERNEL_BUILDER(
145 Name("ExperimentalStatsAggregatorHandle").Device(DEVICE_CPU),
146 StatsAggregatorHandleOp);
147 REGISTER_KERNEL_BUILDER(
148 Name("ExperimentalStatsAggregatorSummary").Device(DEVICE_CPU),
149 StatsAggregatorSummaryOp);
150
151 } // namespace
152 } // namespace data
153 } // namespace tensorflow
154