• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# RUN: %PYTHON %s | FileCheck %s
2# Note that this is separate from ir_attributes.py since it depends on numpy,
3# and we may want to disable if not available.
4
5import gc
6from mlir.ir import *
7import numpy as np
8
9def run(f):
10  print("\nTEST:", f.__name__)
11  f()
12  gc.collect()
13  assert Context._get_live_count() == 0
14
15################################################################################
16# Tests of the array/buffer .get() factory method on unsupported dtype.
17################################################################################
18
19def testGetDenseElementsUnsupported():
20  with Context():
21    array = np.array([["hello", "goodbye"]])
22    try:
23      attr = DenseElementsAttr.get(array)
24    except ValueError as e:
25      # CHECK: unimplemented array format conversion from format:
26      print(e)
27
28run(testGetDenseElementsUnsupported)
29
30################################################################################
31# Splats.
32################################################################################
33
34# CHECK-LABEL: TEST: testGetDenseElementsSplatInt
35def testGetDenseElementsSplatInt():
36  with Context(), Location.unknown():
37    t = IntegerType.get_signless(32)
38    element = IntegerAttr.get(t, 555)
39    shaped_type = RankedTensorType.get((2, 3, 4), t)
40    attr = DenseElementsAttr.get_splat(shaped_type, element)
41    # CHECK: dense<555> : tensor<2x3x4xi32>
42    print(attr)
43    # CHECK: is_splat: True
44    print("is_splat:", attr.is_splat)
45
46run(testGetDenseElementsSplatInt)
47
48
49# CHECK-LABEL: TEST: testGetDenseElementsSplatFloat
50def testGetDenseElementsSplatFloat():
51  with Context(), Location.unknown():
52    t = F32Type.get()
53    element = FloatAttr.get(t, 1.2)
54    shaped_type = RankedTensorType.get((2, 3, 4), t)
55    attr = DenseElementsAttr.get_splat(shaped_type, element)
56    # CHECK: dense<1.200000e+00> : tensor<2x3x4xf32>
57    print(attr)
58
59run(testGetDenseElementsSplatFloat)
60
61
62# CHECK-LABEL: TEST: testGetDenseElementsSplatErrors
63def testGetDenseElementsSplatErrors():
64  with Context(), Location.unknown():
65    t = F32Type.get()
66    other_t = F64Type.get()
67    element = FloatAttr.get(t, 1.2)
68    other_element = FloatAttr.get(other_t, 1.2)
69    shaped_type = RankedTensorType.get((2, 3, 4), t)
70    dynamic_shaped_type = UnrankedTensorType.get(t)
71    non_shaped_type = t
72
73    try:
74      attr = DenseElementsAttr.get_splat(non_shaped_type, element)
75    except ValueError as e:
76      # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(f32)
77      print(e)
78
79    try:
80      attr = DenseElementsAttr.get_splat(dynamic_shaped_type, element)
81    except ValueError as e:
82      # CHECK: Expected a static ShapedType for the shaped_type parameter: Type(tensor<*xf32>)
83      print(e)
84
85    try:
86      attr = DenseElementsAttr.get_splat(shaped_type, other_element)
87    except ValueError as e:
88      # CHECK: Shaped element type and attribute type must be equal: shaped=Type(tensor<2x3x4xf32>), element=Attribute(1.200000e+00 : f64)
89      print(e)
90
91run(testGetDenseElementsSplatErrors)
92
93
94################################################################################
95# Tests of the array/buffer .get() factory method, in all of its permutations.
96################################################################################
97
98### float and double arrays.
99
100# CHECK-LABEL: TEST: testGetDenseElementsF32
101def testGetDenseElementsF32():
102  with Context():
103    array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float32)
104    attr = DenseElementsAttr.get(array)
105    # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf32>
106    print(attr)
107    # CHECK: is_splat: False
108    print("is_splat:", attr.is_splat)
109    # CHECK: {{\[}}[1.1 2.2 3.3]
110    # CHECK: {{\[}}4.4 5.5 6.6]]
111    print(np.array(attr))
112
113run(testGetDenseElementsF32)
114
115
116# CHECK-LABEL: TEST: testGetDenseElementsF64
117def testGetDenseElementsF64():
118  with Context():
119    array = np.array([[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]], dtype=np.float64)
120    attr = DenseElementsAttr.get(array)
121    # CHECK: dense<{{\[}}[1.100000e+00, 2.200000e+00, 3.300000e+00], [4.400000e+00, 5.500000e+00, 6.600000e+00]]> : tensor<2x3xf64>
122    print(attr)
123    # CHECK: {{\[}}[1.1 2.2 3.3]
124    # CHECK: {{\[}}4.4 5.5 6.6]]
125    print(np.array(attr))
126
127run(testGetDenseElementsF64)
128
129
130### 32 bit integer arrays
131# CHECK-LABEL: TEST: testGetDenseElementsI32Signless
132def testGetDenseElementsI32Signless():
133  with Context():
134    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
135    attr = DenseElementsAttr.get(array)
136    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
137    print(attr)
138    # CHECK: {{\[}}[1 2 3]
139    # CHECK: {{\[}}4 5 6]]
140    print(np.array(attr))
141
142run(testGetDenseElementsI32Signless)
143
144
145# CHECK-LABEL: TEST: testGetDenseElementsUI32Signless
146def testGetDenseElementsUI32Signless():
147  with Context():
148    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
149    attr = DenseElementsAttr.get(array)
150    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
151    print(attr)
152    # CHECK: {{\[}}[1 2 3]
153    # CHECK: {{\[}}4 5 6]]
154    print(np.array(attr))
155
156run(testGetDenseElementsUI32Signless)
157
158# CHECK-LABEL: TEST: testGetDenseElementsI32
159def testGetDenseElementsI32():
160  with Context():
161    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)
162    attr = DenseElementsAttr.get(array, signless=False)
163    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi32>
164    print(attr)
165    # CHECK: {{\[}}[1 2 3]
166    # CHECK: {{\[}}4 5 6]]
167    print(np.array(attr))
168
169run(testGetDenseElementsI32)
170
171
172# CHECK-LABEL: TEST: testGetDenseElementsUI32
173def testGetDenseElementsUI32():
174  with Context():
175    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint32)
176    attr = DenseElementsAttr.get(array, signless=False)
177    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui32>
178    print(attr)
179    # CHECK: {{\[}}[1 2 3]
180    # CHECK: {{\[}}4 5 6]]
181    print(np.array(attr))
182
183run(testGetDenseElementsUI32)
184
185
186## 64bit integer arrays
187# CHECK-LABEL: TEST: testGetDenseElementsI64Signless
188def testGetDenseElementsI64Signless():
189  with Context():
190    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
191    attr = DenseElementsAttr.get(array)
192    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
193    print(attr)
194    # CHECK: {{\[}}[1 2 3]
195    # CHECK: {{\[}}4 5 6]]
196    print(np.array(attr))
197
198run(testGetDenseElementsI64Signless)
199
200
201# CHECK-LABEL: TEST: testGetDenseElementsUI64Signless
202def testGetDenseElementsUI64Signless():
203  with Context():
204    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
205    attr = DenseElementsAttr.get(array)
206    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi64>
207    print(attr)
208    # CHECK: {{\[}}[1 2 3]
209    # CHECK: {{\[}}4 5 6]]
210    print(np.array(attr))
211
212run(testGetDenseElementsUI64Signless)
213
214# CHECK-LABEL: TEST: testGetDenseElementsI64
215def testGetDenseElementsI64():
216  with Context():
217    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64)
218    attr = DenseElementsAttr.get(array, signless=False)
219    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xsi64>
220    print(attr)
221    # CHECK: {{\[}}[1 2 3]
222    # CHECK: {{\[}}4 5 6]]
223    print(np.array(attr))
224
225run(testGetDenseElementsI64)
226
227
228# CHECK-LABEL: TEST: testGetDenseElementsUI64
229def testGetDenseElementsUI64():
230  with Context():
231    array = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint64)
232    attr = DenseElementsAttr.get(array, signless=False)
233    # CHECK: dense<{{\[}}[1, 2, 3], [4, 5, 6]]> : tensor<2x3xui64>
234    print(attr)
235    # CHECK: {{\[}}[1 2 3]
236    # CHECK: {{\[}}4 5 6]]
237    print(np.array(attr))
238
239run(testGetDenseElementsUI64)
240
241