1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for xla.reduce_window.""" 16 17import numpy as np 18 19from tensorflow.compiler.tests import xla_test 20from tensorflow.compiler.tf2xla.python import xla 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import function 23from tensorflow.python.ops import array_ops 24from tensorflow.python.platform import googletest 25 26 27class ReduceWindowTest(xla_test.XLATestCase): 28 """Test cases for xla.reduce_window.""" 29 30 def _reduce_window(self, operand, init, reducer, **kwargs): 31 with self.session(): 32 placeholder = array_ops.placeholder(operand.dtype) 33 with self.test_scope(): 34 output = xla.reduce_window(placeholder, init, reducer, **kwargs) 35 return output.eval(feed_dict={placeholder: operand}) 36 37 def testReduceWindow(self): 38 39 # TODO(b/77644762): float16 and float64 ReduceWindow are unimplemented. 40 for dtype in set(self.numeric_types).intersection( 41 set([dtypes.bfloat16.as_numpy_dtype, np.float32])): 42 43 @function.Defun(dtype, dtype) 44 def sum_reducer(x, y): 45 return x + y 46 47 @function.Defun(dtype, dtype) 48 def mul_reducer(x, y): 49 return x * y 50 51 self.assertAllClose( 52 np.array([3, 5, 7, 9, 11, 13], dtype=dtype), 53 self._reduce_window( 54 np.array([1, 2, 3, 4, 5, 6, 7], dtype=dtype), 55 0.0, 56 sum_reducer, 57 window_dimensions=[2])) 58 59 self.assertAllClose( 60 np.array([3, 7, 11], dtype=dtype), 61 self._reduce_window( 62 np.array([1, 2, 3, 4, 5, 6, 7], dtype=dtype), 63 0.0, 64 sum_reducer, 65 window_dimensions=[2], 66 window_strides=[2])) 67 68 self.assertAllClose( 69 np.array([1, 4, 7], dtype=dtype), 70 self._reduce_window( 71 np.array([1, 2, 3, 4, 5, 6, 7], dtype=dtype), 72 0.0, 73 sum_reducer, 74 window_dimensions=[1], 75 window_strides=[3])) 76 77 self.assertAllClose( 78 np.array([[24, 36, 24], [96, 0, 0]], dtype=dtype), 79 self._reduce_window( 80 np.array([[1, 2, 3, 4], [4, 3, 2, 1], [2, 4, 0, 1]], dtype=dtype), 81 1.0, 82 mul_reducer, 83 window_dimensions=[2, 2], 84 window_strides=[1, 1])) 85 86 self.assertAllClose( 87 np.array([[0, 0, 0], [5, 10, 5], [2, 4, 1], [0, 0, 0]], dtype=dtype), 88 self._reduce_window( 89 np.array([[1, 2, 3, 4], [4, 3, 2, 1], [2, 4, 0, 1]], dtype=dtype), 90 0.0, 91 sum_reducer, 92 window_dimensions=[2, 2], 93 window_strides=[2, 2], 94 padding=[[2, 3], [1, 2]])) 95 96 97if __name__ == '__main__': 98 googletest.main() 99