• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for ReduceJoin op from string_ops."""
16
17import itertools
18
19import numpy as np
20
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import test_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import string_ops
26from tensorflow.python.platform import test
27
28
29def _input_array(num_dims):
30  """Creates an ndarray where each element is the binary of its linear index.
31
32  Args:
33    num_dims: The number of dimensions to create.
34
35  Returns:
36    An ndarray of shape [2] * num_dims.
37  """
38  formatter = "{:0%db}" % num_dims
39  strings = [formatter.format(i) for i in range(2**num_dims)]
40  return np.array(strings, dtype="S%d" % num_dims).reshape([2] * num_dims)
41
42
43def _joined_array(num_dims, reduce_dim):
44  """Creates an ndarray with the result from reduce_join on input_array.
45
46  Args:
47    num_dims: The number of dimensions of the original input array.
48    reduce_dim: The dimension to reduce.
49
50  Returns:
51    An ndarray of shape [2] * (num_dims - 1).
52  """
53  formatter = "{:0%db}" % (num_dims - 1)
54  result = np.zeros(shape=[2] * (num_dims - 1), dtype="S%d" % (2 * num_dims))
55  flat = result.ravel()
56  for i in range(2**(num_dims - 1)):
57    dims = formatter.format(i)
58    flat[i] = "".join([
59        (dims[:reduce_dim] + "%d" + dims[reduce_dim:]) % j for j in range(2)
60    ])
61  return result
62
63
64class UnicodeTestCase(test.TestCase):
65  """Test case with Python3-compatible string comparator."""
66
67  def assertAllEqualUnicode(self, truth, actual):
68    self.assertAllEqual(
69        np.array(truth).astype("U"), np.array(actual).astype("U"))
70
71
72class ReduceJoinTestHelperTest(UnicodeTestCase):
73  """Tests for helper functions."""
74
75  def testInputArray(self):
76    num_dims = 3
77    truth = ["{:03b}".format(i) for i in range(2**num_dims)]
78    output_array = _input_array(num_dims).reshape([-1])
79    self.assertAllEqualUnicode(truth, output_array)
80
81  def testJoinedArray(self):
82    num_dims = 3
83    truth_dim_zero = [["000100", "001101"], ["010110", "011111"]]
84    truth_dim_one = [["000010", "001011"], ["100110", "101111"]]
85    truth_dim_two = [["000001", "010011"], ["100101", "110111"]]
86    output_array_dim_zero = _joined_array(num_dims, reduce_dim=0)
87    output_array_dim_one = _joined_array(num_dims, reduce_dim=1)
88    output_array_dim_two = _joined_array(num_dims, reduce_dim=2)
89    self.assertAllEqualUnicode(truth_dim_zero, output_array_dim_zero)
90    self.assertAllEqualUnicode(truth_dim_one, output_array_dim_one)
91    self.assertAllEqualUnicode(truth_dim_two, output_array_dim_two)
92
93
94class ReduceJoinTest(UnicodeTestCase):
95
96  def _testReduceJoin(self,
97                      input_array,
98                      truth,
99                      truth_shape,
100                      axis,
101                      keep_dims=False,
102                      separator=""):
103    """Compares the output of reduce_join to an expected result.
104
105    Args:
106      input_array: The string input to be joined.
107      truth: An array or np.array of the expected result.
108      truth_shape: An array or np.array of the expected shape.
109      axis: The indices to reduce over.
110      keep_dims: Whether or not to retain reduced dimensions.
111      separator: The separator to use for joining.
112    """
113    with self.cached_session():
114      output = string_ops.reduce_join(
115          inputs=input_array,
116          axis=axis,
117          keep_dims=keep_dims,
118          separator=separator)
119      output_array = self.evaluate(output)
120
121    self.assertAllEqualUnicode(truth, output_array)
122    self.assertAllEqual(truth_shape, output.get_shape())
123
124  def _testMultipleReduceJoin(self, input_array, axis, separator=" "):
125    """Tests reduce_join for one input and multiple axes.
126
127    Does so by comparing the output to that from nested reduce_string_joins.
128    The correctness of single-dimension reduce_join is verified by other
129    tests below using _testReduceJoin.
130
131    Args:
132      input_array: The input to test.
133      axis: The indices to reduce.
134      separator: The separator to use when joining.
135    """
136    with self.cached_session():
137      output = string_ops.reduce_join(
138          inputs=input_array, axis=axis, keep_dims=False, separator=separator)
139      output_keep_dims = string_ops.reduce_join(
140          inputs=input_array, axis=axis, keep_dims=True, separator=separator)
141
142      truth = input_array
143      for index in axis:
144        truth = string_ops.reduce_join(
145            inputs=truth, axis=index, keep_dims=True, separator=separator)
146      if not axis:
147        truth = constant_op.constant(truth)
148      truth_squeezed = array_ops.squeeze(truth, axis=axis)
149      output_array = self.evaluate(output)
150      output_keep_dims_array = self.evaluate(output_keep_dims)
151      truth_array = self.evaluate(truth)
152      truth_squeezed_array = self.evaluate(truth_squeezed)
153    self.assertAllEqualUnicode(truth_array, output_keep_dims_array)
154    self.assertAllEqualUnicode(truth_squeezed_array, output_array)
155    self.assertAllEqual(truth.get_shape(), output_keep_dims.get_shape())
156    self.assertAllEqual(truth_squeezed.get_shape(), output.get_shape())
157
158  def testRankOne(self):
159    input_array = ["this", "is", "a", "test"]
160    truth = "thisisatest"
161    truth_shape = []
162    self._testReduceJoin(input_array, truth, truth_shape, axis=0)
163
164  def testRankTwo(self):
165    input_array = [["this", "is", "a", "test"],
166                   ["please", "do", "not", "panic"]]
167    truth_dim_zero = ["thisplease", "isdo", "anot", "testpanic"]
168    truth_shape_dim_zero = [4]
169    truth_dim_one = ["thisisatest", "pleasedonotpanic"]
170    truth_shape_dim_one = [2]
171    self._testReduceJoin(
172        input_array, truth_dim_zero, truth_shape_dim_zero, axis=0)
173    self._testReduceJoin(
174        input_array, truth_dim_one, truth_shape_dim_one, axis=1)
175
176    expected_val = "thisisatestpleasedonotpanic"
177    expected_shape = []
178    self._testReduceJoin(input_array, expected_val, expected_shape, axis=None)
179
180    # Using axis=[] is a no-op.
181    expected_val = input_array
182    expected_shape = [2, 4]
183    self._testReduceJoin(input_array, expected_val, expected_shape, axis=[])
184
185  def testRankFive(self):
186    input_array = _input_array(num_dims=5)
187    truths = [_joined_array(num_dims=5, reduce_dim=i) for i in range(5)]
188    truth_shape = [2] * 4
189    for i in range(5):
190      self._testReduceJoin(input_array, truths[i], truth_shape, axis=i)
191
192  def testNegative(self):
193    input_array = _input_array(num_dims=5)
194    truths = [_joined_array(num_dims=5, reduce_dim=i) for i in range(5)]
195    truth_shape = [2] * 4
196    for i in range(5):
197      self._testReduceJoin(input_array, truths[i], truth_shape, axis=i - 5)
198
199  def testSingletonDimension(self):
200    input_arrays = [
201        _input_array(num_dims=5).reshape([2] * i + [1] + [2] * (5 - i))
202        for i in range(6)
203    ]
204    truth = _input_array(num_dims=5)
205    truth_shape = [2] * 5
206    for i in range(6):
207      self._testReduceJoin(input_arrays[i], truth, truth_shape, axis=i)
208
209  def testSeparator(self):
210    input_array = [["this", "is", "a", "test"],
211                   ["please", "do", "not", "panic"]]
212    truth_dim_zero = ["this  please", "is  do", "a  not", "test  panic"]
213    truth_shape_dim_zero = [4]
214    truth_dim_one = ["this  is  a  test", "please  do  not  panic"]
215    truth_shape_dim_one = [2]
216
217    self._testReduceJoin(
218        input_array,
219        truth_dim_zero,
220        truth_shape_dim_zero,
221        axis=0,
222        separator="  ")
223    self._testReduceJoin(
224        input_array,
225        truth_dim_one,
226        truth_shape_dim_one,
227        axis=1,
228        separator="  ")
229
230  @test_util.run_deprecated_v1
231  def testUnknownShape(self):
232    input_array = [["a"], ["b"]]
233    truth = ["ab"]
234    truth_shape = None
235    with self.cached_session():
236      placeholder = array_ops.placeholder(dtypes.string, name="placeholder")
237      reduced = string_ops.reduce_join(placeholder, axis=0)
238      output_array = reduced.eval(feed_dict={placeholder.name: input_array})
239      self.assertAllEqualUnicode(truth, output_array)
240      self.assertAllEqual(truth_shape, reduced.get_shape())
241
242  @test_util.run_deprecated_v1
243  def testUnknownIndices(self):
244    input_array = [["this", "is", "a", "test"],
245                   ["please", "do", "not", "panic"]]
246    truth_dim_zero = ["thisplease", "isdo", "anot", "testpanic"]
247    truth_dim_one = ["thisisatest", "pleasedonotpanic"]
248    truth_shape = None
249    with self.cached_session():
250      placeholder = array_ops.placeholder(dtypes.int32, name="placeholder")
251      reduced = string_ops.reduce_join(input_array, axis=placeholder)
252      output_array_dim_zero = reduced.eval(feed_dict={placeholder.name: [0]})
253      output_array_dim_one = reduced.eval(feed_dict={placeholder.name: [1]})
254      self.assertAllEqualUnicode(truth_dim_zero, output_array_dim_zero)
255      self.assertAllEqualUnicode(truth_dim_one, output_array_dim_one)
256      self.assertAllEqual(truth_shape, reduced.get_shape())
257
258  def testKeepDims(self):
259    input_array = [["this", "is", "a", "test"],
260                   ["please", "do", "not", "panic"]]
261    truth_dim_zero = [["thisplease", "isdo", "anot", "testpanic"]]
262    truth_shape_dim_zero = [1, 4]
263    truth_dim_one = [["thisisatest"], ["pleasedonotpanic"]]
264    truth_shape_dim_one = [2, 1]
265
266    self._testReduceJoin(
267        input_array,
268        truth_dim_zero,
269        truth_shape_dim_zero,
270        axis=0,
271        keep_dims=True)
272    self._testReduceJoin(
273        input_array,
274        truth_dim_one,
275        truth_shape_dim_one,
276        axis=1,
277        keep_dims=True)
278
279    expected_val = [["thisisatestpleasedonotpanic"]]
280    expected_shape = [1, 1]
281    self._testReduceJoin(
282        constant_op.constant(input_array), expected_val, expected_shape,
283        keep_dims=True, axis=None)
284
285    # Using axis=[] is a no-op.
286    expected_val = input_array
287    expected_shape = [2, 4]
288    self._testReduceJoin(
289        input_array, expected_val, expected_shape, keep_dims=True, axis=[])
290
291  def testMultiIndex(self):
292    num_dims = 3
293    input_array = _input_array(num_dims=num_dims)
294    # Also tests [].
295    for i in range(num_dims + 1):
296      for permutation in itertools.permutations(range(num_dims), i):
297        self._testMultipleReduceJoin(input_array, axis=permutation)
298
299  @test_util.run_deprecated_v1
300  def testInvalidReductionIndices(self):
301    with self.cached_session():
302      with self.assertRaisesRegex(ValueError, "Invalid reduction dim"):
303        string_ops.reduce_join(inputs="", axis=0)
304      with self.assertRaisesRegex(ValueError, "Invalid reduction dimension -3"):
305        string_ops.reduce_join(inputs=[[""]], axis=-3)
306      with self.assertRaisesRegex(ValueError, "Invalid reduction dimension 2"):
307        string_ops.reduce_join(inputs=[[""]], axis=2)
308      with self.assertRaisesRegex(ValueError, "Invalid reduction dimension -3"):
309        string_ops.reduce_join(inputs=[[""]], axis=[0, -3])
310      with self.assertRaisesRegex(ValueError, "Invalid reduction dimension 2"):
311        string_ops.reduce_join(inputs=[[""]], axis=[0, 2])
312
313  def testZeroDims(self):
314    with self.cached_session():
315      inputs = np.zeros([0, 1], dtype=str)
316
317      # Reduction that drops the dim of size 0.
318      output = string_ops.reduce_join(inputs=inputs, axis=0)
319      self.assertAllEqualUnicode([""], self.evaluate(output))
320
321      # Reduction that keeps the dim of size 0.
322      output = string_ops.reduce_join(inputs=inputs, axis=1)
323      output_shape = self.evaluate(output).shape
324      self.assertAllEqual([0], output_shape)
325
326  @test_util.run_deprecated_v1
327  def testInvalidArgsUnknownShape(self):
328    with self.cached_session():
329      placeholder = array_ops.placeholder(dtypes.string, name="placeholder")
330      index_too_high = string_ops.reduce_join(placeholder, axis=1)
331      duplicate_index = string_ops.reduce_join(placeholder, axis=[-1, 1])
332      with self.assertRaisesOpError("Invalid reduction dimension 1"):
333        index_too_high.eval(feed_dict={placeholder.name: [""]})
334      with self.assertRaisesOpError("Duplicate reduction dimension 1"):
335        duplicate_index.eval(feed_dict={placeholder.name: [[""]]})
336
337  @test_util.run_deprecated_v1
338  def testInvalidArgsUnknownIndices(self):
339    with self.cached_session():
340      placeholder = array_ops.placeholder(dtypes.int32, name="placeholder")
341      reduced = string_ops.reduce_join(["test", "test2"], axis=placeholder)
342
343      with self.assertRaisesOpError("reduction dimension -2"):
344        reduced.eval(feed_dict={placeholder.name: -2})
345      with self.assertRaisesOpError("reduction dimension 2"):
346        reduced.eval(feed_dict={placeholder.name: 2})
347
348  def testDeprecatedArgs(self):
349    foobar = constant_op.constant(["foobar"])
350    # Old names: keep_dims and reduction_indices
351    output = string_ops.reduce_join(
352        ["foo", "bar"], reduction_indices=0, keep_dims=True)
353    self.assertAllEqual(foobar, output)
354    # New names keepdims and axis.
355    output = string_ops.reduce_join(["foo", "bar"], axis=0, keepdims=True)
356    self.assertAllEqual(foobar, output)
357
358
359if __name__ == "__main__":
360  test.main()
361