1# RUN: %PYTHON %s | FileCheck %s 2 3import gc 4from mlir.ir import * 5 6def run(f): 7 print("\nTEST:", f.__name__) 8 f() 9 gc.collect() 10 assert Context._get_live_count() == 0 11 12 13# CHECK-LABEL: TEST: testParsePrint 14def testParsePrint(): 15 with Context() as ctx: 16 t = Attribute.parse('"hello"') 17 assert t.context is ctx 18 ctx = None 19 gc.collect() 20 # CHECK: "hello" 21 print(str(t)) 22 # CHECK: Attribute("hello") 23 print(repr(t)) 24 25run(testParsePrint) 26 27 28# CHECK-LABEL: TEST: testParseError 29# TODO: Hook the diagnostic manager to capture a more meaningful error 30# message. 31def testParseError(): 32 with Context(): 33 try: 34 t = Attribute.parse("BAD_ATTR_DOES_NOT_EXIST") 35 except ValueError as e: 36 # CHECK: Unable to parse attribute: 'BAD_ATTR_DOES_NOT_EXIST' 37 print("testParseError:", e) 38 else: 39 print("Exception not produced") 40 41run(testParseError) 42 43 44# CHECK-LABEL: TEST: testAttrEq 45def testAttrEq(): 46 with Context(): 47 a1 = Attribute.parse('"attr1"') 48 a2 = Attribute.parse('"attr2"') 49 a3 = Attribute.parse('"attr1"') 50 # CHECK: a1 == a1: True 51 print("a1 == a1:", a1 == a1) 52 # CHECK: a1 == a2: False 53 print("a1 == a2:", a1 == a2) 54 # CHECK: a1 == a3: True 55 print("a1 == a3:", a1 == a3) 56 # CHECK: a1 == None: False 57 print("a1 == None:", a1 == None) 58 59run(testAttrEq) 60 61 62# CHECK-LABEL: TEST: testAttrEqDoesNotRaise 63def testAttrEqDoesNotRaise(): 64 with Context(): 65 a1 = Attribute.parse('"attr1"') 66 not_an_attr = "foo" 67 # CHECK: False 68 print(a1 == not_an_attr) 69 # CHECK: False 70 print(a1 == None) 71 # CHECK: True 72 print(a1 != None) 73 74run(testAttrEqDoesNotRaise) 75 76 77# CHECK-LABEL: TEST: testAttrCapsule 78def testAttrCapsule(): 79 with Context() as ctx: 80 a1 = Attribute.parse('"attr1"') 81 # CHECK: mlir.ir.Attribute._CAPIPtr 82 attr_capsule = a1._CAPIPtr 83 print(attr_capsule) 84 a2 = Attribute._CAPICreate(attr_capsule) 85 assert a2 == a1 86 assert a2.context is ctx 87 88run(testAttrCapsule) 89 90 91# CHECK-LABEL: TEST: testStandardAttrCasts 92def testStandardAttrCasts(): 93 with Context(): 94 a1 = Attribute.parse('"attr1"') 95 astr = StringAttr(a1) 96 aself = StringAttr(astr) 97 # CHECK: Attribute("attr1") 98 print(repr(astr)) 99 try: 100 tillegal = StringAttr(Attribute.parse("1.0")) 101 except ValueError as e: 102 # CHECK: ValueError: Cannot cast attribute to StringAttr (from Attribute(1.000000e+00 : f64)) 103 print("ValueError:", e) 104 else: 105 print("Exception not produced") 106 107run(testStandardAttrCasts) 108 109 110# CHECK-LABEL: TEST: testFloatAttr 111def testFloatAttr(): 112 with Context(), Location.unknown(): 113 fattr = FloatAttr(Attribute.parse("42.0 : f32")) 114 # CHECK: fattr value: 42.0 115 print("fattr value:", fattr.value) 116 117 # Test factory methods. 118 # CHECK: default_get: 4.200000e+01 : f32 119 print("default_get:", FloatAttr.get( 120 F32Type.get(), 42.0)) 121 # CHECK: f32_get: 4.200000e+01 : f32 122 print("f32_get:", FloatAttr.get_f32(42.0)) 123 # CHECK: f64_get: 4.200000e+01 : f64 124 print("f64_get:", FloatAttr.get_f64(42.0)) 125 try: 126 fattr_invalid = FloatAttr.get( 127 IntegerType.get_signless(32), 42) 128 except ValueError as e: 129 # CHECK: invalid 'Type(i32)' and expected floating point type. 130 print(e) 131 else: 132 print("Exception not produced") 133 134run(testFloatAttr) 135 136 137# CHECK-LABEL: TEST: testIntegerAttr 138def testIntegerAttr(): 139 with Context() as ctx: 140 iattr = IntegerAttr(Attribute.parse("42")) 141 # CHECK: iattr value: 42 142 print("iattr value:", iattr.value) 143 # CHECK: iattr type: i64 144 print("iattr type:", iattr.type) 145 146 # Test factory methods. 147 # CHECK: default_get: 42 : i32 148 print("default_get:", IntegerAttr.get( 149 IntegerType.get_signless(32), 42)) 150 151run(testIntegerAttr) 152 153 154# CHECK-LABEL: TEST: testBoolAttr 155def testBoolAttr(): 156 with Context() as ctx: 157 battr = BoolAttr(Attribute.parse("true")) 158 # CHECK: iattr value: True 159 print("iattr value:", battr.value) 160 161 # Test factory methods. 162 # CHECK: default_get: true 163 print("default_get:", BoolAttr.get(True)) 164 165run(testBoolAttr) 166 167 168# CHECK-LABEL: TEST: testStringAttr 169def testStringAttr(): 170 with Context() as ctx: 171 sattr = StringAttr(Attribute.parse('"stringattr"')) 172 # CHECK: sattr value: stringattr 173 print("sattr value:", sattr.value) 174 175 # Test factory methods. 176 # CHECK: default_get: "foobar" 177 print("default_get:", StringAttr.get("foobar")) 178 # CHECK: typed_get: "12345" : i32 179 print("typed_get:", StringAttr.get_typed( 180 IntegerType.get_signless(32), "12345")) 181 182run(testStringAttr) 183 184 185# CHECK-LABEL: TEST: testNamedAttr 186def testNamedAttr(): 187 with Context(): 188 a = Attribute.parse('"stringattr"') 189 named = a.get_named("foobar") # Note: under the small object threshold 190 # CHECK: attr: "stringattr" 191 print("attr:", named.attr) 192 # CHECK: name: foobar 193 print("name:", named.name) 194 # CHECK: named: NamedAttribute(foobar="stringattr") 195 print("named:", named) 196 197run(testNamedAttr) 198 199 200# CHECK-LABEL: TEST: testDenseIntAttr 201def testDenseIntAttr(): 202 with Context(): 203 raw = Attribute.parse("dense<[[0,1,2],[3,4,5]]> : vector<2x3xi32>") 204 # CHECK: attr: dense<[{{\[}}0, 1, 2], [3, 4, 5]]> 205 print("attr:", raw) 206 207 a = DenseIntElementsAttr(raw) 208 assert len(a) == 6 209 210 # CHECK: 0 1 2 3 4 5 211 for value in a: 212 print(value, end=" ") 213 print() 214 215 # CHECK: i32 216 print(ShapedType(a.type).element_type) 217 218 raw = Attribute.parse("dense<[true,false,true,false]> : vector<4xi1>") 219 # CHECK: attr: dense<[true, false, true, false]> 220 print("attr:", raw) 221 222 a = DenseIntElementsAttr(raw) 223 assert len(a) == 4 224 225 # CHECK: 1 0 1 0 226 for value in a: 227 print(value, end=" ") 228 print() 229 230 # CHECK: i1 231 print(ShapedType(a.type).element_type) 232 233 234run(testDenseIntAttr) 235 236 237# CHECK-LABEL: TEST: testDenseFPAttr 238def testDenseFPAttr(): 239 with Context(): 240 raw = Attribute.parse("dense<[0.0, 1.0, 2.0, 3.0]> : vector<4xf32>") 241 # CHECK: attr: dense<[0.000000e+00, 1.000000e+00, 2.000000e+00, 3.000000e+00]> 242 243 print("attr:", raw) 244 245 a = DenseFPElementsAttr(raw) 246 assert len(a) == 4 247 248 # CHECK: 0.0 1.0 2.0 3.0 249 for value in a: 250 print(value, end=" ") 251 print() 252 253 # CHECK: f32 254 print(ShapedType(a.type).element_type) 255 256 257run(testDenseFPAttr) 258 259 260# CHECK-LABEL: TEST: testTypeAttr 261def testTypeAttr(): 262 with Context(): 263 raw = Attribute.parse("vector<4xf32>") 264 # CHECK: attr: vector<4xf32> 265 print("attr:", raw) 266 type_attr = TypeAttr(raw) 267 # CHECK: f32 268 print(ShapedType(type_attr.value).element_type) 269 270 271run(testTypeAttr) 272