• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for DLPack functions."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from absl.testing import parameterized
21import numpy as np
22
23
24from tensorflow.python.dlpack import dlpack
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import errors
28from tensorflow.python.framework import ops
29from tensorflow.python.platform import test
30from tensorflow.python.ops import array_ops
31
32int_dtypes = [
33    np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32,
34    np.uint64
35]
36float_dtypes = [np.float16, np.float32, np.float64]
37complex_dtypes = [np.complex64, np.complex128]
38dlpack_dtypes = int_dtypes + float_dtypes + [dtypes.bfloat16]
39
40testcase_shapes = [(), (1,), (2, 3), (2, 0), (0, 7), (4, 1, 2)]
41
42
43def FormatShapeAndDtype(shape, dtype):
44  return "_{}[{}]".format(str(dtype), ",".join(map(str, shape)))
45
46
47def GetNamedTestParameters():
48  result = []
49  for dtype in dlpack_dtypes:
50    for shape in testcase_shapes:
51      result.append({
52          "testcase_name": FormatShapeAndDtype(shape, dtype),
53          "dtype": dtype,
54          "shape": shape
55      })
56  return result
57
58
59class DLPackTest(parameterized.TestCase, test.TestCase):
60
61  @parameterized.named_parameters(GetNamedTestParameters())
62  def testRoundTrip(self, dtype, shape):
63    np.random.seed(42)
64    np_array = np.random.randint(0, 10, shape)
65    # copy to gpu if available
66    tf_tensor = array_ops.identity(constant_op.constant(np_array, dtype=dtype))
67    tf_tensor_device = tf_tensor.device
68    tf_tensor_dtype = tf_tensor.dtype
69    dlcapsule = dlpack.to_dlpack(tf_tensor)
70    del tf_tensor  # should still work
71    tf_tensor2 = dlpack.from_dlpack(dlcapsule)
72    self.assertAllClose(np_array, tf_tensor2)
73    if tf_tensor_dtype == dtypes.int32:
74      # int32 tensor is always on cpu for now
75      self.assertEqual(tf_tensor2.device,
76                       "/job:localhost/replica:0/task:0/device:CPU:0")
77    else:
78      self.assertEqual(tf_tensor_device, tf_tensor2.device)
79
80  def testTensorsCanBeConsumedOnceOnly(self):
81    np.random.seed(42)
82    np_array = np.random.randint(0, 10, (2, 3, 4))
83    tf_tensor = constant_op.constant(np_array, dtype=np.float32)
84    dlcapsule = dlpack.to_dlpack(tf_tensor)
85    del tf_tensor  # should still work
86    _ = dlpack.from_dlpack(dlcapsule)
87
88    def ConsumeDLPackTensor():
89      dlpack.from_dlpack(dlcapsule)  # Should can be consumed only once
90
91    self.assertRaisesRegex(Exception,
92                           ".*a DLPack tensor may be consumed at most once.*",
93                           ConsumeDLPackTensor)
94
95  def testUnsupportedTypeToDLPack(self):
96
97    def UnsupportedQint16():
98      tf_tensor = constant_op.constant([[1, 4], [5, 2]], dtype=dtypes.qint16)
99      _ = dlpack.to_dlpack(tf_tensor)
100
101    def UnsupportedComplex64():
102      tf_tensor = constant_op.constant([[1, 4], [5, 2]], dtype=dtypes.complex64)
103      _ = dlpack.to_dlpack(tf_tensor)
104
105    self.assertRaisesRegex(Exception, ".* is not supported by dlpack",
106                           UnsupportedQint16)
107    self.assertRaisesRegex(Exception, ".* is not supported by dlpack",
108                           UnsupportedComplex64)
109
110  def testMustPassTensorArgumentToDLPack(self):
111    with self.assertRaisesRegex(
112        errors.InvalidArgumentError,
113        "The argument to `to_dlpack` must be a TF tensor, not Python object"):
114      dlpack.to_dlpack([1])
115
116
117if __name__ == "__main__":
118  ops.enable_eager_execution()
119  test.main()
120