• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Test cases for binary operators."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.compiler.tests.xla_test import XLATestCase
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import errors
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import bitwise_ops
28from tensorflow.python.ops import gen_math_ops
29from tensorflow.python.ops import gen_nn_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.ops import nn_ops
32from tensorflow.python.platform import googletest
33
34
35class BinaryOpsTest(XLATestCase):
36  """Test cases for binary operators."""
37
38  def _testBinary(self, op, a, b, expected, equality_test=None):
39    with self.test_session() as session:
40      with self.test_scope():
41        pa = array_ops.placeholder(dtypes.as_dtype(a.dtype), a.shape, name="a")
42        pb = array_ops.placeholder(dtypes.as_dtype(b.dtype), b.shape, name="b")
43        output = op(pa, pb)
44      result = session.run(output, {pa: a, pb: b})
45      if equality_test is None:
46        equality_test = self.assertAllCloseAccordingToType
47      equality_test(result, expected, rtol=1e-3)
48
49  def _testSymmetricBinary(self, op, a, b, expected, equality_test=None):
50    self._testBinary(op, a, b, expected, equality_test)
51    self._testBinary(op, b, a, expected, equality_test)
52
53  def ListsAreClose(self, result, expected, rtol):
54    """Tests closeness of two lists of floats."""
55    self.assertEqual(len(result), len(expected))
56    for i in range(len(result)):
57      self.assertAllCloseAccordingToType(result[i], expected[i], rtol)
58
59  def testFloatOps(self):
60    for dtype in self.float_types:
61      if dtype == dtypes.bfloat16.as_numpy_dtype:
62        a = -1.01
63        b = 4.1
64      else:
65        a = -1.001
66        b = 4.01
67      self._testBinary(
68          lambda x, y: math_ops.approximate_equal(x, y, tolerance=0.0001),
69          np.array([[[[-1, 2.00009999], [-3, b]]]], dtype=dtype),
70          np.array([[[[a, 2], [-3.00009, 4]]]], dtype=dtype),
71          expected=np.array([[[[False, True], [True, False]]]], dtype=dtype))
72
73      self._testBinary(
74          gen_math_ops._real_div,
75          np.array([3, 3, -1.5, -8, 44], dtype=dtype),
76          np.array([2, -2, 7, -4, 0], dtype=dtype),
77          expected=np.array(
78              [1.5, -1.5, -0.2142857, 2, float("inf")], dtype=dtype))
79
80      self._testBinary(math_ops.pow, dtype(3), dtype(4), expected=dtype(81))
81
82      self._testBinary(
83          math_ops.pow,
84          np.array([1, 2], dtype=dtype),
85          np.zeros(shape=[0, 2], dtype=dtype),
86          expected=np.zeros(shape=[0, 2], dtype=dtype))
87      self._testBinary(
88          math_ops.pow,
89          np.array([10, 4], dtype=dtype),
90          np.array([2, 3], dtype=dtype),
91          expected=np.array([100, 64], dtype=dtype))
92      self._testBinary(
93          math_ops.pow,
94          dtype(2),
95          np.array([3, 4], dtype=dtype),
96          expected=np.array([8, 16], dtype=dtype))
97      self._testBinary(
98          math_ops.pow,
99          np.array([[2], [3]], dtype=dtype),
100          dtype(4),
101          expected=np.array([[16], [81]], dtype=dtype))
102
103      self._testBinary(
104          math_ops.atan2,
105          np.array([0, np.sqrt(2), 1, np.sqrt(2), 0], dtype),
106          np.array([1, np.sqrt(2), 0, -np.sqrt(2), -1], dtype),
107          expected=np.array(
108              [0, np.pi / 4, np.pi / 2, np.pi * 3 / 4, np.pi], dtype=dtype))
109
110      self._testBinary(
111          gen_math_ops._reciprocal_grad,
112          np.array([4, -3, -2, 1], dtype=dtype),
113          np.array([5, -6, 7, -8], dtype=dtype),
114          expected=np.array([-80, 54, -28, 8], dtype=dtype))
115
116      self._testBinary(
117          gen_math_ops._sigmoid_grad,
118          np.array([4, 3, 2, 1], dtype=dtype),
119          np.array([5, 6, 7, 8], dtype=dtype),
120          expected=np.array([-60, -36, -14, 0], dtype=dtype))
121
122      self._testBinary(
123          gen_math_ops._rsqrt_grad,
124          np.array([4, 3, 2, 1], dtype=dtype),
125          np.array([5, 6, 7, 8], dtype=dtype),
126          expected=np.array([-160, -81, -28, -4], dtype=dtype))
127
128      self._testBinary(
129          gen_math_ops._sqrt_grad,
130          np.array([4, 3, 2, 1], dtype=dtype),
131          np.array([5, 6, 7, 8], dtype=dtype),
132          expected=np.array([0.625, 1, 1.75, 4], dtype=dtype))
133
134      self._testBinary(
135          gen_nn_ops._softplus_grad,
136          np.array([4, 3, 2, 1], dtype=dtype),
137          np.array([5, 6, 7, 8], dtype=dtype),
138          expected=np.array(
139              [3.97322869, 2.99258232, 1.99817801, 0.99966466], dtype=dtype))
140
141      self._testBinary(
142          gen_nn_ops._softsign_grad,
143          np.array([4, 3, 2, 1], dtype=dtype),
144          np.array([5, 6, 7, 8], dtype=dtype),
145          expected=np.array(
146              [0.11111111, 0.06122449, 0.03125, 0.01234568], dtype=dtype))
147
148      self._testBinary(
149          gen_math_ops._tanh_grad,
150          np.array([4, 3, 2, 1], dtype=dtype),
151          np.array([5, 6, 7, 8], dtype=dtype),
152          expected=np.array([-75, -48, -21, 0], dtype=dtype))
153
154      self._testBinary(
155          gen_nn_ops._elu_grad,
156          np.array([1, 2, 3, 4, 5, 6], dtype=dtype),
157          np.array([-.6, -.4, -.2, 0, .2, .4], dtype=dtype),
158          expected=np.array([0.4, 1.2, 2.4, 4, 5, 6], dtype=dtype))
159
160      self._testBinary(
161          gen_nn_ops._selu_grad,
162          np.array([1, 2, 3, 4, 5, 6], dtype=dtype),
163          np.array([-.6, -.4, -.2, .2, .4, .6], dtype=dtype),
164          expected=np.array(
165              [1.158099340847, 2.7161986816948, 4.67429802254,
166               4.202803949422, 5.2535049367774, 6.30420592413], dtype=dtype))
167
168      self._testBinary(
169          gen_nn_ops._relu_grad,
170          np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=dtype),
171          np.array([0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9], dtype=dtype),
172          expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10], dtype=dtype))
173
174      self._testBinary(
175          gen_nn_ops._relu6_grad,
176          np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], dtype=dtype),
177          np.array(
178              [0, 0, 0, 0, 0, 0.1, 0.3, 0.5, 0.7, 0.9, 6.1, 10.0], dtype=dtype),
179          expected=np.array([0, 0, 0, 0, 0, 6, 7, 8, 9, 10, 0, 0], dtype=dtype))
180
181      self._testBinary(
182          gen_nn_ops._softmax_cross_entropy_with_logits,
183          np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=dtype),
184          np.array([[0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1]], dtype=dtype),
185          expected=[
186              np.array([1.44019, 2.44019], dtype=dtype),
187              np.array([[-0.067941, -0.112856, -0.063117, 0.243914],
188                        [-0.367941, -0.212856, 0.036883, 0.543914]],
189                       dtype=dtype),
190          ],
191          equality_test=self.ListsAreClose)
192
193      self._testBinary(
194          gen_nn_ops._sparse_softmax_cross_entropy_with_logits,
195          np.array([[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8],
196                    [0.9, 1.0, 1.1, 1.2]], dtype=dtype),
197          np.array([2, 1, 7], dtype=np.int32),
198          expected=[
199              np.array([1.342536, 1.442536, np.nan], dtype=dtype),
200              np.array([[0.213838, 0.236328, -0.738817, 0.288651],
201                        [0.213838, -0.763672, 0.261183, 0.288651],
202                        [np.nan, np.nan, np.nan, np.nan]],
203                       dtype=dtype),
204          ],
205          equality_test=self.ListsAreClose)
206
207  def testIntOps(self):
208    for dtype in self.int_types:
209      self._testBinary(
210          gen_math_ops._truncate_div,
211          np.array([3, 3, -1, -9, -8], dtype=dtype),
212          np.array([2, -2, 7, 2, -4], dtype=dtype),
213          expected=np.array([1, -1, 0, -4, 2], dtype=dtype))
214      self._testSymmetricBinary(
215          bitwise_ops.bitwise_and,
216          np.array([0b1, 0b101, 0b1000], dtype=dtype),
217          np.array([0b0, 0b101, 0b1001], dtype=dtype),
218          expected=np.array([0b0, 0b101, 0b1000], dtype=dtype))
219      self._testSymmetricBinary(
220          bitwise_ops.bitwise_or,
221          np.array([0b1, 0b101, 0b1000], dtype=dtype),
222          np.array([0b0, 0b101, 0b1001], dtype=dtype),
223          expected=np.array([0b1, 0b101, 0b1001], dtype=dtype))
224
225      lhs = np.array([0, 5, 3, 14], dtype=dtype)
226      rhs = np.array([5, 0, 7, 11], dtype=dtype)
227      self._testBinary(
228          bitwise_ops.left_shift, lhs, rhs,
229          expected=np.left_shift(lhs, rhs))
230      self._testBinary(
231          bitwise_ops.right_shift, lhs, rhs,
232          expected=np.right_shift(lhs, rhs))
233
234      if dtype in [np.int8, np.int16, np.int32, np.int64]:
235        lhs = np.array([-1, -5, -3, -14], dtype=dtype)
236        rhs = np.array([5, 0, 1, 11], dtype=dtype)
237        self._testBinary(
238            bitwise_ops.right_shift, lhs, rhs,
239            expected=np.right_shift(lhs, rhs))
240
241  def testNumericOps(self):
242    for dtype in self.numeric_types:
243      self._testBinary(
244          math_ops.add,
245          np.array([1, 2], dtype=dtype),
246          np.array([10, 20], dtype=dtype),
247          expected=np.array([11, 22], dtype=dtype))
248      self._testBinary(
249          math_ops.add,
250          dtype(5),
251          np.array([1, 2], dtype=dtype),
252          expected=np.array([6, 7], dtype=dtype))
253      self._testBinary(
254          math_ops.add,
255          np.array([[1], [2]], dtype=dtype),
256          dtype(7),
257          expected=np.array([[8], [9]], dtype=dtype))
258
259      self._testBinary(
260          math_ops.subtract,
261          np.array([1, 2], dtype=dtype),
262          np.array([10, 20], dtype=dtype),
263          expected=np.array([-9, -18], dtype=dtype))
264      self._testBinary(
265          math_ops.subtract,
266          dtype(5),
267          np.array([1, 2], dtype=dtype),
268          expected=np.array([4, 3], dtype=dtype))
269      self._testBinary(
270          math_ops.subtract,
271          np.array([[1], [2]], dtype=dtype),
272          dtype(7),
273          expected=np.array([[-6], [-5]], dtype=dtype))
274
275      if dtype not in self.complex_types:  # min/max not supported for complex
276        self._testBinary(
277            math_ops.maximum,
278            np.array([1, 2], dtype=dtype),
279            np.array([10, 20], dtype=dtype),
280            expected=np.array([10, 20], dtype=dtype))
281        self._testBinary(
282            math_ops.maximum,
283            dtype(5),
284            np.array([1, 20], dtype=dtype),
285            expected=np.array([5, 20], dtype=dtype))
286        self._testBinary(
287            math_ops.maximum,
288            np.array([[10], [2]], dtype=dtype),
289            dtype(7),
290            expected=np.array([[10], [7]], dtype=dtype))
291
292        self._testBinary(
293            math_ops.minimum,
294            np.array([1, 20], dtype=dtype),
295            np.array([10, 2], dtype=dtype),
296            expected=np.array([1, 2], dtype=dtype))
297        self._testBinary(
298            math_ops.minimum,
299            dtype(5),
300            np.array([1, 20], dtype=dtype),
301            expected=np.array([1, 5], dtype=dtype))
302        self._testBinary(
303            math_ops.minimum,
304            np.array([[10], [2]], dtype=dtype),
305            dtype(7),
306            expected=np.array([[7], [2]], dtype=dtype))
307
308      self._testBinary(
309          math_ops.multiply,
310          np.array([1, 20], dtype=dtype),
311          np.array([10, 2], dtype=dtype),
312          expected=np.array([10, 40], dtype=dtype))
313      self._testBinary(
314          math_ops.multiply,
315          dtype(5),
316          np.array([1, 20], dtype=dtype),
317          expected=np.array([5, 100], dtype=dtype))
318      self._testBinary(
319          math_ops.multiply,
320          np.array([[10], [2]], dtype=dtype),
321          dtype(7),
322          expected=np.array([[70], [14]], dtype=dtype))
323
324      # Complex support for squared_difference is incidental, see b/68205550
325      if dtype not in self.complex_types:
326        self._testBinary(
327            math_ops.squared_difference,
328            np.array([1, 2], dtype=dtype),
329            np.array([10, 20], dtype=dtype),
330            expected=np.array([81, 324], dtype=dtype))
331        self._testBinary(
332            math_ops.squared_difference,
333            dtype(5),
334            np.array([1, 2], dtype=dtype),
335            expected=np.array([16, 9], dtype=dtype))
336        self._testBinary(
337            math_ops.squared_difference,
338            np.array([[1], [2]], dtype=dtype),
339            dtype(7),
340            expected=np.array([[36], [25]], dtype=dtype))
341
342      self._testBinary(
343          nn_ops.bias_add,
344          np.array([[1, 2], [3, 4]], dtype=dtype),
345          np.array([2, -1], dtype=dtype),
346          expected=np.array([[3, 1], [5, 3]], dtype=dtype))
347      self._testBinary(
348          nn_ops.bias_add,
349          np.array([[[[1, 2], [3, 4]]]], dtype=dtype),
350          np.array([2, -1], dtype=dtype),
351          expected=np.array([[[[3, 1], [5, 3]]]], dtype=dtype))
352
353  def testComplexOps(self):
354    for dtype in self.complex_types:
355      ctypes = {np.complex64: np.float32}
356      self._testBinary(
357          math_ops.complex,
358          np.array([[[[-1, 2], [2, 0]]]], dtype=ctypes[dtype]),
359          np.array([[[[2, -3], [0, 4]]]], dtype=ctypes[dtype]),
360          expected=np.array([[[[-1 + 2j, 2 - 3j], [2, 4j]]]], dtype=dtype))
361
362      self._testBinary(
363          lambda x, y: math_ops.approximate_equal(x, y, tolerance=0.0001),
364          np.array(
365              [[[[-1 + 2j, 2.00009999 - 3j], [2 - 3j, 3 + 4.01j]]]],
366              dtype=dtype),
367          np.array(
368              [[[[-1.001 + 2j, 2 - 3j], [2 - 3.00009j, 3 + 4j]]]], dtype=dtype),
369          expected=np.array([[[[False, True], [True, False]]]], dtype=dtype))
370
371      self._testBinary(
372          gen_math_ops._real_div,
373          np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j], dtype=dtype),
374          np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j], dtype=dtype),
375          expected=np.array(
376              [1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2],
377              dtype=dtype))
378
379      # Test inf/nan scenarios.
380      self._testBinary(
381          gen_math_ops._real_div,
382          np.array([4 + 3j, 4, 3j, -4, -4j, 2 - 3j], dtype=dtype),
383          np.array([0, 0, 0, 0, 0, 0], dtype=dtype),
384          expected=np.array(
385              [
386                  dtype(1 + 1j) / 0,
387                  dtype(1) / 0,
388                  dtype(1j) / 0,
389                  dtype(-1) / 0,
390                  dtype(-1j) / 0,
391                  dtype(1 - 1j) / 0
392              ],
393              dtype=dtype))
394
395      self._testBinary(
396          math_ops.pow,
397          dtype(3 + 2j),
398          dtype(4 - 5j),
399          expected=np.power(dtype(3 + 2j), dtype(4 - 5j)))
400      self._testBinary(  # empty rhs
401          math_ops.pow,
402          np.array([1 + 2j, 2 - 3j], dtype=dtype),
403          np.zeros(shape=[0, 2], dtype=dtype),
404          expected=np.zeros(shape=[0, 2], dtype=dtype))
405      self._testBinary(  # to zero power
406          math_ops.pow,
407          np.array([1 + 2j, 2 - 3j], dtype=dtype),
408          np.zeros(shape=[1, 2], dtype=dtype),
409          expected=np.ones(shape=[1, 2], dtype=dtype))
410      lhs = np.array([1 - 2j, 4 + 3j, 2 - 3j, 3, 2j, 1, 4], dtype=dtype)
411      rhs = np.array([2, 3j, 3 + 4j, 2 + 3j, 3 - 2j, 2, 3 + 3j], dtype=dtype)
412      scalar = dtype(2 + 2j)
413      self._testBinary(math_ops.pow, lhs, rhs, expected=np.power(lhs, rhs))
414      self._testBinary(
415          math_ops.pow, scalar, rhs, expected=np.power(scalar, rhs))
416      self._testBinary(math_ops.pow, lhs, scalar, np.power(lhs, scalar))
417
418      lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype)
419      rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype)
420      self._testBinary(
421          gen_math_ops._reciprocal_grad, lhs, rhs, expected=-rhs * lhs * lhs)
422
423      self._testBinary(
424          gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs))
425
426      self._testBinary(
427          gen_math_ops._rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2)
428
429      self._testBinary(
430          gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs))
431
432      self._testBinary(
433          gen_math_ops._tanh_grad, lhs, rhs, expected=rhs * (1 - lhs * lhs))
434
435  def testComplexMath(self):
436    for dtype in self.complex_types:
437      self._testBinary(
438          math_ops.add,
439          np.array([1 + 3j, 2 + 7j], dtype=dtype),
440          np.array([10 - 4j, 20 + 17j], dtype=dtype),
441          expected=np.array([11 - 1j, 22 + 24j], dtype=dtype))
442      self._testBinary(
443          math_ops.add,
444          dtype(5 - 7j),
445          np.array([1 + 2j, 2 + 4j], dtype=dtype),
446          expected=np.array([6 - 5j, 7 - 3j], dtype=dtype))
447      self._testBinary(
448          math_ops.add,
449          np.array([[1 - 2j], [2 + 1j]], dtype=dtype),
450          dtype(7 + 5j),
451          expected=np.array([[8 + 3j], [9 + 6j]], dtype=dtype))
452
453      self._testBinary(
454          math_ops.subtract,
455          np.array([1 + 3j, 2 + 7j], dtype=dtype),
456          np.array([10 - 4j, 20 + 17j], dtype=dtype),
457          expected=np.array([-9 + 7j, -18 - 10j], dtype=dtype))
458      self._testBinary(
459          math_ops.subtract,
460          dtype(5 - 7j),
461          np.array([1 + 2j, 2 + 4j], dtype=dtype),
462          expected=np.array([4 - 9j, 3 - 11j], dtype=dtype))
463      self._testBinary(
464          math_ops.subtract,
465          np.array([[1 - 2j], [2 + 1j]], dtype=dtype),
466          dtype(7 + 5j),
467          expected=np.array([[-6 - 7j], [-5 - 4j]], dtype=dtype))
468
469      self._testBinary(
470          math_ops.multiply,
471          np.array([1 + 3j, 2 + 7j], dtype=dtype),
472          np.array([10 - 4j, 20 + 17j], dtype=dtype),
473          expected=np.array(
474              [(1 + 3j) * (10 - 4j), (2 + 7j) * (20 + 17j)], dtype=dtype))
475      self._testBinary(
476          math_ops.multiply,
477          dtype(5 - 7j),
478          np.array([1 + 2j, 2 + 4j], dtype=dtype),
479          expected=np.array(
480              [(5 - 7j) * (1 + 2j), (5 - 7j) * (2 + 4j)], dtype=dtype))
481      self._testBinary(
482          math_ops.multiply,
483          np.array([[1 - 2j], [2 + 1j]], dtype=dtype),
484          dtype(7 + 5j),
485          expected=np.array(
486              [[(7 + 5j) * (1 - 2j)], [(7 + 5j) * (2 + 1j)]], dtype=dtype))
487
488      self._testBinary(
489          math_ops.div,
490          np.array([8 - 1j, 2 + 16j], dtype=dtype),
491          np.array([2 + 4j, 4 - 8j], dtype=dtype),
492          expected=np.array(
493              [(8 - 1j) / (2 + 4j), (2 + 16j) / (4 - 8j)], dtype=dtype))
494      self._testBinary(
495          math_ops.div,
496          dtype(1 + 2j),
497          np.array([2 + 4j, 4 - 8j], dtype=dtype),
498          expected=np.array(
499              [(1 + 2j) / (2 + 4j), (1 + 2j) / (4 - 8j)], dtype=dtype))
500      self._testBinary(
501          math_ops.div,
502          np.array([2 + 4j, 4 - 8j], dtype=dtype),
503          dtype(1 + 2j),
504          expected=np.array(
505              [(2 + 4j) / (1 + 2j), (4 - 8j) / (1 + 2j)], dtype=dtype))
506
507      # TODO(b/68205550): math_ops.squared_difference shouldn't be supported.
508
509      self._testBinary(
510          nn_ops.bias_add,
511          np.array([[1 + 2j, 2 + 7j], [3 - 5j, 4 + 2j]], dtype=dtype),
512          np.array([2 + 6j, -1 - 3j], dtype=dtype),
513          expected=np.array([[3 + 8j, 1 + 4j], [5 + 1j, 3 - 1j]], dtype=dtype))
514      self._testBinary(
515          nn_ops.bias_add,
516          np.array([[[[1 + 4j, 2 - 1j], [3 + 7j, 4]]]], dtype=dtype),
517          np.array([2 + 1j, -1 + 2j], dtype=dtype),
518          expected=np.array(
519              [[[[3 + 5j, 1 + 1j], [5 + 8j, 3 + 2j]]]], dtype=dtype))
520
521  def _testDivision(self, dtype):
522    """Test cases for division operators."""
523    self._testBinary(
524        math_ops.div,
525        np.array([10, 20], dtype=dtype),
526        np.array([10, 2], dtype=dtype),
527        expected=np.array([1, 10], dtype=dtype))
528    self._testBinary(
529        math_ops.div,
530        dtype(40),
531        np.array([2, 20], dtype=dtype),
532        expected=np.array([20, 2], dtype=dtype))
533    self._testBinary(
534        math_ops.div,
535        np.array([[10], [4]], dtype=dtype),
536        dtype(2),
537        expected=np.array([[5], [2]], dtype=dtype))
538
539    if dtype not in self.complex_types:  # floordiv unsupported for complex.
540      self._testBinary(
541          gen_math_ops._floor_div,
542          np.array([3, 3, -1, -9, -8], dtype=dtype),
543          np.array([2, -2, 7, 2, -4], dtype=dtype),
544          expected=np.array([1, -2, -1, -5, 2], dtype=dtype))
545
546  def testIntDivision(self):
547    for dtype in self.int_types:
548      self._testDivision(dtype)
549
550  def testFloatDivision(self):
551    for dtype in self.float_types | self.complex_types:
552      self._testDivision(dtype)
553
554  def _testRemainder(self, dtype):
555    """Test cases for remainder operators."""
556    self._testBinary(
557        gen_math_ops._floor_mod,
558        np.array([3, 3, -1, -8], dtype=dtype),
559        np.array([2, -2, 7, -4], dtype=dtype),
560        expected=np.array([1, -1, 6, 0], dtype=dtype))
561    self._testBinary(
562        gen_math_ops._truncate_mod,
563        np.array([3, 3, -1, -8], dtype=dtype),
564        np.array([2, -2, 7, -4], dtype=dtype),
565        expected=np.array([1, 1, -1, 0], dtype=dtype))
566
567  def testIntRemainder(self):
568    for dtype in self.int_types:
569      self._testRemainder(dtype)
570
571  def testFloatRemainder(self):
572    for dtype in self.float_types:
573      self._testRemainder(dtype)
574
575  def testLogicalOps(self):
576    self._testBinary(
577        math_ops.logical_and,
578        np.array([[True, False], [False, True]], dtype=np.bool),
579        np.array([[False, True], [False, True]], dtype=np.bool),
580        expected=np.array([[False, False], [False, True]], dtype=np.bool))
581
582    self._testBinary(
583        math_ops.logical_or,
584        np.array([[True, False], [False, True]], dtype=np.bool),
585        np.array([[False, True], [False, True]], dtype=np.bool),
586        expected=np.array([[True, True], [False, True]], dtype=np.bool))
587
588  def testComparisons(self):
589    self._testBinary(
590        math_ops.equal,
591        np.array([1, 5, 20], dtype=np.float32),
592        np.array([10, 5, 2], dtype=np.float32),
593        expected=np.array([False, True, False], dtype=np.bool))
594    self._testBinary(
595        math_ops.equal,
596        np.float32(5),
597        np.array([1, 5, 20], dtype=np.float32),
598        expected=np.array([False, True, False], dtype=np.bool))
599    self._testBinary(
600        math_ops.equal,
601        np.array([[10], [7], [2]], dtype=np.float32),
602        np.float32(7),
603        expected=np.array([[False], [True], [False]], dtype=np.bool))
604
605    self._testBinary(
606        math_ops.not_equal,
607        np.array([1, 5, 20], dtype=np.float32),
608        np.array([10, 5, 2], dtype=np.float32),
609        expected=np.array([True, False, True], dtype=np.bool))
610    self._testBinary(
611        math_ops.not_equal,
612        np.float32(5),
613        np.array([1, 5, 20], dtype=np.float32),
614        expected=np.array([True, False, True], dtype=np.bool))
615    self._testBinary(
616        math_ops.not_equal,
617        np.array([[10], [7], [2]], dtype=np.float32),
618        np.float32(7),
619        expected=np.array([[True], [False], [True]], dtype=np.bool))
620
621    for greater_op in [math_ops.greater, (lambda x, y: x > y)]:
622      self._testBinary(
623          greater_op,
624          np.array([1, 5, 20], dtype=np.float32),
625          np.array([10, 5, 2], dtype=np.float32),
626          expected=np.array([False, False, True], dtype=np.bool))
627      self._testBinary(
628          greater_op,
629          np.float32(5),
630          np.array([1, 5, 20], dtype=np.float32),
631          expected=np.array([True, False, False], dtype=np.bool))
632      self._testBinary(
633          greater_op,
634          np.array([[10], [7], [2]], dtype=np.float32),
635          np.float32(7),
636          expected=np.array([[True], [False], [False]], dtype=np.bool))
637
638    for greater_equal_op in [math_ops.greater_equal, (lambda x, y: x >= y)]:
639      self._testBinary(
640          greater_equal_op,
641          np.array([1, 5, 20], dtype=np.float32),
642          np.array([10, 5, 2], dtype=np.float32),
643          expected=np.array([False, True, True], dtype=np.bool))
644      self._testBinary(
645          greater_equal_op,
646          np.float32(5),
647          np.array([1, 5, 20], dtype=np.float32),
648          expected=np.array([True, True, False], dtype=np.bool))
649      self._testBinary(
650          greater_equal_op,
651          np.array([[10], [7], [2]], dtype=np.float32),
652          np.float32(7),
653          expected=np.array([[True], [True], [False]], dtype=np.bool))
654
655    for less_op in [math_ops.less, (lambda x, y: x < y)]:
656      self._testBinary(
657          less_op,
658          np.array([1, 5, 20], dtype=np.float32),
659          np.array([10, 5, 2], dtype=np.float32),
660          expected=np.array([True, False, False], dtype=np.bool))
661      self._testBinary(
662          less_op,
663          np.float32(5),
664          np.array([1, 5, 20], dtype=np.float32),
665          expected=np.array([False, False, True], dtype=np.bool))
666      self._testBinary(
667          less_op,
668          np.array([[10], [7], [2]], dtype=np.float32),
669          np.float32(7),
670          expected=np.array([[False], [False], [True]], dtype=np.bool))
671
672    for less_equal_op in [math_ops.less_equal, (lambda x, y: x <= y)]:
673      self._testBinary(
674          less_equal_op,
675          np.array([1, 5, 20], dtype=np.float32),
676          np.array([10, 5, 2], dtype=np.float32),
677          expected=np.array([True, True, False], dtype=np.bool))
678      self._testBinary(
679          less_equal_op,
680          np.float32(5),
681          np.array([1, 5, 20], dtype=np.float32),
682          expected=np.array([False, True, True], dtype=np.bool))
683      self._testBinary(
684          less_equal_op,
685          np.array([[10], [7], [2]], dtype=np.float32),
686          np.float32(7),
687          expected=np.array([[False], [True], [True]], dtype=np.bool))
688
689  def testBroadcasting(self):
690    """Tests broadcasting behavior of an operator."""
691
692    for dtype in self.numeric_types:
693      self._testBinary(
694          math_ops.add,
695          np.array(3, dtype=dtype),
696          np.array([10, 20], dtype=dtype),
697          expected=np.array([13, 23], dtype=dtype))
698      self._testBinary(
699          math_ops.add,
700          np.array([10, 20], dtype=dtype),
701          np.array(4, dtype=dtype),
702          expected=np.array([14, 24], dtype=dtype))
703
704      # [1,3] x [4,1] => [4,3]
705      self._testBinary(
706          math_ops.add,
707          np.array([[10, 20, 30]], dtype=dtype),
708          np.array([[1], [2], [3], [4]], dtype=dtype),
709          expected=np.array(
710              [[11, 21, 31], [12, 22, 32], [13, 23, 33], [14, 24, 34]],
711              dtype=dtype))
712
713      # [3] * [4,1] => [4,3]
714      self._testBinary(
715          math_ops.add,
716          np.array([10, 20, 30], dtype=dtype),
717          np.array([[1], [2], [3], [4]], dtype=dtype),
718          expected=np.array(
719              [[11, 21, 31], [12, 22, 32], [13, 23, 33], [14, 24, 34]],
720              dtype=dtype))
721
722  def testFill(self):
723    for dtype in self.numeric_types:
724      self._testBinary(
725          array_ops.fill,
726          np.array([], dtype=np.int32),
727          dtype(-42),
728          expected=dtype(-42))
729      self._testBinary(
730          array_ops.fill,
731          np.array([1, 2], dtype=np.int32),
732          dtype(7),
733          expected=np.array([[7, 7]], dtype=dtype))
734      self._testBinary(
735          array_ops.fill,
736          np.array([3, 2], dtype=np.int32),
737          dtype(50),
738          expected=np.array([[50, 50], [50, 50], [50, 50]], dtype=dtype))
739
740  # Helper method used by testMatMul, testSparseMatMul, testBatchMatMul below.
741  def _testMatMul(self, op):
742    for dtype in self.float_types:
743      self._testBinary(
744          op,
745          np.array([[-0.25]], dtype=dtype),
746          np.array([[8]], dtype=dtype),
747          expected=np.array([[-2]], dtype=dtype))
748      self._testBinary(
749          op,
750          np.array([[100, 10, 0.5]], dtype=dtype),
751          np.array([[1, 3], [2, 5], [6, 8]], dtype=dtype),
752          expected=np.array([[123, 354]], dtype=dtype))
753      self._testBinary(
754          op,
755          np.array([[1, 3], [2, 5], [6, 8]], dtype=dtype),
756          np.array([[100], [10]], dtype=dtype),
757          expected=np.array([[130], [250], [680]], dtype=dtype))
758      self._testBinary(
759          op,
760          np.array([[1000, 100], [10, 1]], dtype=dtype),
761          np.array([[1, 2], [3, 4]], dtype=dtype),
762          expected=np.array([[1300, 2400], [13, 24]], dtype=dtype))
763
764      self._testBinary(
765          op,
766          np.array([], dtype=dtype).reshape((2, 0)),
767          np.array([], dtype=dtype).reshape((0, 3)),
768          expected=np.array([[0, 0, 0], [0, 0, 0]], dtype=dtype))
769
770  def testMatMul(self):
771    self._testMatMul(math_ops.matmul)
772
773  # TODO(phawkins): failing on GPU, no registered kernel.
774  def DISABLED_testSparseMatMul(self):
775    # Binary wrappers for sparse_matmul with different hints
776    def SparseMatmulWrapperTF(a, b):
777      return math_ops.sparse_matmul(a, b, a_is_sparse=True)
778
779    def SparseMatmulWrapperFT(a, b):
780      return math_ops.sparse_matmul(a, b, b_is_sparse=True)
781
782    def SparseMatmulWrapperTT(a, b):
783      return math_ops.sparse_matmul(a, b, a_is_sparse=True, b_is_sparse=True)
784
785    self._testMatMul(math_ops.sparse_matmul)
786    self._testMatMul(SparseMatmulWrapperTF)
787    self._testMatMul(SparseMatmulWrapperFT)
788    self._testMatMul(SparseMatmulWrapperTT)
789
790  def testBatchMatMul(self):
791    # Same tests as for tf.matmul above.
792    self._testMatMul(math_ops.matmul)
793
794    # Tests with batches of matrices.
795    self._testBinary(
796        math_ops.matmul,
797        np.array([[[-0.25]]], dtype=np.float32),
798        np.array([[[8]]], dtype=np.float32),
799        expected=np.array([[[-2]]], dtype=np.float32))
800    self._testBinary(
801        math_ops.matmul,
802        np.array([[[-0.25]], [[4]]], dtype=np.float32),
803        np.array([[[8]], [[2]]], dtype=np.float32),
804        expected=np.array([[[-2]], [[8]]], dtype=np.float32))
805    self._testBinary(
806        math_ops.matmul,
807        np.array(
808            [[[[7, 13], [10, 1]], [[2, 0.25], [20, 2]]],
809             [[[3, 5], [30, 3]], [[0.75, 1], [40, 4]]]],
810            dtype=np.float32),
811        np.array(
812            [[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], [[[11, 22], [33, 44]],
813                                                    [[55, 66], [77, 88]]]],
814            dtype=np.float32),
815        expected=np.array(
816            [[[[46, 66], [13, 24]], [[11.75, 14], [114, 136]]],
817             [[[198, 286], [429, 792]], [[118.25, 137.5], [2508, 2992]]]],
818            dtype=np.float32))
819
820    self._testBinary(
821        math_ops.matmul,
822        np.array([], dtype=np.float32).reshape((2, 2, 0)),
823        np.array([], dtype=np.float32).reshape((2, 0, 3)),
824        expected=np.array(
825            [[[0, 0, 0], [0, 0, 0]], [[0, 0, 0], [0, 0, 0]]],
826            dtype=np.float32))
827    self._testBinary(
828        math_ops.matmul,
829        np.array([], dtype=np.float32).reshape((0, 2, 4)),
830        np.array([], dtype=np.float32).reshape((0, 4, 3)),
831        expected=np.array([], dtype=np.float32).reshape(0, 2, 3))
832
833    # Regression test for b/31472796.
834    if hasattr(np, "matmul"):
835      x = np.arange(0, 3 * 5 * 2 * 7, dtype=np.float32).reshape((3, 5, 2, 7))
836      self._testBinary(
837          lambda x, y: math_ops.matmul(x, y, adjoint_b=True),
838          x, x,
839          expected=np.matmul(x, x.transpose([0, 1, 3, 2])))
840
841  def testExpandDims(self):
842    for dtype in self.numeric_types:
843      self._testBinary(
844          array_ops.expand_dims,
845          dtype(7),
846          np.int32(0),
847          expected=np.array([7], dtype=dtype))
848      self._testBinary(
849          array_ops.expand_dims,
850          np.array([42], dtype=dtype),
851          np.int32(0),
852          expected=np.array([[42]], dtype=dtype))
853      self._testBinary(
854          array_ops.expand_dims,
855          np.array([], dtype=dtype),
856          np.int32(0),
857          expected=np.array([[]], dtype=dtype))
858      self._testBinary(
859          array_ops.expand_dims,
860          np.array([[[1, 2], [3, 4]]], dtype=dtype),
861          np.int32(0),
862          expected=np.array([[[[1, 2], [3, 4]]]], dtype=dtype))
863      self._testBinary(
864          array_ops.expand_dims,
865          np.array([[[1, 2], [3, 4]]], dtype=dtype),
866          np.int32(1),
867          expected=np.array([[[[1, 2], [3, 4]]]], dtype=dtype))
868      self._testBinary(
869          array_ops.expand_dims,
870          np.array([[[1, 2], [3, 4]]], dtype=dtype),
871          np.int32(2),
872          expected=np.array([[[[1, 2]], [[3, 4]]]], dtype=dtype))
873      self._testBinary(
874          array_ops.expand_dims,
875          np.array([[[1, 2], [3, 4]]], dtype=dtype),
876          np.int32(3),
877          expected=np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype))
878
879  def testPad(self):
880    for dtype in self.numeric_types:
881      self._testBinary(
882          array_ops.pad,
883          np.array(
884              [[1, 2, 3], [4, 5, 6]], dtype=dtype),
885          np.array(
886              [[1, 2], [2, 1]], dtype=np.int32),
887          expected=np.array(
888              [[0, 0, 0, 0, 0, 0],
889               [0, 0, 1, 2, 3, 0],
890               [0, 0, 4, 5, 6, 0],
891               [0, 0, 0, 0, 0, 0],
892               [0, 0, 0, 0, 0, 0]],
893              dtype=dtype))
894
895      self._testBinary(
896          lambda x, y: array_ops.pad(x, y, constant_values=7),
897          np.array(
898              [[1, 2, 3], [4, 5, 6]], dtype=dtype),
899          np.array(
900              [[0, 3], [2, 1]], dtype=np.int32),
901          expected=np.array(
902              [[7, 7, 1, 2, 3, 7],
903               [7, 7, 4, 5, 6, 7],
904               [7, 7, 7, 7, 7, 7],
905               [7, 7, 7, 7, 7, 7],
906               [7, 7, 7, 7, 7, 7]],
907              dtype=dtype))
908
909  def testMirrorPad(self):
910    mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT")
911    for dtype in self.numeric_types:
912      self._testBinary(
913          mirror_pad,
914          np.array(
915              [
916                  [1, 2, 3],  #
917                  [4, 5, 6],  #
918              ],
919              dtype=dtype),
920          np.array([[
921              1,
922              1,
923          ], [2, 2]], dtype=np.int32),
924          expected=np.array(
925              [
926                  [6, 5, 4, 5, 6, 5, 4],  #
927                  [3, 2, 1, 2, 3, 2, 1],  #
928                  [6, 5, 4, 5, 6, 5, 4],  #
929                  [3, 2, 1, 2, 3, 2, 1]
930              ],
931              dtype=dtype))
932      self._testBinary(
933          mirror_pad,
934          np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype),
935          np.array([[0, 0], [0, 0]], dtype=np.int32),
936          expected=np.array([[1, 2, 3], [4, 5, 6]], dtype=dtype))
937      self._testBinary(
938          mirror_pad,
939          np.array(
940              [
941                  [1, 2, 3],  #
942                  [4, 5, 6],  #
943                  [7, 8, 9]
944              ],
945              dtype=dtype),
946          np.array([[2, 2], [0, 0]], dtype=np.int32),
947          expected=np.array(
948              [
949                  [7, 8, 9],  #
950                  [4, 5, 6],  #
951                  [1, 2, 3],  #
952                  [4, 5, 6],  #
953                  [7, 8, 9],  #
954                  [4, 5, 6],  #
955                  [1, 2, 3]
956              ],
957              dtype=dtype))
958      self._testBinary(
959          mirror_pad,
960          np.array(
961              [
962                  [[1, 2, 3], [4, 5, 6]],
963                  [[7, 8, 9], [10, 11, 12]],
964              ], dtype=dtype),
965          np.array([[0, 0], [1, 1], [1, 1]], dtype=np.int32),
966          expected=np.array(
967              [
968                  [
969                      [5, 4, 5, 6, 5],  #
970                      [2, 1, 2, 3, 2],  #
971                      [5, 4, 5, 6, 5],  #
972                      [2, 1, 2, 3, 2],  #
973                  ],
974                  [
975                      [11, 10, 11, 12, 11],  #
976                      [8, 7, 8, 9, 8],  #
977                      [11, 10, 11, 12, 11],  #
978                      [8, 7, 8, 9, 8],  #
979                  ]
980              ],
981              dtype=dtype))
982
983  def testReshape(self):
984    for dtype in self.numeric_types:
985      self._testBinary(
986          array_ops.reshape,
987          np.array([], dtype=dtype),
988          np.array([0, 4], dtype=np.int32),
989          expected=np.zeros(shape=[0, 4], dtype=dtype))
990      self._testBinary(
991          array_ops.reshape,
992          np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
993          np.array([2, 3], dtype=np.int32),
994          expected=np.array([[0, 1, 2], [3, 4, 5]], dtype=dtype))
995      self._testBinary(
996          array_ops.reshape,
997          np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
998          np.array([3, 2], dtype=np.int32),
999          expected=np.array([[0, 1], [2, 3], [4, 5]], dtype=dtype))
1000      self._testBinary(
1001          array_ops.reshape,
1002          np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
1003          np.array([-1, 6], dtype=np.int32),
1004          expected=np.array([[0, 1, 2, 3, 4, 5]], dtype=dtype))
1005      self._testBinary(
1006          array_ops.reshape,
1007          np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
1008          np.array([6, -1], dtype=np.int32),
1009          expected=np.array([[0], [1], [2], [3], [4], [5]], dtype=dtype))
1010      self._testBinary(
1011          array_ops.reshape,
1012          np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
1013          np.array([2, -1], dtype=np.int32),
1014          expected=np.array([[0, 1, 2], [3, 4, 5]], dtype=dtype))
1015      self._testBinary(
1016          array_ops.reshape,
1017          np.array([0, 1, 2, 3, 4, 5], dtype=dtype),
1018          np.array([-1, 3], dtype=np.int32),
1019          expected=np.array([[0, 1, 2], [3, 4, 5]], dtype=dtype))
1020
1021  def testSplit(self):
1022    for dtype in self.numeric_types:
1023      for axis in [0, -3]:
1024        self._testBinary(
1025            lambda x, y: array_ops.split(value=y, num_or_size_splits=3, axis=x),
1026            np.int32(axis),
1027            np.array([[[1], [2]], [[3], [4]], [[5], [6]]],
1028                     dtype=dtype),
1029            expected=[
1030                np.array([[[1], [2]]], dtype=dtype),
1031                np.array([[[3], [4]]], dtype=dtype),
1032                np.array([[[5], [6]]], dtype=dtype),
1033            ],
1034            equality_test=self.ListsAreClose)
1035
1036      for axis in [1, -2]:
1037        self._testBinary(
1038            lambda x, y: array_ops.split(value=y, num_or_size_splits=2, axis=x),
1039            np.int32(axis),
1040            np.array([[[1], [2]], [[3], [4]], [[5], [6]]],
1041                     dtype=dtype),
1042            expected=[
1043                np.array([[[1]], [[3]], [[5]]], dtype=dtype),
1044                np.array([[[2]], [[4]], [[6]]], dtype=dtype),
1045            ],
1046            equality_test=self.ListsAreClose)
1047
1048  def testTile(self):
1049    for dtype in self.numeric_types:
1050      self._testBinary(
1051          array_ops.tile,
1052          np.array([[6]], dtype=dtype),
1053          np.array([1, 2], dtype=np.int32),
1054          expected=np.array([[6, 6]], dtype=dtype))
1055      self._testBinary(
1056          array_ops.tile,
1057          np.array([[1], [2]], dtype=dtype),
1058          np.array([1, 2], dtype=np.int32),
1059          expected=np.array([[1, 1], [2, 2]], dtype=dtype))
1060      self._testBinary(
1061          array_ops.tile,
1062          np.array([[1, 2], [3, 4]], dtype=dtype),
1063          np.array([3, 2], dtype=np.int32),
1064          expected=np.array(
1065              [[1, 2, 1, 2],
1066               [3, 4, 3, 4],
1067               [1, 2, 1, 2],
1068               [3, 4, 3, 4],
1069               [1, 2, 1, 2],
1070               [3, 4, 3, 4]],
1071              dtype=dtype))
1072      self._testBinary(
1073          array_ops.tile,
1074          np.array([[1, 2], [3, 4]], dtype=dtype),
1075          np.array([1, 1], dtype=np.int32),
1076          expected=np.array(
1077              [[1, 2],
1078               [3, 4]],
1079              dtype=dtype))
1080      self._testBinary(
1081          array_ops.tile,
1082          np.array([[1, 2]], dtype=dtype),
1083          np.array([3, 1], dtype=np.int32),
1084          expected=np.array(
1085              [[1, 2],
1086               [1, 2],
1087               [1, 2]],
1088              dtype=dtype))
1089
1090  def testTranspose(self):
1091    for dtype in self.numeric_types:
1092      self._testBinary(
1093          array_ops.transpose,
1094          np.zeros(shape=[1, 0, 4], dtype=dtype),
1095          np.array([1, 2, 0], dtype=np.int32),
1096          expected=np.zeros(shape=[0, 4, 1], dtype=dtype))
1097      self._testBinary(
1098          array_ops.transpose,
1099          np.array([[1, 2], [3, 4]], dtype=dtype),
1100          np.array([0, 1], dtype=np.int32),
1101          expected=np.array([[1, 2], [3, 4]], dtype=dtype))
1102      self._testBinary(
1103          array_ops.transpose,
1104          np.array([[1, 2], [3, 4]], dtype=dtype),
1105          np.array([1, 0], dtype=np.int32),
1106          expected=np.array([[1, 3], [2, 4]], dtype=dtype))
1107
1108  def testCross(self):
1109    for dtype in self.float_types:
1110      self._testBinary(
1111          gen_math_ops.cross,
1112          np.zeros((4, 3), dtype=dtype),
1113          np.zeros((4, 3), dtype=dtype),
1114          expected=np.zeros((4, 3), dtype=dtype))
1115      self._testBinary(
1116          gen_math_ops.cross,
1117          np.array([1, 2, 3], dtype=dtype),
1118          np.array([4, 5, 6], dtype=dtype),
1119          expected=np.array([-3, 6, -3], dtype=dtype))
1120      self._testBinary(
1121          gen_math_ops.cross,
1122          np.array([[1, 2, 3], [10, 11, 12]], dtype=dtype),
1123          np.array([[4, 5, 6], [40, 50, 60]], dtype=dtype),
1124          expected=np.array([[-3, 6, -3], [60, -120, 60]], dtype=dtype))
1125
1126  def testBroadcastArgs(self):
1127    self._testBinary(array_ops.broadcast_dynamic_shape,
1128                     np.array([2, 3, 5], dtype=np.int32),
1129                     np.array([1], dtype=np.int32),
1130                     expected=np.array([2, 3, 5], dtype=np.int32))
1131
1132    self._testBinary(array_ops.broadcast_dynamic_shape,
1133                     np.array([1], dtype=np.int32),
1134                     np.array([2, 3, 5], dtype=np.int32),
1135                     expected=np.array([2, 3, 5], dtype=np.int32))
1136
1137    self._testBinary(array_ops.broadcast_dynamic_shape,
1138                     np.array([2, 3, 5], dtype=np.int32),
1139                     np.array([5], dtype=np.int32),
1140                     expected=np.array([2, 3, 5], dtype=np.int32))
1141
1142    self._testBinary(array_ops.broadcast_dynamic_shape,
1143                     np.array([5], dtype=np.int32),
1144                     np.array([2, 3, 5], dtype=np.int32),
1145                     expected=np.array([2, 3, 5], dtype=np.int32))
1146
1147    self._testBinary(array_ops.broadcast_dynamic_shape,
1148                     np.array([2, 3, 5], dtype=np.int32),
1149                     np.array([3, 5], dtype=np.int32),
1150                     expected=np.array([2, 3, 5], dtype=np.int32))
1151
1152    self._testBinary(array_ops.broadcast_dynamic_shape,
1153                     np.array([3, 5], dtype=np.int32),
1154                     np.array([2, 3, 5], dtype=np.int32),
1155                     expected=np.array([2, 3, 5], dtype=np.int32))
1156
1157    self._testBinary(array_ops.broadcast_dynamic_shape,
1158                     np.array([2, 3, 5], dtype=np.int32),
1159                     np.array([3, 1], dtype=np.int32),
1160                     expected=np.array([2, 3, 5], dtype=np.int32))
1161
1162    self._testBinary(array_ops.broadcast_dynamic_shape,
1163                     np.array([3, 1], dtype=np.int32),
1164                     np.array([2, 3, 5], dtype=np.int32),
1165                     expected=np.array([2, 3, 5], dtype=np.int32))
1166
1167    self._testBinary(array_ops.broadcast_dynamic_shape,
1168                     np.array([2, 1, 5], dtype=np.int32),
1169                     np.array([3, 1], dtype=np.int32),
1170                     expected=np.array([2, 3, 5], dtype=np.int32))
1171
1172    self._testBinary(array_ops.broadcast_dynamic_shape,
1173                     np.array([3, 1], dtype=np.int32),
1174                     np.array([2, 1, 5], dtype=np.int32),
1175                     expected=np.array([2, 3, 5], dtype=np.int32))
1176
1177    with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError,
1178                                             "Incompatible shapes"):
1179      self._testBinary(array_ops.broadcast_dynamic_shape,
1180                       np.array([1, 2, 3], dtype=np.int32),
1181                       np.array([4, 5, 6], dtype=np.int32),
1182                       expected=None)
1183
1184  def testMatrixSetDiag(self):
1185    for dtype in self.numeric_types:
1186      # Square
1187      self._testBinary(
1188          array_ops.matrix_set_diag,
1189          np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]],
1190                   dtype=dtype),
1191          np.array([1.0, 2.0, 3.0], dtype=dtype),
1192          expected=np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0], [1.0, 1.0, 3.0]],
1193                            dtype=dtype))
1194
1195      self._testBinary(
1196          array_ops.matrix_set_diag,
1197          np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]],
1198                    [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0], [2.0, 0.0, 6.0]]],
1199                   dtype=dtype),
1200          np.array([[-1.0, 0.0, -3.0], [-4.0, -5.0, -6.0]], dtype=dtype),
1201          expected=np.array(
1202              [[[-1.0, 0.0, 3.0], [0.0, 0.0, 0.0], [1.0, 0.0, -3.0]],
1203               [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0], [2.0, 0.0, -6.0]]],
1204              dtype=dtype))
1205
1206      # Rectangular
1207      self._testBinary(
1208          array_ops.matrix_set_diag,
1209          np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]], dtype=dtype),
1210          np.array([3.0, 4.0], dtype=dtype),
1211          expected=np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]], dtype=dtype))
1212
1213      self._testBinary(
1214          array_ops.matrix_set_diag,
1215          np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], dtype=dtype),
1216          np.array([3.0, 4.0], dtype=dtype),
1217          expected=np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]], dtype=dtype))
1218
1219      self._testBinary(
1220          array_ops.matrix_set_diag,
1221          np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]],
1222                    [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]], dtype=dtype),
1223          np.array([[-1.0, -2.0], [-4.0, -5.0]],
1224                   dtype=dtype),
1225          expected=np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]],
1226                             [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]],
1227                            dtype=dtype))
1228
1229if __name__ == "__main__":
1230  googletest.main()
1231