1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the MLIR AsmPrinter class, which is used to implement
10 // the various print() methods on the core IR objects.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/IR/AffineExpr.h"
15 #include "mlir/IR/AffineMap.h"
16 #include "mlir/IR/AsmState.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/IR/DialectImplementation.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/Operation.h"
25 #include "llvm/ADT/APFloat.h"
26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/ADT/MapVector.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/ScopedHashTable.h"
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/StringExtras.h"
33 #include "llvm/ADT/StringSet.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Support/Regex.h"
37 #include "llvm/Support/SaveAndRestore.h"
38 using namespace mlir;
39 using namespace mlir::detail;
40
print(raw_ostream & os) const41 void Identifier::print(raw_ostream &os) const { os << str(); }
42
dump() const43 void Identifier::dump() const { print(llvm::errs()); }
44
print(raw_ostream & os) const45 void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
46
dump() const47 void OperationName::dump() const { print(llvm::errs()); }
48
~DialectAsmPrinter()49 DialectAsmPrinter::~DialectAsmPrinter() {}
50
~OpAsmPrinter()51 OpAsmPrinter::~OpAsmPrinter() {}
52
53 //===--------------------------------------------------------------------===//
54 // Operation OpAsm interface.
55 //===--------------------------------------------------------------------===//
56
57 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
58 #include "mlir/IR/OpAsmInterface.cpp.inc"
59
60 //===----------------------------------------------------------------------===//
61 // OpPrintingFlags
62 //===----------------------------------------------------------------------===//
63
64 namespace {
65 /// This struct contains command line options that can be used to initialize
66 /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need
67 /// for global command line options.
68 struct AsmPrinterOptions {
69 llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{
70 "mlir-print-elementsattrs-with-hex-if-larger",
71 llvm::cl::desc(
72 "Print DenseElementsAttrs with a hex string that have "
73 "more elements than the given upper limit (use -1 to disable)")};
74
75 llvm::cl::opt<unsigned> elideElementsAttrIfLarger{
76 "mlir-elide-elementsattrs-if-larger",
77 llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
78 "more elements than the given upper limit")};
79
80 llvm::cl::opt<bool> printDebugInfoOpt{
81 "mlir-print-debuginfo", llvm::cl::init(false),
82 llvm::cl::desc("Print debug info in MLIR output")};
83
84 llvm::cl::opt<bool> printPrettyDebugInfoOpt{
85 "mlir-pretty-debuginfo", llvm::cl::init(false),
86 llvm::cl::desc("Print pretty debug info in MLIR output")};
87
88 // Use the generic op output form in the operation printer even if the custom
89 // form is defined.
90 llvm::cl::opt<bool> printGenericOpFormOpt{
91 "mlir-print-op-generic", llvm::cl::init(false),
92 llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden};
93
94 llvm::cl::opt<bool> printLocalScopeOpt{
95 "mlir-print-local-scope", llvm::cl::init(false),
96 llvm::cl::desc("Print assuming in local scope by default"),
97 llvm::cl::Hidden};
98 };
99 } // end anonymous namespace
100
101 static llvm::ManagedStatic<AsmPrinterOptions> clOptions;
102
103 /// Register a set of useful command-line options that can be used to configure
104 /// various flags within the AsmPrinter.
registerAsmPrinterCLOptions()105 void mlir::registerAsmPrinterCLOptions() {
106 // Make sure that the options struct has been initialized.
107 *clOptions;
108 }
109
110 /// Initialize the printing flags with default supplied by the cl::opts above.
OpPrintingFlags()111 OpPrintingFlags::OpPrintingFlags()
112 : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
113 printGenericOpFormFlag(false), printLocalScope(false) {
114 // Initialize based upon command line options, if they are available.
115 if (!clOptions.isConstructed())
116 return;
117 if (clOptions->elideElementsAttrIfLarger.getNumOccurrences())
118 elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger;
119 printDebugInfoFlag = clOptions->printDebugInfoOpt;
120 printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt;
121 printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
122 printLocalScope = clOptions->printLocalScopeOpt;
123 }
124
125 /// Enable the elision of large elements attributes, by printing a '...'
126 /// instead of the element data, when the number of elements is greater than
127 /// `largeElementLimit`. Note: The IR generated with this option is not
128 /// parsable.
129 OpPrintingFlags &
elideLargeElementsAttrs(int64_t largeElementLimit)130 OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) {
131 elementsAttrElementLimit = largeElementLimit;
132 return *this;
133 }
134
135 /// Enable printing of debug information. If 'prettyForm' is set to true,
136 /// debug information is printed in a more readable 'pretty' form.
enableDebugInfo(bool prettyForm)137 OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) {
138 printDebugInfoFlag = true;
139 printDebugInfoPrettyFormFlag = prettyForm;
140 return *this;
141 }
142
143 /// Always print operations in the generic form.
printGenericOpForm()144 OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
145 printGenericOpFormFlag = true;
146 return *this;
147 }
148
149 /// Use local scope when printing the operation. This allows for using the
150 /// printer in a more localized and thread-safe setting, but may not necessarily
151 /// be identical of what the IR will look like when dumping the full module.
useLocalScope()152 OpPrintingFlags &OpPrintingFlags::useLocalScope() {
153 printLocalScope = true;
154 return *this;
155 }
156
157 /// Return if the given ElementsAttr should be elided.
shouldElideElementsAttr(ElementsAttr attr) const158 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
159 return elementsAttrElementLimit.hasValue() &&
160 *elementsAttrElementLimit < int64_t(attr.getNumElements()) &&
161 !attr.isa<SplatElementsAttr>();
162 }
163
164 /// Return the size limit for printing large ElementsAttr.
getLargeElementsAttrLimit() const165 Optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const {
166 return elementsAttrElementLimit;
167 }
168
169 /// Return if debug information should be printed.
shouldPrintDebugInfo() const170 bool OpPrintingFlags::shouldPrintDebugInfo() const {
171 return printDebugInfoFlag;
172 }
173
174 /// Return if debug information should be printed in the pretty form.
shouldPrintDebugInfoPrettyForm() const175 bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const {
176 return printDebugInfoPrettyFormFlag;
177 }
178
179 /// Return if operations should be printed in the generic form.
shouldPrintGenericOpForm() const180 bool OpPrintingFlags::shouldPrintGenericOpForm() const {
181 return printGenericOpFormFlag;
182 }
183
184 /// Return if the printer should use local scope when dumping the IR.
shouldUseLocalScope() const185 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
186
187 /// Returns true if an ElementsAttr with the given number of elements should be
188 /// printed with hex.
shouldPrintElementsAttrWithHex(int64_t numElements)189 static bool shouldPrintElementsAttrWithHex(int64_t numElements) {
190 // Check to see if a command line option was provided for the limit.
191 if (clOptions.isConstructed()) {
192 if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) {
193 // -1 is used to disable hex printing.
194 if (clOptions->printElementsAttrWithHexIfLarger == -1)
195 return false;
196 return numElements > clOptions->printElementsAttrWithHexIfLarger;
197 }
198 }
199
200 // Otherwise, default to printing with hex if the number of elements is >100.
201 return numElements > 100;
202 }
203
204 //===----------------------------------------------------------------------===//
205 // NewLineCounter
206 //===----------------------------------------------------------------------===//
207
208 namespace {
209 /// This class is a simple formatter that emits a new line when inputted into a
210 /// stream, that enables counting the number of newlines emitted. This class
211 /// should be used whenever emitting newlines in the printer.
212 struct NewLineCounter {
213 unsigned curLine = 1;
214 };
215 } // end anonymous namespace
216
operator <<(raw_ostream & os,NewLineCounter & newLine)217 static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) {
218 ++newLine.curLine;
219 return os << '\n';
220 }
221
222 //===----------------------------------------------------------------------===//
223 // AliasInitializer
224 //===----------------------------------------------------------------------===//
225
226 namespace {
227 /// This class represents a specific instance of a symbol Alias.
228 class SymbolAlias {
229 public:
SymbolAlias(StringRef name,bool isDeferrable)230 SymbolAlias(StringRef name, bool isDeferrable)
231 : name(name), suffixIndex(0), hasSuffixIndex(false),
232 isDeferrable(isDeferrable) {}
SymbolAlias(StringRef name,uint32_t suffixIndex,bool isDeferrable)233 SymbolAlias(StringRef name, uint32_t suffixIndex, bool isDeferrable)
234 : name(name), suffixIndex(suffixIndex), hasSuffixIndex(true),
235 isDeferrable(isDeferrable) {}
236
237 /// Print this alias to the given stream.
print(raw_ostream & os) const238 void print(raw_ostream &os) const {
239 os << name;
240 if (hasSuffixIndex)
241 os << suffixIndex;
242 }
243
244 /// Returns true if this alias supports deferred resolution when parsing.
canBeDeferred() const245 bool canBeDeferred() const { return isDeferrable; }
246
247 private:
248 /// The main name of the alias.
249 StringRef name;
250 /// The optional suffix index of the alias, if multiple aliases had the same
251 /// name.
252 uint32_t suffixIndex : 30;
253 /// A flag indicating whether this alias has a suffix or not.
254 bool hasSuffixIndex : 1;
255 /// A flag indicating whether this alias may be deferred or not.
256 bool isDeferrable : 1;
257 };
258
259 /// This class represents a utility that initializes the set of attribute and
260 /// type aliases, without the need to store the extra information within the
261 /// main AliasState class or pass it around via function arguments.
262 class AliasInitializer {
263 public:
AliasInitializer(DialectInterfaceCollection<OpAsmDialectInterface> & interfaces,llvm::BumpPtrAllocator & aliasAllocator)264 AliasInitializer(
265 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces,
266 llvm::BumpPtrAllocator &aliasAllocator)
267 : interfaces(interfaces), aliasAllocator(aliasAllocator),
268 aliasOS(aliasBuffer) {}
269
270 void initialize(Operation *op, const OpPrintingFlags &printerFlags,
271 llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
272 llvm::MapVector<Type, SymbolAlias> &typeToAlias);
273
274 /// Visit the given attribute to see if it has an alias. `canBeDeferred` is
275 /// set to true if the originator of this attribute can resolve the alias
276 /// after parsing has completed (e.g. in the case of operation locations).
277 void visit(Attribute attr, bool canBeDeferred = false);
278
279 /// Visit the given type to see if it has an alias.
280 void visit(Type type);
281
282 private:
283 /// Try to generate an alias for the provided symbol. If an alias is
284 /// generated, the provided alias mapping and reverse mapping are updated.
285 /// Returns success if an alias was generated, failure otherwise.
286 template <typename T>
287 LogicalResult
288 generateAlias(T symbol,
289 llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol);
290
291 /// The set of asm interfaces within the context.
292 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
293
294 /// Mapping between an alias and the set of symbols mapped to it.
295 llvm::MapVector<StringRef, std::vector<Attribute>> aliasToAttr;
296 llvm::MapVector<StringRef, std::vector<Type>> aliasToType;
297
298 /// An allocator used for alias names.
299 llvm::BumpPtrAllocator &aliasAllocator;
300
301 /// The set of visited attributes.
302 DenseSet<Attribute> visitedAttributes;
303
304 /// The set of attributes that have aliases *and* can be deferred.
305 DenseSet<Attribute> deferrableAttributes;
306
307 /// The set of visited types.
308 DenseSet<Type> visitedTypes;
309
310 /// Storage and stream used when generating an alias.
311 SmallString<32> aliasBuffer;
312 llvm::raw_svector_ostream aliasOS;
313 };
314
315 /// This class implements a dummy OpAsmPrinter that doesn't print any output,
316 /// and merely collects the attributes and types that *would* be printed in a
317 /// normal print invocation so that we can generate proper aliases. This allows
318 /// for us to generate aliases only for the attributes and types that would be
319 /// in the output, and trims down unnecessary output.
320 class DummyAliasOperationPrinter : private OpAsmPrinter {
321 public:
DummyAliasOperationPrinter(const OpPrintingFlags & flags,AliasInitializer & initializer)322 explicit DummyAliasOperationPrinter(const OpPrintingFlags &flags,
323 AliasInitializer &initializer)
324 : printerFlags(flags), initializer(initializer) {}
325
326 /// Print the given operation.
print(Operation * op)327 void print(Operation *op) {
328 // Visit the operation location.
329 if (printerFlags.shouldPrintDebugInfo())
330 initializer.visit(op->getLoc(), /*canBeDeferred=*/true);
331
332 // If requested, always print the generic form.
333 if (!printerFlags.shouldPrintGenericOpForm()) {
334 // Check to see if this is a known operation. If so, use the registered
335 // custom printer hook.
336 if (auto *opInfo = op->getAbstractOperation()) {
337 opInfo->printAssembly(op, *this);
338 return;
339 }
340 }
341
342 // Otherwise print with the generic assembly form.
343 printGenericOp(op);
344 }
345
346 private:
347 /// Print the given operation in the generic form.
printGenericOp(Operation * op)348 void printGenericOp(Operation *op) override {
349 // Consider nested opertions for aliases.
350 if (op->getNumRegions() != 0) {
351 for (Region ®ion : op->getRegions())
352 printRegion(region, /*printEntryBlockArgs=*/true,
353 /*printBlockTerminators=*/true);
354 }
355
356 // Visit all the types used in the operation.
357 for (Type type : op->getOperandTypes())
358 printType(type);
359 for (Type type : op->getResultTypes())
360 printType(type);
361
362 // Consider the attributes of the operation for aliases.
363 for (const NamedAttribute &attr : op->getAttrs())
364 printAttribute(attr.second);
365 }
366
367 /// Print the given block. If 'printBlockArgs' is false, the arguments of the
368 /// block are not printed. If 'printBlockTerminator' is false, the terminator
369 /// operation of the block is not printed.
print(Block * block,bool printBlockArgs=true,bool printBlockTerminator=true)370 void print(Block *block, bool printBlockArgs = true,
371 bool printBlockTerminator = true) {
372 // Consider the types of the block arguments for aliases if 'printBlockArgs'
373 // is set to true.
374 if (printBlockArgs) {
375 for (Type type : block->getArgumentTypes())
376 printType(type);
377 }
378
379 // Consider the operations within this block, ignoring the terminator if
380 // requested.
381 auto range = llvm::make_range(
382 block->begin(), std::prev(block->end(), printBlockTerminator ? 0 : 1));
383 for (Operation &op : range)
384 print(&op);
385 }
386
387 /// Print the given region.
printRegion(Region & region,bool printEntryBlockArgs,bool printBlockTerminators)388 void printRegion(Region ®ion, bool printEntryBlockArgs,
389 bool printBlockTerminators) override {
390 if (region.empty())
391 return;
392
393 auto *entryBlock = ®ion.front();
394 print(entryBlock, printEntryBlockArgs, printBlockTerminators);
395 for (Block &b : llvm::drop_begin(region, 1))
396 print(&b);
397 }
398
399 /// Consider the given type to be printed for an alias.
printType(Type type)400 void printType(Type type) override { initializer.visit(type); }
401
402 /// Consider the given attribute to be printed for an alias.
printAttribute(Attribute attr)403 void printAttribute(Attribute attr) override { initializer.visit(attr); }
printAttributeWithoutType(Attribute attr)404 void printAttributeWithoutType(Attribute attr) override {
405 printAttribute(attr);
406 }
407
408 /// Print the given set of attributes with names not included within
409 /// 'elidedAttrs'.
printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})410 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
411 ArrayRef<StringRef> elidedAttrs = {}) override {
412 // Filter out any attributes that shouldn't be included.
413 SmallVector<NamedAttribute, 8> filteredAttrs(
__anon2591390f0402(NamedAttribute attr) 414 llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
415 return !llvm::is_contained(elidedAttrs, attr.first.strref());
416 }));
417 for (const NamedAttribute &attr : filteredAttrs)
418 printAttribute(attr.second);
419 }
printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})420 void printOptionalAttrDictWithKeyword(
421 ArrayRef<NamedAttribute> attrs,
422 ArrayRef<StringRef> elidedAttrs = {}) override {
423 printOptionalAttrDict(attrs, elidedAttrs);
424 }
425
426 /// Return 'nulls' as the output stream, this will ignore any data fed to it.
getStream() const427 raw_ostream &getStream() const override { return llvm::nulls(); }
428
429 /// The following are hooks of `OpAsmPrinter` that are not necessary for
430 /// determining potential aliases.
printAffineMapOfSSAIds(AffineMapAttr,ValueRange)431 void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
printOperand(Value)432 void printOperand(Value) override {}
printOperand(Value,raw_ostream & os)433 void printOperand(Value, raw_ostream &os) override {
434 // Users expect the output string to have at least the prefixed % to signal
435 // a value name. To maintain this invariant, emit a name even if it is
436 // guaranteed to go unused.
437 os << "%";
438 }
printSymbolName(StringRef)439 void printSymbolName(StringRef) override {}
printSuccessor(Block *)440 void printSuccessor(Block *) override {}
printSuccessorAndUseList(Block *,ValueRange)441 void printSuccessorAndUseList(Block *, ValueRange) override {}
shadowRegionArgs(Region &,ValueRange)442 void shadowRegionArgs(Region &, ValueRange) override {}
443
444 /// The printer flags to use when determining potential aliases.
445 const OpPrintingFlags &printerFlags;
446
447 /// The initializer to use when identifying aliases.
448 AliasInitializer &initializer;
449 };
450 } // end anonymous namespace
451
452 /// Sanitize the given name such that it can be used as a valid identifier. If
453 /// the string needs to be modified in any way, the provided buffer is used to
454 /// store the new copy,
sanitizeIdentifier(StringRef name,SmallString<16> & buffer,StringRef allowedPunctChars="$._-",bool allowTrailingDigit=true)455 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
456 StringRef allowedPunctChars = "$._-",
457 bool allowTrailingDigit = true) {
458 assert(!name.empty() && "Shouldn't have an empty name here");
459
460 auto copyNameToBuffer = [&] {
461 for (char ch : name) {
462 if (llvm::isAlnum(ch) || allowedPunctChars.contains(ch))
463 buffer.push_back(ch);
464 else if (ch == ' ')
465 buffer.push_back('_');
466 else
467 buffer.append(llvm::utohexstr((unsigned char)ch));
468 }
469 };
470
471 // Check to see if this name is valid. If it starts with a digit, then it
472 // could conflict with the autogenerated numeric ID's, so add an underscore
473 // prefix to avoid problems.
474 if (isdigit(name[0])) {
475 buffer.push_back('_');
476 copyNameToBuffer();
477 return buffer;
478 }
479
480 // If the name ends with a trailing digit, add a '_' to avoid potential
481 // conflicts with autogenerated ID's.
482 if (!allowTrailingDigit && isdigit(name.back())) {
483 copyNameToBuffer();
484 buffer.push_back('_');
485 return buffer;
486 }
487
488 // Check to see that the name consists of only valid identifier characters.
489 for (char ch : name) {
490 if (!llvm::isAlnum(ch) && !allowedPunctChars.contains(ch)) {
491 copyNameToBuffer();
492 return buffer;
493 }
494 }
495
496 // If there are no invalid characters, return the original name.
497 return name;
498 }
499
500 /// Given a collection of aliases and symbols, initialize a mapping from a
501 /// symbol to a given alias.
502 template <typename T>
503 static void
initializeAliases(llvm::MapVector<StringRef,std::vector<T>> & aliasToSymbol,llvm::MapVector<T,SymbolAlias> & symbolToAlias,DenseSet<T> * deferrableAliases=nullptr)504 initializeAliases(llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol,
505 llvm::MapVector<T, SymbolAlias> &symbolToAlias,
506 DenseSet<T> *deferrableAliases = nullptr) {
507 std::vector<std::pair<StringRef, std::vector<T>>> aliases =
508 aliasToSymbol.takeVector();
509 llvm::array_pod_sort(aliases.begin(), aliases.end(),
510 [](const auto *lhs, const auto *rhs) {
511 return lhs->first.compare(rhs->first);
512 });
513
514 for (auto &it : aliases) {
515 // If there is only one instance for this alias, use the name directly.
516 if (it.second.size() == 1) {
517 T symbol = it.second.front();
518 bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
519 symbolToAlias.insert({symbol, SymbolAlias(it.first, isDeferrable)});
520 continue;
521 }
522 // Otherwise, add the index to the name.
523 for (int i = 0, e = it.second.size(); i < e; ++i) {
524 T symbol = it.second[i];
525 bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
526 symbolToAlias.insert({symbol, SymbolAlias(it.first, i, isDeferrable)});
527 }
528 }
529 }
530
initialize(Operation * op,const OpPrintingFlags & printerFlags,llvm::MapVector<Attribute,SymbolAlias> & attrToAlias,llvm::MapVector<Type,SymbolAlias> & typeToAlias)531 void AliasInitializer::initialize(
532 Operation *op, const OpPrintingFlags &printerFlags,
533 llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
534 llvm::MapVector<Type, SymbolAlias> &typeToAlias) {
535 // Use a dummy printer when walking the IR so that we can collect the
536 // attributes/types that will actually be used during printing when
537 // considering aliases.
538 DummyAliasOperationPrinter aliasPrinter(printerFlags, *this);
539 aliasPrinter.print(op);
540
541 // Initialize the aliases sorted by name.
542 initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes);
543 initializeAliases(aliasToType, typeToAlias);
544 }
545
visit(Attribute attr,bool canBeDeferred)546 void AliasInitializer::visit(Attribute attr, bool canBeDeferred) {
547 if (!visitedAttributes.insert(attr).second) {
548 // If this attribute already has an alias and this instance can't be
549 // deferred, make sure that the alias isn't deferred.
550 if (!canBeDeferred)
551 deferrableAttributes.erase(attr);
552 return;
553 }
554
555 // Try to generate an alias for this attribute.
556 if (succeeded(generateAlias(attr, aliasToAttr))) {
557 if (canBeDeferred)
558 deferrableAttributes.insert(attr);
559 return;
560 }
561
562 if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
563 for (Attribute element : arrayAttr.getValue())
564 visit(element);
565 } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) {
566 for (const NamedAttribute &attr : dictAttr)
567 visit(attr.second);
568 } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
569 visit(typeAttr.getValue());
570 }
571 }
572
visit(Type type)573 void AliasInitializer::visit(Type type) {
574 if (!visitedTypes.insert(type).second)
575 return;
576
577 // Try to generate an alias for this type.
578 if (succeeded(generateAlias(type, aliasToType)))
579 return;
580
581 // Visit several subtypes that contain types or atttributes.
582 if (auto funcType = type.dyn_cast<FunctionType>()) {
583 // Visit input and result types for functions.
584 for (auto input : funcType.getInputs())
585 visit(input);
586 for (auto result : funcType.getResults())
587 visit(result);
588 } else if (auto shapedType = type.dyn_cast<ShapedType>()) {
589 visit(shapedType.getElementType());
590
591 // Visit affine maps in memref type.
592 if (auto memref = type.dyn_cast<MemRefType>())
593 for (auto map : memref.getAffineMaps())
594 visit(AffineMapAttr::get(map));
595 }
596 }
597
598 template <typename T>
generateAlias(T symbol,llvm::MapVector<StringRef,std::vector<T>> & aliasToSymbol)599 LogicalResult AliasInitializer::generateAlias(
600 T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) {
601 SmallString<16> tempBuffer;
602 for (const auto &interface : interfaces) {
603 interface.getAlias(symbol, aliasOS);
604 StringRef name = aliasOS.str();
605 if (name.empty())
606 continue;
607 name = sanitizeIdentifier(name, tempBuffer, /*allowedPunctChars=*/"$_-",
608 /*allowTrailingDigit=*/false);
609 name = name.copy(aliasAllocator);
610
611 aliasToSymbol[name].push_back(symbol);
612 aliasBuffer.clear();
613 return success();
614 }
615 return failure();
616 }
617
618 //===----------------------------------------------------------------------===//
619 // AliasState
620 //===----------------------------------------------------------------------===//
621
622 namespace {
623 /// This class manages the state for type and attribute aliases.
624 class AliasState {
625 public:
626 // Initialize the internal aliases.
627 void
628 initialize(Operation *op, const OpPrintingFlags &printerFlags,
629 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
630
631 /// Get an alias for the given attribute if it has one and print it in `os`.
632 /// Returns success if an alias was printed, failure otherwise.
633 LogicalResult getAlias(Attribute attr, raw_ostream &os) const;
634
635 /// Get an alias for the given type if it has one and print it in `os`.
636 /// Returns success if an alias was printed, failure otherwise.
637 LogicalResult getAlias(Type ty, raw_ostream &os) const;
638
639 /// Print all of the referenced aliases that can not be resolved in a deferred
640 /// manner.
printNonDeferredAliases(raw_ostream & os,NewLineCounter & newLine) const641 void printNonDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
642 printAliases(os, newLine, /*isDeferred=*/false);
643 }
644
645 /// Print all of the referenced aliases that support deferred resolution.
printDeferredAliases(raw_ostream & os,NewLineCounter & newLine) const646 void printDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
647 printAliases(os, newLine, /*isDeferred=*/true);
648 }
649
650 private:
651 /// Print all of the referenced aliases that support the provided resolution
652 /// behavior.
653 void printAliases(raw_ostream &os, NewLineCounter &newLine,
654 bool isDeferred) const;
655
656 /// Mapping between attribute and alias.
657 llvm::MapVector<Attribute, SymbolAlias> attrToAlias;
658 /// Mapping between type and alias.
659 llvm::MapVector<Type, SymbolAlias> typeToAlias;
660
661 /// An allocator used for alias names.
662 llvm::BumpPtrAllocator aliasAllocator;
663 };
664 } // end anonymous namespace
665
initialize(Operation * op,const OpPrintingFlags & printerFlags,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)666 void AliasState::initialize(
667 Operation *op, const OpPrintingFlags &printerFlags,
668 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
669 AliasInitializer initializer(interfaces, aliasAllocator);
670 initializer.initialize(op, printerFlags, attrToAlias, typeToAlias);
671 }
672
getAlias(Attribute attr,raw_ostream & os) const673 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
674 auto it = attrToAlias.find(attr);
675 if (it == attrToAlias.end())
676 return failure();
677 it->second.print(os << '#');
678 return success();
679 }
680
getAlias(Type ty,raw_ostream & os) const681 LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
682 auto it = typeToAlias.find(ty);
683 if (it == typeToAlias.end())
684 return failure();
685
686 it->second.print(os << '!');
687 return success();
688 }
689
printAliases(raw_ostream & os,NewLineCounter & newLine,bool isDeferred) const690 void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine,
691 bool isDeferred) const {
692 auto filterFn = [=](const auto &aliasIt) {
693 return aliasIt.second.canBeDeferred() == isDeferred;
694 };
695 for (const auto &it : llvm::make_filter_range(attrToAlias, filterFn)) {
696 it.second.print(os << '#');
697 os << " = " << it.first << newLine;
698 }
699 for (const auto &it : llvm::make_filter_range(typeToAlias, filterFn)) {
700 it.second.print(os << '!');
701 os << " = " << it.first << newLine;
702 }
703 }
704
705 //===----------------------------------------------------------------------===//
706 // SSANameState
707 //===----------------------------------------------------------------------===//
708
709 namespace {
710 /// This class manages the state of SSA value names.
711 class SSANameState {
712 public:
713 /// A sentinel value used for values with names set.
714 enum : unsigned { NameSentinel = ~0U };
715
716 SSANameState(Operation *op,
717 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
718
719 /// Print the SSA identifier for the given value to 'stream'. If
720 /// 'printResultNo' is true, it also presents the result number ('#' number)
721 /// of this value.
722 void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
723
724 /// Return the result indices for each of the result groups registered by this
725 /// operation, or empty if none exist.
726 ArrayRef<int> getOpResultGroups(Operation *op);
727
728 /// Get the ID for the given block.
729 unsigned getBlockID(Block *block);
730
731 /// Renumber the arguments for the specified region to the same names as the
732 /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
733 /// details.
734 void shadowRegionArgs(Region ®ion, ValueRange namesToUse);
735
736 private:
737 /// Number the SSA values within the given IR unit.
738 void numberValuesInRegion(
739 Region ®ion,
740 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
741 void numberValuesInBlock(
742 Block &block,
743 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
744 void numberValuesInOp(
745 Operation &op,
746 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
747
748 /// Given a result of an operation 'result', find the result group head
749 /// 'lookupValue' and the result of 'result' within that group in
750 /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
751 /// has more than 1 result.
752 void getResultIDAndNumber(OpResult result, Value &lookupValue,
753 Optional<int> &lookupResultNo) const;
754
755 /// Set a special value name for the given value.
756 void setValueName(Value value, StringRef name);
757
758 /// Uniques the given value name within the printer. If the given name
759 /// conflicts, it is automatically renamed.
760 StringRef uniqueValueName(StringRef name);
761
762 /// This is the value ID for each SSA value. If this returns NameSentinel,
763 /// then the valueID has an entry in valueNames.
764 DenseMap<Value, unsigned> valueIDs;
765 DenseMap<Value, StringRef> valueNames;
766
767 /// This is a map of operations that contain multiple named result groups,
768 /// i.e. there may be multiple names for the results of the operation. The
769 /// value of this map are the result numbers that start a result group.
770 DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
771
772 /// This is the block ID for each block in the current.
773 DenseMap<Block *, unsigned> blockIDs;
774
775 /// This keeps track of all of the non-numeric names that are in flight,
776 /// allowing us to check for duplicates.
777 /// Note: the value of the map is unused.
778 llvm::ScopedHashTable<StringRef, char> usedNames;
779 llvm::BumpPtrAllocator usedNameAllocator;
780
781 /// This is the next value ID to assign in numbering.
782 unsigned nextValueID = 0;
783 /// This is the next ID to assign to a region entry block argument.
784 unsigned nextArgumentID = 0;
785 /// This is the next ID to assign when a name conflict is detected.
786 unsigned nextConflictID = 0;
787 };
788 } // end anonymous namespace
789
SSANameState(Operation * op,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)790 SSANameState::SSANameState(
791 Operation *op,
792 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
793 llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames);
794 numberValuesInOp(*op, interfaces);
795
796 for (auto ®ion : op->getRegions())
797 numberValuesInRegion(region, interfaces);
798 }
799
printValueID(Value value,bool printResultNo,raw_ostream & stream) const800 void SSANameState::printValueID(Value value, bool printResultNo,
801 raw_ostream &stream) const {
802 if (!value) {
803 stream << "<<NULL>>";
804 return;
805 }
806
807 Optional<int> resultNo;
808 auto lookupValue = value;
809
810 // If this is an operation result, collect the head lookup value of the result
811 // group and the result number of 'result' within that group.
812 if (OpResult result = value.dyn_cast<OpResult>())
813 getResultIDAndNumber(result, lookupValue, resultNo);
814
815 auto it = valueIDs.find(lookupValue);
816 if (it == valueIDs.end()) {
817 stream << "<<UNKNOWN SSA VALUE>>";
818 return;
819 }
820
821 stream << '%';
822 if (it->second != NameSentinel) {
823 stream << it->second;
824 } else {
825 auto nameIt = valueNames.find(lookupValue);
826 assert(nameIt != valueNames.end() && "Didn't have a name entry?");
827 stream << nameIt->second;
828 }
829
830 if (resultNo.hasValue() && printResultNo)
831 stream << '#' << resultNo;
832 }
833
getOpResultGroups(Operation * op)834 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
835 auto it = opResultGroups.find(op);
836 return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
837 }
838
getBlockID(Block * block)839 unsigned SSANameState::getBlockID(Block *block) {
840 auto it = blockIDs.find(block);
841 return it != blockIDs.end() ? it->second : NameSentinel;
842 }
843
shadowRegionArgs(Region & region,ValueRange namesToUse)844 void SSANameState::shadowRegionArgs(Region ®ion, ValueRange namesToUse) {
845 assert(!region.empty() && "cannot shadow arguments of an empty region");
846 assert(region.getNumArguments() == namesToUse.size() &&
847 "incorrect number of names passed in");
848 assert(region.getParentOp()->isKnownIsolatedFromAbove() &&
849 "only KnownIsolatedFromAbove ops can shadow names");
850
851 SmallVector<char, 16> nameStr;
852 for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
853 auto nameToUse = namesToUse[i];
854 if (nameToUse == nullptr)
855 continue;
856 auto nameToReplace = region.getArgument(i);
857
858 nameStr.clear();
859 llvm::raw_svector_ostream nameStream(nameStr);
860 printValueID(nameToUse, /*printResultNo=*/true, nameStream);
861
862 // Entry block arguments should already have a pretty "arg" name.
863 assert(valueIDs[nameToReplace] == NameSentinel);
864
865 // Use the name without the leading %.
866 auto name = StringRef(nameStream.str()).drop_front();
867
868 // Overwrite the name.
869 valueNames[nameToReplace] = name.copy(usedNameAllocator);
870 }
871 }
872
numberValuesInRegion(Region & region,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)873 void SSANameState::numberValuesInRegion(
874 Region ®ion,
875 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
876 // Save the current value ids to allow for numbering values in sibling regions
877 // the same.
878 llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID);
879 llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID);
880 llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID);
881
882 // Push a new used names scope.
883 llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames);
884
885 // Number the values within this region in a breadth-first order.
886 unsigned nextBlockID = 0;
887 for (auto &block : region) {
888 // Each block gets a unique ID, and all of the operations within it get
889 // numbered as well.
890 blockIDs[&block] = nextBlockID++;
891 numberValuesInBlock(block, interfaces);
892 }
893
894 // After that we traverse the nested regions.
895 // TODO: Rework this loop to not use recursion.
896 for (auto &block : region) {
897 for (auto &op : block)
898 for (auto &nestedRegion : op.getRegions())
899 numberValuesInRegion(nestedRegion, interfaces);
900 }
901 }
902
numberValuesInBlock(Block & block,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)903 void SSANameState::numberValuesInBlock(
904 Block &block,
905 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
906 auto setArgNameFn = [&](Value arg, StringRef name) {
907 assert(!valueIDs.count(arg) && "arg numbered multiple times");
908 assert(arg.cast<BlockArgument>().getOwner() == &block &&
909 "arg not defined in 'block'");
910 setValueName(arg, name);
911 };
912
913 bool isEntryBlock = block.isEntryBlock();
914 if (isEntryBlock) {
915 if (auto *op = block.getParentOp()) {
916 if (auto asmInterface = interfaces.getInterfaceFor(op->getDialect()))
917 asmInterface->getAsmBlockArgumentNames(&block, setArgNameFn);
918 }
919 }
920
921 // Number the block arguments. We give entry block arguments a special name
922 // 'arg'.
923 SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
924 llvm::raw_svector_ostream specialName(specialNameBuffer);
925 for (auto arg : block.getArguments()) {
926 if (valueIDs.count(arg))
927 continue;
928 if (isEntryBlock) {
929 specialNameBuffer.resize(strlen("arg"));
930 specialName << nextArgumentID++;
931 }
932 setValueName(arg, specialName.str());
933 }
934
935 // Number the operations in this block.
936 for (auto &op : block)
937 numberValuesInOp(op, interfaces);
938 }
939
numberValuesInOp(Operation & op,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)940 void SSANameState::numberValuesInOp(
941 Operation &op,
942 DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
943 unsigned numResults = op.getNumResults();
944 if (numResults == 0)
945 return;
946 Value resultBegin = op.getResult(0);
947
948 // Function used to set the special result names for the operation.
949 SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
950 auto setResultNameFn = [&](Value result, StringRef name) {
951 assert(!valueIDs.count(result) && "result numbered multiple times");
952 assert(result.getDefiningOp() == &op && "result not defined by 'op'");
953 setValueName(result, name);
954
955 // Record the result number for groups not anchored at 0.
956 if (int resultNo = result.cast<OpResult>().getResultNumber())
957 resultGroups.push_back(resultNo);
958 };
959 if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op))
960 asmInterface.getAsmResultNames(setResultNameFn);
961 else if (auto *asmInterface = interfaces.getInterfaceFor(op.getDialect()))
962 asmInterface->getAsmResultNames(&op, setResultNameFn);
963
964 // If the first result wasn't numbered, give it a default number.
965 if (valueIDs.try_emplace(resultBegin, nextValueID).second)
966 ++nextValueID;
967
968 // If this operation has multiple result groups, mark it.
969 if (resultGroups.size() != 1) {
970 llvm::array_pod_sort(resultGroups.begin(), resultGroups.end());
971 opResultGroups.try_emplace(&op, std::move(resultGroups));
972 }
973 }
974
getResultIDAndNumber(OpResult result,Value & lookupValue,Optional<int> & lookupResultNo) const975 void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue,
976 Optional<int> &lookupResultNo) const {
977 Operation *owner = result.getOwner();
978 if (owner->getNumResults() == 1)
979 return;
980 int resultNo = result.getResultNumber();
981
982 // If this operation has multiple result groups, we will need to find the
983 // one corresponding to this result.
984 auto resultGroupIt = opResultGroups.find(owner);
985 if (resultGroupIt == opResultGroups.end()) {
986 // If not, just use the first result.
987 lookupResultNo = resultNo;
988 lookupValue = owner->getResult(0);
989 return;
990 }
991
992 // Find the correct index using a binary search, as the groups are ordered.
993 ArrayRef<int> resultGroups = resultGroupIt->second;
994 auto it = llvm::upper_bound(resultGroups, resultNo);
995 int groupResultNo = 0, groupSize = 0;
996
997 // If there are no smaller elements, the last result group is the lookup.
998 if (it == resultGroups.end()) {
999 groupResultNo = resultGroups.back();
1000 groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
1001 } else {
1002 // Otherwise, the previous element is the lookup.
1003 groupResultNo = *std::prev(it);
1004 groupSize = *it - groupResultNo;
1005 }
1006
1007 // We only record the result number for a group of size greater than 1.
1008 if (groupSize != 1)
1009 lookupResultNo = resultNo - groupResultNo;
1010 lookupValue = owner->getResult(groupResultNo);
1011 }
1012
setValueName(Value value,StringRef name)1013 void SSANameState::setValueName(Value value, StringRef name) {
1014 // If the name is empty, the value uses the default numbering.
1015 if (name.empty()) {
1016 valueIDs[value] = nextValueID++;
1017 return;
1018 }
1019
1020 valueIDs[value] = NameSentinel;
1021 valueNames[value] = uniqueValueName(name);
1022 }
1023
uniqueValueName(StringRef name)1024 StringRef SSANameState::uniqueValueName(StringRef name) {
1025 SmallString<16> tmpBuffer;
1026 name = sanitizeIdentifier(name, tmpBuffer);
1027
1028 // Check to see if this name is already unique.
1029 if (!usedNames.count(name)) {
1030 name = name.copy(usedNameAllocator);
1031 } else {
1032 // Otherwise, we had a conflict - probe until we find a unique name. This
1033 // is guaranteed to terminate (and usually in a single iteration) because it
1034 // generates new names by incrementing nextConflictID.
1035 SmallString<64> probeName(name);
1036 probeName.push_back('_');
1037 while (true) {
1038 probeName += llvm::utostr(nextConflictID++);
1039 if (!usedNames.count(probeName)) {
1040 name = StringRef(probeName).copy(usedNameAllocator);
1041 break;
1042 }
1043 probeName.resize(name.size() + 1);
1044 }
1045 }
1046
1047 usedNames.insert(name, char());
1048 return name;
1049 }
1050
1051 //===----------------------------------------------------------------------===//
1052 // AsmState
1053 //===----------------------------------------------------------------------===//
1054
1055 namespace mlir {
1056 namespace detail {
1057 class AsmStateImpl {
1058 public:
AsmStateImpl(Operation * op,AsmState::LocationMap * locationMap)1059 explicit AsmStateImpl(Operation *op, AsmState::LocationMap *locationMap)
1060 : interfaces(op->getContext()), nameState(op, interfaces),
1061 locationMap(locationMap) {}
1062
1063 /// Initialize the alias state to enable the printing of aliases.
initializeAliases(Operation * op,const OpPrintingFlags & printerFlags)1064 void initializeAliases(Operation *op, const OpPrintingFlags &printerFlags) {
1065 aliasState.initialize(op, printerFlags, interfaces);
1066 }
1067
1068 /// Get an instance of the OpAsmDialectInterface for the given dialect, or
1069 /// null if one wasn't registered.
getOpAsmInterface(Dialect * dialect)1070 const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) {
1071 return interfaces.getInterfaceFor(dialect);
1072 }
1073
1074 /// Get the state used for aliases.
getAliasState()1075 AliasState &getAliasState() { return aliasState; }
1076
1077 /// Get the state used for SSA names.
getSSANameState()1078 SSANameState &getSSANameState() { return nameState; }
1079
1080 /// Register the location, line and column, within the buffer that the given
1081 /// operation was printed at.
registerOperationLocation(Operation * op,unsigned line,unsigned col)1082 void registerOperationLocation(Operation *op, unsigned line, unsigned col) {
1083 if (locationMap)
1084 (*locationMap)[op] = std::make_pair(line, col);
1085 }
1086
1087 private:
1088 /// Collection of OpAsm interfaces implemented in the context.
1089 DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
1090
1091 /// The state used for attribute and type aliases.
1092 AliasState aliasState;
1093
1094 /// The state used for SSA value names.
1095 SSANameState nameState;
1096
1097 /// An optional location map to be populated.
1098 AsmState::LocationMap *locationMap;
1099 };
1100 } // end namespace detail
1101 } // end namespace mlir
1102
AsmState(Operation * op,LocationMap * locationMap)1103 AsmState::AsmState(Operation *op, LocationMap *locationMap)
1104 : impl(std::make_unique<AsmStateImpl>(op, locationMap)) {}
~AsmState()1105 AsmState::~AsmState() {}
1106
1107 //===----------------------------------------------------------------------===//
1108 // ModulePrinter
1109 //===----------------------------------------------------------------------===//
1110
1111 namespace {
1112 class ModulePrinter {
1113 public:
ModulePrinter(raw_ostream & os,OpPrintingFlags flags=llvm::None,AsmStateImpl * state=nullptr)1114 ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None,
1115 AsmStateImpl *state = nullptr)
1116 : os(os), printerFlags(flags), state(state) {}
ModulePrinter(ModulePrinter & printer)1117 explicit ModulePrinter(ModulePrinter &printer)
1118 : os(printer.os), printerFlags(printer.printerFlags),
1119 state(printer.state) {}
1120
1121 /// Returns the output stream of the printer.
getStream()1122 raw_ostream &getStream() { return os; }
1123
1124 template <typename Container, typename UnaryFunctor>
interleaveComma(const Container & c,UnaryFunctor each_fn) const1125 inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
1126 llvm::interleaveComma(c, os, each_fn);
1127 }
1128
1129 /// This enum describes the different kinds of elision for the type of an
1130 /// attribute when printing it.
1131 enum class AttrTypeElision {
1132 /// The type must not be elided,
1133 Never,
1134 /// The type may be elided when it matches the default used in the parser
1135 /// (for example i64 is the default for integer attributes).
1136 May,
1137 /// The type must be elided.
1138 Must
1139 };
1140
1141 /// Print the given attribute.
1142 void printAttribute(Attribute attr,
1143 AttrTypeElision typeElision = AttrTypeElision::Never);
1144
1145 void printType(Type type);
1146
1147 /// Print the given location to the stream. If `allowAlias` is true, this
1148 /// allows for the internal location to use an attribute alias.
1149 void printLocation(LocationAttr loc, bool allowAlias = false);
1150
1151 void printAffineMap(AffineMap map);
1152 void
1153 printAffineExpr(AffineExpr expr,
1154 function_ref<void(unsigned, bool)> printValueName = nullptr);
1155 void printAffineConstraint(AffineExpr expr, bool isEq);
1156 void printIntegerSet(IntegerSet set);
1157
1158 protected:
1159 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
1160 ArrayRef<StringRef> elidedAttrs = {},
1161 bool withKeyword = false);
1162 void printNamedAttribute(NamedAttribute attr);
1163 void printTrailingLocation(Location loc);
1164 void printLocationInternal(LocationAttr loc, bool pretty = false);
1165
1166 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
1167 /// used instead of individual elements when the elements attr is large.
1168 void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex);
1169
1170 /// Print a dense string elements attribute.
1171 void printDenseStringElementsAttr(DenseStringElementsAttr attr);
1172
1173 /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
1174 /// used instead of individual elements when the elements attr is large.
1175 void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
1176 bool allowHex);
1177
1178 void printDialectAttribute(Attribute attr);
1179 void printDialectType(Type type);
1180
1181 /// This enum is used to represent the binding strength of the enclosing
1182 /// context that an AffineExprStorage is being printed in, so we can
1183 /// intelligently produce parens.
1184 enum class BindingStrength {
1185 Weak, // + and -
1186 Strong, // All other binary operators.
1187 };
1188 void printAffineExprInternal(
1189 AffineExpr expr, BindingStrength enclosingTightness,
1190 function_ref<void(unsigned, bool)> printValueName = nullptr);
1191
1192 /// The output stream for the printer.
1193 raw_ostream &os;
1194
1195 /// A set of flags to control the printer's behavior.
1196 OpPrintingFlags printerFlags;
1197
1198 /// An optional printer state for the module.
1199 AsmStateImpl *state;
1200
1201 /// A tracker for the number of new lines emitted during printing.
1202 NewLineCounter newLine;
1203 };
1204 } // end anonymous namespace
1205
printTrailingLocation(Location loc)1206 void ModulePrinter::printTrailingLocation(Location loc) {
1207 // Check to see if we are printing debug information.
1208 if (!printerFlags.shouldPrintDebugInfo())
1209 return;
1210
1211 os << " ";
1212 printLocation(loc, /*allowAlias=*/true);
1213 }
1214
printLocationInternal(LocationAttr loc,bool pretty)1215 void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) {
1216 TypeSwitch<LocationAttr>(loc)
1217 .Case<OpaqueLoc>([&](OpaqueLoc loc) {
1218 printLocationInternal(loc.getFallbackLocation(), pretty);
1219 })
1220 .Case<UnknownLoc>([&](UnknownLoc loc) {
1221 if (pretty)
1222 os << "[unknown]";
1223 else
1224 os << "unknown";
1225 })
1226 .Case<FileLineColLoc>([&](FileLineColLoc loc) {
1227 StringRef mayQuote = pretty ? "" : "\"";
1228 os << mayQuote << loc.getFilename() << mayQuote << ':' << loc.getLine()
1229 << ':' << loc.getColumn();
1230 })
1231 .Case<NameLoc>([&](NameLoc loc) {
1232 os << '\"' << loc.getName() << '\"';
1233
1234 // Print the child if it isn't unknown.
1235 auto childLoc = loc.getChildLoc();
1236 if (!childLoc.isa<UnknownLoc>()) {
1237 os << '(';
1238 printLocationInternal(childLoc, pretty);
1239 os << ')';
1240 }
1241 })
1242 .Case<CallSiteLoc>([&](CallSiteLoc loc) {
1243 Location caller = loc.getCaller();
1244 Location callee = loc.getCallee();
1245 if (!pretty)
1246 os << "callsite(";
1247 printLocationInternal(callee, pretty);
1248 if (pretty) {
1249 if (callee.isa<NameLoc>()) {
1250 if (caller.isa<FileLineColLoc>()) {
1251 os << " at ";
1252 } else {
1253 os << newLine << " at ";
1254 }
1255 } else {
1256 os << newLine << " at ";
1257 }
1258 } else {
1259 os << " at ";
1260 }
1261 printLocationInternal(caller, pretty);
1262 if (!pretty)
1263 os << ")";
1264 })
1265 .Case<FusedLoc>([&](FusedLoc loc) {
1266 if (!pretty)
1267 os << "fused";
1268 if (Attribute metadata = loc.getMetadata())
1269 os << '<' << metadata << '>';
1270 os << '[';
1271 interleave(
1272 loc.getLocations(),
1273 [&](Location loc) { printLocationInternal(loc, pretty); },
1274 [&]() { os << ", "; });
1275 os << ']';
1276 });
1277 }
1278
1279 /// Print a floating point value in a way that the parser will be able to
1280 /// round-trip losslessly.
printFloatValue(const APFloat & apValue,raw_ostream & os)1281 static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
1282 // We would like to output the FP constant value in exponential notation,
1283 // but we cannot do this if doing so will lose precision. Check here to
1284 // make sure that we only output it in exponential format if we can parse
1285 // the value back and get the same value.
1286 bool isInf = apValue.isInfinity();
1287 bool isNaN = apValue.isNaN();
1288 if (!isInf && !isNaN) {
1289 SmallString<128> strValue;
1290 apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
1291 /*TruncateZero=*/false);
1292
1293 // Check to make sure that the stringized number is not some string like
1294 // "Inf" or NaN, that atof will accept, but the lexer will not. Check
1295 // that the string matches the "[-+]?[0-9]" regex.
1296 assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
1297 ((strValue[0] == '-' || strValue[0] == '+') &&
1298 (strValue[1] >= '0' && strValue[1] <= '9'))) &&
1299 "[-+]?[0-9] regex does not match!");
1300
1301 // Parse back the stringized version and check that the value is equal
1302 // (i.e., there is no precision loss).
1303 if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
1304 os << strValue;
1305 return;
1306 }
1307
1308 // If it is not, use the default format of APFloat instead of the
1309 // exponential notation.
1310 strValue.clear();
1311 apValue.toString(strValue);
1312
1313 // Make sure that we can parse the default form as a float.
1314 if (StringRef(strValue).contains('.')) {
1315 os << strValue;
1316 return;
1317 }
1318 }
1319
1320 // Print special values in hexadecimal format. The sign bit should be included
1321 // in the literal.
1322 SmallVector<char, 16> str;
1323 APInt apInt = apValue.bitcastToAPInt();
1324 apInt.toString(str, /*Radix=*/16, /*Signed=*/false,
1325 /*formatAsCLiteral=*/true);
1326 os << str;
1327 }
1328
printLocation(LocationAttr loc,bool allowAlias)1329 void ModulePrinter::printLocation(LocationAttr loc, bool allowAlias) {
1330 if (printerFlags.shouldPrintDebugInfoPrettyForm())
1331 return printLocationInternal(loc, /*pretty=*/true);
1332
1333 os << "loc(";
1334 if (!allowAlias || !state || failed(state->getAliasState().getAlias(loc, os)))
1335 printLocationInternal(loc);
1336 os << ')';
1337 }
1338
1339 /// Returns true if the given dialect symbol data is simple enough to print in
1340 /// the pretty form, i.e. without the enclosing "".
isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName)1341 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
1342 // The name must start with an identifier.
1343 if (symName.empty() || !isalpha(symName.front()))
1344 return false;
1345
1346 // Ignore all the characters that are valid in an identifier in the symbol
1347 // name.
1348 symName = symName.drop_while(
1349 [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; });
1350 if (symName.empty())
1351 return true;
1352
1353 // If we got to an unexpected character, then it must be a <>. Check those
1354 // recursively.
1355 if (symName.front() != '<' || symName.back() != '>')
1356 return false;
1357
1358 SmallVector<char, 8> nestedPunctuation;
1359 do {
1360 // If we ran out of characters, then we had a punctuation mismatch.
1361 if (symName.empty())
1362 return false;
1363
1364 auto c = symName.front();
1365 symName = symName.drop_front();
1366
1367 switch (c) {
1368 // We never allow null characters. This is an EOF indicator for the lexer
1369 // which we could handle, but isn't important for any known dialect.
1370 case '\0':
1371 return false;
1372 case '<':
1373 case '[':
1374 case '(':
1375 case '{':
1376 nestedPunctuation.push_back(c);
1377 continue;
1378 case '-':
1379 // Treat `->` as a special token.
1380 if (!symName.empty() && symName.front() == '>') {
1381 symName = symName.drop_front();
1382 continue;
1383 }
1384 break;
1385 // Reject types with mismatched brackets.
1386 case '>':
1387 if (nestedPunctuation.pop_back_val() != '<')
1388 return false;
1389 break;
1390 case ']':
1391 if (nestedPunctuation.pop_back_val() != '[')
1392 return false;
1393 break;
1394 case ')':
1395 if (nestedPunctuation.pop_back_val() != '(')
1396 return false;
1397 break;
1398 case '}':
1399 if (nestedPunctuation.pop_back_val() != '{')
1400 return false;
1401 break;
1402 default:
1403 continue;
1404 }
1405
1406 // We're done when the punctuation is fully matched.
1407 } while (!nestedPunctuation.empty());
1408
1409 // If there were extra characters, then we failed.
1410 return symName.empty();
1411 }
1412
1413 /// Print the given dialect symbol to the stream.
printDialectSymbol(raw_ostream & os,StringRef symPrefix,StringRef dialectName,StringRef symString)1414 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
1415 StringRef dialectName, StringRef symString) {
1416 os << symPrefix << dialectName;
1417
1418 // If this symbol name is simple enough, print it directly in pretty form,
1419 // otherwise, we print it as an escaped string.
1420 if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) {
1421 os << '.' << symString;
1422 return;
1423 }
1424
1425 // TODO: escape the symbol name, it could contain " characters.
1426 os << "<\"" << symString << "\">";
1427 }
1428
1429 /// Returns true if the given string can be represented as a bare identifier.
isBareIdentifier(StringRef name)1430 static bool isBareIdentifier(StringRef name) {
1431 assert(!name.empty() && "invalid name");
1432
1433 // By making this unsigned, the value passed in to isalnum will always be
1434 // in the range 0-255. This is important when building with MSVC because
1435 // its implementation will assert. This situation can arise when dealing
1436 // with UTF-8 multibyte characters.
1437 unsigned char firstChar = static_cast<unsigned char>(name[0]);
1438 if (!isalpha(firstChar) && firstChar != '_')
1439 return false;
1440 return llvm::all_of(name.drop_front(), [](unsigned char c) {
1441 return isalnum(c) || c == '_' || c == '$' || c == '.';
1442 });
1443 }
1444
1445 /// Print the given string as a symbol reference. A symbol reference is
1446 /// represented as a string prefixed with '@'. The reference is surrounded with
1447 /// ""'s and escaped if it has any special or non-printable characters in it.
printSymbolReference(StringRef symbolRef,raw_ostream & os)1448 static void printSymbolReference(StringRef symbolRef, raw_ostream &os) {
1449 assert(!symbolRef.empty() && "expected valid symbol reference");
1450
1451 // If the symbol can be represented as a bare identifier, write it directly.
1452 if (isBareIdentifier(symbolRef)) {
1453 os << '@' << symbolRef;
1454 return;
1455 }
1456
1457 // Otherwise, output the reference wrapped in quotes with proper escaping.
1458 os << "@\"";
1459 printEscapedString(symbolRef, os);
1460 os << '"';
1461 }
1462
1463 // Print out a valid ElementsAttr that is succinct and can represent any
1464 // potential shape/type, for use when eliding a large ElementsAttr.
1465 //
1466 // We choose to use an opaque ElementsAttr literal with conspicuous content to
1467 // hopefully alert readers to the fact that this has been elided.
1468 //
1469 // Unfortunately, neither of the strings of an opaque ElementsAttr literal will
1470 // accept the string "elided". The first string must be a registered dialect
1471 // name and the latter must be a hex constant.
printElidedElementsAttr(raw_ostream & os)1472 static void printElidedElementsAttr(raw_ostream &os) {
1473 os << R"(opaque<"", "0xDEADBEEF">)";
1474 }
1475
printAttribute(Attribute attr,AttrTypeElision typeElision)1476 void ModulePrinter::printAttribute(Attribute attr,
1477 AttrTypeElision typeElision) {
1478 if (!attr) {
1479 os << "<<NULL ATTRIBUTE>>";
1480 return;
1481 }
1482
1483 // Try to print an alias for this attribute.
1484 if (state && succeeded(state->getAliasState().getAlias(attr, os)))
1485 return;
1486
1487 auto attrType = attr.getType();
1488 if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
1489 printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
1490 opaqueAttr.getAttrData());
1491 } else if (attr.isa<UnitAttr>()) {
1492 os << "unit";
1493 return;
1494 } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) {
1495 os << '{';
1496 interleaveComma(dictAttr.getValue(),
1497 [&](NamedAttribute attr) { printNamedAttribute(attr); });
1498 os << '}';
1499
1500 } else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
1501 if (attrType.isSignlessInteger(1)) {
1502 os << (intAttr.getValue().getBoolValue() ? "true" : "false");
1503
1504 // Boolean integer attributes always elides the type.
1505 return;
1506 }
1507
1508 // Only print attributes as unsigned if they are explicitly unsigned or are
1509 // signless 1-bit values. Indexes, signed values, and multi-bit signless
1510 // values print as signed.
1511 bool isUnsigned =
1512 attrType.isUnsignedInteger() || attrType.isSignlessInteger(1);
1513 intAttr.getValue().print(os, !isUnsigned);
1514
1515 // IntegerAttr elides the type if I64.
1516 if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64))
1517 return;
1518
1519 } else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
1520 printFloatValue(floatAttr.getValue(), os);
1521
1522 // FloatAttr elides the type if F64.
1523 if (typeElision == AttrTypeElision::May && attrType.isF64())
1524 return;
1525
1526 } else if (auto strAttr = attr.dyn_cast<StringAttr>()) {
1527 os << '"';
1528 printEscapedString(strAttr.getValue(), os);
1529 os << '"';
1530
1531 } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
1532 os << '[';
1533 interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
1534 printAttribute(attr, AttrTypeElision::May);
1535 });
1536 os << ']';
1537
1538 } else if (auto affineMapAttr = attr.dyn_cast<AffineMapAttr>()) {
1539 os << "affine_map<";
1540 affineMapAttr.getValue().print(os);
1541 os << '>';
1542
1543 // AffineMap always elides the type.
1544 return;
1545
1546 } else if (auto integerSetAttr = attr.dyn_cast<IntegerSetAttr>()) {
1547 os << "affine_set<";
1548 integerSetAttr.getValue().print(os);
1549 os << '>';
1550
1551 // IntegerSet always elides the type.
1552 return;
1553
1554 } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
1555 printType(typeAttr.getValue());
1556
1557 } else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) {
1558 printSymbolReference(refAttr.getRootReference(), os);
1559 for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
1560 os << "::";
1561 printSymbolReference(nestedRef.getValue(), os);
1562 }
1563
1564 } else if (auto opaqueAttr = attr.dyn_cast<OpaqueElementsAttr>()) {
1565 if (printerFlags.shouldElideElementsAttr(opaqueAttr)) {
1566 printElidedElementsAttr(os);
1567 } else {
1568 os << "opaque<\"" << opaqueAttr.getDialect()->getNamespace() << "\", ";
1569 os << '"' << "0x" << llvm::toHex(opaqueAttr.getValue()) << "\">";
1570 }
1571
1572 } else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) {
1573 if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
1574 printElidedElementsAttr(os);
1575 } else {
1576 os << "dense<";
1577 printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
1578 os << '>';
1579 }
1580
1581 } else if (auto strEltAttr = attr.dyn_cast<DenseStringElementsAttr>()) {
1582 if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
1583 printElidedElementsAttr(os);
1584 } else {
1585 os << "dense<";
1586 printDenseStringElementsAttr(strEltAttr);
1587 os << '>';
1588 }
1589
1590 } else if (auto sparseEltAttr = attr.dyn_cast<SparseElementsAttr>()) {
1591 if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) ||
1592 printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) {
1593 printElidedElementsAttr(os);
1594 } else {
1595 os << "sparse<";
1596 DenseIntElementsAttr indices = sparseEltAttr.getIndices();
1597 if (indices.getNumElements() != 0) {
1598 printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
1599 os << ", ";
1600 printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true);
1601 }
1602 os << '>';
1603 }
1604
1605 } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
1606 printLocation(locAttr);
1607
1608 } else {
1609 return printDialectAttribute(attr);
1610 }
1611
1612 // Don't print the type if we must elide it, or if it is a None type.
1613 if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) {
1614 os << " : ";
1615 printType(attrType);
1616 }
1617 }
1618
1619 /// Print the integer element of a DenseElementsAttr.
printDenseIntElement(const APInt & value,raw_ostream & os,bool isSigned)1620 static void printDenseIntElement(const APInt &value, raw_ostream &os,
1621 bool isSigned) {
1622 if (value.getBitWidth() == 1)
1623 os << (value.getBoolValue() ? "true" : "false");
1624 else
1625 value.print(os, isSigned);
1626 }
1627
1628 static void
printDenseElementsAttrImpl(bool isSplat,ShapedType type,raw_ostream & os,function_ref<void (unsigned)> printEltFn)1629 printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
1630 function_ref<void(unsigned)> printEltFn) {
1631 // Special case for 0-d and splat tensors.
1632 if (isSplat)
1633 return printEltFn(0);
1634
1635 // Special case for degenerate tensors.
1636 auto numElements = type.getNumElements();
1637 if (numElements == 0)
1638 return;
1639
1640 // We use a mixed-radix counter to iterate through the shape. When we bump a
1641 // non-least-significant digit, we emit a close bracket. When we next emit an
1642 // element we re-open all closed brackets.
1643
1644 // The mixed-radix counter, with radices in 'shape'.
1645 int64_t rank = type.getRank();
1646 SmallVector<unsigned, 4> counter(rank, 0);
1647 // The number of brackets that have been opened and not closed.
1648 unsigned openBrackets = 0;
1649
1650 auto shape = type.getShape();
1651 auto bumpCounter = [&] {
1652 // Bump the least significant digit.
1653 ++counter[rank - 1];
1654 // Iterate backwards bubbling back the increment.
1655 for (unsigned i = rank - 1; i > 0; --i)
1656 if (counter[i] >= shape[i]) {
1657 // Index 'i' is rolled over. Bump (i-1) and close a bracket.
1658 counter[i] = 0;
1659 ++counter[i - 1];
1660 --openBrackets;
1661 os << ']';
1662 }
1663 };
1664
1665 for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
1666 if (idx != 0)
1667 os << ", ";
1668 while (openBrackets++ < rank)
1669 os << '[';
1670 openBrackets = rank;
1671 printEltFn(idx);
1672 bumpCounter();
1673 }
1674 while (openBrackets-- > 0)
1675 os << ']';
1676 }
1677
printDenseElementsAttr(DenseElementsAttr attr,bool allowHex)1678 void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
1679 bool allowHex) {
1680 if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>())
1681 return printDenseStringElementsAttr(stringAttr);
1682
1683 printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(),
1684 allowHex);
1685 }
1686
printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,bool allowHex)1687 void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
1688 bool allowHex) {
1689 auto type = attr.getType();
1690 auto elementType = type.getElementType();
1691
1692 // Check to see if we should format this attribute as a hex string.
1693 auto numElements = type.getNumElements();
1694 if (!attr.isSplat() && allowHex &&
1695 shouldPrintElementsAttrWithHex(numElements)) {
1696 ArrayRef<char> rawData = attr.getRawData();
1697 if (llvm::support::endian::system_endianness() ==
1698 llvm::support::endianness::big) {
1699 // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE
1700 // machines. It is converted here to print in LE format.
1701 SmallVector<char, 64> outDataVec(rawData.size());
1702 MutableArrayRef<char> convRawData(outDataVec);
1703 DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1704 rawData, convRawData, type);
1705 os << '"' << "0x"
1706 << llvm::toHex(StringRef(convRawData.data(), convRawData.size()))
1707 << "\"";
1708 } else {
1709 os << '"' << "0x"
1710 << llvm::toHex(StringRef(rawData.data(), rawData.size())) << "\"";
1711 }
1712
1713 return;
1714 }
1715
1716 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1717 Type complexElementType = complexTy.getElementType();
1718 // Note: The if and else below had a common lambda function which invoked
1719 // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
1720 // and hence was replaced.
1721 if (complexElementType.isa<IntegerType>()) {
1722 bool isSigned = !complexElementType.isUnsignedInteger();
1723 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1724 auto complexValue = *(attr.getComplexIntValues().begin() + index);
1725 os << "(";
1726 printDenseIntElement(complexValue.real(), os, isSigned);
1727 os << ",";
1728 printDenseIntElement(complexValue.imag(), os, isSigned);
1729 os << ")";
1730 });
1731 } else {
1732 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1733 auto complexValue = *(attr.getComplexFloatValues().begin() + index);
1734 os << "(";
1735 printFloatValue(complexValue.real(), os);
1736 os << ",";
1737 printFloatValue(complexValue.imag(), os);
1738 os << ")";
1739 });
1740 }
1741 } else if (elementType.isIntOrIndex()) {
1742 bool isSigned = !elementType.isUnsignedInteger();
1743 auto intValues = attr.getIntValues();
1744 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1745 printDenseIntElement(*(intValues.begin() + index), os, isSigned);
1746 });
1747 } else {
1748 assert(elementType.isa<FloatType>() && "unexpected element type");
1749 auto floatValues = attr.getFloatValues();
1750 printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1751 printFloatValue(*(floatValues.begin() + index), os);
1752 });
1753 }
1754 }
1755
printDenseStringElementsAttr(DenseStringElementsAttr attr)1756 void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) {
1757 ArrayRef<StringRef> data = attr.getRawStringData();
1758 auto printFn = [&](unsigned index) {
1759 os << "\"";
1760 printEscapedString(data[index], os);
1761 os << "\"";
1762 };
1763 printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
1764 }
1765
printType(Type type)1766 void ModulePrinter::printType(Type type) {
1767 if (!type) {
1768 os << "<<NULL TYPE>>";
1769 return;
1770 }
1771
1772 // Try to print an alias for this type.
1773 if (state && succeeded(state->getAliasState().getAlias(type, os)))
1774 return;
1775
1776 TypeSwitch<Type>(type)
1777 .Case<OpaqueType>([&](OpaqueType opaqueTy) {
1778 printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
1779 opaqueTy.getTypeData());
1780 })
1781 .Case<IndexType>([&](Type) { os << "index"; })
1782 .Case<BFloat16Type>([&](Type) { os << "bf16"; })
1783 .Case<Float16Type>([&](Type) { os << "f16"; })
1784 .Case<Float32Type>([&](Type) { os << "f32"; })
1785 .Case<Float64Type>([&](Type) { os << "f64"; })
1786 .Case<IntegerType>([&](IntegerType integerTy) {
1787 if (integerTy.isSigned())
1788 os << 's';
1789 else if (integerTy.isUnsigned())
1790 os << 'u';
1791 os << 'i' << integerTy.getWidth();
1792 })
1793 .Case<FunctionType>([&](FunctionType funcTy) {
1794 os << '(';
1795 interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); });
1796 os << ") -> ";
1797 ArrayRef<Type> results = funcTy.getResults();
1798 if (results.size() == 1 && !results[0].isa<FunctionType>()) {
1799 os << results[0];
1800 } else {
1801 os << '(';
1802 interleaveComma(results, [&](Type ty) { printType(ty); });
1803 os << ')';
1804 }
1805 })
1806 .Case<VectorType>([&](VectorType vectorTy) {
1807 os << "vector<";
1808 for (int64_t dim : vectorTy.getShape())
1809 os << dim << 'x';
1810 os << vectorTy.getElementType() << '>';
1811 })
1812 .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
1813 os << "tensor<";
1814 for (int64_t dim : tensorTy.getShape()) {
1815 if (ShapedType::isDynamic(dim))
1816 os << '?';
1817 else
1818 os << dim;
1819 os << 'x';
1820 }
1821 os << tensorTy.getElementType() << '>';
1822 })
1823 .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
1824 os << "tensor<*x";
1825 printType(tensorTy.getElementType());
1826 os << '>';
1827 })
1828 .Case<MemRefType>([&](MemRefType memrefTy) {
1829 os << "memref<";
1830 for (int64_t dim : memrefTy.getShape()) {
1831 if (ShapedType::isDynamic(dim))
1832 os << '?';
1833 else
1834 os << dim;
1835 os << 'x';
1836 }
1837 printType(memrefTy.getElementType());
1838 for (auto map : memrefTy.getAffineMaps()) {
1839 os << ", ";
1840 printAttribute(AffineMapAttr::get(map));
1841 }
1842 // Only print the memory space if it is the non-default one.
1843 if (memrefTy.getMemorySpace())
1844 os << ", " << memrefTy.getMemorySpace();
1845 os << '>';
1846 })
1847 .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
1848 os << "memref<*x";
1849 printType(memrefTy.getElementType());
1850 // Only print the memory space if it is the non-default one.
1851 if (memrefTy.getMemorySpace())
1852 os << ", " << memrefTy.getMemorySpace();
1853 os << '>';
1854 })
1855 .Case<ComplexType>([&](ComplexType complexTy) {
1856 os << "complex<";
1857 printType(complexTy.getElementType());
1858 os << '>';
1859 })
1860 .Case<TupleType>([&](TupleType tupleTy) {
1861 os << "tuple<";
1862 interleaveComma(tupleTy.getTypes(),
1863 [&](Type type) { printType(type); });
1864 os << '>';
1865 })
1866 .Case<NoneType>([&](Type) { os << "none"; })
1867 .Default([&](Type type) { return printDialectType(type); });
1868 }
1869
printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs,bool withKeyword)1870 void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
1871 ArrayRef<StringRef> elidedAttrs,
1872 bool withKeyword) {
1873 // If there are no attributes, then there is nothing to be done.
1874 if (attrs.empty())
1875 return;
1876
1877 // Filter out any attributes that shouldn't be included.
1878 SmallVector<NamedAttribute, 8> filteredAttrs(
1879 llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
1880 return !llvm::is_contained(elidedAttrs, attr.first.strref());
1881 }));
1882
1883 // If there are no attributes left to print after filtering, then we're done.
1884 if (filteredAttrs.empty())
1885 return;
1886
1887 // Print the 'attributes' keyword if necessary.
1888 if (withKeyword)
1889 os << " attributes";
1890
1891 // Otherwise, print them all out in braces.
1892 os << " {";
1893 interleaveComma(filteredAttrs,
1894 [&](NamedAttribute attr) { printNamedAttribute(attr); });
1895 os << '}';
1896 }
1897
printNamedAttribute(NamedAttribute attr)1898 void ModulePrinter::printNamedAttribute(NamedAttribute attr) {
1899 if (isBareIdentifier(attr.first)) {
1900 os << attr.first;
1901 } else {
1902 os << '"';
1903 printEscapedString(attr.first.strref(), os);
1904 os << '"';
1905 }
1906
1907 // Pretty printing elides the attribute value for unit attributes.
1908 if (attr.second.isa<UnitAttr>())
1909 return;
1910
1911 os << " = ";
1912 printAttribute(attr.second);
1913 }
1914
1915 //===----------------------------------------------------------------------===//
1916 // CustomDialectAsmPrinter
1917 //===----------------------------------------------------------------------===//
1918
1919 namespace {
1920 /// This class provides the main specialization of the DialectAsmPrinter that is
1921 /// used to provide support for print attributes and types. This hooks allows
1922 /// for dialects to hook into the main ModulePrinter.
1923 struct CustomDialectAsmPrinter : public DialectAsmPrinter {
1924 public:
CustomDialectAsmPrinter__anon2591390f3611::CustomDialectAsmPrinter1925 CustomDialectAsmPrinter(ModulePrinter &printer) : printer(printer) {}
~CustomDialectAsmPrinter__anon2591390f3611::CustomDialectAsmPrinter1926 ~CustomDialectAsmPrinter() override {}
1927
getStream__anon2591390f3611::CustomDialectAsmPrinter1928 raw_ostream &getStream() const override { return printer.getStream(); }
1929
1930 /// Print the given attribute to the stream.
printAttribute__anon2591390f3611::CustomDialectAsmPrinter1931 void printAttribute(Attribute attr) override { printer.printAttribute(attr); }
1932
1933 /// Print the given floating point value in a stablized form.
printFloat__anon2591390f3611::CustomDialectAsmPrinter1934 void printFloat(const APFloat &value) override {
1935 printFloatValue(value, getStream());
1936 }
1937
1938 /// Print the given type to the stream.
printType__anon2591390f3611::CustomDialectAsmPrinter1939 void printType(Type type) override { printer.printType(type); }
1940
1941 /// The main module printer.
1942 ModulePrinter &printer;
1943 };
1944 } // end anonymous namespace
1945
printDialectAttribute(Attribute attr)1946 void ModulePrinter::printDialectAttribute(Attribute attr) {
1947 auto &dialect = attr.getDialect();
1948
1949 // Ask the dialect to serialize the attribute to a string.
1950 std::string attrName;
1951 {
1952 llvm::raw_string_ostream attrNameStr(attrName);
1953 ModulePrinter subPrinter(attrNameStr, printerFlags, state);
1954 CustomDialectAsmPrinter printer(subPrinter);
1955 dialect.printAttribute(attr, printer);
1956 }
1957 printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
1958 }
1959
printDialectType(Type type)1960 void ModulePrinter::printDialectType(Type type) {
1961 auto &dialect = type.getDialect();
1962
1963 // Ask the dialect to serialize the type to a string.
1964 std::string typeName;
1965 {
1966 llvm::raw_string_ostream typeNameStr(typeName);
1967 ModulePrinter subPrinter(typeNameStr, printerFlags, state);
1968 CustomDialectAsmPrinter printer(subPrinter);
1969 dialect.printType(type, printer);
1970 }
1971 printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
1972 }
1973
1974 //===----------------------------------------------------------------------===//
1975 // Affine expressions and maps
1976 //===----------------------------------------------------------------------===//
1977
printAffineExpr(AffineExpr expr,function_ref<void (unsigned,bool)> printValueName)1978 void ModulePrinter::printAffineExpr(
1979 AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) {
1980 printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
1981 }
1982
printAffineExprInternal(AffineExpr expr,BindingStrength enclosingTightness,function_ref<void (unsigned,bool)> printValueName)1983 void ModulePrinter::printAffineExprInternal(
1984 AffineExpr expr, BindingStrength enclosingTightness,
1985 function_ref<void(unsigned, bool)> printValueName) {
1986 const char *binopSpelling = nullptr;
1987 switch (expr.getKind()) {
1988 case AffineExprKind::SymbolId: {
1989 unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
1990 if (printValueName)
1991 printValueName(pos, /*isSymbol=*/true);
1992 else
1993 os << 's' << pos;
1994 return;
1995 }
1996 case AffineExprKind::DimId: {
1997 unsigned pos = expr.cast<AffineDimExpr>().getPosition();
1998 if (printValueName)
1999 printValueName(pos, /*isSymbol=*/false);
2000 else
2001 os << 'd' << pos;
2002 return;
2003 }
2004 case AffineExprKind::Constant:
2005 os << expr.cast<AffineConstantExpr>().getValue();
2006 return;
2007 case AffineExprKind::Add:
2008 binopSpelling = " + ";
2009 break;
2010 case AffineExprKind::Mul:
2011 binopSpelling = " * ";
2012 break;
2013 case AffineExprKind::FloorDiv:
2014 binopSpelling = " floordiv ";
2015 break;
2016 case AffineExprKind::CeilDiv:
2017 binopSpelling = " ceildiv ";
2018 break;
2019 case AffineExprKind::Mod:
2020 binopSpelling = " mod ";
2021 break;
2022 }
2023
2024 auto binOp = expr.cast<AffineBinaryOpExpr>();
2025 AffineExpr lhsExpr = binOp.getLHS();
2026 AffineExpr rhsExpr = binOp.getRHS();
2027
2028 // Handle tightly binding binary operators.
2029 if (binOp.getKind() != AffineExprKind::Add) {
2030 if (enclosingTightness == BindingStrength::Strong)
2031 os << '(';
2032
2033 // Pretty print multiplication with -1.
2034 auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
2035 if (rhsConst && binOp.getKind() == AffineExprKind::Mul &&
2036 rhsConst.getValue() == -1) {
2037 os << "-";
2038 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2039 if (enclosingTightness == BindingStrength::Strong)
2040 os << ')';
2041 return;
2042 }
2043
2044 printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2045
2046 os << binopSpelling;
2047 printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
2048
2049 if (enclosingTightness == BindingStrength::Strong)
2050 os << ')';
2051 return;
2052 }
2053
2054 // Print out special "pretty" forms for add.
2055 if (enclosingTightness == BindingStrength::Strong)
2056 os << '(';
2057
2058 // Pretty print addition to a product that has a negative operand as a
2059 // subtraction.
2060 if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) {
2061 if (rhs.getKind() == AffineExprKind::Mul) {
2062 AffineExpr rrhsExpr = rhs.getRHS();
2063 if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
2064 if (rrhs.getValue() == -1) {
2065 printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2066 printValueName);
2067 os << " - ";
2068 if (rhs.getLHS().getKind() == AffineExprKind::Add) {
2069 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2070 printValueName);
2071 } else {
2072 printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
2073 printValueName);
2074 }
2075
2076 if (enclosingTightness == BindingStrength::Strong)
2077 os << ')';
2078 return;
2079 }
2080
2081 if (rrhs.getValue() < -1) {
2082 printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2083 printValueName);
2084 os << " - ";
2085 printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2086 printValueName);
2087 os << " * " << -rrhs.getValue();
2088 if (enclosingTightness == BindingStrength::Strong)
2089 os << ')';
2090 return;
2091 }
2092 }
2093 }
2094 }
2095
2096 // Pretty print addition to a negative number as a subtraction.
2097 if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
2098 if (rhsConst.getValue() < 0) {
2099 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2100 os << " - " << -rhsConst.getValue();
2101 if (enclosingTightness == BindingStrength::Strong)
2102 os << ')';
2103 return;
2104 }
2105 }
2106
2107 printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2108
2109 os << " + ";
2110 printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
2111
2112 if (enclosingTightness == BindingStrength::Strong)
2113 os << ')';
2114 }
2115
printAffineConstraint(AffineExpr expr,bool isEq)2116 void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) {
2117 printAffineExprInternal(expr, BindingStrength::Weak);
2118 isEq ? os << " == 0" : os << " >= 0";
2119 }
2120
printAffineMap(AffineMap map)2121 void ModulePrinter::printAffineMap(AffineMap map) {
2122 // Dimension identifiers.
2123 os << '(';
2124 for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
2125 os << 'd' << i << ", ";
2126 if (map.getNumDims() >= 1)
2127 os << 'd' << map.getNumDims() - 1;
2128 os << ')';
2129
2130 // Symbolic identifiers.
2131 if (map.getNumSymbols() != 0) {
2132 os << '[';
2133 for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
2134 os << 's' << i << ", ";
2135 if (map.getNumSymbols() >= 1)
2136 os << 's' << map.getNumSymbols() - 1;
2137 os << ']';
2138 }
2139
2140 // Result affine expressions.
2141 os << " -> (";
2142 interleaveComma(map.getResults(),
2143 [&](AffineExpr expr) { printAffineExpr(expr); });
2144 os << ')';
2145 }
2146
printIntegerSet(IntegerSet set)2147 void ModulePrinter::printIntegerSet(IntegerSet set) {
2148 // Dimension identifiers.
2149 os << '(';
2150 for (unsigned i = 1; i < set.getNumDims(); ++i)
2151 os << 'd' << i - 1 << ", ";
2152 if (set.getNumDims() >= 1)
2153 os << 'd' << set.getNumDims() - 1;
2154 os << ')';
2155
2156 // Symbolic identifiers.
2157 if (set.getNumSymbols() != 0) {
2158 os << '[';
2159 for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
2160 os << 's' << i << ", ";
2161 if (set.getNumSymbols() >= 1)
2162 os << 's' << set.getNumSymbols() - 1;
2163 os << ']';
2164 }
2165
2166 // Print constraints.
2167 os << " : (";
2168 int numConstraints = set.getNumConstraints();
2169 for (int i = 1; i < numConstraints; ++i) {
2170 printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
2171 os << ", ";
2172 }
2173 if (numConstraints >= 1)
2174 printAffineConstraint(set.getConstraint(numConstraints - 1),
2175 set.isEq(numConstraints - 1));
2176 os << ')';
2177 }
2178
2179 //===----------------------------------------------------------------------===//
2180 // OperationPrinter
2181 //===----------------------------------------------------------------------===//
2182
2183 namespace {
2184 /// This class contains the logic for printing operations, regions, and blocks.
2185 class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
2186 public:
OperationPrinter(raw_ostream & os,OpPrintingFlags flags,AsmStateImpl & state)2187 explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags,
2188 AsmStateImpl &state)
2189 : ModulePrinter(os, flags, &state) {}
2190
2191 /// Print the given top-level operation.
2192 void printTopLevelOperation(Operation *op);
2193
2194 /// Print the given operation with its indent and location.
2195 void print(Operation *op);
2196 /// Print the bare location, not including indentation/location/etc.
2197 void printOperation(Operation *op);
2198 /// Print the given operation in the generic form.
2199 void printGenericOp(Operation *op) override;
2200
2201 /// Print the name of the given block.
2202 void printBlockName(Block *block);
2203
2204 /// Print the given block. If 'printBlockArgs' is false, the arguments of the
2205 /// block are not printed. If 'printBlockTerminator' is false, the terminator
2206 /// operation of the block is not printed.
2207 void print(Block *block, bool printBlockArgs = true,
2208 bool printBlockTerminator = true);
2209
2210 /// Print the ID of the given value, optionally with its result number.
2211 void printValueID(Value value, bool printResultNo = true,
2212 raw_ostream *streamOverride = nullptr) const;
2213
2214 //===--------------------------------------------------------------------===//
2215 // OpAsmPrinter methods
2216 //===--------------------------------------------------------------------===//
2217
2218 /// Return the current stream of the printer.
getStream() const2219 raw_ostream &getStream() const override { return os; }
2220
2221 /// Print the given type.
printType(Type type)2222 void printType(Type type) override { ModulePrinter::printType(type); }
2223
2224 /// Print the given attribute.
printAttribute(Attribute attr)2225 void printAttribute(Attribute attr) override {
2226 ModulePrinter::printAttribute(attr);
2227 }
2228
2229 /// Print the given attribute without its type. The corresponding parser must
2230 /// provide a valid type for the attribute.
printAttributeWithoutType(Attribute attr)2231 void printAttributeWithoutType(Attribute attr) override {
2232 ModulePrinter::printAttribute(attr, AttrTypeElision::Must);
2233 }
2234
2235 /// Print the ID for the given value.
printOperand(Value value)2236 void printOperand(Value value) override { printValueID(value); }
printOperand(Value value,raw_ostream & os)2237 void printOperand(Value value, raw_ostream &os) override {
2238 printValueID(value, /*printResultNo=*/true, &os);
2239 }
2240
2241 /// Print an optional attribute dictionary with a given set of elided values.
printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})2242 void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2243 ArrayRef<StringRef> elidedAttrs = {}) override {
2244 ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs);
2245 }
printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})2246 void printOptionalAttrDictWithKeyword(
2247 ArrayRef<NamedAttribute> attrs,
2248 ArrayRef<StringRef> elidedAttrs = {}) override {
2249 ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs,
2250 /*withKeyword=*/true);
2251 }
2252
2253 /// Print the given successor.
2254 void printSuccessor(Block *successor) override;
2255
2256 /// Print an operation successor with the operands used for the block
2257 /// arguments.
2258 void printSuccessorAndUseList(Block *successor,
2259 ValueRange succOperands) override;
2260
2261 /// Print the given region.
2262 void printRegion(Region ®ion, bool printEntryBlockArgs,
2263 bool printBlockTerminators) override;
2264
2265 /// Renumber the arguments for the specified region to the same names as the
2266 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
2267 /// operations. If any entry in namesToUse is null, the corresponding
2268 /// argument name is left alone.
shadowRegionArgs(Region & region,ValueRange namesToUse)2269 void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override {
2270 state->getSSANameState().shadowRegionArgs(region, namesToUse);
2271 }
2272
2273 /// Print the given affine map with the symbol and dimension operands printed
2274 /// inline with the map.
2275 void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2276 ValueRange operands) override;
2277
2278 /// Print the given string as a symbol reference.
printSymbolName(StringRef symbolRef)2279 void printSymbolName(StringRef symbolRef) override {
2280 ::printSymbolReference(symbolRef, os);
2281 }
2282
2283 private:
2284 /// The number of spaces used for indenting nested operations.
2285 const static unsigned indentWidth = 2;
2286
2287 // This is the current indentation level for nested structures.
2288 unsigned currentIndent = 0;
2289 };
2290 } // end anonymous namespace
2291
printTopLevelOperation(Operation * op)2292 void OperationPrinter::printTopLevelOperation(Operation *op) {
2293 // Output the aliases at the top level that can't be deferred.
2294 state->getAliasState().printNonDeferredAliases(os, newLine);
2295
2296 // Print the module.
2297 print(op);
2298 os << newLine;
2299
2300 // Output the aliases at the top level that can be deferred.
2301 state->getAliasState().printDeferredAliases(os, newLine);
2302 }
2303
print(Operation * op)2304 void OperationPrinter::print(Operation *op) {
2305 // Track the location of this operation.
2306 state->registerOperationLocation(op, newLine.curLine, currentIndent);
2307
2308 os.indent(currentIndent);
2309 printOperation(op);
2310 printTrailingLocation(op->getLoc());
2311 }
2312
printOperation(Operation * op)2313 void OperationPrinter::printOperation(Operation *op) {
2314 if (size_t numResults = op->getNumResults()) {
2315 auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
2316 printValueID(op->getResult(resultNo), /*printResultNo=*/false);
2317 if (resultCount > 1)
2318 os << ':' << resultCount;
2319 };
2320
2321 // Check to see if this operation has multiple result groups.
2322 ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op);
2323 if (!resultGroups.empty()) {
2324 // Interleave the groups excluding the last one, this one will be handled
2325 // separately.
2326 interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) {
2327 printResultGroup(resultGroups[i],
2328 resultGroups[i + 1] - resultGroups[i]);
2329 });
2330 os << ", ";
2331 printResultGroup(resultGroups.back(), numResults - resultGroups.back());
2332
2333 } else {
2334 printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults);
2335 }
2336
2337 os << " = ";
2338 }
2339
2340 // If requested, always print the generic form.
2341 if (!printerFlags.shouldPrintGenericOpForm()) {
2342 // Check to see if this is a known operation. If so, use the registered
2343 // custom printer hook.
2344 if (auto *opInfo = op->getAbstractOperation()) {
2345 opInfo->printAssembly(op, *this);
2346 return;
2347 }
2348 }
2349
2350 // Otherwise print with the generic assembly form.
2351 printGenericOp(op);
2352 }
2353
printGenericOp(Operation * op)2354 void OperationPrinter::printGenericOp(Operation *op) {
2355 os << '"';
2356 printEscapedString(op->getName().getStringRef(), os);
2357 os << "\"(";
2358 interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); });
2359 os << ')';
2360
2361 // For terminators, print the list of successors and their operands.
2362 if (op->getNumSuccessors() != 0) {
2363 os << '[';
2364 interleaveComma(op->getSuccessors(),
2365 [&](Block *successor) { printBlockName(successor); });
2366 os << ']';
2367 }
2368
2369 // Print regions.
2370 if (op->getNumRegions() != 0) {
2371 os << " (";
2372 interleaveComma(op->getRegions(), [&](Region ®ion) {
2373 printRegion(region, /*printEntryBlockArgs=*/true,
2374 /*printBlockTerminators=*/true);
2375 });
2376 os << ')';
2377 }
2378
2379 auto attrs = op->getAttrs();
2380 printOptionalAttrDict(attrs);
2381
2382 // Print the type signature of the operation.
2383 os << " : ";
2384 printFunctionalType(op);
2385 }
2386
printBlockName(Block * block)2387 void OperationPrinter::printBlockName(Block *block) {
2388 auto id = state->getSSANameState().getBlockID(block);
2389 if (id != SSANameState::NameSentinel)
2390 os << "^bb" << id;
2391 else
2392 os << "^INVALIDBLOCK";
2393 }
2394
print(Block * block,bool printBlockArgs,bool printBlockTerminator)2395 void OperationPrinter::print(Block *block, bool printBlockArgs,
2396 bool printBlockTerminator) {
2397 // Print the block label and argument list if requested.
2398 if (printBlockArgs) {
2399 os.indent(currentIndent);
2400 printBlockName(block);
2401
2402 // Print the argument list if non-empty.
2403 if (!block->args_empty()) {
2404 os << '(';
2405 interleaveComma(block->getArguments(), [&](BlockArgument arg) {
2406 printValueID(arg);
2407 os << ": ";
2408 printType(arg.getType());
2409 });
2410 os << ')';
2411 }
2412 os << ':';
2413
2414 // Print out some context information about the predecessors of this block.
2415 if (!block->getParent()) {
2416 os << " // block is not in a region!";
2417 } else if (block->hasNoPredecessors()) {
2418 os << " // no predecessors";
2419 } else if (auto *pred = block->getSinglePredecessor()) {
2420 os << " // pred: ";
2421 printBlockName(pred);
2422 } else {
2423 // We want to print the predecessors in increasing numeric order, not in
2424 // whatever order the use-list is in, so gather and sort them.
2425 SmallVector<std::pair<unsigned, Block *>, 4> predIDs;
2426 for (auto *pred : block->getPredecessors())
2427 predIDs.push_back({state->getSSANameState().getBlockID(pred), pred});
2428 llvm::array_pod_sort(predIDs.begin(), predIDs.end());
2429
2430 os << " // " << predIDs.size() << " preds: ";
2431
2432 interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) {
2433 printBlockName(pred.second);
2434 });
2435 }
2436 os << newLine;
2437 }
2438
2439 currentIndent += indentWidth;
2440 auto range = llvm::make_range(
2441 block->begin(), std::prev(block->end(), printBlockTerminator ? 0 : 1));
2442 for (auto &op : range) {
2443 print(&op);
2444 os << newLine;
2445 }
2446 currentIndent -= indentWidth;
2447 }
2448
printValueID(Value value,bool printResultNo,raw_ostream * streamOverride) const2449 void OperationPrinter::printValueID(Value value, bool printResultNo,
2450 raw_ostream *streamOverride) const {
2451 state->getSSANameState().printValueID(value, printResultNo,
2452 streamOverride ? *streamOverride : os);
2453 }
2454
printSuccessor(Block * successor)2455 void OperationPrinter::printSuccessor(Block *successor) {
2456 printBlockName(successor);
2457 }
2458
printSuccessorAndUseList(Block * successor,ValueRange succOperands)2459 void OperationPrinter::printSuccessorAndUseList(Block *successor,
2460 ValueRange succOperands) {
2461 printBlockName(successor);
2462 if (succOperands.empty())
2463 return;
2464
2465 os << '(';
2466 interleaveComma(succOperands,
2467 [this](Value operand) { printValueID(operand); });
2468 os << " : ";
2469 interleaveComma(succOperands,
2470 [this](Value operand) { printType(operand.getType()); });
2471 os << ')';
2472 }
2473
printRegion(Region & region,bool printEntryBlockArgs,bool printBlockTerminators)2474 void OperationPrinter::printRegion(Region ®ion, bool printEntryBlockArgs,
2475 bool printBlockTerminators) {
2476 os << " {" << newLine;
2477 if (!region.empty()) {
2478 auto *entryBlock = ®ion.front();
2479 print(entryBlock, printEntryBlockArgs && entryBlock->getNumArguments() != 0,
2480 printBlockTerminators);
2481 for (auto &b : llvm::drop_begin(region.getBlocks(), 1))
2482 print(&b);
2483 }
2484 os.indent(currentIndent) << "}";
2485 }
2486
printAffineMapOfSSAIds(AffineMapAttr mapAttr,ValueRange operands)2487 void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2488 ValueRange operands) {
2489 AffineMap map = mapAttr.getValue();
2490 unsigned numDims = map.getNumDims();
2491 auto printValueName = [&](unsigned pos, bool isSymbol) {
2492 unsigned index = isSymbol ? numDims + pos : pos;
2493 assert(index < operands.size());
2494 if (isSymbol)
2495 os << "symbol(";
2496 printValueID(operands[index]);
2497 if (isSymbol)
2498 os << ')';
2499 };
2500
2501 interleaveComma(map.getResults(), [&](AffineExpr expr) {
2502 printAffineExpr(expr, printValueName);
2503 });
2504 }
2505
2506 //===----------------------------------------------------------------------===//
2507 // print and dump methods
2508 //===----------------------------------------------------------------------===//
2509
print(raw_ostream & os) const2510 void Attribute::print(raw_ostream &os) const {
2511 ModulePrinter(os).printAttribute(*this);
2512 }
2513
dump() const2514 void Attribute::dump() const {
2515 print(llvm::errs());
2516 llvm::errs() << "\n";
2517 }
2518
print(raw_ostream & os)2519 void Type::print(raw_ostream &os) { ModulePrinter(os).printType(*this); }
2520
dump()2521 void Type::dump() { print(llvm::errs()); }
2522
dump() const2523 void AffineMap::dump() const {
2524 print(llvm::errs());
2525 llvm::errs() << "\n";
2526 }
2527
dump() const2528 void IntegerSet::dump() const {
2529 print(llvm::errs());
2530 llvm::errs() << "\n";
2531 }
2532
print(raw_ostream & os) const2533 void AffineExpr::print(raw_ostream &os) const {
2534 if (!expr) {
2535 os << "<<NULL AFFINE EXPR>>";
2536 return;
2537 }
2538 ModulePrinter(os).printAffineExpr(*this);
2539 }
2540
dump() const2541 void AffineExpr::dump() const {
2542 print(llvm::errs());
2543 llvm::errs() << "\n";
2544 }
2545
print(raw_ostream & os) const2546 void AffineMap::print(raw_ostream &os) const {
2547 if (!map) {
2548 os << "<<NULL AFFINE MAP>>";
2549 return;
2550 }
2551 ModulePrinter(os).printAffineMap(*this);
2552 }
2553
print(raw_ostream & os) const2554 void IntegerSet::print(raw_ostream &os) const {
2555 ModulePrinter(os).printIntegerSet(*this);
2556 }
2557
print(raw_ostream & os)2558 void Value::print(raw_ostream &os) {
2559 if (auto *op = getDefiningOp())
2560 return op->print(os);
2561 // TODO: Improve this.
2562 BlockArgument arg = this->cast<BlockArgument>();
2563 os << "<block argument> of type '" << arg.getType()
2564 << "' at index: " << arg.getArgNumber() << '\n';
2565 }
print(raw_ostream & os,AsmState & state)2566 void Value::print(raw_ostream &os, AsmState &state) {
2567 if (auto *op = getDefiningOp())
2568 return op->print(os, state);
2569
2570 // TODO: Improve this.
2571 BlockArgument arg = this->cast<BlockArgument>();
2572 os << "<block argument> of type '" << arg.getType()
2573 << "' at index: " << arg.getArgNumber() << '\n';
2574 }
2575
dump()2576 void Value::dump() {
2577 print(llvm::errs());
2578 llvm::errs() << "\n";
2579 }
2580
printAsOperand(raw_ostream & os,AsmState & state)2581 void Value::printAsOperand(raw_ostream &os, AsmState &state) {
2582 // TODO: This doesn't necessarily capture all potential cases.
2583 // Currently, region arguments can be shadowed when printing the main
2584 // operation. If the IR hasn't been printed, this will produce the old SSA
2585 // name and not the shadowed name.
2586 state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true,
2587 os);
2588 }
2589
print(raw_ostream & os,OpPrintingFlags flags)2590 void Operation::print(raw_ostream &os, OpPrintingFlags flags) {
2591 // If this is a top level operation, we also print aliases.
2592 if (!getParent() && !flags.shouldUseLocalScope()) {
2593 AsmState state(this);
2594 state.getImpl().initializeAliases(this, flags);
2595 print(os, state, flags);
2596 return;
2597 }
2598
2599 // Find the operation to number from based upon the provided flags.
2600 Operation *printedOp = this;
2601 bool shouldUseLocalScope = flags.shouldUseLocalScope();
2602 do {
2603 // If we are printing local scope, stop at the first operation that is
2604 // isolated from above.
2605 if (shouldUseLocalScope && printedOp->isKnownIsolatedFromAbove())
2606 break;
2607
2608 // Otherwise, traverse up to the next parent.
2609 Operation *parentOp = printedOp->getParentOp();
2610 if (!parentOp)
2611 break;
2612 printedOp = parentOp;
2613 } while (true);
2614
2615 AsmState state(printedOp);
2616 print(os, state, flags);
2617 }
print(raw_ostream & os,AsmState & state,OpPrintingFlags flags)2618 void Operation::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) {
2619 OperationPrinter printer(os, flags, state.getImpl());
2620 if (!getParent() && !flags.shouldUseLocalScope())
2621 printer.printTopLevelOperation(this);
2622 else
2623 printer.print(this);
2624 }
2625
dump()2626 void Operation::dump() {
2627 print(llvm::errs(), OpPrintingFlags().useLocalScope());
2628 llvm::errs() << "\n";
2629 }
2630
print(raw_ostream & os)2631 void Block::print(raw_ostream &os) {
2632 Operation *parentOp = getParentOp();
2633 if (!parentOp) {
2634 os << "<<UNLINKED BLOCK>>\n";
2635 return;
2636 }
2637 // Get the top-level op.
2638 while (auto *nextOp = parentOp->getParentOp())
2639 parentOp = nextOp;
2640
2641 AsmState state(parentOp);
2642 print(os, state);
2643 }
print(raw_ostream & os,AsmState & state)2644 void Block::print(raw_ostream &os, AsmState &state) {
2645 OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this);
2646 }
2647
dump()2648 void Block::dump() { print(llvm::errs()); }
2649
2650 /// Print out the name of the block without printing its body.
printAsOperand(raw_ostream & os,bool printType)2651 void Block::printAsOperand(raw_ostream &os, bool printType) {
2652 Operation *parentOp = getParentOp();
2653 if (!parentOp) {
2654 os << "<<UNLINKED BLOCK>>\n";
2655 return;
2656 }
2657 AsmState state(parentOp);
2658 printAsOperand(os, state);
2659 }
printAsOperand(raw_ostream & os,AsmState & state)2660 void Block::printAsOperand(raw_ostream &os, AsmState &state) {
2661 OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl());
2662 printer.printBlockName(this);
2663 }
2664