• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- AsmPrinter.cpp - MLIR Assembly Printer Implementation --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the MLIR AsmPrinter class, which is used to implement
10 // the various print() methods on the core IR objects.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/IR/AffineExpr.h"
15 #include "mlir/IR/AffineMap.h"
16 #include "mlir/IR/AsmState.h"
17 #include "mlir/IR/Attributes.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/IR/DialectImplementation.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/OpImplementation.h"
24 #include "mlir/IR/Operation.h"
25 #include "llvm/ADT/APFloat.h"
26 #include "llvm/ADT/DenseMap.h"
27 #include "llvm/ADT/MapVector.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/ScopedHashTable.h"
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/ADT/SmallString.h"
32 #include "llvm/ADT/StringExtras.h"
33 #include "llvm/ADT/StringSet.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Support/Regex.h"
37 #include "llvm/Support/SaveAndRestore.h"
38 using namespace mlir;
39 using namespace mlir::detail;
40 
print(raw_ostream & os) const41 void Identifier::print(raw_ostream &os) const { os << str(); }
42 
dump() const43 void Identifier::dump() const { print(llvm::errs()); }
44 
print(raw_ostream & os) const45 void OperationName::print(raw_ostream &os) const { os << getStringRef(); }
46 
dump() const47 void OperationName::dump() const { print(llvm::errs()); }
48 
~DialectAsmPrinter()49 DialectAsmPrinter::~DialectAsmPrinter() {}
50 
~OpAsmPrinter()51 OpAsmPrinter::~OpAsmPrinter() {}
52 
53 //===--------------------------------------------------------------------===//
54 // Operation OpAsm interface.
55 //===--------------------------------------------------------------------===//
56 
57 /// The OpAsmOpInterface, see OpAsmInterface.td for more details.
58 #include "mlir/IR/OpAsmInterface.cpp.inc"
59 
60 //===----------------------------------------------------------------------===//
61 // OpPrintingFlags
62 //===----------------------------------------------------------------------===//
63 
64 namespace {
65 /// This struct contains command line options that can be used to initialize
66 /// various bits of the AsmPrinter. This uses a struct wrapper to avoid the need
67 /// for global command line options.
68 struct AsmPrinterOptions {
69   llvm::cl::opt<int64_t> printElementsAttrWithHexIfLarger{
70       "mlir-print-elementsattrs-with-hex-if-larger",
71       llvm::cl::desc(
72           "Print DenseElementsAttrs with a hex string that have "
73           "more elements than the given upper limit (use -1 to disable)")};
74 
75   llvm::cl::opt<unsigned> elideElementsAttrIfLarger{
76       "mlir-elide-elementsattrs-if-larger",
77       llvm::cl::desc("Elide ElementsAttrs with \"...\" that have "
78                      "more elements than the given upper limit")};
79 
80   llvm::cl::opt<bool> printDebugInfoOpt{
81       "mlir-print-debuginfo", llvm::cl::init(false),
82       llvm::cl::desc("Print debug info in MLIR output")};
83 
84   llvm::cl::opt<bool> printPrettyDebugInfoOpt{
85       "mlir-pretty-debuginfo", llvm::cl::init(false),
86       llvm::cl::desc("Print pretty debug info in MLIR output")};
87 
88   // Use the generic op output form in the operation printer even if the custom
89   // form is defined.
90   llvm::cl::opt<bool> printGenericOpFormOpt{
91       "mlir-print-op-generic", llvm::cl::init(false),
92       llvm::cl::desc("Print the generic op form"), llvm::cl::Hidden};
93 
94   llvm::cl::opt<bool> printLocalScopeOpt{
95       "mlir-print-local-scope", llvm::cl::init(false),
96       llvm::cl::desc("Print assuming in local scope by default"),
97       llvm::cl::Hidden};
98 };
99 } // end anonymous namespace
100 
101 static llvm::ManagedStatic<AsmPrinterOptions> clOptions;
102 
103 /// Register a set of useful command-line options that can be used to configure
104 /// various flags within the AsmPrinter.
registerAsmPrinterCLOptions()105 void mlir::registerAsmPrinterCLOptions() {
106   // Make sure that the options struct has been initialized.
107   *clOptions;
108 }
109 
110 /// Initialize the printing flags with default supplied by the cl::opts above.
OpPrintingFlags()111 OpPrintingFlags::OpPrintingFlags()
112     : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false),
113       printGenericOpFormFlag(false), printLocalScope(false) {
114   // Initialize based upon command line options, if they are available.
115   if (!clOptions.isConstructed())
116     return;
117   if (clOptions->elideElementsAttrIfLarger.getNumOccurrences())
118     elementsAttrElementLimit = clOptions->elideElementsAttrIfLarger;
119   printDebugInfoFlag = clOptions->printDebugInfoOpt;
120   printDebugInfoPrettyFormFlag = clOptions->printPrettyDebugInfoOpt;
121   printGenericOpFormFlag = clOptions->printGenericOpFormOpt;
122   printLocalScope = clOptions->printLocalScopeOpt;
123 }
124 
125 /// Enable the elision of large elements attributes, by printing a '...'
126 /// instead of the element data, when the number of elements is greater than
127 /// `largeElementLimit`. Note: The IR generated with this option is not
128 /// parsable.
129 OpPrintingFlags &
elideLargeElementsAttrs(int64_t largeElementLimit)130 OpPrintingFlags::elideLargeElementsAttrs(int64_t largeElementLimit) {
131   elementsAttrElementLimit = largeElementLimit;
132   return *this;
133 }
134 
135 /// Enable printing of debug information. If 'prettyForm' is set to true,
136 /// debug information is printed in a more readable 'pretty' form.
enableDebugInfo(bool prettyForm)137 OpPrintingFlags &OpPrintingFlags::enableDebugInfo(bool prettyForm) {
138   printDebugInfoFlag = true;
139   printDebugInfoPrettyFormFlag = prettyForm;
140   return *this;
141 }
142 
143 /// Always print operations in the generic form.
printGenericOpForm()144 OpPrintingFlags &OpPrintingFlags::printGenericOpForm() {
145   printGenericOpFormFlag = true;
146   return *this;
147 }
148 
149 /// Use local scope when printing the operation. This allows for using the
150 /// printer in a more localized and thread-safe setting, but may not necessarily
151 /// be identical of what the IR will look like when dumping the full module.
useLocalScope()152 OpPrintingFlags &OpPrintingFlags::useLocalScope() {
153   printLocalScope = true;
154   return *this;
155 }
156 
157 /// Return if the given ElementsAttr should be elided.
shouldElideElementsAttr(ElementsAttr attr) const158 bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const {
159   return elementsAttrElementLimit.hasValue() &&
160          *elementsAttrElementLimit < int64_t(attr.getNumElements()) &&
161          !attr.isa<SplatElementsAttr>();
162 }
163 
164 /// Return the size limit for printing large ElementsAttr.
getLargeElementsAttrLimit() const165 Optional<int64_t> OpPrintingFlags::getLargeElementsAttrLimit() const {
166   return elementsAttrElementLimit;
167 }
168 
169 /// Return if debug information should be printed.
shouldPrintDebugInfo() const170 bool OpPrintingFlags::shouldPrintDebugInfo() const {
171   return printDebugInfoFlag;
172 }
173 
174 /// Return if debug information should be printed in the pretty form.
shouldPrintDebugInfoPrettyForm() const175 bool OpPrintingFlags::shouldPrintDebugInfoPrettyForm() const {
176   return printDebugInfoPrettyFormFlag;
177 }
178 
179 /// Return if operations should be printed in the generic form.
shouldPrintGenericOpForm() const180 bool OpPrintingFlags::shouldPrintGenericOpForm() const {
181   return printGenericOpFormFlag;
182 }
183 
184 /// Return if the printer should use local scope when dumping the IR.
shouldUseLocalScope() const185 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
186 
187 /// Returns true if an ElementsAttr with the given number of elements should be
188 /// printed with hex.
shouldPrintElementsAttrWithHex(int64_t numElements)189 static bool shouldPrintElementsAttrWithHex(int64_t numElements) {
190   // Check to see if a command line option was provided for the limit.
191   if (clOptions.isConstructed()) {
192     if (clOptions->printElementsAttrWithHexIfLarger.getNumOccurrences()) {
193       // -1 is used to disable hex printing.
194       if (clOptions->printElementsAttrWithHexIfLarger == -1)
195         return false;
196       return numElements > clOptions->printElementsAttrWithHexIfLarger;
197     }
198   }
199 
200   // Otherwise, default to printing with hex if the number of elements is >100.
201   return numElements > 100;
202 }
203 
204 //===----------------------------------------------------------------------===//
205 // NewLineCounter
206 //===----------------------------------------------------------------------===//
207 
208 namespace {
209 /// This class is a simple formatter that emits a new line when inputted into a
210 /// stream, that enables counting the number of newlines emitted. This class
211 /// should be used whenever emitting newlines in the printer.
212 struct NewLineCounter {
213   unsigned curLine = 1;
214 };
215 } // end anonymous namespace
216 
operator <<(raw_ostream & os,NewLineCounter & newLine)217 static raw_ostream &operator<<(raw_ostream &os, NewLineCounter &newLine) {
218   ++newLine.curLine;
219   return os << '\n';
220 }
221 
222 //===----------------------------------------------------------------------===//
223 // AliasInitializer
224 //===----------------------------------------------------------------------===//
225 
226 namespace {
227 /// This class represents a specific instance of a symbol Alias.
228 class SymbolAlias {
229 public:
SymbolAlias(StringRef name,bool isDeferrable)230   SymbolAlias(StringRef name, bool isDeferrable)
231       : name(name), suffixIndex(0), hasSuffixIndex(false),
232         isDeferrable(isDeferrable) {}
SymbolAlias(StringRef name,uint32_t suffixIndex,bool isDeferrable)233   SymbolAlias(StringRef name, uint32_t suffixIndex, bool isDeferrable)
234       : name(name), suffixIndex(suffixIndex), hasSuffixIndex(true),
235         isDeferrable(isDeferrable) {}
236 
237   /// Print this alias to the given stream.
print(raw_ostream & os) const238   void print(raw_ostream &os) const {
239     os << name;
240     if (hasSuffixIndex)
241       os << suffixIndex;
242   }
243 
244   /// Returns true if this alias supports deferred resolution when parsing.
canBeDeferred() const245   bool canBeDeferred() const { return isDeferrable; }
246 
247 private:
248   /// The main name of the alias.
249   StringRef name;
250   /// The optional suffix index of the alias, if multiple aliases had the same
251   /// name.
252   uint32_t suffixIndex : 30;
253   /// A flag indicating whether this alias has a suffix or not.
254   bool hasSuffixIndex : 1;
255   /// A flag indicating whether this alias may be deferred or not.
256   bool isDeferrable : 1;
257 };
258 
259 /// This class represents a utility that initializes the set of attribute and
260 /// type aliases, without the need to store the extra information within the
261 /// main AliasState class or pass it around via function arguments.
262 class AliasInitializer {
263 public:
AliasInitializer(DialectInterfaceCollection<OpAsmDialectInterface> & interfaces,llvm::BumpPtrAllocator & aliasAllocator)264   AliasInitializer(
265       DialectInterfaceCollection<OpAsmDialectInterface> &interfaces,
266       llvm::BumpPtrAllocator &aliasAllocator)
267       : interfaces(interfaces), aliasAllocator(aliasAllocator),
268         aliasOS(aliasBuffer) {}
269 
270   void initialize(Operation *op, const OpPrintingFlags &printerFlags,
271                   llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
272                   llvm::MapVector<Type, SymbolAlias> &typeToAlias);
273 
274   /// Visit the given attribute to see if it has an alias. `canBeDeferred` is
275   /// set to true if the originator of this attribute can resolve the alias
276   /// after parsing has completed (e.g. in the case of operation locations).
277   void visit(Attribute attr, bool canBeDeferred = false);
278 
279   /// Visit the given type to see if it has an alias.
280   void visit(Type type);
281 
282 private:
283   /// Try to generate an alias for the provided symbol. If an alias is
284   /// generated, the provided alias mapping and reverse mapping are updated.
285   /// Returns success if an alias was generated, failure otherwise.
286   template <typename T>
287   LogicalResult
288   generateAlias(T symbol,
289                 llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol);
290 
291   /// The set of asm interfaces within the context.
292   DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
293 
294   /// Mapping between an alias and the set of symbols mapped to it.
295   llvm::MapVector<StringRef, std::vector<Attribute>> aliasToAttr;
296   llvm::MapVector<StringRef, std::vector<Type>> aliasToType;
297 
298   /// An allocator used for alias names.
299   llvm::BumpPtrAllocator &aliasAllocator;
300 
301   /// The set of visited attributes.
302   DenseSet<Attribute> visitedAttributes;
303 
304   /// The set of attributes that have aliases *and* can be deferred.
305   DenseSet<Attribute> deferrableAttributes;
306 
307   /// The set of visited types.
308   DenseSet<Type> visitedTypes;
309 
310   /// Storage and stream used when generating an alias.
311   SmallString<32> aliasBuffer;
312   llvm::raw_svector_ostream aliasOS;
313 };
314 
315 /// This class implements a dummy OpAsmPrinter that doesn't print any output,
316 /// and merely collects the attributes and types that *would* be printed in a
317 /// normal print invocation so that we can generate proper aliases. This allows
318 /// for us to generate aliases only for the attributes and types that would be
319 /// in the output, and trims down unnecessary output.
320 class DummyAliasOperationPrinter : private OpAsmPrinter {
321 public:
DummyAliasOperationPrinter(const OpPrintingFlags & flags,AliasInitializer & initializer)322   explicit DummyAliasOperationPrinter(const OpPrintingFlags &flags,
323                                       AliasInitializer &initializer)
324       : printerFlags(flags), initializer(initializer) {}
325 
326   /// Print the given operation.
print(Operation * op)327   void print(Operation *op) {
328     // Visit the operation location.
329     if (printerFlags.shouldPrintDebugInfo())
330       initializer.visit(op->getLoc(), /*canBeDeferred=*/true);
331 
332     // If requested, always print the generic form.
333     if (!printerFlags.shouldPrintGenericOpForm()) {
334       // Check to see if this is a known operation.  If so, use the registered
335       // custom printer hook.
336       if (auto *opInfo = op->getAbstractOperation()) {
337         opInfo->printAssembly(op, *this);
338         return;
339       }
340     }
341 
342     // Otherwise print with the generic assembly form.
343     printGenericOp(op);
344   }
345 
346 private:
347   /// Print the given operation in the generic form.
printGenericOp(Operation * op)348   void printGenericOp(Operation *op) override {
349     // Consider nested opertions for aliases.
350     if (op->getNumRegions() != 0) {
351       for (Region &region : op->getRegions())
352         printRegion(region, /*printEntryBlockArgs=*/true,
353                     /*printBlockTerminators=*/true);
354     }
355 
356     // Visit all the types used in the operation.
357     for (Type type : op->getOperandTypes())
358       printType(type);
359     for (Type type : op->getResultTypes())
360       printType(type);
361 
362     // Consider the attributes of the operation for aliases.
363     for (const NamedAttribute &attr : op->getAttrs())
364       printAttribute(attr.second);
365   }
366 
367   /// Print the given block. If 'printBlockArgs' is false, the arguments of the
368   /// block are not printed. If 'printBlockTerminator' is false, the terminator
369   /// operation of the block is not printed.
print(Block * block,bool printBlockArgs=true,bool printBlockTerminator=true)370   void print(Block *block, bool printBlockArgs = true,
371              bool printBlockTerminator = true) {
372     // Consider the types of the block arguments for aliases if 'printBlockArgs'
373     // is set to true.
374     if (printBlockArgs) {
375       for (Type type : block->getArgumentTypes())
376         printType(type);
377     }
378 
379     // Consider the operations within this block, ignoring the terminator if
380     // requested.
381     auto range = llvm::make_range(
382         block->begin(), std::prev(block->end(), printBlockTerminator ? 0 : 1));
383     for (Operation &op : range)
384       print(&op);
385   }
386 
387   /// Print the given region.
printRegion(Region & region,bool printEntryBlockArgs,bool printBlockTerminators)388   void printRegion(Region &region, bool printEntryBlockArgs,
389                    bool printBlockTerminators) override {
390     if (region.empty())
391       return;
392 
393     auto *entryBlock = &region.front();
394     print(entryBlock, printEntryBlockArgs, printBlockTerminators);
395     for (Block &b : llvm::drop_begin(region, 1))
396       print(&b);
397   }
398 
399   /// Consider the given type to be printed for an alias.
printType(Type type)400   void printType(Type type) override { initializer.visit(type); }
401 
402   /// Consider the given attribute to be printed for an alias.
printAttribute(Attribute attr)403   void printAttribute(Attribute attr) override { initializer.visit(attr); }
printAttributeWithoutType(Attribute attr)404   void printAttributeWithoutType(Attribute attr) override {
405     printAttribute(attr);
406   }
407 
408   /// Print the given set of attributes with names not included within
409   /// 'elidedAttrs'.
printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})410   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
411                              ArrayRef<StringRef> elidedAttrs = {}) override {
412     // Filter out any attributes that shouldn't be included.
413     SmallVector<NamedAttribute, 8> filteredAttrs(
__anon2591390f0402(NamedAttribute attr) 414         llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
415           return !llvm::is_contained(elidedAttrs, attr.first.strref());
416         }));
417     for (const NamedAttribute &attr : filteredAttrs)
418       printAttribute(attr.second);
419   }
printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})420   void printOptionalAttrDictWithKeyword(
421       ArrayRef<NamedAttribute> attrs,
422       ArrayRef<StringRef> elidedAttrs = {}) override {
423     printOptionalAttrDict(attrs, elidedAttrs);
424   }
425 
426   /// Return 'nulls' as the output stream, this will ignore any data fed to it.
getStream() const427   raw_ostream &getStream() const override { return llvm::nulls(); }
428 
429   /// The following are hooks of `OpAsmPrinter` that are not necessary for
430   /// determining potential aliases.
printAffineMapOfSSAIds(AffineMapAttr,ValueRange)431   void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {}
printOperand(Value)432   void printOperand(Value) override {}
printOperand(Value,raw_ostream & os)433   void printOperand(Value, raw_ostream &os) override {
434     // Users expect the output string to have at least the prefixed % to signal
435     // a value name. To maintain this invariant, emit a name even if it is
436     // guaranteed to go unused.
437     os << "%";
438   }
printSymbolName(StringRef)439   void printSymbolName(StringRef) override {}
printSuccessor(Block *)440   void printSuccessor(Block *) override {}
printSuccessorAndUseList(Block *,ValueRange)441   void printSuccessorAndUseList(Block *, ValueRange) override {}
shadowRegionArgs(Region &,ValueRange)442   void shadowRegionArgs(Region &, ValueRange) override {}
443 
444   /// The printer flags to use when determining potential aliases.
445   const OpPrintingFlags &printerFlags;
446 
447   /// The initializer to use when identifying aliases.
448   AliasInitializer &initializer;
449 };
450 } // end anonymous namespace
451 
452 /// Sanitize the given name such that it can be used as a valid identifier. If
453 /// the string needs to be modified in any way, the provided buffer is used to
454 /// store the new copy,
sanitizeIdentifier(StringRef name,SmallString<16> & buffer,StringRef allowedPunctChars="$._-",bool allowTrailingDigit=true)455 static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
456                                     StringRef allowedPunctChars = "$._-",
457                                     bool allowTrailingDigit = true) {
458   assert(!name.empty() && "Shouldn't have an empty name here");
459 
460   auto copyNameToBuffer = [&] {
461     for (char ch : name) {
462       if (llvm::isAlnum(ch) || allowedPunctChars.contains(ch))
463         buffer.push_back(ch);
464       else if (ch == ' ')
465         buffer.push_back('_');
466       else
467         buffer.append(llvm::utohexstr((unsigned char)ch));
468     }
469   };
470 
471   // Check to see if this name is valid. If it starts with a digit, then it
472   // could conflict with the autogenerated numeric ID's, so add an underscore
473   // prefix to avoid problems.
474   if (isdigit(name[0])) {
475     buffer.push_back('_');
476     copyNameToBuffer();
477     return buffer;
478   }
479 
480   // If the name ends with a trailing digit, add a '_' to avoid potential
481   // conflicts with autogenerated ID's.
482   if (!allowTrailingDigit && isdigit(name.back())) {
483     copyNameToBuffer();
484     buffer.push_back('_');
485     return buffer;
486   }
487 
488   // Check to see that the name consists of only valid identifier characters.
489   for (char ch : name) {
490     if (!llvm::isAlnum(ch) && !allowedPunctChars.contains(ch)) {
491       copyNameToBuffer();
492       return buffer;
493     }
494   }
495 
496   // If there are no invalid characters, return the original name.
497   return name;
498 }
499 
500 /// Given a collection of aliases and symbols, initialize a mapping from a
501 /// symbol to a given alias.
502 template <typename T>
503 static void
initializeAliases(llvm::MapVector<StringRef,std::vector<T>> & aliasToSymbol,llvm::MapVector<T,SymbolAlias> & symbolToAlias,DenseSet<T> * deferrableAliases=nullptr)504 initializeAliases(llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol,
505                   llvm::MapVector<T, SymbolAlias> &symbolToAlias,
506                   DenseSet<T> *deferrableAliases = nullptr) {
507   std::vector<std::pair<StringRef, std::vector<T>>> aliases =
508       aliasToSymbol.takeVector();
509   llvm::array_pod_sort(aliases.begin(), aliases.end(),
510                        [](const auto *lhs, const auto *rhs) {
511                          return lhs->first.compare(rhs->first);
512                        });
513 
514   for (auto &it : aliases) {
515     // If there is only one instance for this alias, use the name directly.
516     if (it.second.size() == 1) {
517       T symbol = it.second.front();
518       bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
519       symbolToAlias.insert({symbol, SymbolAlias(it.first, isDeferrable)});
520       continue;
521     }
522     // Otherwise, add the index to the name.
523     for (int i = 0, e = it.second.size(); i < e; ++i) {
524       T symbol = it.second[i];
525       bool isDeferrable = deferrableAliases && deferrableAliases->count(symbol);
526       symbolToAlias.insert({symbol, SymbolAlias(it.first, i, isDeferrable)});
527     }
528   }
529 }
530 
initialize(Operation * op,const OpPrintingFlags & printerFlags,llvm::MapVector<Attribute,SymbolAlias> & attrToAlias,llvm::MapVector<Type,SymbolAlias> & typeToAlias)531 void AliasInitializer::initialize(
532     Operation *op, const OpPrintingFlags &printerFlags,
533     llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
534     llvm::MapVector<Type, SymbolAlias> &typeToAlias) {
535   // Use a dummy printer when walking the IR so that we can collect the
536   // attributes/types that will actually be used during printing when
537   // considering aliases.
538   DummyAliasOperationPrinter aliasPrinter(printerFlags, *this);
539   aliasPrinter.print(op);
540 
541   // Initialize the aliases sorted by name.
542   initializeAliases(aliasToAttr, attrToAlias, &deferrableAttributes);
543   initializeAliases(aliasToType, typeToAlias);
544 }
545 
visit(Attribute attr,bool canBeDeferred)546 void AliasInitializer::visit(Attribute attr, bool canBeDeferred) {
547   if (!visitedAttributes.insert(attr).second) {
548     // If this attribute already has an alias and this instance can't be
549     // deferred, make sure that the alias isn't deferred.
550     if (!canBeDeferred)
551       deferrableAttributes.erase(attr);
552     return;
553   }
554 
555   // Try to generate an alias for this attribute.
556   if (succeeded(generateAlias(attr, aliasToAttr))) {
557     if (canBeDeferred)
558       deferrableAttributes.insert(attr);
559     return;
560   }
561 
562   if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
563     for (Attribute element : arrayAttr.getValue())
564       visit(element);
565   } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) {
566     for (const NamedAttribute &attr : dictAttr)
567       visit(attr.second);
568   } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
569     visit(typeAttr.getValue());
570   }
571 }
572 
visit(Type type)573 void AliasInitializer::visit(Type type) {
574   if (!visitedTypes.insert(type).second)
575     return;
576 
577   // Try to generate an alias for this type.
578   if (succeeded(generateAlias(type, aliasToType)))
579     return;
580 
581   // Visit several subtypes that contain types or atttributes.
582   if (auto funcType = type.dyn_cast<FunctionType>()) {
583     // Visit input and result types for functions.
584     for (auto input : funcType.getInputs())
585       visit(input);
586     for (auto result : funcType.getResults())
587       visit(result);
588   } else if (auto shapedType = type.dyn_cast<ShapedType>()) {
589     visit(shapedType.getElementType());
590 
591     // Visit affine maps in memref type.
592     if (auto memref = type.dyn_cast<MemRefType>())
593       for (auto map : memref.getAffineMaps())
594         visit(AffineMapAttr::get(map));
595   }
596 }
597 
598 template <typename T>
generateAlias(T symbol,llvm::MapVector<StringRef,std::vector<T>> & aliasToSymbol)599 LogicalResult AliasInitializer::generateAlias(
600     T symbol, llvm::MapVector<StringRef, std::vector<T>> &aliasToSymbol) {
601   SmallString<16> tempBuffer;
602   for (const auto &interface : interfaces) {
603     interface.getAlias(symbol, aliasOS);
604     StringRef name = aliasOS.str();
605     if (name.empty())
606       continue;
607     name = sanitizeIdentifier(name, tempBuffer, /*allowedPunctChars=*/"$_-",
608                               /*allowTrailingDigit=*/false);
609     name = name.copy(aliasAllocator);
610 
611     aliasToSymbol[name].push_back(symbol);
612     aliasBuffer.clear();
613     return success();
614   }
615   return failure();
616 }
617 
618 //===----------------------------------------------------------------------===//
619 // AliasState
620 //===----------------------------------------------------------------------===//
621 
622 namespace {
623 /// This class manages the state for type and attribute aliases.
624 class AliasState {
625 public:
626   // Initialize the internal aliases.
627   void
628   initialize(Operation *op, const OpPrintingFlags &printerFlags,
629              DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
630 
631   /// Get an alias for the given attribute if it has one and print it in `os`.
632   /// Returns success if an alias was printed, failure otherwise.
633   LogicalResult getAlias(Attribute attr, raw_ostream &os) const;
634 
635   /// Get an alias for the given type if it has one and print it in `os`.
636   /// Returns success if an alias was printed, failure otherwise.
637   LogicalResult getAlias(Type ty, raw_ostream &os) const;
638 
639   /// Print all of the referenced aliases that can not be resolved in a deferred
640   /// manner.
printNonDeferredAliases(raw_ostream & os,NewLineCounter & newLine) const641   void printNonDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
642     printAliases(os, newLine, /*isDeferred=*/false);
643   }
644 
645   /// Print all of the referenced aliases that support deferred resolution.
printDeferredAliases(raw_ostream & os,NewLineCounter & newLine) const646   void printDeferredAliases(raw_ostream &os, NewLineCounter &newLine) const {
647     printAliases(os, newLine, /*isDeferred=*/true);
648   }
649 
650 private:
651   /// Print all of the referenced aliases that support the provided resolution
652   /// behavior.
653   void printAliases(raw_ostream &os, NewLineCounter &newLine,
654                     bool isDeferred) const;
655 
656   /// Mapping between attribute and alias.
657   llvm::MapVector<Attribute, SymbolAlias> attrToAlias;
658   /// Mapping between type and alias.
659   llvm::MapVector<Type, SymbolAlias> typeToAlias;
660 
661   /// An allocator used for alias names.
662   llvm::BumpPtrAllocator aliasAllocator;
663 };
664 } // end anonymous namespace
665 
initialize(Operation * op,const OpPrintingFlags & printerFlags,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)666 void AliasState::initialize(
667     Operation *op, const OpPrintingFlags &printerFlags,
668     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
669   AliasInitializer initializer(interfaces, aliasAllocator);
670   initializer.initialize(op, printerFlags, attrToAlias, typeToAlias);
671 }
672 
getAlias(Attribute attr,raw_ostream & os) const673 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
674   auto it = attrToAlias.find(attr);
675   if (it == attrToAlias.end())
676     return failure();
677   it->second.print(os << '#');
678   return success();
679 }
680 
getAlias(Type ty,raw_ostream & os) const681 LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
682   auto it = typeToAlias.find(ty);
683   if (it == typeToAlias.end())
684     return failure();
685 
686   it->second.print(os << '!');
687   return success();
688 }
689 
printAliases(raw_ostream & os,NewLineCounter & newLine,bool isDeferred) const690 void AliasState::printAliases(raw_ostream &os, NewLineCounter &newLine,
691                               bool isDeferred) const {
692   auto filterFn = [=](const auto &aliasIt) {
693     return aliasIt.second.canBeDeferred() == isDeferred;
694   };
695   for (const auto &it : llvm::make_filter_range(attrToAlias, filterFn)) {
696     it.second.print(os << '#');
697     os << " = " << it.first << newLine;
698   }
699   for (const auto &it : llvm::make_filter_range(typeToAlias, filterFn)) {
700     it.second.print(os << '!');
701     os << " = " << it.first << newLine;
702   }
703 }
704 
705 //===----------------------------------------------------------------------===//
706 // SSANameState
707 //===----------------------------------------------------------------------===//
708 
709 namespace {
710 /// This class manages the state of SSA value names.
711 class SSANameState {
712 public:
713   /// A sentinel value used for values with names set.
714   enum : unsigned { NameSentinel = ~0U };
715 
716   SSANameState(Operation *op,
717                DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
718 
719   /// Print the SSA identifier for the given value to 'stream'. If
720   /// 'printResultNo' is true, it also presents the result number ('#' number)
721   /// of this value.
722   void printValueID(Value value, bool printResultNo, raw_ostream &stream) const;
723 
724   /// Return the result indices for each of the result groups registered by this
725   /// operation, or empty if none exist.
726   ArrayRef<int> getOpResultGroups(Operation *op);
727 
728   /// Get the ID for the given block.
729   unsigned getBlockID(Block *block);
730 
731   /// Renumber the arguments for the specified region to the same names as the
732   /// SSA values in namesToUse. See OperationPrinter::shadowRegionArgs for
733   /// details.
734   void shadowRegionArgs(Region &region, ValueRange namesToUse);
735 
736 private:
737   /// Number the SSA values within the given IR unit.
738   void numberValuesInRegion(
739       Region &region,
740       DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
741   void numberValuesInBlock(
742       Block &block,
743       DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
744   void numberValuesInOp(
745       Operation &op,
746       DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
747 
748   /// Given a result of an operation 'result', find the result group head
749   /// 'lookupValue' and the result of 'result' within that group in
750   /// 'lookupResultNo'. 'lookupResultNo' is only filled in if the result group
751   /// has more than 1 result.
752   void getResultIDAndNumber(OpResult result, Value &lookupValue,
753                             Optional<int> &lookupResultNo) const;
754 
755   /// Set a special value name for the given value.
756   void setValueName(Value value, StringRef name);
757 
758   /// Uniques the given value name within the printer. If the given name
759   /// conflicts, it is automatically renamed.
760   StringRef uniqueValueName(StringRef name);
761 
762   /// This is the value ID for each SSA value. If this returns NameSentinel,
763   /// then the valueID has an entry in valueNames.
764   DenseMap<Value, unsigned> valueIDs;
765   DenseMap<Value, StringRef> valueNames;
766 
767   /// This is a map of operations that contain multiple named result groups,
768   /// i.e. there may be multiple names for the results of the operation. The
769   /// value of this map are the result numbers that start a result group.
770   DenseMap<Operation *, SmallVector<int, 1>> opResultGroups;
771 
772   /// This is the block ID for each block in the current.
773   DenseMap<Block *, unsigned> blockIDs;
774 
775   /// This keeps track of all of the non-numeric names that are in flight,
776   /// allowing us to check for duplicates.
777   /// Note: the value of the map is unused.
778   llvm::ScopedHashTable<StringRef, char> usedNames;
779   llvm::BumpPtrAllocator usedNameAllocator;
780 
781   /// This is the next value ID to assign in numbering.
782   unsigned nextValueID = 0;
783   /// This is the next ID to assign to a region entry block argument.
784   unsigned nextArgumentID = 0;
785   /// This is the next ID to assign when a name conflict is detected.
786   unsigned nextConflictID = 0;
787 };
788 } // end anonymous namespace
789 
SSANameState(Operation * op,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)790 SSANameState::SSANameState(
791     Operation *op,
792     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
793   llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames);
794   numberValuesInOp(*op, interfaces);
795 
796   for (auto &region : op->getRegions())
797     numberValuesInRegion(region, interfaces);
798 }
799 
printValueID(Value value,bool printResultNo,raw_ostream & stream) const800 void SSANameState::printValueID(Value value, bool printResultNo,
801                                 raw_ostream &stream) const {
802   if (!value) {
803     stream << "<<NULL>>";
804     return;
805   }
806 
807   Optional<int> resultNo;
808   auto lookupValue = value;
809 
810   // If this is an operation result, collect the head lookup value of the result
811   // group and the result number of 'result' within that group.
812   if (OpResult result = value.dyn_cast<OpResult>())
813     getResultIDAndNumber(result, lookupValue, resultNo);
814 
815   auto it = valueIDs.find(lookupValue);
816   if (it == valueIDs.end()) {
817     stream << "<<UNKNOWN SSA VALUE>>";
818     return;
819   }
820 
821   stream << '%';
822   if (it->second != NameSentinel) {
823     stream << it->second;
824   } else {
825     auto nameIt = valueNames.find(lookupValue);
826     assert(nameIt != valueNames.end() && "Didn't have a name entry?");
827     stream << nameIt->second;
828   }
829 
830   if (resultNo.hasValue() && printResultNo)
831     stream << '#' << resultNo;
832 }
833 
getOpResultGroups(Operation * op)834 ArrayRef<int> SSANameState::getOpResultGroups(Operation *op) {
835   auto it = opResultGroups.find(op);
836   return it == opResultGroups.end() ? ArrayRef<int>() : it->second;
837 }
838 
getBlockID(Block * block)839 unsigned SSANameState::getBlockID(Block *block) {
840   auto it = blockIDs.find(block);
841   return it != blockIDs.end() ? it->second : NameSentinel;
842 }
843 
shadowRegionArgs(Region & region,ValueRange namesToUse)844 void SSANameState::shadowRegionArgs(Region &region, ValueRange namesToUse) {
845   assert(!region.empty() && "cannot shadow arguments of an empty region");
846   assert(region.getNumArguments() == namesToUse.size() &&
847          "incorrect number of names passed in");
848   assert(region.getParentOp()->isKnownIsolatedFromAbove() &&
849          "only KnownIsolatedFromAbove ops can shadow names");
850 
851   SmallVector<char, 16> nameStr;
852   for (unsigned i = 0, e = namesToUse.size(); i != e; ++i) {
853     auto nameToUse = namesToUse[i];
854     if (nameToUse == nullptr)
855       continue;
856     auto nameToReplace = region.getArgument(i);
857 
858     nameStr.clear();
859     llvm::raw_svector_ostream nameStream(nameStr);
860     printValueID(nameToUse, /*printResultNo=*/true, nameStream);
861 
862     // Entry block arguments should already have a pretty "arg" name.
863     assert(valueIDs[nameToReplace] == NameSentinel);
864 
865     // Use the name without the leading %.
866     auto name = StringRef(nameStream.str()).drop_front();
867 
868     // Overwrite the name.
869     valueNames[nameToReplace] = name.copy(usedNameAllocator);
870   }
871 }
872 
numberValuesInRegion(Region & region,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)873 void SSANameState::numberValuesInRegion(
874     Region &region,
875     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
876   // Save the current value ids to allow for numbering values in sibling regions
877   // the same.
878   llvm::SaveAndRestore<unsigned> valueIDSaver(nextValueID);
879   llvm::SaveAndRestore<unsigned> argumentIDSaver(nextArgumentID);
880   llvm::SaveAndRestore<unsigned> conflictIDSaver(nextConflictID);
881 
882   // Push a new used names scope.
883   llvm::ScopedHashTable<StringRef, char>::ScopeTy usedNamesScope(usedNames);
884 
885   // Number the values within this region in a breadth-first order.
886   unsigned nextBlockID = 0;
887   for (auto &block : region) {
888     // Each block gets a unique ID, and all of the operations within it get
889     // numbered as well.
890     blockIDs[&block] = nextBlockID++;
891     numberValuesInBlock(block, interfaces);
892   }
893 
894   // After that we traverse the nested regions.
895   // TODO: Rework this loop to not use recursion.
896   for (auto &block : region) {
897     for (auto &op : block)
898       for (auto &nestedRegion : op.getRegions())
899         numberValuesInRegion(nestedRegion, interfaces);
900   }
901 }
902 
numberValuesInBlock(Block & block,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)903 void SSANameState::numberValuesInBlock(
904     Block &block,
905     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
906   auto setArgNameFn = [&](Value arg, StringRef name) {
907     assert(!valueIDs.count(arg) && "arg numbered multiple times");
908     assert(arg.cast<BlockArgument>().getOwner() == &block &&
909            "arg not defined in 'block'");
910     setValueName(arg, name);
911   };
912 
913   bool isEntryBlock = block.isEntryBlock();
914   if (isEntryBlock) {
915     if (auto *op = block.getParentOp()) {
916       if (auto asmInterface = interfaces.getInterfaceFor(op->getDialect()))
917         asmInterface->getAsmBlockArgumentNames(&block, setArgNameFn);
918     }
919   }
920 
921   // Number the block arguments. We give entry block arguments a special name
922   // 'arg'.
923   SmallString<32> specialNameBuffer(isEntryBlock ? "arg" : "");
924   llvm::raw_svector_ostream specialName(specialNameBuffer);
925   for (auto arg : block.getArguments()) {
926     if (valueIDs.count(arg))
927       continue;
928     if (isEntryBlock) {
929       specialNameBuffer.resize(strlen("arg"));
930       specialName << nextArgumentID++;
931     }
932     setValueName(arg, specialName.str());
933   }
934 
935   // Number the operations in this block.
936   for (auto &op : block)
937     numberValuesInOp(op, interfaces);
938 }
939 
numberValuesInOp(Operation & op,DialectInterfaceCollection<OpAsmDialectInterface> & interfaces)940 void SSANameState::numberValuesInOp(
941     Operation &op,
942     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
943   unsigned numResults = op.getNumResults();
944   if (numResults == 0)
945     return;
946   Value resultBegin = op.getResult(0);
947 
948   // Function used to set the special result names for the operation.
949   SmallVector<int, 2> resultGroups(/*Size=*/1, /*Value=*/0);
950   auto setResultNameFn = [&](Value result, StringRef name) {
951     assert(!valueIDs.count(result) && "result numbered multiple times");
952     assert(result.getDefiningOp() == &op && "result not defined by 'op'");
953     setValueName(result, name);
954 
955     // Record the result number for groups not anchored at 0.
956     if (int resultNo = result.cast<OpResult>().getResultNumber())
957       resultGroups.push_back(resultNo);
958   };
959   if (OpAsmOpInterface asmInterface = dyn_cast<OpAsmOpInterface>(&op))
960     asmInterface.getAsmResultNames(setResultNameFn);
961   else if (auto *asmInterface = interfaces.getInterfaceFor(op.getDialect()))
962     asmInterface->getAsmResultNames(&op, setResultNameFn);
963 
964   // If the first result wasn't numbered, give it a default number.
965   if (valueIDs.try_emplace(resultBegin, nextValueID).second)
966     ++nextValueID;
967 
968   // If this operation has multiple result groups, mark it.
969   if (resultGroups.size() != 1) {
970     llvm::array_pod_sort(resultGroups.begin(), resultGroups.end());
971     opResultGroups.try_emplace(&op, std::move(resultGroups));
972   }
973 }
974 
getResultIDAndNumber(OpResult result,Value & lookupValue,Optional<int> & lookupResultNo) const975 void SSANameState::getResultIDAndNumber(OpResult result, Value &lookupValue,
976                                         Optional<int> &lookupResultNo) const {
977   Operation *owner = result.getOwner();
978   if (owner->getNumResults() == 1)
979     return;
980   int resultNo = result.getResultNumber();
981 
982   // If this operation has multiple result groups, we will need to find the
983   // one corresponding to this result.
984   auto resultGroupIt = opResultGroups.find(owner);
985   if (resultGroupIt == opResultGroups.end()) {
986     // If not, just use the first result.
987     lookupResultNo = resultNo;
988     lookupValue = owner->getResult(0);
989     return;
990   }
991 
992   // Find the correct index using a binary search, as the groups are ordered.
993   ArrayRef<int> resultGroups = resultGroupIt->second;
994   auto it = llvm::upper_bound(resultGroups, resultNo);
995   int groupResultNo = 0, groupSize = 0;
996 
997   // If there are no smaller elements, the last result group is the lookup.
998   if (it == resultGroups.end()) {
999     groupResultNo = resultGroups.back();
1000     groupSize = static_cast<int>(owner->getNumResults()) - resultGroups.back();
1001   } else {
1002     // Otherwise, the previous element is the lookup.
1003     groupResultNo = *std::prev(it);
1004     groupSize = *it - groupResultNo;
1005   }
1006 
1007   // We only record the result number for a group of size greater than 1.
1008   if (groupSize != 1)
1009     lookupResultNo = resultNo - groupResultNo;
1010   lookupValue = owner->getResult(groupResultNo);
1011 }
1012 
setValueName(Value value,StringRef name)1013 void SSANameState::setValueName(Value value, StringRef name) {
1014   // If the name is empty, the value uses the default numbering.
1015   if (name.empty()) {
1016     valueIDs[value] = nextValueID++;
1017     return;
1018   }
1019 
1020   valueIDs[value] = NameSentinel;
1021   valueNames[value] = uniqueValueName(name);
1022 }
1023 
uniqueValueName(StringRef name)1024 StringRef SSANameState::uniqueValueName(StringRef name) {
1025   SmallString<16> tmpBuffer;
1026   name = sanitizeIdentifier(name, tmpBuffer);
1027 
1028   // Check to see if this name is already unique.
1029   if (!usedNames.count(name)) {
1030     name = name.copy(usedNameAllocator);
1031   } else {
1032     // Otherwise, we had a conflict - probe until we find a unique name. This
1033     // is guaranteed to terminate (and usually in a single iteration) because it
1034     // generates new names by incrementing nextConflictID.
1035     SmallString<64> probeName(name);
1036     probeName.push_back('_');
1037     while (true) {
1038       probeName += llvm::utostr(nextConflictID++);
1039       if (!usedNames.count(probeName)) {
1040         name = StringRef(probeName).copy(usedNameAllocator);
1041         break;
1042       }
1043       probeName.resize(name.size() + 1);
1044     }
1045   }
1046 
1047   usedNames.insert(name, char());
1048   return name;
1049 }
1050 
1051 //===----------------------------------------------------------------------===//
1052 // AsmState
1053 //===----------------------------------------------------------------------===//
1054 
1055 namespace mlir {
1056 namespace detail {
1057 class AsmStateImpl {
1058 public:
AsmStateImpl(Operation * op,AsmState::LocationMap * locationMap)1059   explicit AsmStateImpl(Operation *op, AsmState::LocationMap *locationMap)
1060       : interfaces(op->getContext()), nameState(op, interfaces),
1061         locationMap(locationMap) {}
1062 
1063   /// Initialize the alias state to enable the printing of aliases.
initializeAliases(Operation * op,const OpPrintingFlags & printerFlags)1064   void initializeAliases(Operation *op, const OpPrintingFlags &printerFlags) {
1065     aliasState.initialize(op, printerFlags, interfaces);
1066   }
1067 
1068   /// Get an instance of the OpAsmDialectInterface for the given dialect, or
1069   /// null if one wasn't registered.
getOpAsmInterface(Dialect * dialect)1070   const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) {
1071     return interfaces.getInterfaceFor(dialect);
1072   }
1073 
1074   /// Get the state used for aliases.
getAliasState()1075   AliasState &getAliasState() { return aliasState; }
1076 
1077   /// Get the state used for SSA names.
getSSANameState()1078   SSANameState &getSSANameState() { return nameState; }
1079 
1080   /// Register the location, line and column, within the buffer that the given
1081   /// operation was printed at.
registerOperationLocation(Operation * op,unsigned line,unsigned col)1082   void registerOperationLocation(Operation *op, unsigned line, unsigned col) {
1083     if (locationMap)
1084       (*locationMap)[op] = std::make_pair(line, col);
1085   }
1086 
1087 private:
1088   /// Collection of OpAsm interfaces implemented in the context.
1089   DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
1090 
1091   /// The state used for attribute and type aliases.
1092   AliasState aliasState;
1093 
1094   /// The state used for SSA value names.
1095   SSANameState nameState;
1096 
1097   /// An optional location map to be populated.
1098   AsmState::LocationMap *locationMap;
1099 };
1100 } // end namespace detail
1101 } // end namespace mlir
1102 
AsmState(Operation * op,LocationMap * locationMap)1103 AsmState::AsmState(Operation *op, LocationMap *locationMap)
1104     : impl(std::make_unique<AsmStateImpl>(op, locationMap)) {}
~AsmState()1105 AsmState::~AsmState() {}
1106 
1107 //===----------------------------------------------------------------------===//
1108 // ModulePrinter
1109 //===----------------------------------------------------------------------===//
1110 
1111 namespace {
1112 class ModulePrinter {
1113 public:
ModulePrinter(raw_ostream & os,OpPrintingFlags flags=llvm::None,AsmStateImpl * state=nullptr)1114   ModulePrinter(raw_ostream &os, OpPrintingFlags flags = llvm::None,
1115                 AsmStateImpl *state = nullptr)
1116       : os(os), printerFlags(flags), state(state) {}
ModulePrinter(ModulePrinter & printer)1117   explicit ModulePrinter(ModulePrinter &printer)
1118       : os(printer.os), printerFlags(printer.printerFlags),
1119         state(printer.state) {}
1120 
1121   /// Returns the output stream of the printer.
getStream()1122   raw_ostream &getStream() { return os; }
1123 
1124   template <typename Container, typename UnaryFunctor>
interleaveComma(const Container & c,UnaryFunctor each_fn) const1125   inline void interleaveComma(const Container &c, UnaryFunctor each_fn) const {
1126     llvm::interleaveComma(c, os, each_fn);
1127   }
1128 
1129   /// This enum describes the different kinds of elision for the type of an
1130   /// attribute when printing it.
1131   enum class AttrTypeElision {
1132     /// The type must not be elided,
1133     Never,
1134     /// The type may be elided when it matches the default used in the parser
1135     /// (for example i64 is the default for integer attributes).
1136     May,
1137     /// The type must be elided.
1138     Must
1139   };
1140 
1141   /// Print the given attribute.
1142   void printAttribute(Attribute attr,
1143                       AttrTypeElision typeElision = AttrTypeElision::Never);
1144 
1145   void printType(Type type);
1146 
1147   /// Print the given location to the stream. If `allowAlias` is true, this
1148   /// allows for the internal location to use an attribute alias.
1149   void printLocation(LocationAttr loc, bool allowAlias = false);
1150 
1151   void printAffineMap(AffineMap map);
1152   void
1153   printAffineExpr(AffineExpr expr,
1154                   function_ref<void(unsigned, bool)> printValueName = nullptr);
1155   void printAffineConstraint(AffineExpr expr, bool isEq);
1156   void printIntegerSet(IntegerSet set);
1157 
1158 protected:
1159   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
1160                              ArrayRef<StringRef> elidedAttrs = {},
1161                              bool withKeyword = false);
1162   void printNamedAttribute(NamedAttribute attr);
1163   void printTrailingLocation(Location loc);
1164   void printLocationInternal(LocationAttr loc, bool pretty = false);
1165 
1166   /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
1167   /// used instead of individual elements when the elements attr is large.
1168   void printDenseElementsAttr(DenseElementsAttr attr, bool allowHex);
1169 
1170   /// Print a dense string elements attribute.
1171   void printDenseStringElementsAttr(DenseStringElementsAttr attr);
1172 
1173   /// Print a dense elements attribute. If 'allowHex' is true, a hex string is
1174   /// used instead of individual elements when the elements attr is large.
1175   void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
1176                                      bool allowHex);
1177 
1178   void printDialectAttribute(Attribute attr);
1179   void printDialectType(Type type);
1180 
1181   /// This enum is used to represent the binding strength of the enclosing
1182   /// context that an AffineExprStorage is being printed in, so we can
1183   /// intelligently produce parens.
1184   enum class BindingStrength {
1185     Weak,   // + and -
1186     Strong, // All other binary operators.
1187   };
1188   void printAffineExprInternal(
1189       AffineExpr expr, BindingStrength enclosingTightness,
1190       function_ref<void(unsigned, bool)> printValueName = nullptr);
1191 
1192   /// The output stream for the printer.
1193   raw_ostream &os;
1194 
1195   /// A set of flags to control the printer's behavior.
1196   OpPrintingFlags printerFlags;
1197 
1198   /// An optional printer state for the module.
1199   AsmStateImpl *state;
1200 
1201   /// A tracker for the number of new lines emitted during printing.
1202   NewLineCounter newLine;
1203 };
1204 } // end anonymous namespace
1205 
printTrailingLocation(Location loc)1206 void ModulePrinter::printTrailingLocation(Location loc) {
1207   // Check to see if we are printing debug information.
1208   if (!printerFlags.shouldPrintDebugInfo())
1209     return;
1210 
1211   os << " ";
1212   printLocation(loc, /*allowAlias=*/true);
1213 }
1214 
printLocationInternal(LocationAttr loc,bool pretty)1215 void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) {
1216   TypeSwitch<LocationAttr>(loc)
1217       .Case<OpaqueLoc>([&](OpaqueLoc loc) {
1218         printLocationInternal(loc.getFallbackLocation(), pretty);
1219       })
1220       .Case<UnknownLoc>([&](UnknownLoc loc) {
1221         if (pretty)
1222           os << "[unknown]";
1223         else
1224           os << "unknown";
1225       })
1226       .Case<FileLineColLoc>([&](FileLineColLoc loc) {
1227         StringRef mayQuote = pretty ? "" : "\"";
1228         os << mayQuote << loc.getFilename() << mayQuote << ':' << loc.getLine()
1229            << ':' << loc.getColumn();
1230       })
1231       .Case<NameLoc>([&](NameLoc loc) {
1232         os << '\"' << loc.getName() << '\"';
1233 
1234         // Print the child if it isn't unknown.
1235         auto childLoc = loc.getChildLoc();
1236         if (!childLoc.isa<UnknownLoc>()) {
1237           os << '(';
1238           printLocationInternal(childLoc, pretty);
1239           os << ')';
1240         }
1241       })
1242       .Case<CallSiteLoc>([&](CallSiteLoc loc) {
1243         Location caller = loc.getCaller();
1244         Location callee = loc.getCallee();
1245         if (!pretty)
1246           os << "callsite(";
1247         printLocationInternal(callee, pretty);
1248         if (pretty) {
1249           if (callee.isa<NameLoc>()) {
1250             if (caller.isa<FileLineColLoc>()) {
1251               os << " at ";
1252             } else {
1253               os << newLine << " at ";
1254             }
1255           } else {
1256             os << newLine << " at ";
1257           }
1258         } else {
1259           os << " at ";
1260         }
1261         printLocationInternal(caller, pretty);
1262         if (!pretty)
1263           os << ")";
1264       })
1265       .Case<FusedLoc>([&](FusedLoc loc) {
1266         if (!pretty)
1267           os << "fused";
1268         if (Attribute metadata = loc.getMetadata())
1269           os << '<' << metadata << '>';
1270         os << '[';
1271         interleave(
1272             loc.getLocations(),
1273             [&](Location loc) { printLocationInternal(loc, pretty); },
1274             [&]() { os << ", "; });
1275         os << ']';
1276       });
1277 }
1278 
1279 /// Print a floating point value in a way that the parser will be able to
1280 /// round-trip losslessly.
printFloatValue(const APFloat & apValue,raw_ostream & os)1281 static void printFloatValue(const APFloat &apValue, raw_ostream &os) {
1282   // We would like to output the FP constant value in exponential notation,
1283   // but we cannot do this if doing so will lose precision.  Check here to
1284   // make sure that we only output it in exponential format if we can parse
1285   // the value back and get the same value.
1286   bool isInf = apValue.isInfinity();
1287   bool isNaN = apValue.isNaN();
1288   if (!isInf && !isNaN) {
1289     SmallString<128> strValue;
1290     apValue.toString(strValue, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
1291                      /*TruncateZero=*/false);
1292 
1293     // Check to make sure that the stringized number is not some string like
1294     // "Inf" or NaN, that atof will accept, but the lexer will not.  Check
1295     // that the string matches the "[-+]?[0-9]" regex.
1296     assert(((strValue[0] >= '0' && strValue[0] <= '9') ||
1297             ((strValue[0] == '-' || strValue[0] == '+') &&
1298              (strValue[1] >= '0' && strValue[1] <= '9'))) &&
1299            "[-+]?[0-9] regex does not match!");
1300 
1301     // Parse back the stringized version and check that the value is equal
1302     // (i.e., there is no precision loss).
1303     if (APFloat(apValue.getSemantics(), strValue).bitwiseIsEqual(apValue)) {
1304       os << strValue;
1305       return;
1306     }
1307 
1308     // If it is not, use the default format of APFloat instead of the
1309     // exponential notation.
1310     strValue.clear();
1311     apValue.toString(strValue);
1312 
1313     // Make sure that we can parse the default form as a float.
1314     if (StringRef(strValue).contains('.')) {
1315       os << strValue;
1316       return;
1317     }
1318   }
1319 
1320   // Print special values in hexadecimal format. The sign bit should be included
1321   // in the literal.
1322   SmallVector<char, 16> str;
1323   APInt apInt = apValue.bitcastToAPInt();
1324   apInt.toString(str, /*Radix=*/16, /*Signed=*/false,
1325                  /*formatAsCLiteral=*/true);
1326   os << str;
1327 }
1328 
printLocation(LocationAttr loc,bool allowAlias)1329 void ModulePrinter::printLocation(LocationAttr loc, bool allowAlias) {
1330   if (printerFlags.shouldPrintDebugInfoPrettyForm())
1331     return printLocationInternal(loc, /*pretty=*/true);
1332 
1333   os << "loc(";
1334   if (!allowAlias || !state || failed(state->getAliasState().getAlias(loc, os)))
1335     printLocationInternal(loc);
1336   os << ')';
1337 }
1338 
1339 /// Returns true if the given dialect symbol data is simple enough to print in
1340 /// the pretty form, i.e. without the enclosing "".
isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName)1341 static bool isDialectSymbolSimpleEnoughForPrettyForm(StringRef symName) {
1342   // The name must start with an identifier.
1343   if (symName.empty() || !isalpha(symName.front()))
1344     return false;
1345 
1346   // Ignore all the characters that are valid in an identifier in the symbol
1347   // name.
1348   symName = symName.drop_while(
1349       [](char c) { return llvm::isAlnum(c) || c == '.' || c == '_'; });
1350   if (symName.empty())
1351     return true;
1352 
1353   // If we got to an unexpected character, then it must be a <>.  Check those
1354   // recursively.
1355   if (symName.front() != '<' || symName.back() != '>')
1356     return false;
1357 
1358   SmallVector<char, 8> nestedPunctuation;
1359   do {
1360     // If we ran out of characters, then we had a punctuation mismatch.
1361     if (symName.empty())
1362       return false;
1363 
1364     auto c = symName.front();
1365     symName = symName.drop_front();
1366 
1367     switch (c) {
1368     // We never allow null characters. This is an EOF indicator for the lexer
1369     // which we could handle, but isn't important for any known dialect.
1370     case '\0':
1371       return false;
1372     case '<':
1373     case '[':
1374     case '(':
1375     case '{':
1376       nestedPunctuation.push_back(c);
1377       continue;
1378     case '-':
1379       // Treat `->` as a special token.
1380       if (!symName.empty() && symName.front() == '>') {
1381         symName = symName.drop_front();
1382         continue;
1383       }
1384       break;
1385     // Reject types with mismatched brackets.
1386     case '>':
1387       if (nestedPunctuation.pop_back_val() != '<')
1388         return false;
1389       break;
1390     case ']':
1391       if (nestedPunctuation.pop_back_val() != '[')
1392         return false;
1393       break;
1394     case ')':
1395       if (nestedPunctuation.pop_back_val() != '(')
1396         return false;
1397       break;
1398     case '}':
1399       if (nestedPunctuation.pop_back_val() != '{')
1400         return false;
1401       break;
1402     default:
1403       continue;
1404     }
1405 
1406     // We're done when the punctuation is fully matched.
1407   } while (!nestedPunctuation.empty());
1408 
1409   // If there were extra characters, then we failed.
1410   return symName.empty();
1411 }
1412 
1413 /// Print the given dialect symbol to the stream.
printDialectSymbol(raw_ostream & os,StringRef symPrefix,StringRef dialectName,StringRef symString)1414 static void printDialectSymbol(raw_ostream &os, StringRef symPrefix,
1415                                StringRef dialectName, StringRef symString) {
1416   os << symPrefix << dialectName;
1417 
1418   // If this symbol name is simple enough, print it directly in pretty form,
1419   // otherwise, we print it as an escaped string.
1420   if (isDialectSymbolSimpleEnoughForPrettyForm(symString)) {
1421     os << '.' << symString;
1422     return;
1423   }
1424 
1425   // TODO: escape the symbol name, it could contain " characters.
1426   os << "<\"" << symString << "\">";
1427 }
1428 
1429 /// Returns true if the given string can be represented as a bare identifier.
isBareIdentifier(StringRef name)1430 static bool isBareIdentifier(StringRef name) {
1431   assert(!name.empty() && "invalid name");
1432 
1433   // By making this unsigned, the value passed in to isalnum will always be
1434   // in the range 0-255. This is important when building with MSVC because
1435   // its implementation will assert. This situation can arise when dealing
1436   // with UTF-8 multibyte characters.
1437   unsigned char firstChar = static_cast<unsigned char>(name[0]);
1438   if (!isalpha(firstChar) && firstChar != '_')
1439     return false;
1440   return llvm::all_of(name.drop_front(), [](unsigned char c) {
1441     return isalnum(c) || c == '_' || c == '$' || c == '.';
1442   });
1443 }
1444 
1445 /// Print the given string as a symbol reference. A symbol reference is
1446 /// represented as a string prefixed with '@'. The reference is surrounded with
1447 /// ""'s and escaped if it has any special or non-printable characters in it.
printSymbolReference(StringRef symbolRef,raw_ostream & os)1448 static void printSymbolReference(StringRef symbolRef, raw_ostream &os) {
1449   assert(!symbolRef.empty() && "expected valid symbol reference");
1450 
1451   // If the symbol can be represented as a bare identifier, write it directly.
1452   if (isBareIdentifier(symbolRef)) {
1453     os << '@' << symbolRef;
1454     return;
1455   }
1456 
1457   // Otherwise, output the reference wrapped in quotes with proper escaping.
1458   os << "@\"";
1459   printEscapedString(symbolRef, os);
1460   os << '"';
1461 }
1462 
1463 // Print out a valid ElementsAttr that is succinct and can represent any
1464 // potential shape/type, for use when eliding a large ElementsAttr.
1465 //
1466 // We choose to use an opaque ElementsAttr literal with conspicuous content to
1467 // hopefully alert readers to the fact that this has been elided.
1468 //
1469 // Unfortunately, neither of the strings of an opaque ElementsAttr literal will
1470 // accept the string "elided". The first string must be a registered dialect
1471 // name and the latter must be a hex constant.
printElidedElementsAttr(raw_ostream & os)1472 static void printElidedElementsAttr(raw_ostream &os) {
1473   os << R"(opaque<"", "0xDEADBEEF">)";
1474 }
1475 
printAttribute(Attribute attr,AttrTypeElision typeElision)1476 void ModulePrinter::printAttribute(Attribute attr,
1477                                    AttrTypeElision typeElision) {
1478   if (!attr) {
1479     os << "<<NULL ATTRIBUTE>>";
1480     return;
1481   }
1482 
1483   // Try to print an alias for this attribute.
1484   if (state && succeeded(state->getAliasState().getAlias(attr, os)))
1485     return;
1486 
1487   auto attrType = attr.getType();
1488   if (auto opaqueAttr = attr.dyn_cast<OpaqueAttr>()) {
1489     printDialectSymbol(os, "#", opaqueAttr.getDialectNamespace(),
1490                        opaqueAttr.getAttrData());
1491   } else if (attr.isa<UnitAttr>()) {
1492     os << "unit";
1493     return;
1494   } else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) {
1495     os << '{';
1496     interleaveComma(dictAttr.getValue(),
1497                     [&](NamedAttribute attr) { printNamedAttribute(attr); });
1498     os << '}';
1499 
1500   } else if (auto intAttr = attr.dyn_cast<IntegerAttr>()) {
1501     if (attrType.isSignlessInteger(1)) {
1502       os << (intAttr.getValue().getBoolValue() ? "true" : "false");
1503 
1504       // Boolean integer attributes always elides the type.
1505       return;
1506     }
1507 
1508     // Only print attributes as unsigned if they are explicitly unsigned or are
1509     // signless 1-bit values.  Indexes, signed values, and multi-bit signless
1510     // values print as signed.
1511     bool isUnsigned =
1512         attrType.isUnsignedInteger() || attrType.isSignlessInteger(1);
1513     intAttr.getValue().print(os, !isUnsigned);
1514 
1515     // IntegerAttr elides the type if I64.
1516     if (typeElision == AttrTypeElision::May && attrType.isSignlessInteger(64))
1517       return;
1518 
1519   } else if (auto floatAttr = attr.dyn_cast<FloatAttr>()) {
1520     printFloatValue(floatAttr.getValue(), os);
1521 
1522     // FloatAttr elides the type if F64.
1523     if (typeElision == AttrTypeElision::May && attrType.isF64())
1524       return;
1525 
1526   } else if (auto strAttr = attr.dyn_cast<StringAttr>()) {
1527     os << '"';
1528     printEscapedString(strAttr.getValue(), os);
1529     os << '"';
1530 
1531   } else if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
1532     os << '[';
1533     interleaveComma(arrayAttr.getValue(), [&](Attribute attr) {
1534       printAttribute(attr, AttrTypeElision::May);
1535     });
1536     os << ']';
1537 
1538   } else if (auto affineMapAttr = attr.dyn_cast<AffineMapAttr>()) {
1539     os << "affine_map<";
1540     affineMapAttr.getValue().print(os);
1541     os << '>';
1542 
1543     // AffineMap always elides the type.
1544     return;
1545 
1546   } else if (auto integerSetAttr = attr.dyn_cast<IntegerSetAttr>()) {
1547     os << "affine_set<";
1548     integerSetAttr.getValue().print(os);
1549     os << '>';
1550 
1551     // IntegerSet always elides the type.
1552     return;
1553 
1554   } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
1555     printType(typeAttr.getValue());
1556 
1557   } else if (auto refAttr = attr.dyn_cast<SymbolRefAttr>()) {
1558     printSymbolReference(refAttr.getRootReference(), os);
1559     for (FlatSymbolRefAttr nestedRef : refAttr.getNestedReferences()) {
1560       os << "::";
1561       printSymbolReference(nestedRef.getValue(), os);
1562     }
1563 
1564   } else if (auto opaqueAttr = attr.dyn_cast<OpaqueElementsAttr>()) {
1565     if (printerFlags.shouldElideElementsAttr(opaqueAttr)) {
1566       printElidedElementsAttr(os);
1567     } else {
1568       os << "opaque<\"" << opaqueAttr.getDialect()->getNamespace() << "\", ";
1569       os << '"' << "0x" << llvm::toHex(opaqueAttr.getValue()) << "\">";
1570     }
1571 
1572   } else if (auto intOrFpEltAttr = attr.dyn_cast<DenseIntOrFPElementsAttr>()) {
1573     if (printerFlags.shouldElideElementsAttr(intOrFpEltAttr)) {
1574       printElidedElementsAttr(os);
1575     } else {
1576       os << "dense<";
1577       printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true);
1578       os << '>';
1579     }
1580 
1581   } else if (auto strEltAttr = attr.dyn_cast<DenseStringElementsAttr>()) {
1582     if (printerFlags.shouldElideElementsAttr(strEltAttr)) {
1583       printElidedElementsAttr(os);
1584     } else {
1585       os << "dense<";
1586       printDenseStringElementsAttr(strEltAttr);
1587       os << '>';
1588     }
1589 
1590   } else if (auto sparseEltAttr = attr.dyn_cast<SparseElementsAttr>()) {
1591     if (printerFlags.shouldElideElementsAttr(sparseEltAttr.getIndices()) ||
1592         printerFlags.shouldElideElementsAttr(sparseEltAttr.getValues())) {
1593       printElidedElementsAttr(os);
1594     } else {
1595       os << "sparse<";
1596       DenseIntElementsAttr indices = sparseEltAttr.getIndices();
1597       if (indices.getNumElements() != 0) {
1598         printDenseIntOrFPElementsAttr(indices, /*allowHex=*/false);
1599         os << ", ";
1600         printDenseElementsAttr(sparseEltAttr.getValues(), /*allowHex=*/true);
1601       }
1602       os << '>';
1603     }
1604 
1605   } else if (auto locAttr = attr.dyn_cast<LocationAttr>()) {
1606     printLocation(locAttr);
1607 
1608   } else {
1609     return printDialectAttribute(attr);
1610   }
1611 
1612   // Don't print the type if we must elide it, or if it is a None type.
1613   if (typeElision != AttrTypeElision::Must && !attrType.isa<NoneType>()) {
1614     os << " : ";
1615     printType(attrType);
1616   }
1617 }
1618 
1619 /// Print the integer element of a DenseElementsAttr.
printDenseIntElement(const APInt & value,raw_ostream & os,bool isSigned)1620 static void printDenseIntElement(const APInt &value, raw_ostream &os,
1621                                  bool isSigned) {
1622   if (value.getBitWidth() == 1)
1623     os << (value.getBoolValue() ? "true" : "false");
1624   else
1625     value.print(os, isSigned);
1626 }
1627 
1628 static void
printDenseElementsAttrImpl(bool isSplat,ShapedType type,raw_ostream & os,function_ref<void (unsigned)> printEltFn)1629 printDenseElementsAttrImpl(bool isSplat, ShapedType type, raw_ostream &os,
1630                            function_ref<void(unsigned)> printEltFn) {
1631   // Special case for 0-d and splat tensors.
1632   if (isSplat)
1633     return printEltFn(0);
1634 
1635   // Special case for degenerate tensors.
1636   auto numElements = type.getNumElements();
1637   if (numElements == 0)
1638     return;
1639 
1640   // We use a mixed-radix counter to iterate through the shape. When we bump a
1641   // non-least-significant digit, we emit a close bracket. When we next emit an
1642   // element we re-open all closed brackets.
1643 
1644   // The mixed-radix counter, with radices in 'shape'.
1645   int64_t rank = type.getRank();
1646   SmallVector<unsigned, 4> counter(rank, 0);
1647   // The number of brackets that have been opened and not closed.
1648   unsigned openBrackets = 0;
1649 
1650   auto shape = type.getShape();
1651   auto bumpCounter = [&] {
1652     // Bump the least significant digit.
1653     ++counter[rank - 1];
1654     // Iterate backwards bubbling back the increment.
1655     for (unsigned i = rank - 1; i > 0; --i)
1656       if (counter[i] >= shape[i]) {
1657         // Index 'i' is rolled over. Bump (i-1) and close a bracket.
1658         counter[i] = 0;
1659         ++counter[i - 1];
1660         --openBrackets;
1661         os << ']';
1662       }
1663   };
1664 
1665   for (unsigned idx = 0, e = numElements; idx != e; ++idx) {
1666     if (idx != 0)
1667       os << ", ";
1668     while (openBrackets++ < rank)
1669       os << '[';
1670     openBrackets = rank;
1671     printEltFn(idx);
1672     bumpCounter();
1673   }
1674   while (openBrackets-- > 0)
1675     os << ']';
1676 }
1677 
printDenseElementsAttr(DenseElementsAttr attr,bool allowHex)1678 void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr,
1679                                            bool allowHex) {
1680   if (auto stringAttr = attr.dyn_cast<DenseStringElementsAttr>())
1681     return printDenseStringElementsAttr(stringAttr);
1682 
1683   printDenseIntOrFPElementsAttr(attr.cast<DenseIntOrFPElementsAttr>(),
1684                                 allowHex);
1685 }
1686 
printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,bool allowHex)1687 void ModulePrinter::printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr,
1688                                                   bool allowHex) {
1689   auto type = attr.getType();
1690   auto elementType = type.getElementType();
1691 
1692   // Check to see if we should format this attribute as a hex string.
1693   auto numElements = type.getNumElements();
1694   if (!attr.isSplat() && allowHex &&
1695       shouldPrintElementsAttrWithHex(numElements)) {
1696     ArrayRef<char> rawData = attr.getRawData();
1697     if (llvm::support::endian::system_endianness() ==
1698         llvm::support::endianness::big) {
1699       // Convert endianess in big-endian(BE) machines. `rawData` is BE in BE
1700       // machines. It is converted here to print in LE format.
1701       SmallVector<char, 64> outDataVec(rawData.size());
1702       MutableArrayRef<char> convRawData(outDataVec);
1703       DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1704           rawData, convRawData, type);
1705       os << '"' << "0x"
1706          << llvm::toHex(StringRef(convRawData.data(), convRawData.size()))
1707          << "\"";
1708     } else {
1709       os << '"' << "0x"
1710          << llvm::toHex(StringRef(rawData.data(), rawData.size())) << "\"";
1711     }
1712 
1713     return;
1714   }
1715 
1716   if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1717     Type complexElementType = complexTy.getElementType();
1718     // Note: The if and else below had a common lambda function which invoked
1719     // printDenseElementsAttrImpl. This lambda was hitting a bug in gcc 9.1,9.2
1720     // and hence was replaced.
1721     if (complexElementType.isa<IntegerType>()) {
1722       bool isSigned = !complexElementType.isUnsignedInteger();
1723       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1724         auto complexValue = *(attr.getComplexIntValues().begin() + index);
1725         os << "(";
1726         printDenseIntElement(complexValue.real(), os, isSigned);
1727         os << ",";
1728         printDenseIntElement(complexValue.imag(), os, isSigned);
1729         os << ")";
1730       });
1731     } else {
1732       printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1733         auto complexValue = *(attr.getComplexFloatValues().begin() + index);
1734         os << "(";
1735         printFloatValue(complexValue.real(), os);
1736         os << ",";
1737         printFloatValue(complexValue.imag(), os);
1738         os << ")";
1739       });
1740     }
1741   } else if (elementType.isIntOrIndex()) {
1742     bool isSigned = !elementType.isUnsignedInteger();
1743     auto intValues = attr.getIntValues();
1744     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1745       printDenseIntElement(*(intValues.begin() + index), os, isSigned);
1746     });
1747   } else {
1748     assert(elementType.isa<FloatType>() && "unexpected element type");
1749     auto floatValues = attr.getFloatValues();
1750     printDenseElementsAttrImpl(attr.isSplat(), type, os, [&](unsigned index) {
1751       printFloatValue(*(floatValues.begin() + index), os);
1752     });
1753   }
1754 }
1755 
printDenseStringElementsAttr(DenseStringElementsAttr attr)1756 void ModulePrinter::printDenseStringElementsAttr(DenseStringElementsAttr attr) {
1757   ArrayRef<StringRef> data = attr.getRawStringData();
1758   auto printFn = [&](unsigned index) {
1759     os << "\"";
1760     printEscapedString(data[index], os);
1761     os << "\"";
1762   };
1763   printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn);
1764 }
1765 
printType(Type type)1766 void ModulePrinter::printType(Type type) {
1767   if (!type) {
1768     os << "<<NULL TYPE>>";
1769     return;
1770   }
1771 
1772   // Try to print an alias for this type.
1773   if (state && succeeded(state->getAliasState().getAlias(type, os)))
1774     return;
1775 
1776   TypeSwitch<Type>(type)
1777       .Case<OpaqueType>([&](OpaqueType opaqueTy) {
1778         printDialectSymbol(os, "!", opaqueTy.getDialectNamespace(),
1779                            opaqueTy.getTypeData());
1780       })
1781       .Case<IndexType>([&](Type) { os << "index"; })
1782       .Case<BFloat16Type>([&](Type) { os << "bf16"; })
1783       .Case<Float16Type>([&](Type) { os << "f16"; })
1784       .Case<Float32Type>([&](Type) { os << "f32"; })
1785       .Case<Float64Type>([&](Type) { os << "f64"; })
1786       .Case<IntegerType>([&](IntegerType integerTy) {
1787         if (integerTy.isSigned())
1788           os << 's';
1789         else if (integerTy.isUnsigned())
1790           os << 'u';
1791         os << 'i' << integerTy.getWidth();
1792       })
1793       .Case<FunctionType>([&](FunctionType funcTy) {
1794         os << '(';
1795         interleaveComma(funcTy.getInputs(), [&](Type ty) { printType(ty); });
1796         os << ") -> ";
1797         ArrayRef<Type> results = funcTy.getResults();
1798         if (results.size() == 1 && !results[0].isa<FunctionType>()) {
1799           os << results[0];
1800         } else {
1801           os << '(';
1802           interleaveComma(results, [&](Type ty) { printType(ty); });
1803           os << ')';
1804         }
1805       })
1806       .Case<VectorType>([&](VectorType vectorTy) {
1807         os << "vector<";
1808         for (int64_t dim : vectorTy.getShape())
1809           os << dim << 'x';
1810         os << vectorTy.getElementType() << '>';
1811       })
1812       .Case<RankedTensorType>([&](RankedTensorType tensorTy) {
1813         os << "tensor<";
1814         for (int64_t dim : tensorTy.getShape()) {
1815           if (ShapedType::isDynamic(dim))
1816             os << '?';
1817           else
1818             os << dim;
1819           os << 'x';
1820         }
1821         os << tensorTy.getElementType() << '>';
1822       })
1823       .Case<UnrankedTensorType>([&](UnrankedTensorType tensorTy) {
1824         os << "tensor<*x";
1825         printType(tensorTy.getElementType());
1826         os << '>';
1827       })
1828       .Case<MemRefType>([&](MemRefType memrefTy) {
1829         os << "memref<";
1830         for (int64_t dim : memrefTy.getShape()) {
1831           if (ShapedType::isDynamic(dim))
1832             os << '?';
1833           else
1834             os << dim;
1835           os << 'x';
1836         }
1837         printType(memrefTy.getElementType());
1838         for (auto map : memrefTy.getAffineMaps()) {
1839           os << ", ";
1840           printAttribute(AffineMapAttr::get(map));
1841         }
1842         // Only print the memory space if it is the non-default one.
1843         if (memrefTy.getMemorySpace())
1844           os << ", " << memrefTy.getMemorySpace();
1845         os << '>';
1846       })
1847       .Case<UnrankedMemRefType>([&](UnrankedMemRefType memrefTy) {
1848         os << "memref<*x";
1849         printType(memrefTy.getElementType());
1850         // Only print the memory space if it is the non-default one.
1851         if (memrefTy.getMemorySpace())
1852           os << ", " << memrefTy.getMemorySpace();
1853         os << '>';
1854       })
1855       .Case<ComplexType>([&](ComplexType complexTy) {
1856         os << "complex<";
1857         printType(complexTy.getElementType());
1858         os << '>';
1859       })
1860       .Case<TupleType>([&](TupleType tupleTy) {
1861         os << "tuple<";
1862         interleaveComma(tupleTy.getTypes(),
1863                         [&](Type type) { printType(type); });
1864         os << '>';
1865       })
1866       .Case<NoneType>([&](Type) { os << "none"; })
1867       .Default([&](Type type) { return printDialectType(type); });
1868 }
1869 
printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs,bool withKeyword)1870 void ModulePrinter::printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
1871                                           ArrayRef<StringRef> elidedAttrs,
1872                                           bool withKeyword) {
1873   // If there are no attributes, then there is nothing to be done.
1874   if (attrs.empty())
1875     return;
1876 
1877   // Filter out any attributes that shouldn't be included.
1878   SmallVector<NamedAttribute, 8> filteredAttrs(
1879       llvm::make_filter_range(attrs, [&](NamedAttribute attr) {
1880         return !llvm::is_contained(elidedAttrs, attr.first.strref());
1881       }));
1882 
1883   // If there are no attributes left to print after filtering, then we're done.
1884   if (filteredAttrs.empty())
1885     return;
1886 
1887   // Print the 'attributes' keyword if necessary.
1888   if (withKeyword)
1889     os << " attributes";
1890 
1891   // Otherwise, print them all out in braces.
1892   os << " {";
1893   interleaveComma(filteredAttrs,
1894                   [&](NamedAttribute attr) { printNamedAttribute(attr); });
1895   os << '}';
1896 }
1897 
printNamedAttribute(NamedAttribute attr)1898 void ModulePrinter::printNamedAttribute(NamedAttribute attr) {
1899   if (isBareIdentifier(attr.first)) {
1900     os << attr.first;
1901   } else {
1902     os << '"';
1903     printEscapedString(attr.first.strref(), os);
1904     os << '"';
1905   }
1906 
1907   // Pretty printing elides the attribute value for unit attributes.
1908   if (attr.second.isa<UnitAttr>())
1909     return;
1910 
1911   os << " = ";
1912   printAttribute(attr.second);
1913 }
1914 
1915 //===----------------------------------------------------------------------===//
1916 // CustomDialectAsmPrinter
1917 //===----------------------------------------------------------------------===//
1918 
1919 namespace {
1920 /// This class provides the main specialization of the DialectAsmPrinter that is
1921 /// used to provide support for print attributes and types. This hooks allows
1922 /// for dialects to hook into the main ModulePrinter.
1923 struct CustomDialectAsmPrinter : public DialectAsmPrinter {
1924 public:
CustomDialectAsmPrinter__anon2591390f3611::CustomDialectAsmPrinter1925   CustomDialectAsmPrinter(ModulePrinter &printer) : printer(printer) {}
~CustomDialectAsmPrinter__anon2591390f3611::CustomDialectAsmPrinter1926   ~CustomDialectAsmPrinter() override {}
1927 
getStream__anon2591390f3611::CustomDialectAsmPrinter1928   raw_ostream &getStream() const override { return printer.getStream(); }
1929 
1930   /// Print the given attribute to the stream.
printAttribute__anon2591390f3611::CustomDialectAsmPrinter1931   void printAttribute(Attribute attr) override { printer.printAttribute(attr); }
1932 
1933   /// Print the given floating point value in a stablized form.
printFloat__anon2591390f3611::CustomDialectAsmPrinter1934   void printFloat(const APFloat &value) override {
1935     printFloatValue(value, getStream());
1936   }
1937 
1938   /// Print the given type to the stream.
printType__anon2591390f3611::CustomDialectAsmPrinter1939   void printType(Type type) override { printer.printType(type); }
1940 
1941   /// The main module printer.
1942   ModulePrinter &printer;
1943 };
1944 } // end anonymous namespace
1945 
printDialectAttribute(Attribute attr)1946 void ModulePrinter::printDialectAttribute(Attribute attr) {
1947   auto &dialect = attr.getDialect();
1948 
1949   // Ask the dialect to serialize the attribute to a string.
1950   std::string attrName;
1951   {
1952     llvm::raw_string_ostream attrNameStr(attrName);
1953     ModulePrinter subPrinter(attrNameStr, printerFlags, state);
1954     CustomDialectAsmPrinter printer(subPrinter);
1955     dialect.printAttribute(attr, printer);
1956   }
1957   printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
1958 }
1959 
printDialectType(Type type)1960 void ModulePrinter::printDialectType(Type type) {
1961   auto &dialect = type.getDialect();
1962 
1963   // Ask the dialect to serialize the type to a string.
1964   std::string typeName;
1965   {
1966     llvm::raw_string_ostream typeNameStr(typeName);
1967     ModulePrinter subPrinter(typeNameStr, printerFlags, state);
1968     CustomDialectAsmPrinter printer(subPrinter);
1969     dialect.printType(type, printer);
1970   }
1971   printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
1972 }
1973 
1974 //===----------------------------------------------------------------------===//
1975 // Affine expressions and maps
1976 //===----------------------------------------------------------------------===//
1977 
printAffineExpr(AffineExpr expr,function_ref<void (unsigned,bool)> printValueName)1978 void ModulePrinter::printAffineExpr(
1979     AffineExpr expr, function_ref<void(unsigned, bool)> printValueName) {
1980   printAffineExprInternal(expr, BindingStrength::Weak, printValueName);
1981 }
1982 
printAffineExprInternal(AffineExpr expr,BindingStrength enclosingTightness,function_ref<void (unsigned,bool)> printValueName)1983 void ModulePrinter::printAffineExprInternal(
1984     AffineExpr expr, BindingStrength enclosingTightness,
1985     function_ref<void(unsigned, bool)> printValueName) {
1986   const char *binopSpelling = nullptr;
1987   switch (expr.getKind()) {
1988   case AffineExprKind::SymbolId: {
1989     unsigned pos = expr.cast<AffineSymbolExpr>().getPosition();
1990     if (printValueName)
1991       printValueName(pos, /*isSymbol=*/true);
1992     else
1993       os << 's' << pos;
1994     return;
1995   }
1996   case AffineExprKind::DimId: {
1997     unsigned pos = expr.cast<AffineDimExpr>().getPosition();
1998     if (printValueName)
1999       printValueName(pos, /*isSymbol=*/false);
2000     else
2001       os << 'd' << pos;
2002     return;
2003   }
2004   case AffineExprKind::Constant:
2005     os << expr.cast<AffineConstantExpr>().getValue();
2006     return;
2007   case AffineExprKind::Add:
2008     binopSpelling = " + ";
2009     break;
2010   case AffineExprKind::Mul:
2011     binopSpelling = " * ";
2012     break;
2013   case AffineExprKind::FloorDiv:
2014     binopSpelling = " floordiv ";
2015     break;
2016   case AffineExprKind::CeilDiv:
2017     binopSpelling = " ceildiv ";
2018     break;
2019   case AffineExprKind::Mod:
2020     binopSpelling = " mod ";
2021     break;
2022   }
2023 
2024   auto binOp = expr.cast<AffineBinaryOpExpr>();
2025   AffineExpr lhsExpr = binOp.getLHS();
2026   AffineExpr rhsExpr = binOp.getRHS();
2027 
2028   // Handle tightly binding binary operators.
2029   if (binOp.getKind() != AffineExprKind::Add) {
2030     if (enclosingTightness == BindingStrength::Strong)
2031       os << '(';
2032 
2033     // Pretty print multiplication with -1.
2034     auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>();
2035     if (rhsConst && binOp.getKind() == AffineExprKind::Mul &&
2036         rhsConst.getValue() == -1) {
2037       os << "-";
2038       printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2039       if (enclosingTightness == BindingStrength::Strong)
2040         os << ')';
2041       return;
2042     }
2043 
2044     printAffineExprInternal(lhsExpr, BindingStrength::Strong, printValueName);
2045 
2046     os << binopSpelling;
2047     printAffineExprInternal(rhsExpr, BindingStrength::Strong, printValueName);
2048 
2049     if (enclosingTightness == BindingStrength::Strong)
2050       os << ')';
2051     return;
2052   }
2053 
2054   // Print out special "pretty" forms for add.
2055   if (enclosingTightness == BindingStrength::Strong)
2056     os << '(';
2057 
2058   // Pretty print addition to a product that has a negative operand as a
2059   // subtraction.
2060   if (auto rhs = rhsExpr.dyn_cast<AffineBinaryOpExpr>()) {
2061     if (rhs.getKind() == AffineExprKind::Mul) {
2062       AffineExpr rrhsExpr = rhs.getRHS();
2063       if (auto rrhs = rrhsExpr.dyn_cast<AffineConstantExpr>()) {
2064         if (rrhs.getValue() == -1) {
2065           printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2066                                   printValueName);
2067           os << " - ";
2068           if (rhs.getLHS().getKind() == AffineExprKind::Add) {
2069             printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2070                                     printValueName);
2071           } else {
2072             printAffineExprInternal(rhs.getLHS(), BindingStrength::Weak,
2073                                     printValueName);
2074           }
2075 
2076           if (enclosingTightness == BindingStrength::Strong)
2077             os << ')';
2078           return;
2079         }
2080 
2081         if (rrhs.getValue() < -1) {
2082           printAffineExprInternal(lhsExpr, BindingStrength::Weak,
2083                                   printValueName);
2084           os << " - ";
2085           printAffineExprInternal(rhs.getLHS(), BindingStrength::Strong,
2086                                   printValueName);
2087           os << " * " << -rrhs.getValue();
2088           if (enclosingTightness == BindingStrength::Strong)
2089             os << ')';
2090           return;
2091         }
2092       }
2093     }
2094   }
2095 
2096   // Pretty print addition to a negative number as a subtraction.
2097   if (auto rhsConst = rhsExpr.dyn_cast<AffineConstantExpr>()) {
2098     if (rhsConst.getValue() < 0) {
2099       printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2100       os << " - " << -rhsConst.getValue();
2101       if (enclosingTightness == BindingStrength::Strong)
2102         os << ')';
2103       return;
2104     }
2105   }
2106 
2107   printAffineExprInternal(lhsExpr, BindingStrength::Weak, printValueName);
2108 
2109   os << " + ";
2110   printAffineExprInternal(rhsExpr, BindingStrength::Weak, printValueName);
2111 
2112   if (enclosingTightness == BindingStrength::Strong)
2113     os << ')';
2114 }
2115 
printAffineConstraint(AffineExpr expr,bool isEq)2116 void ModulePrinter::printAffineConstraint(AffineExpr expr, bool isEq) {
2117   printAffineExprInternal(expr, BindingStrength::Weak);
2118   isEq ? os << " == 0" : os << " >= 0";
2119 }
2120 
printAffineMap(AffineMap map)2121 void ModulePrinter::printAffineMap(AffineMap map) {
2122   // Dimension identifiers.
2123   os << '(';
2124   for (int i = 0; i < (int)map.getNumDims() - 1; ++i)
2125     os << 'd' << i << ", ";
2126   if (map.getNumDims() >= 1)
2127     os << 'd' << map.getNumDims() - 1;
2128   os << ')';
2129 
2130   // Symbolic identifiers.
2131   if (map.getNumSymbols() != 0) {
2132     os << '[';
2133     for (unsigned i = 0; i < map.getNumSymbols() - 1; ++i)
2134       os << 's' << i << ", ";
2135     if (map.getNumSymbols() >= 1)
2136       os << 's' << map.getNumSymbols() - 1;
2137     os << ']';
2138   }
2139 
2140   // Result affine expressions.
2141   os << " -> (";
2142   interleaveComma(map.getResults(),
2143                   [&](AffineExpr expr) { printAffineExpr(expr); });
2144   os << ')';
2145 }
2146 
printIntegerSet(IntegerSet set)2147 void ModulePrinter::printIntegerSet(IntegerSet set) {
2148   // Dimension identifiers.
2149   os << '(';
2150   for (unsigned i = 1; i < set.getNumDims(); ++i)
2151     os << 'd' << i - 1 << ", ";
2152   if (set.getNumDims() >= 1)
2153     os << 'd' << set.getNumDims() - 1;
2154   os << ')';
2155 
2156   // Symbolic identifiers.
2157   if (set.getNumSymbols() != 0) {
2158     os << '[';
2159     for (unsigned i = 0; i < set.getNumSymbols() - 1; ++i)
2160       os << 's' << i << ", ";
2161     if (set.getNumSymbols() >= 1)
2162       os << 's' << set.getNumSymbols() - 1;
2163     os << ']';
2164   }
2165 
2166   // Print constraints.
2167   os << " : (";
2168   int numConstraints = set.getNumConstraints();
2169   for (int i = 1; i < numConstraints; ++i) {
2170     printAffineConstraint(set.getConstraint(i - 1), set.isEq(i - 1));
2171     os << ", ";
2172   }
2173   if (numConstraints >= 1)
2174     printAffineConstraint(set.getConstraint(numConstraints - 1),
2175                           set.isEq(numConstraints - 1));
2176   os << ')';
2177 }
2178 
2179 //===----------------------------------------------------------------------===//
2180 // OperationPrinter
2181 //===----------------------------------------------------------------------===//
2182 
2183 namespace {
2184 /// This class contains the logic for printing operations, regions, and blocks.
2185 class OperationPrinter : public ModulePrinter, private OpAsmPrinter {
2186 public:
OperationPrinter(raw_ostream & os,OpPrintingFlags flags,AsmStateImpl & state)2187   explicit OperationPrinter(raw_ostream &os, OpPrintingFlags flags,
2188                             AsmStateImpl &state)
2189       : ModulePrinter(os, flags, &state) {}
2190 
2191   /// Print the given top-level operation.
2192   void printTopLevelOperation(Operation *op);
2193 
2194   /// Print the given operation with its indent and location.
2195   void print(Operation *op);
2196   /// Print the bare location, not including indentation/location/etc.
2197   void printOperation(Operation *op);
2198   /// Print the given operation in the generic form.
2199   void printGenericOp(Operation *op) override;
2200 
2201   /// Print the name of the given block.
2202   void printBlockName(Block *block);
2203 
2204   /// Print the given block. If 'printBlockArgs' is false, the arguments of the
2205   /// block are not printed. If 'printBlockTerminator' is false, the terminator
2206   /// operation of the block is not printed.
2207   void print(Block *block, bool printBlockArgs = true,
2208              bool printBlockTerminator = true);
2209 
2210   /// Print the ID of the given value, optionally with its result number.
2211   void printValueID(Value value, bool printResultNo = true,
2212                     raw_ostream *streamOverride = nullptr) const;
2213 
2214   //===--------------------------------------------------------------------===//
2215   // OpAsmPrinter methods
2216   //===--------------------------------------------------------------------===//
2217 
2218   /// Return the current stream of the printer.
getStream() const2219   raw_ostream &getStream() const override { return os; }
2220 
2221   /// Print the given type.
printType(Type type)2222   void printType(Type type) override { ModulePrinter::printType(type); }
2223 
2224   /// Print the given attribute.
printAttribute(Attribute attr)2225   void printAttribute(Attribute attr) override {
2226     ModulePrinter::printAttribute(attr);
2227   }
2228 
2229   /// Print the given attribute without its type. The corresponding parser must
2230   /// provide a valid type for the attribute.
printAttributeWithoutType(Attribute attr)2231   void printAttributeWithoutType(Attribute attr) override {
2232     ModulePrinter::printAttribute(attr, AttrTypeElision::Must);
2233   }
2234 
2235   /// Print the ID for the given value.
printOperand(Value value)2236   void printOperand(Value value) override { printValueID(value); }
printOperand(Value value,raw_ostream & os)2237   void printOperand(Value value, raw_ostream &os) override {
2238     printValueID(value, /*printResultNo=*/true, &os);
2239   }
2240 
2241   /// Print an optional attribute dictionary with a given set of elided values.
printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})2242   void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
2243                              ArrayRef<StringRef> elidedAttrs = {}) override {
2244     ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs);
2245   }
printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,ArrayRef<StringRef> elidedAttrs={})2246   void printOptionalAttrDictWithKeyword(
2247       ArrayRef<NamedAttribute> attrs,
2248       ArrayRef<StringRef> elidedAttrs = {}) override {
2249     ModulePrinter::printOptionalAttrDict(attrs, elidedAttrs,
2250                                          /*withKeyword=*/true);
2251   }
2252 
2253   /// Print the given successor.
2254   void printSuccessor(Block *successor) override;
2255 
2256   /// Print an operation successor with the operands used for the block
2257   /// arguments.
2258   void printSuccessorAndUseList(Block *successor,
2259                                 ValueRange succOperands) override;
2260 
2261   /// Print the given region.
2262   void printRegion(Region &region, bool printEntryBlockArgs,
2263                    bool printBlockTerminators) override;
2264 
2265   /// Renumber the arguments for the specified region to the same names as the
2266   /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
2267   /// operations. If any entry in namesToUse is null, the corresponding
2268   /// argument name is left alone.
shadowRegionArgs(Region & region,ValueRange namesToUse)2269   void shadowRegionArgs(Region &region, ValueRange namesToUse) override {
2270     state->getSSANameState().shadowRegionArgs(region, namesToUse);
2271   }
2272 
2273   /// Print the given affine map with the symbol and dimension operands printed
2274   /// inline with the map.
2275   void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2276                               ValueRange operands) override;
2277 
2278   /// Print the given string as a symbol reference.
printSymbolName(StringRef symbolRef)2279   void printSymbolName(StringRef symbolRef) override {
2280     ::printSymbolReference(symbolRef, os);
2281   }
2282 
2283 private:
2284   /// The number of spaces used for indenting nested operations.
2285   const static unsigned indentWidth = 2;
2286 
2287   // This is the current indentation level for nested structures.
2288   unsigned currentIndent = 0;
2289 };
2290 } // end anonymous namespace
2291 
printTopLevelOperation(Operation * op)2292 void OperationPrinter::printTopLevelOperation(Operation *op) {
2293   // Output the aliases at the top level that can't be deferred.
2294   state->getAliasState().printNonDeferredAliases(os, newLine);
2295 
2296   // Print the module.
2297   print(op);
2298   os << newLine;
2299 
2300   // Output the aliases at the top level that can be deferred.
2301   state->getAliasState().printDeferredAliases(os, newLine);
2302 }
2303 
print(Operation * op)2304 void OperationPrinter::print(Operation *op) {
2305   // Track the location of this operation.
2306   state->registerOperationLocation(op, newLine.curLine, currentIndent);
2307 
2308   os.indent(currentIndent);
2309   printOperation(op);
2310   printTrailingLocation(op->getLoc());
2311 }
2312 
printOperation(Operation * op)2313 void OperationPrinter::printOperation(Operation *op) {
2314   if (size_t numResults = op->getNumResults()) {
2315     auto printResultGroup = [&](size_t resultNo, size_t resultCount) {
2316       printValueID(op->getResult(resultNo), /*printResultNo=*/false);
2317       if (resultCount > 1)
2318         os << ':' << resultCount;
2319     };
2320 
2321     // Check to see if this operation has multiple result groups.
2322     ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op);
2323     if (!resultGroups.empty()) {
2324       // Interleave the groups excluding the last one, this one will be handled
2325       // separately.
2326       interleaveComma(llvm::seq<int>(0, resultGroups.size() - 1), [&](int i) {
2327         printResultGroup(resultGroups[i],
2328                          resultGroups[i + 1] - resultGroups[i]);
2329       });
2330       os << ", ";
2331       printResultGroup(resultGroups.back(), numResults - resultGroups.back());
2332 
2333     } else {
2334       printResultGroup(/*resultNo=*/0, /*resultCount=*/numResults);
2335     }
2336 
2337     os << " = ";
2338   }
2339 
2340   // If requested, always print the generic form.
2341   if (!printerFlags.shouldPrintGenericOpForm()) {
2342     // Check to see if this is a known operation.  If so, use the registered
2343     // custom printer hook.
2344     if (auto *opInfo = op->getAbstractOperation()) {
2345       opInfo->printAssembly(op, *this);
2346       return;
2347     }
2348   }
2349 
2350   // Otherwise print with the generic assembly form.
2351   printGenericOp(op);
2352 }
2353 
printGenericOp(Operation * op)2354 void OperationPrinter::printGenericOp(Operation *op) {
2355   os << '"';
2356   printEscapedString(op->getName().getStringRef(), os);
2357   os << "\"(";
2358   interleaveComma(op->getOperands(), [&](Value value) { printValueID(value); });
2359   os << ')';
2360 
2361   // For terminators, print the list of successors and their operands.
2362   if (op->getNumSuccessors() != 0) {
2363     os << '[';
2364     interleaveComma(op->getSuccessors(),
2365                     [&](Block *successor) { printBlockName(successor); });
2366     os << ']';
2367   }
2368 
2369   // Print regions.
2370   if (op->getNumRegions() != 0) {
2371     os << " (";
2372     interleaveComma(op->getRegions(), [&](Region &region) {
2373       printRegion(region, /*printEntryBlockArgs=*/true,
2374                   /*printBlockTerminators=*/true);
2375     });
2376     os << ')';
2377   }
2378 
2379   auto attrs = op->getAttrs();
2380   printOptionalAttrDict(attrs);
2381 
2382   // Print the type signature of the operation.
2383   os << " : ";
2384   printFunctionalType(op);
2385 }
2386 
printBlockName(Block * block)2387 void OperationPrinter::printBlockName(Block *block) {
2388   auto id = state->getSSANameState().getBlockID(block);
2389   if (id != SSANameState::NameSentinel)
2390     os << "^bb" << id;
2391   else
2392     os << "^INVALIDBLOCK";
2393 }
2394 
print(Block * block,bool printBlockArgs,bool printBlockTerminator)2395 void OperationPrinter::print(Block *block, bool printBlockArgs,
2396                              bool printBlockTerminator) {
2397   // Print the block label and argument list if requested.
2398   if (printBlockArgs) {
2399     os.indent(currentIndent);
2400     printBlockName(block);
2401 
2402     // Print the argument list if non-empty.
2403     if (!block->args_empty()) {
2404       os << '(';
2405       interleaveComma(block->getArguments(), [&](BlockArgument arg) {
2406         printValueID(arg);
2407         os << ": ";
2408         printType(arg.getType());
2409       });
2410       os << ')';
2411     }
2412     os << ':';
2413 
2414     // Print out some context information about the predecessors of this block.
2415     if (!block->getParent()) {
2416       os << "  // block is not in a region!";
2417     } else if (block->hasNoPredecessors()) {
2418       os << "  // no predecessors";
2419     } else if (auto *pred = block->getSinglePredecessor()) {
2420       os << "  // pred: ";
2421       printBlockName(pred);
2422     } else {
2423       // We want to print the predecessors in increasing numeric order, not in
2424       // whatever order the use-list is in, so gather and sort them.
2425       SmallVector<std::pair<unsigned, Block *>, 4> predIDs;
2426       for (auto *pred : block->getPredecessors())
2427         predIDs.push_back({state->getSSANameState().getBlockID(pred), pred});
2428       llvm::array_pod_sort(predIDs.begin(), predIDs.end());
2429 
2430       os << "  // " << predIDs.size() << " preds: ";
2431 
2432       interleaveComma(predIDs, [&](std::pair<unsigned, Block *> pred) {
2433         printBlockName(pred.second);
2434       });
2435     }
2436     os << newLine;
2437   }
2438 
2439   currentIndent += indentWidth;
2440   auto range = llvm::make_range(
2441       block->begin(), std::prev(block->end(), printBlockTerminator ? 0 : 1));
2442   for (auto &op : range) {
2443     print(&op);
2444     os << newLine;
2445   }
2446   currentIndent -= indentWidth;
2447 }
2448 
printValueID(Value value,bool printResultNo,raw_ostream * streamOverride) const2449 void OperationPrinter::printValueID(Value value, bool printResultNo,
2450                                     raw_ostream *streamOverride) const {
2451   state->getSSANameState().printValueID(value, printResultNo,
2452                                         streamOverride ? *streamOverride : os);
2453 }
2454 
printSuccessor(Block * successor)2455 void OperationPrinter::printSuccessor(Block *successor) {
2456   printBlockName(successor);
2457 }
2458 
printSuccessorAndUseList(Block * successor,ValueRange succOperands)2459 void OperationPrinter::printSuccessorAndUseList(Block *successor,
2460                                                 ValueRange succOperands) {
2461   printBlockName(successor);
2462   if (succOperands.empty())
2463     return;
2464 
2465   os << '(';
2466   interleaveComma(succOperands,
2467                   [this](Value operand) { printValueID(operand); });
2468   os << " : ";
2469   interleaveComma(succOperands,
2470                   [this](Value operand) { printType(operand.getType()); });
2471   os << ')';
2472 }
2473 
printRegion(Region & region,bool printEntryBlockArgs,bool printBlockTerminators)2474 void OperationPrinter::printRegion(Region &region, bool printEntryBlockArgs,
2475                                    bool printBlockTerminators) {
2476   os << " {" << newLine;
2477   if (!region.empty()) {
2478     auto *entryBlock = &region.front();
2479     print(entryBlock, printEntryBlockArgs && entryBlock->getNumArguments() != 0,
2480           printBlockTerminators);
2481     for (auto &b : llvm::drop_begin(region.getBlocks(), 1))
2482       print(&b);
2483   }
2484   os.indent(currentIndent) << "}";
2485 }
2486 
printAffineMapOfSSAIds(AffineMapAttr mapAttr,ValueRange operands)2487 void OperationPrinter::printAffineMapOfSSAIds(AffineMapAttr mapAttr,
2488                                               ValueRange operands) {
2489   AffineMap map = mapAttr.getValue();
2490   unsigned numDims = map.getNumDims();
2491   auto printValueName = [&](unsigned pos, bool isSymbol) {
2492     unsigned index = isSymbol ? numDims + pos : pos;
2493     assert(index < operands.size());
2494     if (isSymbol)
2495       os << "symbol(";
2496     printValueID(operands[index]);
2497     if (isSymbol)
2498       os << ')';
2499   };
2500 
2501   interleaveComma(map.getResults(), [&](AffineExpr expr) {
2502     printAffineExpr(expr, printValueName);
2503   });
2504 }
2505 
2506 //===----------------------------------------------------------------------===//
2507 // print and dump methods
2508 //===----------------------------------------------------------------------===//
2509 
print(raw_ostream & os) const2510 void Attribute::print(raw_ostream &os) const {
2511   ModulePrinter(os).printAttribute(*this);
2512 }
2513 
dump() const2514 void Attribute::dump() const {
2515   print(llvm::errs());
2516   llvm::errs() << "\n";
2517 }
2518 
print(raw_ostream & os)2519 void Type::print(raw_ostream &os) { ModulePrinter(os).printType(*this); }
2520 
dump()2521 void Type::dump() { print(llvm::errs()); }
2522 
dump() const2523 void AffineMap::dump() const {
2524   print(llvm::errs());
2525   llvm::errs() << "\n";
2526 }
2527 
dump() const2528 void IntegerSet::dump() const {
2529   print(llvm::errs());
2530   llvm::errs() << "\n";
2531 }
2532 
print(raw_ostream & os) const2533 void AffineExpr::print(raw_ostream &os) const {
2534   if (!expr) {
2535     os << "<<NULL AFFINE EXPR>>";
2536     return;
2537   }
2538   ModulePrinter(os).printAffineExpr(*this);
2539 }
2540 
dump() const2541 void AffineExpr::dump() const {
2542   print(llvm::errs());
2543   llvm::errs() << "\n";
2544 }
2545 
print(raw_ostream & os) const2546 void AffineMap::print(raw_ostream &os) const {
2547   if (!map) {
2548     os << "<<NULL AFFINE MAP>>";
2549     return;
2550   }
2551   ModulePrinter(os).printAffineMap(*this);
2552 }
2553 
print(raw_ostream & os) const2554 void IntegerSet::print(raw_ostream &os) const {
2555   ModulePrinter(os).printIntegerSet(*this);
2556 }
2557 
print(raw_ostream & os)2558 void Value::print(raw_ostream &os) {
2559   if (auto *op = getDefiningOp())
2560     return op->print(os);
2561   // TODO: Improve this.
2562   BlockArgument arg = this->cast<BlockArgument>();
2563   os << "<block argument> of type '" << arg.getType()
2564      << "' at index: " << arg.getArgNumber() << '\n';
2565 }
print(raw_ostream & os,AsmState & state)2566 void Value::print(raw_ostream &os, AsmState &state) {
2567   if (auto *op = getDefiningOp())
2568     return op->print(os, state);
2569 
2570   // TODO: Improve this.
2571   BlockArgument arg = this->cast<BlockArgument>();
2572   os << "<block argument> of type '" << arg.getType()
2573      << "' at index: " << arg.getArgNumber() << '\n';
2574 }
2575 
dump()2576 void Value::dump() {
2577   print(llvm::errs());
2578   llvm::errs() << "\n";
2579 }
2580 
printAsOperand(raw_ostream & os,AsmState & state)2581 void Value::printAsOperand(raw_ostream &os, AsmState &state) {
2582   // TODO: This doesn't necessarily capture all potential cases.
2583   // Currently, region arguments can be shadowed when printing the main
2584   // operation. If the IR hasn't been printed, this will produce the old SSA
2585   // name and not the shadowed name.
2586   state.getImpl().getSSANameState().printValueID(*this, /*printResultNo=*/true,
2587                                                  os);
2588 }
2589 
print(raw_ostream & os,OpPrintingFlags flags)2590 void Operation::print(raw_ostream &os, OpPrintingFlags flags) {
2591   // If this is a top level operation, we also print aliases.
2592   if (!getParent() && !flags.shouldUseLocalScope()) {
2593     AsmState state(this);
2594     state.getImpl().initializeAliases(this, flags);
2595     print(os, state, flags);
2596     return;
2597   }
2598 
2599   // Find the operation to number from based upon the provided flags.
2600   Operation *printedOp = this;
2601   bool shouldUseLocalScope = flags.shouldUseLocalScope();
2602   do {
2603     // If we are printing local scope, stop at the first operation that is
2604     // isolated from above.
2605     if (shouldUseLocalScope && printedOp->isKnownIsolatedFromAbove())
2606       break;
2607 
2608     // Otherwise, traverse up to the next parent.
2609     Operation *parentOp = printedOp->getParentOp();
2610     if (!parentOp)
2611       break;
2612     printedOp = parentOp;
2613   } while (true);
2614 
2615   AsmState state(printedOp);
2616   print(os, state, flags);
2617 }
print(raw_ostream & os,AsmState & state,OpPrintingFlags flags)2618 void Operation::print(raw_ostream &os, AsmState &state, OpPrintingFlags flags) {
2619   OperationPrinter printer(os, flags, state.getImpl());
2620   if (!getParent() && !flags.shouldUseLocalScope())
2621     printer.printTopLevelOperation(this);
2622   else
2623     printer.print(this);
2624 }
2625 
dump()2626 void Operation::dump() {
2627   print(llvm::errs(), OpPrintingFlags().useLocalScope());
2628   llvm::errs() << "\n";
2629 }
2630 
print(raw_ostream & os)2631 void Block::print(raw_ostream &os) {
2632   Operation *parentOp = getParentOp();
2633   if (!parentOp) {
2634     os << "<<UNLINKED BLOCK>>\n";
2635     return;
2636   }
2637   // Get the top-level op.
2638   while (auto *nextOp = parentOp->getParentOp())
2639     parentOp = nextOp;
2640 
2641   AsmState state(parentOp);
2642   print(os, state);
2643 }
print(raw_ostream & os,AsmState & state)2644 void Block::print(raw_ostream &os, AsmState &state) {
2645   OperationPrinter(os, /*flags=*/llvm::None, state.getImpl()).print(this);
2646 }
2647 
dump()2648 void Block::dump() { print(llvm::errs()); }
2649 
2650 /// Print out the name of the block without printing its body.
printAsOperand(raw_ostream & os,bool printType)2651 void Block::printAsOperand(raw_ostream &os, bool printType) {
2652   Operation *parentOp = getParentOp();
2653   if (!parentOp) {
2654     os << "<<UNLINKED BLOCK>>\n";
2655     return;
2656   }
2657   AsmState state(parentOp);
2658   printAsOperand(os, state);
2659 }
printAsOperand(raw_ostream & os,AsmState & state)2660 void Block::printAsOperand(raw_ostream &os, AsmState &state) {
2661   OperationPrinter printer(os, /*flags=*/llvm::None, state.getImpl());
2662   printer.printBlockName(this);
2663 }
2664