1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `tf.data.Dataset.reduce()`.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import time 21 22from absl.testing import parameterized 23import numpy as np 24 25from tensorflow.python.data.experimental.ops import testing 26from tensorflow.python.data.kernel_tests import test_base 27from tensorflow.python.data.ops import dataset_ops 28from tensorflow.python.eager import def_function 29from tensorflow.python.framework import combinations 30from tensorflow.python.framework import constant_op 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import errors 33from tensorflow.python.framework import ops 34from tensorflow.python.framework import sparse_tensor 35from tensorflow.python.framework import test_util 36from tensorflow.python.ops import array_ops 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops import variables 39from tensorflow.python.platform import test 40 41 42class ReduceTest(test_base.DatasetTestBase, parameterized.TestCase): 43 44 @combinations.generate(test_base.default_test_combinations()) 45 def testSum(self): 46 for i in range(10): 47 ds = dataset_ops.Dataset.range(1, i + 1) 48 result = ds.reduce(np.int64(0), lambda x, y: x + y) 49 self.assertEqual(((i + 1) * i) // 2, self.evaluate(result)) 50 51 @combinations.generate(test_base.default_test_combinations()) 52 def testSumTuple(self): 53 54 def reduce_fn(state, value): 55 v1, v2 = value 56 return state + v1 + v2 57 58 for i in range(10): 59 ds = dataset_ops.Dataset.range(1, i + 1) 60 ds = dataset_ops.Dataset.zip((ds, ds)) 61 result = ds.reduce(constant_op.constant(0, dtype=dtypes.int64), reduce_fn) 62 self.assertEqual(((i + 1) * i), self.evaluate(result)) 63 64 @combinations.generate(test_base.default_test_combinations()) 65 def testSumAndCount(self): 66 67 def reduce_fn(state, value): 68 s, c = state 69 return s + value, c + 1 70 71 for i in range(10): 72 ds = dataset_ops.Dataset.range(1, i + 1) 73 result = ds.reduce((constant_op.constant(0, dtype=dtypes.int64), 74 constant_op.constant(0, dtype=dtypes.int64)), 75 reduce_fn) 76 s, c = self.evaluate(result) 77 self.assertEqual(((i + 1) * i) // 2, s) 78 self.assertEqual(i, c) 79 80 @combinations.generate(combinations.combine(tf_api_version=1, mode="graph")) 81 def testSquareUsingPlaceholder(self): 82 delta = array_ops.placeholder(dtype=dtypes.int64) 83 84 def reduce_fn(state, _): 85 return state + delta 86 87 for i in range(10): 88 ds = dataset_ops.Dataset.range(1, i + 1) 89 result = ds.reduce(np.int64(0), reduce_fn) 90 with self.cached_session() as sess: 91 square = sess.run(result, feed_dict={delta: i}) 92 self.assertEqual(i * i, square) 93 94 @combinations.generate(test_base.default_test_combinations()) 95 def testSparse(self): 96 97 def reduce_fn(_, value): 98 return value 99 100 def make_sparse_fn(i): 101 return sparse_tensor.SparseTensorValue( 102 indices=np.array([[0, 0]]), 103 values=(i * np.array([1])), 104 dense_shape=np.array([1, 1])) 105 106 for i in range(10): 107 ds = dataset_ops.Dataset.from_tensors(make_sparse_fn(i+1)) 108 result = ds.reduce(make_sparse_fn(0), reduce_fn) 109 self.assertValuesEqual(make_sparse_fn(i + 1), self.evaluate(result)) 110 111 @combinations.generate(test_base.default_test_combinations()) 112 def testNested(self): 113 114 def reduce_fn(state, value): 115 state["dense"] += value["dense"] 116 state["sparse"] = value["sparse"] 117 return state 118 119 def make_sparse_fn(i): 120 return sparse_tensor.SparseTensorValue( 121 indices=np.array([[0, 0]]), 122 values=(i * np.array([1])), 123 dense_shape=np.array([1, 1])) 124 125 def map_fn(i): 126 return {"dense": math_ops.cast(i, dtype=dtypes.int64), 127 "sparse": make_sparse_fn(math_ops.cast(i, dtype=dtypes.int64))} 128 129 for i in range(10): 130 ds = dataset_ops.Dataset.range(1, i + 1).map(map_fn) 131 result = ds.reduce(map_fn(0), reduce_fn) 132 result = self.evaluate(result) 133 self.assertEqual(((i + 1) * i) // 2, result["dense"]) 134 self.assertValuesEqual(make_sparse_fn(i), result["sparse"]) 135 136 @combinations.generate(test_base.default_test_combinations()) 137 def testDatasetSideEffect(self): 138 counter_var = variables.Variable(0) 139 140 def increment_fn(x): 141 counter_var.assign_add(1) 142 return x 143 144 def dataset_fn(): 145 return dataset_ops.Dataset.range(10).map(increment_fn) 146 147 def reduce_fn(state, value): 148 return state + value 149 150 @def_function.function 151 def fn(): 152 _ = dataset_fn().reduce(np.int64(0), reduce_fn) 153 return "hello" 154 155 self.evaluate(counter_var.initializer) 156 self.assertEqual(self.evaluate(fn()), b"hello") 157 self.assertEqual(self.evaluate(counter_var), 10) 158 159 @combinations.generate(test_base.default_test_combinations()) 160 def testSideEffect(self): 161 counter_var = variables.Variable(0) 162 163 def dataset_fn(): 164 return dataset_ops.Dataset.range(10) 165 166 def reduce_fn(state, value): 167 counter_var.assign_add(1) 168 return state + value 169 170 @def_function.function 171 def fn(): 172 _ = dataset_fn().reduce(np.int64(0), reduce_fn) 173 return "hello" 174 175 self.evaluate(counter_var.initializer) 176 self.assertEqual(self.evaluate(fn()), b"hello") 177 self.assertEqual(self.evaluate(counter_var), 10) 178 179 @combinations.generate(test_base.default_test_combinations()) 180 def testAutomaticControlDependencies(self): 181 counter_var = variables.Variable(1) 182 183 def dataset_fn(): 184 return dataset_ops.Dataset.range(1) 185 186 def reduce1_fn(state, value): 187 counter_var.assign(counter_var + 1) 188 return state + value 189 190 def reduce2_fn(state, value): 191 counter_var.assign(counter_var * 2) 192 return state + value 193 194 @def_function.function 195 def fn(): 196 _ = dataset_fn().reduce(np.int64(0), reduce1_fn) 197 _ = dataset_fn().reduce(np.int64(0), reduce2_fn) 198 return "hello" 199 200 self.evaluate(counter_var.initializer) 201 self.assertEqual(self.evaluate(fn()), b"hello") 202 self.assertEqual(self.evaluate(counter_var), 4) 203 204 @combinations.generate(test_base.default_test_combinations()) 205 def testNestedAutomaticControlDependencies(self): 206 counter_var = variables.Variable(0) 207 208 def map_fn(x): 209 counter_var.assign_add(1) 210 return x 211 212 def dataset_fn(): 213 return dataset_ops.Dataset.range(10).map(map_fn) 214 215 @def_function.function 216 def fn(): 217 for _ in dataset_fn(): 218 pass 219 return counter_var 220 221 self.evaluate(counter_var.initializer) 222 self.assertEqual(self.evaluate(fn()), 10) 223 224 @combinations.generate(test_base.default_test_combinations()) 225 def testStateOnGPU(self): 226 if not test_util.is_gpu_available(): 227 self.skipTest("No GPUs available.") 228 229 state = constant_op.constant(0, dtype=dtypes.int64) 230 231 def reduce_fn(state, value): 232 with ops.device("/gpu:0"): 233 return state + value 234 235 for i in range(10): 236 ds = dataset_ops.Dataset.range(1, i + 1) 237 result = ds.reduce(state, reduce_fn) 238 self.assertEqual(((i + 1) * i) // 2, self.evaluate(result)) 239 240 @combinations.generate(combinations.combine(tf_api_version=1, mode="graph")) 241 def testCancellation(self): 242 ds = dataset_ops.Dataset.from_tensors(1).repeat() 243 result = ds.reduce(0, lambda x, y: x + y) 244 with self.cached_session() as sess: 245 # The `result` op is guaranteed to not complete before cancelled because 246 # the dataset that is being reduced is infinite. 247 thread = self.checkedThread(self.assert_op_cancelled, args=(result,)) 248 thread.start() 249 time.sleep(0.2) 250 sess.close() 251 thread.join() 252 253 @combinations.generate(test_base.default_test_combinations()) 254 def testInvalidFunction(self): 255 ds = dataset_ops.Dataset.range(5) 256 with self.assertRaises(errors.InvalidArgumentError): 257 self.evaluate(ds.reduce(0, lambda _, __: ())) 258 259 @combinations.generate(test_base.default_test_combinations()) 260 def testOptions(self): 261 dataset = dataset_ops.Dataset.range(5) 262 dataset = dataset.apply(testing.assert_next(["MapAndBatch"])) 263 dataset = dataset.map(lambda x: x * 2).batch(5) 264 self.evaluate(dataset.reduce(0, lambda state, value: state)) 265 266 267if __name__ == "__main__": 268 test.main() 269