• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import operator
21import re
22import textwrap
23
24import numpy as np
25from six.moves import range  # pylint: disable=redefined-builtin
26
27from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc
28from tensorflow.contrib.labeled_tensor.python.ops import core
29from tensorflow.contrib.labeled_tensor.python.ops import test_util
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import tensor_shape
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.platform import test as test_lib
37
38
39class AxisTest(test_lib.TestCase):
40
41  def setUp(self):
42    d_7 = tensor_shape.Dimension(7)
43    p_rgb = ['red', 'green', 'blue']
44
45    self.i_7 = core.Axis('7', d_7)
46    self.i_7p = core.Axis('7prime', d_7)
47    self.i_rgb = core.Axis('rgb', p_rgb)
48    self.i_range = core.Axis('range', range(7))
49    self.i_unknown = core.Axis('unknown', None)
50
51  def test_equality(self):
52
53    axes = [self.i_7, self.i_7p, self.i_rgb, self.i_range, self.i_unknown]
54    for i, axis_0 in enumerate(axes):
55      for j, axis_1 in enumerate(axes):
56        if i == j:
57          self.assertEqual(axis_0, axis_1)
58        else:
59          self.assertNotEqual(axis_0, axis_1)
60
61  def test_axis_value(self):
62    self.assertEqual(self.i_7.value, tensor_shape.Dimension(7))
63    self.assertTrue(self.i_range.value == tuple(range(7)))
64
65  def test_axis_input(self):
66    axes = [self.i_7, self.i_7p, self.i_rgb, self.i_range, self.i_unknown]
67    for axis in axes:
68      self.assertEqual(axis, core.Axis(axis.name, axis.value))
69
70  def test_axis_value_input(self):
71    axis = self.i_range
72    for value in [range(7), list(range(7)), np.arange(7)]:
73      self.assertEqual(axis, core.Axis(axis.name, value))
74
75  def test_size(self):
76    self.assertEqual(len(self.i_7), 7)
77    self.assertEqual(len(self.i_rgb), 3)
78    self.assertEqual(len(self.i_range), 7)
79    self.assertEqual(self.i_unknown.size, None)
80
81  def test_concat_single(self):
82    red = core.Axis('rgb', ['red'])
83
84    self.assertEqual(core.concat_axes([red]), red)
85
86  def test_concat_many(self):
87    red = core.Axis('rgb', ['red'])
88    green = core.Axis('rgb', ['green'])
89    blue = core.Axis('rgb', ['blue'])
90    red_green_blue = core.Axis('rgb', ['red', 'green', 'blue'])
91
92    self.assertEqual(core.concat_axes([red, green, blue]), red_green_blue)
93
94  def test_concat_different_names(self):
95    red = core.Axis('red', ['red'])
96    green = core.Axis('green', ['red'])
97    with self.assertRaises(ValueError):
98      core.concat_axes([red, green])
99
100  def test_concat_unknown(self):
101    red = core.Axis('rgb', None)
102    green = core.Axis('rgb', None)
103    self.assertEqual(core.concat_axes([red, green]), red)
104
105  def test_repr(self):
106    self.assertEqual("Axis('7', Dimension(7))", repr(self.i_7))
107
108  def test_invalid_input(self):
109    with self.assertRaises(TypeError):
110      core.Axis('foo', [{}])
111    with self.assertRaises(ValueError):
112      core.Axis('foo', [1, 2, 3, 1])
113    red = core.Axis('foo', ['red'])
114    with self.assertRaises(tc.Error):
115      core.concat_axes([red, 1])
116
117  def test_as_axis(self):
118    self.assertEqual(self.i_7, core.as_axis(('7', 7)))
119    self.assertEqual(self.i_7, core.as_axis(self.i_7))
120
121
122class AxesTest(test_lib.TestCase):
123
124  def setUp(self):
125    d_7 = tensor_shape.Dimension(7)
126    d_8 = tensor_shape.Dimension(8)
127    p_rgb = ['red', 'green', 'blue']
128    p_range = range(7)
129
130    self.i_8 = core.Axis('8', d_8)
131
132    self.a0 = core.Axes([('d7', d_7)])
133    self.a1 = core.Axes([('d7', d_7)])
134    self.a2 = core.Axes([('d7', d_7), ('rgb', p_rgb)])
135    self.a3 = core.Axes([('8', d_8), ('range', p_range)])
136
137  def test_equality(self):
138    self.assertEqual(self.a0, self.a0)
139    self.assertEqual(self.a0, self.a1)
140    self.assertNotEqual(self.a0, self.a2)
141
142  def test_repr(self):
143    self.assertEqual("Axes([('d7', Dimension(7))])", repr(self.a0))
144
145  def test_remove(self):
146    a = self.a3.remove('range')
147    self.assertEqual(a, core.Axes([self.i_8]))
148    with self.assertRaises(KeyError):
149      self.a3.remove('foobar')
150
151  def test_typecheck_error_message(self):
152    pattern = ('List(Union(labeled_tensor.Axis, Tuple(..., '
153               'Union(Union(numpy.ndarray, %s, list, tuple), '
154               'Optional(Union(tensorflow.Dimension, int))))))' %
155               range.__name__)
156    regexp = re.escape(pattern).replace(re.escape('...'), '.*')
157    with self.assertRaisesRegexp(tc.Error, 'allowed type ' + regexp):
158      core.Axes(None)
159
160
161class LabeledTensorTest(test_util.Base):
162
163  def setUp(self):
164    tensor = array_ops.ones([7, 3, 8, 1])
165    a0 = ('x', range(7))
166    a1 = ('channel', ['red', 'green', 'blue'])
167    a2 = ('y', 8)
168    a3 = ('z', tensor_shape.Dimension(1))
169
170    self.lt = core.LabeledTensor(tensor, [a0, a1, a2, a3])
171
172  def test_repr(self):
173    pattern = textwrap.dedent("""\
174    <LabeledTensor '...' shape=(7, 3, 8, 1) dtype=float32
175     axes=[('x', ...),
176           ('channel', ...),
177           ('y', Dimension(8)),
178           ('z', Dimension(1))]>""")
179    regexp = re.escape(pattern).replace(re.escape('...'), '.*')
180    self.assertRegexpMatches(repr(self.lt), regexp)
181
182  def test_reuse_existing_axes(self):
183    alt_lt = core.LabeledTensor(self.lt.tensor, self.lt.axes)
184    self.assertLabeledTensorsEqual(alt_lt, self.lt)
185
186  def test_reuse_existing_axis_objects(self):
187    alt_lt = core.LabeledTensor(self.lt.tensor, self.lt.axes.values())
188    self.assertLabeledTensorsEqual(alt_lt, self.lt)
189
190  def test_indexing_scalars(self):
191    actual = self.lt[:, :, :, 0]
192    expected = core.LabeledTensor(self.lt.tensor[:, :, :, 0],
193                                  list(self.lt.axes.values())[:-1])
194    self.assertLabeledTensorsEqual(actual, expected)
195
196    actual = self.lt[1, :, :, 0]
197    expected = core.LabeledTensor(self.lt.tensor[1, :, :, 0],
198                                  list(self.lt.axes.values())[1:-1])
199    self.assertLabeledTensorsEqual(actual, expected)
200
201    actual = self.lt[1, 2, :, 0]
202    expected = core.LabeledTensor(self.lt.tensor[1, 2, :, 0],
203                                  list(self.lt.axes.values())[2:-1])
204    self.assertLabeledTensorsEqual(actual, expected)
205
206  def test_indexing_1d(self):
207    lt_1d = self.lt[1, 2, :, 0]
208    actual = lt_1d[3]
209    expected = core.LabeledTensor(lt_1d.tensor[3], [])
210    self.assertLabeledTensorsEqual(actual, expected)
211
212  def test_indexing_slices(self):
213    actual = self.lt[:3, :, :, :]
214    axes = [('x', range(3))] + list(self.lt.axes.values())[1:]
215    expected = core.LabeledTensor(self.lt.tensor[:3, :, :, :], axes)
216    self.assertLabeledTensorsEqual(actual, expected)
217
218  def test_invalid_indexing(self):
219    with self.assertRaises(ValueError):
220      self.lt[0]  # pylint: disable=pointless-statement
221    with self.assertRaises(ValueError):
222      self.lt[:, :, :, :, 0]  # pylint: disable=pointless-statement
223
224  def test_unknown_size(self):
225    tensor = array_ops.placeholder(dtypes.string, [None])
226    actual = core.LabeledTensor(tensor, ['x'])
227    self.assertIsNone(actual.axes['x'].size)
228    self.assertIsNone(actual.axes['x'].value.value)
229
230  def test_eq(self):
231    self.assertEqual(self.lt, self.lt)
232    self.assertNotEqual(self.lt, self.lt.tensor)
233    self.assertNotEqual(self.lt.tensor, self.lt)
234
235  def test_hash(self):
236    lt1 = self.lt
237    lt2 = core.LabeledTensor(self.lt.tensor, self.lt.axes)
238    self.assertEqual(lt1, lt2)
239    self.assertEqual(hash(lt1), hash(lt2))
240
241  def test_name(self):
242    self.assertEqual(self.lt.name, self.lt.tensor.name)
243
244  def test_dtype(self):
245    self.assertEqual(self.lt.dtype, self.lt.tensor.dtype)
246
247  def test_shape(self):
248    self.assertEqual(self.lt.shape, self.lt.tensor.shape)
249
250  def test_get_shape(self):
251    self.assertEqual(self.lt.get_shape(), self.lt.tensor.get_shape())
252
253  def test_convert_to_tensor(self):
254    expected = self.lt.tensor
255    actual = ops.convert_to_tensor(self.lt)
256    self.assertIs(expected, actual)
257
258
259class Base(test_util.Base):
260
261  def setUp(self):
262    self.x_size = 7
263    self.channel_size = 3
264    self.z_size = 4
265    self.probs_size = 11
266
267    tensor = math_ops.range(0, self.x_size * self.channel_size * self.z_size *
268                            self.probs_size)
269    tensor = array_ops.reshape(
270        tensor, [self.x_size, self.channel_size, self.z_size, self.probs_size])
271    a0 = ('x', range(self.x_size))
272    a1 = ('channel', ['red', 'green', 'blue'])
273    a2 = 'z'
274    a3 = ('probs', np.linspace(0.0, 1.0, self.probs_size))
275
276    self.tensor = tensor
277    self.a0 = a0
278    self.a1 = a1
279    self.a2 = a2
280    self.a3 = a3
281    self.original_lt = core.LabeledTensor(tensor, [a0, a1, a2, a3])
282
283    self.x_probs_lt = core.slice_function(self.original_lt,
284                                          {'z': 0,
285                                           'channel': 0})
286    self.channel_probs_lt = core.slice_function(self.original_lt,
287                                                {'x': 3,
288                                                 'z': 0})
289
290
291class IdentityTest(Base):
292
293  def test_name(self):
294    identity_lt = core.identity(self.original_lt)
295    self.assertIn('lt_identity', identity_lt.name)
296
297
298class SliceFunctionTest(Base):
299
300  def test_name(self):
301    select_lt = core.slice_function(self.original_lt, {'channel': 1})
302    self.assertIn('lt_slice', select_lt.name)
303
304  def test_scalar(self):
305    select_lt = core.slice_function(self.original_lt, {'channel': 1})
306    golden_lt = core.LabeledTensor(self.tensor[:, 1, :, :],
307                                   [self.a0, self.a2, self.a3])
308
309    self.assertLabeledTensorsEqual(select_lt, golden_lt)
310
311  def test_slice(self):
312    select_lt = core.slice_function(self.original_lt, {'channel': slice(0, 2)})
313
314    a1_sliced = ('channel', ['red', 'green'])
315    golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :],
316                                   [self.a0, a1_sliced, self.a2, self.a3])
317
318    self.assertLabeledTensorsEqual(select_lt, golden_lt)
319
320  def test_slices(self):
321    select_lt = core.slice_function(
322        self.original_lt, {'x': slice(1, 5),
323                           'channel': slice(1, None)})
324
325    a0_sliced = ('x', range(1, 5))
326    a1_sliced = ('channel', ['green', 'blue'])
327    golden_lt = core.LabeledTensor(self.tensor[1:5, 1:, :, :],
328                                   [a0_sliced, a1_sliced, self.a2, self.a3])
329
330    self.assertLabeledTensorsEqual(select_lt, golden_lt)
331
332  def test_slice_unlabeled(self):
333    select_lt = core.slice_function(self.original_lt, {'z': slice(1, 3)})
334
335    a2_sliced = 'z'
336    golden_lt = core.LabeledTensor(self.tensor[:, :, 1:3, :],
337                                   [self.a0, self.a1, a2_sliced, self.a3])
338
339    self.assertLabeledTensorsEqual(select_lt, golden_lt)
340
341  def test_slice_unknown_shape(self):
342    lt = core.LabeledTensor(
343        array_ops.placeholder(dtypes.float32, [None, 1]), ['x', 'y'])
344    sliced_lt = core.slice_function(lt, {'y': 0})
345    self.assertEqual(list(sliced_lt.axes.values()), [lt.axes['x']])
346
347
348class TransposeTest(Base):
349
350  def test_name(self):
351    transpose_lt = core.transpose(self.original_lt,
352                                  self.original_lt.axes.keys())
353    self.assertIn('lt_transpose', transpose_lt.name)
354
355  def test_identity(self):
356    transpose_lt = core.transpose(self.original_lt,
357                                  self.original_lt.axes.keys())
358    golden_lt = self.original_lt
359
360    self.assertLabeledTensorsEqual(transpose_lt, golden_lt)
361
362  def test(self):
363    transpose_lt = core.transpose(self.original_lt,
364                                  ['z', 'channel', 'x', 'probs'])
365    golden_lt = core.LabeledTensor(
366        array_ops.transpose(self.tensor, [2, 1, 0, 3]),
367        [self.a2, self.a1, self.a0, self.a3])
368
369    self.assertLabeledTensorsEqual(transpose_lt, golden_lt)
370
371  def test_default_axis_order(self):
372    transpose_lt = core.transpose(self.original_lt)
373    golden_lt = core.LabeledTensor(
374        array_ops.transpose(self.tensor, [3, 2, 1, 0]),
375        list(reversed(list(self.original_lt.axes.values()))))
376
377    self.assertLabeledTensorsEqual(transpose_lt, golden_lt)
378
379  def test_invalid_input(self):
380    with self.assertRaises(ValueError):
381      core.transpose(self.original_lt, ['channel', 'x', 'probs'])
382    with self.assertRaises(ValueError):
383      core.transpose(self.original_lt, ['z', 'foo', 'x', 'probs'])
384
385
386class ExpandDimsTest(Base):
387
388  def test_name(self):
389    expand_lt = core.expand_dims(self.original_lt, self.original_lt.axes.keys())
390    self.assertIn('lt_expand', expand_lt.name)
391
392  def test_identity(self):
393    expand_lt = core.expand_dims(self.original_lt, self.original_lt.axes.keys())
394    golden_lt = self.original_lt
395
396    self.assertLabeledTensorsEqual(expand_lt, golden_lt)
397
398  def test(self):
399    expand_lt = core.expand_dims(
400        self.original_lt, ['foo', 'x', 'bar', 'channel', 'z', 'probs', 'grok'])
401    golden_lt = core.LabeledTensor(
402        array_ops.reshape(self.tensor, [
403            1, self.x_size, 1, self.channel_size, self.z_size, self.probs_size,
404            1
405        ]), ['foo', self.a0, 'bar', self.a1, self.a2, self.a3, 'grok'])
406
407    self.assertLabeledTensorsEqual(expand_lt, golden_lt)
408
409  def test_label(self):
410    expand_lt = core.expand_dims(self.original_lt, [
411        'x',
412        'channel',
413        ('foo', 'bar'),
414        'z',
415        'probs',
416    ])
417    golden_lt = core.LabeledTensor(
418        array_ops.reshape(
419            self.tensor,
420            [self.x_size, self.channel_size, 1, self.z_size, self.probs_size]),
421        [self.a0, self.a1, ('foo', ['bar']), self.a2, self.a3])
422
423    self.assertLabeledTensorsEqual(expand_lt, golden_lt)
424
425  def test_unknown_dimension(self):
426    orig_lt = core.LabeledTensor(
427        array_ops.placeholder(dtypes.float32, [None]), ['x'])
428    expand_lt = core.expand_dims(orig_lt, ['x', 'y'])
429    self.assertEqual(expand_lt.axes, core.Axes([('x', None), ('y', 1)]))
430
431  def test_invalid_input(self):
432    with self.assertRaises(core.AxisOrderError):
433      core.expand_dims(self.original_lt,
434                       ['foo', 'not_x', 'bar', 'channel', 'z', 'probs', 'grok'])
435    with self.assertRaises(core.AxisOrderError):
436      core.expand_dims(self.original_lt,
437                       ['foo', 'z', 'bar', 'channel', 'x', 'probs', 'grok'])
438
439
440class AxisOrderScopeTest(Base):
441
442  def test(self):
443    xyz = ['x', 'y', 'z']
444    abc = ['a', 'b', 'c']
445
446    self.assertIsNone(core.get_axis_order())
447
448    with core.axis_order_scope(xyz):
449      self.assertEqual(core.get_axis_order(), xyz)
450
451      with core.axis_order_scope():
452        self.assertIsNone(core.get_axis_order())
453
454        with core.axis_order_scope(abc):
455          self.assertEqual(core.get_axis_order(), abc)
456
457        self.assertIsNone(core.get_axis_order())
458
459      self.assertEqual(core.get_axis_order(), xyz)
460
461    self.assertIsNone(core.get_axis_order())
462
463
464class CheckAxisOrderTest(Base):
465
466  def test_passes(self):
467    axis_order = ['w', 'x', 'y', 'z']
468
469    lt = core.LabeledTensor(array_ops.ones((1, 1, 1, 1)), axis_order)
470    core.check_axis_order(lt, axis_order)
471
472    lt = core.LabeledTensor(array_ops.ones((1, 1, 1)), axis_order[1:])
473    core.check_axis_order(lt, axis_order)
474
475    lt = core.LabeledTensor(array_ops.ones((1, 1, 1)), axis_order[:-1])
476    core.check_axis_order(lt, axis_order)
477
478  def test_invalid(self):
479    axis_order = ['w', 'x', 'y', 'z']
480    lt = core.LabeledTensor(array_ops.ones((1, 1, 1, 1)), axis_order)
481    with self.assertRaises(core.AxisOrderError):
482      core.check_axis_order(lt)
483    with self.assertRaises(core.AxisOrderError):
484      core.check_axis_order(lt, axis_order[:-1])
485    with self.assertRaises(core.AxisOrderError):
486      core.check_axis_order(lt, axis_order[::-1])
487
488  def test_scope(self):
489    axis_order = ['w', 'x', 'y', 'z']
490    lt = core.LabeledTensor(array_ops.ones((1, 1, 1, 1)), axis_order)
491    with core.axis_order_scope(axis_order):
492      core.check_axis_order(lt)
493
494
495class ImposeAxisOrderTest(Base):
496
497  def test_identity(self):
498    axis_order = ['w', 'x', 'y', 'z']
499    lt = core.LabeledTensor(
500        array_ops.reshape(math_ops.range(24), (1, 2, 3, 4)), axis_order)
501    actual = core.impose_axis_order(lt, axis_order)
502    self.assertLabeledTensorsEqual(lt, actual)
503
504    lt = core.LabeledTensor(
505        array_ops.reshape(math_ops.range(6), (1, 2, 3)), axis_order[:3])
506    actual = core.impose_axis_order(lt, axis_order)
507    self.assertLabeledTensorsEqual(lt, actual)
508
509  def test_reverse(self):
510    axis_order = ['w', 'x', 'y', 'z']
511
512    lt = core.LabeledTensor(
513        array_ops.reshape(math_ops.range(24), (1, 2, 3, 4)), axis_order)
514    actual = core.impose_axis_order(lt, axis_order[::-1])
515    expected = core.transpose(lt, axis_order[::-1])
516    self.assertLabeledTensorsEqual(expected, actual)
517
518    lt = core.LabeledTensor(
519        array_ops.reshape(math_ops.range(6), (1, 2, 3)), axis_order[:3])
520    actual = core.impose_axis_order(lt, axis_order[::-1])
521    expected = core.transpose(lt, ['y', 'x', 'w'])
522    self.assertLabeledTensorsEqual(expected, actual)
523
524  def test_scope(self):
525    axis_order = ['w', 'x', 'y', 'z']
526
527    lt = core.LabeledTensor(
528        array_ops.reshape(math_ops.range(24), (1, 2, 3, 4)), axis_order)
529    expected = core.transpose(lt, axis_order[::-1])
530    with core.axis_order_scope(axis_order[::-1]):
531      actual = core.impose_axis_order(lt)
532    self.assertLabeledTensorsEqual(expected, actual)
533
534  def test_invalid(self):
535    lt = core.LabeledTensor(
536        array_ops.reshape(math_ops.range(2), (1, 2)), ['x', 'y'])
537    with self.assertRaises(ValueError):
538      core.impose_axis_order(lt)
539    with self.assertRaises(ValueError):
540      core.impose_axis_order(lt, ['x'])
541
542
543class FindConsistentOrderingTest(Base):
544
545  def test(self):
546    cases = [
547        ([], [], []),
548        (['x'], [], ['x']),
549        ([], ['x'], ['x']),
550        (['x'], ['x'], ['x']),
551        (['x'], ['y'], ['x', 'y']),
552        (['y'], ['x'], ['y', 'x']),
553        (['x', 'y'], ['x', 'y'], ['x', 'y']),
554        (['x', 'y'], ['y', 'x'], None),
555        (['x', 'y'], ['y', 'z'], ['x', 'y', 'z']),
556        (['x', 'z'], ['y', 'z'], ['x', 'y', 'z']),
557        (['x', 'y'], ['x', 'z'], ['x', 'y', 'z']),
558        (['w', 'x'], ['y', 'z'], ['w', 'x', 'y', 'z']),
559        (['x', 'y', 'z'], ['z', 'x'], None),
560        (['x', 'y', 'z'], ['x'], ['x', 'y', 'z']),
561        ([], ['x', 'y', 'z'], ['x', 'y', 'z']),
562    ]
563    for a, b, expected in cases:
564      actual = core._find_consistent_ordering(a, b)
565      msg = ('unexpected ordering between %r and %r:\nexpected: %r\nactual: %r'
566             % (a, b, expected, actual))
567      self.assertEqual(expected, actual, msg=msg)
568
569
570class AlignTest(Base):
571
572  def test_name(self):
573    align_lt_0, align_lt_1, _ = core.align(self.original_lt, self.original_lt)
574    self.assertIn('lt_align', align_lt_0.name)
575    self.assertIn('/0', align_lt_0.name)
576    self.assertIn('lt_align', align_lt_1.name)
577    self.assertIn('/1', align_lt_1.name)
578
579  def test_identical_shaped_inputs(self):
580    offset_tensor = self.original_lt.tensor + 1
581    offset_lt = core.LabeledTensor(offset_tensor, self.original_lt.axes)
582
583    align_lt, align_offset_lt, broadcast_axes = core.align(self.original_lt,
584                                                           offset_lt)
585
586    self.assertLabeledTensorsEqual(align_lt, self.original_lt)
587    self.assertLabeledTensorsEqual(align_offset_lt, offset_lt)
588    self.assertEqual(broadcast_axes, self.original_lt.axes)
589
590  def test_different_inputs(self):
591    # The correct axis ordering is ['x', 'channel', 'probs'].
592    align_x_probs_lt, align_channel_probs_lt, broadcast_axes = core.align(
593        self.x_probs_lt, self.channel_probs_lt)
594
595    x_probs_golden_lt = core.LabeledTensor(
596        array_ops.reshape(self.x_probs_lt.tensor,
597                          [self.x_size, 1, self.probs_size]),
598        [self.a0, 'channel', self.a3])
599
600    self.assertLabeledTensorsEqual(align_x_probs_lt, x_probs_golden_lt)
601
602    channel_probs_golden_lt = core.LabeledTensor(
603        array_ops.reshape(self.channel_probs_lt.tensor,
604                          [1, self.channel_size, self.probs_size]),
605        ['x', self.a1, self.a3])
606
607    self.assertLabeledTensorsEqual(align_channel_probs_lt,
608                                   channel_probs_golden_lt)
609
610    self.assertEqual(broadcast_axes, core.Axes([self.a0, self.a1, self.a3]))
611
612  def test_axis_order_scope(self):
613    xz_lt = core.LabeledTensor(array_ops.ones((2, 3)), ['x', 'z'])
614    yz_lt = core.LabeledTensor(array_ops.ones((4, 3)), ['y', 'z'])
615
616    _, _, broadcast_axes = core.align(xz_lt, yz_lt)
617    self.assertEqual(list(broadcast_axes.keys()), ['x', 'y', 'z'])
618
619    _, _, broadcast_axes = core.align(yz_lt, xz_lt)
620    self.assertEqual(list(broadcast_axes.keys()), ['y', 'x', 'z'])
621
622    with core.axis_order_scope(['x', 'y', 'z']):
623      _, _, broadcast_axes = core.align(yz_lt, xz_lt)
624      self.assertEqual(list(broadcast_axes.keys()), ['x', 'y', 'z'])
625
626    with core.axis_order_scope(['x', 'y']):
627      with self.assertRaises(core.AxisOrderError):
628        core.align(xz_lt, yz_lt)
629      with self.assertRaises(core.AxisOrderError):
630        core.align(yz_lt, xz_lt)
631
632  def test_invalid_input(self):
633    lt_0 = core.LabeledTensor(array_ops.zeros([5]), [('a', range(5))])
634    lt_1 = core.LabeledTensor(array_ops.zeros([5]), [('a', range(1, 6))])
635    with self.assertRaises(ValueError):
636      core.align(lt_0, lt_1)
637
638
639class ConvertToLabeledTensorTest(Base):
640
641  # TODO(shoyer): Simplify these tests once we can reuse labeled tensors in
642  # assertLabeledTensorsEqual.
643
644  def test_labeled_tensor(self):
645    actual = core.convert_to_labeled_tensor(self.original_lt)
646    self.assertLabeledTensorsEqual(actual, self.original_lt)
647
648  def test_python_scalar(self):
649    actual = core.convert_to_labeled_tensor(42)
650    golden_lt = core.LabeledTensor(ops.convert_to_tensor(42), [])
651    self.assertLabeledTensorsEqual(actual, golden_lt)
652
653  def test_numpy_array(self):
654    actual = core.convert_to_labeled_tensor(np.array(42))
655    golden_lt = core.LabeledTensor(ops.convert_to_tensor(42), [])
656    self.assertLabeledTensorsEqual(actual, golden_lt)
657
658  def test_tensor(self):
659    actual = core.convert_to_labeled_tensor(constant_op.constant(42))
660    golden_lt = core.LabeledTensor(ops.convert_to_tensor(42), [])
661    self.assertLabeledTensorsEqual(actual, golden_lt)
662
663  def test_invalid_input(self):
664    with self.assertRaises(ValueError):
665      core.convert_to_labeled_tensor(math_ops.range(5))
666    with self.assertRaises(ValueError):
667      core.convert_to_labeled_tensor(np.array([1, 2]))
668
669
670class DocStringCheckMixin(object):
671  # requires self.ops to be defined
672
673  def test_function_docstring_and_name(self):
674    for op_name, _, _, lt_op in self.ops:
675      if lt_op is not None:
676        self.assertIn('tf.%s' % op_name, lt_op.__doc__)
677        self.assertEqual(op_name, lt_op.__name__)
678
679
680class UnaryOpsTestsMixin(object):
681  # requires self.ops and self.test_lt to be defined
682
683  def test_core_op(self):
684    for op_name, _, tf_op, lt_op in self.ops:
685      if tf_op is not None:
686        golden_lt = core.LabeledTensor(
687            tf_op(self.test_lt.tensor), self.test_lt.axes)
688        actual_lt = lt_op(self.test_lt)
689        self.assertIn(op_name, actual_lt.name)
690        self.assertLabeledTensorsEqual(golden_lt, actual_lt)
691
692  def test_infix(self):
693    for op_name, infix_op, _, _ in self.ops:
694      if infix_op is not None:
695        expected_lt = core.LabeledTensor(
696            infix_op(self.test_lt.tensor), self.test_lt.axes)
697        actual_lt = infix_op(self.test_lt)
698        self.assertIn(op_name, actual_lt.name)
699        self.assertLabeledTensorsEqual(expected_lt, actual_lt)
700
701
702class CoreUnaryOpsTest(Base, DocStringCheckMixin, UnaryOpsTestsMixin):
703
704  def setUp(self):
705    super(CoreUnaryOpsTest, self).setUp()
706
707    self.ops = [
708        ('abs', operator.abs, math_ops.abs, core.abs_function),
709        ('neg', operator.neg, math_ops.negative, core.neg),
710        # TODO(shoyer): add unary + to core TensorFlow
711        ('pos', None, None, None),
712        ('sign', None, math_ops.sign, core.sign),
713        ('reciprocal', None, math_ops.reciprocal, core.reciprocal),
714        ('square', None, math_ops.square, core.square),
715        ('round', None, math_ops.round, core.round_function),
716        ('sqrt', None, math_ops.sqrt, core.sqrt),
717        ('rsqrt', None, math_ops.rsqrt, core.rsqrt),
718        ('log', None, math_ops.log, core.log),
719        ('exp', None, math_ops.exp, core.exp),
720        ('log', None, math_ops.log, core.log),
721        ('ceil', None, math_ops.ceil, core.ceil),
722        ('floor', None, math_ops.floor, core.floor),
723        ('cos', None, math_ops.cos, core.cos),
724        ('sin', None, math_ops.sin, core.sin),
725        ('tan', None, math_ops.tan, core.tan),
726        ('acos', None, math_ops.acos, core.acos),
727        ('asin', None, math_ops.asin, core.asin),
728        ('atan', None, math_ops.atan, core.atan),
729        ('lgamma', None, math_ops.lgamma, core.lgamma),
730        ('digamma', None, math_ops.digamma, core.digamma),
731        ('erf', None, math_ops.erf, core.erf),
732        ('erfc', None, math_ops.erfc, core.erfc),
733        ('lgamma', None, math_ops.lgamma, core.lgamma),
734    ]
735    total_size = np.prod([v.size for v in self.original_lt.axes.values()])
736    self.test_lt = core.LabeledTensor(
737        math_ops.cast(self.original_lt, dtypes.float32) / total_size,
738        self.original_lt.axes)
739
740
741class LogicalNotTest(Base, DocStringCheckMixin, UnaryOpsTestsMixin):
742
743  def setUp(self):
744    super(LogicalNotTest, self).setUp()
745    self.ops = [('logical_not', operator.invert, math_ops.logical_not,
746                 core.logical_not),]
747    self.test_lt = self.original_lt < 10
748
749
750class BinaryOpsTestsMixin(object):
751  # requires self.ops, self.test_lt_1, self.test_lt_2, self.test_lt_1_broadcast
752  # and self.test_lt_2_broadcast to be defined
753
754  def test_core_op(self):
755    for op_name, _, tf_op, lt_op in self.ops:
756      golden_tensor = tf_op(self.test_lt_1_broadcast, self.test_lt_2_broadcast)
757      golden_lt = core.LabeledTensor(golden_tensor, self.broadcast_axes)
758      actual_lt = lt_op(self.test_lt_1, self.test_lt_2)
759      self.assertIn(op_name, actual_lt.name)
760      self.assertLabeledTensorsEqual(golden_lt, actual_lt)
761
762  def test_infix(self):
763    for op_name, infix_op, _, lt_op in self.ops:
764      if infix_op is not None:
765        expected_lt = lt_op(self.test_lt_1, self.test_lt_2)
766        actual_lt = infix_op(self.test_lt_1, self.test_lt_2)
767        self.assertIn(op_name, actual_lt.name)
768        self.assertLabeledTensorsEqual(expected_lt, actual_lt)
769
770
771class CoreBinaryOpsTest(Base, DocStringCheckMixin, BinaryOpsTestsMixin):
772
773  def setUp(self):
774    super(CoreBinaryOpsTest, self).setUp()
775
776    self.x_probs_broadcast_tensor = array_ops.reshape(
777        self.x_probs_lt.tensor, [self.x_size, 1, self.probs_size])
778
779    self.channel_probs_broadcast_tensor = array_ops.reshape(
780        self.channel_probs_lt.tensor, [1, self.channel_size, self.probs_size])
781
782    # == and != are not element-wise for tf.Tensor, so they shouldn't be
783    # elementwise for LabeledTensor, either.
784    self.ops = [
785        ('add', operator.add, math_ops.add, core.add),
786        ('sub', operator.sub, math_ops.subtract, core.sub),
787        ('mul', operator.mul, math_ops.multiply, core.mul),
788        ('div', operator.truediv, math_ops.div, core.div),
789        ('mod', operator.mod, math_ops.mod, core.mod),
790        ('pow', operator.pow, math_ops.pow, core.pow_function),
791        ('equal', None, math_ops.equal, core.equal),
792        ('less', operator.lt, math_ops.less, core.less),
793        ('less_equal', operator.le, math_ops.less_equal, core.less_equal),
794        ('not_equal', None, math_ops.not_equal, core.not_equal),
795        ('greater', operator.gt, math_ops.greater, core.greater),
796        ('greater_equal', operator.ge, math_ops.greater_equal,
797         core.greater_equal),
798    ]
799    self.test_lt_1 = self.x_probs_lt
800    self.test_lt_2 = self.channel_probs_lt
801    self.test_lt_1_broadcast = self.x_probs_broadcast_tensor
802    self.test_lt_2_broadcast = self.channel_probs_broadcast_tensor
803    self.broadcast_axes = [self.a0, self.a1, self.a3]
804
805  def test_reflexive(self):
806    labeled_tensor = self.x_probs_lt + 1  # all elements must be >0 for division
807    for op_name, infix_op, _, lt_op in self.ops:
808      if infix_op is not None:
809        expected_lt = lt_op(2, labeled_tensor)
810        actual_lt = infix_op(2, labeled_tensor)
811        # Python uses greater for the reflexive version of less (and vise-versa)
812        if 'less' in op_name:
813          op_name = op_name.replace('less', 'greater')
814        elif 'greater' in op_name:
815          op_name = op_name.replace('greater', 'less')
816        self.assertIn(op_name, actual_lt.name)
817        self.assertLabeledTensorsEqual(expected_lt, actual_lt)
818
819
820class LogicalBinaryOpsTest(Base, DocStringCheckMixin, BinaryOpsTestsMixin):
821
822  def setUp(self):
823    super(LogicalBinaryOpsTest, self).setUp()
824
825    self.ops = [
826        ('logical_and', operator.and_, math_ops.logical_and, core.logical_and),
827        ('logical_or', operator.or_, math_ops.logical_or, core.logical_or),
828        ('logical_xor', operator.xor, math_ops.logical_xor, core.logical_xor),
829    ]
830    self.test_lt_1 = self.original_lt < 10
831    self.test_lt_2 = self.original_lt < 5
832    self.test_lt_1_broadcast = self.test_lt_1.tensor
833    self.test_lt_2_broadcast = self.test_lt_2.tensor
834    self.broadcast_axes = self.test_lt_1.axes
835
836
837class FloatBinaryOpsTest(Base, DocStringCheckMixin, BinaryOpsTestsMixin):
838
839  def setUp(self):
840    super(FloatBinaryOpsTest, self).setUp()
841
842    self.ops = [
843        ('igamma', None, math_ops.igamma, core.igamma),
844        ('igammac', None, math_ops.igammac, core.igammac),
845        ('zeta', None, math_ops.zeta, core.zeta),
846        ('polygamma', None, math_ops.polygamma, core.polygamma),
847        ('maximum', None, math_ops.maximum, core.maximum),
848        ('minimum', None, math_ops.minimum, core.minimum),
849        ('squared_difference', None, math_ops.squared_difference,
850         core.squared_difference),
851    ]
852    total_size = np.prod([v.size for v in self.original_lt.axes.values()])
853    test_lt = core.LabeledTensor(
854        math_ops.cast(self.original_lt, dtypes.float32) / total_size,
855        self.original_lt.axes)
856    self.test_lt_1 = test_lt
857    self.test_lt_2 = 1.0 - test_lt
858    self.test_lt_1_broadcast = self.test_lt_1.tensor
859    self.test_lt_2_broadcast = self.test_lt_2.tensor
860    self.broadcast_axes = self.test_lt_1.axes
861
862
863if __name__ == '__main__':
864  test_lib.main()
865