1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <memory>
16
17 #include "tensorflow/core/framework/op_kernel.h"
18 #include "tensorflow/core/framework/resource_op_kernel.h"
19 #include "tensorflow/core/framework/stats_aggregator.h"
20 #include "tensorflow/core/framework/summary.pb.h"
21 #include "tensorflow/core/kernels/summary_interface.h"
22 #include "tensorflow/core/lib/core/refcount.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/lib/histogram/histogram.h"
25 #include "tensorflow/core/lib/monitoring/counter.h"
26 #include "tensorflow/core/lib/monitoring/gauge.h"
27 #include "tensorflow/core/lib/monitoring/sampler.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow/core/util/events_writer.h"
30
31 namespace tensorflow {
32 namespace data {
33 namespace experimental {
34 namespace {
35
get_counters_map_lock()36 static mutex* get_counters_map_lock() {
37 static mutex counters_map_lock(LINKER_INITIALIZED);
38 return &counters_map_lock;
39 }
40
get_counters_map()41 static std::unordered_map<string, monitoring::Counter<1>*>* get_counters_map() {
42 static std::unordered_map<string, monitoring::Counter<1>*>* counters_map =
43 new std::unordered_map<string, monitoring::Counter<1>*>;
44 return counters_map;
45 }
46
47 class StatsAggregatorImpl : public StatsAggregator {
48 public:
StatsAggregatorImpl()49 StatsAggregatorImpl() {}
50
AddToHistogram(const string & name,gtl::ArraySlice<double> values,const int64 steps)51 void AddToHistogram(const string& name, gtl::ArraySlice<double> values,
52 const int64 steps) override {
53 mutex_lock l(mu_);
54 histogram::Histogram& histogram = histograms_[name];
55 for (double value : values) {
56 histogram.Add(value);
57 }
58 }
59
AddScalar(const string & name,float value,const int64 steps)60 void AddScalar(const string& name, float value, const int64 steps) override {
61 mutex_lock l(mu_);
62 scalars_[name] = value;
63 }
64
EncodeToProto(Summary * out_summary)65 void EncodeToProto(Summary* out_summary) override {
66 mutex_lock l(mu_);
67 for (const auto& pair : histograms_) {
68 const string& name = pair.first;
69 const histogram::Histogram& histogram = pair.second;
70
71 Summary::Value* value = out_summary->add_value();
72 value->set_tag(name);
73 histogram.EncodeToProto(value->mutable_histo(),
74 false /* doesn't preserve zero buckets */);
75 }
76 for (const auto& pair : scalars_) {
77 Summary::Value* value = out_summary->add_value();
78 value->set_tag(pair.first);
79 value->set_simple_value(pair.second);
80 }
81 }
82
83 // StatsAggregator implementation for V2 is based on push-based summary, no-op
84 // in V1.
SetSummaryWriter(SummaryWriterInterface * summary_writer_interface)85 Status SetSummaryWriter(
86 SummaryWriterInterface* summary_writer_interface) override {
87 return Status::OK();
88 }
89
IncrementCounter(const string & name,const string & label,int64 val)90 void IncrementCounter(const string& name, const string& label,
91 int64 val) override {
92 mutex_lock l(*get_counters_map_lock());
93 auto counters_map = get_counters_map();
94 if (counters_map->find(name) == counters_map->end()) {
95 counters_map->emplace(
96 name,
97 monitoring::Counter<1>::New(
98 /*streamz name*/ name,
99 /*streamz description*/
100 strings::StrCat(name, " generated or consumed by the component."),
101 /*streamz label name*/ "component_descriptor"));
102 }
103 counters_map->at(name)->GetCell(label)->IncrementBy(val);
104 }
105
106 private:
107 mutex mu_;
108 std::unordered_map<string, histogram::Histogram> histograms_
109 TF_GUARDED_BY(mu_);
110 std::unordered_map<string, float> scalars_ TF_GUARDED_BY(mu_);
111 TF_DISALLOW_COPY_AND_ASSIGN(StatsAggregatorImpl);
112 };
113
114 class StatsAggregatorHandleOp
115 : public ResourceOpKernel<StatsAggregatorResource> {
116 public:
StatsAggregatorHandleOp(OpKernelConstruction * ctx)117 explicit StatsAggregatorHandleOp(OpKernelConstruction* ctx)
118 : ResourceOpKernel<StatsAggregatorResource>(ctx) {}
119
120 private:
CreateResource(StatsAggregatorResource ** ret)121 Status CreateResource(StatsAggregatorResource** ret) override
122 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
123 *ret =
124 new StatsAggregatorResource(absl::make_unique<StatsAggregatorImpl>());
125 return Status::OK();
126 }
127 };
128
129 class StatsAggregatorImplV2 : public StatsAggregator {
130 public:
StatsAggregatorImplV2()131 StatsAggregatorImplV2() {}
132
~StatsAggregatorImplV2()133 ~StatsAggregatorImplV2() override {
134 if (summary_writer_interface_) {
135 summary_writer_interface_->Unref();
136 }
137 }
138
AddToHistogram(const string & name,gtl::ArraySlice<double> values,const int64 steps)139 void AddToHistogram(const string& name, gtl::ArraySlice<double> values,
140 const int64 steps) override {
141 mutex_lock l(mu_);
142 histogram::Histogram& histogram = histograms_[name];
143 for (double value : values) {
144 histogram.Add(value);
145 }
146 AddToEvents(name, steps, histogram);
147 }
148
AddScalar(const string & name,float value,const int64 steps)149 void AddScalar(const string& name, float value, const int64 steps) override {
150 mutex_lock l(mu_);
151 AddToEvents(name, steps, value);
152 }
153
154 // TODO(b/116314787): expose this is public API to manually flush summary.
Flush()155 Status Flush() {
156 mutex_lock l(mu_);
157 if (summary_writer_interface_)
158 TF_RETURN_IF_ERROR(summary_writer_interface_->Flush());
159 return Status::OK();
160 }
161
IncrementCounter(const string & name,const string & label,int64 val)162 void IncrementCounter(const string& name, const string& label,
163 int64 val) override {
164 mutex_lock l(*get_counters_map_lock());
165 auto counters_map = get_counters_map();
166 if (counters_map->find(name) == counters_map->end()) {
167 counters_map->emplace(
168 name, monitoring::Counter<1>::New(
169 /*streamz name*/ "/tensorflow/" + name,
170 /*streamz description*/
171 name + " generated or consumed by the component.",
172 /*streamz label name*/ "component_descriptor"));
173 }
174 counters_map->at(name)->GetCell(label)->IncrementBy(val);
175 }
176
177 // StatsAggregator implementation for V1 is based on pull-based summary, no-op
178 // in V2.
EncodeToProto(Summary * out_summary)179 void EncodeToProto(Summary* out_summary) override {}
180
SetSummaryWriter(SummaryWriterInterface * summary_writer_interface)181 Status SetSummaryWriter(
182 SummaryWriterInterface* summary_writer_interface) override {
183 mutex_lock l(mu_);
184 if (summary_writer_interface_) {
185 summary_writer_interface_->Unref();
186 // If we create stats_aggregator twice in a program, we would end up with
187 // already existing resource. In this case emitting an error if a
188 // `summary_writer_resource` is present is not the intended behavior, we
189 // could either Unref the existing summary_writer_resource or not set the
190 // new resource at all.
191 }
192 summary_writer_interface_ = summary_writer_interface;
193 summary_writer_interface_->Ref();
194 return Status::OK();
195 }
196
197 private:
AddToEvents(const string & name,const int64 steps,const float scalar_value)198 void AddToEvents(const string& name, const int64 steps,
199 const float scalar_value) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
200 if (summary_writer_interface_ == nullptr) {
201 return;
202 }
203 std::unique_ptr<Event> e{new Event};
204 e->set_step(steps);
205 e->set_wall_time(EnvTime::NowMicros() / 1.0e6);
206 // maybe expose GetWallTime in SummaryWriterInterface
207 Summary::Value* v = e->mutable_summary()->add_value();
208 v->set_tag(name);
209 v->set_simple_value(scalar_value);
210 TF_CHECK_OK(summary_writer_interface_->WriteEvent(std::move(e)));
211 }
212
AddToEvents(const string & name,const int64 steps,const histogram::Histogram & histogram)213 void AddToEvents(const string& name, const int64 steps,
214 const histogram::Histogram& histogram)
215 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
216 if (summary_writer_interface_ == nullptr) {
217 return;
218 }
219 std::unique_ptr<Event> e{new Event};
220 e->set_step(steps);
221 e->set_wall_time(EnvTime::NowMicros() / 1.0e6);
222 Summary::Value* v = e->mutable_summary()->add_value();
223 v->set_tag(name);
224 histogram.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */);
225 TF_CHECK_OK(summary_writer_interface_->WriteEvent(std::move(e)));
226 }
227
228 mutex mu_;
229 SummaryWriterInterface* summary_writer_interface_ TF_GUARDED_BY(mu_) =
230 nullptr;
231 // not owned, we might be associating the default summary_writer from the
232 // context
233 std::unordered_map<string, histogram::Histogram> histograms_
234 TF_GUARDED_BY(mu_);
235 TF_DISALLOW_COPY_AND_ASSIGN(StatsAggregatorImplV2);
236 };
237
238 class StatsAggregatorHandleOpV2
239 : public ResourceOpKernel<StatsAggregatorResource> {
240 public:
StatsAggregatorHandleOpV2(OpKernelConstruction * ctx)241 explicit StatsAggregatorHandleOpV2(OpKernelConstruction* ctx)
242 : ResourceOpKernel<StatsAggregatorResource>(ctx) {}
243
244 private:
CreateResource(StatsAggregatorResource ** ret)245 Status CreateResource(StatsAggregatorResource** ret) override
246 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
247 *ret =
248 new StatsAggregatorResource(absl::make_unique<StatsAggregatorImplV2>());
249 return Status::OK();
250 }
251 };
252
253 class StatsAggregatorSummaryOp : public OpKernel {
254 public:
StatsAggregatorSummaryOp(OpKernelConstruction * ctx)255 explicit StatsAggregatorSummaryOp(OpKernelConstruction* ctx)
256 : OpKernel(ctx) {}
257
Compute(OpKernelContext * ctx)258 void Compute(OpKernelContext* ctx) override {
259 const Tensor& resource_handle_t = ctx->input(0);
260 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
261 errors::InvalidArgument("resource_handle must be a scalar"));
262
263 core::RefCountPtr<StatsAggregatorResource> resource;
264 OP_REQUIRES_OK(ctx,
265 LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
266
267 Tensor* summary_t;
268 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &summary_t));
269 Summary summary;
270 resource->stats_aggregator()->EncodeToProto(&summary);
271 summary_t->scalar<tstring>()() = summary.SerializeAsString();
272 }
273 };
274
275 class StatsAggregatorSetSummaryWriterOp : public OpKernel {
276 public:
StatsAggregatorSetSummaryWriterOp(OpKernelConstruction * ctx)277 explicit StatsAggregatorSetSummaryWriterOp(OpKernelConstruction* ctx)
278 : OpKernel(ctx) {}
279
Compute(OpKernelContext * ctx)280 void Compute(OpKernelContext* ctx) override {
281 const Tensor& resource_handle_t = ctx->input(0);
282 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
283 errors::InvalidArgument("resource_handle must be a scalar"));
284
285 core::RefCountPtr<StatsAggregatorResource> resource;
286 OP_REQUIRES_OK(ctx,
287 LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
288
289 const Tensor& summary_resource_handle_t = ctx->input(1);
290 OP_REQUIRES(ctx,
291 TensorShapeUtils::IsScalar(summary_resource_handle_t.shape()),
292 errors::InvalidArgument("resource_handle must be a scalar"));
293 core::RefCountPtr<SummaryWriterInterface> summary_resource;
294 OP_REQUIRES_OK(
295 ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &summary_resource));
296 TF_CHECK_OK(
297 resource->stats_aggregator()->SetSummaryWriter(summary_resource.get()));
298 }
299 };
300
301 REGISTER_KERNEL_BUILDER(Name("StatsAggregatorHandle").Device(DEVICE_CPU),
302 StatsAggregatorHandleOp);
303 REGISTER_KERNEL_BUILDER(
304 Name("ExperimentalStatsAggregatorHandle").Device(DEVICE_CPU),
305 StatsAggregatorHandleOp);
306
307 REGISTER_KERNEL_BUILDER(Name("StatsAggregatorHandleV2").Device(DEVICE_CPU),
308 StatsAggregatorHandleOpV2);
309
310 REGISTER_KERNEL_BUILDER(Name("StatsAggregatorSummary").Device(DEVICE_CPU),
311 StatsAggregatorSummaryOp);
312 REGISTER_KERNEL_BUILDER(
313 Name("ExperimentalStatsAggregatorSummary").Device(DEVICE_CPU),
314 StatsAggregatorSummaryOp);
315
316 REGISTER_KERNEL_BUILDER(
317 Name("StatsAggregatorSetSummaryWriter").Device(DEVICE_CPU),
318 StatsAggregatorSetSummaryWriterOp);
319
320 } // namespace
321 } // namespace experimental
322 } // namespace data
323 } // namespace tensorflow
324