1 //===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the MLIR Types.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "Parser.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/IntegerSet.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/Support/Endian.h"
20
21 using namespace mlir;
22 using namespace mlir::detail;
23
24 /// Parse an arbitrary attribute.
25 ///
26 /// attribute-value ::= `unit`
27 /// | bool-literal
28 /// | integer-literal (`:` (index-type | integer-type))?
29 /// | float-literal (`:` float-type)?
30 /// | string-literal (`:` type)?
31 /// | type
32 /// | `[` (attribute-value (`,` attribute-value)*)? `]`
33 /// | `{` (attribute-entry (`,` attribute-entry)*)? `}`
34 /// | symbol-ref-id (`::` symbol-ref-id)*
35 /// | `dense` `<` attribute-value `>` `:`
36 /// (tensor-type | vector-type)
37 /// | `sparse` `<` attribute-value `,` attribute-value `>`
38 /// `:` (tensor-type | vector-type)
39 /// | `opaque` `<` dialect-namespace `,` hex-string-literal
40 /// `>` `:` (tensor-type | vector-type)
41 /// | extended-attribute
42 ///
parseAttribute(Type type)43 Attribute Parser::parseAttribute(Type type) {
44 switch (getToken().getKind()) {
45 // Parse an AffineMap or IntegerSet attribute.
46 case Token::kw_affine_map: {
47 consumeToken(Token::kw_affine_map);
48
49 AffineMap map;
50 if (parseToken(Token::less, "expected '<' in affine map") ||
51 parseAffineMapReference(map) ||
52 parseToken(Token::greater, "expected '>' in affine map"))
53 return Attribute();
54 return AffineMapAttr::get(map);
55 }
56 case Token::kw_affine_set: {
57 consumeToken(Token::kw_affine_set);
58
59 IntegerSet set;
60 if (parseToken(Token::less, "expected '<' in integer set") ||
61 parseIntegerSetReference(set) ||
62 parseToken(Token::greater, "expected '>' in integer set"))
63 return Attribute();
64 return IntegerSetAttr::get(set);
65 }
66
67 // Parse an array attribute.
68 case Token::l_square: {
69 consumeToken(Token::l_square);
70
71 SmallVector<Attribute, 4> elements;
72 auto parseElt = [&]() -> ParseResult {
73 elements.push_back(parseAttribute());
74 return elements.back() ? success() : failure();
75 };
76
77 if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
78 return nullptr;
79 return builder.getArrayAttr(elements);
80 }
81
82 // Parse a boolean attribute.
83 case Token::kw_false:
84 consumeToken(Token::kw_false);
85 return builder.getBoolAttr(false);
86 case Token::kw_true:
87 consumeToken(Token::kw_true);
88 return builder.getBoolAttr(true);
89
90 // Parse a dense elements attribute.
91 case Token::kw_dense:
92 return parseDenseElementsAttr(type);
93
94 // Parse a dictionary attribute.
95 case Token::l_brace: {
96 NamedAttrList elements;
97 if (parseAttributeDict(elements))
98 return nullptr;
99 return elements.getDictionary(getContext());
100 }
101
102 // Parse an extended attribute, i.e. alias or dialect attribute.
103 case Token::hash_identifier:
104 return parseExtendedAttr(type);
105
106 // Parse floating point and integer attributes.
107 case Token::floatliteral:
108 return parseFloatAttr(type, /*isNegative=*/false);
109 case Token::integer:
110 return parseDecOrHexAttr(type, /*isNegative=*/false);
111 case Token::minus: {
112 consumeToken(Token::minus);
113 if (getToken().is(Token::integer))
114 return parseDecOrHexAttr(type, /*isNegative=*/true);
115 if (getToken().is(Token::floatliteral))
116 return parseFloatAttr(type, /*isNegative=*/true);
117
118 return (emitError("expected constant integer or floating point value"),
119 nullptr);
120 }
121
122 // Parse a location attribute.
123 case Token::kw_loc: {
124 consumeToken(Token::kw_loc);
125
126 LocationAttr locAttr;
127 if (parseToken(Token::l_paren, "expected '(' in inline location") ||
128 parseLocationInstance(locAttr) ||
129 parseToken(Token::r_paren, "expected ')' in inline location"))
130 return Attribute();
131 return locAttr;
132 }
133
134 // Parse an opaque elements attribute.
135 case Token::kw_opaque:
136 return parseOpaqueElementsAttr(type);
137
138 // Parse a sparse elements attribute.
139 case Token::kw_sparse:
140 return parseSparseElementsAttr(type);
141
142 // Parse a string attribute.
143 case Token::string: {
144 auto val = getToken().getStringValue();
145 consumeToken(Token::string);
146 // Parse the optional trailing colon type if one wasn't explicitly provided.
147 if (!type && consumeIf(Token::colon) && !(type = parseType()))
148 return Attribute();
149
150 return type ? StringAttr::get(val, type)
151 : StringAttr::get(val, getContext());
152 }
153
154 // Parse a symbol reference attribute.
155 case Token::at_identifier: {
156 std::string nameStr = getToken().getSymbolReference();
157 consumeToken(Token::at_identifier);
158
159 // Parse any nested references.
160 std::vector<FlatSymbolRefAttr> nestedRefs;
161 while (getToken().is(Token::colon)) {
162 // Check for the '::' prefix.
163 const char *curPointer = getToken().getLoc().getPointer();
164 consumeToken(Token::colon);
165 if (!consumeIf(Token::colon)) {
166 state.lex.resetPointer(curPointer);
167 consumeToken();
168 break;
169 }
170 // Parse the reference itself.
171 auto curLoc = getToken().getLoc();
172 if (getToken().isNot(Token::at_identifier)) {
173 emitError(curLoc, "expected nested symbol reference identifier");
174 return Attribute();
175 }
176
177 std::string nameStr = getToken().getSymbolReference();
178 consumeToken(Token::at_identifier);
179 nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext()));
180 }
181
182 return builder.getSymbolRefAttr(nameStr, nestedRefs);
183 }
184
185 // Parse a 'unit' attribute.
186 case Token::kw_unit:
187 consumeToken(Token::kw_unit);
188 return builder.getUnitAttr();
189
190 default:
191 // Parse a type attribute.
192 if (Type type = parseType())
193 return TypeAttr::get(type);
194 return nullptr;
195 }
196 }
197
198 /// Parse an optional attribute with the provided type.
parseOptionalAttribute(Attribute & attribute,Type type)199 OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
200 Type type) {
201 switch (getToken().getKind()) {
202 case Token::at_identifier:
203 case Token::floatliteral:
204 case Token::integer:
205 case Token::hash_identifier:
206 case Token::kw_affine_map:
207 case Token::kw_affine_set:
208 case Token::kw_dense:
209 case Token::kw_false:
210 case Token::kw_loc:
211 case Token::kw_opaque:
212 case Token::kw_sparse:
213 case Token::kw_true:
214 case Token::kw_unit:
215 case Token::l_brace:
216 case Token::l_square:
217 case Token::minus:
218 case Token::string:
219 attribute = parseAttribute(type);
220 return success(attribute != nullptr);
221
222 default:
223 // Parse an optional type attribute.
224 Type type;
225 OptionalParseResult result = parseOptionalType(type);
226 if (result.hasValue() && succeeded(*result))
227 attribute = TypeAttr::get(type);
228 return result;
229 }
230 }
parseOptionalAttribute(ArrayAttr & attribute,Type type)231 OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute,
232 Type type) {
233 return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
234 }
parseOptionalAttribute(StringAttr & attribute,Type type)235 OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute,
236 Type type) {
237 return parseOptionalAttributeWithToken(Token::string, attribute, type);
238 }
239
240 /// Attribute dictionary.
241 ///
242 /// attribute-dict ::= `{` `}`
243 /// | `{` attribute-entry (`,` attribute-entry)* `}`
244 /// attribute-entry ::= (bare-id | string-literal) `=` attribute-value
245 ///
parseAttributeDict(NamedAttrList & attributes)246 ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
247 if (parseToken(Token::l_brace, "expected '{' in attribute dictionary"))
248 return failure();
249
250 llvm::SmallDenseSet<Identifier> seenKeys;
251 auto parseElt = [&]() -> ParseResult {
252 // The name of an attribute can either be a bare identifier, or a string.
253 Optional<Identifier> nameId;
254 if (getToken().is(Token::string))
255 nameId = builder.getIdentifier(getToken().getStringValue());
256 else if (getToken().isAny(Token::bare_identifier, Token::inttype) ||
257 getToken().isKeyword())
258 nameId = builder.getIdentifier(getTokenSpelling());
259 else
260 return emitError("expected attribute name");
261 if (!seenKeys.insert(*nameId).second)
262 return emitError("duplicate key '")
263 << *nameId << "' in dictionary attribute";
264 consumeToken();
265
266 // Lazy load a dialect in the context if there is a possible namespace.
267 auto splitName = nameId->strref().split('.');
268 if (!splitName.second.empty())
269 getContext()->getOrLoadDialect(splitName.first);
270
271 // Try to parse the '=' for the attribute value.
272 if (!consumeIf(Token::equal)) {
273 // If there is no '=', we treat this as a unit attribute.
274 attributes.push_back({*nameId, builder.getUnitAttr()});
275 return success();
276 }
277
278 auto attr = parseAttribute();
279 if (!attr)
280 return failure();
281 attributes.push_back({*nameId, attr});
282 return success();
283 };
284
285 if (parseCommaSeparatedListUntil(Token::r_brace, parseElt))
286 return failure();
287
288 return success();
289 }
290
291 /// Parse a float attribute.
parseFloatAttr(Type type,bool isNegative)292 Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
293 auto val = getToken().getFloatingPointValue();
294 if (!val.hasValue())
295 return (emitError("floating point value too large for attribute"), nullptr);
296 consumeToken(Token::floatliteral);
297 if (!type) {
298 // Default to F64 when no type is specified.
299 if (!consumeIf(Token::colon))
300 type = builder.getF64Type();
301 else if (!(type = parseType()))
302 return nullptr;
303 }
304 if (!type.isa<FloatType>())
305 return (emitError("floating point value not valid for specified type"),
306 nullptr);
307 return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue());
308 }
309
310 /// Construct a float attribute bitwise equivalent to the integer literal.
buildHexadecimalFloatLiteral(Parser * p,FloatType type,uint64_t value)311 static Optional<APFloat> buildHexadecimalFloatLiteral(Parser *p, FloatType type,
312 uint64_t value) {
313 if (type.isF64())
314 return APFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value));
315
316 APInt apInt(type.getWidth(), value);
317 if (apInt != value) {
318 p->emitError("hexadecimal float constant out of range for type");
319 return llvm::None;
320 }
321 return APFloat(type.getFloatSemantics(), apInt);
322 }
323
324 /// Construct an APint from a parsed value, a known attribute type and
325 /// sign.
buildAttributeAPInt(Type type,bool isNegative,StringRef spelling)326 static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
327 StringRef spelling) {
328 // Parse the integer value into an APInt that is big enough to hold the value.
329 APInt result;
330 bool isHex = spelling.size() > 1 && spelling[1] == 'x';
331 if (spelling.getAsInteger(isHex ? 0 : 10, result))
332 return llvm::None;
333
334 // Extend or truncate the bitwidth to the right size.
335 unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
336 : type.getIntOrFloatBitWidth();
337 if (width > result.getBitWidth()) {
338 result = result.zext(width);
339 } else if (width < result.getBitWidth()) {
340 // The parser can return an unnecessarily wide result with leading zeros.
341 // This isn't a problem, but truncating off bits is bad.
342 if (result.countLeadingZeros() < result.getBitWidth() - width)
343 return llvm::None;
344
345 result = result.trunc(width);
346 }
347
348 if (isNegative) {
349 // The value is negative, we have an overflow if the sign bit is not set
350 // in the negated apInt.
351 result.negate();
352 if (!result.isSignBitSet())
353 return llvm::None;
354 } else if ((type.isSignedInteger() || type.isIndex()) &&
355 result.isSignBitSet()) {
356 // The value is a positive signed integer or index,
357 // we have an overflow if the sign bit is set.
358 return llvm::None;
359 }
360
361 return result;
362 }
363
364 /// Parse a decimal or a hexadecimal literal, which can be either an integer
365 /// or a float attribute.
parseDecOrHexAttr(Type type,bool isNegative)366 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
367 // Remember if the literal is hexadecimal.
368 StringRef spelling = getToken().getSpelling();
369 auto loc = state.curToken.getLoc();
370 bool isHex = spelling.size() > 1 && spelling[1] == 'x';
371
372 consumeToken(Token::integer);
373 if (!type) {
374 // Default to i64 if not type is specified.
375 if (!consumeIf(Token::colon))
376 type = builder.getIntegerType(64);
377 else if (!(type = parseType()))
378 return nullptr;
379 }
380
381 if (auto floatType = type.dyn_cast<FloatType>()) {
382 if (isNegative)
383 return emitError(
384 loc,
385 "hexadecimal float literal should not have a leading minus"),
386 nullptr;
387 if (!isHex) {
388 emitError(loc, "unexpected decimal integer literal for a float attribute")
389 .attachNote()
390 << "add a trailing dot to make the literal a float";
391 return nullptr;
392 }
393
394 auto val = Token::getUInt64IntegerValue(spelling);
395 if (!val.hasValue())
396 return emitError("integer constant out of range for attribute"), nullptr;
397
398 // Construct a float attribute bitwise equivalent to the integer literal.
399 Optional<APFloat> apVal =
400 buildHexadecimalFloatLiteral(this, floatType, *val);
401 return apVal ? FloatAttr::get(floatType, *apVal) : Attribute();
402 }
403
404 if (!type.isa<IntegerType, IndexType>())
405 return emitError(loc, "integer literal not valid for specified type"),
406 nullptr;
407
408 if (isNegative && type.isUnsignedInteger()) {
409 emitError(loc,
410 "negative integer literal not valid for unsigned integer type");
411 return nullptr;
412 }
413
414 Optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling);
415 if (!apInt)
416 return emitError(loc, "integer constant out of range for attribute"),
417 nullptr;
418 return builder.getIntegerAttr(type, *apInt);
419 }
420
421 //===----------------------------------------------------------------------===//
422 // TensorLiteralParser
423 //===----------------------------------------------------------------------===//
424
425 /// Parse elements values stored within a hex string. On success, the values are
426 /// stored into 'result'.
parseElementAttrHexValues(Parser & parser,Token tok,std::string & result)427 static ParseResult parseElementAttrHexValues(Parser &parser, Token tok,
428 std::string &result) {
429 if (Optional<std::string> value = tok.getHexStringValue()) {
430 result = std::move(*value);
431 return success();
432 }
433 return parser.emitError(
434 tok.getLoc(), "expected string containing hex digits starting with `0x`");
435 }
436
437 namespace {
438 /// This class implements a parser for TensorLiterals. A tensor literal is
439 /// either a single element (e.g, 5) or a multi-dimensional list of elements
440 /// (e.g., [[5, 5]]).
441 class TensorLiteralParser {
442 public:
TensorLiteralParser(Parser & p)443 TensorLiteralParser(Parser &p) : p(p) {}
444
445 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
446 /// may also parse a tensor literal that is store as a hex string.
447 ParseResult parse(bool allowHex);
448
449 /// Build a dense attribute instance with the parsed elements and the given
450 /// shaped type.
451 DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type);
452
getShape() const453 ArrayRef<int64_t> getShape() const { return shape; }
454
455 private:
456 /// Get the parsed elements for an integer attribute.
457 ParseResult getIntAttrElements(llvm::SMLoc loc, Type eltTy,
458 std::vector<APInt> &intValues);
459
460 /// Get the parsed elements for a float attribute.
461 ParseResult getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
462 std::vector<APFloat> &floatValues);
463
464 /// Build a Dense String attribute for the given type.
465 DenseElementsAttr getStringAttr(llvm::SMLoc loc, ShapedType type, Type eltTy);
466
467 /// Build a Dense attribute with hex data for the given type.
468 DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type);
469
470 /// Parse a single element, returning failure if it isn't a valid element
471 /// literal. For example:
472 /// parseElement(1) -> Success, 1
473 /// parseElement([1]) -> Failure
474 ParseResult parseElement();
475
476 /// Parse a list of either lists or elements, returning the dimensions of the
477 /// parsed sub-tensors in dims. For example:
478 /// parseList([1, 2, 3]) -> Success, [3]
479 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
480 /// parseList([[1, 2], 3]) -> Failure
481 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
482 ParseResult parseList(SmallVectorImpl<int64_t> &dims);
483
484 /// Parse a literal that was printed as a hex string.
485 ParseResult parseHexElements();
486
487 Parser &p;
488
489 /// The shape inferred from the parsed elements.
490 SmallVector<int64_t, 4> shape;
491
492 /// Storage used when parsing elements, this is a pair of <is_negated, token>.
493 std::vector<std::pair<bool, Token>> storage;
494
495 /// Storage used when parsing elements that were stored as hex values.
496 Optional<Token> hexStorage;
497 };
498 } // end anonymous namespace
499
500 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
501 /// may also parse a tensor literal that is store as a hex string.
parse(bool allowHex)502 ParseResult TensorLiteralParser::parse(bool allowHex) {
503 // If hex is allowed, check for a string literal.
504 if (allowHex && p.getToken().is(Token::string)) {
505 hexStorage = p.getToken();
506 p.consumeToken(Token::string);
507 return success();
508 }
509 // Otherwise, parse a list or an individual element.
510 if (p.getToken().is(Token::l_square))
511 return parseList(shape);
512 return parseElement();
513 }
514
515 /// Build a dense attribute instance with the parsed elements and the given
516 /// shaped type.
getAttr(llvm::SMLoc loc,ShapedType type)517 DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
518 ShapedType type) {
519 Type eltType = type.getElementType();
520
521 // Check to see if we parse the literal from a hex string.
522 if (hexStorage.hasValue() &&
523 (eltType.isIntOrFloat() || eltType.isa<ComplexType>()))
524 return getHexAttr(loc, type);
525
526 // Check that the parsed storage size has the same number of elements to the
527 // type, or is a known splat.
528 if (!shape.empty() && getShape() != type.getShape()) {
529 p.emitError(loc) << "inferred shape of elements literal ([" << getShape()
530 << "]) does not match type ([" << type.getShape() << "])";
531 return nullptr;
532 }
533
534 // Handle complex types in the specific element type cases below.
535 bool isComplex = false;
536 if (ComplexType complexTy = eltType.dyn_cast<ComplexType>()) {
537 eltType = complexTy.getElementType();
538 isComplex = true;
539 }
540
541 // Handle integer and index types.
542 if (eltType.isIntOrIndex()) {
543 std::vector<APInt> intValues;
544 if (failed(getIntAttrElements(loc, eltType, intValues)))
545 return nullptr;
546 if (isComplex) {
547 // If this is a complex, treat the parsed values as complex values.
548 auto complexData = llvm::makeArrayRef(
549 reinterpret_cast<std::complex<APInt> *>(intValues.data()),
550 intValues.size() / 2);
551 return DenseElementsAttr::get(type, complexData);
552 }
553 return DenseElementsAttr::get(type, intValues);
554 }
555 // Handle floating point types.
556 if (FloatType floatTy = eltType.dyn_cast<FloatType>()) {
557 std::vector<APFloat> floatValues;
558 if (failed(getFloatAttrElements(loc, floatTy, floatValues)))
559 return nullptr;
560 if (isComplex) {
561 // If this is a complex, treat the parsed values as complex values.
562 auto complexData = llvm::makeArrayRef(
563 reinterpret_cast<std::complex<APFloat> *>(floatValues.data()),
564 floatValues.size() / 2);
565 return DenseElementsAttr::get(type, complexData);
566 }
567 return DenseElementsAttr::get(type, floatValues);
568 }
569
570 // Other types are assumed to be string representations.
571 return getStringAttr(loc, type, type.getElementType());
572 }
573
574 /// Build a Dense Integer attribute for the given type.
575 ParseResult
getIntAttrElements(llvm::SMLoc loc,Type eltTy,std::vector<APInt> & intValues)576 TensorLiteralParser::getIntAttrElements(llvm::SMLoc loc, Type eltTy,
577 std::vector<APInt> &intValues) {
578 intValues.reserve(storage.size());
579 bool isUintType = eltTy.isUnsignedInteger();
580 for (const auto &signAndToken : storage) {
581 bool isNegative = signAndToken.first;
582 const Token &token = signAndToken.second;
583 auto tokenLoc = token.getLoc();
584
585 if (isNegative && isUintType) {
586 return p.emitError(tokenLoc)
587 << "expected unsigned integer elements, but parsed negative value";
588 }
589
590 // Check to see if floating point values were parsed.
591 if (token.is(Token::floatliteral)) {
592 return p.emitError(tokenLoc)
593 << "expected integer elements, but parsed floating-point";
594 }
595
596 assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
597 "unexpected token type");
598 if (token.isAny(Token::kw_true, Token::kw_false)) {
599 if (!eltTy.isInteger(1)) {
600 return p.emitError(tokenLoc)
601 << "expected i1 type for 'true' or 'false' values";
602 }
603 APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false);
604 intValues.push_back(apInt);
605 continue;
606 }
607
608 // Create APInt values for each element with the correct bitwidth.
609 Optional<APInt> apInt =
610 buildAttributeAPInt(eltTy, isNegative, token.getSpelling());
611 if (!apInt)
612 return p.emitError(tokenLoc, "integer constant out of range for type");
613 intValues.push_back(*apInt);
614 }
615 return success();
616 }
617
618 /// Build a Dense Float attribute for the given type.
619 ParseResult
getFloatAttrElements(llvm::SMLoc loc,FloatType eltTy,std::vector<APFloat> & floatValues)620 TensorLiteralParser::getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
621 std::vector<APFloat> &floatValues) {
622 floatValues.reserve(storage.size());
623 for (const auto &signAndToken : storage) {
624 bool isNegative = signAndToken.first;
625 const Token &token = signAndToken.second;
626
627 // Handle hexadecimal float literals.
628 if (token.is(Token::integer) && token.getSpelling().startswith("0x")) {
629 if (isNegative) {
630 return p.emitError(token.getLoc())
631 << "hexadecimal float literal should not have a leading minus";
632 }
633 auto val = token.getUInt64IntegerValue();
634 if (!val.hasValue()) {
635 return p.emitError(
636 "hexadecimal float constant out of range for attribute");
637 }
638 Optional<APFloat> apVal = buildHexadecimalFloatLiteral(&p, eltTy, *val);
639 if (!apVal)
640 return failure();
641 floatValues.push_back(*apVal);
642 continue;
643 }
644
645 // Check to see if any decimal integers or booleans were parsed.
646 if (!token.is(Token::floatliteral))
647 return p.emitError()
648 << "expected floating-point elements, but parsed integer";
649
650 // Build the float values from tokens.
651 auto val = token.getFloatingPointValue();
652 if (!val.hasValue())
653 return p.emitError("floating point value too large for attribute");
654
655 APFloat apVal(isNegative ? -*val : *val);
656 if (!eltTy.isF64()) {
657 bool unused;
658 apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
659 &unused);
660 }
661 floatValues.push_back(apVal);
662 }
663 return success();
664 }
665
666 /// Build a Dense String attribute for the given type.
getStringAttr(llvm::SMLoc loc,ShapedType type,Type eltTy)667 DenseElementsAttr TensorLiteralParser::getStringAttr(llvm::SMLoc loc,
668 ShapedType type,
669 Type eltTy) {
670 if (hexStorage.hasValue()) {
671 auto stringValue = hexStorage.getValue().getStringValue();
672 return DenseStringElementsAttr::get(type, {stringValue});
673 }
674
675 std::vector<std::string> stringValues;
676 std::vector<StringRef> stringRefValues;
677 stringValues.reserve(storage.size());
678 stringRefValues.reserve(storage.size());
679
680 for (auto val : storage) {
681 stringValues.push_back(val.second.getStringValue());
682 stringRefValues.push_back(stringValues.back());
683 }
684
685 return DenseStringElementsAttr::get(type, stringRefValues);
686 }
687
688 /// Build a Dense attribute with hex data for the given type.
getHexAttr(llvm::SMLoc loc,ShapedType type)689 DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc,
690 ShapedType type) {
691 Type elementType = type.getElementType();
692 if (!elementType.isIntOrIndexOrFloat() && !elementType.isa<ComplexType>()) {
693 p.emitError(loc)
694 << "expected floating-point, integer, or complex element type, got "
695 << elementType;
696 return nullptr;
697 }
698
699 std::string data;
700 if (parseElementAttrHexValues(p, hexStorage.getValue(), data))
701 return nullptr;
702
703 ArrayRef<char> rawData(data.data(), data.size());
704 bool detectedSplat = false;
705 if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) {
706 p.emitError(loc) << "elements hex data size is invalid for provided type: "
707 << type;
708 return nullptr;
709 }
710
711 if (llvm::support::endian::system_endianness() ==
712 llvm::support::endianness::big) {
713 // Convert endianess in big-endian(BE) machines. `rawData` is
714 // little-endian(LE) because HEX in raw data of dense element attribute
715 // is always LE format. It is converted into BE here to be used in BE
716 // machines.
717 SmallVector<char, 64> outDataVec(rawData.size());
718 MutableArrayRef<char> convRawData(outDataVec);
719 DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
720 rawData, convRawData, type);
721 return DenseElementsAttr::getFromRawBuffer(type, convRawData,
722 detectedSplat);
723 }
724
725 return DenseElementsAttr::getFromRawBuffer(type, rawData, detectedSplat);
726 }
727
parseElement()728 ParseResult TensorLiteralParser::parseElement() {
729 switch (p.getToken().getKind()) {
730 // Parse a boolean element.
731 case Token::kw_true:
732 case Token::kw_false:
733 case Token::floatliteral:
734 case Token::integer:
735 storage.emplace_back(/*isNegative=*/false, p.getToken());
736 p.consumeToken();
737 break;
738
739 // Parse a signed integer or a negative floating-point element.
740 case Token::minus:
741 p.consumeToken(Token::minus);
742 if (!p.getToken().isAny(Token::floatliteral, Token::integer))
743 return p.emitError("expected integer or floating point literal");
744 storage.emplace_back(/*isNegative=*/true, p.getToken());
745 p.consumeToken();
746 break;
747
748 case Token::string:
749 storage.emplace_back(/*isNegative=*/false, p.getToken());
750 p.consumeToken();
751 break;
752
753 // Parse a complex element of the form '(' element ',' element ')'.
754 case Token::l_paren:
755 p.consumeToken(Token::l_paren);
756 if (parseElement() ||
757 p.parseToken(Token::comma, "expected ',' between complex elements") ||
758 parseElement() ||
759 p.parseToken(Token::r_paren, "expected ')' after complex elements"))
760 return failure();
761 break;
762
763 default:
764 return p.emitError("expected element literal of primitive type");
765 }
766
767 return success();
768 }
769
770 /// Parse a list of either lists or elements, returning the dimensions of the
771 /// parsed sub-tensors in dims. For example:
772 /// parseList([1, 2, 3]) -> Success, [3]
773 /// parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
774 /// parseList([[1, 2], 3]) -> Failure
775 /// parseList([[1, [2, 3]], [4, [5]]]) -> Failure
parseList(SmallVectorImpl<int64_t> & dims)776 ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
777 p.consumeToken(Token::l_square);
778
779 auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
780 const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
781 if (prevDims == newDims)
782 return success();
783 return p.emitError("tensor literal is invalid; ranks are not consistent "
784 "between elements");
785 };
786
787 bool first = true;
788 SmallVector<int64_t, 4> newDims;
789 unsigned size = 0;
790 auto parseCommaSeparatedList = [&]() -> ParseResult {
791 SmallVector<int64_t, 4> thisDims;
792 if (p.getToken().getKind() == Token::l_square) {
793 if (parseList(thisDims))
794 return failure();
795 } else if (parseElement()) {
796 return failure();
797 }
798 ++size;
799 if (!first)
800 return checkDims(newDims, thisDims);
801 newDims = thisDims;
802 first = false;
803 return success();
804 };
805 if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList))
806 return failure();
807
808 // Return the sublists' dimensions with 'size' prepended.
809 dims.clear();
810 dims.push_back(size);
811 dims.append(newDims.begin(), newDims.end());
812 return success();
813 }
814
815 //===----------------------------------------------------------------------===//
816 // ElementsAttr Parser
817 //===----------------------------------------------------------------------===//
818
819 /// Parse a dense elements attribute.
parseDenseElementsAttr(Type attrType)820 Attribute Parser::parseDenseElementsAttr(Type attrType) {
821 auto attribLoc = getToken().getLoc();
822 consumeToken(Token::kw_dense);
823 if (parseToken(Token::less, "expected '<' after 'dense'"))
824 return nullptr;
825
826 // Parse the literal data if necessary.
827 TensorLiteralParser literalParser(*this);
828 if (!consumeIf(Token::greater)) {
829 if (literalParser.parse(/*allowHex=*/true) ||
830 parseToken(Token::greater, "expected '>'"))
831 return nullptr;
832 }
833
834 // If the type is specified `parseElementsLiteralType` will not parse a type.
835 // Use the attribute location as the location for error reporting in that
836 // case.
837 auto loc = attrType ? attribLoc : getToken().getLoc();
838 auto type = parseElementsLiteralType(attrType);
839 if (!type)
840 return nullptr;
841 return literalParser.getAttr(loc, type);
842 }
843
844 /// Parse an opaque elements attribute.
parseOpaqueElementsAttr(Type attrType)845 Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
846 consumeToken(Token::kw_opaque);
847 if (parseToken(Token::less, "expected '<' after 'opaque'"))
848 return nullptr;
849
850 if (getToken().isNot(Token::string))
851 return (emitError("expected dialect namespace"), nullptr);
852
853 auto name = getToken().getStringValue();
854 // Lazy load a dialect in the context if there is a possible namespace.
855 Dialect *dialect = builder.getContext()->getOrLoadDialect(name);
856
857 // TODO: Allow for having an unknown dialect on an opaque
858 // attribute. Otherwise, it can't be roundtripped without having the dialect
859 // registered.
860 if (!dialect)
861 return (emitError("no registered dialect with namespace '" + name + "'"),
862 nullptr);
863 consumeToken(Token::string);
864
865 if (parseToken(Token::comma, "expected ','"))
866 return nullptr;
867
868 Token hexTok = getToken();
869 if (parseToken(Token::string, "elements hex string should start with '0x'") ||
870 parseToken(Token::greater, "expected '>'"))
871 return nullptr;
872 auto type = parseElementsLiteralType(attrType);
873 if (!type)
874 return nullptr;
875
876 std::string data;
877 if (parseElementAttrHexValues(*this, hexTok, data))
878 return nullptr;
879 return OpaqueElementsAttr::get(dialect, type, data);
880 }
881
882 /// Shaped type for elements attribute.
883 ///
884 /// elements-literal-type ::= vector-type | ranked-tensor-type
885 ///
886 /// This method also checks the type has static shape.
parseElementsLiteralType(Type type)887 ShapedType Parser::parseElementsLiteralType(Type type) {
888 // If the user didn't provide a type, parse the colon type for the literal.
889 if (!type) {
890 if (parseToken(Token::colon, "expected ':'"))
891 return nullptr;
892 if (!(type = parseType()))
893 return nullptr;
894 }
895
896 if (!type.isa<RankedTensorType, VectorType>()) {
897 emitError("elements literal must be a ranked tensor or vector type");
898 return nullptr;
899 }
900
901 auto sType = type.cast<ShapedType>();
902 if (!sType.hasStaticShape())
903 return (emitError("elements literal type must have static shape"), nullptr);
904
905 return sType;
906 }
907
908 /// Parse a sparse elements attribute.
parseSparseElementsAttr(Type attrType)909 Attribute Parser::parseSparseElementsAttr(Type attrType) {
910 consumeToken(Token::kw_sparse);
911 if (parseToken(Token::less, "Expected '<' after 'sparse'"))
912 return nullptr;
913
914 // Check for the case where all elements are sparse. The indices are
915 // represented by a 2-dimensional shape where the second dimension is the rank
916 // of the type.
917 Type indiceEltType = builder.getIntegerType(64);
918 if (consumeIf(Token::greater)) {
919 ShapedType type = parseElementsLiteralType(attrType);
920 if (!type)
921 return nullptr;
922
923 // Construct the sparse elements attr using zero element indice/value
924 // attributes.
925 ShapedType indicesType =
926 RankedTensorType::get({0, type.getRank()}, indiceEltType);
927 ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
928 return SparseElementsAttr::get(
929 type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
930 DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
931 }
932
933 /// Parse the indices. We don't allow hex values here as we may need to use
934 /// the inferred shape.
935 auto indicesLoc = getToken().getLoc();
936 TensorLiteralParser indiceParser(*this);
937 if (indiceParser.parse(/*allowHex=*/false))
938 return nullptr;
939
940 if (parseToken(Token::comma, "expected ','"))
941 return nullptr;
942
943 /// Parse the values.
944 auto valuesLoc = getToken().getLoc();
945 TensorLiteralParser valuesParser(*this);
946 if (valuesParser.parse(/*allowHex=*/true))
947 return nullptr;
948
949 if (parseToken(Token::greater, "expected '>'"))
950 return nullptr;
951
952 auto type = parseElementsLiteralType(attrType);
953 if (!type)
954 return nullptr;
955
956 // If the indices are a splat, i.e. the literal parser parsed an element and
957 // not a list, we set the shape explicitly. The indices are represented by a
958 // 2-dimensional shape where the second dimension is the rank of the type.
959 // Given that the parsed indices is a splat, we know that we only have one
960 // indice and thus one for the first dimension.
961 ShapedType indicesType;
962 if (indiceParser.getShape().empty()) {
963 indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
964 } else {
965 // Otherwise, set the shape to the one parsed by the literal parser.
966 indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
967 }
968 auto indices = indiceParser.getAttr(indicesLoc, indicesType);
969
970 // If the values are a splat, set the shape explicitly based on the number of
971 // indices. The number of indices is encoded in the first dimension of the
972 // indice shape type.
973 auto valuesEltType = type.getElementType();
974 ShapedType valuesType =
975 valuesParser.getShape().empty()
976 ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
977 : RankedTensorType::get(valuesParser.getShape(), valuesEltType);
978 auto values = valuesParser.getAttr(valuesLoc, valuesType);
979
980 /// Sanity check.
981 if (valuesType.getRank() != 1)
982 return (emitError("expected 1-d tensor for values"), nullptr);
983
984 auto sameShape = (indicesType.getRank() == 1) ||
985 (type.getRank() == indicesType.getDimSize(1));
986 auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0);
987 if (!sameShape || !sameElementNum) {
988 emitError() << "expected shape ([" << type.getShape()
989 << "]); inferred shape of indices literal (["
990 << indicesType.getShape()
991 << "]); inferred shape of values literal (["
992 << valuesType.getShape() << "])";
993 return nullptr;
994 }
995
996 // Build the sparse elements attribute by the indices and values.
997 return SparseElementsAttr::get(type, indices, values);
998 }
999