• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for operations in eager execution."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import gc
21import threading
22import weakref
23
24from absl.testing import parameterized
25import numpy as np
26
27from tensorflow.python.eager import context
28from tensorflow.python.eager import execute
29from tensorflow.python.eager import test
30from tensorflow.python.framework import config
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import errors_impl
34from tensorflow.python.framework import ops
35from tensorflow.python.framework import tensor_shape
36from tensorflow.python.framework import test_util
37from tensorflow.python.layers import core
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import control_flow_ops
40from tensorflow.python.ops import math_ops
41from tensorflow.python.ops import random_ops
42from tensorflow.python.ops import resource_variable_ops
43from tensorflow.python.ops import sparse_ops
44
45
46class OpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
47
48  def testExecuteBasic(self):
49    three = constant_op.constant(3)
50    five = constant_op.constant(5)
51    product = three * five
52    self.assertAllEqual(15, product)
53
54  @test_util.run_gpu_only
55  def testMatMulGPU(self):
56    three = constant_op.constant([[3.]]).gpu()
57    five = constant_op.constant([[5.]]).gpu()
58    product = math_ops.matmul(three, five)
59    self.assertEqual([[15.0]], product.numpy())
60
61  def testExecuteStringAttr(self):
62    three = constant_op.constant(3.0)
63    checked_three = array_ops.check_numerics(three,
64                                             message='just checking')
65    self.assertEqual([[3]], checked_three.numpy())
66
67  def testExecuteFloatAttr(self):
68    three = constant_op.constant(3.0)
69    almost_three = constant_op.constant(2.8)
70    almost_equal = math_ops.approximate_equal(
71        three, almost_three, tolerance=0.3)
72    self.assertTrue(almost_equal)
73
74  def testExecuteIntAttr(self):
75    three = constant_op.constant(3)
76    four = constant_op.constant(4)
77    total = math_ops.add_n([three, four])
78    self.assertAllEqual(7, total)
79
80  def testExecuteBoolAttr(self):
81    three = constant_op.constant([[3]])
82    five = constant_op.constant([[5]])
83    product = math_ops.matmul(three, five, transpose_a=True)
84    self.assertAllEqual([[15]], product)
85
86  def testExecuteOneListOutput(self):
87    split_dim = constant_op.constant(1)
88    value = constant_op.constant([[0, 1, 2], [3, 4, 5]])
89    x1, x2, x3 = array_ops.split(value, 3, axis=split_dim)
90    self.assertAllEqual([[0], [3]], x1)
91    self.assertAllEqual([[1], [4]], x2)
92    self.assertAllEqual([[2], [5]], x3)
93
94  def testGraphMode(self):
95    graph = ops.Graph()
96    with graph.as_default(), context.graph_mode():
97      array_ops.placeholder(dtypes.int32)
98    self.assertLen(graph.get_operations(), 1)
99
100  # See comments on handling of int32 tensors on GPU in
101  # EagerTensor.__init__.
102  @test_util.run_gpu_only
103  def testInt32CPUDefault(self):
104    with context.device('/gpu:0'):
105      r = constant_op.constant(1) + constant_op.constant(2)
106    self.assertAllEqual(r, 3)
107
108  def testExecuteListOutputLen1(self):
109    split_dim = constant_op.constant(1)
110    value = constant_op.constant([[0, 1, 2], [3, 4, 5]])
111    result = array_ops.split(value, 1, axis=split_dim)
112    self.assertIsInstance(result, list)
113    self.assertLen(result, 1)
114    self.assertAllEqual([[0, 1, 2], [3, 4, 5]], result[0])
115
116  def testExecuteListOutputLen0(self):
117    empty = constant_op.constant([], dtype=dtypes.int32)
118    result = array_ops.unstack(empty, 0)
119    self.assertIsInstance(result, list)
120    self.assertEmpty(result)
121
122  def testExecuteMultipleNonListOutput(self):
123    x = constant_op.constant([1, 2, 3, 4, 5, 6])
124    y = constant_op.constant([1, 3, 5])
125    result = array_ops.listdiff(x, y)
126    out, idx = result
127    self.assertIs(out, result.out)
128    self.assertIs(idx, result.idx)
129    self.assertAllEqual([2, 4, 6], out)
130    self.assertAllEqual([1, 3, 5], idx)
131
132  def testExecuteMultipleListOutput(self):
133    split_dim = constant_op.constant(1, dtype=dtypes.int64)
134    indices = constant_op.constant([[0, 2], [0, 4], [0, 5], [1, 0], [1, 1]],
135                                   dtype=dtypes.int64)
136    values = constant_op.constant([2, 3, 5, 7, 11])
137    shape = constant_op.constant([2, 7], dtype=dtypes.int64)
138    result = sparse_ops.gen_sparse_ops.sparse_split(
139        split_dim,
140        indices,
141        values,
142        shape,
143        num_split=2)
144    output_indices, output_values, output_shape = result
145    self.assertLen(output_indices, 2)
146    self.assertLen(output_values, 2)
147    self.assertLen(output_shape, 2)
148    self.assertEqual(output_indices, result.output_indices)
149    self.assertEqual(output_values, result.output_values)
150    self.assertEqual(output_shape, result.output_shape)
151    self.assertAllEqual([[0, 2], [1, 0], [1, 1]], output_indices[0])
152    self.assertAllEqual([[0, 0], [0, 1]], output_indices[1])
153    self.assertAllEqual([2, 7, 11], output_values[0])
154    self.assertAllEqual([3, 5], output_values[1])
155    self.assertAllEqual([2, 4], output_shape[0])
156    self.assertAllEqual([2, 3], output_shape[1])
157
158  # TODO(josh11b): Test an op that has multiple outputs, some but not
159  # all of which are lists. Examples: barrier_take_many (currently
160  # unsupported since it uses a type list) or sdca_optimizer (I don't
161  # have an example of legal inputs & outputs).
162
163  def testComposition(self):
164    x = constant_op.constant(1, dtype=dtypes.int32)
165    three_x = x + x + x
166    self.assertEqual(dtypes.int32, three_x.dtype)
167    self.assertAllEqual(3, three_x)
168
169  def testOperatorOverrides(self):
170
171    def ops_test(v1, v2):
172      a = constant_op.constant(v1)
173      b = constant_op.constant(v2)
174
175      self.assertAllEqual((-a), np.negative(v1))
176      self.assertAllEqual(abs(b), np.absolute(v2))
177
178      self.assertAllEqual((a + b), np.add(v1, v2))
179      self.assertAllEqual((a - b), np.subtract(v1, v2))
180      self.assertAllEqual((a * b), np.multiply(v1, v2))
181      self.assertAllEqual((a * a), np.multiply(v1, v1))
182
183      if all(x >= 0 for x in v2):
184        self.assertAllEqual((a**b), np.power(v1, v2))
185      self.assertAllEqual((a / b), np.true_divide(v1, v2))
186
187      self.assertAllEqual((a / a), np.true_divide(v1, v1))
188      self.assertAllEqual((a % b), np.mod(v1, v2))
189
190      self.assertAllEqual((a < b), np.less(v1, v2))
191      self.assertAllEqual((a <= b), np.less_equal(v1, v2))
192      self.assertAllEqual((a > b), np.greater(v1, v2))
193      self.assertAllEqual((a >= b), np.greater_equal(v1, v2))
194
195      # TODO(b/120678848): Remove the else branch once we enable
196      # ops.Tensor._USE_EQUALITY by default.
197      if ops.Tensor._USE_EQUALITY:
198        self.assertAllEqual((a == b), np.equal(v1, v2))
199        self.assertAllEqual((a != b), np.not_equal(v1, v2))
200      else:
201        self.assertAllEqual((a == b), np.equal(v1, v2)[0])
202        self.assertAllEqual((a != b), np.not_equal(v1, v2)[0])
203
204      self.assertAllEqual(v1[0], a[constant_op.constant(0)])
205
206    ops_test([1, 4, 8], [2, 3, 5])
207    ops_test([1, -4, -5], [-2, 3, -6])
208
209  def test_basic_slice(self):
210    npt = np.arange(1, 19, dtype=np.float32).reshape(3, 2, 3)
211    t = constant_op.constant(npt)
212
213    self.assertAllEqual(npt[:, :, :], t[:, :, :])
214    self.assertAllEqual(npt[::, ::, ::], t[::, ::, ::])
215    self.assertAllEqual(npt[::1, ::1, ::1], t[::1, ::1, ::1])
216    self.assertAllEqual(npt[::1, ::5, ::2], t[::1, ::5, ::2])
217    self.assertAllEqual(npt[::-1, :, :], t[::-1, :, :])
218    self.assertAllEqual(npt[:, ::-1, :], t[:, ::-1, :])
219    self.assertAllEqual(npt[:, :, ::-1], t[:, :, ::-1])
220    self.assertAllEqual(npt[-2::-1, :, ::1], t[-2::-1, :, ::1])
221    self.assertAllEqual(npt[-2::-1, :, ::2], t[-2::-1, :, ::2])
222
223  def testDegenerateSlices(self):
224    npt = np.arange(1, 19, dtype=np.float32).reshape(3, 2, 3)
225    t = constant_op.constant(npt)
226    # degenerate by offering a forward interval with a negative stride
227    self.assertAllEqual(npt[0:-1:-1, :, :], t[0:-1:-1, :, :])
228    # degenerate with a reverse interval with a positive stride
229    self.assertAllEqual(npt[-1:0, :, :], t[-1:0, :, :])
230    # empty interval in every dimension
231    self.assertAllEqual(npt[-1:0, 2:2, 2:3:-1], t[-1:0, 2:2, 2:3:-1])
232
233  def testEllipsis(self):
234    npt = np.array(
235        [[[[[1, 2], [3, 4], [5, 6]]], [[[7, 8], [9, 10], [11, 12]]]]])
236    t = constant_op.constant(npt)
237
238    self.assertAllEqual(npt[0:], t[0:])
239    # implicit ellipsis
240    self.assertAllEqual(npt[0:, ...], t[0:, ...])
241    # ellipsis alone
242    self.assertAllEqual(npt[...], t[...])
243    # ellipsis at end
244    self.assertAllEqual(npt[0:1, ...], t[0:1, ...])
245    # ellipsis at begin
246    self.assertAllEqual(npt[..., 0:1], t[..., 0:1])
247    # ellipsis at middle
248    self.assertAllEqual(npt[0:1, ..., 0:1], t[0:1, ..., 0:1])
249
250  def testShrink(self):
251    npt = np.array([[[[[1, 2, 4, 5], [5, 6, 7, 8], [9, 10, 11, 12]]],
252                     [[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]]]])
253    t = constant_op.constant(npt)
254    self.assertAllEqual(npt[:, :, :, :, 3], t[:, :, :, :, 3])
255    self.assertAllEqual(npt[..., 3], t[..., 3])
256    self.assertAllEqual(npt[:, 0], t[:, 0])
257    self.assertAllEqual(npt[:, :, 0], t[:, :, 0])
258
259  @test_util.run_gpu_only
260  def testOpWithInputsOnDifferentDevices(self):
261    # The GPU kernel for the Reshape op requires that the
262    # shape input be on CPU.
263    value = constant_op.constant([1., 2.]).gpu()
264    shape = constant_op.constant([2, 1])
265    reshaped = array_ops.reshape(value, shape)
266    self.assertAllEqual([[1], [2]], reshaped.cpu())
267
268  def testInt64(self):
269    # Fill requires the first input to be an int32 tensor.
270    self.assertAllEqual(
271        [1.0, 1.0],
272        array_ops.fill(constant_op.constant([2], dtype=dtypes.int64),
273                       constant_op.constant(1)))
274
275  @test_util.run_gpu_only
276  def testOutputOnHostMemory(self):
277    # The Shape op kernel on GPU places the output in host memory.
278    value = constant_op.constant([1.]).gpu()
279    shape = array_ops.shape(value)
280    self.assertEqual([1], shape.numpy())
281
282  @test_util.run_gpu_only
283  def testSilentCopy(self):
284    # Temporarily replace the context
285    # pylint: disable=protected-access
286    old_context = context.context()
287    context._set_context(context.Context())
288    try:
289      config.set_device_policy('silent')
290      cpu_tensor = constant_op.constant(1.0)
291      gpu_tensor = cpu_tensor.gpu()
292      self.assertAllEqual(cpu_tensor + gpu_tensor, 2.0)
293    finally:
294      context._set_context(old_context)
295    # pylint: enable=protected-access
296
297  @test_util.run_gpu_only
298  def testSoftPlacement(self):
299    # Temporarily replace the context
300    # pylint: disable=protected-access
301    old_context = context.context()
302    context._set_context(context.Context())
303    try:
304      config.set_device_policy('silent')
305      config.set_soft_device_placement(True)
306      cpu_tensor = constant_op.constant(1.0)
307      result = cpu_tensor + cpu_tensor
308      self.assertEqual(result.device,
309                       '/job:localhost/replica:0/task:0/device:GPU:0')
310    finally:
311      context._set_context(old_context)
312    # pylint: enable=protected-access
313
314  def testRandomUniform(self):
315    scalar_shape = constant_op.constant([], dtype=dtypes.int32)
316
317    x = random_ops.random_uniform(scalar_shape)
318    self.assertEqual(0, x.shape.ndims)
319    self.assertEqual(dtypes.float32, x.dtype)
320
321    x = random_ops.random_uniform(
322        scalar_shape, minval=constant_op.constant(5.),
323        maxval=constant_op.constant(6.))
324    self.assertLess(x, 6)
325    self.assertGreaterEqual(x, 5)
326
327  def testArgsToMatchingEagerDefault(self):
328    # Uses default
329    ctx = context.context()
330    allowed_dtypes = [dtypes.int32, dtypes.int64]
331
332    # Follows standard int conversion rules
333    t, r = execute.args_to_matching_eager([[3, 4]], ctx, allowed_dtypes,
334                                          dtypes.int32)
335    self.assertEqual(t, dtypes.int32)
336    self.assertEqual(r[0].dtype, dtypes.int32)
337    t, r = execute.args_to_matching_eager([[3, 4]], ctx, allowed_dtypes,
338                                          dtypes.int64)
339    self.assertEqual(t, dtypes.int32)
340    self.assertEqual(r[0].dtype, dtypes.int32)
341    # Use int64 since it is a better fit
342    t, r = execute.args_to_matching_eager([[2**48]], ctx, allowed_dtypes,
343                                          dtypes.int32)
344    self.assertEqual(t, dtypes.int64)
345    self.assertEqual(r[0].dtype, dtypes.int64)
346
347    # When the regular tensor conversion fails, then use the default type as a
348    # hint.
349    allowed_dtypes = [dtypes.uint32, dtypes.uint32]
350    t, r = execute.args_to_matching_eager([[3, 4]], ctx, allowed_dtypes,
351                                          dtypes.uint32)
352    self.assertEqual(t, dtypes.uint32)
353    self.assertEqual(r[0].dtype, dtypes.uint32)
354    t, r = execute.args_to_matching_eager([[3, 4]], ctx, allowed_dtypes,
355                                          dtypes.uint64)
356    self.assertEqual(t, dtypes.uint64)
357    self.assertEqual(r[0].dtype, dtypes.uint64)
358
359    t, r = execute.args_to_matching_eager([], ctx, allowed_dtypes, dtypes.int64)
360    self.assertEqual(t, dtypes.int64)
361
362    # Doesn't use default
363    allowed_dtypes = [dtypes.int32, dtypes.string]
364    t, r = execute.args_to_matching_eager([['string', 'arg']], ctx,
365                                          allowed_dtypes, dtypes.int32)
366    self.assertEqual(t, dtypes.string)
367    self.assertEqual(r[0].dtype, dtypes.string)
368
369  def testFlattenLayer(self):
370    flatten_layer = core.Flatten()
371    x = constant_op.constant([[[-10, -20], [-30, -40]], [[10, 20], [30, 40]]])
372    y = flatten_layer(x)
373    self.assertAllEqual([[-10, -20, -30, -40], [10, 20, 30, 40]], y)
374
375  def testIdentity(self):
376    self.assertAllEqual(2, array_ops.identity(2))
377
378  @test_util.run_gpu_only
379  def testIdentityOnVariable(self):
380    with context.device('/gpu:0'):
381      v = resource_variable_ops.ResourceVariable(True)
382    self.assertAllEqual(True, array_ops.identity(v))
383
384  def testIncompatibleSetShape(self):
385    x = constant_op.constant(1)
386    with self.assertRaises(ValueError):
387      x.set_shape((1, 2))
388
389  def testCompatibleSetShape(self):
390    x = constant_op.constant([[1, 2]])
391    x.set_shape(tensor_shape.TensorShape([None, 2]))
392    self.assertEqual(x.get_shape(), (1, 2))
393
394  @parameterized.named_parameters(
395      ('Tensor', lambda: constant_op.constant(1.3+1j)),
396      ('Variable', lambda: resource_variable_ops.ResourceVariable(1.3+1j)))
397  def testCastToPrimitiveTypesFrom(self, value_fn):
398    x = value_fn()
399    self.assertIsInstance(int(x), int)
400    self.assertEqual(int(x), 1)
401    self.assertIsInstance(float(x), float)
402    self.assertAllClose(float(x), 1.3)
403    self.assertIsInstance(complex(x), complex)
404    self.assertAllClose(complex(x), 1.3+1j)
405
406  def testCastNonScalarToPrimitiveTypesFails(self):
407    x = constant_op.constant([1.3, 2])
408    with self.assertRaises(TypeError):
409      int(x)
410    with self.assertRaises(TypeError):
411      float(x)
412
413  def testRange(self):
414    x = constant_op.constant(2)
415    self.assertEqual([0, 1], list(range(x)))
416
417  def testFormatString(self):
418    x = constant_op.constant(3.1415)
419    self.assertEqual('3.14', '{:.2f}'.format(x))
420
421  def testNoOpIsNone(self):
422    self.assertIsNone(control_flow_ops.no_op())
423
424  def testEagerContextPreservedAcrossThreads(self):
425    def init_fn():
426      self.assertTrue(context.executing_eagerly())
427      with ops.init_scope():
428        self.assertTrue(context.executing_eagerly())
429        context_switches = context.context().context_switches
430        self.assertLen(context_switches.stack, 1)
431        self.assertFalse(context_switches.stack[0].is_building_function)
432        self.assertEqual(context_switches.stack[0].enter_context_fn,
433                         context.eager_mode)
434
435    self.assertTrue(context.executing_eagerly())
436    t1 = threading.Thread(target=init_fn)
437    t1.start()
438    t1.join()
439
440  def testWeakrefEagerTensor(self):
441    x = constant_op.constant([[1.]])
442    x.at1 = constant_op.constant([[2.]])
443    x.at2 = 3.
444    weak_x = weakref.ref(x)
445    weak_xat1 = weakref.ref(x.at1)
446    del x
447    self.assertIs(weak_x(), None)
448    self.assertIs(weak_xat1(), None)
449
450  def testWeakKeyDictionaryTensor(self):
451    weak_key_dict = weakref.WeakKeyDictionary()
452
453    strong_x = constant_op.constant([[1.]])
454    strong_y = constant_op.constant([[2.]])
455    strong_x_ref = strong_x.ref()
456    strong_y_ref = strong_y.ref()
457    weak_key_dict[strong_x_ref] = constant_op.constant([[3.]])
458    weak_key_dict[strong_y_ref] = constant_op.constant([[4.]])
459    strong_y.a = constant_op.constant([[5.]])
460    weak_x_ref = weakref.ref(strong_x)
461
462    del strong_x, strong_x_ref
463    self.assertIs(weak_x_ref(), None)
464    self.assertEqual([strong_y_ref], list(weak_key_dict))
465    self.assertLen(list(weak_key_dict), 1)
466    self.assertLen(weak_key_dict, 1)
467
468    del strong_y, strong_y_ref
469    self.assertEqual([], list(weak_key_dict))
470
471  def testEagerTensorsCanBeGarbageCollected(self):
472    x = constant_op.constant([[1.]])
473    y = constant_op.constant([[2.]])
474    x.y = y
475    y.x = x
476    weak_x = weakref.ref(x)
477    weak_y = weakref.ref(y)
478    del x
479    del y
480    gc.collect()
481    self.assertIs(weak_x(), None)
482    self.assertIs(weak_y(), None)
483
484  @test_util.disable_tfrt(
485      'b/153697193: tfrt cannot decode python stacktrace yet')
486  def testAsyncExceptionStackTrace(self):
487    config.set_synchronous_execution(False)
488
489    def exception_originated_from_here():
490      # Invalid shapes for matmul.
491      return math_ops.matmul([[1]], [[2], [3]])
492
493    # In sync mode, an exception would have been raised here but since this is
494    # in async, the exception will be raised next.
495    x = exception_originated_from_here()
496
497    with self.assertRaisesRegex(errors_impl.InvalidArgumentError,
498                                'in exception_originated_from_here'):
499      x.numpy()
500
501    context.async_clear_error()
502    config.set_synchronous_execution(True)
503
504  def testCrossContextTensorCache(self):
505    old_context = context.context()
506    old_x = constant_op.constant(9.5)
507    context._set_context(context.Context())
508
509    try:
510      new_x = constant_op.constant(9.5)
511      self.assertEqual(new_x.numpy(), 9.5)
512    finally:
513      context._set_context(old_context)
514
515    self.assertEqual(old_x.numpy(), 9.5)
516
517if __name__ == '__main__':
518  test.main()
519