• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Experimental API for gathering statistics from `tf.data` pipelines."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.data.ops import dataset_ops
21from tensorflow.python.data.ops import iterator_ops
22from tensorflow.python.data.util import nest
23from tensorflow.python.data.util import sparse
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import gen_dataset_ops
27
28
29class StatsAggregator(object):
30  """A stateful resource that aggregates statistics from one or more iterators.
31
32  To record statistics, use one of the custom transformation functions defined
33  in this module when defining your @{tf.data.Dataset}. All statistics will be
34  aggregated by the `StatsAggregator` that is associated with a particular
35  iterator (see below). For example, to record the total number of bytes
36  produced by iterating over a dataset:
37
38  ```python
39  dataset = ...
40  dataset = dataset.apply(stats_ops.bytes_produced_stats("total_bytes"))
41  ```
42
43  To associate a `StatsAggregator` with a @{tf.data.Iterator} object, use
44  the following pattern:
45
46  ```python
47  dataset = ...
48  iterator = dataset.make_one_shot_iterator()
49  stats_aggregator = stats_ops.StatsAggregator()
50  set_op = stats_op.set_stats_aggregator_op(iterator, stats_aggregator)
51
52  with tf.Session() as sess:
53    # Running `set_op` will associate `iterator` with `stats_aggregator`.
54    sess.run(set_op)
55  ```
56
57  To get a protocol buffer summary of the currently aggregated statistics,
58  use the `StatsAggregator.get_summary()` tensor. The easiest way to do this
59  is to add the returned tensor to the @{tf.GraphKeys.SUMMARIES} collection,
60  so that the summaries will be included with any existing summaries.
61
62  ```python
63  stats_aggregator = stats_ops.StatsAggregator()
64  stats_summary = stats_aggregator.get_summary()
65  tf.add_to_collection(tf.GraphKeys.SUMMARIES, stats_summary)
66  ```
67
68  Note: This interface is experimental and expected to change. In particular,
69  we expect to add other implementations of `StatsAggregator` that provide
70  different ways of exporting statistics, and add more types of statistics.
71  """
72
73  def __init__(self):
74    """Creates a `StatsAggregator`."""
75    self._resource = gen_dataset_ops.stats_aggregator_handle()
76
77  def get_summary(self):
78    """Returns a string @{tf.Tensor} that summarizes the aggregated statistics.
79
80    The returned tensor will contain a serialized @{tf.summary.Summary} protocol
81    buffer, which can be used with the standard TensorBoard logging facilities.
82
83    Returns:
84      A scalar string @{tf.Tensor} that summarizes the aggregated statistics.
85    """
86    return gen_dataset_ops.stats_aggregator_summary(self._resource)
87
88  def subscribe(self, iterator):
89    """Returns a @{tf.Operation} to associate this aggregator with `iterator`.
90
91    Note: Each @{tf.data.Iterator} can be associated with at most one
92    `StatsAggregator`. After running the operation that this function
93    returns, all statistics recorded in the iteration of `iterator`
94    will be stored in `stats_aggregator`.
95
96    Args:
97      iterator: A @{tf.data.Iterator} object.
98
99    Returns:
100      A @{tf.Operation} that, when run, associates this aggregator with
101      `iterator`.
102    """
103    if not isinstance(iterator, iterator_ops.Iterator):
104      raise TypeError("`iterator` must be a `tf.data.Iterator` object.")
105    return gen_dataset_ops.iterator_set_stats_aggregator(
106        iterator._iterator_resource, self._resource)  # pylint: disable=protected-access
107
108
109def bytes_produced_stats(tag):
110  """Records the number of bytes produced by each element of the input dataset.
111
112  To consume the statistics, associate a `StatsAggregator` with an iterator
113  over the output dataset.
114
115  Args:
116    tag: String. All statistics recorded by the returned transformation will
117      be associated with the given `tag`.
118
119  Returns:
120    A `Dataset` transformation function, which can be passed to
121    @{tf.data.Dataset.apply}.
122  """
123
124  def _apply_fn(dataset):
125    return _StatsDataset(dataset, gen_dataset_ops.bytes_produced_stats_dataset,
126                         tag)
127
128  return _apply_fn
129
130
131def latency_stats(tag):
132  """Records the latency of producing each element of the input dataset.
133
134  To consume the statistics, associate a `StatsAggregator` with an iterator
135  over the output dataset.
136
137  Args:
138    tag: String. All statistics recorded by the returned transformation will
139      be associated with the given `tag`.
140
141  Returns:
142    A `Dataset` transformation function, which can be passed to
143    @{tf.data.Dataset.apply}.
144  """
145
146  def _apply_fn(dataset):
147    return _StatsDataset(dataset, gen_dataset_ops.latency_stats_dataset, tag)
148
149  return _apply_fn
150
151
152class _StatsDataset(dataset_ops.Dataset):
153  """A `Dataset` that acts as an identity, and also records statistics."""
154
155  def __init__(self, input_dataset, op_function, tag):
156    super(_StatsDataset, self).__init__()
157    self._input_dataset = input_dataset
158    self._op_function = op_function
159    self._tag = ops.convert_to_tensor(tag, dtype=dtypes.string)
160
161  def _as_variant_tensor(self):
162    return self._op_function(
163        self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
164        self._tag,
165        output_types=nest.flatten(
166            sparse.as_dense_types(self.output_types, self.output_classes)),
167        output_shapes=nest.flatten(
168            sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
169
170  @property
171  def output_shapes(self):
172    return self._input_dataset.output_shapes
173
174  @property
175  def output_types(self):
176    return self._input_dataset.output_types
177
178  @property
179  def output_classes(self):
180    return self._input_dataset.output_classes
181