1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import operator 21import re 22import textwrap 23 24import numpy as np 25from six.moves import range # pylint: disable=redefined-builtin 26 27from tensorflow.contrib.labeled_tensor.python.ops import _typecheck as tc 28from tensorflow.contrib.labeled_tensor.python.ops import core 29from tensorflow.contrib.labeled_tensor.python.ops import test_util 30from tensorflow.python.framework import constant_op 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import tensor_shape 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import math_ops 36from tensorflow.python.platform import test as test_lib 37 38 39class AxisTest(test_lib.TestCase): 40 41 def setUp(self): 42 d_7 = tensor_shape.Dimension(7) 43 p_rgb = ['red', 'green', 'blue'] 44 45 self.i_7 = core.Axis('7', d_7) 46 self.i_7p = core.Axis('7prime', d_7) 47 self.i_rgb = core.Axis('rgb', p_rgb) 48 self.i_range = core.Axis('range', range(7)) 49 self.i_unknown = core.Axis('unknown', None) 50 51 def test_equality(self): 52 53 axes = [self.i_7, self.i_7p, self.i_rgb, self.i_range, self.i_unknown] 54 for i, axis_0 in enumerate(axes): 55 for j, axis_1 in enumerate(axes): 56 if i == j: 57 self.assertEqual(axis_0, axis_1) 58 else: 59 self.assertNotEqual(axis_0, axis_1) 60 61 def test_axis_value(self): 62 self.assertEqual(self.i_7.value, tensor_shape.Dimension(7)) 63 self.assertTrue(self.i_range.value == tuple(range(7))) 64 65 def test_axis_input(self): 66 axes = [self.i_7, self.i_7p, self.i_rgb, self.i_range, self.i_unknown] 67 for axis in axes: 68 self.assertEqual(axis, core.Axis(axis.name, axis.value)) 69 70 def test_axis_value_input(self): 71 axis = self.i_range 72 for value in [range(7), list(range(7)), np.arange(7)]: 73 self.assertEqual(axis, core.Axis(axis.name, value)) 74 75 def test_size(self): 76 self.assertEqual(len(self.i_7), 7) 77 self.assertEqual(len(self.i_rgb), 3) 78 self.assertEqual(len(self.i_range), 7) 79 self.assertEqual(self.i_unknown.size, None) 80 81 def test_concat_single(self): 82 red = core.Axis('rgb', ['red']) 83 84 self.assertEqual(core.concat_axes([red]), red) 85 86 def test_concat_many(self): 87 red = core.Axis('rgb', ['red']) 88 green = core.Axis('rgb', ['green']) 89 blue = core.Axis('rgb', ['blue']) 90 red_green_blue = core.Axis('rgb', ['red', 'green', 'blue']) 91 92 self.assertEqual(core.concat_axes([red, green, blue]), red_green_blue) 93 94 def test_concat_different_names(self): 95 red = core.Axis('red', ['red']) 96 green = core.Axis('green', ['red']) 97 with self.assertRaises(ValueError): 98 core.concat_axes([red, green]) 99 100 def test_concat_unknown(self): 101 red = core.Axis('rgb', None) 102 green = core.Axis('rgb', None) 103 self.assertEqual(core.concat_axes([red, green]), red) 104 105 def test_repr(self): 106 self.assertEqual("Axis('7', Dimension(7))", repr(self.i_7)) 107 108 def test_invalid_input(self): 109 with self.assertRaises(TypeError): 110 core.Axis('foo', [{}]) 111 with self.assertRaises(ValueError): 112 core.Axis('foo', [1, 2, 3, 1]) 113 red = core.Axis('foo', ['red']) 114 with self.assertRaises(tc.Error): 115 core.concat_axes([red, 1]) 116 117 def test_as_axis(self): 118 self.assertEqual(self.i_7, core.as_axis(('7', 7))) 119 self.assertEqual(self.i_7, core.as_axis(self.i_7)) 120 121 122class AxesTest(test_lib.TestCase): 123 124 def setUp(self): 125 d_7 = tensor_shape.Dimension(7) 126 d_8 = tensor_shape.Dimension(8) 127 p_rgb = ['red', 'green', 'blue'] 128 p_range = range(7) 129 130 self.i_8 = core.Axis('8', d_8) 131 132 self.a0 = core.Axes([('d7', d_7)]) 133 self.a1 = core.Axes([('d7', d_7)]) 134 self.a2 = core.Axes([('d7', d_7), ('rgb', p_rgb)]) 135 self.a3 = core.Axes([('8', d_8), ('range', p_range)]) 136 137 def test_equality(self): 138 self.assertEqual(self.a0, self.a0) 139 self.assertEqual(self.a0, self.a1) 140 self.assertNotEqual(self.a0, self.a2) 141 142 def test_repr(self): 143 self.assertEqual("Axes([('d7', Dimension(7))])", repr(self.a0)) 144 145 def test_remove(self): 146 a = self.a3.remove('range') 147 self.assertEqual(a, core.Axes([self.i_8])) 148 with self.assertRaises(KeyError): 149 self.a3.remove('foobar') 150 151 def test_typecheck_error_message(self): 152 pattern = ('List(Union(labeled_tensor.Axis, Tuple(..., ' 153 'Union(Union(numpy.ndarray, %s, list, tuple), ' 154 'Optional(Union(tensorflow.Dimension, int))))))' % 155 range.__name__) 156 regexp = re.escape(pattern).replace(re.escape('...'), '.*') 157 with self.assertRaisesRegexp(tc.Error, 'allowed type ' + regexp): 158 core.Axes(None) 159 160 161class LabeledTensorTest(test_util.Base): 162 163 def setUp(self): 164 tensor = array_ops.ones([7, 3, 8, 1]) 165 a0 = ('x', range(7)) 166 a1 = ('channel', ['red', 'green', 'blue']) 167 a2 = ('y', 8) 168 a3 = ('z', tensor_shape.Dimension(1)) 169 170 self.lt = core.LabeledTensor(tensor, [a0, a1, a2, a3]) 171 172 def test_repr(self): 173 pattern = textwrap.dedent("""\ 174 <LabeledTensor '...' shape=(7, 3, 8, 1) dtype=float32 175 axes=[('x', ...), 176 ('channel', ...), 177 ('y', Dimension(8)), 178 ('z', Dimension(1))]>""") 179 regexp = re.escape(pattern).replace(re.escape('...'), '.*') 180 self.assertRegexpMatches(repr(self.lt), regexp) 181 182 def test_reuse_existing_axes(self): 183 alt_lt = core.LabeledTensor(self.lt.tensor, self.lt.axes) 184 self.assertLabeledTensorsEqual(alt_lt, self.lt) 185 186 def test_reuse_existing_axis_objects(self): 187 alt_lt = core.LabeledTensor(self.lt.tensor, self.lt.axes.values()) 188 self.assertLabeledTensorsEqual(alt_lt, self.lt) 189 190 def test_indexing_scalars(self): 191 actual = self.lt[:, :, :, 0] 192 expected = core.LabeledTensor(self.lt.tensor[:, :, :, 0], 193 list(self.lt.axes.values())[:-1]) 194 self.assertLabeledTensorsEqual(actual, expected) 195 196 actual = self.lt[1, :, :, 0] 197 expected = core.LabeledTensor(self.lt.tensor[1, :, :, 0], 198 list(self.lt.axes.values())[1:-1]) 199 self.assertLabeledTensorsEqual(actual, expected) 200 201 actual = self.lt[1, 2, :, 0] 202 expected = core.LabeledTensor(self.lt.tensor[1, 2, :, 0], 203 list(self.lt.axes.values())[2:-1]) 204 self.assertLabeledTensorsEqual(actual, expected) 205 206 def test_indexing_1d(self): 207 lt_1d = self.lt[1, 2, :, 0] 208 actual = lt_1d[3] 209 expected = core.LabeledTensor(lt_1d.tensor[3], []) 210 self.assertLabeledTensorsEqual(actual, expected) 211 212 def test_indexing_slices(self): 213 actual = self.lt[:3, :, :, :] 214 axes = [('x', range(3))] + list(self.lt.axes.values())[1:] 215 expected = core.LabeledTensor(self.lt.tensor[:3, :, :, :], axes) 216 self.assertLabeledTensorsEqual(actual, expected) 217 218 def test_invalid_indexing(self): 219 with self.assertRaises(ValueError): 220 self.lt[0] # pylint: disable=pointless-statement 221 with self.assertRaises(ValueError): 222 self.lt[:, :, :, :, 0] # pylint: disable=pointless-statement 223 224 def test_unknown_size(self): 225 tensor = array_ops.placeholder(dtypes.string, [None]) 226 actual = core.LabeledTensor(tensor, ['x']) 227 self.assertIsNone(actual.axes['x'].size) 228 self.assertIsNone(actual.axes['x'].value.value) 229 230 def test_eq(self): 231 self.assertEqual(self.lt, self.lt) 232 self.assertNotEqual(self.lt, self.lt.tensor) 233 self.assertNotEqual(self.lt.tensor, self.lt) 234 235 def test_hash(self): 236 lt1 = self.lt 237 lt2 = core.LabeledTensor(self.lt.tensor, self.lt.axes) 238 self.assertEqual(lt1, lt2) 239 self.assertEqual(hash(lt1), hash(lt2)) 240 241 def test_name(self): 242 self.assertEqual(self.lt.name, self.lt.tensor.name) 243 244 def test_dtype(self): 245 self.assertEqual(self.lt.dtype, self.lt.tensor.dtype) 246 247 def test_shape(self): 248 self.assertEqual(self.lt.shape, self.lt.tensor.shape) 249 250 def test_get_shape(self): 251 self.assertEqual(self.lt.get_shape(), self.lt.tensor.get_shape()) 252 253 def test_convert_to_tensor(self): 254 expected = self.lt.tensor 255 actual = ops.convert_to_tensor(self.lt) 256 self.assertIs(expected, actual) 257 258 259class Base(test_util.Base): 260 261 def setUp(self): 262 self.x_size = 7 263 self.channel_size = 3 264 self.z_size = 4 265 self.probs_size = 11 266 267 tensor = math_ops.range(0, self.x_size * self.channel_size * self.z_size * 268 self.probs_size) 269 tensor = array_ops.reshape( 270 tensor, [self.x_size, self.channel_size, self.z_size, self.probs_size]) 271 a0 = ('x', range(self.x_size)) 272 a1 = ('channel', ['red', 'green', 'blue']) 273 a2 = 'z' 274 a3 = ('probs', np.linspace(0.0, 1.0, self.probs_size)) 275 276 self.tensor = tensor 277 self.a0 = a0 278 self.a1 = a1 279 self.a2 = a2 280 self.a3 = a3 281 self.original_lt = core.LabeledTensor(tensor, [a0, a1, a2, a3]) 282 283 self.x_probs_lt = core.slice_function(self.original_lt, 284 {'z': 0, 285 'channel': 0}) 286 self.channel_probs_lt = core.slice_function(self.original_lt, 287 {'x': 3, 288 'z': 0}) 289 290 291class IdentityTest(Base): 292 293 def test_name(self): 294 identity_lt = core.identity(self.original_lt) 295 self.assertIn('lt_identity', identity_lt.name) 296 297 298class SliceFunctionTest(Base): 299 300 def test_name(self): 301 select_lt = core.slice_function(self.original_lt, {'channel': 1}) 302 self.assertIn('lt_slice', select_lt.name) 303 304 def test_scalar(self): 305 select_lt = core.slice_function(self.original_lt, {'channel': 1}) 306 golden_lt = core.LabeledTensor(self.tensor[:, 1, :, :], 307 [self.a0, self.a2, self.a3]) 308 309 self.assertLabeledTensorsEqual(select_lt, golden_lt) 310 311 def test_slice(self): 312 select_lt = core.slice_function(self.original_lt, {'channel': slice(0, 2)}) 313 314 a1_sliced = ('channel', ['red', 'green']) 315 golden_lt = core.LabeledTensor(self.tensor[:, :2, :, :], 316 [self.a0, a1_sliced, self.a2, self.a3]) 317 318 self.assertLabeledTensorsEqual(select_lt, golden_lt) 319 320 def test_slices(self): 321 select_lt = core.slice_function( 322 self.original_lt, {'x': slice(1, 5), 323 'channel': slice(1, None)}) 324 325 a0_sliced = ('x', range(1, 5)) 326 a1_sliced = ('channel', ['green', 'blue']) 327 golden_lt = core.LabeledTensor(self.tensor[1:5, 1:, :, :], 328 [a0_sliced, a1_sliced, self.a2, self.a3]) 329 330 self.assertLabeledTensorsEqual(select_lt, golden_lt) 331 332 def test_slice_unlabeled(self): 333 select_lt = core.slice_function(self.original_lt, {'z': slice(1, 3)}) 334 335 a2_sliced = 'z' 336 golden_lt = core.LabeledTensor(self.tensor[:, :, 1:3, :], 337 [self.a0, self.a1, a2_sliced, self.a3]) 338 339 self.assertLabeledTensorsEqual(select_lt, golden_lt) 340 341 def test_slice_unknown_shape(self): 342 lt = core.LabeledTensor( 343 array_ops.placeholder(dtypes.float32, [None, 1]), ['x', 'y']) 344 sliced_lt = core.slice_function(lt, {'y': 0}) 345 self.assertEqual(list(sliced_lt.axes.values()), [lt.axes['x']]) 346 347 348class TransposeTest(Base): 349 350 def test_name(self): 351 transpose_lt = core.transpose(self.original_lt, 352 self.original_lt.axes.keys()) 353 self.assertIn('lt_transpose', transpose_lt.name) 354 355 def test_identity(self): 356 transpose_lt = core.transpose(self.original_lt, 357 self.original_lt.axes.keys()) 358 golden_lt = self.original_lt 359 360 self.assertLabeledTensorsEqual(transpose_lt, golden_lt) 361 362 def test(self): 363 transpose_lt = core.transpose(self.original_lt, 364 ['z', 'channel', 'x', 'probs']) 365 golden_lt = core.LabeledTensor( 366 array_ops.transpose(self.tensor, [2, 1, 0, 3]), 367 [self.a2, self.a1, self.a0, self.a3]) 368 369 self.assertLabeledTensorsEqual(transpose_lt, golden_lt) 370 371 def test_default_axis_order(self): 372 transpose_lt = core.transpose(self.original_lt) 373 golden_lt = core.LabeledTensor( 374 array_ops.transpose(self.tensor, [3, 2, 1, 0]), 375 list(reversed(list(self.original_lt.axes.values())))) 376 377 self.assertLabeledTensorsEqual(transpose_lt, golden_lt) 378 379 def test_invalid_input(self): 380 with self.assertRaises(ValueError): 381 core.transpose(self.original_lt, ['channel', 'x', 'probs']) 382 with self.assertRaises(ValueError): 383 core.transpose(self.original_lt, ['z', 'foo', 'x', 'probs']) 384 385 386class ExpandDimsTest(Base): 387 388 def test_name(self): 389 expand_lt = core.expand_dims(self.original_lt, self.original_lt.axes.keys()) 390 self.assertIn('lt_expand', expand_lt.name) 391 392 def test_identity(self): 393 expand_lt = core.expand_dims(self.original_lt, self.original_lt.axes.keys()) 394 golden_lt = self.original_lt 395 396 self.assertLabeledTensorsEqual(expand_lt, golden_lt) 397 398 def test(self): 399 expand_lt = core.expand_dims( 400 self.original_lt, ['foo', 'x', 'bar', 'channel', 'z', 'probs', 'grok']) 401 golden_lt = core.LabeledTensor( 402 array_ops.reshape(self.tensor, [ 403 1, self.x_size, 1, self.channel_size, self.z_size, self.probs_size, 404 1 405 ]), ['foo', self.a0, 'bar', self.a1, self.a2, self.a3, 'grok']) 406 407 self.assertLabeledTensorsEqual(expand_lt, golden_lt) 408 409 def test_label(self): 410 expand_lt = core.expand_dims(self.original_lt, [ 411 'x', 412 'channel', 413 ('foo', 'bar'), 414 'z', 415 'probs', 416 ]) 417 golden_lt = core.LabeledTensor( 418 array_ops.reshape( 419 self.tensor, 420 [self.x_size, self.channel_size, 1, self.z_size, self.probs_size]), 421 [self.a0, self.a1, ('foo', ['bar']), self.a2, self.a3]) 422 423 self.assertLabeledTensorsEqual(expand_lt, golden_lt) 424 425 def test_unknown_dimension(self): 426 orig_lt = core.LabeledTensor( 427 array_ops.placeholder(dtypes.float32, [None]), ['x']) 428 expand_lt = core.expand_dims(orig_lt, ['x', 'y']) 429 self.assertEqual(expand_lt.axes, core.Axes([('x', None), ('y', 1)])) 430 431 def test_invalid_input(self): 432 with self.assertRaises(core.AxisOrderError): 433 core.expand_dims(self.original_lt, 434 ['foo', 'not_x', 'bar', 'channel', 'z', 'probs', 'grok']) 435 with self.assertRaises(core.AxisOrderError): 436 core.expand_dims(self.original_lt, 437 ['foo', 'z', 'bar', 'channel', 'x', 'probs', 'grok']) 438 439 440class AxisOrderScopeTest(Base): 441 442 def test(self): 443 xyz = ['x', 'y', 'z'] 444 abc = ['a', 'b', 'c'] 445 446 self.assertIsNone(core.get_axis_order()) 447 448 with core.axis_order_scope(xyz): 449 self.assertEqual(core.get_axis_order(), xyz) 450 451 with core.axis_order_scope(): 452 self.assertIsNone(core.get_axis_order()) 453 454 with core.axis_order_scope(abc): 455 self.assertEqual(core.get_axis_order(), abc) 456 457 self.assertIsNone(core.get_axis_order()) 458 459 self.assertEqual(core.get_axis_order(), xyz) 460 461 self.assertIsNone(core.get_axis_order()) 462 463 464class CheckAxisOrderTest(Base): 465 466 def test_passes(self): 467 axis_order = ['w', 'x', 'y', 'z'] 468 469 lt = core.LabeledTensor(array_ops.ones((1, 1, 1, 1)), axis_order) 470 core.check_axis_order(lt, axis_order) 471 472 lt = core.LabeledTensor(array_ops.ones((1, 1, 1)), axis_order[1:]) 473 core.check_axis_order(lt, axis_order) 474 475 lt = core.LabeledTensor(array_ops.ones((1, 1, 1)), axis_order[:-1]) 476 core.check_axis_order(lt, axis_order) 477 478 def test_invalid(self): 479 axis_order = ['w', 'x', 'y', 'z'] 480 lt = core.LabeledTensor(array_ops.ones((1, 1, 1, 1)), axis_order) 481 with self.assertRaises(core.AxisOrderError): 482 core.check_axis_order(lt) 483 with self.assertRaises(core.AxisOrderError): 484 core.check_axis_order(lt, axis_order[:-1]) 485 with self.assertRaises(core.AxisOrderError): 486 core.check_axis_order(lt, axis_order[::-1]) 487 488 def test_scope(self): 489 axis_order = ['w', 'x', 'y', 'z'] 490 lt = core.LabeledTensor(array_ops.ones((1, 1, 1, 1)), axis_order) 491 with core.axis_order_scope(axis_order): 492 core.check_axis_order(lt) 493 494 495class ImposeAxisOrderTest(Base): 496 497 def test_identity(self): 498 axis_order = ['w', 'x', 'y', 'z'] 499 lt = core.LabeledTensor( 500 array_ops.reshape(math_ops.range(24), (1, 2, 3, 4)), axis_order) 501 actual = core.impose_axis_order(lt, axis_order) 502 self.assertLabeledTensorsEqual(lt, actual) 503 504 lt = core.LabeledTensor( 505 array_ops.reshape(math_ops.range(6), (1, 2, 3)), axis_order[:3]) 506 actual = core.impose_axis_order(lt, axis_order) 507 self.assertLabeledTensorsEqual(lt, actual) 508 509 def test_reverse(self): 510 axis_order = ['w', 'x', 'y', 'z'] 511 512 lt = core.LabeledTensor( 513 array_ops.reshape(math_ops.range(24), (1, 2, 3, 4)), axis_order) 514 actual = core.impose_axis_order(lt, axis_order[::-1]) 515 expected = core.transpose(lt, axis_order[::-1]) 516 self.assertLabeledTensorsEqual(expected, actual) 517 518 lt = core.LabeledTensor( 519 array_ops.reshape(math_ops.range(6), (1, 2, 3)), axis_order[:3]) 520 actual = core.impose_axis_order(lt, axis_order[::-1]) 521 expected = core.transpose(lt, ['y', 'x', 'w']) 522 self.assertLabeledTensorsEqual(expected, actual) 523 524 def test_scope(self): 525 axis_order = ['w', 'x', 'y', 'z'] 526 527 lt = core.LabeledTensor( 528 array_ops.reshape(math_ops.range(24), (1, 2, 3, 4)), axis_order) 529 expected = core.transpose(lt, axis_order[::-1]) 530 with core.axis_order_scope(axis_order[::-1]): 531 actual = core.impose_axis_order(lt) 532 self.assertLabeledTensorsEqual(expected, actual) 533 534 def test_invalid(self): 535 lt = core.LabeledTensor( 536 array_ops.reshape(math_ops.range(2), (1, 2)), ['x', 'y']) 537 with self.assertRaises(ValueError): 538 core.impose_axis_order(lt) 539 with self.assertRaises(ValueError): 540 core.impose_axis_order(lt, ['x']) 541 542 543class FindConsistentOrderingTest(Base): 544 545 def test(self): 546 cases = [ 547 ([], [], []), 548 (['x'], [], ['x']), 549 ([], ['x'], ['x']), 550 (['x'], ['x'], ['x']), 551 (['x'], ['y'], ['x', 'y']), 552 (['y'], ['x'], ['y', 'x']), 553 (['x', 'y'], ['x', 'y'], ['x', 'y']), 554 (['x', 'y'], ['y', 'x'], None), 555 (['x', 'y'], ['y', 'z'], ['x', 'y', 'z']), 556 (['x', 'z'], ['y', 'z'], ['x', 'y', 'z']), 557 (['x', 'y'], ['x', 'z'], ['x', 'y', 'z']), 558 (['w', 'x'], ['y', 'z'], ['w', 'x', 'y', 'z']), 559 (['x', 'y', 'z'], ['z', 'x'], None), 560 (['x', 'y', 'z'], ['x'], ['x', 'y', 'z']), 561 ([], ['x', 'y', 'z'], ['x', 'y', 'z']), 562 ] 563 for a, b, expected in cases: 564 actual = core._find_consistent_ordering(a, b) 565 msg = ('unexpected ordering between %r and %r:\nexpected: %r\nactual: %r' 566 % (a, b, expected, actual)) 567 self.assertEqual(expected, actual, msg=msg) 568 569 570class AlignTest(Base): 571 572 def test_name(self): 573 align_lt_0, align_lt_1, _ = core.align(self.original_lt, self.original_lt) 574 self.assertIn('lt_align', align_lt_0.name) 575 self.assertIn('/0', align_lt_0.name) 576 self.assertIn('lt_align', align_lt_1.name) 577 self.assertIn('/1', align_lt_1.name) 578 579 def test_identical_shaped_inputs(self): 580 offset_tensor = self.original_lt.tensor + 1 581 offset_lt = core.LabeledTensor(offset_tensor, self.original_lt.axes) 582 583 align_lt, align_offset_lt, broadcast_axes = core.align(self.original_lt, 584 offset_lt) 585 586 self.assertLabeledTensorsEqual(align_lt, self.original_lt) 587 self.assertLabeledTensorsEqual(align_offset_lt, offset_lt) 588 self.assertEqual(broadcast_axes, self.original_lt.axes) 589 590 def test_different_inputs(self): 591 # The correct axis ordering is ['x', 'channel', 'probs']. 592 align_x_probs_lt, align_channel_probs_lt, broadcast_axes = core.align( 593 self.x_probs_lt, self.channel_probs_lt) 594 595 x_probs_golden_lt = core.LabeledTensor( 596 array_ops.reshape(self.x_probs_lt.tensor, 597 [self.x_size, 1, self.probs_size]), 598 [self.a0, 'channel', self.a3]) 599 600 self.assertLabeledTensorsEqual(align_x_probs_lt, x_probs_golden_lt) 601 602 channel_probs_golden_lt = core.LabeledTensor( 603 array_ops.reshape(self.channel_probs_lt.tensor, 604 [1, self.channel_size, self.probs_size]), 605 ['x', self.a1, self.a3]) 606 607 self.assertLabeledTensorsEqual(align_channel_probs_lt, 608 channel_probs_golden_lt) 609 610 self.assertEqual(broadcast_axes, core.Axes([self.a0, self.a1, self.a3])) 611 612 def test_axis_order_scope(self): 613 xz_lt = core.LabeledTensor(array_ops.ones((2, 3)), ['x', 'z']) 614 yz_lt = core.LabeledTensor(array_ops.ones((4, 3)), ['y', 'z']) 615 616 _, _, broadcast_axes = core.align(xz_lt, yz_lt) 617 self.assertEqual(list(broadcast_axes.keys()), ['x', 'y', 'z']) 618 619 _, _, broadcast_axes = core.align(yz_lt, xz_lt) 620 self.assertEqual(list(broadcast_axes.keys()), ['y', 'x', 'z']) 621 622 with core.axis_order_scope(['x', 'y', 'z']): 623 _, _, broadcast_axes = core.align(yz_lt, xz_lt) 624 self.assertEqual(list(broadcast_axes.keys()), ['x', 'y', 'z']) 625 626 with core.axis_order_scope(['x', 'y']): 627 with self.assertRaises(core.AxisOrderError): 628 core.align(xz_lt, yz_lt) 629 with self.assertRaises(core.AxisOrderError): 630 core.align(yz_lt, xz_lt) 631 632 def test_invalid_input(self): 633 lt_0 = core.LabeledTensor(array_ops.zeros([5]), [('a', range(5))]) 634 lt_1 = core.LabeledTensor(array_ops.zeros([5]), [('a', range(1, 6))]) 635 with self.assertRaises(ValueError): 636 core.align(lt_0, lt_1) 637 638 639class ConvertToLabeledTensorTest(Base): 640 641 # TODO(shoyer): Simplify these tests once we can reuse labeled tensors in 642 # assertLabeledTensorsEqual. 643 644 def test_labeled_tensor(self): 645 actual = core.convert_to_labeled_tensor(self.original_lt) 646 self.assertLabeledTensorsEqual(actual, self.original_lt) 647 648 def test_python_scalar(self): 649 actual = core.convert_to_labeled_tensor(42) 650 golden_lt = core.LabeledTensor(ops.convert_to_tensor(42), []) 651 self.assertLabeledTensorsEqual(actual, golden_lt) 652 653 def test_numpy_array(self): 654 actual = core.convert_to_labeled_tensor(np.array(42)) 655 golden_lt = core.LabeledTensor(ops.convert_to_tensor(42), []) 656 self.assertLabeledTensorsEqual(actual, golden_lt) 657 658 def test_tensor(self): 659 actual = core.convert_to_labeled_tensor(constant_op.constant(42)) 660 golden_lt = core.LabeledTensor(ops.convert_to_tensor(42), []) 661 self.assertLabeledTensorsEqual(actual, golden_lt) 662 663 def test_invalid_input(self): 664 with self.assertRaises(ValueError): 665 core.convert_to_labeled_tensor(math_ops.range(5)) 666 with self.assertRaises(ValueError): 667 core.convert_to_labeled_tensor(np.array([1, 2])) 668 669 670class DocStringCheckMixin(object): 671 # requires self.ops to be defined 672 673 def test_function_docstring_and_name(self): 674 for op_name, _, _, lt_op in self.ops: 675 if lt_op is not None: 676 self.assertIn('tf.%s' % op_name, lt_op.__doc__) 677 self.assertEqual(op_name, lt_op.__name__) 678 679 680class UnaryOpsTestsMixin(object): 681 # requires self.ops and self.test_lt to be defined 682 683 def test_core_op(self): 684 for op_name, _, tf_op, lt_op in self.ops: 685 if tf_op is not None: 686 golden_lt = core.LabeledTensor( 687 tf_op(self.test_lt.tensor), self.test_lt.axes) 688 actual_lt = lt_op(self.test_lt) 689 self.assertIn(op_name, actual_lt.name) 690 self.assertLabeledTensorsEqual(golden_lt, actual_lt) 691 692 def test_infix(self): 693 for op_name, infix_op, _, _ in self.ops: 694 if infix_op is not None: 695 expected_lt = core.LabeledTensor( 696 infix_op(self.test_lt.tensor), self.test_lt.axes) 697 actual_lt = infix_op(self.test_lt) 698 self.assertIn(op_name, actual_lt.name) 699 self.assertLabeledTensorsEqual(expected_lt, actual_lt) 700 701 702class CoreUnaryOpsTest(Base, DocStringCheckMixin, UnaryOpsTestsMixin): 703 704 def setUp(self): 705 super(CoreUnaryOpsTest, self).setUp() 706 707 self.ops = [ 708 ('abs', operator.abs, math_ops.abs, core.abs_function), 709 ('neg', operator.neg, math_ops.negative, core.neg), 710 # TODO(shoyer): add unary + to core TensorFlow 711 ('pos', None, None, None), 712 ('sign', None, math_ops.sign, core.sign), 713 ('reciprocal', None, math_ops.reciprocal, core.reciprocal), 714 ('square', None, math_ops.square, core.square), 715 ('round', None, math_ops.round, core.round_function), 716 ('sqrt', None, math_ops.sqrt, core.sqrt), 717 ('rsqrt', None, math_ops.rsqrt, core.rsqrt), 718 ('log', None, math_ops.log, core.log), 719 ('exp', None, math_ops.exp, core.exp), 720 ('log', None, math_ops.log, core.log), 721 ('ceil', None, math_ops.ceil, core.ceil), 722 ('floor', None, math_ops.floor, core.floor), 723 ('cos', None, math_ops.cos, core.cos), 724 ('sin', None, math_ops.sin, core.sin), 725 ('tan', None, math_ops.tan, core.tan), 726 ('acos', None, math_ops.acos, core.acos), 727 ('asin', None, math_ops.asin, core.asin), 728 ('atan', None, math_ops.atan, core.atan), 729 ('lgamma', None, math_ops.lgamma, core.lgamma), 730 ('digamma', None, math_ops.digamma, core.digamma), 731 ('erf', None, math_ops.erf, core.erf), 732 ('erfc', None, math_ops.erfc, core.erfc), 733 ('lgamma', None, math_ops.lgamma, core.lgamma), 734 ] 735 total_size = np.prod([v.size for v in self.original_lt.axes.values()]) 736 self.test_lt = core.LabeledTensor( 737 math_ops.cast(self.original_lt, dtypes.float32) / total_size, 738 self.original_lt.axes) 739 740 741class LogicalNotTest(Base, DocStringCheckMixin, UnaryOpsTestsMixin): 742 743 def setUp(self): 744 super(LogicalNotTest, self).setUp() 745 self.ops = [('logical_not', operator.invert, math_ops.logical_not, 746 core.logical_not),] 747 self.test_lt = self.original_lt < 10 748 749 750class BinaryOpsTestsMixin(object): 751 # requires self.ops, self.test_lt_1, self.test_lt_2, self.test_lt_1_broadcast 752 # and self.test_lt_2_broadcast to be defined 753 754 def test_core_op(self): 755 for op_name, _, tf_op, lt_op in self.ops: 756 golden_tensor = tf_op(self.test_lt_1_broadcast, self.test_lt_2_broadcast) 757 golden_lt = core.LabeledTensor(golden_tensor, self.broadcast_axes) 758 actual_lt = lt_op(self.test_lt_1, self.test_lt_2) 759 self.assertIn(op_name, actual_lt.name) 760 self.assertLabeledTensorsEqual(golden_lt, actual_lt) 761 762 def test_infix(self): 763 for op_name, infix_op, _, lt_op in self.ops: 764 if infix_op is not None: 765 expected_lt = lt_op(self.test_lt_1, self.test_lt_2) 766 actual_lt = infix_op(self.test_lt_1, self.test_lt_2) 767 self.assertIn(op_name, actual_lt.name) 768 self.assertLabeledTensorsEqual(expected_lt, actual_lt) 769 770 771class CoreBinaryOpsTest(Base, DocStringCheckMixin, BinaryOpsTestsMixin): 772 773 def setUp(self): 774 super(CoreBinaryOpsTest, self).setUp() 775 776 self.x_probs_broadcast_tensor = array_ops.reshape( 777 self.x_probs_lt.tensor, [self.x_size, 1, self.probs_size]) 778 779 self.channel_probs_broadcast_tensor = array_ops.reshape( 780 self.channel_probs_lt.tensor, [1, self.channel_size, self.probs_size]) 781 782 # == and != are not element-wise for tf.Tensor, so they shouldn't be 783 # elementwise for LabeledTensor, either. 784 self.ops = [ 785 ('add', operator.add, math_ops.add, core.add), 786 ('sub', operator.sub, math_ops.subtract, core.sub), 787 ('mul', operator.mul, math_ops.multiply, core.mul), 788 ('div', operator.truediv, math_ops.div, core.div), 789 ('mod', operator.mod, math_ops.mod, core.mod), 790 ('pow', operator.pow, math_ops.pow, core.pow_function), 791 ('equal', None, math_ops.equal, core.equal), 792 ('less', operator.lt, math_ops.less, core.less), 793 ('less_equal', operator.le, math_ops.less_equal, core.less_equal), 794 ('not_equal', None, math_ops.not_equal, core.not_equal), 795 ('greater', operator.gt, math_ops.greater, core.greater), 796 ('greater_equal', operator.ge, math_ops.greater_equal, 797 core.greater_equal), 798 ] 799 self.test_lt_1 = self.x_probs_lt 800 self.test_lt_2 = self.channel_probs_lt 801 self.test_lt_1_broadcast = self.x_probs_broadcast_tensor 802 self.test_lt_2_broadcast = self.channel_probs_broadcast_tensor 803 self.broadcast_axes = [self.a0, self.a1, self.a3] 804 805 def test_reflexive(self): 806 labeled_tensor = self.x_probs_lt + 1 # all elements must be >0 for division 807 for op_name, infix_op, _, lt_op in self.ops: 808 if infix_op is not None: 809 expected_lt = lt_op(2, labeled_tensor) 810 actual_lt = infix_op(2, labeled_tensor) 811 # Python uses greater for the reflexive version of less (and vise-versa) 812 if 'less' in op_name: 813 op_name = op_name.replace('less', 'greater') 814 elif 'greater' in op_name: 815 op_name = op_name.replace('greater', 'less') 816 self.assertIn(op_name, actual_lt.name) 817 self.assertLabeledTensorsEqual(expected_lt, actual_lt) 818 819 820class LogicalBinaryOpsTest(Base, DocStringCheckMixin, BinaryOpsTestsMixin): 821 822 def setUp(self): 823 super(LogicalBinaryOpsTest, self).setUp() 824 825 self.ops = [ 826 ('logical_and', operator.and_, math_ops.logical_and, core.logical_and), 827 ('logical_or', operator.or_, math_ops.logical_or, core.logical_or), 828 ('logical_xor', operator.xor, math_ops.logical_xor, core.logical_xor), 829 ] 830 self.test_lt_1 = self.original_lt < 10 831 self.test_lt_2 = self.original_lt < 5 832 self.test_lt_1_broadcast = self.test_lt_1.tensor 833 self.test_lt_2_broadcast = self.test_lt_2.tensor 834 self.broadcast_axes = self.test_lt_1.axes 835 836 837class FloatBinaryOpsTest(Base, DocStringCheckMixin, BinaryOpsTestsMixin): 838 839 def setUp(self): 840 super(FloatBinaryOpsTest, self).setUp() 841 842 self.ops = [ 843 ('igamma', None, math_ops.igamma, core.igamma), 844 ('igammac', None, math_ops.igammac, core.igammac), 845 ('zeta', None, math_ops.zeta, core.zeta), 846 ('polygamma', None, math_ops.polygamma, core.polygamma), 847 ('maximum', None, math_ops.maximum, core.maximum), 848 ('minimum', None, math_ops.minimum, core.minimum), 849 ('squared_difference', None, math_ops.squared_difference, 850 core.squared_difference), 851 ] 852 total_size = np.prod([v.size for v in self.original_lt.axes.values()]) 853 test_lt = core.LabeledTensor( 854 math_ops.cast(self.original_lt, dtypes.float32) / total_size, 855 self.original_lt.axes) 856 self.test_lt_1 = test_lt 857 self.test_lt_2 = 1.0 - test_lt 858 self.test_lt_1_broadcast = self.test_lt_1.tensor 859 self.test_lt_2_broadcast = self.test_lt_2.tensor 860 self.broadcast_axes = self.test_lt_1.axes 861 862 863if __name__ == '__main__': 864 test_lib.main() 865