• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# RUN: %PYTHON %s | FileCheck %s
2
3import gc
4from mlir.ir import *
5
6def run(f):
7  print("\nTEST:", f.__name__)
8  f()
9  gc.collect()
10  assert Context._get_live_count() == 0
11
12
13# CHECK-LABEL: TEST: testParsePrint
14def testParsePrint():
15  with Context() as ctx:
16    t = Attribute.parse('"hello"')
17  assert t.context is ctx
18  ctx = None
19  gc.collect()
20  # CHECK: "hello"
21  print(str(t))
22  # CHECK: Attribute("hello")
23  print(repr(t))
24
25run(testParsePrint)
26
27
28# CHECK-LABEL: TEST: testParseError
29# TODO: Hook the diagnostic manager to capture a more meaningful error
30# message.
31def testParseError():
32  with Context():
33    try:
34      t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST")
35    except ValueError as e:
36      # CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST'
37      print("testParseError:", e)
38    else:
39      print("Exception not produced")
40
41run(testParseError)
42
43
44# CHECK-LABEL: TEST: testAttrEq
45def testAttrEq():
46  with Context():
47    a1 = Attribute.parse('"attr1"')
48    a2 = Attribute.parse('"attr2"')
49    a3 = Attribute.parse('"attr1"')
50    # CHECK: a1 == a1: True
51    print("a1 == a1:", a1 == a1)
52    # CHECK: a1 == a2: False
53    print("a1 == a2:", a1 == a2)
54    # CHECK: a1 == a3: True
55    print("a1 == a3:", a1 == a3)
56    # CHECK: a1 == None: False
57    print("a1 == None:", a1 == None)
58
59run(testAttrEq)
60
61
62# CHECK-LABEL: TEST: testAttrEqDoesNotRaise
63def testAttrEqDoesNotRaise():
64  with Context():
65    a1 = Attribute.parse('"attr1"')
66    not_an_attr = "foo"
67    # CHECK: False
68    print(a1 == not_an_attr)
69    # CHECK: False
70    print(a1 == None)
71    # CHECK: True
72    print(a1 != None)
73
74run(testAttrEqDoesNotRaise)
75
76
77# CHECK-LABEL: TEST: testAttrCapsule
78def testAttrCapsule():
79  with Context() as ctx:
80    a1 = Attribute.parse('"attr1"')
81  # CHECK: mlir.ir.Attribute._CAPIPtr
82  attr_capsule = a1._CAPIPtr
83  print(attr_capsule)
84  a2 = Attribute._CAPICreate(attr_capsule)
85  assert a2 == a1
86  assert a2.context is ctx
87
88run(testAttrCapsule)
89
90
91# CHECK-LABEL: TEST: testStandardAttrCasts
92def testStandardAttrCasts():
93  with Context():
94    a1 = Attribute.parse('"attr1"')
95    astr = StringAttr(a1)
96    aself = StringAttr(astr)
97    # CHECK: Attribute("attr1")
98    print(repr(astr))
99    try:
100      tillegal = StringAttr(Attribute.parse("1.0"))
101    except ValueError as e:
102      # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64))
103      print("ValueError:", e)
104    else:
105      print("Exception not produced")
106
107run(testStandardAttrCasts)
108
109
110# CHECK-LABEL: TEST: testFloatAttr
111def testFloatAttr():
112  with Context(), Location.unknown():
113    fattr = FloatAttr(Attribute.parse("42.0 : f32"))
114    # CHECK: fattr value: 42.0
115    print("fattr value:", fattr.value)
116
117    # Test factory methods.
118    # CHECK: default_get: 4.200000e+01 : f32
119    print("default_get:", FloatAttr.get(
120        F32Type.get(), 42.0))
121    # CHECK: f32_get: 4.200000e+01 : f32
122    print("f32_get:", FloatAttr.get_f32(42.0))
123    # CHECK: f64_get: 4.200000e+01 : f64
124    print("f64_get:", FloatAttr.get_f64(42.0))
125    try:
126      fattr_invalid = FloatAttr.get(
127          IntegerType.get_signless(32), 42)
128    except ValueError as e:
129      # CHECK: invalid 'Type(i32)' and expected floating point type.
130      print(e)
131    else:
132      print("Exception not produced")
133
134run(testFloatAttr)
135
136
137# CHECK-LABEL: TEST: testIntegerAttr
138def testIntegerAttr():
139  with Context() as ctx:
140    iattr = IntegerAttr(Attribute.parse("42"))
141    # CHECK: iattr value: 42
142    print("iattr value:", iattr.value)
143    # CHECK: iattr type: i64
144    print("iattr type:", iattr.type)
145
146    # Test factory methods.
147    # CHECK: default_get: 42 : i32
148    print("default_get:", IntegerAttr.get(
149        IntegerType.get_signless(32), 42))
150
151run(testIntegerAttr)
152
153
154# CHECK-LABEL: TEST: testBoolAttr
155def testBoolAttr():
156  with Context() as ctx:
157    battr = BoolAttr(Attribute.parse("true"))
158    # CHECK: iattr value: True
159    print("iattr value:", battr.value)
160
161    # Test factory methods.
162    # CHECK: default_get: true
163    print("default_get:", BoolAttr.get(True))
164
165run(testBoolAttr)
166
167
168# CHECK-LABEL: TEST: testStringAttr
169def testStringAttr():
170  with Context() as ctx:
171    sattr = StringAttr(Attribute.parse('"stringattr"'))
172    # CHECK: sattr value: stringattr
173    print("sattr value:", sattr.value)
174
175    # Test factory methods.
176    # CHECK: default_get: "foobar"
177    print("default_get:", StringAttr.get("foobar"))
178    # CHECK: typed_get: "12345" : i32
179    print("typed_get:", StringAttr.get_typed(
180        IntegerType.get_signless(32), "12345"))
181
182run(testStringAttr)
183
184
185# CHECK-LABEL: TEST: testNamedAttr
186def testNamedAttr():
187  with Context():
188    a = Attribute.parse('"stringattr"')
189    named = a.get_named("foobar")  # Note: under the small object threshold
190    # CHECK: attr: "stringattr"
191    print("attr:", named.attr)
192    # CHECK: name: foobar
193    print("name:", named.name)
194    # CHECK: named: NamedAttribute(foobar="stringattr")
195    print("named:", named)
196
197run(testNamedAttr)
198
199
200# CHECK-LABEL: TEST: testDenseIntAttr
201def testDenseIntAttr():
202  with Context():
203    raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>")
204    # CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]>
205    print("attr:", raw)
206
207    a = DenseIntElementsAttr(raw)
208    assert len(a) == 6
209
210    # CHECK: 0 1 2 3 4 5
211    for value in a:
212      print(value, end=" ")
213    print()
214
215    # CHECK: i32
216    print(ShapedType(a.type).element_type)
217
218    raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>")
219    # CHECK: attr: dense<[true, false, true, false]>
220    print("attr:", raw)
221
222    a = DenseIntElementsAttr(raw)
223    assert len(a) == 4
224
225    # CHECK: 1 0 1 0
226    for value in a:
227      print(value, end=" ")
228    print()
229
230    # CHECK: i1
231    print(ShapedType(a.type).element_type)
232
233
234run(testDenseIntAttr)
235
236
237# CHECK-LABEL: TEST: testDenseFPAttr
238def testDenseFPAttr():
239  with Context():
240    raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>")
241    # CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]>
242
243    print("attr:", raw)
244
245    a = DenseFPElementsAttr(raw)
246    assert len(a) == 4
247
248    # CHECK: 0.0 1.0 2.0 3.0
249    for value in a:
250      print(value, end=" ")
251    print()
252
253    # CHECK: f32
254    print(ShapedType(a.type).element_type)
255
256
257run(testDenseFPAttr)
258
259
260# CHECK-LABEL: TEST: testTypeAttr
261def testTypeAttr():
262  with Context():
263    raw = Attribute.parse("vector<4xf32>")
264    # CHECK: attr: vector<4xf32>
265    print("attr:", raw)
266    type_attr = TypeAttr(raw)
267    # CHECK: f32
268    print(ShapedType(type_attr.value).element_type)
269
270
271run(testTypeAttr)
272