• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===-- FIRAttr.cpp -------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Optimizer/Dialect/FIRAttr.h"
10 #include "flang/Optimizer/Dialect/FIRDialect.h"
11 #include "flang/Optimizer/Support/KindMapping.h"
12 #include "mlir/IR/AttributeSupport.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "mlir/IR/Types.h"
15 #include "llvm/ADT/SmallString.h"
16 
17 using namespace fir;
18 
19 namespace fir {
20 namespace detail {
21 
22 struct RealAttributeStorage : public mlir::AttributeStorage {
23   using KeyTy = std::pair<int, llvm::APFloat>;
24 
RealAttributeStoragefir::detail::RealAttributeStorage25   RealAttributeStorage(int kind, const llvm::APFloat &value)
26       : kind(kind), value(value) {}
RealAttributeStoragefir::detail::RealAttributeStorage27   RealAttributeStorage(const KeyTy &key)
28       : RealAttributeStorage(key.first, key.second) {}
29 
hashKeyfir::detail::RealAttributeStorage30   static unsigned hashKey(const KeyTy &key) { return llvm::hash_value(key); }
31 
operator ==fir::detail::RealAttributeStorage32   bool operator==(const KeyTy &key) const {
33     return key.first == kind &&
34            key.second.compare(value) == llvm::APFloatBase::cmpEqual;
35   }
36 
37   static RealAttributeStorage *
constructfir::detail::RealAttributeStorage38   construct(mlir::AttributeStorageAllocator &allocator, const KeyTy &key) {
39     return new (allocator.allocate<RealAttributeStorage>())
40         RealAttributeStorage(key);
41   }
42 
getFKindfir::detail::RealAttributeStorage43   int getFKind() const { return kind; }
getValuefir::detail::RealAttributeStorage44   llvm::APFloat getValue() const { return value; }
45 
46 private:
47   int kind;
48   llvm::APFloat value;
49 };
50 
51 /// An attribute representing a reference to a type.
52 struct TypeAttributeStorage : public mlir::AttributeStorage {
53   using KeyTy = mlir::Type;
54 
TypeAttributeStoragefir::detail::TypeAttributeStorage55   TypeAttributeStorage(mlir::Type value) : value(value) {
56     assert(value && "must not be of Type null");
57   }
58 
59   /// Key equality function.
operator ==fir::detail::TypeAttributeStorage60   bool operator==(const KeyTy &key) const { return key == value; }
61 
62   /// Construct a new storage instance.
63   static TypeAttributeStorage *
constructfir::detail::TypeAttributeStorage64   construct(mlir::AttributeStorageAllocator &allocator, KeyTy key) {
65     return new (allocator.allocate<TypeAttributeStorage>())
66         TypeAttributeStorage(key);
67   }
68 
getTypefir::detail::TypeAttributeStorage69   mlir::Type getType() const { return value; }
70 
71 private:
72   mlir::Type value;
73 };
74 } // namespace detail
75 
get(mlir::Type value)76 ExactTypeAttr ExactTypeAttr::get(mlir::Type value) {
77   return Base::get(value.getContext(), value);
78 }
79 
getType() const80 mlir::Type ExactTypeAttr::getType() const { return getImpl()->getType(); }
81 
get(mlir::Type value)82 SubclassAttr SubclassAttr::get(mlir::Type value) {
83   return Base::get(value.getContext(), value);
84 }
85 
getType() const86 mlir::Type SubclassAttr::getType() const { return getImpl()->getType(); }
87 
88 using AttributeUniquer = mlir::detail::AttributeUniquer;
89 
get(mlir::MLIRContext * ctxt)90 ClosedIntervalAttr ClosedIntervalAttr::get(mlir::MLIRContext *ctxt) {
91   return AttributeUniquer::get<ClosedIntervalAttr>(ctxt);
92 }
93 
get(mlir::MLIRContext * ctxt)94 UpperBoundAttr UpperBoundAttr::get(mlir::MLIRContext *ctxt) {
95   return AttributeUniquer::get<UpperBoundAttr>(ctxt);
96 }
97 
get(mlir::MLIRContext * ctxt)98 LowerBoundAttr LowerBoundAttr::get(mlir::MLIRContext *ctxt) {
99   return AttributeUniquer::get<LowerBoundAttr>(ctxt);
100 }
101 
get(mlir::MLIRContext * ctxt)102 PointIntervalAttr PointIntervalAttr::get(mlir::MLIRContext *ctxt) {
103   return AttributeUniquer::get<PointIntervalAttr>(ctxt);
104 }
105 
106 // RealAttr
107 
get(mlir::MLIRContext * ctxt,const RealAttr::ValueType & key)108 RealAttr RealAttr::get(mlir::MLIRContext *ctxt,
109                        const RealAttr::ValueType &key) {
110   return Base::get(ctxt, key);
111 }
112 
getFKind() const113 int RealAttr::getFKind() const { return getImpl()->getFKind(); }
114 
getValue() const115 llvm::APFloat RealAttr::getValue() const { return getImpl()->getValue(); }
116 
117 // FIR attribute parsing
118 
119 namespace {
parseFirRealAttr(FIROpsDialect * dialect,mlir::DialectAsmParser & parser,mlir::Type type)120 mlir::Attribute parseFirRealAttr(FIROpsDialect *dialect,
121                                  mlir::DialectAsmParser &parser,
122                                  mlir::Type type) {
123   int kind = 0;
124   if (parser.parseLess() || parser.parseInteger(kind) || parser.parseComma()) {
125     parser.emitError(parser.getNameLoc(), "expected '<' kind ','");
126     return {};
127   }
128   KindMapping kindMap(dialect->getContext());
129   llvm::APFloat value(0.);
130   if (parser.parseOptionalKeyword("i")) {
131     // `i` not present, so literal float must be present
132     double dontCare;
133     if (parser.parseFloat(dontCare) || parser.parseGreater()) {
134       parser.emitError(parser.getNameLoc(), "expected real constant '>'");
135       return {};
136     }
137     auto fltStr = parser.getFullSymbolSpec()
138                       .drop_until([](char c) { return c == ','; })
139                       .drop_front()
140                       .drop_while([](char c) { return c == ' ' || c == '\t'; })
141                       .take_until([](char c) {
142                         return c == '>' || c == ' ' || c == '\t';
143                       });
144     value = llvm::APFloat(kindMap.getFloatSemantics(kind), fltStr);
145   } else {
146     // `i` is present, so literal bitstring (hex) must be present
147     llvm::StringRef hex;
148     if (parser.parseKeyword(&hex) || parser.parseGreater()) {
149       parser.emitError(parser.getNameLoc(), "expected real constant '>'");
150       return {};
151     }
152     auto bits = llvm::APInt(kind * 8, hex.drop_front(), 16);
153     value = llvm::APFloat(kindMap.getFloatSemantics(kind), bits);
154   }
155   return RealAttr::get(dialect->getContext(), {kind, value});
156 }
157 } // namespace
158 
parseFirAttribute(FIROpsDialect * dialect,mlir::DialectAsmParser & parser,mlir::Type type)159 mlir::Attribute parseFirAttribute(FIROpsDialect *dialect,
160                                   mlir::DialectAsmParser &parser,
161                                   mlir::Type type) {
162   auto loc = parser.getNameLoc();
163   llvm::StringRef attrName;
164   if (parser.parseKeyword(&attrName)) {
165     parser.emitError(loc, "expected an attribute name");
166     return {};
167   }
168 
169   if (attrName == ExactTypeAttr::getAttrName()) {
170     mlir::Type type;
171     if (parser.parseLess() || parser.parseType(type) || parser.parseGreater()) {
172       parser.emitError(loc, "expected a type");
173       return {};
174     }
175     return ExactTypeAttr::get(type);
176   }
177   if (attrName == SubclassAttr::getAttrName()) {
178     mlir::Type type;
179     if (parser.parseLess() || parser.parseType(type) || parser.parseGreater()) {
180       parser.emitError(loc, "expected a subtype");
181       return {};
182     }
183     return SubclassAttr::get(type);
184   }
185   if (attrName == PointIntervalAttr::getAttrName())
186     return PointIntervalAttr::get(dialect->getContext());
187   if (attrName == LowerBoundAttr::getAttrName())
188     return LowerBoundAttr::get(dialect->getContext());
189   if (attrName == UpperBoundAttr::getAttrName())
190     return UpperBoundAttr::get(dialect->getContext());
191   if (attrName == ClosedIntervalAttr::getAttrName())
192     return ClosedIntervalAttr::get(dialect->getContext());
193   if (attrName == RealAttr::getAttrName())
194     return parseFirRealAttr(dialect, parser, type);
195 
196   parser.emitError(loc, "unknown FIR attribute: ") << attrName;
197   return {};
198 }
199 
200 // FIR attribute pretty printer
201 
printFirAttribute(FIROpsDialect * dialect,mlir::Attribute attr,mlir::DialectAsmPrinter & p)202 void printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
203                        mlir::DialectAsmPrinter &p) {
204   auto &os = p.getStream();
205   if (auto exact = attr.dyn_cast<fir::ExactTypeAttr>()) {
206     os << fir::ExactTypeAttr::getAttrName() << '<';
207     p.printType(exact.getType());
208     os << '>';
209   } else if (auto sub = attr.dyn_cast<fir::SubclassAttr>()) {
210     os << fir::SubclassAttr::getAttrName() << '<';
211     p.printType(sub.getType());
212     os << '>';
213   } else if (attr.dyn_cast_or_null<fir::PointIntervalAttr>()) {
214     os << fir::PointIntervalAttr::getAttrName();
215   } else if (attr.dyn_cast_or_null<fir::ClosedIntervalAttr>()) {
216     os << fir::ClosedIntervalAttr::getAttrName();
217   } else if (attr.dyn_cast_or_null<fir::LowerBoundAttr>()) {
218     os << fir::LowerBoundAttr::getAttrName();
219   } else if (attr.dyn_cast_or_null<fir::UpperBoundAttr>()) {
220     os << fir::UpperBoundAttr::getAttrName();
221   } else if (auto a = attr.dyn_cast_or_null<fir::RealAttr>()) {
222     os << fir::RealAttr::getAttrName() << '<' << a.getFKind() << ", i x";
223     llvm::SmallString<40> ss;
224     a.getValue().bitcastToAPInt().toStringUnsigned(ss, 16);
225     os << ss << '>';
226   } else {
227     llvm_unreachable("attribute pretty-printer is not implemented");
228   }
229 }
230 
231 } // namespace fir
232