• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tf numpy array methods."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import itertools
22import sys
23import numpy as np
24from six.moves import range
25from six.moves import zip
26
27from tensorflow.python.eager import context
28from tensorflow.python.eager import def_function
29from tensorflow.python.framework import config
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import indexed_slices
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import tensor_shape
35from tensorflow.python.framework import tensor_spec
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops.numpy_ops import np_array_ops
38from tensorflow.python.ops.numpy_ops import np_arrays
39from tensorflow.python.ops.numpy_ops import np_math_ops
40from tensorflow.python.platform import test
41
42
43_virtual_devices_ready = False
44
45
46def set_up_virtual_devices():
47  global _virtual_devices_ready
48  if _virtual_devices_ready:
49    return
50  physical_devices = config.list_physical_devices('CPU')
51  config.set_logical_device_configuration(
52      physical_devices[0], [
53          context.LogicalDeviceConfiguration(),
54          context.LogicalDeviceConfiguration()
55      ])
56  _virtual_devices_ready = True
57
58
59class ArrayCreationTest(test.TestCase):
60
61  def setUp(self):
62    super(ArrayCreationTest, self).setUp()
63    set_up_virtual_devices()
64    python_shapes = [
65        0, 1, 2, (), (1,), (2,), (1, 2, 3), [], [1], [2], [1, 2, 3]
66    ]
67    self.shape_transforms = [
68        lambda x: x, lambda x: np.array(x, dtype=int),
69        lambda x: np_array_ops.array(x, dtype=int), tensor_shape.TensorShape
70    ]
71
72    self.all_shapes = []
73    for fn in self.shape_transforms:
74      self.all_shapes.extend([fn(s) for s in python_shapes])
75
76    if sys.version_info.major == 3:
77      # There is a bug of np.empty (and alike) in Python 3 causing a crash when
78      # the `shape` argument is an np_arrays.ndarray scalar (or tf.Tensor
79      # scalar).
80      def not_ndarray_scalar(s):
81        return not (isinstance(s, np_arrays.ndarray) and s.ndim == 0)
82
83      self.all_shapes = list(filter(not_ndarray_scalar, self.all_shapes))
84
85    self.all_types = [
86        int, float, np.int16, np.int32, np.int64, np.float16, np.float32,
87        np.float64
88    ]
89
90    source_array_data = [
91        1,
92        5.5,
93        7,
94        (),
95        (8, 10.),
96        ((), ()),
97        ((1, 4), (2, 8)),
98        [],
99        [7],
100        [8, 10.],
101        [[], []],
102        [[1, 4], [2, 8]],
103        ([], []),
104        ([1, 4], [2, 8]),
105        [(), ()],
106        [(1, 4), (2, 8)],
107    ]
108
109    self.array_transforms = [
110        lambda x: x,
111        ops.convert_to_tensor,
112        np.array,
113        np_array_ops.array,
114    ]
115    self.all_arrays = []
116    for fn in self.array_transforms:
117      self.all_arrays.extend([fn(s) for s in source_array_data])
118
119  def testEmpty(self):
120    for s in self.all_shapes:
121      actual = np_array_ops.empty(s)
122      expected = np.empty(s)
123      msg = 'shape: {}'.format(s)
124      self.match_shape(actual, expected, msg)
125      self.match_dtype(actual, expected, msg)
126
127    for s, t in itertools.product(self.all_shapes, self.all_types):
128      actual = np_array_ops.empty(s, t)
129      expected = np.empty(s, t)
130      msg = 'shape: {}, dtype: {}'.format(s, t)
131      self.match_shape(actual, expected, msg)
132      self.match_dtype(actual, expected, msg)
133
134  def testEmptyLike(self):
135    for a in self.all_arrays:
136      actual = np_array_ops.empty_like(a)
137      expected = np.empty_like(a)
138      msg = 'array: {}'.format(a)
139      self.match_shape(actual, expected, msg)
140      self.match_dtype(actual, expected, msg)
141
142    for a, t in itertools.product(self.all_arrays, self.all_types):
143      actual = np_array_ops.empty_like(a, t)
144      expected = np.empty_like(a, t)
145      msg = 'array: {} type: {}'.format(a, t)
146      self.match_shape(actual, expected, msg)
147      self.match_dtype(actual, expected, msg)
148
149  def testZeros(self):
150    for s in self.all_shapes:
151      actual = np_array_ops.zeros(s)
152      expected = np.zeros(s)
153      msg = 'shape: {}'.format(s)
154      self.match(actual, expected, msg)
155
156    for s, t in itertools.product(self.all_shapes, self.all_types):
157      actual = np_array_ops.zeros(s, t)
158      expected = np.zeros(s, t)
159      msg = 'shape: {}, dtype: {}'.format(s, t)
160      self.match(actual, expected, msg)
161
162  def testZerosLike(self):
163    for a in self.all_arrays:
164      actual = np_array_ops.zeros_like(a)
165      expected = np.zeros_like(a)
166      msg = 'array: {}'.format(a)
167      self.match(actual, expected, msg)
168
169    for a, t in itertools.product(self.all_arrays, self.all_types):
170      actual = np_array_ops.zeros_like(a, t)
171      expected = np.zeros_like(a, t)
172      msg = 'array: {} type: {}'.format(a, t)
173      self.match(actual, expected, msg)
174
175  def testOnes(self):
176    for s in self.all_shapes:
177      actual = np_array_ops.ones(s)
178      expected = np.ones(s)
179      msg = 'shape: {}'.format(s)
180      self.match(actual, expected, msg)
181
182    for s, t in itertools.product(self.all_shapes, self.all_types):
183      actual = np_array_ops.ones(s, t)
184      expected = np.ones(s, t)
185      msg = 'shape: {}, dtype: {}'.format(s, t)
186      self.match(actual, expected, msg)
187
188  def testOnesLike(self):
189    for a in self.all_arrays:
190      actual = np_array_ops.ones_like(a)
191      expected = np.ones_like(a)
192      msg = 'array: {}'.format(a)
193      self.match(actual, expected, msg)
194
195    for a, t in itertools.product(self.all_arrays, self.all_types):
196      actual = np_array_ops.ones_like(a, t)
197      expected = np.ones_like(a, t)
198      msg = 'array: {} type: {}'.format(a, t)
199      self.match(actual, expected, msg)
200
201  def testEye(self):
202    n_max = 3
203    m_max = 3
204
205    for n in range(1, n_max + 1):
206      self.match(np_array_ops.eye(n), np.eye(n))
207      for k in range(-n, n + 1):
208        self.match(np_array_ops.eye(n, k=k), np.eye(n, k=k))
209      for m in range(1, m_max + 1):
210        self.match(np_array_ops.eye(n, m), np.eye(n, m))
211        for k in range(-n, m):
212          self.match(np_array_ops.eye(n, k=k), np.eye(n, k=k))
213          self.match(np_array_ops.eye(n, m, k), np.eye(n, m, k))
214
215    for dtype in self.all_types:
216      for n in range(1, n_max + 1):
217        self.match(np_array_ops.eye(n, dtype=dtype), np.eye(n, dtype=dtype))
218        for k in range(-n, n + 1):
219          self.match(
220              np_array_ops.eye(n, k=k, dtype=dtype),
221              np.eye(n, k=k, dtype=dtype))
222        for m in range(1, m_max + 1):
223          self.match(
224              np_array_ops.eye(n, m, dtype=dtype), np.eye(n, m, dtype=dtype))
225          for k in range(-n, m):
226            self.match(
227                np_array_ops.eye(n, k=k, dtype=dtype),
228                np.eye(n, k=k, dtype=dtype))
229            self.match(
230                np_array_ops.eye(n, m, k, dtype=dtype),
231                np.eye(n, m, k, dtype=dtype))
232
233  def testIdentity(self):
234    n_max = 3
235
236    for n in range(1, n_max + 1):
237      self.match(np_array_ops.identity(n), np.identity(n))
238
239    for dtype in self.all_types:
240      for n in range(1, n_max + 1):
241        self.match(
242            np_array_ops.identity(n, dtype=dtype), np.identity(n, dtype=dtype))
243
244  def testFull(self):
245    # List of 2-tuples of fill value and shape.
246    data = [
247        (5, ()),
248        (5, (7,)),
249        (5., (7,)),
250        ([5, 8], (2,)),
251        ([5, 8], (3, 2)),
252        ([[5], [8]], (2, 3)),
253        ([[5], [8]], (3, 2, 5)),
254        ([[5.], [8.]], (3, 2, 5)),
255        ([[3, 4], [5, 6], [7, 8]], (3, 3, 2)),
256    ]
257    for f, s in data:
258      for fn1, fn2 in itertools.product(self.array_transforms,
259                                        self.shape_transforms):
260        fill_value = fn1(f)
261        shape = fn2(s)
262        self.match(
263            np_array_ops.full(shape, fill_value), np.full(shape, fill_value))
264        for dtype in self.all_types:
265          self.match(
266              np_array_ops.full(shape, fill_value, dtype=dtype),
267              np.full(shape, fill_value, dtype=dtype))
268
269  def testFullLike(self):
270    # List of 2-tuples of fill value and shape.
271    data = [
272        (5, ()),
273        (5, (7,)),
274        (5., (7,)),
275        ([5, 8], (2,)),
276        ([5, 8], (3, 2)),
277        ([[5], [8]], (2, 3)),
278        ([[5], [8]], (3, 2, 5)),
279        ([[5.], [8.]], (3, 2, 5)),
280    ]
281    zeros_builders = [np_array_ops.zeros, np.zeros]
282    for f, s in data:
283      for fn1, fn2, arr_dtype in itertools.product(self.array_transforms,
284                                                   zeros_builders,
285                                                   self.all_types):
286        fill_value = fn1(f)
287        arr = fn2(s, arr_dtype)
288        self.match(
289            np_array_ops.full_like(arr, fill_value),
290            np.full_like(arr, fill_value))
291        for dtype in self.all_types:
292          self.match(
293              np_array_ops.full_like(arr, fill_value, dtype=dtype),
294              np.full_like(arr, fill_value, dtype=dtype))
295
296  def testArray(self):
297    ndmins = [0, 1, 2, 5]
298    for a, dtype, ndmin, copy in itertools.product(self.all_arrays,
299                                                   self.all_types, ndmins,
300                                                   [True, False]):
301      self.match(
302          np_array_ops.array(a, dtype=dtype, ndmin=ndmin, copy=copy),
303          np.array(a, dtype=dtype, ndmin=ndmin, copy=copy))
304
305    zeros_list = np_array_ops.zeros(5)
306
307    def test_copy_equal_false():
308      # Backing tensor is the same if copy=False, other attributes being None.
309      self.assertIs(np_array_ops.array(zeros_list, copy=False), zeros_list)
310      self.assertIs(np_array_ops.array(zeros_list, copy=False), zeros_list)
311
312      # Backing tensor is different if ndmin is not satisfied.
313      self.assertIsNot(
314          np_array_ops.array(zeros_list, copy=False, ndmin=2),
315          zeros_list)
316      self.assertIsNot(
317          np_array_ops.array(zeros_list, copy=False, ndmin=2),
318          zeros_list)
319      self.assertIs(
320          np_array_ops.array(zeros_list, copy=False, ndmin=1),
321          zeros_list)
322      self.assertIs(
323          np_array_ops.array(zeros_list, copy=False, ndmin=1),
324          zeros_list)
325
326      # Backing tensor is different if dtype is not satisfied.
327      self.assertIsNot(
328          np_array_ops.array(zeros_list, copy=False, dtype=int),
329          zeros_list)
330      self.assertIsNot(
331          np_array_ops.array(zeros_list, copy=False, dtype=int),
332          zeros_list)
333      self.assertIs(
334          np_array_ops.array(zeros_list, copy=False, dtype=float),
335          zeros_list)
336      self.assertIs(
337          np_array_ops.array(zeros_list, copy=False, dtype=float),
338          zeros_list)
339
340    test_copy_equal_false()
341    with ops.device('CPU:1'):
342      test_copy_equal_false()
343
344    self.assertNotIn('CPU:1', zeros_list.backing_device)
345    with ops.device('CPU:1'):
346      self.assertIn(
347          'CPU:1', np_array_ops.array(zeros_list, copy=True).backing_device)
348      self.assertIn(
349          'CPU:1', np_array_ops.array(np.array(0), copy=True).backing_device)
350
351  def testAsArray(self):
352    for a, dtype in itertools.product(self.all_arrays, self.all_types):
353      self.match(
354          np_array_ops.asarray(a, dtype=dtype), np.asarray(a, dtype=dtype))
355
356    zeros_list = np_array_ops.zeros(5)
357    # Same instance is returned if no dtype is specified and input is ndarray.
358    self.assertIs(np_array_ops.asarray(zeros_list), zeros_list)
359    with ops.device('CPU:1'):
360      self.assertIs(np_array_ops.asarray(zeros_list), zeros_list)
361    # Different instance is returned if dtype is specified and input is ndarray.
362    self.assertIsNot(np_array_ops.asarray(zeros_list, dtype=int), zeros_list)
363
364  def testAsAnyArray(self):
365    for a, dtype in itertools.product(self.all_arrays, self.all_types):
366      self.match(
367          np_array_ops.asanyarray(a, dtype=dtype),
368          np.asanyarray(a, dtype=dtype))
369    zeros_list = np_array_ops.zeros(5)
370    # Same instance is returned if no dtype is specified and input is ndarray.
371    self.assertIs(np_array_ops.asanyarray(zeros_list), zeros_list)
372    with ops.device('CPU:1'):
373      self.assertIs(np_array_ops.asanyarray(zeros_list), zeros_list)
374    # Different instance is returned if dtype is specified and input is ndarray.
375    self.assertIsNot(np_array_ops.asanyarray(zeros_list, dtype=int), zeros_list)
376
377  def testAsContiguousArray(self):
378    for a, dtype in itertools.product(self.all_arrays, self.all_types):
379      self.match(
380          np_array_ops.ascontiguousarray(a, dtype=dtype),
381          np.ascontiguousarray(a, dtype=dtype))
382
383  def testARange(self):
384    int_values = np.arange(-3, 3).tolist()
385    float_values = np.arange(-3.5, 3.5).tolist()
386    all_values = int_values + float_values
387    for dtype in self.all_types:
388      for start in all_values:
389        msg = 'dtype:{} start:{}'.format(dtype, start)
390        self.match(np_array_ops.arange(start), np.arange(start), msg=msg)
391        self.match(
392            np_array_ops.arange(start, dtype=dtype),
393            np.arange(start, dtype=dtype),
394            msg=msg)
395        for stop in all_values:
396          msg = 'dtype:{} start:{} stop:{}'.format(dtype, start, stop)
397          self.match(
398              np_array_ops.arange(start, stop), np.arange(start, stop), msg=msg)
399          # TODO(srbs): Investigate and remove check.
400          # There are some bugs when start or stop is float and dtype is int.
401          if not isinstance(start, float) and not isinstance(stop, float):
402            self.match(
403                np_array_ops.arange(start, stop, dtype=dtype),
404                np.arange(start, stop, dtype=dtype),
405                msg=msg)
406          # Note: We intentionally do not test with float values for step
407          # because numpy.arange itself returns inconsistent results. e.g.
408          # np.arange(0.5, 3, step=0.5, dtype=int) returns
409          # array([0, 1, 2, 3, 4])
410          for step in int_values:
411            msg = 'dtype:{} start:{} stop:{} step:{}'.format(
412                dtype, start, stop, step)
413            if not step:
414              with self.assertRaises(ValueError):
415                self.match(
416                    np_array_ops.arange(start, stop, step),
417                    np.arange(start, stop, step),
418                    msg=msg)
419                if not isinstance(start, float) and not isinstance(stop, float):
420                  self.match(
421                      np_array_ops.arange(start, stop, step, dtype=dtype),
422                      np.arange(start, stop, step, dtype=dtype),
423                      msg=msg)
424            else:
425              self.match(
426                  np_array_ops.arange(start, stop, step),
427                  np.arange(start, stop, step),
428                  msg=msg)
429              if not isinstance(start, float) and not isinstance(stop, float):
430                self.match(
431                    np_array_ops.arange(start, stop, step, dtype=dtype),
432                    np.arange(start, stop, step, dtype=dtype),
433                    msg=msg)
434
435  def testDiag(self):
436    array_transforms = [
437        lambda x: x,  # Identity,
438        ops.convert_to_tensor,
439        np.array,
440        lambda x: np.array(x, dtype=np.float32),
441        lambda x: np.array(x, dtype=np.float64),
442        np_array_ops.array,
443        lambda x: np_array_ops.array(x, dtype=np.float32),
444        lambda x: np_array_ops.array(x, dtype=np.float64)
445    ]
446
447    def run_test(arr):
448      for fn in array_transforms:
449        arr = fn(arr)
450        self.match(
451            np_array_ops.diag(arr), np.diag(arr), msg='diag({})'.format(arr))
452        for k in range(-3, 3):
453          self.match(
454              np_array_ops.diag(arr, k),
455              np.diag(arr, k),
456              msg='diag({}, k={})'.format(arr, k))
457
458    # 2-d arrays.
459    run_test(np.arange(9).reshape((3, 3)).tolist())
460    run_test(np.arange(6).reshape((2, 3)).tolist())
461    run_test(np.arange(6).reshape((3, 2)).tolist())
462    run_test(np.arange(3).reshape((1, 3)).tolist())
463    run_test(np.arange(3).reshape((3, 1)).tolist())
464    run_test([[5]])
465    run_test([[]])
466    run_test([[], []])
467
468    # 1-d arrays.
469    run_test([])
470    run_test([1])
471    run_test([1, 2])
472
473  def testDiagFlat(self):
474    array_transforms = [
475        lambda x: x,  # Identity,
476        ops.convert_to_tensor,
477        np.array,
478        lambda x: np.array(x, dtype=np.float32),
479        lambda x: np.array(x, dtype=np.float64),
480        np_array_ops.array,
481        lambda x: np_array_ops.array(x, dtype=np.float32),
482        lambda x: np_array_ops.array(x, dtype=np.float64)
483    ]
484
485    def run_test(arr):
486      for fn in array_transforms:
487        arr = fn(arr)
488        self.match(
489            np_array_ops.diagflat(arr),
490            np.diagflat(arr),
491            msg='diagflat({})'.format(arr))
492        for k in range(-3, 3):
493          self.match(
494              np_array_ops.diagflat(arr, k),
495              np.diagflat(arr, k),
496              msg='diagflat({}, k={})'.format(arr, k))
497
498    # 1-d arrays.
499    run_test([])
500    run_test([1])
501    run_test([1, 2])
502    # 2-d arrays.
503    run_test([[]])
504    run_test([[5]])
505    run_test([[], []])
506    run_test(np.arange(4).reshape((2, 2)).tolist())
507    run_test(np.arange(2).reshape((2, 1)).tolist())
508    run_test(np.arange(2).reshape((1, 2)).tolist())
509    # 3-d arrays
510    run_test(np.arange(8).reshape((2, 2, 2)).tolist())
511
512  def match_shape(self, actual, expected, msg=None):
513    if msg:
514      msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format(
515          msg, expected.shape, actual.shape)
516    self.assertEqual(actual.shape, expected.shape, msg=msg)
517
518  def match_dtype(self, actual, expected, msg=None):
519    if msg:
520      msg = 'Dtype match failed for: {}. Expected: {} Actual: {}.'.format(
521          msg, expected.dtype, actual.dtype)
522    self.assertEqual(actual.dtype, expected.dtype, msg=msg)
523
524  def match(self, actual, expected, msg=None, almost=False, decimal=7):
525    msg_ = 'Expected: {} Actual: {}'.format(expected, actual)
526    if msg:
527      msg = '{} {}'.format(msg_, msg)
528    else:
529      msg = msg_
530    self.assertIsInstance(actual, np_arrays.ndarray)
531    self.match_dtype(actual, expected, msg)
532    self.match_shape(actual, expected, msg)
533    if not almost:
534      if not actual.shape.rank:
535        self.assertEqual(actual.tolist(), expected.tolist())
536      else:
537        self.assertSequenceEqual(actual.tolist(), expected.tolist())
538    else:
539      np.testing.assert_almost_equal(
540          actual.tolist(), expected.tolist(), decimal=decimal)
541
542  def testIndexedSlices(self):
543    dtype = dtypes.int64
544    iss = indexed_slices.IndexedSlices(
545        values=np_array_ops.ones([2, 3], dtype=dtype),
546        indices=constant_op.constant([1, 9]),
547        dense_shape=[10, 3])
548    a = np_array_ops.array(iss, copy=False)
549    expected = array_ops.scatter_nd([[1], [9]],
550                                    array_ops.ones([2, 3], dtype=dtype),
551                                    [10, 3])
552    self.assertAllEqual(expected, a)
553
554
555class ArrayMethodsTest(test.TestCase):
556
557  def setUp(self):
558    super(ArrayMethodsTest, self).setUp()
559    set_up_virtual_devices()
560    self.array_transforms = [
561        lambda x: x,
562        ops.convert_to_tensor,
563        np.array,
564        np_array_ops.array,
565    ]
566
567  def testAllAny(self):
568
569    def run_test(arr, *args, **kwargs):
570      for fn in self.array_transforms:
571        arr = fn(arr)
572        self.match(
573            np_array_ops.all(arr, *args, **kwargs),
574            np.all(arr, *args, **kwargs))
575        self.match(
576            np_array_ops.any(arr, *args, **kwargs),
577            np.any(arr, *args, **kwargs))
578
579    run_test(0)
580    run_test(1)
581    run_test([])
582    run_test([[True, False], [True, True]])
583    run_test([[True, False], [True, True]], axis=0)
584    run_test([[True, False], [True, True]], axis=0, keepdims=True)
585    run_test([[True, False], [True, True]], axis=1)
586    run_test([[True, False], [True, True]], axis=1, keepdims=True)
587    run_test([[True, False], [True, True]], axis=(0, 1))
588    run_test([[True, False], [True, True]], axis=(0, 1), keepdims=True)
589    run_test([5.2, 3.5], axis=0)
590    run_test([1, 0], axis=0)
591
592  def testCompress(self):
593
594    def run_test(condition, arr, *args, **kwargs):
595      for fn1 in self.array_transforms:
596        for fn2 in self.array_transforms:
597          arg1 = fn1(condition)
598          arg2 = fn2(arr)
599          self.match(
600              np_array_ops.compress(arg1, arg2, *args, **kwargs),
601              np.compress(
602                  np.asarray(arg1).astype(np.bool), arg2, *args, **kwargs))
603
604    run_test([True], 5)
605    run_test([False], 5)
606    run_test([], 5)
607    run_test([True, False, True], [1, 2, 3])
608    run_test([True, False], [1, 2, 3])
609    run_test([False, True], [[1, 2], [3, 4]])
610    run_test([1, 0, 1], [1, 2, 3])
611    run_test([1, 0], [1, 2, 3])
612    run_test([0, 1], [[1, 2], [3, 4]])
613    run_test([True], [[1, 2], [3, 4]])
614    run_test([False, True], [[1, 2], [3, 4]], axis=1)
615    run_test([False, True], [[1, 2], [3, 4]], axis=0)
616    run_test([False, True], [[1, 2], [3, 4]], axis=-1)
617    run_test([False, True], [[1, 2], [3, 4]], axis=-2)
618
619  def testCopy(self):
620
621    def run_test(arr, *args, **kwargs):
622      for fn in self.array_transforms:
623        arg = fn(arr)
624        self.match(
625            np_array_ops.copy(arg, *args, **kwargs),
626            np.copy(arg, *args, **kwargs))
627
628    run_test([])
629    run_test([1, 2, 3])
630    run_test([1., 2., 3.])
631    run_test([True])
632    run_test(np.arange(9).reshape((3, 3)).tolist())
633
634    a = np_array_ops.asarray(0)
635    self.assertNotIn('CPU:1', a.backing_device)
636    with ops.device('CPU:1'):
637      self.assertIn('CPU:1', np_array_ops.array(a, copy=True)
638                    .backing_device)
639      self.assertIn('CPU:1', np_array_ops.array(np.array(0), copy=True)
640                    .backing_device)
641
642  def testCumProdAndSum(self):
643
644    def run_test(arr, *args, **kwargs):
645      for fn in self.array_transforms:
646        arg = fn(arr)
647        self.match(
648            np_array_ops.cumprod(arg, *args, **kwargs),
649            np.cumprod(arg, *args, **kwargs))
650        self.match(
651            np_array_ops.cumsum(arg, *args, **kwargs),
652            np.cumsum(arg, *args, **kwargs))
653
654    run_test([])
655    run_test([1, 2, 3])
656    run_test([1, 2, 3], dtype=float)
657    run_test([1, 2, 3], dtype=np.float32)
658    run_test([1, 2, 3], dtype=np.float64)
659    run_test([1., 2., 3.])
660    run_test([1., 2., 3.], dtype=int)
661    run_test([1., 2., 3.], dtype=np.int32)
662    run_test([1., 2., 3.], dtype=np.int64)
663    run_test([[1, 2], [3, 4]], axis=1)
664    run_test([[1, 2], [3, 4]], axis=0)
665    run_test([[1, 2], [3, 4]], axis=-1)
666    run_test([[1, 2], [3, 4]], axis=-2)
667
668  def testImag(self):
669
670    def run_test(arr, *args, **kwargs):
671      for fn in self.array_transforms:
672        arg = fn(arr)
673        self.match(
674            np_array_ops.imag(arg, *args, **kwargs),
675            # np.imag may return a scalar so we convert to a np.ndarray.
676            np.array(np.imag(arg, *args, **kwargs)))
677
678    run_test(1)
679    run_test(5.5)
680    run_test(5 + 3j)
681    run_test(3j)
682    run_test([])
683    run_test([1, 2, 3])
684    run_test([1 + 5j, 2 + 3j])
685    run_test([[1 + 5j, 2 + 3j], [1 + 7j, 2 + 8j]])
686
687  def testAMaxAMin(self):
688
689    def run_test(arr, *args, **kwargs):
690      axis = kwargs.pop('axis', None)
691      for fn1 in self.array_transforms:
692        for fn2 in self.array_transforms:
693          arr_arg = fn1(arr)
694          axis_arg = fn2(axis) if axis is not None else None
695          self.match(
696              np_array_ops.amax(arr_arg, axis=axis_arg, *args, **kwargs),
697              np.amax(arr_arg, axis=axis, *args, **kwargs))
698          self.match(
699              np_array_ops.amin(arr_arg, axis=axis_arg, *args, **kwargs),
700              np.amin(arr_arg, axis=axis, *args, **kwargs))
701
702    run_test([1, 2, 3])
703    run_test([1., 2., 3.])
704    run_test([[1, 2], [3, 4]], axis=1)
705    run_test([[1, 2], [3, 4]], axis=0)
706    run_test([[1, 2], [3, 4]], axis=-1)
707    run_test([[1, 2], [3, 4]], axis=-2)
708    run_test([[1, 2], [3, 4]], axis=(0, 1))
709    run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(0, 2))
710    run_test(
711        np.arange(8).reshape((2, 2, 2)).tolist(), axis=(0, 2), keepdims=True)
712    run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(2, 0))
713    run_test(
714        np.arange(8).reshape((2, 2, 2)).tolist(), axis=(2, 0), keepdims=True)
715
716  def testMean(self):
717
718    def run_test(arr, *args, **kwargs):
719      axis = kwargs.pop('axis', None)
720      for fn1 in self.array_transforms:
721        for fn2 in self.array_transforms:
722          arr_arg = fn1(arr)
723          axis_arg = fn2(axis) if axis is not None else None
724          self.match(
725              np_array_ops.mean(arr_arg, axis=axis_arg, *args, **kwargs),
726              np.mean(arr_arg, axis=axis, *args, **kwargs))
727
728    run_test([1, 2, 1])
729    run_test([1., 2., 1.])
730    run_test([1., 2., 1.], dtype=int)
731    run_test([[1, 2], [3, 4]], axis=1)
732    run_test([[1, 2], [3, 4]], axis=0)
733    run_test([[1, 2], [3, 4]], axis=-1)
734    run_test([[1, 2], [3, 4]], axis=-2)
735    run_test([[1, 2], [3, 4]], axis=(0, 1))
736    run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(0, 2))
737    run_test(
738        np.arange(8).reshape((2, 2, 2)).tolist(), axis=(0, 2), keepdims=True)
739    run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(2, 0))
740    run_test(
741        np.arange(8).reshape((2, 2, 2)).tolist(), axis=(2, 0), keepdims=True)
742
743  def testProd(self):
744
745    def run_test(arr, *args, **kwargs):
746      for fn in self.array_transforms:
747        arg = fn(arr)
748        self.match(
749            np_array_ops.prod(arg, *args, **kwargs),
750            np.prod(arg, *args, **kwargs))
751
752    run_test([1, 2, 3])
753    run_test([1., 2., 3.])
754    run_test(np.array([1, 2, 3], dtype=np.int16))
755    run_test([[1, 2], [3, 4]], axis=1)
756    run_test([[1, 2], [3, 4]], axis=0)
757    run_test([[1, 2], [3, 4]], axis=-1)
758    run_test([[1, 2], [3, 4]], axis=-2)
759    run_test([[1, 2], [3, 4]], axis=(0, 1))
760    run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(0, 2))
761    run_test(
762        np.arange(8).reshape((2, 2, 2)).tolist(), axis=(0, 2), keepdims=True)
763    run_test(np.arange(8).reshape((2, 2, 2)).tolist(), axis=(2, 0))
764    run_test(
765        np.arange(8).reshape((2, 2, 2)).tolist(), axis=(2, 0), keepdims=True)
766
767  def _testReduce(self, math_fun, np_fun, name):
768    axis_transforms = [
769        lambda x: x,  # Identity,
770        ops.convert_to_tensor,
771        np.array,
772        np_array_ops.array,
773        lambda x: np_array_ops.array(x, dtype=np.float32),
774        lambda x: np_array_ops.array(x, dtype=np.float64),
775    ]
776
777    def run_test(a, **kwargs):
778      axis = kwargs.pop('axis', None)
779      for fn1 in self.array_transforms:
780        for fn2 in axis_transforms:
781          arg1 = fn1(a)
782          axis_arg = fn2(axis) if axis is not None else None
783          self.match(
784              math_fun(arg1, axis=axis_arg, **kwargs),
785              np_fun(arg1, axis=axis, **kwargs),
786              msg='{}({}, axis={}, keepdims={})'.format(name, arg1, axis,
787                                                        kwargs.get('keepdims')))
788
789    run_test(5)
790    run_test([2, 3])
791    run_test([[2, -3], [-6, 7]])
792    run_test([[2, -3], [-6, 7]], axis=0)
793    run_test([[2, -3], [-6, 7]], axis=0, keepdims=True)
794    run_test([[2, -3], [-6, 7]], axis=1)
795    run_test([[2, -3], [-6, 7]], axis=1, keepdims=True)
796    run_test([[2, -3], [-6, 7]], axis=(0, 1))
797    run_test([[2, -3], [-6, 7]], axis=(1, 0))
798
799  def testSum(self):
800    self._testReduce(np_array_ops.sum, np.sum, 'sum')
801
802  def testAmax(self):
803    self._testReduce(np_array_ops.amax, np.amax, 'amax')
804
805  def testSize(self):
806
807    def run_test(arr, axis=None):
808      onp_arr = np.array(arr)
809      self.assertEqual(np_array_ops.size(arr, axis), np.size(onp_arr, axis))
810
811    run_test(np_array_ops.array([1]))
812    run_test(np_array_ops.array([1, 2, 3, 4, 5]))
813    run_test(np_array_ops.ones((2, 3, 2)))
814    run_test(np_array_ops.ones((3, 2)))
815    run_test(np_array_ops.zeros((5, 6, 7)))
816    run_test(1)
817    run_test(np_array_ops.ones((3, 2, 1)))
818    run_test(constant_op.constant(5))
819    run_test(constant_op.constant([1, 1, 1]))
820    self.assertRaises(NotImplementedError, np_array_ops.size, np.ones((2, 2)),
821                      1)
822
823    @def_function.function(input_signature=[
824        tensor_spec.TensorSpec(dtype=dtypes.float64, shape=None)])
825    def f(arr):
826      arr = np_array_ops.asarray(arr)
827      return np_array_ops.size(arr)
828
829    self.assertEqual(f(np_array_ops.ones((3, 2))).numpy(), 6)
830
831  def testRavel(self):
832
833    def run_test(arr, *args, **kwargs):
834      for fn in self.array_transforms:
835        arg = fn(arr)
836        self.match(
837            np_array_ops.ravel(arg, *args, **kwargs),
838            np.ravel(arg, *args, **kwargs))
839
840    run_test(5)
841    run_test(5.)
842    run_test([])
843    run_test([[]])
844    run_test([[], []])
845    run_test([1, 2, 3])
846    run_test([1., 2., 3.])
847    run_test([[1, 2], [3, 4]])
848    run_test(np.arange(8).reshape((2, 2, 2)).tolist())
849
850  def testReal(self):
851
852    def run_test(arr, *args, **kwargs):
853      for fn in self.array_transforms:
854        arg = fn(arr)
855        self.match(
856            np_array_ops.real(arg, *args, **kwargs),
857            np.array(np.real(arg, *args, **kwargs)))
858
859    run_test(1)
860    run_test(5.5)
861    run_test(5 + 3j)
862    run_test(3j)
863    run_test([])
864    run_test([1, 2, 3])
865    run_test([1 + 5j, 2 + 3j])
866    run_test([[1 + 5j, 2 + 3j], [1 + 7j, 2 + 8j]])
867
868  def testRepeat(self):
869
870    def run_test(arr, repeats, *args, **kwargs):
871      for fn1 in self.array_transforms:
872        for fn2 in self.array_transforms:
873          arr_arg = fn1(arr)
874          repeats_arg = fn2(repeats)
875          self.match(
876              np_array_ops.repeat(arr_arg, repeats_arg, *args, **kwargs),
877              np.repeat(arr_arg, repeats_arg, *args, **kwargs))
878
879    run_test(1, 2)
880    run_test([1, 2], 2)
881    run_test([1, 2], [2])
882    run_test([1, 2], [1, 2])
883    run_test([[1, 2], [3, 4]], 3, axis=0)
884    run_test([[1, 2], [3, 4]], 3, axis=1)
885    run_test([[1, 2], [3, 4]], [3], axis=0)
886    run_test([[1, 2], [3, 4]], [3], axis=1)
887    run_test([[1, 2], [3, 4]], [3, 2], axis=0)
888    run_test([[1, 2], [3, 4]], [3, 2], axis=1)
889    run_test([[1, 2], [3, 4]], [3, 2], axis=-1)
890    run_test([[1, 2], [3, 4]], [3, 2], axis=-2)
891
892  def testAround(self):
893
894    def run_test(arr, *args, **kwargs):
895      for fn in self.array_transforms:
896        arg = fn(arr)
897        self.match(
898            np_array_ops.around(arg, *args, **kwargs),
899            np.around(arg, *args, **kwargs))
900
901    run_test(5.5)
902    run_test(5.567, decimals=2)
903    run_test([])
904    run_test([1.27, 2.49, 2.75], decimals=1)
905    run_test([23.6, 45.1], decimals=-1)
906
907  def testReshape(self):
908
909    def run_test(arr, newshape, *args, **kwargs):
910      for fn1 in self.array_transforms:
911        for fn2 in self.array_transforms:
912          arr_arg = fn1(arr)
913          newshape_arg = fn2(newshape)
914          self.match(
915              np_array_ops.reshape(arr_arg, newshape_arg, *args, **kwargs),
916              np.reshape(arr_arg, newshape, *args, **kwargs))
917
918    run_test(5, [-1])
919    run_test([], [-1])
920    run_test([1, 2, 3], [1, 3])
921    run_test([1, 2, 3], [3, 1])
922    run_test([1, 2, 3, 4], [2, 2])
923    run_test([1, 2, 3, 4], [2, 1, 2])
924
925  def testExpandDims(self):
926
927    def run_test(arr, axis):
928      self.match(np_array_ops.expand_dims(arr, axis), np.expand_dims(arr, axis))
929
930    run_test([1, 2, 3], 0)
931    run_test([1, 2, 3], 1)
932
933  def testSqueeze(self):
934
935    def run_test(arr, *args, **kwargs):
936      for fn in self.array_transforms:
937        arg = fn(arr)
938        # Note: np.squeeze ignores the axis arg for non-ndarray objects.
939        # This looks like a bug: https://github.com/numpy/numpy/issues/8201
940        # So we convert the arg to np.ndarray before passing to np.squeeze.
941        self.match(
942            np_array_ops.squeeze(arg, *args, **kwargs),
943            np.squeeze(np.array(arg), *args, **kwargs))
944
945    run_test(5)
946    run_test([])
947    run_test([5])
948    run_test([[1, 2, 3]])
949    run_test([[[1], [2], [3]]])
950    run_test([[[1], [2], [3]]], axis=0)
951    run_test([[[1], [2], [3]]], axis=2)
952    run_test([[[1], [2], [3]]], axis=(0, 2))
953    run_test([[[1], [2], [3]]], axis=-1)
954    run_test([[[1], [2], [3]]], axis=-3)
955
956  def testTranspose(self):
957
958    def run_test(arr, axes=None):
959      for fn1 in self.array_transforms:
960        for fn2 in self.array_transforms:
961          arr_arg = fn1(arr)
962          axes_arg = fn2(axes) if axes is not None else None
963          self.match(
964              np_array_ops.transpose(arr_arg, axes_arg),
965              np.transpose(arr_arg, axes))
966
967    run_test(5)
968    run_test([])
969    run_test([5])
970    run_test([5, 6, 7])
971    run_test(np.arange(30).reshape(2, 3, 5).tolist())
972    run_test(np.arange(30).reshape(2, 3, 5).tolist(), [0, 1, 2])
973    run_test(np.arange(30).reshape(2, 3, 5).tolist(), [0, 2, 1])
974    run_test(np.arange(30).reshape(2, 3, 5).tolist(), [1, 0, 2])
975    run_test(np.arange(30).reshape(2, 3, 5).tolist(), [1, 2, 0])
976    run_test(np.arange(30).reshape(2, 3, 5).tolist(), [2, 0, 1])
977    run_test(np.arange(30).reshape(2, 3, 5).tolist(), [2, 1, 0])
978
979  def match_shape(self, actual, expected, msg=None):
980    if msg:
981      msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format(
982          msg, expected.shape, actual.shape)
983    self.assertEqual(actual.shape, expected.shape, msg=msg)
984
985  def match_dtype(self, actual, expected, msg=None):
986    if msg:
987      msg = 'Dtype match failed for: {}. Expected: {} Actual: {}.'.format(
988          msg, expected.dtype, actual.dtype)
989    self.assertEqual(actual.dtype, expected.dtype, msg=msg)
990
991  def match(self, actual, expected, msg=None, check_dtype=True):
992    msg_ = 'Expected: {} Actual: {}'.format(expected, actual)
993    if msg:
994      msg = '{} {}'.format(msg_, msg)
995    else:
996      msg = msg_
997    self.assertIsInstance(actual, np_arrays.ndarray)
998    if check_dtype:
999      self.match_dtype(actual, expected, msg)
1000    self.match_shape(actual, expected, msg)
1001    if not actual.shape.rank:
1002      self.assertAllClose(actual.tolist(), expected.tolist())
1003    else:
1004      self.assertAllClose(actual.tolist(), expected.tolist())
1005
1006  def testPad(self):
1007    t = [[1, 2, 3], [4, 5, 6]]
1008    paddings = [[
1009        1,
1010        1,
1011    ], [2, 2]]
1012    self.assertAllEqual(
1013        np_array_ops.pad(t, paddings, 'constant'),
1014        [[0, 0, 0, 0, 0, 0, 0], [0, 0, 1, 2, 3, 0, 0], [0, 0, 4, 5, 6, 0, 0],
1015         [0, 0, 0, 0, 0, 0, 0]])
1016
1017    self.assertAllEqual(
1018        np_array_ops.pad(t, paddings, 'reflect'),
1019        [[6, 5, 4, 5, 6, 5, 4], [3, 2, 1, 2, 3, 2, 1], [6, 5, 4, 5, 6, 5, 4],
1020         [3, 2, 1, 2, 3, 2, 1]])
1021
1022    self.assertAllEqual(
1023        np_array_ops.pad(t, paddings, 'symmetric'),
1024        [[2, 1, 1, 2, 3, 3, 2], [2, 1, 1, 2, 3, 3, 2], [5, 4, 4, 5, 6, 6, 5],
1025         [5, 4, 4, 5, 6, 6, 5]])
1026
1027  def testTake(self):
1028    a = [4, 3, 5, 7, 6, 8]
1029    indices = [0, 1, 4]
1030    self.assertAllEqual([4, 3, 6], np_array_ops.take(a, indices))
1031    indices = [[0, 1], [2, 3]]
1032    self.assertAllEqual([[4, 3], [5, 7]], np_array_ops.take(a, indices))
1033    a = [[4, 3, 5], [7, 6, 8]]
1034    self.assertAllEqual([[4, 3], [5, 7]], np_array_ops.take(a, indices))
1035    a = np.random.rand(2, 16, 3)
1036    axis = 1
1037    self.assertAllEqual(
1038        np.take(a, indices, axis=axis),
1039        np_array_ops.take(a, indices, axis=axis))
1040
1041  def testWhere(self):
1042    self.assertAllEqual([[1.0, 1.0], [1.0, 1.0]],
1043                        np_array_ops.where([True], [1.0, 1.0],
1044                                           [[0, 0], [0, 0]]))
1045
1046  def testShape(self):
1047    self.assertAllEqual((1, 2), np_array_ops.shape([[0, 0]]))
1048
1049  def testSwapaxes(self):
1050    x = [[1, 2, 3]]
1051    self.assertAllEqual([[1], [2], [3]], np_array_ops.swapaxes(x, 0, 1))
1052    self.assertAllEqual([[1], [2], [3]], np_array_ops.swapaxes(x, -2, -1))
1053    x = [[[0, 1], [2, 3]], [[4, 5], [6, 7]]]
1054    self.assertAllEqual([[[0, 4], [2, 6]], [[1, 5], [3, 7]]],
1055                        np_array_ops.swapaxes(x, 0, 2))
1056    self.assertAllEqual([[[0, 4], [2, 6]], [[1, 5], [3, 7]]],
1057                        np_array_ops.swapaxes(x, -3, -1))
1058
1059  def testMoveaxis(self):
1060
1061    def _test(*args):
1062      expected = np.moveaxis(*args)
1063      raw_ans = np_array_ops.moveaxis(*args)
1064
1065      self.assertAllEqual(expected, raw_ans)
1066
1067    a = np.random.rand(1, 2, 3, 4, 5, 6)
1068
1069    # Basic
1070    _test(a, (0, 2), (3, 5))
1071    _test(a, (0, 2), (-1, -3))
1072    _test(a, (-6, -4), (3, 5))
1073    _test(a, (-6, -4), (-1, -3))
1074    _test(a, 0, 4)
1075    _test(a, -6, -2)
1076    _test(a, tuple(range(6)), tuple(range(6)))
1077    _test(a, tuple(range(6)), tuple(reversed(range(6))))
1078    _test(a, (), ())
1079
1080  def testNdim(self):
1081    self.assertAllEqual(0, np_array_ops.ndim(0.5))
1082    self.assertAllEqual(1, np_array_ops.ndim([1, 2]))
1083
1084  def testIsscalar(self):
1085    self.assertTrue(np_array_ops.isscalar(0.5))
1086    self.assertTrue(np_array_ops.isscalar(5))
1087    self.assertTrue(np_array_ops.isscalar(False))
1088    self.assertFalse(np_array_ops.isscalar([1, 2]))
1089
1090  def assertListEqual(self, a, b):
1091    self.assertAllEqual(len(a), len(b))
1092    for x, y in zip(a, b):
1093      self.assertAllEqual(x, y)
1094
1095  def testSplit(self):
1096    x = np_array_ops.arange(9)
1097    y = np_array_ops.split(x, 3)
1098    self.assertListEqual([([0, 1, 2]), ([3, 4, 5]), ([6, 7, 8])], y)
1099
1100    x = np_array_ops.arange(8)
1101    y = np_array_ops.split(x, [3, 5, 6, 10])
1102    self.assertListEqual([([0, 1, 2]), ([3, 4]), ([5]), ([6, 7]), ([])], y)
1103
1104  def testSign(self):
1105    state = np.random.RandomState(0)
1106    test_types = [np.float16, np.float32, np.float64, np.int32, np.int64,
1107                  np.complex64, np.complex128]
1108    test_shapes = [(), (1,), (2, 3, 4), (2, 3, 0, 4)]
1109
1110    for dtype in test_types:
1111      for shape in test_shapes:
1112        if np.issubdtype(dtype, np.complex):
1113          arr = (np.asarray(state.randn(*shape) * 100, dtype=dtype) +
1114                 1j * np.asarray(state.randn(*shape) * 100, dtype=dtype))
1115        else:
1116          arr = np.asarray(state.randn(*shape) * 100, dtype=dtype)
1117        self.match(np_array_ops.sign(arr), np.sign(arr))
1118
1119
1120class ArrayManipulationTest(test.TestCase):
1121
1122  def setUp(self):
1123    super(ArrayManipulationTest, self).setUp()
1124    self.array_transforms = [
1125        lambda x: x,
1126        ops.convert_to_tensor,
1127        np.array,
1128        np_array_ops.array,
1129    ]
1130
1131  def testBroadcastTo(self):
1132
1133    def run_test(arr, shape):
1134      for fn in self.array_transforms:
1135        arg1 = fn(arr)
1136        self.match(
1137            np_array_ops.broadcast_to(arg1, shape),
1138            np.broadcast_to(arg1, shape))
1139
1140    run_test(1, 2)
1141    run_test(1, (2, 2))
1142    run_test([1, 2], (2, 2))
1143    run_test([[1], [2]], (2, 2))
1144    run_test([[1, 2]], (3, 2))
1145    run_test([[[1, 2]], [[3, 4]], [[5, 6]]], (3, 4, 2))
1146
1147  def testIx_(self):
1148    possible_arys = [[True, True], [True, False], [False, False],
1149                     list(range(5)), np_array_ops.empty(0, dtype=np.int64)]
1150    for r in range(len(possible_arys)):
1151      for arys in itertools.combinations_with_replacement(possible_arys, r):
1152        tnp_ans = np_array_ops.ix_(*arys)
1153        onp_ans = np.ix_(*arys)
1154        for t, o in zip(tnp_ans, onp_ans):
1155          self.match(t, o)
1156
1157  def match_shape(self, actual, expected, msg=None):
1158    if msg:
1159      msg = 'Shape match failed for: {}. Expected: {} Actual: {}'.format(
1160          msg, expected.shape, actual.shape)
1161    self.assertEqual(actual.shape, expected.shape, msg=msg)
1162
1163  def match_dtype(self, actual, expected, msg=None):
1164    if msg:
1165      msg = 'Dtype match failed for: {}. Expected: {} Actual: {}.'.format(
1166          msg, expected.dtype, actual.dtype)
1167    self.assertEqual(actual.dtype, expected.dtype, msg=msg)
1168
1169  def match(self, actual, expected, msg=None):
1170    msg_ = 'Expected: {} Actual: {}'.format(expected, actual)
1171    if msg:
1172      msg = '{} {}'.format(msg_, msg)
1173    else:
1174      msg = msg_
1175    self.assertIsInstance(actual, np_arrays.ndarray)
1176    self.match_dtype(actual, expected, msg)
1177    self.match_shape(actual, expected, msg)
1178    if not actual.shape.rank:
1179      self.assertEqual(actual.tolist(), expected.tolist())
1180    else:
1181      self.assertSequenceEqual(actual.tolist(), expected.tolist())
1182
1183
1184if __name__ == '__main__':
1185  ops.enable_eager_execution()
1186  ops.enable_numpy_style_type_promotion()
1187  np_math_ops.enable_numpy_methods_on_tensor()
1188  test.main()
1189