1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
14
15 #include "mlir/Dialect/SPIRV/ParserUtils.h"
16 #include "mlir/Dialect/SPIRV/SPIRVAttributes.h"
17 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
18 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
19 #include "mlir/Dialect/SPIRV/TargetAndABI.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinOps.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/FunctionImplementation.h"
24 #include "mlir/IR/OpImplementation.h"
25 #include "mlir/Interfaces/CallInterfaces.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/bit.h"
28
29 using namespace mlir;
30
31 // TODO: generate these strings using ODS.
32 static constexpr const char kMemoryAccessAttrName[] = "memory_access";
33 static constexpr const char kSourceMemoryAccessAttrName[] =
34 "source_memory_access";
35 static constexpr const char kAlignmentAttrName[] = "alignment";
36 static constexpr const char kSourceAlignmentAttrName[] = "source_alignment";
37 static constexpr const char kBranchWeightAttrName[] = "branch_weights";
38 static constexpr const char kCallee[] = "callee";
39 static constexpr const char kClusterSize[] = "cluster_size";
40 static constexpr const char kControl[] = "control";
41 static constexpr const char kDefaultValueAttrName[] = "default_value";
42 static constexpr const char kExecutionScopeAttrName[] = "execution_scope";
43 static constexpr const char kEqualSemanticsAttrName[] = "equal_semantics";
44 static constexpr const char kFnNameAttrName[] = "fn";
45 static constexpr const char kGroupOperationAttrName[] = "group_operation";
46 static constexpr const char kIndicesAttrName[] = "indices";
47 static constexpr const char kInitializerAttrName[] = "initializer";
48 static constexpr const char kInterfaceAttrName[] = "interface";
49 static constexpr const char kMemoryScopeAttrName[] = "memory_scope";
50 static constexpr const char kSemanticsAttrName[] = "semantics";
51 static constexpr const char kSpecIdAttrName[] = "spec_id";
52 static constexpr const char kTypeAttrName[] = "type";
53 static constexpr const char kUnequalSemanticsAttrName[] = "unequal_semantics";
54 static constexpr const char kValueAttrName[] = "value";
55 static constexpr const char kValuesAttrName[] = "values";
56 static constexpr const char kCompositeSpecConstituentsName[] = "constituents";
57
58 //===----------------------------------------------------------------------===//
59 // Common utility functions
60 //===----------------------------------------------------------------------===//
61
62 /// Returns true if the given op is a function-like op or nested in a
63 /// function-like op without a module-like op in the middle.
isNestedInFunctionLikeOp(Operation * op)64 static bool isNestedInFunctionLikeOp(Operation *op) {
65 if (!op)
66 return false;
67 if (op->hasTrait<OpTrait::SymbolTable>())
68 return false;
69 if (op->hasTrait<OpTrait::FunctionLike>())
70 return true;
71 return isNestedInFunctionLikeOp(op->getParentOp());
72 }
73
74 /// Returns true if the given op is an module-like op that maintains a symbol
75 /// table.
isDirectInModuleLikeOp(Operation * op)76 static bool isDirectInModuleLikeOp(Operation *op) {
77 return op && op->hasTrait<OpTrait::SymbolTable>();
78 }
79
extractValueFromConstOp(Operation * op,int32_t & value)80 static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
81 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
82 if (!constOp) {
83 return failure();
84 }
85 auto valueAttr = constOp.value();
86 auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
87 if (!integerValueAttr) {
88 return failure();
89 }
90 value = integerValueAttr.getInt();
91 return success();
92 }
93
94 template <typename Ty>
95 static ArrayAttr
getStrArrayAttrForEnumList(Builder & builder,ArrayRef<Ty> enumValues,function_ref<StringRef (Ty)> stringifyFn)96 getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
97 function_ref<StringRef(Ty)> stringifyFn) {
98 if (enumValues.empty()) {
99 return nullptr;
100 }
101 SmallVector<StringRef, 1> enumValStrs;
102 enumValStrs.reserve(enumValues.size());
103 for (auto val : enumValues) {
104 enumValStrs.emplace_back(stringifyFn(val));
105 }
106 return builder.getStrArrayAttr(enumValStrs);
107 }
108
109 /// Parses the next string attribute in `parser` as an enumerant of the given
110 /// `EnumClass`.
111 template <typename EnumClass>
112 static ParseResult
parseEnumStrAttr(EnumClass & value,OpAsmParser & parser,StringRef attrName=spirv::attributeName<EnumClass> ())113 parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
114 StringRef attrName = spirv::attributeName<EnumClass>()) {
115 Attribute attrVal;
116 NamedAttrList attr;
117 auto loc = parser.getCurrentLocation();
118 if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
119 attrName, attr)) {
120 return failure();
121 }
122 if (!attrVal.isa<StringAttr>()) {
123 return parser.emitError(loc, "expected ")
124 << attrName << " attribute specified as string";
125 }
126 auto attrOptional =
127 spirv::symbolizeEnum<EnumClass>(attrVal.cast<StringAttr>().getValue());
128 if (!attrOptional) {
129 return parser.emitError(loc, "invalid ")
130 << attrName << " attribute specification: " << attrVal;
131 }
132 value = attrOptional.getValue();
133 return success();
134 }
135
136 /// Parses the next string attribute in `parser` as an enumerant of the given
137 /// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
138 /// attribute with the enum class's name as attribute name.
139 template <typename EnumClass>
140 static ParseResult
parseEnumStrAttr(EnumClass & value,OpAsmParser & parser,OperationState & state,StringRef attrName=spirv::attributeName<EnumClass> ())141 parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
142 StringRef attrName = spirv::attributeName<EnumClass>()) {
143 if (parseEnumStrAttr(value, parser)) {
144 return failure();
145 }
146 state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
147 llvm::bit_cast<int32_t>(value)));
148 return success();
149 }
150
151 /// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
152 /// and inserts the enumerant into `state` as an 32-bit integer attribute with
153 /// the enum class's name as attribute name.
154 template <typename EnumClass>
155 static ParseResult
parseEnumKeywordAttr(EnumClass & value,OpAsmParser & parser,OperationState & state,StringRef attrName=spirv::attributeName<EnumClass> ())156 parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
157 OperationState &state,
158 StringRef attrName = spirv::attributeName<EnumClass>()) {
159 if (parseEnumKeywordAttr(value, parser)) {
160 return failure();
161 }
162 state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
163 llvm::bit_cast<int32_t>(value)));
164 return success();
165 }
166
167 /// Parses Function, Selection and Loop control attributes. If no control is
168 /// specified, "None" is used as a default.
169 template <typename EnumClass>
170 static ParseResult
parseControlAttribute(OpAsmParser & parser,OperationState & state,StringRef attrName=spirv::attributeName<EnumClass> ())171 parseControlAttribute(OpAsmParser &parser, OperationState &state,
172 StringRef attrName = spirv::attributeName<EnumClass>()) {
173 if (succeeded(parser.parseOptionalKeyword(kControl))) {
174 EnumClass control;
175 if (parser.parseLParen() || parseEnumKeywordAttr(control, parser, state) ||
176 parser.parseRParen())
177 return failure();
178 return success();
179 }
180 // Set control to "None" otherwise.
181 Builder builder = parser.getBuilder();
182 state.addAttribute(attrName, builder.getI32IntegerAttr(0));
183 return success();
184 }
185
186 /// Parses optional memory access attributes attached to a memory access
187 /// operand/pointer. Specifically, parses the following syntax:
188 /// (`[` memory-access `]`)?
189 /// where:
190 /// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
191 /// integer-literal | `"NonTemporal"`
parseMemoryAccessAttributes(OpAsmParser & parser,OperationState & state)192 static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
193 OperationState &state) {
194 // Parse an optional list of attributes staring with '['
195 if (parser.parseOptionalLSquare()) {
196 // Nothing to do
197 return success();
198 }
199
200 spirv::MemoryAccess memoryAccessAttr;
201 if (parseEnumStrAttr(memoryAccessAttr, parser, state,
202 kMemoryAccessAttrName)) {
203 return failure();
204 }
205
206 if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
207 // Parse integer attribute for alignment.
208 Attribute alignmentAttr;
209 Type i32Type = parser.getBuilder().getIntegerType(32);
210 if (parser.parseComma() ||
211 parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
212 state.attributes)) {
213 return failure();
214 }
215 }
216 return parser.parseRSquare();
217 }
218
219 // TODO Make sure to merge this and the previous function into one template
220 // parameterized by memory access attribute name and alignment. Doing so now
221 // results in VS2017 in producing an internal error (at the call site) that's
222 // not detailed enough to understand what is happening.
parseSourceMemoryAccessAttributes(OpAsmParser & parser,OperationState & state)223 static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
224 OperationState &state) {
225 // Parse an optional list of attributes staring with '['
226 if (parser.parseOptionalLSquare()) {
227 // Nothing to do
228 return success();
229 }
230
231 spirv::MemoryAccess memoryAccessAttr;
232 if (parseEnumStrAttr(memoryAccessAttr, parser, state,
233 kSourceMemoryAccessAttrName)) {
234 return failure();
235 }
236
237 if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
238 // Parse integer attribute for alignment.
239 Attribute alignmentAttr;
240 Type i32Type = parser.getBuilder().getIntegerType(32);
241 if (parser.parseComma() ||
242 parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
243 state.attributes)) {
244 return failure();
245 }
246 }
247 return parser.parseRSquare();
248 }
249
250 template <typename MemoryOpTy>
printMemoryAccessAttribute(MemoryOpTy memoryOp,OpAsmPrinter & printer,SmallVectorImpl<StringRef> & elidedAttrs,Optional<spirv::MemoryAccess> memoryAccessAtrrValue=None,Optional<uint32_t> alignmentAttrValue=None)251 static void printMemoryAccessAttribute(
252 MemoryOpTy memoryOp, OpAsmPrinter &printer,
253 SmallVectorImpl<StringRef> &elidedAttrs,
254 Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
255 Optional<uint32_t> alignmentAttrValue = None) {
256 // Print optional memory access attribute.
257 if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
258 : memoryOp.memory_access())) {
259 elidedAttrs.push_back(kMemoryAccessAttrName);
260
261 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
262
263 if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
264 // Print integer alignment attribute.
265 if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
266 : memoryOp.alignment())) {
267 elidedAttrs.push_back(kAlignmentAttrName);
268 printer << ", " << alignment;
269 }
270 }
271 printer << "]";
272 }
273 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
274 }
275
276 // TODO Make sure to merge this and the previous function into one template
277 // parameterized by memory access attribute name and alignment. Doing so now
278 // results in VS2017 in producing an internal error (at the call site) that's
279 // not detailed enough to understand what is happening.
280 template <typename MemoryOpTy>
printSourceMemoryAccessAttribute(MemoryOpTy memoryOp,OpAsmPrinter & printer,SmallVectorImpl<StringRef> & elidedAttrs,Optional<spirv::MemoryAccess> memoryAccessAtrrValue=None,Optional<uint32_t> alignmentAttrValue=None)281 static void printSourceMemoryAccessAttribute(
282 MemoryOpTy memoryOp, OpAsmPrinter &printer,
283 SmallVectorImpl<StringRef> &elidedAttrs,
284 Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
285 Optional<uint32_t> alignmentAttrValue = None) {
286
287 printer << ", ";
288
289 // Print optional memory access attribute.
290 if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
291 : memoryOp.memory_access())) {
292 elidedAttrs.push_back(kSourceMemoryAccessAttrName);
293
294 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
295
296 if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
297 // Print integer alignment attribute.
298 if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
299 : memoryOp.alignment())) {
300 elidedAttrs.push_back(kSourceAlignmentAttrName);
301 printer << ", " << alignment;
302 }
303 }
304 printer << "]";
305 }
306 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
307 }
308
verifyCastOp(Operation * op,bool requireSameBitWidth=true,bool skipBitWidthCheck=false)309 static LogicalResult verifyCastOp(Operation *op,
310 bool requireSameBitWidth = true,
311 bool skipBitWidthCheck = false) {
312 // Some CastOps have no limit on bit widths for result and operand type.
313 if (skipBitWidthCheck)
314 return success();
315
316 Type operandType = op->getOperand(0).getType();
317 Type resultType = op->getResult(0).getType();
318
319 // ODS checks that result type and operand type have the same shape.
320 if (auto vectorType = operandType.dyn_cast<VectorType>()) {
321 operandType = vectorType.getElementType();
322 resultType = resultType.cast<VectorType>().getElementType();
323 }
324
325 if (auto coopMatrixType =
326 operandType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
327 operandType = coopMatrixType.getElementType();
328 resultType =
329 resultType.cast<spirv::CooperativeMatrixNVType>().getElementType();
330 }
331
332 auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
333 auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
334 auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
335
336 if (requireSameBitWidth) {
337 if (!isSameBitWidth) {
338 return op->emitOpError(
339 "expected the same bit widths for operand type and result "
340 "type, but provided ")
341 << operandType << " and " << resultType;
342 }
343 return success();
344 }
345
346 if (isSameBitWidth) {
347 return op->emitOpError(
348 "expected the different bit widths for operand type and result "
349 "type, but provided ")
350 << operandType << " and " << resultType;
351 }
352 return success();
353 }
354
355 template <typename MemoryOpTy>
verifyMemoryAccessAttribute(MemoryOpTy memoryOp)356 static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
357 // ODS checks for attributes values. Just need to verify that if the
358 // memory-access attribute is Aligned, then the alignment attribute must be
359 // present.
360 auto *op = memoryOp.getOperation();
361 auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
362 if (!memAccessAttr) {
363 // Alignment attribute shouldn't be present if memory access attribute is
364 // not present.
365 if (op->getAttr(kAlignmentAttrName)) {
366 return memoryOp.emitOpError(
367 "invalid alignment specification without aligned memory access "
368 "specification");
369 }
370 return success();
371 }
372
373 auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
374 auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
375
376 if (!memAccess) {
377 return memoryOp.emitOpError("invalid memory access specifier: ")
378 << memAccessVal;
379 }
380
381 if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
382 if (!op->getAttr(kAlignmentAttrName)) {
383 return memoryOp.emitOpError("missing alignment value");
384 }
385 } else {
386 if (op->getAttr(kAlignmentAttrName)) {
387 return memoryOp.emitOpError(
388 "invalid alignment specification with non-aligned memory access "
389 "specification");
390 }
391 }
392 return success();
393 }
394
395 // TODO Make sure to merge this and the previous function into one template
396 // parameterized by memory access attribute name and alignment. Doing so now
397 // results in VS2017 in producing an internal error (at the call site) that's
398 // not detailed enough to understand what is happening.
399 template <typename MemoryOpTy>
verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp)400 static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
401 // ODS checks for attributes values. Just need to verify that if the
402 // memory-access attribute is Aligned, then the alignment attribute must be
403 // present.
404 auto *op = memoryOp.getOperation();
405 auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
406 if (!memAccessAttr) {
407 // Alignment attribute shouldn't be present if memory access attribute is
408 // not present.
409 if (op->getAttr(kSourceAlignmentAttrName)) {
410 return memoryOp.emitOpError(
411 "invalid alignment specification without aligned memory access "
412 "specification");
413 }
414 return success();
415 }
416
417 auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
418 auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
419
420 if (!memAccess) {
421 return memoryOp.emitOpError("invalid memory access specifier: ")
422 << memAccessVal;
423 }
424
425 if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
426 if (!op->getAttr(kSourceAlignmentAttrName)) {
427 return memoryOp.emitOpError("missing alignment value");
428 }
429 } else {
430 if (op->getAttr(kSourceAlignmentAttrName)) {
431 return memoryOp.emitOpError(
432 "invalid alignment specification with non-aligned memory access "
433 "specification");
434 }
435 }
436 return success();
437 }
438
439 template <typename BarrierOp>
verifyMemorySemantics(BarrierOp op)440 static LogicalResult verifyMemorySemantics(BarrierOp op) {
441 // According to the SPIR-V specification:
442 // "Despite being a mask and allowing multiple bits to be combined, it is
443 // invalid for more than one of these four bits to be set: Acquire, Release,
444 // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
445 // Release semantics is done by setting the AcquireRelease bit, not by setting
446 // two bits."
447 auto memorySemantics = op.memory_semantics();
448 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
449 spirv::MemorySemantics::Release |
450 spirv::MemorySemantics::AcquireRelease |
451 spirv::MemorySemantics::SequentiallyConsistent;
452
453 auto bitCount = llvm::countPopulation(
454 static_cast<uint32_t>(memorySemantics & atMostOneInSet));
455 if (bitCount > 1) {
456 return op.emitError("expected at most one of these four memory constraints "
457 "to be set: `Acquire`, `Release`,"
458 "`AcquireRelease` or `SequentiallyConsistent`");
459 }
460 return success();
461 }
462
463 template <typename LoadStoreOpTy>
verifyLoadStorePtrAndValTypes(LoadStoreOpTy op,Value ptr,Value val)464 static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
465 Value val) {
466 // ODS already checks ptr is spirv::PointerType. Just check that the pointee
467 // type of the pointer and the type of the value are the same
468 //
469 // TODO: Check that the value type satisfies restrictions of
470 // SPIR-V OpLoad/OpStore operations
471 if (val.getType() !=
472 ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
473 return op.emitOpError("mismatch in result type and pointer type");
474 }
475 return success();
476 }
477
478 template <typename BlockReadWriteOpTy>
verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,Value ptr,Value val)479 static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
480 Value ptr, Value val) {
481 auto valType = val.getType();
482 if (auto valVecTy = valType.dyn_cast<VectorType>())
483 valType = valVecTy.getElementType();
484
485 if (valType != ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
486 return op.emitOpError("mismatch in result type and pointer type");
487 }
488 return success();
489 }
490
parseVariableDecorations(OpAsmParser & parser,OperationState & state)491 static ParseResult parseVariableDecorations(OpAsmParser &parser,
492 OperationState &state) {
493 auto builtInName = llvm::convertToSnakeFromCamelCase(
494 stringifyDecoration(spirv::Decoration::BuiltIn));
495 if (succeeded(parser.parseOptionalKeyword("bind"))) {
496 Attribute set, binding;
497 // Parse optional descriptor binding
498 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
499 stringifyDecoration(spirv::Decoration::DescriptorSet));
500 auto bindingName = llvm::convertToSnakeFromCamelCase(
501 stringifyDecoration(spirv::Decoration::Binding));
502 Type i32Type = parser.getBuilder().getIntegerType(32);
503 if (parser.parseLParen() ||
504 parser.parseAttribute(set, i32Type, descriptorSetName,
505 state.attributes) ||
506 parser.parseComma() ||
507 parser.parseAttribute(binding, i32Type, bindingName,
508 state.attributes) ||
509 parser.parseRParen()) {
510 return failure();
511 }
512 } else if (succeeded(parser.parseOptionalKeyword(builtInName))) {
513 StringAttr builtIn;
514 if (parser.parseLParen() ||
515 parser.parseAttribute(builtIn, builtInName, state.attributes) ||
516 parser.parseRParen()) {
517 return failure();
518 }
519 }
520
521 // Parse other attributes
522 if (parser.parseOptionalAttrDict(state.attributes))
523 return failure();
524
525 return success();
526 }
527
printVariableDecorations(Operation * op,OpAsmPrinter & printer,SmallVectorImpl<StringRef> & elidedAttrs)528 static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
529 SmallVectorImpl<StringRef> &elidedAttrs) {
530 // Print optional descriptor binding
531 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
532 stringifyDecoration(spirv::Decoration::DescriptorSet));
533 auto bindingName = llvm::convertToSnakeFromCamelCase(
534 stringifyDecoration(spirv::Decoration::Binding));
535 auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
536 auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
537 if (descriptorSet && binding) {
538 elidedAttrs.push_back(descriptorSetName);
539 elidedAttrs.push_back(bindingName);
540 printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
541 << ")";
542 }
543
544 // Print BuiltIn attribute if present
545 auto builtInName = llvm::convertToSnakeFromCamelCase(
546 stringifyDecoration(spirv::Decoration::BuiltIn));
547 if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
548 printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
549 elidedAttrs.push_back(builtInName);
550 }
551
552 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
553 }
554
555 // Get bit width of types.
getBitWidth(Type type)556 static unsigned getBitWidth(Type type) {
557 if (type.isa<spirv::PointerType>()) {
558 // Just return 64 bits for pointer types for now.
559 // TODO: Make sure not caller relies on the actual pointer width value.
560 return 64;
561 }
562
563 if (type.isIntOrFloat())
564 return type.getIntOrFloatBitWidth();
565
566 if (auto vectorType = type.dyn_cast<VectorType>()) {
567 assert(vectorType.getElementType().isIntOrFloat());
568 return vectorType.getNumElements() *
569 vectorType.getElementType().getIntOrFloatBitWidth();
570 }
571 llvm_unreachable("unhandled bit width computation for type");
572 }
573
574 /// Walks the given type hierarchy with the given indices, potentially down
575 /// to component granularity, to select an element type. Returns null type and
576 /// emits errors with the given loc on failure.
577 static Type
getElementType(Type type,ArrayRef<int32_t> indices,function_ref<InFlightDiagnostic (StringRef)> emitErrorFn)578 getElementType(Type type, ArrayRef<int32_t> indices,
579 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
580 if (indices.empty()) {
581 emitErrorFn("expected at least one index for spv.CompositeExtract");
582 return nullptr;
583 }
584
585 for (auto index : indices) {
586 if (auto cType = type.dyn_cast<spirv::CompositeType>()) {
587 if (cType.hasCompileTimeKnownNumElements() &&
588 (index < 0 ||
589 static_cast<uint64_t>(index) >= cType.getNumElements())) {
590 emitErrorFn("index ") << index << " out of bounds for " << type;
591 return nullptr;
592 }
593 type = cType.getElementType(index);
594 } else {
595 emitErrorFn("cannot extract from non-composite type ")
596 << type << " with index " << index;
597 return nullptr;
598 }
599 }
600 return type;
601 }
602
603 static Type
getElementType(Type type,Attribute indices,function_ref<InFlightDiagnostic (StringRef)> emitErrorFn)604 getElementType(Type type, Attribute indices,
605 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
606 auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>();
607 if (!indicesArrayAttr) {
608 emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
609 return nullptr;
610 }
611 if (!indicesArrayAttr.size()) {
612 emitErrorFn("expected at least one index for spv.CompositeExtract");
613 return nullptr;
614 }
615
616 SmallVector<int32_t, 2> indexVals;
617 for (auto indexAttr : indicesArrayAttr) {
618 auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>();
619 if (!indexIntAttr) {
620 emitErrorFn("expected an 32-bit integer for index, but found '")
621 << indexAttr << "'";
622 return nullptr;
623 }
624 indexVals.push_back(indexIntAttr.getInt());
625 }
626 return getElementType(type, indexVals, emitErrorFn);
627 }
628
getElementType(Type type,Attribute indices,Location loc)629 static Type getElementType(Type type, Attribute indices, Location loc) {
630 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
631 return ::mlir::emitError(loc, err);
632 };
633 return getElementType(type, indices, errorFn);
634 }
635
getElementType(Type type,Attribute indices,OpAsmParser & parser,llvm::SMLoc loc)636 static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
637 llvm::SMLoc loc) {
638 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
639 return parser.emitError(loc, err);
640 };
641 return getElementType(type, indices, errorFn);
642 }
643
644 /// Returns true if the given `block` only contains one `spv.mlir.merge` op.
isMergeBlock(Block & block)645 static inline bool isMergeBlock(Block &block) {
646 return !block.empty() && std::next(block.begin()) == block.end() &&
647 isa<spirv::MergeOp>(block.front());
648 }
649
650 //===----------------------------------------------------------------------===//
651 // Common parsers and printers
652 //===----------------------------------------------------------------------===//
653
654 // Parses an atomic update op. If the update op does not take a value (like
655 // AtomicIIncrement) `hasValue` must be false.
parseAtomicUpdateOp(OpAsmParser & parser,OperationState & state,bool hasValue)656 static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
657 OperationState &state, bool hasValue) {
658 spirv::Scope scope;
659 spirv::MemorySemantics memoryScope;
660 SmallVector<OpAsmParser::OperandType, 2> operandInfo;
661 OpAsmParser::OperandType ptrInfo, valueInfo;
662 Type type;
663 llvm::SMLoc loc;
664 if (parseEnumStrAttr(scope, parser, state, kMemoryScopeAttrName) ||
665 parseEnumStrAttr(memoryScope, parser, state, kSemanticsAttrName) ||
666 parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
667 parser.getCurrentLocation(&loc) || parser.parseColonType(type))
668 return failure();
669
670 auto ptrType = type.dyn_cast<spirv::PointerType>();
671 if (!ptrType)
672 return parser.emitError(loc, "expected pointer type");
673
674 SmallVector<Type, 2> operandTypes;
675 operandTypes.push_back(ptrType);
676 if (hasValue)
677 operandTypes.push_back(ptrType.getPointeeType());
678 if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(),
679 state.operands))
680 return failure();
681 return parser.addTypeToList(ptrType.getPointeeType(), state.types);
682 }
683
684 // Prints an atomic update op.
printAtomicUpdateOp(Operation * op,OpAsmPrinter & printer)685 static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
686 printer << op->getName() << " \"";
687 auto scopeAttr = op->getAttrOfType<IntegerAttr>(kMemoryScopeAttrName);
688 printer << spirv::stringifyScope(
689 static_cast<spirv::Scope>(scopeAttr.getInt()))
690 << "\" \"";
691 auto memorySemanticsAttr = op->getAttrOfType<IntegerAttr>(kSemanticsAttrName);
692 printer << spirv::stringifyMemorySemantics(
693 static_cast<spirv::MemorySemantics>(
694 memorySemanticsAttr.getInt()))
695 << "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
696 }
697
698 // Verifies an atomic update op.
verifyAtomicUpdateOp(Operation * op)699 static LogicalResult verifyAtomicUpdateOp(Operation *op) {
700 auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
701 auto elementType = ptrType.getPointeeType();
702 if (!elementType.isa<IntegerType>())
703 return op->emitOpError(
704 "pointer operand must point to an integer value, found ")
705 << elementType;
706
707 if (op->getNumOperands() > 1) {
708 auto valueType = op->getOperand(1).getType();
709 if (valueType != elementType)
710 return op->emitOpError("expected value to have the same type as the "
711 "pointer operand's pointee type ")
712 << elementType << ", but found " << valueType;
713 }
714 return success();
715 }
716
parseGroupNonUniformArithmeticOp(OpAsmParser & parser,OperationState & state)717 static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
718 OperationState &state) {
719 spirv::Scope executionScope;
720 spirv::GroupOperation groupOperation;
721 OpAsmParser::OperandType valueInfo;
722 if (parseEnumStrAttr(executionScope, parser, state,
723 kExecutionScopeAttrName) ||
724 parseEnumStrAttr(groupOperation, parser, state,
725 kGroupOperationAttrName) ||
726 parser.parseOperand(valueInfo))
727 return failure();
728
729 Optional<OpAsmParser::OperandType> clusterSizeInfo;
730 if (succeeded(parser.parseOptionalKeyword(kClusterSize))) {
731 clusterSizeInfo = OpAsmParser::OperandType();
732 if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
733 parser.parseRParen())
734 return failure();
735 }
736
737 Type resultType;
738 if (parser.parseColonType(resultType))
739 return failure();
740
741 if (parser.resolveOperand(valueInfo, resultType, state.operands))
742 return failure();
743
744 if (clusterSizeInfo.hasValue()) {
745 Type i32Type = parser.getBuilder().getIntegerType(32);
746 if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
747 return failure();
748 }
749
750 return parser.addTypeToList(resultType, state.types);
751 }
752
printGroupNonUniformArithmeticOp(Operation * groupOp,OpAsmPrinter & printer)753 static void printGroupNonUniformArithmeticOp(Operation *groupOp,
754 OpAsmPrinter &printer) {
755 printer << groupOp->getName() << " \""
756 << stringifyScope(static_cast<spirv::Scope>(
757 groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName)
758 .getInt()))
759 << "\" \""
760 << stringifyGroupOperation(static_cast<spirv::GroupOperation>(
761 groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName)
762 .getInt()))
763 << "\" " << groupOp->getOperand(0);
764
765 if (groupOp->getNumOperands() > 1)
766 printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
767 printer << " : " << groupOp->getResult(0).getType();
768 }
769
verifyGroupNonUniformArithmeticOp(Operation * groupOp)770 static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
771 spirv::Scope scope = static_cast<spirv::Scope>(
772 groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName).getInt());
773 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
774 return groupOp->emitOpError(
775 "execution scope must be 'Workgroup' or 'Subgroup'");
776
777 spirv::GroupOperation operation = static_cast<spirv::GroupOperation>(
778 groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName).getInt());
779 if (operation == spirv::GroupOperation::ClusteredReduce &&
780 groupOp->getNumOperands() == 1)
781 return groupOp->emitOpError("cluster size operand must be provided for "
782 "'ClusteredReduce' group operation");
783 if (groupOp->getNumOperands() > 1) {
784 Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
785 int32_t clusterSize = 0;
786
787 // TODO: support specialization constant here.
788 if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
789 return groupOp->emitOpError(
790 "cluster size operand must come from a constant op");
791
792 if (!llvm::isPowerOf2_32(clusterSize))
793 return groupOp->emitOpError(
794 "cluster size operand must be a power of two");
795 }
796 return success();
797 }
798
parseUnaryOp(OpAsmParser & parser,OperationState & state)799 static ParseResult parseUnaryOp(OpAsmParser &parser, OperationState &state) {
800 OpAsmParser::OperandType operandInfo;
801 Type type;
802 if (parser.parseOperand(operandInfo) || parser.parseColonType(type) ||
803 parser.resolveOperands(operandInfo, type, state.operands)) {
804 return failure();
805 }
806 state.addTypes(type);
807 return success();
808 }
809
printUnaryOp(Operation * unaryOp,OpAsmPrinter & printer)810 static void printUnaryOp(Operation *unaryOp, OpAsmPrinter &printer) {
811 printer << unaryOp->getName() << ' ' << unaryOp->getOperand(0) << " : "
812 << unaryOp->getOperand(0).getType();
813 }
814
815 /// Result of a logical op must be a scalar or vector of boolean type.
getUnaryOpResultType(Builder & builder,Type operandType)816 static Type getUnaryOpResultType(Builder &builder, Type operandType) {
817 Type resultType = builder.getIntegerType(1);
818 if (auto vecType = operandType.dyn_cast<VectorType>()) {
819 return VectorType::get(vecType.getNumElements(), resultType);
820 }
821 return resultType;
822 }
823
parseLogicalUnaryOp(OpAsmParser & parser,OperationState & state)824 static ParseResult parseLogicalUnaryOp(OpAsmParser &parser,
825 OperationState &state) {
826 OpAsmParser::OperandType operandInfo;
827 Type type;
828 if (parser.parseOperand(operandInfo) || parser.parseColonType(type) ||
829 parser.resolveOperand(operandInfo, type, state.operands)) {
830 return failure();
831 }
832 state.addTypes(getUnaryOpResultType(parser.getBuilder(), type));
833 return success();
834 }
835
parseLogicalBinaryOp(OpAsmParser & parser,OperationState & result)836 static ParseResult parseLogicalBinaryOp(OpAsmParser &parser,
837 OperationState &result) {
838 SmallVector<OpAsmParser::OperandType, 2> ops;
839 Type type;
840 if (parser.parseOperandList(ops, 2) || parser.parseColonType(type) ||
841 parser.resolveOperands(ops, type, result.operands)) {
842 return failure();
843 }
844 result.addTypes(getUnaryOpResultType(parser.getBuilder(), type));
845 return success();
846 }
847
printLogicalOp(Operation * logicalOp,OpAsmPrinter & printer)848 static void printLogicalOp(Operation *logicalOp, OpAsmPrinter &printer) {
849 printer << logicalOp->getName() << ' ' << logicalOp->getOperands() << " : "
850 << logicalOp->getOperand(0).getType();
851 }
852
parseShiftOp(OpAsmParser & parser,OperationState & state)853 static ParseResult parseShiftOp(OpAsmParser &parser, OperationState &state) {
854 SmallVector<OpAsmParser::OperandType, 2> operandInfo;
855 Type baseType;
856 Type shiftType;
857 auto loc = parser.getCurrentLocation();
858
859 if (parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
860 parser.parseType(baseType) || parser.parseComma() ||
861 parser.parseType(shiftType) ||
862 parser.resolveOperands(operandInfo, {baseType, shiftType}, loc,
863 state.operands)) {
864 return failure();
865 }
866 state.addTypes(baseType);
867 return success();
868 }
869
printShiftOp(Operation * op,OpAsmPrinter & printer)870 static void printShiftOp(Operation *op, OpAsmPrinter &printer) {
871 Value base = op->getOperand(0);
872 Value shift = op->getOperand(1);
873 printer << op->getName() << ' ' << base << ", " << shift << " : "
874 << base.getType() << ", " << shift.getType();
875 }
876
verifyShiftOp(Operation * op)877 static LogicalResult verifyShiftOp(Operation *op) {
878 if (op->getOperand(0).getType() != op->getResult(0).getType()) {
879 return op->emitError("expected the same type for the first operand and "
880 "result, but provided ")
881 << op->getOperand(0).getType() << " and "
882 << op->getResult(0).getType();
883 }
884 return success();
885 }
886
buildLogicalBinaryOp(OpBuilder & builder,OperationState & state,Value lhs,Value rhs)887 static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state,
888 Value lhs, Value rhs) {
889 assert(lhs.getType() == rhs.getType());
890
891 Type boolType = builder.getI1Type();
892 if (auto vecType = lhs.getType().dyn_cast<VectorType>())
893 boolType = VectorType::get(vecType.getShape(), boolType);
894 state.addTypes(boolType);
895
896 state.addOperands({lhs, rhs});
897 }
898
899 //===----------------------------------------------------------------------===//
900 // spv.AccessChainOp
901 //===----------------------------------------------------------------------===//
902
getElementPtrType(Type type,ValueRange indices,Location baseLoc)903 static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
904 auto ptrType = type.dyn_cast<spirv::PointerType>();
905 if (!ptrType) {
906 emitError(baseLoc, "'spv.AccessChain' op expected a pointer "
907 "to composite type, but provided ")
908 << type;
909 return nullptr;
910 }
911
912 auto resultType = ptrType.getPointeeType();
913 auto resultStorageClass = ptrType.getStorageClass();
914 int32_t index = 0;
915
916 for (auto indexSSA : indices) {
917 auto cType = resultType.dyn_cast<spirv::CompositeType>();
918 if (!cType) {
919 emitError(baseLoc,
920 "'spv.AccessChain' op cannot extract from non-composite type ")
921 << resultType << " with index " << index;
922 return nullptr;
923 }
924 index = 0;
925 if (resultType.isa<spirv::StructType>()) {
926 Operation *op = indexSSA.getDefiningOp();
927 if (!op) {
928 emitError(baseLoc, "'spv.AccessChain' op index must be an "
929 "integer spv.constant to access "
930 "element of spv.struct");
931 return nullptr;
932 }
933
934 // TODO: this should be relaxed to allow
935 // integer literals of other bitwidths.
936 if (failed(extractValueFromConstOp(op, index))) {
937 emitError(baseLoc,
938 "'spv.AccessChain' index must be an integer spv.constant to "
939 "access element of spv.struct, but provided ")
940 << op->getName();
941 return nullptr;
942 }
943 if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
944 emitError(baseLoc, "'spv.AccessChain' op index ")
945 << index << " out of bounds for " << resultType;
946 return nullptr;
947 }
948 }
949 resultType = cType.getElementType(index);
950 }
951 return spirv::PointerType::get(resultType, resultStorageClass);
952 }
953
build(OpBuilder & builder,OperationState & state,Value basePtr,ValueRange indices)954 void spirv::AccessChainOp::build(OpBuilder &builder, OperationState &state,
955 Value basePtr, ValueRange indices) {
956 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
957 assert(type && "Unable to deduce return type based on basePtr and indices");
958 build(builder, state, type, basePtr, indices);
959 }
960
parseAccessChainOp(OpAsmParser & parser,OperationState & state)961 static ParseResult parseAccessChainOp(OpAsmParser &parser,
962 OperationState &state) {
963 OpAsmParser::OperandType ptrInfo;
964 SmallVector<OpAsmParser::OperandType, 4> indicesInfo;
965 Type type;
966 auto loc = parser.getCurrentLocation();
967 SmallVector<Type, 4> indicesTypes;
968
969 if (parser.parseOperand(ptrInfo) ||
970 parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
971 parser.parseColonType(type) ||
972 parser.resolveOperand(ptrInfo, type, state.operands)) {
973 return failure();
974 }
975
976 // Check that the provided indices list is not empty before parsing their
977 // type list.
978 if (indicesInfo.empty()) {
979 return emitError(state.location, "'spv.AccessChain' op expected at "
980 "least one index ");
981 }
982
983 if (parser.parseComma() || parser.parseTypeList(indicesTypes))
984 return failure();
985
986 // Check that the indices types list is not empty and that it has a one-to-one
987 // mapping to the provided indices.
988 if (indicesTypes.size() != indicesInfo.size()) {
989 return emitError(state.location, "'spv.AccessChain' op indices "
990 "types' count must be equal to indices "
991 "info count");
992 }
993
994 if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
995 return failure();
996
997 auto resultType = getElementPtrType(
998 type, llvm::makeArrayRef(state.operands).drop_front(), state.location);
999 if (!resultType) {
1000 return failure();
1001 }
1002
1003 state.addTypes(resultType);
1004 return success();
1005 }
1006
print(spirv::AccessChainOp op,OpAsmPrinter & printer)1007 static void print(spirv::AccessChainOp op, OpAsmPrinter &printer) {
1008 printer << spirv::AccessChainOp::getOperationName() << ' ' << op.base_ptr()
1009 << '[' << op.indices() << "] : " << op.base_ptr().getType() << ", "
1010 << op.indices().getTypes();
1011 }
1012
verify(spirv::AccessChainOp accessChainOp)1013 static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
1014 SmallVector<Value, 4> indices(accessChainOp.indices().begin(),
1015 accessChainOp.indices().end());
1016 auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(),
1017 indices, accessChainOp.getLoc());
1018 if (!resultType) {
1019 return failure();
1020 }
1021
1022 auto providedResultType =
1023 accessChainOp.getType().dyn_cast<spirv::PointerType>();
1024 if (!providedResultType) {
1025 return accessChainOp.emitOpError(
1026 "result type must be a pointer, but provided")
1027 << providedResultType;
1028 }
1029
1030 if (resultType != providedResultType) {
1031 return accessChainOp.emitOpError("invalid result type: expected ")
1032 << resultType << ", but provided " << providedResultType;
1033 }
1034
1035 return success();
1036 }
1037
1038 //===----------------------------------------------------------------------===//
1039 // spv.mlir.addressof
1040 //===----------------------------------------------------------------------===//
1041
build(OpBuilder & builder,OperationState & state,spirv::GlobalVariableOp var)1042 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
1043 spirv::GlobalVariableOp var) {
1044 build(builder, state, var.type(), builder.getSymbolRefAttr(var));
1045 }
1046
verify(spirv::AddressOfOp addressOfOp)1047 static LogicalResult verify(spirv::AddressOfOp addressOfOp) {
1048 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
1049 SymbolTable::lookupNearestSymbolFrom(addressOfOp->getParentOp(),
1050 addressOfOp.variable()));
1051 if (!varOp) {
1052 return addressOfOp.emitOpError("expected spv.globalVariable symbol");
1053 }
1054 if (addressOfOp.pointer().getType() != varOp.type()) {
1055 return addressOfOp.emitOpError(
1056 "result type mismatch with the referenced global variable's type");
1057 }
1058 return success();
1059 }
1060
1061 //===----------------------------------------------------------------------===//
1062 // spv.AtomicCompareExchangeWeak
1063 //===----------------------------------------------------------------------===//
1064
parseAtomicCompareExchangeWeakOp(OpAsmParser & parser,OperationState & state)1065 static ParseResult parseAtomicCompareExchangeWeakOp(OpAsmParser &parser,
1066 OperationState &state) {
1067 spirv::Scope memoryScope;
1068 spirv::MemorySemantics equalSemantics, unequalSemantics;
1069 SmallVector<OpAsmParser::OperandType, 3> operandInfo;
1070 Type type;
1071 if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) ||
1072 parseEnumStrAttr(equalSemantics, parser, state,
1073 kEqualSemanticsAttrName) ||
1074 parseEnumStrAttr(unequalSemantics, parser, state,
1075 kUnequalSemanticsAttrName) ||
1076 parser.parseOperandList(operandInfo, 3))
1077 return failure();
1078
1079 auto loc = parser.getCurrentLocation();
1080 if (parser.parseColonType(type))
1081 return failure();
1082
1083 auto ptrType = type.dyn_cast<spirv::PointerType>();
1084 if (!ptrType)
1085 return parser.emitError(loc, "expected pointer type");
1086
1087 if (parser.resolveOperands(
1088 operandInfo,
1089 {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
1090 parser.getNameLoc(), state.operands))
1091 return failure();
1092
1093 return parser.addTypeToList(ptrType.getPointeeType(), state.types);
1094 }
1095
print(spirv::AtomicCompareExchangeWeakOp atomOp,OpAsmPrinter & printer)1096 static void print(spirv::AtomicCompareExchangeWeakOp atomOp,
1097 OpAsmPrinter &printer) {
1098 printer << spirv::AtomicCompareExchangeWeakOp::getOperationName() << " \""
1099 << stringifyScope(atomOp.memory_scope()) << "\" \""
1100 << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
1101 << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
1102 << atomOp.getOperands() << " : " << atomOp.pointer().getType();
1103 }
1104
verify(spirv::AtomicCompareExchangeWeakOp atomOp)1105 static LogicalResult verify(spirv::AtomicCompareExchangeWeakOp atomOp) {
1106 // According to the spec:
1107 // "The type of Value must be the same as Result Type. The type of the value
1108 // pointed to by Pointer must be the same as Result Type. This type must also
1109 // match the type of Comparator."
1110 if (atomOp.getType() != atomOp.value().getType())
1111 return atomOp.emitOpError("value operand must have the same type as the op "
1112 "result, but found ")
1113 << atomOp.value().getType() << " vs " << atomOp.getType();
1114
1115 if (atomOp.getType() != atomOp.comparator().getType())
1116 return atomOp.emitOpError(
1117 "comparator operand must have the same type as the op "
1118 "result, but found ")
1119 << atomOp.comparator().getType() << " vs " << atomOp.getType();
1120
1121 Type pointeeType =
1122 atomOp.pointer().getType().cast<spirv::PointerType>().getPointeeType();
1123 if (atomOp.getType() != pointeeType)
1124 return atomOp.emitOpError(
1125 "pointer operand's pointee type must have the same "
1126 "as the op result type, but found ")
1127 << pointeeType << " vs " << atomOp.getType();
1128
1129 // TODO: Unequal cannot be set to Release or Acquire and Release.
1130 // In addition, Unequal cannot be set to a stronger memory-order then Equal.
1131
1132 return success();
1133 }
1134
1135 //===----------------------------------------------------------------------===//
1136 // spv.BitcastOp
1137 //===----------------------------------------------------------------------===//
1138
verify(spirv::BitcastOp bitcastOp)1139 static LogicalResult verify(spirv::BitcastOp bitcastOp) {
1140 // TODO: The SPIR-V spec validation rules are different for different
1141 // versions.
1142 auto operandType = bitcastOp.operand().getType();
1143 auto resultType = bitcastOp.result().getType();
1144 if (operandType == resultType) {
1145 return bitcastOp.emitError(
1146 "result type must be different from operand type");
1147 }
1148 if (operandType.isa<spirv::PointerType>() &&
1149 !resultType.isa<spirv::PointerType>()) {
1150 return bitcastOp.emitError(
1151 "unhandled bit cast conversion from pointer type to non-pointer type");
1152 }
1153 if (!operandType.isa<spirv::PointerType>() &&
1154 resultType.isa<spirv::PointerType>()) {
1155 return bitcastOp.emitError(
1156 "unhandled bit cast conversion from non-pointer type to pointer type");
1157 }
1158 auto operandBitWidth = getBitWidth(operandType);
1159 auto resultBitWidth = getBitWidth(resultType);
1160 if (operandBitWidth != resultBitWidth) {
1161 return bitcastOp.emitOpError("mismatch in result type bitwidth ")
1162 << resultBitWidth << " and operand type bitwidth "
1163 << operandBitWidth;
1164 }
1165 return success();
1166 }
1167
1168 //===----------------------------------------------------------------------===//
1169 // spv.BranchOp
1170 //===----------------------------------------------------------------------===//
1171
1172 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)1173 spirv::BranchOp::getMutableSuccessorOperands(unsigned index) {
1174 assert(index == 0 && "invalid successor index");
1175 return targetOperandsMutable();
1176 }
1177
1178 //===----------------------------------------------------------------------===//
1179 // spv.BranchConditionalOp
1180 //===----------------------------------------------------------------------===//
1181
1182 Optional<MutableOperandRange>
getMutableSuccessorOperands(unsigned index)1183 spirv::BranchConditionalOp::getMutableSuccessorOperands(unsigned index) {
1184 assert(index < 2 && "invalid successor index");
1185 return index == kTrueIndex ? trueTargetOperandsMutable()
1186 : falseTargetOperandsMutable();
1187 }
1188
parseBranchConditionalOp(OpAsmParser & parser,OperationState & state)1189 static ParseResult parseBranchConditionalOp(OpAsmParser &parser,
1190 OperationState &state) {
1191 auto &builder = parser.getBuilder();
1192 OpAsmParser::OperandType condInfo;
1193 Block *dest;
1194
1195 // Parse the condition.
1196 Type boolTy = builder.getI1Type();
1197 if (parser.parseOperand(condInfo) ||
1198 parser.resolveOperand(condInfo, boolTy, state.operands))
1199 return failure();
1200
1201 // Parse the optional branch weights.
1202 if (succeeded(parser.parseOptionalLSquare())) {
1203 IntegerAttr trueWeight, falseWeight;
1204 NamedAttrList weights;
1205
1206 auto i32Type = builder.getIntegerType(32);
1207 if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
1208 parser.parseComma() ||
1209 parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
1210 parser.parseRSquare())
1211 return failure();
1212
1213 state.addAttribute(kBranchWeightAttrName,
1214 builder.getArrayAttr({trueWeight, falseWeight}));
1215 }
1216
1217 // Parse the true branch.
1218 SmallVector<Value, 4> trueOperands;
1219 if (parser.parseComma() ||
1220 parser.parseSuccessorAndUseList(dest, trueOperands))
1221 return failure();
1222 state.addSuccessors(dest);
1223 state.addOperands(trueOperands);
1224
1225 // Parse the false branch.
1226 SmallVector<Value, 4> falseOperands;
1227 if (parser.parseComma() ||
1228 parser.parseSuccessorAndUseList(dest, falseOperands))
1229 return failure();
1230 state.addSuccessors(dest);
1231 state.addOperands(falseOperands);
1232 state.addAttribute(
1233 spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
1234 builder.getI32VectorAttr({1, static_cast<int32_t>(trueOperands.size()),
1235 static_cast<int32_t>(falseOperands.size())}));
1236
1237 return success();
1238 }
1239
print(spirv::BranchConditionalOp branchOp,OpAsmPrinter & printer)1240 static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) {
1241 printer << spirv::BranchConditionalOp::getOperationName() << ' '
1242 << branchOp.condition();
1243
1244 if (auto weights = branchOp.branch_weights()) {
1245 printer << " [";
1246 llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
1247 printer << a.cast<IntegerAttr>().getInt();
1248 });
1249 printer << "]";
1250 }
1251
1252 printer << ", ";
1253 printer.printSuccessorAndUseList(branchOp.getTrueBlock(),
1254 branchOp.getTrueBlockArguments());
1255 printer << ", ";
1256 printer.printSuccessorAndUseList(branchOp.getFalseBlock(),
1257 branchOp.getFalseBlockArguments());
1258 }
1259
verify(spirv::BranchConditionalOp branchOp)1260 static LogicalResult verify(spirv::BranchConditionalOp branchOp) {
1261 if (auto weights = branchOp.branch_weights()) {
1262 if (weights->getValue().size() != 2) {
1263 return branchOp.emitOpError("must have exactly two branch weights");
1264 }
1265 if (llvm::all_of(*weights, [](Attribute attr) {
1266 return attr.cast<IntegerAttr>().getValue().isNullValue();
1267 }))
1268 return branchOp.emitOpError("branch weights cannot both be zero");
1269 }
1270
1271 return success();
1272 }
1273
1274 //===----------------------------------------------------------------------===//
1275 // spv.CompositeConstruct
1276 //===----------------------------------------------------------------------===//
1277
parseCompositeConstructOp(OpAsmParser & parser,OperationState & state)1278 static ParseResult parseCompositeConstructOp(OpAsmParser &parser,
1279 OperationState &state) {
1280 SmallVector<OpAsmParser::OperandType, 4> operands;
1281 Type type;
1282 auto loc = parser.getCurrentLocation();
1283
1284 if (parser.parseOperandList(operands) || parser.parseColonType(type)) {
1285 return failure();
1286 }
1287 auto cType = type.dyn_cast<spirv::CompositeType>();
1288 if (!cType) {
1289 return parser.emitError(
1290 loc, "result type must be a composite type, but provided ")
1291 << type;
1292 }
1293
1294 if (cType.hasCompileTimeKnownNumElements() &&
1295 operands.size() != cType.getNumElements()) {
1296 return parser.emitError(loc, "has incorrect number of operands: expected ")
1297 << cType.getNumElements() << ", but provided " << operands.size();
1298 }
1299 // TODO: Add support for constructing a vector type from the vector operands.
1300 // According to the spec: "for constructing a vector, the operands may
1301 // also be vectors with the same component type as the Result Type component
1302 // type".
1303 SmallVector<Type, 4> elementTypes;
1304 elementTypes.reserve(operands.size());
1305 for (auto index : llvm::seq<uint32_t>(0, operands.size())) {
1306 elementTypes.push_back(cType.getElementType(index));
1307 }
1308 state.addTypes(type);
1309 return parser.resolveOperands(operands, elementTypes, loc, state.operands);
1310 }
1311
print(spirv::CompositeConstructOp compositeConstructOp,OpAsmPrinter & printer)1312 static void print(spirv::CompositeConstructOp compositeConstructOp,
1313 OpAsmPrinter &printer) {
1314 printer << spirv::CompositeConstructOp::getOperationName() << " "
1315 << compositeConstructOp.constituents() << " : "
1316 << compositeConstructOp.getResult().getType();
1317 }
1318
verify(spirv::CompositeConstructOp compositeConstructOp)1319 static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) {
1320 auto cType = compositeConstructOp.getType().cast<spirv::CompositeType>();
1321 SmallVector<Value, 4> constituents(compositeConstructOp.constituents());
1322
1323 if (cType.isa<spirv::CooperativeMatrixNVType>()) {
1324 if (constituents.size() != 1)
1325 return compositeConstructOp.emitError(
1326 "has incorrect number of operands: expected ")
1327 << "1, but provided " << constituents.size();
1328 } else if (constituents.size() != cType.getNumElements()) {
1329 return compositeConstructOp.emitError(
1330 "has incorrect number of operands: expected ")
1331 << cType.getNumElements() << ", but provided "
1332 << constituents.size();
1333 }
1334
1335 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1336 if (constituents[index].getType() != cType.getElementType(index)) {
1337 return compositeConstructOp.emitError(
1338 "operand type mismatch: expected operand type ")
1339 << cType.getElementType(index) << ", but provided "
1340 << constituents[index].getType();
1341 }
1342 }
1343
1344 return success();
1345 }
1346
1347 //===----------------------------------------------------------------------===//
1348 // spv.CompositeExtractOp
1349 //===----------------------------------------------------------------------===//
1350
build(OpBuilder & builder,OperationState & state,Value composite,ArrayRef<int32_t> indices)1351 void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
1352 Value composite,
1353 ArrayRef<int32_t> indices) {
1354 auto indexAttr = builder.getI32ArrayAttr(indices);
1355 auto elementType =
1356 getElementType(composite.getType(), indexAttr, state.location);
1357 if (!elementType) {
1358 return;
1359 }
1360 build(builder, state, elementType, composite, indexAttr);
1361 }
1362
parseCompositeExtractOp(OpAsmParser & parser,OperationState & state)1363 static ParseResult parseCompositeExtractOp(OpAsmParser &parser,
1364 OperationState &state) {
1365 OpAsmParser::OperandType compositeInfo;
1366 Attribute indicesAttr;
1367 Type compositeType;
1368 llvm::SMLoc attrLocation;
1369
1370 if (parser.parseOperand(compositeInfo) ||
1371 parser.getCurrentLocation(&attrLocation) ||
1372 parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
1373 parser.parseColonType(compositeType) ||
1374 parser.resolveOperand(compositeInfo, compositeType, state.operands)) {
1375 return failure();
1376 }
1377
1378 Type resultType =
1379 getElementType(compositeType, indicesAttr, parser, attrLocation);
1380 if (!resultType) {
1381 return failure();
1382 }
1383 state.addTypes(resultType);
1384 return success();
1385 }
1386
print(spirv::CompositeExtractOp compositeExtractOp,OpAsmPrinter & printer)1387 static void print(spirv::CompositeExtractOp compositeExtractOp,
1388 OpAsmPrinter &printer) {
1389 printer << spirv::CompositeExtractOp::getOperationName() << ' '
1390 << compositeExtractOp.composite() << compositeExtractOp.indices()
1391 << " : " << compositeExtractOp.composite().getType();
1392 }
1393
verify(spirv::CompositeExtractOp compExOp)1394 static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
1395 auto indicesArrayAttr = compExOp.indices().dyn_cast<ArrayAttr>();
1396 auto resultType = getElementType(compExOp.composite().getType(),
1397 indicesArrayAttr, compExOp.getLoc());
1398 if (!resultType)
1399 return failure();
1400
1401 if (resultType != compExOp.getType()) {
1402 return compExOp.emitOpError("invalid result type: expected ")
1403 << resultType << " but provided " << compExOp.getType();
1404 }
1405
1406 return success();
1407 }
1408
1409 //===----------------------------------------------------------------------===//
1410 // spv.CompositeInsert
1411 //===----------------------------------------------------------------------===//
1412
build(OpBuilder & builder,OperationState & state,Value object,Value composite,ArrayRef<int32_t> indices)1413 void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
1414 Value object, Value composite,
1415 ArrayRef<int32_t> indices) {
1416 auto indexAttr = builder.getI32ArrayAttr(indices);
1417 build(builder, state, composite.getType(), object, composite, indexAttr);
1418 }
1419
parseCompositeInsertOp(OpAsmParser & parser,OperationState & state)1420 static ParseResult parseCompositeInsertOp(OpAsmParser &parser,
1421 OperationState &state) {
1422 SmallVector<OpAsmParser::OperandType, 2> operands;
1423 Type objectType, compositeType;
1424 Attribute indicesAttr;
1425 auto loc = parser.getCurrentLocation();
1426
1427 return failure(
1428 parser.parseOperandList(operands, 2) ||
1429 parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
1430 parser.parseColonType(objectType) ||
1431 parser.parseKeywordType("into", compositeType) ||
1432 parser.resolveOperands(operands, {objectType, compositeType}, loc,
1433 state.operands) ||
1434 parser.addTypesToList(compositeType, state.types));
1435 }
1436
verify(spirv::CompositeInsertOp compositeInsertOp)1437 static LogicalResult verify(spirv::CompositeInsertOp compositeInsertOp) {
1438 auto indicesArrayAttr = compositeInsertOp.indices().dyn_cast<ArrayAttr>();
1439 auto objectType =
1440 getElementType(compositeInsertOp.composite().getType(), indicesArrayAttr,
1441 compositeInsertOp.getLoc());
1442 if (!objectType)
1443 return failure();
1444
1445 if (objectType != compositeInsertOp.object().getType()) {
1446 return compositeInsertOp.emitOpError("object operand type should be ")
1447 << objectType << ", but found "
1448 << compositeInsertOp.object().getType();
1449 }
1450
1451 if (compositeInsertOp.composite().getType() != compositeInsertOp.getType()) {
1452 return compositeInsertOp.emitOpError("result type should be the same as "
1453 "the composite type, but found ")
1454 << compositeInsertOp.composite().getType() << " vs "
1455 << compositeInsertOp.getType();
1456 }
1457
1458 return success();
1459 }
1460
print(spirv::CompositeInsertOp compositeInsertOp,OpAsmPrinter & printer)1461 static void print(spirv::CompositeInsertOp compositeInsertOp,
1462 OpAsmPrinter &printer) {
1463 printer << spirv::CompositeInsertOp::getOperationName() << " "
1464 << compositeInsertOp.object() << ", " << compositeInsertOp.composite()
1465 << compositeInsertOp.indices() << " : "
1466 << compositeInsertOp.object().getType() << " into "
1467 << compositeInsertOp.composite().getType();
1468 }
1469
1470 //===----------------------------------------------------------------------===//
1471 // spv.constant
1472 //===----------------------------------------------------------------------===//
1473
parseConstantOp(OpAsmParser & parser,OperationState & state)1474 static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &state) {
1475 Attribute value;
1476 if (parser.parseAttribute(value, kValueAttrName, state.attributes))
1477 return failure();
1478
1479 Type type = value.getType();
1480 if (type.isa<NoneType, TensorType>()) {
1481 if (parser.parseColonType(type))
1482 return failure();
1483 }
1484
1485 return parser.addTypeToList(type, state.types);
1486 }
1487
print(spirv::ConstantOp constOp,OpAsmPrinter & printer)1488 static void print(spirv::ConstantOp constOp, OpAsmPrinter &printer) {
1489 printer << spirv::ConstantOp::getOperationName() << ' ' << constOp.value();
1490 if (constOp.getType().isa<spirv::ArrayType>())
1491 printer << " : " << constOp.getType();
1492 }
1493
verify(spirv::ConstantOp constOp)1494 static LogicalResult verify(spirv::ConstantOp constOp) {
1495 auto opType = constOp.getType();
1496 auto value = constOp.value();
1497 auto valueType = value.getType();
1498
1499 // ODS already generates checks to make sure the result type is valid. We just
1500 // need to additionally check that the value's attribute type is consistent
1501 // with the result type.
1502 if (value.isa<IntegerAttr, FloatAttr>()) {
1503 if (valueType != opType)
1504 return constOp.emitOpError("result type (")
1505 << opType << ") does not match value type (" << valueType << ")";
1506 return success();
1507 }
1508 if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
1509 if (valueType == opType)
1510 return success();
1511 auto arrayType = opType.dyn_cast<spirv::ArrayType>();
1512 auto shapedType = valueType.dyn_cast<ShapedType>();
1513 if (!arrayType) {
1514 return constOp.emitOpError(
1515 "must have spv.array result type for array value");
1516 }
1517
1518 int numElements = arrayType.getNumElements();
1519 auto opElemType = arrayType.getElementType();
1520 while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) {
1521 numElements *= t.getNumElements();
1522 opElemType = t.getElementType();
1523 }
1524 if (!opElemType.isIntOrFloat())
1525 return constOp.emitOpError("only support nested array result type");
1526
1527 auto valueElemType = shapedType.getElementType();
1528 if (valueElemType != opElemType) {
1529 return constOp.emitOpError("result element type (")
1530 << opElemType << ") does not match value element type ("
1531 << valueElemType << ")";
1532 }
1533
1534 if (numElements != shapedType.getNumElements()) {
1535 return constOp.emitOpError("result number of elements (")
1536 << numElements << ") does not match value number of elements ("
1537 << shapedType.getNumElements() << ")";
1538 }
1539 return success();
1540 }
1541 if (auto attayAttr = value.dyn_cast<ArrayAttr>()) {
1542 auto arrayType = opType.dyn_cast<spirv::ArrayType>();
1543 if (!arrayType)
1544 return constOp.emitOpError(
1545 "must have spv.array result type for array value");
1546 Type elemType = arrayType.getElementType();
1547 for (Attribute element : attayAttr.getValue()) {
1548 if (element.getType() != elemType)
1549 return constOp.emitOpError("has array element whose type (")
1550 << element.getType()
1551 << ") does not match the result element type (" << elemType
1552 << ')';
1553 }
1554 return success();
1555 }
1556 return constOp.emitOpError("cannot have value of type ") << valueType;
1557 }
1558
isBuildableWith(Type type)1559 bool spirv::ConstantOp::isBuildableWith(Type type) {
1560 // Must be valid SPIR-V type first.
1561 if (!type.isa<spirv::SPIRVType>())
1562 return false;
1563
1564 if (isa<SPIRVDialect>(type.getDialect())) {
1565 // TODO: support constant struct
1566 return type.isa<spirv::ArrayType>();
1567 }
1568
1569 return true;
1570 }
1571
getZero(Type type,Location loc,OpBuilder & builder)1572 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
1573 OpBuilder &builder) {
1574 if (auto intType = type.dyn_cast<IntegerType>()) {
1575 unsigned width = intType.getWidth();
1576 if (width == 1)
1577 return builder.create<spirv::ConstantOp>(loc, type,
1578 builder.getBoolAttr(false));
1579 return builder.create<spirv::ConstantOp>(
1580 loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
1581 }
1582
1583 llvm_unreachable("unimplemented types for ConstantOp::getZero()");
1584 }
1585
getOne(Type type,Location loc,OpBuilder & builder)1586 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
1587 OpBuilder &builder) {
1588 if (auto intType = type.dyn_cast<IntegerType>()) {
1589 unsigned width = intType.getWidth();
1590 if (width == 1)
1591 return builder.create<spirv::ConstantOp>(loc, type,
1592 builder.getBoolAttr(true));
1593 return builder.create<spirv::ConstantOp>(
1594 loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
1595 }
1596
1597 llvm_unreachable("unimplemented types for ConstantOp::getOne()");
1598 }
1599
1600 //===----------------------------------------------------------------------===//
1601 // spv.EntryPoint
1602 //===----------------------------------------------------------------------===//
1603
build(OpBuilder & builder,OperationState & state,spirv::ExecutionModel executionModel,spirv::FuncOp function,ArrayRef<Attribute> interfaceVars)1604 void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
1605 spirv::ExecutionModel executionModel,
1606 spirv::FuncOp function,
1607 ArrayRef<Attribute> interfaceVars) {
1608 build(builder, state,
1609 builder.getI32IntegerAttr(static_cast<int32_t>(executionModel)),
1610 builder.getSymbolRefAttr(function),
1611 builder.getArrayAttr(interfaceVars));
1612 }
1613
parseEntryPointOp(OpAsmParser & parser,OperationState & state)1614 static ParseResult parseEntryPointOp(OpAsmParser &parser,
1615 OperationState &state) {
1616 spirv::ExecutionModel execModel;
1617 SmallVector<OpAsmParser::OperandType, 0> identifiers;
1618 SmallVector<Type, 0> idTypes;
1619 SmallVector<Attribute, 4> interfaceVars;
1620
1621 FlatSymbolRefAttr fn;
1622 if (parseEnumStrAttr(execModel, parser, state) ||
1623 parser.parseAttribute(fn, Type(), kFnNameAttrName, state.attributes)) {
1624 return failure();
1625 }
1626
1627 if (!parser.parseOptionalComma()) {
1628 // Parse the interface variables
1629 do {
1630 // The name of the interface variable attribute isnt important
1631 auto attrName = "var_symbol";
1632 FlatSymbolRefAttr var;
1633 NamedAttrList attrs;
1634 if (parser.parseAttribute(var, Type(), attrName, attrs)) {
1635 return failure();
1636 }
1637 interfaceVars.push_back(var);
1638 } while (!parser.parseOptionalComma());
1639 }
1640 state.addAttribute(kInterfaceAttrName,
1641 parser.getBuilder().getArrayAttr(interfaceVars));
1642 return success();
1643 }
1644
print(spirv::EntryPointOp entryPointOp,OpAsmPrinter & printer)1645 static void print(spirv::EntryPointOp entryPointOp, OpAsmPrinter &printer) {
1646 printer << spirv::EntryPointOp::getOperationName() << " \""
1647 << stringifyExecutionModel(entryPointOp.execution_model()) << "\" ";
1648 printer.printSymbolName(entryPointOp.fn());
1649 auto interfaceVars = entryPointOp.interface().getValue();
1650 if (!interfaceVars.empty()) {
1651 printer << ", ";
1652 llvm::interleaveComma(interfaceVars, printer);
1653 }
1654 }
1655
verify(spirv::EntryPointOp entryPointOp)1656 static LogicalResult verify(spirv::EntryPointOp entryPointOp) {
1657 // Checks for fn and interface symbol reference are done in spirv::ModuleOp
1658 // verification.
1659 return success();
1660 }
1661
1662 //===----------------------------------------------------------------------===//
1663 // spv.ExecutionMode
1664 //===----------------------------------------------------------------------===//
1665
build(OpBuilder & builder,OperationState & state,spirv::FuncOp function,spirv::ExecutionMode executionMode,ArrayRef<int32_t> params)1666 void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
1667 spirv::FuncOp function,
1668 spirv::ExecutionMode executionMode,
1669 ArrayRef<int32_t> params) {
1670 build(builder, state, builder.getSymbolRefAttr(function),
1671 builder.getI32IntegerAttr(static_cast<int32_t>(executionMode)),
1672 builder.getI32ArrayAttr(params));
1673 }
1674
parseExecutionModeOp(OpAsmParser & parser,OperationState & state)1675 static ParseResult parseExecutionModeOp(OpAsmParser &parser,
1676 OperationState &state) {
1677 spirv::ExecutionMode execMode;
1678 Attribute fn;
1679 if (parser.parseAttribute(fn, kFnNameAttrName, state.attributes) ||
1680 parseEnumStrAttr(execMode, parser, state)) {
1681 return failure();
1682 }
1683
1684 SmallVector<int32_t, 4> values;
1685 Type i32Type = parser.getBuilder().getIntegerType(32);
1686 while (!parser.parseOptionalComma()) {
1687 NamedAttrList attr;
1688 Attribute value;
1689 if (parser.parseAttribute(value, i32Type, "value", attr)) {
1690 return failure();
1691 }
1692 values.push_back(value.cast<IntegerAttr>().getInt());
1693 }
1694 state.addAttribute(kValuesAttrName,
1695 parser.getBuilder().getI32ArrayAttr(values));
1696 return success();
1697 }
1698
print(spirv::ExecutionModeOp execModeOp,OpAsmPrinter & printer)1699 static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) {
1700 printer << spirv::ExecutionModeOp::getOperationName() << " ";
1701 printer.printSymbolName(execModeOp.fn());
1702 printer << " \"" << stringifyExecutionMode(execModeOp.execution_mode())
1703 << "\"";
1704 auto values = execModeOp.values();
1705 if (!values.size())
1706 return;
1707 printer << ", ";
1708 llvm::interleaveComma(values, printer, [&](Attribute a) {
1709 printer << a.cast<IntegerAttr>().getInt();
1710 });
1711 }
1712
1713 //===----------------------------------------------------------------------===//
1714 // spv.func
1715 //===----------------------------------------------------------------------===//
1716
parseFuncOp(OpAsmParser & parser,OperationState & state)1717 static ParseResult parseFuncOp(OpAsmParser &parser, OperationState &state) {
1718 SmallVector<OpAsmParser::OperandType, 4> entryArgs;
1719 SmallVector<NamedAttrList, 4> argAttrs;
1720 SmallVector<NamedAttrList, 4> resultAttrs;
1721 SmallVector<Type, 4> argTypes;
1722 SmallVector<Type, 4> resultTypes;
1723 auto &builder = parser.getBuilder();
1724
1725 // Parse the name as a symbol.
1726 StringAttr nameAttr;
1727 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1728 state.attributes))
1729 return failure();
1730
1731 // Parse the function signature.
1732 bool isVariadic = false;
1733 if (impl::parseFunctionSignature(parser, /*allowVariadic=*/false, entryArgs,
1734 argTypes, argAttrs, isVariadic, resultTypes,
1735 resultAttrs))
1736 return failure();
1737
1738 auto fnType = builder.getFunctionType(argTypes, resultTypes);
1739 state.addAttribute(impl::getTypeAttrName(), TypeAttr::get(fnType));
1740
1741 // Parse the optional function control keyword.
1742 spirv::FunctionControl fnControl;
1743 if (parseEnumStrAttr(fnControl, parser, state))
1744 return failure();
1745
1746 // If additional attributes are present, parse them.
1747 if (parser.parseOptionalAttrDictWithKeyword(state.attributes))
1748 return failure();
1749
1750 // Add the attributes to the function arguments.
1751 assert(argAttrs.size() == argTypes.size());
1752 assert(resultAttrs.size() == resultTypes.size());
1753 impl::addArgAndResultAttrs(builder, state, argAttrs, resultAttrs);
1754
1755 // Parse the optional function body.
1756 auto *body = state.addRegion();
1757 OptionalParseResult result = parser.parseOptionalRegion(
1758 *body, entryArgs, entryArgs.empty() ? ArrayRef<Type>() : argTypes);
1759 return failure(result.hasValue() && failed(*result));
1760 }
1761
print(spirv::FuncOp fnOp,OpAsmPrinter & printer)1762 static void print(spirv::FuncOp fnOp, OpAsmPrinter &printer) {
1763 // Print function name, signature, and control.
1764 printer << spirv::FuncOp::getOperationName() << " ";
1765 printer.printSymbolName(fnOp.sym_name());
1766 auto fnType = fnOp.getType();
1767 impl::printFunctionSignature(printer, fnOp, fnType.getInputs(),
1768 /*isVariadic=*/false, fnType.getResults());
1769 printer << " \"" << spirv::stringifyFunctionControl(fnOp.function_control())
1770 << "\"";
1771 impl::printFunctionAttributes(
1772 printer, fnOp, fnType.getNumInputs(), fnType.getNumResults(),
1773 {spirv::attributeName<spirv::FunctionControl>()});
1774
1775 // Print the body if this is not an external function.
1776 Region &body = fnOp.body();
1777 if (!body.empty())
1778 printer.printRegion(body, /*printEntryBlockArgs=*/false,
1779 /*printBlockTerminators=*/true);
1780 }
1781
verifyType()1782 LogicalResult spirv::FuncOp::verifyType() {
1783 auto type = getTypeAttr().getValue();
1784 if (!type.isa<FunctionType>())
1785 return emitOpError("requires '" + getTypeAttrName() +
1786 "' attribute of function type");
1787 if (getType().getNumResults() > 1)
1788 return emitOpError("cannot have more than one result");
1789 return success();
1790 }
1791
verifyBody()1792 LogicalResult spirv::FuncOp::verifyBody() {
1793 FunctionType fnType = getType();
1794
1795 auto walkResult = walk([fnType](Operation *op) -> WalkResult {
1796 if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
1797 if (fnType.getNumResults() != 0)
1798 return retOp.emitOpError("cannot be used in functions returning value");
1799 } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
1800 if (fnType.getNumResults() != 1)
1801 return retOp.emitOpError(
1802 "returns 1 value but enclosing function requires ")
1803 << fnType.getNumResults() << " results";
1804
1805 auto retOperandType = retOp.value().getType();
1806 auto fnResultType = fnType.getResult(0);
1807 if (retOperandType != fnResultType)
1808 return retOp.emitOpError(" return value's type (")
1809 << retOperandType << ") mismatch with function's result type ("
1810 << fnResultType << ")";
1811 }
1812 return WalkResult::advance();
1813 });
1814
1815 // TODO: verify other bits like linkage type.
1816
1817 return failure(walkResult.wasInterrupted());
1818 }
1819
build(OpBuilder & builder,OperationState & state,StringRef name,FunctionType type,spirv::FunctionControl control,ArrayRef<NamedAttribute> attrs)1820 void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
1821 StringRef name, FunctionType type,
1822 spirv::FunctionControl control,
1823 ArrayRef<NamedAttribute> attrs) {
1824 state.addAttribute(SymbolTable::getSymbolAttrName(),
1825 builder.getStringAttr(name));
1826 state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
1827 state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
1828 builder.getI32IntegerAttr(static_cast<uint32_t>(control)));
1829 state.attributes.append(attrs.begin(), attrs.end());
1830 state.addRegion();
1831 }
1832
1833 // CallableOpInterface
getCallableRegion()1834 Region *spirv::FuncOp::getCallableRegion() {
1835 return isExternal() ? nullptr : &body();
1836 }
1837
1838 // CallableOpInterface
getCallableResults()1839 ArrayRef<Type> spirv::FuncOp::getCallableResults() {
1840 return getType().getResults();
1841 }
1842
1843 //===----------------------------------------------------------------------===//
1844 // spv.FunctionCall
1845 //===----------------------------------------------------------------------===//
1846
verify(spirv::FunctionCallOp functionCallOp)1847 static LogicalResult verify(spirv::FunctionCallOp functionCallOp) {
1848 auto fnName = functionCallOp.callee();
1849
1850 auto funcOp =
1851 dyn_cast_or_null<spirv::FuncOp>(SymbolTable::lookupNearestSymbolFrom(
1852 functionCallOp->getParentOp(), fnName));
1853 if (!funcOp) {
1854 return functionCallOp.emitOpError("callee function '")
1855 << fnName << "' not found in nearest symbol table";
1856 }
1857
1858 auto functionType = funcOp.getType();
1859
1860 if (functionCallOp.getNumResults() > 1) {
1861 return functionCallOp.emitOpError(
1862 "expected callee function to have 0 or 1 result, but provided ")
1863 << functionCallOp.getNumResults();
1864 }
1865
1866 if (functionType.getNumInputs() != functionCallOp.getNumOperands()) {
1867 return functionCallOp.emitOpError(
1868 "has incorrect number of operands for callee: expected ")
1869 << functionType.getNumInputs() << ", but provided "
1870 << functionCallOp.getNumOperands();
1871 }
1872
1873 for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
1874 if (functionCallOp.getOperand(i).getType() != functionType.getInput(i)) {
1875 return functionCallOp.emitOpError(
1876 "operand type mismatch: expected operand type ")
1877 << functionType.getInput(i) << ", but provided "
1878 << functionCallOp.getOperand(i).getType() << " for operand number "
1879 << i;
1880 }
1881 }
1882
1883 if (functionType.getNumResults() != functionCallOp.getNumResults()) {
1884 return functionCallOp.emitOpError(
1885 "has incorrect number of results has for callee: expected ")
1886 << functionType.getNumResults() << ", but provided "
1887 << functionCallOp.getNumResults();
1888 }
1889
1890 if (functionCallOp.getNumResults() &&
1891 (functionCallOp.getResult(0).getType() != functionType.getResult(0))) {
1892 return functionCallOp.emitOpError("result type mismatch: expected ")
1893 << functionType.getResult(0) << ", but provided "
1894 << functionCallOp.getResult(0).getType();
1895 }
1896
1897 return success();
1898 }
1899
getCallableForCallee()1900 CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
1901 return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
1902 }
1903
getArgOperands()1904 Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
1905 return arguments();
1906 }
1907
1908 //===----------------------------------------------------------------------===//
1909 // spv.globalVariable
1910 //===----------------------------------------------------------------------===//
1911
build(OpBuilder & builder,OperationState & state,Type type,StringRef name,unsigned descriptorSet,unsigned binding)1912 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1913 Type type, StringRef name,
1914 unsigned descriptorSet, unsigned binding) {
1915 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name),
1916 nullptr);
1917 state.addAttribute(
1918 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
1919 builder.getI32IntegerAttr(descriptorSet));
1920 state.addAttribute(
1921 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
1922 builder.getI32IntegerAttr(binding));
1923 }
1924
build(OpBuilder & builder,OperationState & state,Type type,StringRef name,spirv::BuiltIn builtin)1925 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
1926 Type type, StringRef name,
1927 spirv::BuiltIn builtin) {
1928 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name),
1929 nullptr);
1930 state.addAttribute(
1931 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
1932 builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
1933 }
1934
parseGlobalVariableOp(OpAsmParser & parser,OperationState & state)1935 static ParseResult parseGlobalVariableOp(OpAsmParser &parser,
1936 OperationState &state) {
1937 // Parse variable name.
1938 StringAttr nameAttr;
1939 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
1940 state.attributes)) {
1941 return failure();
1942 }
1943
1944 // Parse optional initializer
1945 if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
1946 FlatSymbolRefAttr initSymbol;
1947 if (parser.parseLParen() ||
1948 parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
1949 state.attributes) ||
1950 parser.parseRParen())
1951 return failure();
1952 }
1953
1954 if (parseVariableDecorations(parser, state)) {
1955 return failure();
1956 }
1957
1958 Type type;
1959 auto loc = parser.getCurrentLocation();
1960 if (parser.parseColonType(type)) {
1961 return failure();
1962 }
1963 if (!type.isa<spirv::PointerType>()) {
1964 return parser.emitError(loc, "expected spv.ptr type");
1965 }
1966 state.addAttribute(kTypeAttrName, TypeAttr::get(type));
1967
1968 return success();
1969 }
1970
print(spirv::GlobalVariableOp varOp,OpAsmPrinter & printer)1971 static void print(spirv::GlobalVariableOp varOp, OpAsmPrinter &printer) {
1972 auto *op = varOp.getOperation();
1973 SmallVector<StringRef, 4> elidedAttrs{
1974 spirv::attributeName<spirv::StorageClass>()};
1975 printer << spirv::GlobalVariableOp::getOperationName();
1976
1977 // Print variable name.
1978 printer << ' ';
1979 printer.printSymbolName(varOp.sym_name());
1980 elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
1981
1982 // Print optional initializer
1983 if (auto initializer = varOp.initializer()) {
1984 printer << " " << kInitializerAttrName << '(';
1985 printer.printSymbolName(initializer.getValue());
1986 printer << ')';
1987 elidedAttrs.push_back(kInitializerAttrName);
1988 }
1989
1990 elidedAttrs.push_back(kTypeAttrName);
1991 printVariableDecorations(op, printer, elidedAttrs);
1992 printer << " : " << varOp.type();
1993 }
1994
verify(spirv::GlobalVariableOp varOp)1995 static LogicalResult verify(spirv::GlobalVariableOp varOp) {
1996 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
1997 // object. It cannot be Generic. It must be the same as the Storage Class
1998 // operand of the Result Type."
1999 // Also, Function storage class is reserved by spv.Variable.
2000 auto storageClass = varOp.storageClass();
2001 if (storageClass == spirv::StorageClass::Generic ||
2002 storageClass == spirv::StorageClass::Function) {
2003 return varOp.emitOpError("storage class cannot be '")
2004 << stringifyStorageClass(storageClass) << "'";
2005 }
2006
2007 if (auto init =
2008 varOp->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
2009 Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
2010 varOp->getParentOp(), init.getValue());
2011 // TODO: Currently only variable initialization with specialization
2012 // constants and other variables is supported. They could be normal
2013 // constants in the module scope as well.
2014 if (!initOp ||
2015 !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
2016 return varOp.emitOpError("initializer must be result of a "
2017 "spv.specConstant or spv.globalVariable op");
2018 }
2019 }
2020
2021 return success();
2022 }
2023
2024 //===----------------------------------------------------------------------===//
2025 // spv.GroupBroadcast
2026 //===----------------------------------------------------------------------===//
2027
verify(spirv::GroupBroadcastOp broadcastOp)2028 static LogicalResult verify(spirv::GroupBroadcastOp broadcastOp) {
2029 spirv::Scope scope = broadcastOp.execution_scope();
2030 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2031 return broadcastOp.emitOpError(
2032 "execution scope must be 'Workgroup' or 'Subgroup'");
2033
2034 if (auto localIdTy = broadcastOp.localid().getType().dyn_cast<VectorType>())
2035 if (!(localIdTy.getNumElements() == 2 || localIdTy.getNumElements() == 3))
2036 return broadcastOp.emitOpError("localid is a vector and can be with only "
2037 " 2 or 3 components, actual number is ")
2038 << localIdTy.getNumElements();
2039
2040 return success();
2041 }
2042
2043 //===----------------------------------------------------------------------===//
2044 // spv.GroupNonUniformBallotOp
2045 //===----------------------------------------------------------------------===//
2046
verify(spirv::GroupNonUniformBallotOp ballotOp)2047 static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) {
2048 spirv::Scope scope = ballotOp.execution_scope();
2049 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2050 return ballotOp.emitOpError(
2051 "execution scope must be 'Workgroup' or 'Subgroup'");
2052
2053 return success();
2054 }
2055
2056 //===----------------------------------------------------------------------===//
2057 // spv.GroupNonUniformBroadcast
2058 //===----------------------------------------------------------------------===//
2059
verify(spirv::GroupNonUniformBroadcastOp broadcastOp)2060 static LogicalResult verify(spirv::GroupNonUniformBroadcastOp broadcastOp) {
2061 spirv::Scope scope = broadcastOp.execution_scope();
2062 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2063 return broadcastOp.emitOpError(
2064 "execution scope must be 'Workgroup' or 'Subgroup'");
2065
2066 // SPIR-V spec: "Before version 1.5, Id must come from a
2067 // constant instruction.
2068 auto targetEnv = spirv::getDefaultTargetEnv(broadcastOp.getContext());
2069 if (auto spirvModule = broadcastOp->getParentOfType<spirv::ModuleOp>())
2070 targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
2071
2072 if (targetEnv.getVersion() < spirv::Version::V_1_5) {
2073 auto *idOp = broadcastOp.id().getDefiningOp();
2074 if (!idOp || !isa<spirv::ConstantOp, // for normal constant
2075 spirv::ReferenceOfOp>(idOp)) // for spec constant
2076 return broadcastOp.emitOpError("id must be the result of a constant op");
2077 }
2078
2079 return success();
2080 }
2081
2082 //===----------------------------------------------------------------------===//
2083 // spv.SubgroupBlockReadINTEL
2084 //===----------------------------------------------------------------------===//
2085
parseSubgroupBlockReadINTELOp(OpAsmParser & parser,OperationState & state)2086 static ParseResult parseSubgroupBlockReadINTELOp(OpAsmParser &parser,
2087 OperationState &state) {
2088 // Parse the storage class specification
2089 spirv::StorageClass storageClass;
2090 OpAsmParser::OperandType ptrInfo;
2091 Type elementType;
2092 if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
2093 parser.parseColon() || parser.parseType(elementType)) {
2094 return failure();
2095 }
2096
2097 auto ptrType = spirv::PointerType::get(elementType, storageClass);
2098 if (auto valVecTy = elementType.dyn_cast<VectorType>())
2099 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2100
2101 if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
2102 return failure();
2103 }
2104
2105 state.addTypes(elementType);
2106 return success();
2107 }
2108
print(spirv::SubgroupBlockReadINTELOp blockReadOp,OpAsmPrinter & printer)2109 static void print(spirv::SubgroupBlockReadINTELOp blockReadOp,
2110 OpAsmPrinter &printer) {
2111 SmallVector<StringRef, 4> elidedAttrs;
2112 printer << spirv::SubgroupBlockReadINTELOp::getOperationName() << " "
2113 << blockReadOp.ptr();
2114 printer << " : " << blockReadOp.getType();
2115 }
2116
verify(spirv::SubgroupBlockReadINTELOp blockReadOp)2117 static LogicalResult verify(spirv::SubgroupBlockReadINTELOp blockReadOp) {
2118 if (failed(verifyBlockReadWritePtrAndValTypes(blockReadOp, blockReadOp.ptr(),
2119 blockReadOp.value())))
2120 return failure();
2121
2122 return success();
2123 }
2124
2125 //===----------------------------------------------------------------------===//
2126 // spv.SubgroupBlockWriteINTEL
2127 //===----------------------------------------------------------------------===//
2128
parseSubgroupBlockWriteINTELOp(OpAsmParser & parser,OperationState & state)2129 static ParseResult parseSubgroupBlockWriteINTELOp(OpAsmParser &parser,
2130 OperationState &state) {
2131 // Parse the storage class specification
2132 spirv::StorageClass storageClass;
2133 SmallVector<OpAsmParser::OperandType, 2> operandInfo;
2134 auto loc = parser.getCurrentLocation();
2135 Type elementType;
2136 if (parseEnumStrAttr(storageClass, parser) ||
2137 parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
2138 parser.parseType(elementType)) {
2139 return failure();
2140 }
2141
2142 auto ptrType = spirv::PointerType::get(elementType, storageClass);
2143 if (auto valVecTy = elementType.dyn_cast<VectorType>())
2144 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2145
2146 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
2147 state.operands)) {
2148 return failure();
2149 }
2150 return success();
2151 }
2152
print(spirv::SubgroupBlockWriteINTELOp blockWriteOp,OpAsmPrinter & printer)2153 static void print(spirv::SubgroupBlockWriteINTELOp blockWriteOp,
2154 OpAsmPrinter &printer) {
2155 SmallVector<StringRef, 4> elidedAttrs;
2156 printer << spirv::SubgroupBlockWriteINTELOp::getOperationName() << " "
2157 << blockWriteOp.ptr() << ", " << blockWriteOp.value();
2158 printer << " : " << blockWriteOp.value().getType();
2159 }
2160
verify(spirv::SubgroupBlockWriteINTELOp blockWriteOp)2161 static LogicalResult verify(spirv::SubgroupBlockWriteINTELOp blockWriteOp) {
2162 if (failed(verifyBlockReadWritePtrAndValTypes(
2163 blockWriteOp, blockWriteOp.ptr(), blockWriteOp.value())))
2164 return failure();
2165
2166 return success();
2167 }
2168
2169 //===----------------------------------------------------------------------===//
2170 // spv.GroupNonUniformElectOp
2171 //===----------------------------------------------------------------------===//
2172
build(OpBuilder & builder,OperationState & state,spirv::Scope scope)2173 void spirv::GroupNonUniformElectOp::build(OpBuilder &builder,
2174 OperationState &state,
2175 spirv::Scope scope) {
2176 build(builder, state, builder.getI1Type(), scope);
2177 }
2178
verify(spirv::GroupNonUniformElectOp groupOp)2179 static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) {
2180 spirv::Scope scope = groupOp.execution_scope();
2181 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2182 return groupOp.emitOpError(
2183 "execution scope must be 'Workgroup' or 'Subgroup'");
2184
2185 return success();
2186 }
2187
2188 //===----------------------------------------------------------------------===//
2189 // spv.LoadOp
2190 //===----------------------------------------------------------------------===//
2191
build(OpBuilder & builder,OperationState & state,Value basePtr,IntegerAttr memory_access,IntegerAttr alignment)2192 void spirv::LoadOp::build(OpBuilder &builder, OperationState &state,
2193 Value basePtr, IntegerAttr memory_access,
2194 IntegerAttr alignment) {
2195 auto ptrType = basePtr.getType().cast<spirv::PointerType>();
2196 build(builder, state, ptrType.getPointeeType(), basePtr, memory_access,
2197 alignment);
2198 }
2199
parseLoadOp(OpAsmParser & parser,OperationState & state)2200 static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &state) {
2201 // Parse the storage class specification
2202 spirv::StorageClass storageClass;
2203 OpAsmParser::OperandType ptrInfo;
2204 Type elementType;
2205 if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
2206 parseMemoryAccessAttributes(parser, state) ||
2207 parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() ||
2208 parser.parseType(elementType)) {
2209 return failure();
2210 }
2211
2212 auto ptrType = spirv::PointerType::get(elementType, storageClass);
2213 if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
2214 return failure();
2215 }
2216
2217 state.addTypes(elementType);
2218 return success();
2219 }
2220
print(spirv::LoadOp loadOp,OpAsmPrinter & printer)2221 static void print(spirv::LoadOp loadOp, OpAsmPrinter &printer) {
2222 auto *op = loadOp.getOperation();
2223 SmallVector<StringRef, 4> elidedAttrs;
2224 StringRef sc = stringifyStorageClass(
2225 loadOp.ptr().getType().cast<spirv::PointerType>().getStorageClass());
2226 printer << spirv::LoadOp::getOperationName() << " \"" << sc << "\" "
2227 << loadOp.ptr();
2228
2229 printMemoryAccessAttribute(loadOp, printer, elidedAttrs);
2230
2231 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
2232 printer << " : " << loadOp.getType();
2233 }
2234
verify(spirv::LoadOp loadOp)2235 static LogicalResult verify(spirv::LoadOp loadOp) {
2236 // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
2237 // type with fixed size; i.e., it cannot be, nor include, any
2238 // OpTypeRuntimeArray types."
2239 if (failed(verifyLoadStorePtrAndValTypes(loadOp, loadOp.ptr(),
2240 loadOp.value()))) {
2241 return failure();
2242 }
2243 return verifyMemoryAccessAttribute(loadOp);
2244 }
2245
2246 //===----------------------------------------------------------------------===//
2247 // spv.loop
2248 //===----------------------------------------------------------------------===//
2249
build(OpBuilder & builder,OperationState & state)2250 void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) {
2251 state.addAttribute("loop_control",
2252 builder.getI32IntegerAttr(
2253 static_cast<uint32_t>(spirv::LoopControl::None)));
2254 state.addRegion();
2255 }
2256
parseLoopOp(OpAsmParser & parser,OperationState & state)2257 static ParseResult parseLoopOp(OpAsmParser &parser, OperationState &state) {
2258 if (parseControlAttribute<spirv::LoopControl>(parser, state))
2259 return failure();
2260 return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
2261 /*argTypes=*/{});
2262 }
2263
print(spirv::LoopOp loopOp,OpAsmPrinter & printer)2264 static void print(spirv::LoopOp loopOp, OpAsmPrinter &printer) {
2265 auto *op = loopOp.getOperation();
2266
2267 printer << spirv::LoopOp::getOperationName();
2268 auto control = loopOp.loop_control();
2269 if (control != spirv::LoopControl::None)
2270 printer << " control(" << spirv::stringifyLoopControl(control) << ")";
2271 printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
2272 /*printBlockTerminators=*/true);
2273 }
2274
2275 /// Returns true if the given `srcBlock` contains only one `spv.Branch` to the
2276 /// given `dstBlock`.
hasOneBranchOpTo(Block & srcBlock,Block & dstBlock)2277 static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
2278 // Check that there is only one op in the `srcBlock`.
2279 if (!llvm::hasSingleElement(srcBlock))
2280 return false;
2281
2282 auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
2283 return branchOp && branchOp.getSuccessor() == &dstBlock;
2284 }
2285
verify(spirv::LoopOp loopOp)2286 static LogicalResult verify(spirv::LoopOp loopOp) {
2287 auto *op = loopOp.getOperation();
2288
2289 // We need to verify that the blocks follow the following layout:
2290 //
2291 // +-------------+
2292 // | entry block |
2293 // +-------------+
2294 // |
2295 // v
2296 // +-------------+
2297 // | loop header | <-----+
2298 // +-------------+ |
2299 // |
2300 // ... |
2301 // \ | / |
2302 // v |
2303 // +---------------+ |
2304 // | loop continue | -----+
2305 // +---------------+
2306 //
2307 // ...
2308 // \ | /
2309 // v
2310 // +-------------+
2311 // | merge block |
2312 // +-------------+
2313
2314 auto ®ion = op->getRegion(0);
2315 // Allow empty region as a degenerated case, which can come from
2316 // optimizations.
2317 if (region.empty())
2318 return success();
2319
2320 // The last block is the merge block.
2321 Block &merge = region.back();
2322 if (!isMergeBlock(merge))
2323 return loopOp.emitOpError(
2324 "last block must be the merge block with only one 'spv.mlir.merge' op");
2325
2326 if (std::next(region.begin()) == region.end())
2327 return loopOp.emitOpError(
2328 "must have an entry block branching to the loop header block");
2329 // The first block is the entry block.
2330 Block &entry = region.front();
2331
2332 if (std::next(region.begin(), 2) == region.end())
2333 return loopOp.emitOpError(
2334 "must have a loop header block branched from the entry block");
2335 // The second block is the loop header block.
2336 Block &header = *std::next(region.begin(), 1);
2337
2338 if (!hasOneBranchOpTo(entry, header))
2339 return loopOp.emitOpError(
2340 "entry block must only have one 'spv.Branch' op to the second block");
2341
2342 if (std::next(region.begin(), 3) == region.end())
2343 return loopOp.emitOpError(
2344 "requires a loop continue block branching to the loop header block");
2345 // The second to last block is the loop continue block.
2346 Block &cont = *std::prev(region.end(), 2);
2347
2348 // Make sure that we have a branch from the loop continue block to the loop
2349 // header block.
2350 if (llvm::none_of(
2351 llvm::seq<unsigned>(0, cont.getNumSuccessors()),
2352 [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
2353 return loopOp.emitOpError("second to last block must be the loop continue "
2354 "block that branches to the loop header block");
2355
2356 // Make sure that no other blocks (except the entry and loop continue block)
2357 // branches to the loop header block.
2358 for (auto &block : llvm::make_range(std::next(region.begin(), 2),
2359 std::prev(region.end(), 2))) {
2360 for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
2361 if (block.getSuccessor(i) == &header) {
2362 return loopOp.emitOpError("can only have the entry and loop continue "
2363 "block branching to the loop header block");
2364 }
2365 }
2366 }
2367
2368 return success();
2369 }
2370
getEntryBlock()2371 Block *spirv::LoopOp::getEntryBlock() {
2372 assert(!body().empty() && "op region should not be empty!");
2373 return &body().front();
2374 }
2375
getHeaderBlock()2376 Block *spirv::LoopOp::getHeaderBlock() {
2377 assert(!body().empty() && "op region should not be empty!");
2378 // The second block is the loop header block.
2379 return &*std::next(body().begin());
2380 }
2381
getContinueBlock()2382 Block *spirv::LoopOp::getContinueBlock() {
2383 assert(!body().empty() && "op region should not be empty!");
2384 // The second to last block is the loop continue block.
2385 return &*std::prev(body().end(), 2);
2386 }
2387
getMergeBlock()2388 Block *spirv::LoopOp::getMergeBlock() {
2389 assert(!body().empty() && "op region should not be empty!");
2390 // The last block is the loop merge block.
2391 return &body().back();
2392 }
2393
addEntryAndMergeBlock()2394 void spirv::LoopOp::addEntryAndMergeBlock() {
2395 assert(body().empty() && "entry and merge block already exist");
2396 body().push_back(new Block());
2397 auto *mergeBlock = new Block();
2398 body().push_back(mergeBlock);
2399 OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
2400
2401 // Add a spv.mlir.merge op into the merge block.
2402 builder.create<spirv::MergeOp>(getLoc());
2403 }
2404
2405 //===----------------------------------------------------------------------===//
2406 // spv.mlir.merge
2407 //===----------------------------------------------------------------------===//
2408
verify(spirv::MergeOp mergeOp)2409 static LogicalResult verify(spirv::MergeOp mergeOp) {
2410 auto *parentOp = mergeOp->getParentOp();
2411 if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
2412 return mergeOp.emitOpError(
2413 "expected parent op to be 'spv.selection' or 'spv.loop'");
2414
2415 Block &parentLastBlock = mergeOp->getParentRegion()->back();
2416 if (mergeOp.getOperation() != parentLastBlock.getTerminator())
2417 return mergeOp.emitOpError(
2418 "can only be used in the last block of 'spv.selection' or 'spv.loop'");
2419 return success();
2420 }
2421
2422 //===----------------------------------------------------------------------===//
2423 // spv.module
2424 //===----------------------------------------------------------------------===//
2425
build(OpBuilder & builder,OperationState & state,Optional<StringRef> name)2426 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
2427 Optional<StringRef> name) {
2428 ensureTerminator(*state.addRegion(), builder, state.location);
2429 if (name) {
2430 state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
2431 builder.getStringAttr(*name));
2432 }
2433 }
2434
build(OpBuilder & builder,OperationState & state,spirv::AddressingModel addressingModel,spirv::MemoryModel memoryModel,Optional<StringRef> name)2435 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
2436 spirv::AddressingModel addressingModel,
2437 spirv::MemoryModel memoryModel,
2438 Optional<StringRef> name) {
2439 state.addAttribute(
2440 "addressing_model",
2441 builder.getI32IntegerAttr(static_cast<int32_t>(addressingModel)));
2442 state.addAttribute("memory_model", builder.getI32IntegerAttr(
2443 static_cast<int32_t>(memoryModel)));
2444 ensureTerminator(*state.addRegion(), builder, state.location);
2445 if (name) {
2446 state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
2447 builder.getStringAttr(*name));
2448 }
2449 }
2450
parseModuleOp(OpAsmParser & parser,OperationState & state)2451 static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
2452 Region *body = state.addRegion();
2453
2454 // If the name is present, parse it.
2455 StringAttr nameAttr;
2456 parser.parseOptionalSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2457 state.attributes);
2458
2459 // Parse attributes
2460 spirv::AddressingModel addrModel;
2461 spirv::MemoryModel memoryModel;
2462 if (parseEnumKeywordAttr(addrModel, parser, state) ||
2463 parseEnumKeywordAttr(memoryModel, parser, state))
2464 return failure();
2465
2466 if (succeeded(parser.parseOptionalKeyword("requires"))) {
2467 spirv::VerCapExtAttr vceTriple;
2468 if (parser.parseAttribute(vceTriple,
2469 spirv::ModuleOp::getVCETripleAttrName(),
2470 state.attributes))
2471 return failure();
2472 }
2473
2474 if (parser.parseOptionalAttrDictWithKeyword(state.attributes))
2475 return failure();
2476
2477 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
2478 return failure();
2479
2480 spirv::ModuleOp::ensureTerminator(*body, parser.getBuilder(), state.location);
2481 return success();
2482 }
2483
print(spirv::ModuleOp moduleOp,OpAsmPrinter & printer)2484 static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) {
2485 printer << spirv::ModuleOp::getOperationName();
2486
2487 if (Optional<StringRef> name = moduleOp.getName()) {
2488 printer << ' ';
2489 printer.printSymbolName(*name);
2490 }
2491
2492 SmallVector<StringRef, 2> elidedAttrs;
2493
2494 printer << " " << spirv::stringifyAddressingModel(moduleOp.addressing_model())
2495 << " " << spirv::stringifyMemoryModel(moduleOp.memory_model());
2496 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
2497 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
2498 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
2499 SymbolTable::getSymbolAttrName()});
2500
2501 if (Optional<spirv::VerCapExtAttr> triple = moduleOp.vce_triple()) {
2502 printer << " requires " << *triple;
2503 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
2504 }
2505
2506 printer.printOptionalAttrDictWithKeyword(moduleOp.getAttrs(), elidedAttrs);
2507 printer.printRegion(moduleOp.body(), /*printEntryBlockArgs=*/false,
2508 /*printBlockTerminators=*/false);
2509 }
2510
verify(spirv::ModuleOp moduleOp)2511 static LogicalResult verify(spirv::ModuleOp moduleOp) {
2512 auto &op = *moduleOp.getOperation();
2513 auto *dialect = op.getDialect();
2514 DenseMap<std::pair<spirv::FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
2515 entryPoints;
2516 SymbolTable table(moduleOp);
2517
2518 for (auto &op : moduleOp.getBlock()) {
2519 if (op.getDialect() != dialect)
2520 return op.emitError("'spv.module' can only contain spv.* ops");
2521
2522 // For EntryPoint op, check that the function and execution model is not
2523 // duplicated in EntryPointOps. Also verify that the interface specified
2524 // comes from globalVariables here to make this check cheaper.
2525 if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
2526 auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.fn());
2527 if (!funcOp) {
2528 return entryPointOp.emitError("function '")
2529 << entryPointOp.fn() << "' not found in 'spv.module'";
2530 }
2531 if (auto interface = entryPointOp.interface()) {
2532 for (Attribute varRef : interface) {
2533 auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
2534 if (!varSymRef) {
2535 return entryPointOp.emitError(
2536 "expected symbol reference for interface "
2537 "specification instead of '")
2538 << varRef;
2539 }
2540 auto variableOp =
2541 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
2542 if (!variableOp) {
2543 return entryPointOp.emitError("expected spv.globalVariable "
2544 "symbol reference instead of'")
2545 << varSymRef << "'";
2546 }
2547 }
2548 }
2549
2550 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
2551 funcOp, entryPointOp.execution_model());
2552 auto entryPtIt = entryPoints.find(key);
2553 if (entryPtIt != entryPoints.end()) {
2554 return entryPointOp.emitError("duplicate of a previous EntryPointOp");
2555 }
2556 entryPoints[key] = entryPointOp;
2557 } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
2558 if (funcOp.isExternal())
2559 return op.emitError("'spv.module' cannot contain external functions");
2560
2561 // TODO: move this check to spv.func.
2562 for (auto &block : funcOp)
2563 for (auto &op : block) {
2564 if (op.getDialect() != dialect)
2565 return op.emitError(
2566 "functions in 'spv.module' can only contain spv.* ops");
2567 }
2568 }
2569 }
2570
2571 return success();
2572 }
2573
2574 //===----------------------------------------------------------------------===//
2575 // spv.mlir.referenceof
2576 //===----------------------------------------------------------------------===//
2577
verify(spirv::ReferenceOfOp referenceOfOp)2578 static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
2579 auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
2580 referenceOfOp->getParentOp(), referenceOfOp.spec_const());
2581 Type constType;
2582
2583 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
2584 if (specConstOp)
2585 constType = specConstOp.default_value().getType();
2586
2587 auto specConstCompositeOp =
2588 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
2589 if (specConstCompositeOp)
2590 constType = specConstCompositeOp.type();
2591
2592 if (!specConstOp && !specConstCompositeOp)
2593 return referenceOfOp.emitOpError(
2594 "expected spv.specConstant or spv.SpecConstantComposite symbol");
2595
2596 if (referenceOfOp.reference().getType() != constType)
2597 return referenceOfOp.emitOpError("result type mismatch with the referenced "
2598 "specialization constant's type");
2599
2600 return success();
2601 }
2602
2603 //===----------------------------------------------------------------------===//
2604 // spv.Return
2605 //===----------------------------------------------------------------------===//
2606
verify(spirv::ReturnOp returnOp)2607 static LogicalResult verify(spirv::ReturnOp returnOp) {
2608 // Verification is performed in spv.func op.
2609 return success();
2610 }
2611
2612 //===----------------------------------------------------------------------===//
2613 // spv.ReturnValue
2614 //===----------------------------------------------------------------------===//
2615
verify(spirv::ReturnValueOp retValOp)2616 static LogicalResult verify(spirv::ReturnValueOp retValOp) {
2617 // Verification is performed in spv.func op.
2618 return success();
2619 }
2620
2621 //===----------------------------------------------------------------------===//
2622 // spv.Select
2623 //===----------------------------------------------------------------------===//
2624
build(OpBuilder & builder,OperationState & state,Value cond,Value trueValue,Value falseValue)2625 void spirv::SelectOp::build(OpBuilder &builder, OperationState &state,
2626 Value cond, Value trueValue, Value falseValue) {
2627 build(builder, state, trueValue.getType(), cond, trueValue, falseValue);
2628 }
2629
verify(spirv::SelectOp op)2630 static LogicalResult verify(spirv::SelectOp op) {
2631 if (auto conditionTy = op.condition().getType().dyn_cast<VectorType>()) {
2632 auto resultVectorTy = op.result().getType().dyn_cast<VectorType>();
2633 if (!resultVectorTy) {
2634 return op.emitOpError("result expected to be of vector type when "
2635 "condition is of vector type");
2636 }
2637 if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
2638 return op.emitOpError("result should have the same number of elements as "
2639 "the condition when condition is of vector type");
2640 }
2641 }
2642 return success();
2643 }
2644
2645 //===----------------------------------------------------------------------===//
2646 // spv.selection
2647 //===----------------------------------------------------------------------===//
2648
parseSelectionOp(OpAsmParser & parser,OperationState & state)2649 static ParseResult parseSelectionOp(OpAsmParser &parser,
2650 OperationState &state) {
2651 if (parseControlAttribute<spirv::SelectionControl>(parser, state))
2652 return failure();
2653 return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
2654 /*argTypes=*/{});
2655 }
2656
print(spirv::SelectionOp selectionOp,OpAsmPrinter & printer)2657 static void print(spirv::SelectionOp selectionOp, OpAsmPrinter &printer) {
2658 auto *op = selectionOp.getOperation();
2659
2660 printer << spirv::SelectionOp::getOperationName();
2661 auto control = selectionOp.selection_control();
2662 if (control != spirv::SelectionControl::None)
2663 printer << " control(" << spirv::stringifySelectionControl(control) << ")";
2664 printer.printRegion(op->getRegion(0), /*printEntryBlockArgs=*/false,
2665 /*printBlockTerminators=*/true);
2666 }
2667
verify(spirv::SelectionOp selectionOp)2668 static LogicalResult verify(spirv::SelectionOp selectionOp) {
2669 auto *op = selectionOp.getOperation();
2670
2671 // We need to verify that the blocks follow the following layout:
2672 //
2673 // +--------------+
2674 // | header block |
2675 // +--------------+
2676 // / | \
2677 // ...
2678 //
2679 //
2680 // +---------+ +---------+ +---------+
2681 // | case #0 | | case #1 | | case #2 | ...
2682 // +---------+ +---------+ +---------+
2683 //
2684 //
2685 // ...
2686 // \ | /
2687 // v
2688 // +-------------+
2689 // | merge block |
2690 // +-------------+
2691
2692 auto ®ion = op->getRegion(0);
2693 // Allow empty region as a degenerated case, which can come from
2694 // optimizations.
2695 if (region.empty())
2696 return success();
2697
2698 // The last block is the merge block.
2699 if (!isMergeBlock(region.back()))
2700 return selectionOp.emitOpError(
2701 "last block must be the merge block with only one 'spv.mlir.merge' op");
2702
2703 if (std::next(region.begin()) == region.end())
2704 return selectionOp.emitOpError("must have a selection header block");
2705
2706 return success();
2707 }
2708
getHeaderBlock()2709 Block *spirv::SelectionOp::getHeaderBlock() {
2710 assert(!body().empty() && "op region should not be empty!");
2711 // The first block is the loop header block.
2712 return &body().front();
2713 }
2714
getMergeBlock()2715 Block *spirv::SelectionOp::getMergeBlock() {
2716 assert(!body().empty() && "op region should not be empty!");
2717 // The last block is the loop merge block.
2718 return &body().back();
2719 }
2720
addMergeBlock()2721 void spirv::SelectionOp::addMergeBlock() {
2722 assert(body().empty() && "entry and merge block already exist");
2723 auto *mergeBlock = new Block();
2724 body().push_back(mergeBlock);
2725 OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
2726
2727 // Add a spv.mlir.merge op into the merge block.
2728 builder.create<spirv::MergeOp>(getLoc());
2729 }
2730
createIfThen(Location loc,Value condition,function_ref<void (OpBuilder & builder)> thenBody,OpBuilder & builder)2731 spirv::SelectionOp spirv::SelectionOp::createIfThen(
2732 Location loc, Value condition,
2733 function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) {
2734 auto selectionControl = builder.getI32IntegerAttr(
2735 static_cast<uint32_t>(spirv::SelectionControl::None));
2736 auto selectionOp = builder.create<spirv::SelectionOp>(loc, selectionControl);
2737
2738 selectionOp.addMergeBlock();
2739 Block *mergeBlock = selectionOp.getMergeBlock();
2740 Block *thenBlock = nullptr;
2741
2742 // Build the "then" block.
2743 {
2744 OpBuilder::InsertionGuard guard(builder);
2745 thenBlock = builder.createBlock(mergeBlock);
2746 thenBody(builder);
2747 builder.create<spirv::BranchOp>(loc, mergeBlock);
2748 }
2749
2750 // Build the header block.
2751 {
2752 OpBuilder::InsertionGuard guard(builder);
2753 builder.createBlock(thenBlock);
2754 builder.create<spirv::BranchConditionalOp>(
2755 loc, condition, thenBlock,
2756 /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
2757 /*falseArguments=*/ArrayRef<Value>());
2758 }
2759
2760 return selectionOp;
2761 }
2762
2763 //===----------------------------------------------------------------------===//
2764 // spv.specConstant
2765 //===----------------------------------------------------------------------===//
2766
parseSpecConstantOp(OpAsmParser & parser,OperationState & state)2767 static ParseResult parseSpecConstantOp(OpAsmParser &parser,
2768 OperationState &state) {
2769 StringAttr nameAttr;
2770 Attribute valueAttr;
2771
2772 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2773 state.attributes))
2774 return failure();
2775
2776 // Parse optional spec_id.
2777 if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
2778 IntegerAttr specIdAttr;
2779 if (parser.parseLParen() ||
2780 parser.parseAttribute(specIdAttr, kSpecIdAttrName, state.attributes) ||
2781 parser.parseRParen())
2782 return failure();
2783 }
2784
2785 if (parser.parseEqual() ||
2786 parser.parseAttribute(valueAttr, kDefaultValueAttrName, state.attributes))
2787 return failure();
2788
2789 return success();
2790 }
2791
print(spirv::SpecConstantOp constOp,OpAsmPrinter & printer)2792 static void print(spirv::SpecConstantOp constOp, OpAsmPrinter &printer) {
2793 printer << spirv::SpecConstantOp::getOperationName() << ' ';
2794 printer.printSymbolName(constOp.sym_name());
2795 if (auto specID = constOp->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
2796 printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
2797 printer << " = " << constOp.default_value();
2798 }
2799
verify(spirv::SpecConstantOp constOp)2800 static LogicalResult verify(spirv::SpecConstantOp constOp) {
2801 if (auto specID = constOp->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
2802 if (specID.getValue().isNegative())
2803 return constOp.emitOpError("SpecId cannot be negative");
2804
2805 auto value = constOp.default_value();
2806 if (value.isa<IntegerAttr, FloatAttr>()) {
2807 // Make sure bitwidth is allowed.
2808 if (!value.getType().isa<spirv::SPIRVType>())
2809 return constOp.emitOpError("default value bitwidth disallowed");
2810 return success();
2811 }
2812 return constOp.emitOpError(
2813 "default value can only be a bool, integer, or float scalar");
2814 }
2815
2816 //===----------------------------------------------------------------------===//
2817 // spv.StoreOp
2818 //===----------------------------------------------------------------------===//
2819
parseStoreOp(OpAsmParser & parser,OperationState & state)2820 static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &state) {
2821 // Parse the storage class specification
2822 spirv::StorageClass storageClass;
2823 SmallVector<OpAsmParser::OperandType, 2> operandInfo;
2824 auto loc = parser.getCurrentLocation();
2825 Type elementType;
2826 if (parseEnumStrAttr(storageClass, parser) ||
2827 parser.parseOperandList(operandInfo, 2) ||
2828 parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
2829 parser.parseType(elementType)) {
2830 return failure();
2831 }
2832
2833 auto ptrType = spirv::PointerType::get(elementType, storageClass);
2834 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
2835 state.operands)) {
2836 return failure();
2837 }
2838 return success();
2839 }
2840
print(spirv::StoreOp storeOp,OpAsmPrinter & printer)2841 static void print(spirv::StoreOp storeOp, OpAsmPrinter &printer) {
2842 auto *op = storeOp.getOperation();
2843 SmallVector<StringRef, 4> elidedAttrs;
2844 StringRef sc = stringifyStorageClass(
2845 storeOp.ptr().getType().cast<spirv::PointerType>().getStorageClass());
2846 printer << spirv::StoreOp::getOperationName() << " \"" << sc << "\" "
2847 << storeOp.ptr() << ", " << storeOp.value();
2848
2849 printMemoryAccessAttribute(storeOp, printer, elidedAttrs);
2850
2851 printer << " : " << storeOp.value().getType();
2852 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
2853 }
2854
verify(spirv::StoreOp storeOp)2855 static LogicalResult verify(spirv::StoreOp storeOp) {
2856 // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
2857 // OpTypePointer whose Type operand is the same as the type of Object."
2858 if (failed(verifyLoadStorePtrAndValTypes(storeOp, storeOp.ptr(),
2859 storeOp.value()))) {
2860 return failure();
2861 }
2862 return verifyMemoryAccessAttribute(storeOp);
2863 }
2864
2865 //===----------------------------------------------------------------------===//
2866 // spv.Unreachable
2867 //===----------------------------------------------------------------------===//
2868
verify(spirv::UnreachableOp unreachableOp)2869 static LogicalResult verify(spirv::UnreachableOp unreachableOp) {
2870 auto *op = unreachableOp.getOperation();
2871 auto *block = op->getBlock();
2872 // Fast track: if this is in entry block, its invalid. Otherwise, if no
2873 // predecessors, it's valid.
2874 if (block->isEntryBlock())
2875 return unreachableOp.emitOpError("cannot be used in reachable block");
2876 if (block->hasNoPredecessors())
2877 return success();
2878
2879 // TODO: further verification needs to analyze reachability from
2880 // the entry block.
2881
2882 return success();
2883 }
2884
2885 //===----------------------------------------------------------------------===//
2886 // spv.Variable
2887 //===----------------------------------------------------------------------===//
2888
parseVariableOp(OpAsmParser & parser,OperationState & state)2889 static ParseResult parseVariableOp(OpAsmParser &parser, OperationState &state) {
2890 // Parse optional initializer
2891 Optional<OpAsmParser::OperandType> initInfo;
2892 if (succeeded(parser.parseOptionalKeyword("init"))) {
2893 initInfo = OpAsmParser::OperandType();
2894 if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
2895 parser.parseRParen())
2896 return failure();
2897 }
2898
2899 if (parseVariableDecorations(parser, state)) {
2900 return failure();
2901 }
2902
2903 // Parse result pointer type
2904 Type type;
2905 if (parser.parseColon())
2906 return failure();
2907 auto loc = parser.getCurrentLocation();
2908 if (parser.parseType(type))
2909 return failure();
2910
2911 auto ptrType = type.dyn_cast<spirv::PointerType>();
2912 if (!ptrType)
2913 return parser.emitError(loc, "expected spv.ptr type");
2914 state.addTypes(ptrType);
2915
2916 // Resolve the initializer operand
2917 if (initInfo) {
2918 if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
2919 state.operands))
2920 return failure();
2921 }
2922
2923 auto attr = parser.getBuilder().getI32IntegerAttr(
2924 llvm::bit_cast<int32_t>(ptrType.getStorageClass()));
2925 state.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
2926
2927 return success();
2928 }
2929
print(spirv::VariableOp varOp,OpAsmPrinter & printer)2930 static void print(spirv::VariableOp varOp, OpAsmPrinter &printer) {
2931 SmallVector<StringRef, 4> elidedAttrs{
2932 spirv::attributeName<spirv::StorageClass>()};
2933 printer << spirv::VariableOp::getOperationName();
2934
2935 // Print optional initializer
2936 if (varOp.getNumOperands() != 0)
2937 printer << " init(" << varOp.initializer() << ")";
2938
2939 printVariableDecorations(varOp, printer, elidedAttrs);
2940 printer << " : " << varOp.getType();
2941 }
2942
verify(spirv::VariableOp varOp)2943 static LogicalResult verify(spirv::VariableOp varOp) {
2944 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
2945 // object. It cannot be Generic. It must be the same as the Storage Class
2946 // operand of the Result Type."
2947 if (varOp.storage_class() != spirv::StorageClass::Function) {
2948 return varOp.emitOpError(
2949 "can only be used to model function-level variables. Use "
2950 "spv.globalVariable for module-level variables.");
2951 }
2952
2953 auto pointerType = varOp.pointer().getType().cast<spirv::PointerType>();
2954 if (varOp.storage_class() != pointerType.getStorageClass())
2955 return varOp.emitOpError(
2956 "storage class must match result pointer's storage class");
2957
2958 if (varOp.getNumOperands() != 0) {
2959 // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
2960 // a global (module scope) OpVariable instruction".
2961 auto *initOp = varOp.getOperand(0).getDefiningOp();
2962 if (!initOp || !isa<spirv::ConstantOp, // for normal constant
2963 spirv::ReferenceOfOp, // for spec constant
2964 spirv::AddressOfOp>(initOp))
2965 return varOp.emitOpError("initializer must be the result of a "
2966 "constant or spv.globalVariable op");
2967 }
2968
2969 // TODO: generate these strings using ODS.
2970 auto *op = varOp.getOperation();
2971 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
2972 stringifyDecoration(spirv::Decoration::DescriptorSet));
2973 auto bindingName = llvm::convertToSnakeFromCamelCase(
2974 stringifyDecoration(spirv::Decoration::Binding));
2975 auto builtInName = llvm::convertToSnakeFromCamelCase(
2976 stringifyDecoration(spirv::Decoration::BuiltIn));
2977
2978 for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
2979 if (op->getAttr(attr))
2980 return varOp.emitOpError("cannot have '")
2981 << attr << "' attribute (only allowed in spv.globalVariable)";
2982 }
2983
2984 return success();
2985 }
2986
2987 //===----------------------------------------------------------------------===//
2988 // spv.CooperativeMatrixLoadNV
2989 //===----------------------------------------------------------------------===//
2990
parseCooperativeMatrixLoadNVOp(OpAsmParser & parser,OperationState & state)2991 static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser,
2992 OperationState &state) {
2993 SmallVector<OpAsmParser::OperandType, 3> operandInfo;
2994 Type strideType = parser.getBuilder().getIntegerType(32);
2995 Type columnMajorType = parser.getBuilder().getIntegerType(1);
2996 Type ptrType;
2997 Type elementType;
2998 if (parser.parseOperandList(operandInfo, 3) ||
2999 parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
3000 parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
3001 return failure();
3002 }
3003 if (parser.resolveOperands(operandInfo,
3004 {ptrType, strideType, columnMajorType},
3005 parser.getNameLoc(), state.operands)) {
3006 return failure();
3007 }
3008
3009 state.addTypes(elementType);
3010 return success();
3011 }
3012
print(spirv::CooperativeMatrixLoadNVOp M,OpAsmPrinter & printer)3013 static void print(spirv::CooperativeMatrixLoadNVOp M, OpAsmPrinter &printer) {
3014 printer << spirv::CooperativeMatrixLoadNVOp::getOperationName() << " "
3015 << M.pointer() << ", " << M.stride() << ", " << M.columnmajor();
3016 // Print optional memory access attribute.
3017 if (auto memAccess = M.memory_access())
3018 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
3019 printer << " : " << M.pointer().getType() << " as " << M.getType();
3020 }
3021
verifyPointerAndCoopMatrixType(Operation * op,Type pointer,Type coopMatrix)3022 static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
3023 Type coopMatrix) {
3024 Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
3025 if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
3026 return op->emitError(
3027 "Pointer must point to a scalar or vector type but provided ")
3028 << pointeeType;
3029 spirv::StorageClass storage =
3030 pointer.cast<spirv::PointerType>().getStorageClass();
3031 if (storage != spirv::StorageClass::Workgroup &&
3032 storage != spirv::StorageClass::StorageBuffer &&
3033 storage != spirv::StorageClass::PhysicalStorageBuffer)
3034 return op->emitError(
3035 "Pointer storage class must be Workgroup, StorageBuffer or "
3036 "PhysicalStorageBufferEXT but provided ")
3037 << stringifyStorageClass(storage);
3038 return success();
3039 }
3040
3041 //===----------------------------------------------------------------------===//
3042 // spv.CooperativeMatrixStoreNV
3043 //===----------------------------------------------------------------------===//
3044
parseCooperativeMatrixStoreNVOp(OpAsmParser & parser,OperationState & state)3045 static ParseResult parseCooperativeMatrixStoreNVOp(OpAsmParser &parser,
3046 OperationState &state) {
3047 SmallVector<OpAsmParser::OperandType, 4> operandInfo;
3048 Type strideType = parser.getBuilder().getIntegerType(32);
3049 Type columnMajorType = parser.getBuilder().getIntegerType(1);
3050 Type ptrType;
3051 Type elementType;
3052 if (parser.parseOperandList(operandInfo, 4) ||
3053 parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
3054 parser.parseType(ptrType) || parser.parseComma() ||
3055 parser.parseType(elementType)) {
3056 return failure();
3057 }
3058 if (parser.resolveOperands(
3059 operandInfo, {ptrType, elementType, strideType, columnMajorType},
3060 parser.getNameLoc(), state.operands)) {
3061 return failure();
3062 }
3063
3064 return success();
3065 }
3066
print(spirv::CooperativeMatrixStoreNVOp coopMatrix,OpAsmPrinter & printer)3067 static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix,
3068 OpAsmPrinter &printer) {
3069 printer << spirv::CooperativeMatrixStoreNVOp::getOperationName() << " "
3070 << coopMatrix.pointer() << ", " << coopMatrix.object() << ", "
3071 << coopMatrix.stride() << ", " << coopMatrix.columnmajor();
3072 // Print optional memory access attribute.
3073 if (auto memAccess = coopMatrix.memory_access())
3074 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
3075 printer << " : " << coopMatrix.pointer().getType() << ", "
3076 << coopMatrix.getOperand(1).getType();
3077 }
3078
3079 //===----------------------------------------------------------------------===//
3080 // spv.CooperativeMatrixMulAddNV
3081 //===----------------------------------------------------------------------===//
3082
3083 static LogicalResult
verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op)3084 verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
3085 if (op.c().getType() != op.result().getType())
3086 return op.emitOpError("result and third operand must have the same type");
3087 auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
3088 auto typeB = op.b().getType().cast<spirv::CooperativeMatrixNVType>();
3089 auto typeC = op.c().getType().cast<spirv::CooperativeMatrixNVType>();
3090 auto typeR = op.result().getType().cast<spirv::CooperativeMatrixNVType>();
3091 if (typeA.getRows() != typeR.getRows() ||
3092 typeA.getColumns() != typeB.getRows() ||
3093 typeB.getColumns() != typeR.getColumns())
3094 return op.emitOpError("matrix size must match");
3095 if (typeR.getScope() != typeA.getScope() ||
3096 typeR.getScope() != typeB.getScope() ||
3097 typeR.getScope() != typeC.getScope())
3098 return op.emitOpError("matrix scope must match");
3099 if (typeA.getElementType() != typeB.getElementType() ||
3100 typeR.getElementType() != typeC.getElementType())
3101 return op.emitOpError("matrix element type must match");
3102 return success();
3103 }
3104
3105 //===----------------------------------------------------------------------===//
3106 // spv.MatrixTimesScalar
3107 //===----------------------------------------------------------------------===//
3108
verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op)3109 static LogicalResult verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op) {
3110 // We already checked that result and matrix are both of matrix type in the
3111 // auto-generated verify method.
3112
3113 auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
3114 auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
3115
3116 // Check that the scalar type is the same as the matrix element type.
3117 if (op.scalar().getType() != inputMatrix.getElementType())
3118 return op.emitError("input matrix components' type and scaling value must "
3119 "have the same type");
3120
3121 // Note that the next three checks could be done using the AllTypesMatch
3122 // trait in the Op definition file but it generates a vague error message.
3123
3124 // Check that the input and result matrices have the same columns' count
3125 if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns())
3126 return op.emitError("input and result matrices must have the same "
3127 "number of columns");
3128
3129 // Check that the input and result matrices' have the same rows count
3130 if (inputMatrix.getNumRows() != resultMatrix.getNumRows())
3131 return op.emitError("input and result matrices' columns must have "
3132 "the same size");
3133
3134 // Check that the input and result matrices' have the same component type
3135 if (inputMatrix.getElementType() != resultMatrix.getElementType())
3136 return op.emitError("input and result matrices' columns must have "
3137 "the same component type");
3138
3139 return success();
3140 }
3141
3142 //===----------------------------------------------------------------------===//
3143 // spv.CopyMemory
3144 //===----------------------------------------------------------------------===//
3145
print(spirv::CopyMemoryOp copyMemory,OpAsmPrinter & printer)3146 static void print(spirv::CopyMemoryOp copyMemory, OpAsmPrinter &printer) {
3147 auto *op = copyMemory.getOperation();
3148 printer << spirv::CopyMemoryOp::getOperationName() << ' ';
3149
3150 StringRef targetStorageClass =
3151 stringifyStorageClass(copyMemory.target()
3152 .getType()
3153 .cast<spirv::PointerType>()
3154 .getStorageClass());
3155 printer << " \"" << targetStorageClass << "\" " << copyMemory.target()
3156 << ", ";
3157
3158 StringRef sourceStorageClass =
3159 stringifyStorageClass(copyMemory.source()
3160 .getType()
3161 .cast<spirv::PointerType>()
3162 .getStorageClass());
3163 printer << " \"" << sourceStorageClass << "\" " << copyMemory.source();
3164
3165 SmallVector<StringRef, 4> elidedAttrs;
3166 printMemoryAccessAttribute(copyMemory, printer, elidedAttrs);
3167 printSourceMemoryAccessAttribute(copyMemory, printer, elidedAttrs,
3168 copyMemory.source_memory_access(),
3169 copyMemory.source_alignment());
3170
3171 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
3172
3173 Type pointeeType =
3174 copyMemory.target().getType().cast<spirv::PointerType>().getPointeeType();
3175 printer << " : " << pointeeType;
3176 }
3177
parseCopyMemoryOp(OpAsmParser & parser,OperationState & state)3178 static ParseResult parseCopyMemoryOp(OpAsmParser &parser,
3179 OperationState &state) {
3180 spirv::StorageClass targetStorageClass;
3181 OpAsmParser::OperandType targetPtrInfo;
3182
3183 spirv::StorageClass sourceStorageClass;
3184 OpAsmParser::OperandType sourcePtrInfo;
3185
3186 Type elementType;
3187
3188 if (parseEnumStrAttr(targetStorageClass, parser) ||
3189 parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
3190 parseEnumStrAttr(sourceStorageClass, parser) ||
3191 parser.parseOperand(sourcePtrInfo) ||
3192 parseMemoryAccessAttributes(parser, state)) {
3193 return failure();
3194 }
3195
3196 if (!parser.parseOptionalComma()) {
3197 // Parse 2nd memory access attributes.
3198 if (parseSourceMemoryAccessAttributes(parser, state)) {
3199 return failure();
3200 }
3201 }
3202
3203 if (parser.parseColon() || parser.parseType(elementType))
3204 return failure();
3205
3206 if (parser.parseOptionalAttrDict(state.attributes))
3207 return failure();
3208
3209 auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
3210 auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
3211
3212 if (parser.resolveOperand(targetPtrInfo, targetPtrType, state.operands) ||
3213 parser.resolveOperand(sourcePtrInfo, sourcePtrType, state.operands)) {
3214 return failure();
3215 }
3216
3217 return success();
3218 }
3219
verifyCopyMemory(spirv::CopyMemoryOp copyMemory)3220 static LogicalResult verifyCopyMemory(spirv::CopyMemoryOp copyMemory) {
3221 Type targetType =
3222 copyMemory.target().getType().cast<spirv::PointerType>().getPointeeType();
3223
3224 Type sourceType =
3225 copyMemory.source().getType().cast<spirv::PointerType>().getPointeeType();
3226
3227 if (targetType != sourceType) {
3228 return copyMemory.emitOpError(
3229 "both operands must be pointers to the same type");
3230 }
3231
3232 if (failed(verifyMemoryAccessAttribute(copyMemory))) {
3233 return failure();
3234 }
3235
3236 // TODO - According to the spec:
3237 //
3238 // If two masks are present, the first applies to Target and cannot include
3239 // MakePointerVisible, and the second applies to Source and cannot include
3240 // MakePointerAvailable.
3241 //
3242 // Add such verification here.
3243
3244 return verifySourceMemoryAccessAttribute(copyMemory);
3245 }
3246
3247 //===----------------------------------------------------------------------===//
3248 // spv.Transpose
3249 //===----------------------------------------------------------------------===//
3250
verifyTranspose(spirv::TransposeOp op)3251 static LogicalResult verifyTranspose(spirv::TransposeOp op) {
3252 auto inputMatrix = op.matrix().getType().cast<spirv::MatrixType>();
3253 auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
3254
3255 // Verify that the input and output matrices have correct shapes.
3256 if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
3257 return op.emitError("input matrix rows count must be equal to "
3258 "output matrix columns count");
3259
3260 if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
3261 return op.emitError("input matrix columns count must be equal to "
3262 "output matrix rows count");
3263
3264 // Verify that the input and output matrices have the same component type
3265 if (inputMatrix.getElementType() != resultMatrix.getElementType())
3266 return op.emitError("input and output matrices must have the same "
3267 "component type");
3268
3269 return success();
3270 }
3271
3272 //===----------------------------------------------------------------------===//
3273 // spv.MatrixTimesMatrix
3274 //===----------------------------------------------------------------------===//
3275
verifyMatrixTimesMatrix(spirv::MatrixTimesMatrixOp op)3276 static LogicalResult verifyMatrixTimesMatrix(spirv::MatrixTimesMatrixOp op) {
3277 auto leftMatrix = op.leftmatrix().getType().cast<spirv::MatrixType>();
3278 auto rightMatrix = op.rightmatrix().getType().cast<spirv::MatrixType>();
3279 auto resultMatrix = op.result().getType().cast<spirv::MatrixType>();
3280
3281 // left matrix columns' count and right matrix rows' count must be equal
3282 if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
3283 return op.emitError("left matrix columns' count must be equal to "
3284 "the right matrix rows' count");
3285
3286 // right and result matrices columns' count must be the same
3287 if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
3288 return op.emitError(
3289 "right and result matrices must have equal columns' count");
3290
3291 // right and result matrices component type must be the same
3292 if (rightMatrix.getElementType() != resultMatrix.getElementType())
3293 return op.emitError("right and result matrices' component type must"
3294 " be the same");
3295
3296 // left and result matrices component type must be the same
3297 if (leftMatrix.getElementType() != resultMatrix.getElementType())
3298 return op.emitError("left and result matrices' component type"
3299 " must be the same");
3300
3301 // left and result matrices rows count must be the same
3302 if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
3303 return op.emitError("left and result matrices must have equal rows'"
3304 " count");
3305
3306 return success();
3307 }
3308
3309 //===----------------------------------------------------------------------===//
3310 // spv.specConstantComposite
3311 //===----------------------------------------------------------------------===//
3312
parseSpecConstantCompositeOp(OpAsmParser & parser,OperationState & state)3313 static ParseResult parseSpecConstantCompositeOp(OpAsmParser &parser,
3314 OperationState &state) {
3315
3316 StringAttr compositeName;
3317 if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
3318 state.attributes))
3319 return failure();
3320
3321 if (parser.parseLParen())
3322 return failure();
3323
3324 SmallVector<Attribute, 4> constituents;
3325
3326 do {
3327 // The name of the constituent attribute isn't important
3328 const char *attrName = "spec_const";
3329 FlatSymbolRefAttr specConstRef;
3330 NamedAttrList attrs;
3331
3332 if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
3333 return failure();
3334
3335 constituents.push_back(specConstRef);
3336 } while (!parser.parseOptionalComma());
3337
3338 if (parser.parseRParen())
3339 return failure();
3340
3341 state.addAttribute(kCompositeSpecConstituentsName,
3342 parser.getBuilder().getArrayAttr(constituents));
3343
3344 Type type;
3345 if (parser.parseColonType(type))
3346 return failure();
3347
3348 state.addAttribute(kTypeAttrName, TypeAttr::get(type));
3349
3350 return success();
3351 }
3352
print(spirv::SpecConstantCompositeOp op,OpAsmPrinter & printer)3353 static void print(spirv::SpecConstantCompositeOp op, OpAsmPrinter &printer) {
3354 printer << spirv::SpecConstantCompositeOp::getOperationName() << " ";
3355 printer.printSymbolName(op.sym_name());
3356 printer << " (";
3357 auto constituents = op.constituents().getValue();
3358
3359 if (!constituents.empty())
3360 llvm::interleaveComma(constituents, printer);
3361
3362 printer << ") : " << op.type();
3363 }
3364
verify(spirv::SpecConstantCompositeOp constOp)3365 static LogicalResult verify(spirv::SpecConstantCompositeOp constOp) {
3366 auto cType = constOp.type().dyn_cast<spirv::CompositeType>();
3367 auto constituents = constOp.constituents().getValue();
3368
3369 if (!cType)
3370 return constOp.emitError(
3371 "result type must be a composite type, but provided ")
3372 << constOp.type();
3373
3374 if (cType.isa<spirv::CooperativeMatrixNVType>())
3375 return constOp.emitError("unsupported composite type ") << cType;
3376 else if (constituents.size() != cType.getNumElements())
3377 return constOp.emitError("has incorrect number of operands: expected ")
3378 << cType.getNumElements() << ", but provided "
3379 << constituents.size();
3380
3381 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
3382 auto constituent = constituents[index].dyn_cast<FlatSymbolRefAttr>();
3383
3384 auto constituentSpecConstOp =
3385 dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
3386 constOp->getParentOp(), constituent.getValue()));
3387
3388 if (constituentSpecConstOp.default_value().getType() !=
3389 cType.getElementType(index))
3390 return constOp.emitError("has incorrect types of operands: expected ")
3391 << cType.getElementType(index) << ", but provided "
3392 << constituentSpecConstOp.default_value().getType();
3393 }
3394
3395 return success();
3396 }
3397
3398 //===----------------------------------------------------------------------===//
3399 // spv.mlir.yield
3400 //===----------------------------------------------------------------------===//
3401
verify(spirv::YieldOp yieldOp)3402 static LogicalResult verify(spirv::YieldOp yieldOp) {
3403 Operation *parentOp = yieldOp->getParentOp();
3404
3405 if (!parentOp || !isa<spirv::SpecConstantOperationOp>(parentOp))
3406 return yieldOp.emitOpError(
3407 "expected parent op to be 'spv.SpecConstantOperation'");
3408
3409 Block &block = parentOp->getRegion(0).getBlocks().front();
3410 Operation &enclosedOp = block.getOperations().front();
3411
3412 if (yieldOp.getOperand().getDefiningOp() != &enclosedOp)
3413 return yieldOp.emitOpError(
3414 "expected operand to be defined by preceeding op");
3415
3416 return success();
3417 }
3418
parseSpecConstantOperationOp(OpAsmParser & parser,OperationState & state)3419 static ParseResult parseSpecConstantOperationOp(OpAsmParser &parser,
3420 OperationState &state) {
3421 // TODO: For now, only generic form is supported.
3422 return failure();
3423 }
3424
print(spirv::SpecConstantOperationOp op,OpAsmPrinter & printer)3425 static void print(spirv::SpecConstantOperationOp op, OpAsmPrinter &printer) {
3426 // TODO
3427 printer.printGenericOp(op);
3428 }
3429
verify(spirv::SpecConstantOperationOp constOp)3430 static LogicalResult verify(spirv::SpecConstantOperationOp constOp) {
3431 Block &block = constOp.getRegion().getBlocks().front();
3432
3433 if (block.getOperations().size() != 2)
3434 return constOp.emitOpError("expected exactly 2 nested ops");
3435
3436 Operation &yieldOp = block.getOperations().back();
3437
3438 if (!isa<spirv::YieldOp>(yieldOp))
3439 return constOp.emitOpError("expected terminator to be a yield op");
3440
3441 Operation &enclosedOp = block.getOperations().front();
3442
3443 // TODO Add a `UsableInSpecConstantOp` trait and mark ops from the list below
3444 // with it instead.
3445 if (!isa<spirv::SConvertOp, spirv::UConvertOp, spirv::FConvertOp,
3446 spirv::SNegateOp, spirv::NotOp, spirv::IAddOp, spirv::ISubOp,
3447 spirv::IMulOp, spirv::UDivOp, spirv::SDivOp, spirv::UModOp,
3448 spirv::SRemOp, spirv::SModOp, spirv::ShiftRightLogicalOp,
3449 spirv::ShiftRightArithmeticOp, spirv::ShiftLeftLogicalOp,
3450 spirv::BitwiseOrOp, spirv::BitwiseXorOp, spirv::BitwiseAndOp,
3451 spirv::CompositeExtractOp, spirv::CompositeInsertOp,
3452 spirv::LogicalOrOp, spirv::LogicalAndOp, spirv::LogicalNotOp,
3453 spirv::LogicalEqualOp, spirv::LogicalNotEqualOp, spirv::SelectOp,
3454 spirv::IEqualOp, spirv::INotEqualOp, spirv::ULessThanOp,
3455 spirv::SLessThanOp, spirv::UGreaterThanOp, spirv::SGreaterThanOp,
3456 spirv::ULessThanEqualOp, spirv::SLessThanEqualOp,
3457 spirv::UGreaterThanEqualOp, spirv::SGreaterThanEqualOp>(enclosedOp))
3458 return constOp.emitOpError("invalid enclosed op");
3459
3460 if (enclosedOp.getNumOperands() != constOp.getOperands().size())
3461 return constOp.emitOpError("invalid number of operands; expected ")
3462 << enclosedOp.getNumOperands() << ", actual "
3463 << constOp.getOperands().size();
3464
3465 if (enclosedOp.getNumOperands() != constOp.getRegion().getNumArguments())
3466 return constOp.emitOpError("invalid number of region arguments; expected ")
3467 << enclosedOp.getNumOperands() << ", actual "
3468 << constOp.getRegion().getNumArguments();
3469
3470 for (auto operand : constOp.getOperands())
3471 if (!isa<spirv::ConstantOp, spirv::SpecConstantOp,
3472 spirv::SpecConstantCompositeOp, spirv::SpecConstantOperationOp>(
3473 operand.getDefiningOp()))
3474 return constOp.emitOpError("invalid operand");
3475
3476 return success();
3477 }
3478
3479 namespace mlir {
3480 namespace spirv {
3481
3482 // TableGen'erated operation interfaces for querying versions, extensions, and
3483 // capabilities.
3484 #include "mlir/Dialect/SPIRV/SPIRVAvailability.cpp.inc"
3485 } // namespace spirv
3486 } // namespace mlir
3487
3488 // TablenGen'erated operation definitions.
3489 #define GET_OP_CLASSES
3490 #include "mlir/Dialect/SPIRV/SPIRVOps.cpp.inc"
3491
3492 namespace mlir {
3493 namespace spirv {
3494 // TableGen'erated operation availability interface implementations.
3495 #include "mlir/Dialect/SPIRV/SPIRVOpAvailabilityImpl.inc"
3496
3497 } // namespace spirv
3498 } // namespace mlir
3499