1 //===- TypeParser.cpp - MLIR Type Parser Implementation -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the MLIR Types.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "Parser.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BuiltinTypes.h"
16
17 using namespace mlir;
18 using namespace mlir::detail;
19
20 /// Optionally parse a type.
parseOptionalType(Type & type)21 OptionalParseResult Parser::parseOptionalType(Type &type) {
22 // There are many different starting tokens for a type, check them here.
23 switch (getToken().getKind()) {
24 case Token::l_paren:
25 case Token::kw_memref:
26 case Token::kw_tensor:
27 case Token::kw_complex:
28 case Token::kw_tuple:
29 case Token::kw_vector:
30 case Token::inttype:
31 case Token::kw_bf16:
32 case Token::kw_f16:
33 case Token::kw_f32:
34 case Token::kw_f64:
35 case Token::kw_index:
36 case Token::kw_none:
37 case Token::exclamation_identifier:
38 return failure(!(type = parseType()));
39
40 default:
41 return llvm::None;
42 }
43 }
44
45 /// Parse an arbitrary type.
46 ///
47 /// type ::= function-type
48 /// | non-function-type
49 ///
parseType()50 Type Parser::parseType() {
51 if (getToken().is(Token::l_paren))
52 return parseFunctionType();
53 return parseNonFunctionType();
54 }
55
56 /// Parse a function result type.
57 ///
58 /// function-result-type ::= type-list-parens
59 /// | non-function-type
60 ///
parseFunctionResultTypes(SmallVectorImpl<Type> & elements)61 ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) {
62 if (getToken().is(Token::l_paren))
63 return parseTypeListParens(elements);
64
65 Type t = parseNonFunctionType();
66 if (!t)
67 return failure();
68 elements.push_back(t);
69 return success();
70 }
71
72 /// Parse a list of types without an enclosing parenthesis. The list must have
73 /// at least one member.
74 ///
75 /// type-list-no-parens ::= type (`,` type)*
76 ///
parseTypeListNoParens(SmallVectorImpl<Type> & elements)77 ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
78 auto parseElt = [&]() -> ParseResult {
79 auto elt = parseType();
80 elements.push_back(elt);
81 return elt ? success() : failure();
82 };
83
84 return parseCommaSeparatedList(parseElt);
85 }
86
87 /// Parse a parenthesized list of types.
88 ///
89 /// type-list-parens ::= `(` `)`
90 /// | `(` type-list-no-parens `)`
91 ///
parseTypeListParens(SmallVectorImpl<Type> & elements)92 ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
93 if (parseToken(Token::l_paren, "expected '('"))
94 return failure();
95
96 // Handle empty lists.
97 if (getToken().is(Token::r_paren))
98 return consumeToken(), success();
99
100 if (parseTypeListNoParens(elements) ||
101 parseToken(Token::r_paren, "expected ')'"))
102 return failure();
103 return success();
104 }
105
106 /// Parse a complex type.
107 ///
108 /// complex-type ::= `complex` `<` type `>`
109 ///
parseComplexType()110 Type Parser::parseComplexType() {
111 consumeToken(Token::kw_complex);
112
113 // Parse the '<'.
114 if (parseToken(Token::less, "expected '<' in complex type"))
115 return nullptr;
116
117 llvm::SMLoc elementTypeLoc = getToken().getLoc();
118 auto elementType = parseType();
119 if (!elementType ||
120 parseToken(Token::greater, "expected '>' in complex type"))
121 return nullptr;
122 if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
123 return emitError(elementTypeLoc, "invalid element type for complex"),
124 nullptr;
125
126 return ComplexType::get(elementType);
127 }
128
129 /// Parse a function type.
130 ///
131 /// function-type ::= type-list-parens `->` function-result-type
132 ///
parseFunctionType()133 Type Parser::parseFunctionType() {
134 assert(getToken().is(Token::l_paren));
135
136 SmallVector<Type, 4> arguments, results;
137 if (parseTypeListParens(arguments) ||
138 parseToken(Token::arrow, "expected '->' in function type") ||
139 parseFunctionResultTypes(results))
140 return nullptr;
141
142 return builder.getFunctionType(arguments, results);
143 }
144
145 /// Parse the offset and strides from a strided layout specification.
146 ///
147 /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
148 ///
parseStridedLayout(int64_t & offset,SmallVectorImpl<int64_t> & strides)149 ParseResult Parser::parseStridedLayout(int64_t &offset,
150 SmallVectorImpl<int64_t> &strides) {
151 // Parse offset.
152 consumeToken(Token::kw_offset);
153 if (!consumeIf(Token::colon))
154 return emitError("expected colon after `offset` keyword");
155 auto maybeOffset = getToken().getUnsignedIntegerValue();
156 bool question = getToken().is(Token::question);
157 if (!maybeOffset && !question)
158 return emitError("invalid offset");
159 offset = maybeOffset ? static_cast<int64_t>(maybeOffset.getValue())
160 : MemRefType::getDynamicStrideOrOffset();
161 consumeToken();
162
163 if (!consumeIf(Token::comma))
164 return emitError("expected comma after offset value");
165
166 // Parse stride list.
167 if (!consumeIf(Token::kw_strides))
168 return emitError("expected `strides` keyword after offset specification");
169 if (!consumeIf(Token::colon))
170 return emitError("expected colon after `strides` keyword");
171 if (failed(parseStrideList(strides)))
172 return emitError("invalid braces-enclosed stride list");
173 if (llvm::any_of(strides, [](int64_t st) { return st == 0; }))
174 return emitError("invalid memref stride");
175
176 return success();
177 }
178
179 /// Parse a memref type.
180 ///
181 /// memref-type ::= ranked-memref-type | unranked-memref-type
182 ///
183 /// ranked-memref-type ::= `memref` `<` dimension-list-ranked type
184 /// (`,` semi-affine-map-composition)? (`,`
185 /// memory-space)? `>`
186 ///
187 /// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
188 ///
189 /// semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
190 /// memory-space ::= integer-literal /* | TODO: address-space-id */
191 ///
parseMemRefType()192 Type Parser::parseMemRefType() {
193 consumeToken(Token::kw_memref);
194
195 if (parseToken(Token::less, "expected '<' in memref type"))
196 return nullptr;
197
198 bool isUnranked;
199 SmallVector<int64_t, 4> dimensions;
200
201 if (consumeIf(Token::star)) {
202 // This is an unranked memref type.
203 isUnranked = true;
204 if (parseXInDimensionList())
205 return nullptr;
206
207 } else {
208 isUnranked = false;
209 if (parseDimensionListRanked(dimensions))
210 return nullptr;
211 }
212
213 // Parse the element type.
214 auto typeLoc = getToken().getLoc();
215 auto elementType = parseType();
216 if (!elementType)
217 return nullptr;
218
219 // Check that memref is formed from allowed types.
220 if (!elementType.isIntOrIndexOrFloat() &&
221 !elementType.isa<VectorType, ComplexType>())
222 return emitError(typeLoc, "invalid memref element type"), nullptr;
223
224 // Parse semi-affine-map-composition.
225 SmallVector<AffineMap, 2> affineMapComposition;
226 Optional<unsigned> memorySpace;
227 unsigned numDims = dimensions.size();
228
229 auto parseElt = [&]() -> ParseResult {
230 // Check for the memory space.
231 if (getToken().is(Token::integer)) {
232 if (memorySpace)
233 return emitError("multiple memory spaces specified in memref type");
234 memorySpace = getToken().getUnsignedIntegerValue();
235 if (!memorySpace.hasValue())
236 return emitError("invalid memory space in memref type");
237 consumeToken(Token::integer);
238 return success();
239 }
240 if (isUnranked)
241 return emitError("cannot have affine map for unranked memref type");
242 if (memorySpace)
243 return emitError("expected memory space to be last in memref type");
244
245 AffineMap map;
246 llvm::SMLoc mapLoc = getToken().getLoc();
247 if (getToken().is(Token::kw_offset)) {
248 int64_t offset;
249 SmallVector<int64_t, 4> strides;
250 if (failed(parseStridedLayout(offset, strides)))
251 return failure();
252 // Construct strided affine map.
253 map = makeStridedLinearLayoutMap(strides, offset, state.context);
254 } else {
255 // Parse an affine map attribute.
256 auto affineMap = parseAttribute();
257 if (!affineMap)
258 return failure();
259 auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>();
260 if (!affineMapAttr)
261 return emitError("expected affine map in memref type");
262 map = affineMapAttr.getValue();
263 }
264
265 if (map.getNumDims() != numDims) {
266 size_t i = affineMapComposition.size();
267 return emitError(mapLoc, "memref affine map dimension mismatch between ")
268 << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
269 << " and affine map" << i + 1 << ": " << numDims
270 << " != " << map.getNumDims();
271 }
272 numDims = map.getNumResults();
273 affineMapComposition.push_back(map);
274 return success();
275 };
276
277 // Parse a list of mappings and address space if present.
278 if (!consumeIf(Token::greater)) {
279 // Parse comma separated list of affine maps, followed by memory space.
280 if (parseToken(Token::comma, "expected ',' or '>' in memref type") ||
281 parseCommaSeparatedListUntil(Token::greater, parseElt,
282 /*allowEmptyList=*/false)) {
283 return nullptr;
284 }
285 }
286
287 if (isUnranked)
288 return UnrankedMemRefType::get(elementType, memorySpace.getValueOr(0));
289
290 return MemRefType::get(dimensions, elementType, affineMapComposition,
291 memorySpace.getValueOr(0));
292 }
293
294 /// Parse any type except the function type.
295 ///
296 /// non-function-type ::= integer-type
297 /// | index-type
298 /// | float-type
299 /// | extended-type
300 /// | vector-type
301 /// | tensor-type
302 /// | memref-type
303 /// | complex-type
304 /// | tuple-type
305 /// | none-type
306 ///
307 /// index-type ::= `index`
308 /// float-type ::= `f16` | `bf16` | `f32` | `f64`
309 /// none-type ::= `none`
310 ///
parseNonFunctionType()311 Type Parser::parseNonFunctionType() {
312 switch (getToken().getKind()) {
313 default:
314 return (emitError("expected non-function type"), nullptr);
315 case Token::kw_memref:
316 return parseMemRefType();
317 case Token::kw_tensor:
318 return parseTensorType();
319 case Token::kw_complex:
320 return parseComplexType();
321 case Token::kw_tuple:
322 return parseTupleType();
323 case Token::kw_vector:
324 return parseVectorType();
325 // integer-type
326 case Token::inttype: {
327 auto width = getToken().getIntTypeBitwidth();
328 if (!width.hasValue())
329 return (emitError("invalid integer width"), nullptr);
330 if (width.getValue() > IntegerType::kMaxWidth) {
331 emitError(getToken().getLoc(), "integer bitwidth is limited to ")
332 << IntegerType::kMaxWidth << " bits";
333 return nullptr;
334 }
335
336 IntegerType::SignednessSemantics signSemantics = IntegerType::Signless;
337 if (Optional<bool> signedness = getToken().getIntTypeSignedness())
338 signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned;
339
340 consumeToken(Token::inttype);
341 return IntegerType::get(width.getValue(), signSemantics, getContext());
342 }
343
344 // float-type
345 case Token::kw_bf16:
346 consumeToken(Token::kw_bf16);
347 return builder.getBF16Type();
348 case Token::kw_f16:
349 consumeToken(Token::kw_f16);
350 return builder.getF16Type();
351 case Token::kw_f32:
352 consumeToken(Token::kw_f32);
353 return builder.getF32Type();
354 case Token::kw_f64:
355 consumeToken(Token::kw_f64);
356 return builder.getF64Type();
357
358 // index-type
359 case Token::kw_index:
360 consumeToken(Token::kw_index);
361 return builder.getIndexType();
362
363 // none-type
364 case Token::kw_none:
365 consumeToken(Token::kw_none);
366 return builder.getNoneType();
367
368 // extended type
369 case Token::exclamation_identifier:
370 return parseExtendedType();
371 }
372 }
373
374 /// Parse a tensor type.
375 ///
376 /// tensor-type ::= `tensor` `<` dimension-list type `>`
377 /// dimension-list ::= dimension-list-ranked | `*x`
378 ///
parseTensorType()379 Type Parser::parseTensorType() {
380 consumeToken(Token::kw_tensor);
381
382 if (parseToken(Token::less, "expected '<' in tensor type"))
383 return nullptr;
384
385 bool isUnranked;
386 SmallVector<int64_t, 4> dimensions;
387
388 if (consumeIf(Token::star)) {
389 // This is an unranked tensor type.
390 isUnranked = true;
391
392 if (parseXInDimensionList())
393 return nullptr;
394
395 } else {
396 isUnranked = false;
397 if (parseDimensionListRanked(dimensions))
398 return nullptr;
399 }
400
401 // Parse the element type.
402 auto elementTypeLoc = getToken().getLoc();
403 auto elementType = parseType();
404 if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
405 return nullptr;
406 if (!TensorType::isValidElementType(elementType))
407 return emitError(elementTypeLoc, "invalid tensor element type"), nullptr;
408
409 if (isUnranked)
410 return UnrankedTensorType::get(elementType);
411 return RankedTensorType::get(dimensions, elementType);
412 }
413
414 /// Parse a tuple type.
415 ///
416 /// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
417 ///
parseTupleType()418 Type Parser::parseTupleType() {
419 consumeToken(Token::kw_tuple);
420
421 // Parse the '<'.
422 if (parseToken(Token::less, "expected '<' in tuple type"))
423 return nullptr;
424
425 // Check for an empty tuple by directly parsing '>'.
426 if (consumeIf(Token::greater))
427 return TupleType::get(getContext());
428
429 // Parse the element types and the '>'.
430 SmallVector<Type, 4> types;
431 if (parseTypeListNoParens(types) ||
432 parseToken(Token::greater, "expected '>' in tuple type"))
433 return nullptr;
434
435 return TupleType::get(types, getContext());
436 }
437
438 /// Parse a vector type.
439 ///
440 /// vector-type ::= `vector` `<` non-empty-static-dimension-list type `>`
441 /// non-empty-static-dimension-list ::= decimal-literal `x`
442 /// static-dimension-list
443 /// static-dimension-list ::= (decimal-literal `x`)*
444 ///
parseVectorType()445 VectorType Parser::parseVectorType() {
446 consumeToken(Token::kw_vector);
447
448 if (parseToken(Token::less, "expected '<' in vector type"))
449 return nullptr;
450
451 SmallVector<int64_t, 4> dimensions;
452 if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false))
453 return nullptr;
454 if (dimensions.empty())
455 return (emitError("expected dimension size in vector type"), nullptr);
456 if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
457 return emitError(getToken().getLoc(),
458 "vector types must have positive constant sizes"),
459 nullptr;
460
461 // Parse the element type.
462 auto typeLoc = getToken().getLoc();
463 auto elementType = parseType();
464 if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
465 return nullptr;
466 if (!VectorType::isValidElementType(elementType))
467 return emitError(typeLoc, "vector elements must be int or float type"),
468 nullptr;
469
470 return VectorType::get(dimensions, elementType);
471 }
472
473 /// Parse a dimension list of a tensor or memref type. This populates the
474 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and
475 /// errors out on `?` otherwise.
476 ///
477 /// dimension-list-ranked ::= (dimension `x`)*
478 /// dimension ::= `?` | decimal-literal
479 ///
480 /// When `allowDynamic` is not set, this is used to parse:
481 ///
482 /// static-dimension-list ::= (decimal-literal `x`)*
483 ParseResult
parseDimensionListRanked(SmallVectorImpl<int64_t> & dimensions,bool allowDynamic)484 Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
485 bool allowDynamic) {
486 while (getToken().isAny(Token::integer, Token::question)) {
487 if (consumeIf(Token::question)) {
488 if (!allowDynamic)
489 return emitError("expected static shape");
490 dimensions.push_back(-1);
491 } else {
492 // Hexadecimal integer literals (starting with `0x`) are not allowed in
493 // aggregate type declarations. Therefore, `0xf32` should be processed as
494 // a sequence of separate elements `0`, `x`, `f32`.
495 if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
496 // We can get here only if the token is an integer literal. Hexadecimal
497 // integer literals can only start with `0x` (`1x` wouldn't lex as a
498 // literal, just `1` would, at which point we don't get into this
499 // branch).
500 assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
501 dimensions.push_back(0);
502 state.lex.resetPointer(getTokenSpelling().data() + 1);
503 consumeToken();
504 } else {
505 // Make sure this integer value is in bound and valid.
506 auto dimension = getToken().getUnsignedIntegerValue();
507 if (!dimension.hasValue())
508 return emitError("invalid dimension");
509 dimensions.push_back((int64_t)dimension.getValue());
510 consumeToken(Token::integer);
511 }
512 }
513
514 // Make sure we have an 'x' or something like 'xbf32'.
515 if (parseXInDimensionList())
516 return failure();
517 }
518
519 return success();
520 }
521
522 /// Parse an 'x' token in a dimension list, handling the case where the x is
523 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
524 /// token.
parseXInDimensionList()525 ParseResult Parser::parseXInDimensionList() {
526 if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x')
527 return emitError("expected 'x' in dimension list");
528
529 // If we had a prefix of 'x', lex the next token immediately after the 'x'.
530 if (getTokenSpelling().size() != 1)
531 state.lex.resetPointer(getTokenSpelling().data() + 1);
532
533 // Consume the 'x'.
534 consumeToken(Token::bare_identifier);
535
536 return success();
537 }
538
539 // Parse a comma-separated list of dimensions, possibly empty:
540 // stride-list ::= `[` (dimension (`,` dimension)*)? `]`
parseStrideList(SmallVectorImpl<int64_t> & dimensions)541 ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) {
542 if (!consumeIf(Token::l_square))
543 return failure();
544 // Empty list early exit.
545 if (consumeIf(Token::r_square))
546 return success();
547 while (true) {
548 if (consumeIf(Token::question)) {
549 dimensions.push_back(MemRefType::getDynamicStrideOrOffset());
550 } else {
551 // This must be an integer value.
552 int64_t val;
553 if (getToken().getSpelling().getAsInteger(10, val))
554 return emitError("invalid integer value: ") << getToken().getSpelling();
555 // Make sure it is not the one value for `?`.
556 if (ShapedType::isDynamic(val))
557 return emitError("invalid integer value: ")
558 << getToken().getSpelling()
559 << ", use `?` to specify a dynamic dimension";
560 dimensions.push_back(val);
561 consumeToken(Token::integer);
562 }
563 if (!consumeIf(Token::comma))
564 break;
565 }
566 if (!consumeIf(Token::r_square))
567 return failure();
568 return success();
569 }
570