• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for DLPack functions."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from absl.testing import parameterized
21import numpy as np
22
23
24from tensorflow.python.dlpack import dlpack
25from tensorflow.python.eager import context
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import errors
29from tensorflow.python.framework import ops
30from tensorflow.python.platform import test
31from tensorflow.python.ops import array_ops
32
33int_dtypes = [
34    np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32,
35    np.uint64
36]
37float_dtypes = [np.float16, np.float32, np.float64]
38complex_dtypes = [np.complex64, np.complex128]
39dlpack_dtypes = int_dtypes + float_dtypes + [dtypes.bfloat16]
40
41testcase_shapes = [(), (1,), (2, 3), (2, 0), (0, 7), (4, 1, 2)]
42
43
44def FormatShapeAndDtype(shape, dtype):
45  return "_{}[{}]".format(str(dtype), ",".join(map(str, shape)))
46
47
48def GetNamedTestParameters():
49  result = []
50  for dtype in dlpack_dtypes:
51    for shape in testcase_shapes:
52      result.append({
53          "testcase_name": FormatShapeAndDtype(shape, dtype),
54          "dtype": dtype,
55          "shape": shape
56      })
57  return result
58
59
60class DLPackTest(parameterized.TestCase, test.TestCase):
61
62  @parameterized.named_parameters(GetNamedTestParameters())
63  def testRoundTrip(self, dtype, shape):
64    np.random.seed(42)
65    np_array = np.random.randint(0, 10, shape)
66    # copy to gpu if available
67    tf_tensor = array_ops.identity(constant_op.constant(np_array, dtype=dtype))
68    tf_tensor_device = tf_tensor.device
69    tf_tensor_dtype = tf_tensor.dtype
70    dlcapsule = dlpack.to_dlpack(tf_tensor)
71    del tf_tensor  # should still work
72    tf_tensor2 = dlpack.from_dlpack(dlcapsule)
73    self.assertAllClose(np_array, tf_tensor2)
74    if tf_tensor_dtype == dtypes.int32:
75      # int32 tensor is always on cpu for now
76      self.assertEqual(tf_tensor2.device,
77                       "/job:localhost/replica:0/task:0/device:CPU:0")
78    else:
79      self.assertEqual(tf_tensor_device, tf_tensor2.device)
80
81  def testTensorsCanBeConsumedOnceOnly(self):
82    np.random.seed(42)
83    np_array = np.random.randint(0, 10, (2, 3, 4))
84    tf_tensor = constant_op.constant(np_array, dtype=np.float32)
85    dlcapsule = dlpack.to_dlpack(tf_tensor)
86    del tf_tensor  # should still work
87    _ = dlpack.from_dlpack(dlcapsule)
88
89    def ConsumeDLPackTensor():
90      dlpack.from_dlpack(dlcapsule)  # Should can be consumed only once
91
92    self.assertRaisesRegex(Exception,
93                           ".*a DLPack tensor may be consumed at most once.*",
94                           ConsumeDLPackTensor)
95
96  def testDLPackFromWithoutContextInitialization(self):
97    tf_tensor = constant_op.constant(1)
98    dlcapsule = dlpack.to_dlpack(tf_tensor)
99    # Resetting the context doesn't cause an error.
100    context._reset_context()
101    _ = dlpack.from_dlpack(dlcapsule)
102
103  def testUnsupportedTypeToDLPack(self):
104
105    def UnsupportedQint16():
106      tf_tensor = constant_op.constant([[1, 4], [5, 2]], dtype=dtypes.qint16)
107      _ = dlpack.to_dlpack(tf_tensor)
108
109    def UnsupportedComplex64():
110      tf_tensor = constant_op.constant([[1, 4], [5, 2]], dtype=dtypes.complex64)
111      _ = dlpack.to_dlpack(tf_tensor)
112
113    self.assertRaisesRegex(Exception, ".* is not supported by dlpack",
114                           UnsupportedQint16)
115    self.assertRaisesRegex(Exception, ".* is not supported by dlpack",
116                           UnsupportedComplex64)
117
118  def testMustPassTensorArgumentToDLPack(self):
119    with self.assertRaisesRegex(
120        errors.InvalidArgumentError,
121        "The argument to `to_dlpack` must be a TF tensor, not Python object"):
122      dlpack.to_dlpack([1])
123
124
125if __name__ == "__main__":
126  ops.enable_eager_execution()
127  test.main()
128