1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for DLPack functions.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from absl.testing import parameterized 21import numpy as np 22 23 24from tensorflow.python.dlpack import dlpack 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import errors 28from tensorflow.python.framework import ops 29from tensorflow.python.platform import test 30from tensorflow.python.ops import array_ops 31 32int_dtypes = [ 33 np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, 34 np.uint64 35] 36float_dtypes = [np.float16, np.float32, np.float64] 37complex_dtypes = [np.complex64, np.complex128] 38dlpack_dtypes = int_dtypes + float_dtypes + [dtypes.bfloat16] 39 40testcase_shapes = [(), (1,), (2, 3), (2, 0), (0, 7), (4, 1, 2)] 41 42 43def FormatShapeAndDtype(shape, dtype): 44 return "_{}[{}]".format(str(dtype), ",".join(map(str, shape))) 45 46 47def GetNamedTestParameters(): 48 result = [] 49 for dtype in dlpack_dtypes: 50 for shape in testcase_shapes: 51 result.append({ 52 "testcase_name": FormatShapeAndDtype(shape, dtype), 53 "dtype": dtype, 54 "shape": shape 55 }) 56 return result 57 58 59class DLPackTest(parameterized.TestCase, test.TestCase): 60 61 @parameterized.named_parameters(GetNamedTestParameters()) 62 def testRoundTrip(self, dtype, shape): 63 np.random.seed(42) 64 np_array = np.random.randint(0, 10, shape) 65 # copy to gpu if available 66 tf_tensor = array_ops.identity(constant_op.constant(np_array, dtype=dtype)) 67 tf_tensor_device = tf_tensor.device 68 tf_tensor_dtype = tf_tensor.dtype 69 dlcapsule = dlpack.to_dlpack(tf_tensor) 70 del tf_tensor # should still work 71 tf_tensor2 = dlpack.from_dlpack(dlcapsule) 72 self.assertAllClose(np_array, tf_tensor2) 73 if tf_tensor_dtype == dtypes.int32: 74 # int32 tensor is always on cpu for now 75 self.assertEqual(tf_tensor2.device, 76 "/job:localhost/replica:0/task:0/device:CPU:0") 77 else: 78 self.assertEqual(tf_tensor_device, tf_tensor2.device) 79 80 def testTensorsCanBeConsumedOnceOnly(self): 81 np.random.seed(42) 82 np_array = np.random.randint(0, 10, (2, 3, 4)) 83 tf_tensor = constant_op.constant(np_array, dtype=np.float32) 84 dlcapsule = dlpack.to_dlpack(tf_tensor) 85 del tf_tensor # should still work 86 _ = dlpack.from_dlpack(dlcapsule) 87 88 def ConsumeDLPackTensor(): 89 dlpack.from_dlpack(dlcapsule) # Should can be consumed only once 90 91 self.assertRaisesRegex(Exception, 92 ".*a DLPack tensor may be consumed at most once.*", 93 ConsumeDLPackTensor) 94 95 def testUnsupportedTypeToDLPack(self): 96 97 def UnsupportedQint16(): 98 tf_tensor = constant_op.constant([[1, 4], [5, 2]], dtype=dtypes.qint16) 99 _ = dlpack.to_dlpack(tf_tensor) 100 101 def UnsupportedComplex64(): 102 tf_tensor = constant_op.constant([[1, 4], [5, 2]], dtype=dtypes.complex64) 103 _ = dlpack.to_dlpack(tf_tensor) 104 105 self.assertRaisesRegex(Exception, ".* is not supported by dlpack", 106 UnsupportedQint16) 107 self.assertRaisesRegex(Exception, ".* is not supported by dlpack", 108 UnsupportedComplex64) 109 110 def testMustPassTensorArgumentToDLPack(self): 111 with self.assertRaisesRegex( 112 errors.InvalidArgumentError, 113 "The argument to `to_dlpack` must be a TF tensor, not Python object"): 114 dlpack.to_dlpack([1]) 115 116 117if __name__ == "__main__": 118 ops.enable_eager_execution() 119 test.main() 120