1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for DLPack functions.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from absl.testing import parameterized 21import numpy as np 22 23 24from tensorflow.python.dlpack import dlpack 25from tensorflow.python.eager import context 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import errors 29from tensorflow.python.framework import ops 30from tensorflow.python.platform import test 31from tensorflow.python.ops import array_ops 32 33int_dtypes = [ 34 np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, 35 np.uint64 36] 37float_dtypes = [np.float16, np.float32, np.float64] 38complex_dtypes = [np.complex64, np.complex128] 39dlpack_dtypes = int_dtypes + float_dtypes + [dtypes.bfloat16] 40 41testcase_shapes = [(), (1,), (2, 3), (2, 0), (0, 7), (4, 1, 2)] 42 43 44def FormatShapeAndDtype(shape, dtype): 45 return "_{}[{}]".format(str(dtype), ",".join(map(str, shape))) 46 47 48def GetNamedTestParameters(): 49 result = [] 50 for dtype in dlpack_dtypes: 51 for shape in testcase_shapes: 52 result.append({ 53 "testcase_name": FormatShapeAndDtype(shape, dtype), 54 "dtype": dtype, 55 "shape": shape 56 }) 57 return result 58 59 60class DLPackTest(parameterized.TestCase, test.TestCase): 61 62 @parameterized.named_parameters(GetNamedTestParameters()) 63 def testRoundTrip(self, dtype, shape): 64 np.random.seed(42) 65 np_array = np.random.randint(0, 10, shape) 66 # copy to gpu if available 67 tf_tensor = array_ops.identity(constant_op.constant(np_array, dtype=dtype)) 68 tf_tensor_device = tf_tensor.device 69 tf_tensor_dtype = tf_tensor.dtype 70 dlcapsule = dlpack.to_dlpack(tf_tensor) 71 del tf_tensor # should still work 72 tf_tensor2 = dlpack.from_dlpack(dlcapsule) 73 self.assertAllClose(np_array, tf_tensor2) 74 if tf_tensor_dtype == dtypes.int32: 75 # int32 tensor is always on cpu for now 76 self.assertEqual(tf_tensor2.device, 77 "/job:localhost/replica:0/task:0/device:CPU:0") 78 else: 79 self.assertEqual(tf_tensor_device, tf_tensor2.device) 80 81 def testTensorsCanBeConsumedOnceOnly(self): 82 np.random.seed(42) 83 np_array = np.random.randint(0, 10, (2, 3, 4)) 84 tf_tensor = constant_op.constant(np_array, dtype=np.float32) 85 dlcapsule = dlpack.to_dlpack(tf_tensor) 86 del tf_tensor # should still work 87 _ = dlpack.from_dlpack(dlcapsule) 88 89 def ConsumeDLPackTensor(): 90 dlpack.from_dlpack(dlcapsule) # Should can be consumed only once 91 92 self.assertRaisesRegex(Exception, 93 ".*a DLPack tensor may be consumed at most once.*", 94 ConsumeDLPackTensor) 95 96 def testDLPackFromWithoutContextInitialization(self): 97 tf_tensor = constant_op.constant(1) 98 dlcapsule = dlpack.to_dlpack(tf_tensor) 99 # Resetting the context doesn't cause an error. 100 context._reset_context() 101 _ = dlpack.from_dlpack(dlcapsule) 102 103 def testUnsupportedTypeToDLPack(self): 104 105 def UnsupportedQint16(): 106 tf_tensor = constant_op.constant([[1, 4], [5, 2]], dtype=dtypes.qint16) 107 _ = dlpack.to_dlpack(tf_tensor) 108 109 def UnsupportedComplex64(): 110 tf_tensor = constant_op.constant([[1, 4], [5, 2]], dtype=dtypes.complex64) 111 _ = dlpack.to_dlpack(tf_tensor) 112 113 self.assertRaisesRegex(Exception, ".* is not supported by dlpack", 114 UnsupportedQint16) 115 self.assertRaisesRegex(Exception, ".* is not supported by dlpack", 116 UnsupportedComplex64) 117 118 def testMustPassTensorArgumentToDLPack(self): 119 with self.assertRaisesRegex( 120 errors.InvalidArgumentError, 121 "The argument to `to_dlpack` must be a TF tensor, not Python object"): 122 dlpack.to_dlpack([1]) 123 124 125if __name__ == "__main__": 126 ops.enable_eager_execution() 127 test.main() 128