• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `tf.data.Dataset.reduce()`."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import time
21
22from absl.testing import parameterized
23import numpy as np
24
25from tensorflow.python.data.experimental.ops import testing
26from tensorflow.python.data.kernel_tests import test_base
27from tensorflow.python.data.ops import dataset_ops
28from tensorflow.python.eager import def_function
29from tensorflow.python.framework import combinations
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import errors
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import sparse_tensor
35from tensorflow.python.framework import test_util
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import variables
39from tensorflow.python.platform import test
40
41
42class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase):
43
44  @combinations.generate(test_base.default_test_combinations())
45  def testSum(self):
46    for i in range(10):
47      ds = dataset_ops.Dataset.range(1, i + 1)
48      result = ds.reduce(np.int64(0), lambda x, y: x + y)
49      self.assertEqual(((i + 1) * i) // 2, self.evaluate(result))
50
51  @combinations.generate(test_base.default_test_combinations())
52  def testSumTuple(self):
53
54    def reduce_fn(state, value):
55      v1, v2 = value
56      return state + v1 + v2
57
58    for i in range(10):
59      ds = dataset_ops.Dataset.range(1, i + 1)
60      ds = dataset_ops.Dataset.zip((ds, ds))
61      result = ds.reduce(constant_op.constant(0, dtype=dtypes.int64), reduce_fn)
62      self.assertEqual(((i + 1) * i), self.evaluate(result))
63
64  @combinations.generate(test_base.default_test_combinations())
65  def testSumAndCount(self):
66
67    def reduce_fn(state, value):
68      s, c = state
69      return s + value, c + 1
70
71    for i in range(10):
72      ds = dataset_ops.Dataset.range(1, i + 1)
73      result = ds.reduce((constant_op.constant(0, dtype=dtypes.int64),
74                          constant_op.constant(0, dtype=dtypes.int64)),
75                         reduce_fn)
76      s, c = self.evaluate(result)
77      self.assertEqual(((i + 1) * i) // 2, s)
78      self.assertEqual(i, c)
79
80  @combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
81  def testSquareUsingPlaceholder(self):
82    delta = array_ops.placeholder(dtype=dtypes.int64)
83
84    def reduce_fn(state, _):
85      return state + delta
86
87    for i in range(10):
88      ds = dataset_ops.Dataset.range(1, i + 1)
89      result = ds.reduce(np.int64(0), reduce_fn)
90      with self.cached_session() as sess:
91        square = sess.run(result, feed_dict={delta: i})
92        self.assertEqual(i * i, square)
93
94  @combinations.generate(test_base.default_test_combinations())
95  def testSparse(self):
96
97    def reduce_fn(_, value):
98      return value
99
100    def make_sparse_fn(i):
101      return sparse_tensor.SparseTensorValue(
102          indices=np.array([[0, 0]]),
103          values=(i * np.array([1])),
104          dense_shape=np.array([1, 1]))
105
106    for i in range(10):
107      ds = dataset_ops.Dataset.from_tensors(make_sparse_fn(i+1))
108      result = ds.reduce(make_sparse_fn(0), reduce_fn)
109      self.assertValuesEqual(make_sparse_fn(i + 1), self.evaluate(result))
110
111  @combinations.generate(test_base.default_test_combinations())
112  def testNested(self):
113
114    def reduce_fn(state, value):
115      state["dense"] += value["dense"]
116      state["sparse"] = value["sparse"]
117      return state
118
119    def make_sparse_fn(i):
120      return sparse_tensor.SparseTensorValue(
121          indices=np.array([[0, 0]]),
122          values=(i * np.array([1])),
123          dense_shape=np.array([1, 1]))
124
125    def map_fn(i):
126      return {"dense": math_ops.cast(i, dtype=dtypes.int64),
127              "sparse": make_sparse_fn(math_ops.cast(i, dtype=dtypes.int64))}
128
129    for i in range(10):
130      ds = dataset_ops.Dataset.range(1, i + 1).map(map_fn)
131      result = ds.reduce(map_fn(0), reduce_fn)
132      result = self.evaluate(result)
133      self.assertEqual(((i + 1) * i) // 2, result["dense"])
134      self.assertValuesEqual(make_sparse_fn(i), result["sparse"])
135
136  @combinations.generate(test_base.default_test_combinations())
137  def testDatasetSideEffect(self):
138    counter_var = variables.Variable(0)
139
140    def increment_fn(x):
141      counter_var.assign_add(1)
142      return x
143
144    def dataset_fn():
145      return dataset_ops.Dataset.range(10).map(increment_fn)
146
147    def reduce_fn(state, value):
148      return state + value
149
150    @def_function.function
151    def fn():
152      _ = dataset_fn().reduce(np.int64(0), reduce_fn)
153      return "hello"
154
155    self.evaluate(counter_var.initializer)
156    self.assertEqual(self.evaluate(fn()), b"hello")
157    self.assertEqual(self.evaluate(counter_var), 10)
158
159  @combinations.generate(test_base.default_test_combinations())
160  def testSideEffect(self):
161    counter_var = variables.Variable(0)
162
163    def dataset_fn():
164      return dataset_ops.Dataset.range(10)
165
166    def reduce_fn(state, value):
167      counter_var.assign_add(1)
168      return state + value
169
170    @def_function.function
171    def fn():
172      _ = dataset_fn().reduce(np.int64(0), reduce_fn)
173      return "hello"
174
175    self.evaluate(counter_var.initializer)
176    self.assertEqual(self.evaluate(fn()), b"hello")
177    self.assertEqual(self.evaluate(counter_var), 10)
178
179  @combinations.generate(test_base.default_test_combinations())
180  def testAutomaticControlDependencies(self):
181    counter_var = variables.Variable(1)
182
183    def dataset_fn():
184      return dataset_ops.Dataset.range(1)
185
186    def reduce1_fn(state, value):
187      counter_var.assign(counter_var + 1)
188      return state + value
189
190    def reduce2_fn(state, value):
191      counter_var.assign(counter_var * 2)
192      return state + value
193
194    @def_function.function
195    def fn():
196      _ = dataset_fn().reduce(np.int64(0), reduce1_fn)
197      _ = dataset_fn().reduce(np.int64(0), reduce2_fn)
198      return "hello"
199
200    self.evaluate(counter_var.initializer)
201    self.assertEqual(self.evaluate(fn()), b"hello")
202    self.assertEqual(self.evaluate(counter_var), 4)
203
204  @combinations.generate(test_base.default_test_combinations())
205  def testNestedAutomaticControlDependencies(self):
206    counter_var = variables.Variable(0)
207
208    def map_fn(x):
209      counter_var.assign_add(1)
210      return x
211
212    def dataset_fn():
213      return dataset_ops.Dataset.range(10).map(map_fn)
214
215    @def_function.function
216    def fn():
217      for _ in dataset_fn():
218        pass
219      return counter_var
220
221    self.evaluate(counter_var.initializer)
222    self.assertEqual(self.evaluate(fn()), 10)
223
224  @combinations.generate(test_base.default_test_combinations())
225  def testStateOnGPU(self):
226    if not test_util.is_gpu_available():
227      self.skipTest("No GPUs available.")
228
229    state = constant_op.constant(0, dtype=dtypes.int64)
230
231    def reduce_fn(state, value):
232      with ops.device("/gpu:0"):
233        return state + value
234
235    for i in range(10):
236      ds = dataset_ops.Dataset.range(1, i + 1)
237      result = ds.reduce(state, reduce_fn)
238      self.assertEqual(((i + 1) * i) // 2, self.evaluate(result))
239
240  @combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
241  def testCancellation(self):
242    ds = dataset_ops.Dataset.from_tensors(1).repeat()
243    result = ds.reduce(0, lambda x, y: x + y)
244    with self.cached_session() as sess:
245      # The `result` op is guaranteed to not complete before cancelled because
246      # the dataset that is being reduced is infinite.
247      thread = self.checkedThread(self.assert_op_cancelled, args=(result,))
248      thread.start()
249      time.sleep(0.2)
250      sess.close()
251      thread.join()
252
253  @combinations.generate(test_base.default_test_combinations())
254  def testInvalidFunction(self):
255    ds = dataset_ops.Dataset.range(5)
256    with self.assertRaises(errors.InvalidArgumentError):
257      self.evaluate(ds.reduce(0, lambda _, __: ()))
258
259  @combinations.generate(test_base.default_test_combinations())
260  def testOptions(self):
261    dataset = dataset_ops.Dataset.range(5)
262    dataset = dataset.apply(testing.assert_next(["MapAndBatch"]))
263    dataset = dataset.map(lambda x: x * 2).batch(5)
264    self.evaluate(dataset.reduce(0, lambda state, value: state))
265
266
267if __name__ == "__main__":
268  test.main()
269