1 //===-- FIRAttr.cpp -------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "flang/Optimizer/Dialect/FIRAttr.h"
10 #include "flang/Optimizer/Dialect/FIRDialect.h"
11 #include "flang/Optimizer/Support/KindMapping.h"
12 #include "mlir/IR/AttributeSupport.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "mlir/IR/Types.h"
15 #include "llvm/ADT/SmallString.h"
16
17 using namespace fir;
18
19 namespace fir {
20 namespace detail {
21
22 struct RealAttributeStorage : public mlir::AttributeStorage {
23 using KeyTy = std::pair<int, llvm::APFloat>;
24
RealAttributeStoragefir::detail::RealAttributeStorage25 RealAttributeStorage(int kind, const llvm::APFloat &value)
26 : kind(kind), value(value) {}
RealAttributeStoragefir::detail::RealAttributeStorage27 RealAttributeStorage(const KeyTy &key)
28 : RealAttributeStorage(key.first, key.second) {}
29
hashKeyfir::detail::RealAttributeStorage30 static unsigned hashKey(const KeyTy &key) { return llvm::hash_value(key); }
31
operator ==fir::detail::RealAttributeStorage32 bool operator==(const KeyTy &key) const {
33 return key.first == kind &&
34 key.second.compare(value) == llvm::APFloatBase::cmpEqual;
35 }
36
37 static RealAttributeStorage *
constructfir::detail::RealAttributeStorage38 construct(mlir::AttributeStorageAllocator &allocator, const KeyTy &key) {
39 return new (allocator.allocate<RealAttributeStorage>())
40 RealAttributeStorage(key);
41 }
42
getFKindfir::detail::RealAttributeStorage43 int getFKind() const { return kind; }
getValuefir::detail::RealAttributeStorage44 llvm::APFloat getValue() const { return value; }
45
46 private:
47 int kind;
48 llvm::APFloat value;
49 };
50
51 /// An attribute representing a reference to a type.
52 struct TypeAttributeStorage : public mlir::AttributeStorage {
53 using KeyTy = mlir::Type;
54
TypeAttributeStoragefir::detail::TypeAttributeStorage55 TypeAttributeStorage(mlir::Type value) : value(value) {
56 assert(value && "must not be of Type null");
57 }
58
59 /// Key equality function.
operator ==fir::detail::TypeAttributeStorage60 bool operator==(const KeyTy &key) const { return key == value; }
61
62 /// Construct a new storage instance.
63 static TypeAttributeStorage *
constructfir::detail::TypeAttributeStorage64 construct(mlir::AttributeStorageAllocator &allocator, KeyTy key) {
65 return new (allocator.allocate<TypeAttributeStorage>())
66 TypeAttributeStorage(key);
67 }
68
getTypefir::detail::TypeAttributeStorage69 mlir::Type getType() const { return value; }
70
71 private:
72 mlir::Type value;
73 };
74 } // namespace detail
75
get(mlir::Type value)76 ExactTypeAttr ExactTypeAttr::get(mlir::Type value) {
77 return Base::get(value.getContext(), value);
78 }
79
getType() const80 mlir::Type ExactTypeAttr::getType() const { return getImpl()->getType(); }
81
get(mlir::Type value)82 SubclassAttr SubclassAttr::get(mlir::Type value) {
83 return Base::get(value.getContext(), value);
84 }
85
getType() const86 mlir::Type SubclassAttr::getType() const { return getImpl()->getType(); }
87
88 using AttributeUniquer = mlir::detail::AttributeUniquer;
89
get(mlir::MLIRContext * ctxt)90 ClosedIntervalAttr ClosedIntervalAttr::get(mlir::MLIRContext *ctxt) {
91 return AttributeUniquer::get<ClosedIntervalAttr>(ctxt);
92 }
93
get(mlir::MLIRContext * ctxt)94 UpperBoundAttr UpperBoundAttr::get(mlir::MLIRContext *ctxt) {
95 return AttributeUniquer::get<UpperBoundAttr>(ctxt);
96 }
97
get(mlir::MLIRContext * ctxt)98 LowerBoundAttr LowerBoundAttr::get(mlir::MLIRContext *ctxt) {
99 return AttributeUniquer::get<LowerBoundAttr>(ctxt);
100 }
101
get(mlir::MLIRContext * ctxt)102 PointIntervalAttr PointIntervalAttr::get(mlir::MLIRContext *ctxt) {
103 return AttributeUniquer::get<PointIntervalAttr>(ctxt);
104 }
105
106 // RealAttr
107
get(mlir::MLIRContext * ctxt,const RealAttr::ValueType & key)108 RealAttr RealAttr::get(mlir::MLIRContext *ctxt,
109 const RealAttr::ValueType &key) {
110 return Base::get(ctxt, key);
111 }
112
getFKind() const113 int RealAttr::getFKind() const { return getImpl()->getFKind(); }
114
getValue() const115 llvm::APFloat RealAttr::getValue() const { return getImpl()->getValue(); }
116
117 // FIR attribute parsing
118
119 namespace {
parseFirRealAttr(FIROpsDialect * dialect,mlir::DialectAsmParser & parser,mlir::Type type)120 mlir::Attribute parseFirRealAttr(FIROpsDialect *dialect,
121 mlir::DialectAsmParser &parser,
122 mlir::Type type) {
123 int kind = 0;
124 if (parser.parseLess() || parser.parseInteger(kind) || parser.parseComma()) {
125 parser.emitError(parser.getNameLoc(), "expected '<' kind ','");
126 return {};
127 }
128 KindMapping kindMap(dialect->getContext());
129 llvm::APFloat value(0.);
130 if (parser.parseOptionalKeyword("i")) {
131 // `i` not present, so literal float must be present
132 double dontCare;
133 if (parser.parseFloat(dontCare) || parser.parseGreater()) {
134 parser.emitError(parser.getNameLoc(), "expected real constant '>'");
135 return {};
136 }
137 auto fltStr = parser.getFullSymbolSpec()
138 .drop_until([](char c) { return c == ','; })
139 .drop_front()
140 .drop_while([](char c) { return c == ' ' || c == '\t'; })
141 .take_until([](char c) {
142 return c == '>' || c == ' ' || c == '\t';
143 });
144 value = llvm::APFloat(kindMap.getFloatSemantics(kind), fltStr);
145 } else {
146 // `i` is present, so literal bitstring (hex) must be present
147 llvm::StringRef hex;
148 if (parser.parseKeyword(&hex) || parser.parseGreater()) {
149 parser.emitError(parser.getNameLoc(), "expected real constant '>'");
150 return {};
151 }
152 auto bits = llvm::APInt(kind * 8, hex.drop_front(), 16);
153 value = llvm::APFloat(kindMap.getFloatSemantics(kind), bits);
154 }
155 return RealAttr::get(dialect->getContext(), {kind, value});
156 }
157 } // namespace
158
parseFirAttribute(FIROpsDialect * dialect,mlir::DialectAsmParser & parser,mlir::Type type)159 mlir::Attribute parseFirAttribute(FIROpsDialect *dialect,
160 mlir::DialectAsmParser &parser,
161 mlir::Type type) {
162 auto loc = parser.getNameLoc();
163 llvm::StringRef attrName;
164 if (parser.parseKeyword(&attrName)) {
165 parser.emitError(loc, "expected an attribute name");
166 return {};
167 }
168
169 if (attrName == ExactTypeAttr::getAttrName()) {
170 mlir::Type type;
171 if (parser.parseLess() || parser.parseType(type) || parser.parseGreater()) {
172 parser.emitError(loc, "expected a type");
173 return {};
174 }
175 return ExactTypeAttr::get(type);
176 }
177 if (attrName == SubclassAttr::getAttrName()) {
178 mlir::Type type;
179 if (parser.parseLess() || parser.parseType(type) || parser.parseGreater()) {
180 parser.emitError(loc, "expected a subtype");
181 return {};
182 }
183 return SubclassAttr::get(type);
184 }
185 if (attrName == PointIntervalAttr::getAttrName())
186 return PointIntervalAttr::get(dialect->getContext());
187 if (attrName == LowerBoundAttr::getAttrName())
188 return LowerBoundAttr::get(dialect->getContext());
189 if (attrName == UpperBoundAttr::getAttrName())
190 return UpperBoundAttr::get(dialect->getContext());
191 if (attrName == ClosedIntervalAttr::getAttrName())
192 return ClosedIntervalAttr::get(dialect->getContext());
193 if (attrName == RealAttr::getAttrName())
194 return parseFirRealAttr(dialect, parser, type);
195
196 parser.emitError(loc, "unknown FIR attribute: ") << attrName;
197 return {};
198 }
199
200 // FIR attribute pretty printer
201
printFirAttribute(FIROpsDialect * dialect,mlir::Attribute attr,mlir::DialectAsmPrinter & p)202 void printFirAttribute(FIROpsDialect *dialect, mlir::Attribute attr,
203 mlir::DialectAsmPrinter &p) {
204 auto &os = p.getStream();
205 if (auto exact = attr.dyn_cast<fir::ExactTypeAttr>()) {
206 os << fir::ExactTypeAttr::getAttrName() << '<';
207 p.printType(exact.getType());
208 os << '>';
209 } else if (auto sub = attr.dyn_cast<fir::SubclassAttr>()) {
210 os << fir::SubclassAttr::getAttrName() << '<';
211 p.printType(sub.getType());
212 os << '>';
213 } else if (attr.dyn_cast_or_null<fir::PointIntervalAttr>()) {
214 os << fir::PointIntervalAttr::getAttrName();
215 } else if (attr.dyn_cast_or_null<fir::ClosedIntervalAttr>()) {
216 os << fir::ClosedIntervalAttr::getAttrName();
217 } else if (attr.dyn_cast_or_null<fir::LowerBoundAttr>()) {
218 os << fir::LowerBoundAttr::getAttrName();
219 } else if (attr.dyn_cast_or_null<fir::UpperBoundAttr>()) {
220 os << fir::UpperBoundAttr::getAttrName();
221 } else if (auto a = attr.dyn_cast_or_null<fir::RealAttr>()) {
222 os << fir::RealAttr::getAttrName() << '<' << a.getFKind() << ", i x";
223 llvm::SmallString<40> ss;
224 a.getValue().bitcastToAPInt().toStringUnsigned(ss, 16);
225 os << ss << '>';
226 } else {
227 llvm_unreachable("attribute pretty-printer is not implemented");
228 }
229 }
230
231 } // namespace fir
232