1# RUN: %PYTHON %s | FileCheck %s 2# Note that this is separate from ir_attributes.py since it depends on numpy, 3# and we may want to disable if not available. 4 5import gc 6from mlir.ir import * 7import numpy as np 8 9def run(f): 10 print("\nTEST:", f.__name__) 11 f() 12 gc.collect() 13 assert Context._get_live_count() == 0 14 15################################################################################ 16# Tests of the array/buffer .get() factory method on unsupported dtype. 17################################################################################ 18 19def testGetDenseElementsUnsupported(): 20 with Context(): 21 array = np.array([["hello", "goodbye"]]) 22 try: 23 attr = DenseElementsAttr.get(array) 24 except ValueError as e: 25 # CHECK: unimplemented array format conversion from format: 26 print(e) 27 28run(testGetDenseElementsUnsupported) 29 30################################################################################ 31# Splats. 32################################################################################ 33 34# CHECK-LABEL: TEST: testGetDenseElementsSplatInt 35def testGetDenseElementsSplatInt(): 36 with Context(), Location.unknown(): 37 t = IntegerType.get_signless(32) 38 element = IntegerAttr.get(t, 555) 39 shaped_type = RankedTensorType.get((2, 3, 4), t) 40 attr = DenseElementsAttr.get_splat(shaped_type, element) 41 # CHECK: dense<555> : tensor<2x3x4xi32> 42 print(attr) 43 # CHECK: is_splat: True 44 print("is_splat:", attr.is_splat) 45 46run(testGetDenseElementsSplatInt) 47 48 49# CHECK-LABEL: TEST: testGetDenseElementsSplatFloat 50def testGetDenseElementsSplatFloat(): 51 with Context(), Location.unknown(): 52 t = F32Type.get() 53 element = FloatAttr.get(t, 1.2) 54 shaped_type = RankedTensorType.get((2, 3, 4), t) 55 attr = DenseElementsAttr.get_splat(shaped_type, element) 56 # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32> 57 print(attr) 58 59run(testGetDenseElementsSplatFloat) 60 61 62# CHECK-LABEL: TEST: testGetDenseElementsSplatErrors 63def testGetDenseElementsSplatErrors(): 64 with Context(), Location.unknown(): 65 t = F32Type.get() 66 other_t = F64Type.get() 67 element = FloatAttr.get(t, 1.2) 68 other_element = FloatAttr.get(other_t, 1.2) 69 shaped_type = RankedTensorType.get((2, 3, 4), t) 70 dynamic_shaped_type = UnrankedTensorType.get(t) 71 non_shaped_type = t 72 73 try: 74 attr = DenseElementsAttr.get_splat(non_shaped_type, element) 75 except ValueError as e: 76 # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32) 77 print(e) 78 79 try: 80 attr = DenseElementsAttr.get_splat(dynamic_shaped_type, element) 81 except ValueError as e: 82 # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>) 83 print(e) 84 85 try: 86 attr = DenseElementsAttr.get_splat(shaped_type, other_element) 87 except ValueError as e: 88 # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64) 89 print(e) 90 91run(testGetDenseElementsSplatErrors) 92 93 94################################################################################ 95# Tests of the array/buffer .get() factory method, in all of its permutations. 96################################################################################ 97 98### float and double arrays. 99 100# CHECK-LABEL: TEST: testGetDenseElementsF32 101def testGetDenseElementsF32(): 102 with Context(): 103 array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32) 104 attr = DenseElementsAttr.get(array) 105 # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32> 106 print(attr) 107 # CHECK: is_splat: False 108 print("is_splat:", attr.is_splat) 109 # CHECK: {{\[}}[1.1 2.2 3.3] 110 # CHECK: {{\[}}4.4 5.5 6.6]] 111 print(np.array(attr)) 112 113run(testGetDenseElementsF32) 114 115 116# CHECK-LABEL: TEST: testGetDenseElementsF64 117def testGetDenseElementsF64(): 118 with Context(): 119 array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64) 120 attr = DenseElementsAttr.get(array) 121 # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64> 122 print(attr) 123 # CHECK: {{\[}}[1.1 2.2 3.3] 124 # CHECK: {{\[}}4.4 5.5 6.6]] 125 print(np.array(attr)) 126 127run(testGetDenseElementsF64) 128 129 130### 32 bit integer arrays 131# CHECK-LABEL: TEST: testGetDenseElementsI32Signless 132def testGetDenseElementsI32Signless(): 133 with Context(): 134 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) 135 attr = DenseElementsAttr.get(array) 136 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> 137 print(attr) 138 # CHECK: {{\[}}[1 2 3] 139 # CHECK: {{\[}}4 5 6]] 140 print(np.array(attr)) 141 142run(testGetDenseElementsI32Signless) 143 144 145# CHECK-LABEL: TEST: testGetDenseElementsUI32Signless 146def testGetDenseElementsUI32Signless(): 147 with Context(): 148 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32) 149 attr = DenseElementsAttr.get(array) 150 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32> 151 print(attr) 152 # CHECK: {{\[}}[1 2 3] 153 # CHECK: {{\[}}4 5 6]] 154 print(np.array(attr)) 155 156run(testGetDenseElementsUI32Signless) 157 158# CHECK-LABEL: TEST: testGetDenseElementsI32 159def testGetDenseElementsI32(): 160 with Context(): 161 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32) 162 attr = DenseElementsAttr.get(array, signless=False) 163 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32> 164 print(attr) 165 # CHECK: {{\[}}[1 2 3] 166 # CHECK: {{\[}}4 5 6]] 167 print(np.array(attr)) 168 169run(testGetDenseElementsI32) 170 171 172# CHECK-LABEL: TEST: testGetDenseElementsUI32 173def testGetDenseElementsUI32(): 174 with Context(): 175 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32) 176 attr = DenseElementsAttr.get(array, signless=False) 177 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32> 178 print(attr) 179 # CHECK: {{\[}}[1 2 3] 180 # CHECK: {{\[}}4 5 6]] 181 print(np.array(attr)) 182 183run(testGetDenseElementsUI32) 184 185 186## 64bit integer arrays 187# CHECK-LABEL: TEST: testGetDenseElementsI64Signless 188def testGetDenseElementsI64Signless(): 189 with Context(): 190 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) 191 attr = DenseElementsAttr.get(array) 192 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> 193 print(attr) 194 # CHECK: {{\[}}[1 2 3] 195 # CHECK: {{\[}}4 5 6]] 196 print(np.array(attr)) 197 198run(testGetDenseElementsI64Signless) 199 200 201# CHECK-LABEL: TEST: testGetDenseElementsUI64Signless 202def testGetDenseElementsUI64Signless(): 203 with Context(): 204 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64) 205 attr = DenseElementsAttr.get(array) 206 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64> 207 print(attr) 208 # CHECK: {{\[}}[1 2 3] 209 # CHECK: {{\[}}4 5 6]] 210 print(np.array(attr)) 211 212run(testGetDenseElementsUI64Signless) 213 214# CHECK-LABEL: TEST: testGetDenseElementsI64 215def testGetDenseElementsI64(): 216 with Context(): 217 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) 218 attr = DenseElementsAttr.get(array, signless=False) 219 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64> 220 print(attr) 221 # CHECK: {{\[}}[1 2 3] 222 # CHECK: {{\[}}4 5 6]] 223 print(np.array(attr)) 224 225run(testGetDenseElementsI64) 226 227 228# CHECK-LABEL: TEST: testGetDenseElementsUI64 229def testGetDenseElementsUI64(): 230 with Context(): 231 array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64) 232 attr = DenseElementsAttr.get(array, signless=False) 233 # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64> 234 print(attr) 235 # CHECK: {{\[}}[1 2 3] 236 # CHECK: {{\[}}4 5 6]] 237 print(np.array(attr)) 238 239run(testGetDenseElementsUI64) 240 241