1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""StatsAggregator for aggregating statistics from `tf.data` pipelines.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 21from tensorflow.python.util.tf_export import tf_export 22 23 24@tf_export("data.experimental.StatsAggregator") 25class StatsAggregator(object): 26 """A stateful resource that aggregates statistics from one or more iterators. 27 28 To record statistics, use one of the custom transformation functions defined 29 in this module when defining your `tf.data.Dataset`. All statistics will be 30 aggregated by the `StatsAggregator` that is associated with a particular 31 iterator (see below). For example, to record the latency of producing each 32 element by iterating over a dataset: 33 34 ```python 35 dataset = ... 36 dataset = dataset.apply(tf.data.experimental.latency_stats("total_bytes")) 37 ``` 38 39 To associate a `StatsAggregator` with a `tf.data.Dataset` object, use 40 the following pattern: 41 42 ```python 43 aggregator = tf.data.experimental.StatsAggregator() 44 dataset = ... 45 46 # Apply `StatsOptions` to associate `dataset` with `aggregator`. 47 options = tf.data.Options() 48 options.experimental_stats.aggregator = aggregator 49 dataset = dataset.with_options(options) 50 ``` 51 52 To get a protocol buffer summary of the currently aggregated statistics, 53 use the `StatsAggregator.get_summary()` tensor. The easiest way to do this 54 is to add the returned tensor to the `tf.GraphKeys.SUMMARIES` collection, 55 so that the summaries will be included with any existing summaries. 56 57 ```python 58 aggregator = tf.data.experimental.StatsAggregator() 59 # ... 60 stats_summary = aggregator.get_summary() 61 tf.add_to_collection(tf.GraphKeys.SUMMARIES, stats_summary) 62 ``` 63 64 Note: This interface is experimental and expected to change. In particular, 65 we expect to add other implementations of `StatsAggregator` that provide 66 different ways of exporting statistics, and add more types of statistics. 67 """ 68 69 def __init__(self): 70 """Creates a `StatsAggregator`.""" 71 self._resource = ged_ops.experimental_stats_aggregator_handle() 72 73 # TODO(b/116314787): Update this/add support for V2 summary API. 74 def get_summary(self): 75 """Returns a string `tf.Tensor` that summarizes the aggregated statistics. 76 77 The returned tensor will contain a serialized `tf.summary.Summary` protocol 78 buffer, which can be used with the standard TensorBoard logging facilities. 79 80 Returns: 81 A scalar string `tf.Tensor` that summarizes the aggregated statistics. 82 """ 83 return ged_ops.experimental_stats_aggregator_summary(self._resource) 84