1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test cases for binary operators.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.compiler.tests.xla_test import XLATestCase 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import errors 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import bitwise_ops 28from tensorflow.python.ops import gen_math_ops 29from tensorflow.python.ops import gen_nn_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import nn_ops 32from tensorflow.python.platform import googletest 33 34 35class BinaryOpsTest(XLATestCase): 36 """Test cases for binary operators.""" 37 38 def _testBinary(self, op, a, b, expected, equality_test=None): 39 with self.test_session() as session: 40 with self.test_scope(): 41 pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a") 42 pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b") 43 output = op(pa, pb) 44 result = session.run(output, {pa: a, pb: b}) 45 if equality_test is None: 46 equality_test = self.assertAllCloseAccordingToType 47 equality_test(result, expected, rtol=1e-3) 48 49 def _testSymmetricBinary(self, op, a, b, expected, equality_test=None): 50 self._testBinary(op, a, b, expected, equality_test) 51 self._testBinary(op, b, a, expected, equality_test) 52 53 def ListsAreClose(self, result, expected, rtol): 54 """Tests closeness of two lists of floats.""" 55 self.assertEqual(len(result), len(expected)) 56 for i in range(len(result)): 57 self.assertAllCloseAccordingToType(result[i], expected[i], rtol) 58 59 def testFloatOps(self): 60 for dtype in self.float_types: 61 if dtype == dtypes.bfloat16.as_numpy_dtype: 62 a = -1.01 63 b = 4.1 64 else: 65 a = -1.001 66 b = 4.01 67 self._testBinary( 68 lambda x, y: math_ops.approximate_equal(x, y, tolerance=0.0001), 69 np.array([[[[-1, 2.00009999], [-3, b]]]], dtype=dtype), 70 np.array([[[[a, 2], [-3.00009, 4]]]], dtype=dtype), 71 expected=np.array([[[[False, True], [True, False]]]], dtype=dtype)) 72 73 self._testBinary( 74 gen_math_ops._real_div, 75 np.array([3, 3, -1.5, -8, 44], dtype=dtype), 76 np.array([2, -2, 7, -4, 0], dtype=dtype), 77 expected=np.array( 78 [1.5, -1.5, -0.2142857, 2, float("inf")], dtype=dtype)) 79 80 self._testBinary(math_ops.pow, dtype(3), dtype(4), expected=dtype(81)) 81 82 self._testBinary( 83 math_ops.pow, 84 np.array([1, 2], dtype=dtype), 85 np.zeros(shape=[0, 2], dtype=dtype), 86 expected=np.zeros(shape=[0, 2], dtype=dtype)) 87 self._testBinary( 88 math_ops.pow, 89 np.array([10, 4], dtype=dtype), 90 np.array([2, 3], dtype=dtype), 91 expected=np.array([100, 64], dtype=dtype)) 92 self._testBinary( 93 math_ops.pow, 94 dtype(2), 95 np.array([3, 4], dtype=dtype), 96 expected=np.array([8, 16], dtype=dtype)) 97 self._testBinary( 98 math_ops.pow, 99 np.array([[2], [3]], dtype=dtype), 100 dtype(4), 101 expected=np.array([[16], [81]], dtype=dtype)) 102 103 self._testBinary( 104 math_ops.atan2, 105 np.array([0, np.sqrt(2), 1, np.sqrt(2), 0], dtype), 106 np.array([1, np.sqrt(2), 0, -np.sqrt(2), -1], dtype), 107 expected=np.array( 108 [0, np.pi / 4, np.pi / 2, np.pi * 3 / 4, np.pi], dtype=dtype)) 109 110 self._testBinary( 111 gen_math_ops._reciprocal_grad, 112 np.array([4, -3, -2, 1], dtype=dtype), 113 np.array([5, -6, 7, -8], dtype=dtype), 114 expected=np.array([-80, 54, -28, 8], dtype=dtype)) 115 116 self._testBinary( 117 gen_math_ops._sigmoid_grad, 118 np.array([4, 3, 2, 1], dtype=dtype), 119 np.array([5, 6, 7, 8], dtype=dtype), 120 expected=np.array([-60, -36, -14, 0], dtype=dtype)) 121 122 self._testBinary( 123 gen_math_ops._rsqrt_grad, 124 np.array([4, 3, 2, 1], dtype=dtype), 125 np.array([5, 6, 7, 8], dtype=dtype), 126 expected=np.array([-160, -81, -28, -4], dtype=dtype)) 127 128 self._testBinary( 129 gen_math_ops._sqrt_grad, 130 np.array([4, 3, 2, 1], dtype=dtype), 131 np.array([5, 6, 7, 8], dtype=dtype), 132 expected=np.array([0.625, 1, 1.75, 4], dtype=dtype)) 133 134 self._testBinary( 135 gen_nn_ops._softplus_grad, 136 np.array([4, 3, 2, 1], dtype=dtype), 137 np.array([5, 6, 7, 8], dtype=dtype), 138 expected=np.array( 139 [3.97322869, 2.99258232, 1.99817801, 0.99966466], dtype=dtype)) 140 141 self._testBinary( 142 gen_nn_ops._softsign_grad, 143 np.array([4, 3, 2, 1], dtype=dtype), 144 np.array([5, 6, 7, 8], dtype=dtype), 145 expected=np.array( 146 [0.11111111, 0.06122449, 0.03125, 0.01234568], dtype=dtype)) 147 148 self._testBinary( 149 gen_math_ops._tanh_grad, 150 np.array([4, 3, 2, 1], dtype=dtype), 151 np.array([5, 6, 7, 8], dtype=dtype), 152 expected=np.array([-75, -48, -21, 0], dtype=dtype)) 153 154 self._testBinary( 155 gen_nn_ops._elu_grad, 156 np.array([1, 2, 3, 4, 5, 6], dtype=dtype), 157 np.array([-.6, -.4, -.2, 0, .2, .4], dtype=dtype), 158 expected=np.array([0.4, 1.2, 2.4, 4, 5, 6], dtype=dtype)) 159 160 self._testBinary( 161 gen_nn_ops._selu_grad, 162 np.array([1, 2, 3, 4, 5, 6], dtype=dtype), 163 np.array([-.6, -.4, -.2, .2, .4, .6], dtype=dtype), 164 expected=np.array( 165 [1.158099340847, 2.7161986816948, 4.67429802254, 166 4.202803949422, 5.2535049367774, 6.30420592413], dtype=dtype)) 167 168 self._testBinary( 169 gen_nn_ops._relu_grad, 170 np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype), 171 np.array([0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9], dtype=dtype), 172 expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10], dtype=dtype)) 173 174 self._testBinary( 175 gen_nn_ops._relu6_grad, 176 np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=dtype), 177 np.array( 178 [0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9, 6.1, 10.0], dtype=dtype), 179 expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10, 0, 0], dtype=dtype)) 180 181 self._testBinary( 182 gen_nn_ops._softmax_cross_entropy_with_logits, 183 np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype), 184 np.array([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]], dtype=dtype), 185 expected=[ 186 np.array([1.44019, 2.44019], dtype=dtype), 187 np.array([[-0.067941, -0.112856, -0.063117, 0.243914], 188 [-0.367941, -0.212856, 0.036883, 0.543914]], 189 dtype=dtype), 190 ], 191 equality_test=self.ListsAreClose) 192 193 self._testBinary( 194 gen_nn_ops._sparse_softmax_cross_entropy_with_logits, 195 np.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], 196 [0.9, 1.0, 1.1, 1.2]], dtype=dtype), 197 np.array([2, 1, 7], dtype=np.int32), 198 expected=[ 199 np.array([1.342536, 1.442536, np.nan], dtype=dtype), 200 np.array([[0.213838, 0.236328, -0.738817, 0.288651], 201 [0.213838, -0.763672, 0.261183, 0.288651], 202 [np.nan, np.nan, np.nan, np.nan]], 203 dtype=dtype), 204 ], 205 equality_test=self.ListsAreClose) 206 207 def testIntOps(self): 208 for dtype in self.int_types: 209 self._testBinary( 210 gen_math_ops._truncate_div, 211 np.array([3, 3, -1, -9, -8], dtype=dtype), 212 np.array([2, -2, 7, 2, -4], dtype=dtype), 213 expected=np.array([1, -1, 0, -4, 2], dtype=dtype)) 214 self._testSymmetricBinary( 215 bitwise_ops.bitwise_and, 216 np.array([0b1, 0b101, 0b1000], dtype=dtype), 217 np.array([0b0, 0b101, 0b1001], dtype=dtype), 218 expected=np.array([0b0, 0b101, 0b1000], dtype=dtype)) 219 self._testSymmetricBinary( 220 bitwise_ops.bitwise_or, 221 np.array([0b1, 0b101, 0b1000], dtype=dtype), 222 np.array([0b0, 0b101, 0b1001], dtype=dtype), 223 expected=np.array([0b1, 0b101, 0b1001], dtype=dtype)) 224 225 lhs = np.array([0, 5, 3, 14], dtype=dtype) 226 rhs = np.array([5, 0, 7, 11], dtype=dtype) 227 self._testBinary( 228 bitwise_ops.left_shift, lhs, rhs, 229 expected=np.left_shift(lhs, rhs)) 230 self._testBinary( 231 bitwise_ops.right_shift, lhs, rhs, 232 expected=np.right_shift(lhs, rhs)) 233 234 if dtype in [np.int8, np.int16, np.int32, np.int64]: 235 lhs = np.array([-1, -5, -3, -14], dtype=dtype) 236 rhs = np.array([5, 0, 1, 11], dtype=dtype) 237 self._testBinary( 238 bitwise_ops.right_shift, lhs, rhs, 239 expected=np.right_shift(lhs, rhs)) 240 241 def testNumericOps(self): 242 for dtype in self.numeric_types: 243 self._testBinary( 244 math_ops.add, 245 np.array([1, 2], dtype=dtype), 246 np.array([10, 20], dtype=dtype), 247 expected=np.array([11, 22], dtype=dtype)) 248 self._testBinary( 249 math_ops.add, 250 dtype(5), 251 np.array([1, 2], dtype=dtype), 252 expected=np.array([6, 7], dtype=dtype)) 253 self._testBinary( 254 math_ops.add, 255 np.array([[1], [2]], dtype=dtype), 256 dtype(7), 257 expected=np.array([[8], [9]], dtype=dtype)) 258 259 self._testBinary( 260 math_ops.subtract, 261 np.array([1, 2], dtype=dtype), 262 np.array([10, 20], dtype=dtype), 263 expected=np.array([-9, -18], dtype=dtype)) 264 self._testBinary( 265 math_ops.subtract, 266 dtype(5), 267 np.array([1, 2], dtype=dtype), 268 expected=np.array([4, 3], dtype=dtype)) 269 self._testBinary( 270 math_ops.subtract, 271 np.array([[1], [2]], dtype=dtype), 272 dtype(7), 273 expected=np.array([[-6], [-5]], dtype=dtype)) 274 275 if dtype not in self.complex_types: # min/max not supported for complex 276 self._testBinary( 277 math_ops.maximum, 278 np.array([1, 2], dtype=dtype), 279 np.array([10, 20], dtype=dtype), 280 expected=np.array([10, 20], dtype=dtype)) 281 self._testBinary( 282 math_ops.maximum, 283 dtype(5), 284 np.array([1, 20], dtype=dtype), 285 expected=np.array([5, 20], dtype=dtype)) 286 self._testBinary( 287 math_ops.maximum, 288 np.array([[10], [2]], dtype=dtype), 289 dtype(7), 290 expected=np.array([[10], [7]], dtype=dtype)) 291 292 self._testBinary( 293 math_ops.minimum, 294 np.array([1, 20], dtype=dtype), 295 np.array([10, 2], dtype=dtype), 296 expected=np.array([1, 2], dtype=dtype)) 297 self._testBinary( 298 math_ops.minimum, 299 dtype(5), 300 np.array([1, 20], dtype=dtype), 301 expected=np.array([1, 5], dtype=dtype)) 302 self._testBinary( 303 math_ops.minimum, 304 np.array([[10], [2]], dtype=dtype), 305 dtype(7), 306 expected=np.array([[7], [2]], dtype=dtype)) 307 308 self._testBinary( 309 math_ops.multiply, 310 np.array([1, 20], dtype=dtype), 311 np.array([10, 2], dtype=dtype), 312 expected=np.array([10, 40], dtype=dtype)) 313 self._testBinary( 314 math_ops.multiply, 315 dtype(5), 316 np.array([1, 20], dtype=dtype), 317 expected=np.array([5, 100], dtype=dtype)) 318 self._testBinary( 319 math_ops.multiply, 320 np.array([[10], [2]], dtype=dtype), 321 dtype(7), 322 expected=np.array([[70], [14]], dtype=dtype)) 323 324 # Complex support for squared_difference is incidental, see b/68205550 325 if dtype not in self.complex_types: 326 self._testBinary( 327 math_ops.squared_difference, 328 np.array([1, 2], dtype=dtype), 329 np.array([10, 20], dtype=dtype), 330 expected=np.array([81, 324], dtype=dtype)) 331 self._testBinary( 332 math_ops.squared_difference, 333 dtype(5), 334 np.array([1, 2], dtype=dtype), 335 expected=np.array([16, 9], dtype=dtype)) 336 self._testBinary( 337 math_ops.squared_difference, 338 np.array([[1], [2]], dtype=dtype), 339 dtype(7), 340 expected=np.array([[36], [25]], dtype=dtype)) 341 342 self._testBinary( 343 nn_ops.bias_add, 344 np.array([[1, 2], [3, 4]], dtype=dtype), 345 np.array([2, -1], dtype=dtype), 346 expected=np.array([[3, 1], [5, 3]], dtype=dtype)) 347 self._testBinary( 348 nn_ops.bias_add, 349 np.array([[[[1, 2], [3, 4]]]], dtype=dtype), 350 np.array([2, -1], dtype=dtype), 351 expected=np.array([[[[3, 1], [5, 3]]]], dtype=dtype)) 352 353 def testComplexOps(self): 354 for dtype in self.complex_types: 355 ctypes = {np.complex64: np.float32} 356 self._testBinary( 357 math_ops.complex, 358 np.array([[[[-1, 2], [2, 0]]]], dtype=ctypes[dtype]), 359 np.array([[[[2, -3], [0, 4]]]], dtype=ctypes[dtype]), 360 expected=np.array([[[[-1 + 2j, 2 - 3j], [2, 4j]]]], dtype=dtype)) 361 362 self._testBinary( 363 lambda x, y: math_ops.approximate_equal(x, y, tolerance=0.0001), 364 np.array( 365 [[[[-1 + 2j, 2.00009999 - 3j], [2 - 3j, 3 + 4.01j]]]], 366 dtype=dtype), 367 np.array( 368 [[[[-1.001 + 2j, 2 - 3j], [2 - 3.00009j, 3 + 4j]]]], dtype=dtype), 369 expected=np.array([[[[False, True], [True, False]]]], dtype=dtype)) 370 371 self._testBinary( 372 gen_math_ops._real_div, 373 np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j], dtype=dtype), 374 np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j], dtype=dtype), 375 expected=np.array( 376 [1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2], 377 dtype=dtype)) 378 379 # Test inf/nan scenarios. 380 self._testBinary( 381 gen_math_ops._real_div, 382 np.array([4 + 3j, 4, 3j, -4, -4j, 2 - 3j], dtype=dtype), 383 np.array([0, 0, 0, 0, 0, 0], dtype=dtype), 384 expected=np.array( 385 [ 386 dtype(1 + 1j) / 0, 387 dtype(1) / 0, 388 dtype(1j) / 0, 389 dtype(-1) / 0, 390 dtype(-1j) / 0, 391 dtype(1 - 1j) / 0 392 ], 393 dtype=dtype)) 394 395 self._testBinary( 396 math_ops.pow, 397 dtype(3 + 2j), 398 dtype(4 - 5j), 399 expected=np.power(dtype(3 + 2j), dtype(4 - 5j))) 400 self._testBinary( # empty rhs 401 math_ops.pow, 402 np.array([1 + 2j, 2 - 3j], dtype=dtype), 403 np.zeros(shape=[0, 2], dtype=dtype), 404 expected=np.zeros(shape=[0, 2], dtype=dtype)) 405 self._testBinary( # to zero power 406 math_ops.pow, 407 np.array([1 + 2j, 2 - 3j], dtype=dtype), 408 np.zeros(shape=[1, 2], dtype=dtype), 409 expected=np.ones(shape=[1, 2], dtype=dtype)) 410 lhs = np.array([1 - 2j, 4 + 3j, 2 - 3j, 3, 2j, 1, 4], dtype=dtype) 411 rhs = np.array([2, 3j, 3 + 4j, 2 + 3j, 3 - 2j, 2, 3 + 3j], dtype=dtype) 412 scalar = dtype(2 + 2j) 413 self._testBinary(math_ops.pow, lhs, rhs, expected=np.power(lhs, rhs)) 414 self._testBinary( 415 math_ops.pow, scalar, rhs, expected=np.power(scalar, rhs)) 416 self._testBinary(math_ops.pow, lhs, scalar, np.power(lhs, scalar)) 417 418 lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype) 419 rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype) 420 self._testBinary( 421 gen_math_ops._reciprocal_grad, lhs, rhs, expected=-rhs * lhs * lhs) 422 423 self._testBinary( 424 gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs)) 425 426 self._testBinary( 427 gen_math_ops._rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2) 428 429 self._testBinary( 430 gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs)) 431 432 self._testBinary( 433 gen_math_ops._tanh_grad, lhs, rhs, expected=rhs * (1 - lhs * lhs)) 434 435 def testComplexMath(self): 436 for dtype in self.complex_types: 437 self._testBinary( 438 math_ops.add, 439 np.array([1 + 3j, 2 + 7j], dtype=dtype), 440 np.array([10 - 4j, 20 + 17j], dtype=dtype), 441 expected=np.array([11 - 1j, 22 + 24j], dtype=dtype)) 442 self._testBinary( 443 math_ops.add, 444 dtype(5 - 7j), 445 np.array([1 + 2j, 2 + 4j], dtype=dtype), 446 expected=np.array([6 - 5j, 7 - 3j], dtype=dtype)) 447 self._testBinary( 448 math_ops.add, 449 np.array([[1 - 2j], [2 + 1j]], dtype=dtype), 450 dtype(7 + 5j), 451 expected=np.array([[8 + 3j], [9 + 6j]], dtype=dtype)) 452 453 self._testBinary( 454 math_ops.subtract, 455 np.array([1 + 3j, 2 + 7j], dtype=dtype), 456 np.array([10 - 4j, 20 + 17j], dtype=dtype), 457 expected=np.array([-9 + 7j, -18 - 10j], dtype=dtype)) 458 self._testBinary( 459 math_ops.subtract, 460 dtype(5 - 7j), 461 np.array([1 + 2j, 2 + 4j], dtype=dtype), 462 expected=np.array([4 - 9j, 3 - 11j], dtype=dtype)) 463 self._testBinary( 464 math_ops.subtract, 465 np.array([[1 - 2j], [2 + 1j]], dtype=dtype), 466 dtype(7 + 5j), 467 expected=np.array([[-6 - 7j], [-5 - 4j]], dtype=dtype)) 468 469 self._testBinary( 470 math_ops.multiply, 471 np.array([1 + 3j, 2 + 7j], dtype=dtype), 472 np.array([10 - 4j, 20 + 17j], dtype=dtype), 473 expected=np.array( 474 [(1 + 3j) * (10 - 4j), (2 + 7j) * (20 + 17j)], dtype=dtype)) 475 self._testBinary( 476 math_ops.multiply, 477 dtype(5 - 7j), 478 np.array([1 + 2j, 2 + 4j], dtype=dtype), 479 expected=np.array( 480 [(5 - 7j) * (1 + 2j), (5 - 7j) * (2 + 4j)], dtype=dtype)) 481 self._testBinary( 482 math_ops.multiply, 483 np.array([[1 - 2j], [2 + 1j]], dtype=dtype), 484 dtype(7 + 5j), 485 expected=np.array( 486 [[(7 + 5j) * (1 - 2j)], [(7 + 5j) * (2 + 1j)]], dtype=dtype)) 487 488 self._testBinary( 489 math_ops.div, 490 np.array([8 - 1j, 2 + 16j], dtype=dtype), 491 np.array([2 + 4j, 4 - 8j], dtype=dtype), 492 expected=np.array( 493 [(8 - 1j) / (2 + 4j), (2 + 16j) / (4 - 8j)], dtype=dtype)) 494 self._testBinary( 495 math_ops.div, 496 dtype(1 + 2j), 497 np.array([2 + 4j, 4 - 8j], dtype=dtype), 498 expected=np.array( 499 [(1 + 2j) / (2 + 4j), (1 + 2j) / (4 - 8j)], dtype=dtype)) 500 self._testBinary( 501 math_ops.div, 502 np.array([2 + 4j, 4 - 8j], dtype=dtype), 503 dtype(1 + 2j), 504 expected=np.array( 505 [(2 + 4j) / (1 + 2j), (4 - 8j) / (1 + 2j)], dtype=dtype)) 506 507 # TODO(b/68205550): math_ops.squared_difference shouldn't be supported. 508 509 self._testBinary( 510 nn_ops.bias_add, 511 np.array([[1 + 2j, 2 + 7j], [3 - 5j, 4 + 2j]], dtype=dtype), 512 np.array([2 + 6j, -1 - 3j], dtype=dtype), 513 expected=np.array([[3 + 8j, 1 + 4j], [5 + 1j, 3 - 1j]], dtype=dtype)) 514 self._testBinary( 515 nn_ops.bias_add, 516 np.array([[[[1 + 4j, 2 - 1j], [3 + 7j, 4]]]], dtype=dtype), 517 np.array([2 + 1j, -1 + 2j], dtype=dtype), 518 expected=np.array( 519 [[[[3 + 5j, 1 + 1j], [5 + 8j, 3 + 2j]]]], dtype=dtype)) 520 521 def _testDivision(self, dtype): 522 """Test cases for division operators.""" 523 self._testBinary( 524 math_ops.div, 525 np.array([10, 20], dtype=dtype), 526 np.array([10, 2], dtype=dtype), 527 expected=np.array([1, 10], dtype=dtype)) 528 self._testBinary( 529 math_ops.div, 530 dtype(40), 531 np.array([2, 20], dtype=dtype), 532 expected=np.array([20, 2], dtype=dtype)) 533 self._testBinary( 534 math_ops.div, 535 np.array([[10], [4]], dtype=dtype), 536 dtype(2), 537 expected=np.array([[5], [2]], dtype=dtype)) 538 539 if dtype not in self.complex_types: # floordiv unsupported for complex. 540 self._testBinary( 541 gen_math_ops._floor_div, 542 np.array([3, 3, -1, -9, -8], dtype=dtype), 543 np.array([2, -2, 7, 2, -4], dtype=dtype), 544 expected=np.array([1, -2, -1, -5, 2], dtype=dtype)) 545 546 def testIntDivision(self): 547 for dtype in self.int_types: 548 self._testDivision(dtype) 549 550 def testFloatDivision(self): 551 for dtype in self.float_types | self.complex_types: 552 self._testDivision(dtype) 553 554 def _testRemainder(self, dtype): 555 """Test cases for remainder operators.""" 556 self._testBinary( 557 gen_math_ops._floor_mod, 558 np.array([3, 3, -1, -8], dtype=dtype), 559 np.array([2, -2, 7, -4], dtype=dtype), 560 expected=np.array([1, -1, 6, 0], dtype=dtype)) 561 self._testBinary( 562 gen_math_ops._truncate_mod, 563 np.array([3, 3, -1, -8], dtype=dtype), 564 np.array([2, -2, 7, -4], dtype=dtype), 565 expected=np.array([1, 1, -1, 0], dtype=dtype)) 566 567 def testIntRemainder(self): 568 for dtype in self.int_types: 569 self._testRemainder(dtype) 570 571 def testFloatRemainder(self): 572 for dtype in self.float_types: 573 self._testRemainder(dtype) 574 575 def testLogicalOps(self): 576 self._testBinary( 577 math_ops.logical_and, 578 np.array([[True, False], [False, True]], dtype=np.bool), 579 np.array([[False, True], [False, True]], dtype=np.bool), 580 expected=np.array([[False, False], [False, True]], dtype=np.bool)) 581 582 self._testBinary( 583 math_ops.logical_or, 584 np.array([[True, False], [False, True]], dtype=np.bool), 585 np.array([[False, True], [False, True]], dtype=np.bool), 586 expected=np.array([[True, True], [False, True]], dtype=np.bool)) 587 588 def testComparisons(self): 589 self._testBinary( 590 math_ops.equal, 591 np.array([1, 5, 20], dtype=np.float32), 592 np.array([10, 5, 2], dtype=np.float32), 593 expected=np.array([False, True, False], dtype=np.bool)) 594 self._testBinary( 595 math_ops.equal, 596 np.float32(5), 597 np.array([1, 5, 20], dtype=np.float32), 598 expected=np.array([False, True, False], dtype=np.bool)) 599 self._testBinary( 600 math_ops.equal, 601 np.array([[10], [7], [2]], dtype=np.float32), 602 np.float32(7), 603 expected=np.array([[False], [True], [False]], dtype=np.bool)) 604 605 self._testBinary( 606 math_ops.not_equal, 607 np.array([1, 5, 20], dtype=np.float32), 608 np.array([10, 5, 2], dtype=np.float32), 609 expected=np.array([True, False, True], dtype=np.bool)) 610 self._testBinary( 611 math_ops.not_equal, 612 np.float32(5), 613 np.array([1, 5, 20], dtype=np.float32), 614 expected=np.array([True, False, True], dtype=np.bool)) 615 self._testBinary( 616 math_ops.not_equal, 617 np.array([[10], [7], [2]], dtype=np.float32), 618 np.float32(7), 619 expected=np.array([[True], [False], [True]], dtype=np.bool)) 620 621 for greater_op in [math_ops.greater, (lambda x, y: x > y)]: 622 self._testBinary( 623 greater_op, 624 np.array([1, 5, 20], dtype=np.float32), 625 np.array([10, 5, 2], dtype=np.float32), 626 expected=np.array([False, False, True], dtype=np.bool)) 627 self._testBinary( 628 greater_op, 629 np.float32(5), 630 np.array([1, 5, 20], dtype=np.float32), 631 expected=np.array([True, False, False], dtype=np.bool)) 632 self._testBinary( 633 greater_op, 634 np.array([[10], [7], [2]], dtype=np.float32), 635 np.float32(7), 636 expected=np.array([[True], [False], [False]], dtype=np.bool)) 637 638 for greater_equal_op in [math_ops.greater_equal, (lambda x, y: x >= y)]: 639 self._testBinary( 640 greater_equal_op, 641 np.array([1, 5, 20], dtype=np.float32), 642 np.array([10, 5, 2], dtype=np.float32), 643 expected=np.array([False, True, True], dtype=np.bool)) 644 self._testBinary( 645 greater_equal_op, 646 np.float32(5), 647 np.array([1, 5, 20], dtype=np.float32), 648 expected=np.array([True, True, False], dtype=np.bool)) 649 self._testBinary( 650 greater_equal_op, 651 np.array([[10], [7], [2]], dtype=np.float32), 652 np.float32(7), 653 expected=np.array([[True], [True], [False]], dtype=np.bool)) 654 655 for less_op in [math_ops.less, (lambda x, y: x < y)]: 656 self._testBinary( 657 less_op, 658 np.array([1, 5, 20], dtype=np.float32), 659 np.array([10, 5, 2], dtype=np.float32), 660 expected=np.array([True, False, False], dtype=np.bool)) 661 self._testBinary( 662 less_op, 663 np.float32(5), 664 np.array([1, 5, 20], dtype=np.float32), 665 expected=np.array([False, False, True], dtype=np.bool)) 666 self._testBinary( 667 less_op, 668 np.array([[10], [7], [2]], dtype=np.float32), 669 np.float32(7), 670 expected=np.array([[False], [False], [True]], dtype=np.bool)) 671 672 for less_equal_op in [math_ops.less_equal, (lambda x, y: x <= y)]: 673 self._testBinary( 674 less_equal_op, 675 np.array([1, 5, 20], dtype=np.float32), 676 np.array([10, 5, 2], dtype=np.float32), 677 expected=np.array([True, True, False], dtype=np.bool)) 678 self._testBinary( 679 less_equal_op, 680 np.float32(5), 681 np.array([1, 5, 20], dtype=np.float32), 682 expected=np.array([False, True, True], dtype=np.bool)) 683 self._testBinary( 684 less_equal_op, 685 np.array([[10], [7], [2]], dtype=np.float32), 686 np.float32(7), 687 expected=np.array([[False], [True], [True]], dtype=np.bool)) 688 689 def testBroadcasting(self): 690 """Tests broadcasting behavior of an operator.""" 691 692 for dtype in self.numeric_types: 693 self._testBinary( 694 math_ops.add, 695 np.array(3, dtype=dtype), 696 np.array([10, 20], dtype=dtype), 697 expected=np.array([13, 23], dtype=dtype)) 698 self._testBinary( 699 math_ops.add, 700 np.array([10, 20], dtype=dtype), 701 np.array(4, dtype=dtype), 702 expected=np.array([14, 24], dtype=dtype)) 703 704 # [1,3] x [4,1] => [4,3] 705 self._testBinary( 706 math_ops.add, 707 np.array([[10, 20, 30]], dtype=dtype), 708 np.array([[1], [2], [3], [4]], dtype=dtype), 709 expected=np.array( 710 [[11, 21, 31], [12, 22, 32], [13, 23, 33], [14, 24, 34]], 711 dtype=dtype)) 712 713 # [3] * [4,1] => [4,3] 714 self._testBinary( 715 math_ops.add, 716 np.array([10, 20, 30], dtype=dtype), 717 np.array([[1], [2], [3], [4]], dtype=dtype), 718 expected=np.array( 719 [[11, 21, 31], [12, 22, 32], [13, 23, 33], [14, 24, 34]], 720 dtype=dtype)) 721 722 def testFill(self): 723 for dtype in self.numeric_types: 724 self._testBinary( 725 array_ops.fill, 726 np.array([], dtype=np.int32), 727 dtype(-42), 728 expected=dtype(-42)) 729 self._testBinary( 730 array_ops.fill, 731 np.array([1, 2], dtype=np.int32), 732 dtype(7), 733 expected=np.array([[7, 7]], dtype=dtype)) 734 self._testBinary( 735 array_ops.fill, 736 np.array([3, 2], dtype=np.int32), 737 dtype(50), 738 expected=np.array([[50, 50], [50, 50], [50, 50]], dtype=dtype)) 739 740 # Helper method used by testMatMul, testSparseMatMul, testBatchMatMul below. 741 def _testMatMul(self, op): 742 for dtype in self.float_types: 743 self._testBinary( 744 op, 745 np.array([[-0.25]], dtype=dtype), 746 np.array([[8]], dtype=dtype), 747 expected=np.array([[-2]], dtype=dtype)) 748 self._testBinary( 749 op, 750 np.array([[100, 10, 0.5]], dtype=dtype), 751 np.array([[1, 3], [2, 5], [6, 8]], dtype=dtype), 752 expected=np.array([[123, 354]], dtype=dtype)) 753 self._testBinary( 754 op, 755 np.array([[1, 3], [2, 5], [6, 8]], dtype=dtype), 756 np.array([[100], [10]], dtype=dtype), 757 expected=np.array([[130], [250], [680]], dtype=dtype)) 758 self._testBinary( 759 op, 760 np.array([[1000, 100], [10, 1]], dtype=dtype), 761 np.array([[1, 2], [3, 4]], dtype=dtype), 762 expected=np.array([[1300, 2400], [13, 24]], dtype=dtype)) 763 764 self._testBinary( 765 op, 766 np.array([], dtype=dtype).reshape((2, 0)), 767 np.array([], dtype=dtype).reshape((0, 3)), 768 expected=np.array([[0, 0, 0], [0, 0, 0]], dtype=dtype)) 769 770 def testMatMul(self): 771 self._testMatMul(math_ops.matmul) 772 773 # TODO(phawkins): failing on GPU, no registered kernel. 774 def DISABLED_testSparseMatMul(self): 775 # Binary wrappers for sparse_matmul with different hints 776 def SparseMatmulWrapperTF(a, b): 777 return math_ops.sparse_matmul(a, b, a_is_sparse=True) 778 779 def SparseMatmulWrapperFT(a, b): 780 return math_ops.sparse_matmul(a, b, b_is_sparse=True) 781 782 def SparseMatmulWrapperTT(a, b): 783 return math_ops.sparse_matmul(a, b, a_is_sparse=True, b_is_sparse=True) 784 785 self._testMatMul(math_ops.sparse_matmul) 786 self._testMatMul(SparseMatmulWrapperTF) 787 self._testMatMul(SparseMatmulWrapperFT) 788 self._testMatMul(SparseMatmulWrapperTT) 789 790 def testBatchMatMul(self): 791 # Same tests as for tf.matmul above. 792 self._testMatMul(math_ops.matmul) 793 794 # Tests with batches of matrices. 795 self._testBinary( 796 math_ops.matmul, 797 np.array([[[-0.25]]], dtype=np.float32), 798 np.array([[[8]]], dtype=np.float32), 799 expected=np.array([[[-2]]], dtype=np.float32)) 800 self._testBinary( 801 math_ops.matmul, 802 np.array([[[-0.25]], [[4]]], dtype=np.float32), 803 np.array([[[8]], [[2]]], dtype=np.float32), 804 expected=np.array([[[-2]], [[8]]], dtype=np.float32)) 805 self._testBinary( 806 math_ops.matmul, 807 np.array( 808 [[[[7, 13], [10, 1]], [[2, 0.25], [20, 2]]], 809 [[[3, 5], [30, 3]], [[0.75, 1], [40, 4]]]], 810 dtype=np.float32), 811 np.array( 812 [[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], [[[11, 22], [33, 44]], 813 [[55, 66], [77, 88]]]], 814 dtype=np.float32), 815 expected=np.array( 816 [[[[46, 66], [13, 24]], [[11.75, 14], [114, 136]]], 817 [[[198, 286], [429, 792]], [[118.25, 137.5], [2508, 2992]]]], 818 dtype=np.float32)) 819 820 self._testBinary( 821 math_ops.matmul, 822 np.array([], dtype=np.float32).reshape((2, 2, 0)), 823 np.array([], dtype=np.float32).reshape((2, 0, 3)), 824 expected=np.array( 825 [[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]], 826 dtype=np.float32)) 827 self._testBinary( 828 math_ops.matmul, 829 np.array([], dtype=np.float32).reshape((0, 2, 4)), 830 np.array([], dtype=np.float32).reshape((0, 4, 3)), 831 expected=np.array([], dtype=np.float32).reshape(0, 2, 3)) 832 833 # Regression test for b/31472796. 834 if hasattr(np, "matmul"): 835 x = np.arange(0, 3 * 5 * 2 * 7, dtype=np.float32).reshape((3, 5, 2, 7)) 836 self._testBinary( 837 lambda x, y: math_ops.matmul(x, y, adjoint_b=True), 838 x, x, 839 expected=np.matmul(x, x.transpose([0, 1, 3, 2]))) 840 841 def testExpandDims(self): 842 for dtype in self.numeric_types: 843 self._testBinary( 844 array_ops.expand_dims, 845 dtype(7), 846 np.int32(0), 847 expected=np.array([7], dtype=dtype)) 848 self._testBinary( 849 array_ops.expand_dims, 850 np.array([42], dtype=dtype), 851 np.int32(0), 852 expected=np.array([[42]], dtype=dtype)) 853 self._testBinary( 854 array_ops.expand_dims, 855 np.array([], dtype=dtype), 856 np.int32(0), 857 expected=np.array([[]], dtype=dtype)) 858 self._testBinary( 859 array_ops.expand_dims, 860 np.array([[[1, 2], [3, 4]]], dtype=dtype), 861 np.int32(0), 862 expected=np.array([[[[1, 2], [3, 4]]]], dtype=dtype)) 863 self._testBinary( 864 array_ops.expand_dims, 865 np.array([[[1, 2], [3, 4]]], dtype=dtype), 866 np.int32(1), 867 expected=np.array([[[[1, 2], [3, 4]]]], dtype=dtype)) 868 self._testBinary( 869 array_ops.expand_dims, 870 np.array([[[1, 2], [3, 4]]], dtype=dtype), 871 np.int32(2), 872 expected=np.array([[[[1, 2]], [[3, 4]]]], dtype=dtype)) 873 self._testBinary( 874 array_ops.expand_dims, 875 np.array([[[1, 2], [3, 4]]], dtype=dtype), 876 np.int32(3), 877 expected=np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype)) 878 879 def testPad(self): 880 for dtype in self.numeric_types: 881 self._testBinary( 882 array_ops.pad, 883 np.array( 884 [[1, 2, 3], [4, 5, 6]], dtype=dtype), 885 np.array( 886 [[1, 2], [2, 1]], dtype=np.int32), 887 expected=np.array( 888 [[0, 0, 0, 0, 0, 0], 889 [0, 0, 1, 2, 3, 0], 890 [0, 0, 4, 5, 6, 0], 891 [0, 0, 0, 0, 0, 0], 892 [0, 0, 0, 0, 0, 0]], 893 dtype=dtype)) 894 895 self._testBinary( 896 lambda x, y: array_ops.pad(x, y, constant_values=7), 897 np.array( 898 [[1, 2, 3], [4, 5, 6]], dtype=dtype), 899 np.array( 900 [[0, 3], [2, 1]], dtype=np.int32), 901 expected=np.array( 902 [[7, 7, 1, 2, 3, 7], 903 [7, 7, 4, 5, 6, 7], 904 [7, 7, 7, 7, 7, 7], 905 [7, 7, 7, 7, 7, 7], 906 [7, 7, 7, 7, 7, 7]], 907 dtype=dtype)) 908 909 def testMirrorPad(self): 910 mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT") 911 for dtype in self.numeric_types: 912 self._testBinary( 913 mirror_pad, 914 np.array( 915 [ 916 [1, 2, 3], # 917 [4, 5, 6], # 918 ], 919 dtype=dtype), 920 np.array([[ 921 1, 922 1, 923 ], [2, 2]], dtype=np.int32), 924 expected=np.array( 925 [ 926 [6, 5, 4, 5, 6, 5, 4], # 927 [3, 2, 1, 2, 3, 2, 1], # 928 [6, 5, 4, 5, 6, 5, 4], # 929 [3, 2, 1, 2, 3, 2, 1] 930 ], 931 dtype=dtype)) 932 self._testBinary( 933 mirror_pad, 934 np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype), 935 np.array([[0, 0], [0, 0]], dtype=np.int32), 936 expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype)) 937 self._testBinary( 938 mirror_pad, 939 np.array( 940 [ 941 [1, 2, 3], # 942 [4, 5, 6], # 943 [7, 8, 9] 944 ], 945 dtype=dtype), 946 np.array([[2, 2], [0, 0]], dtype=np.int32), 947 expected=np.array( 948 [ 949 [7, 8, 9], # 950 [4, 5, 6], # 951 [1, 2, 3], # 952 [4, 5, 6], # 953 [7, 8, 9], # 954 [4, 5, 6], # 955 [1, 2, 3] 956 ], 957 dtype=dtype)) 958 self._testBinary( 959 mirror_pad, 960 np.array( 961 [ 962 [[1, 2, 3], [4, 5, 6]], 963 [[7, 8, 9], [10, 11, 12]], 964 ], dtype=dtype), 965 np.array([[0, 0], [1, 1], [1, 1]], dtype=np.int32), 966 expected=np.array( 967 [ 968 [ 969 [5, 4, 5, 6, 5], # 970 [2, 1, 2, 3, 2], # 971 [5, 4, 5, 6, 5], # 972 [2, 1, 2, 3, 2], # 973 ], 974 [ 975 [11, 10, 11, 12, 11], # 976 [8, 7, 8, 9, 8], # 977 [11, 10, 11, 12, 11], # 978 [8, 7, 8, 9, 8], # 979 ] 980 ], 981 dtype=dtype)) 982 983 def testReshape(self): 984 for dtype in self.numeric_types: 985 self._testBinary( 986 array_ops.reshape, 987 np.array([], dtype=dtype), 988 np.array([0, 4], dtype=np.int32), 989 expected=np.zeros(shape=[0, 4], dtype=dtype)) 990 self._testBinary( 991 array_ops.reshape, 992 np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 993 np.array([2, 3], dtype=np.int32), 994 expected=np.array([[0, 1, 2], [3, 4, 5]], dtype=dtype)) 995 self._testBinary( 996 array_ops.reshape, 997 np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 998 np.array([3, 2], dtype=np.int32), 999 expected=np.array([[0, 1], [2, 3], [4, 5]], dtype=dtype)) 1000 self._testBinary( 1001 array_ops.reshape, 1002 np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 1003 np.array([-1, 6], dtype=np.int32), 1004 expected=np.array([[0, 1, 2, 3, 4, 5]], dtype=dtype)) 1005 self._testBinary( 1006 array_ops.reshape, 1007 np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 1008 np.array([6, -1], dtype=np.int32), 1009 expected=np.array([[0], [1], [2], [3], [4], [5]], dtype=dtype)) 1010 self._testBinary( 1011 array_ops.reshape, 1012 np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 1013 np.array([2, -1], dtype=np.int32), 1014 expected=np.array([[0, 1, 2], [3, 4, 5]], dtype=dtype)) 1015 self._testBinary( 1016 array_ops.reshape, 1017 np.array([0, 1, 2, 3, 4, 5], dtype=dtype), 1018 np.array([-1, 3], dtype=np.int32), 1019 expected=np.array([[0, 1, 2], [3, 4, 5]], dtype=dtype)) 1020 1021 def testSplit(self): 1022 for dtype in self.numeric_types: 1023 for axis in [0, -3]: 1024 self._testBinary( 1025 lambda x, y: array_ops.split(value=y, num_or_size_splits=3, axis=x), 1026 np.int32(axis), 1027 np.array([[[1], [2]], [[3], [4]], [[5], [6]]], 1028 dtype=dtype), 1029 expected=[ 1030 np.array([[[1], [2]]], dtype=dtype), 1031 np.array([[[3], [4]]], dtype=dtype), 1032 np.array([[[5], [6]]], dtype=dtype), 1033 ], 1034 equality_test=self.ListsAreClose) 1035 1036 for axis in [1, -2]: 1037 self._testBinary( 1038 lambda x, y: array_ops.split(value=y, num_or_size_splits=2, axis=x), 1039 np.int32(axis), 1040 np.array([[[1], [2]], [[3], [4]], [[5], [6]]], 1041 dtype=dtype), 1042 expected=[ 1043 np.array([[[1]], [[3]], [[5]]], dtype=dtype), 1044 np.array([[[2]], [[4]], [[6]]], dtype=dtype), 1045 ], 1046 equality_test=self.ListsAreClose) 1047 1048 def testTile(self): 1049 for dtype in self.numeric_types: 1050 self._testBinary( 1051 array_ops.tile, 1052 np.array([[6]], dtype=dtype), 1053 np.array([1, 2], dtype=np.int32), 1054 expected=np.array([[6, 6]], dtype=dtype)) 1055 self._testBinary( 1056 array_ops.tile, 1057 np.array([[1], [2]], dtype=dtype), 1058 np.array([1, 2], dtype=np.int32), 1059 expected=np.array([[1, 1], [2, 2]], dtype=dtype)) 1060 self._testBinary( 1061 array_ops.tile, 1062 np.array([[1, 2], [3, 4]], dtype=dtype), 1063 np.array([3, 2], dtype=np.int32), 1064 expected=np.array( 1065 [[1, 2, 1, 2], 1066 [3, 4, 3, 4], 1067 [1, 2, 1, 2], 1068 [3, 4, 3, 4], 1069 [1, 2, 1, 2], 1070 [3, 4, 3, 4]], 1071 dtype=dtype)) 1072 self._testBinary( 1073 array_ops.tile, 1074 np.array([[1, 2], [3, 4]], dtype=dtype), 1075 np.array([1, 1], dtype=np.int32), 1076 expected=np.array( 1077 [[1, 2], 1078 [3, 4]], 1079 dtype=dtype)) 1080 self._testBinary( 1081 array_ops.tile, 1082 np.array([[1, 2]], dtype=dtype), 1083 np.array([3, 1], dtype=np.int32), 1084 expected=np.array( 1085 [[1, 2], 1086 [1, 2], 1087 [1, 2]], 1088 dtype=dtype)) 1089 1090 def testTranspose(self): 1091 for dtype in self.numeric_types: 1092 self._testBinary( 1093 array_ops.transpose, 1094 np.zeros(shape=[1, 0, 4], dtype=dtype), 1095 np.array([1, 2, 0], dtype=np.int32), 1096 expected=np.zeros(shape=[0, 4, 1], dtype=dtype)) 1097 self._testBinary( 1098 array_ops.transpose, 1099 np.array([[1, 2], [3, 4]], dtype=dtype), 1100 np.array([0, 1], dtype=np.int32), 1101 expected=np.array([[1, 2], [3, 4]], dtype=dtype)) 1102 self._testBinary( 1103 array_ops.transpose, 1104 np.array([[1, 2], [3, 4]], dtype=dtype), 1105 np.array([1, 0], dtype=np.int32), 1106 expected=np.array([[1, 3], [2, 4]], dtype=dtype)) 1107 1108 def testCross(self): 1109 for dtype in self.float_types: 1110 self._testBinary( 1111 gen_math_ops.cross, 1112 np.zeros((4, 3), dtype=dtype), 1113 np.zeros((4, 3), dtype=dtype), 1114 expected=np.zeros((4, 3), dtype=dtype)) 1115 self._testBinary( 1116 gen_math_ops.cross, 1117 np.array([1, 2, 3], dtype=dtype), 1118 np.array([4, 5, 6], dtype=dtype), 1119 expected=np.array([-3, 6, -3], dtype=dtype)) 1120 self._testBinary( 1121 gen_math_ops.cross, 1122 np.array([[1, 2, 3], [10, 11, 12]], dtype=dtype), 1123 np.array([[4, 5, 6], [40, 50, 60]], dtype=dtype), 1124 expected=np.array([[-3, 6, -3], [60, -120, 60]], dtype=dtype)) 1125 1126 def testBroadcastArgs(self): 1127 self._testBinary(array_ops.broadcast_dynamic_shape, 1128 np.array([2, 3, 5], dtype=np.int32), 1129 np.array([1], dtype=np.int32), 1130 expected=np.array([2, 3, 5], dtype=np.int32)) 1131 1132 self._testBinary(array_ops.broadcast_dynamic_shape, 1133 np.array([1], dtype=np.int32), 1134 np.array([2, 3, 5], dtype=np.int32), 1135 expected=np.array([2, 3, 5], dtype=np.int32)) 1136 1137 self._testBinary(array_ops.broadcast_dynamic_shape, 1138 np.array([2, 3, 5], dtype=np.int32), 1139 np.array([5], dtype=np.int32), 1140 expected=np.array([2, 3, 5], dtype=np.int32)) 1141 1142 self._testBinary(array_ops.broadcast_dynamic_shape, 1143 np.array([5], dtype=np.int32), 1144 np.array([2, 3, 5], dtype=np.int32), 1145 expected=np.array([2, 3, 5], dtype=np.int32)) 1146 1147 self._testBinary(array_ops.broadcast_dynamic_shape, 1148 np.array([2, 3, 5], dtype=np.int32), 1149 np.array([3, 5], dtype=np.int32), 1150 expected=np.array([2, 3, 5], dtype=np.int32)) 1151 1152 self._testBinary(array_ops.broadcast_dynamic_shape, 1153 np.array([3, 5], dtype=np.int32), 1154 np.array([2, 3, 5], dtype=np.int32), 1155 expected=np.array([2, 3, 5], dtype=np.int32)) 1156 1157 self._testBinary(array_ops.broadcast_dynamic_shape, 1158 np.array([2, 3, 5], dtype=np.int32), 1159 np.array([3, 1], dtype=np.int32), 1160 expected=np.array([2, 3, 5], dtype=np.int32)) 1161 1162 self._testBinary(array_ops.broadcast_dynamic_shape, 1163 np.array([3, 1], dtype=np.int32), 1164 np.array([2, 3, 5], dtype=np.int32), 1165 expected=np.array([2, 3, 5], dtype=np.int32)) 1166 1167 self._testBinary(array_ops.broadcast_dynamic_shape, 1168 np.array([2, 1, 5], dtype=np.int32), 1169 np.array([3, 1], dtype=np.int32), 1170 expected=np.array([2, 3, 5], dtype=np.int32)) 1171 1172 self._testBinary(array_ops.broadcast_dynamic_shape, 1173 np.array([3, 1], dtype=np.int32), 1174 np.array([2, 1, 5], dtype=np.int32), 1175 expected=np.array([2, 3, 5], dtype=np.int32)) 1176 1177 with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, 1178 "Incompatible shapes"): 1179 self._testBinary(array_ops.broadcast_dynamic_shape, 1180 np.array([1, 2, 3], dtype=np.int32), 1181 np.array([4, 5, 6], dtype=np.int32), 1182 expected=None) 1183 1184 def testMatrixSetDiag(self): 1185 for dtype in self.numeric_types: 1186 # Square 1187 self._testBinary( 1188 array_ops.matrix_set_diag, 1189 np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]], 1190 dtype=dtype), 1191 np.array([1.0, 2.0, 3.0], dtype=dtype), 1192 expected=np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0], [1.0, 1.0, 3.0]], 1193 dtype=dtype)) 1194 1195 self._testBinary( 1196 array_ops.matrix_set_diag, 1197 np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]], 1198 [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0], [2.0, 0.0, 6.0]]], 1199 dtype=dtype), 1200 np.array([[-1.0, 0.0, -3.0], [-4.0, -5.0, -6.0]], dtype=dtype), 1201 expected=np.array( 1202 [[[-1.0, 0.0, 3.0], [0.0, 0.0, 0.0], [1.0, 0.0, -3.0]], 1203 [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0], [2.0, 0.0, -6.0]]], 1204 dtype=dtype)) 1205 1206 # Rectangular 1207 self._testBinary( 1208 array_ops.matrix_set_diag, 1209 np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]], dtype=dtype), 1210 np.array([3.0, 4.0], dtype=dtype), 1211 expected=np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]], dtype=dtype)) 1212 1213 self._testBinary( 1214 array_ops.matrix_set_diag, 1215 np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], dtype=dtype), 1216 np.array([3.0, 4.0], dtype=dtype), 1217 expected=np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]], dtype=dtype)) 1218 1219 self._testBinary( 1220 array_ops.matrix_set_diag, 1221 np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]], 1222 [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]], dtype=dtype), 1223 np.array([[-1.0, -2.0], [-4.0, -5.0]], 1224 dtype=dtype), 1225 expected=np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]], 1226 [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]], 1227 dtype=dtype)) 1228 1229if __name__ == "__main__": 1230 googletest.main() 1231