1 //===- IRModules.cpp - IR Submodules of pybind module ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "IRModules.h"
10
11 #include "Globals.h"
12 #include "PybindUtils.h"
13
14 #include "mlir-c/Bindings/Python/Interop.h"
15 #include "mlir-c/BuiltinAttributes.h"
16 #include "mlir-c/BuiltinTypes.h"
17 #include "mlir-c/Registration.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include <pybind11/stl.h>
20
21 namespace py = pybind11;
22 using namespace mlir;
23 using namespace mlir::python;
24
25 using llvm::SmallVector;
26 using llvm::StringRef;
27 using llvm::Twine;
28
29 //------------------------------------------------------------------------------
30 // Docstrings (trivial, non-duplicated docstrings are included inline).
31 //------------------------------------------------------------------------------
32
33 static const char kContextParseTypeDocstring[] =
34 R"(Parses the assembly form of a type.
35
36 Returns a Type object or raises a ValueError if the type cannot be parsed.
37
38 See also: https://mlir.llvm.org/docs/LangRef/#type-system
39 )";
40
41 static const char kContextGetFileLocationDocstring[] =
42 R"(Gets a Location representing a file, line and column)";
43
44 static const char kModuleParseDocstring[] =
45 R"(Parses a module's assembly format from a string.
46
47 Returns a new MlirModule or raises a ValueError if the parsing fails.
48
49 See also: https://mlir.llvm.org/docs/LangRef/
50 )";
51
52 static const char kOperationCreateDocstring[] =
53 R"(Creates a new operation.
54
55 Args:
56 name: Operation name (e.g. "dialect.operation").
57 results: Sequence of Type representing op result types.
58 attributes: Dict of str:Attribute.
59 successors: List of Block for the operation's successors.
60 regions: Number of regions to create.
61 location: A Location object (defaults to resolve from context manager).
62 ip: An InsertionPoint (defaults to resolve from context manager or set to
63 False to disable insertion, even with an insertion point set in the
64 context manager).
65 Returns:
66 A new "detached" Operation object. Detached operations can be added
67 to blocks, which causes them to become "attached."
68 )";
69
70 static const char kOperationPrintDocstring[] =
71 R"(Prints the assembly form of the operation to a file like object.
72
73 Args:
74 file: The file like object to write to. Defaults to sys.stdout.
75 binary: Whether to write bytes (True) or str (False). Defaults to False.
76 large_elements_limit: Whether to elide elements attributes above this
77 number of elements. Defaults to None (no limit).
78 enable_debug_info: Whether to print debug/location information. Defaults
79 to False.
80 pretty_debug_info: Whether to format debug information for easier reading
81 by a human (warning: the result is unparseable).
82 print_generic_op_form: Whether to print the generic assembly forms of all
83 ops. Defaults to False.
84 use_local_Scope: Whether to print in a way that is more optimized for
85 multi-threaded access but may not be consistent with how the overall
86 module prints.
87 )";
88
89 static const char kOperationGetAsmDocstring[] =
90 R"(Gets the assembly form of the operation with all options available.
91
92 Args:
93 binary: Whether to return a bytes (True) or str (False) object. Defaults to
94 False.
95 ... others ...: See the print() method for common keyword arguments for
96 configuring the printout.
97 Returns:
98 Either a bytes or str object, depending on the setting of the 'binary'
99 argument.
100 )";
101
102 static const char kOperationStrDunderDocstring[] =
103 R"(Gets the assembly form of the operation with default options.
104
105 If more advanced control over the assembly formatting or I/O options is needed,
106 use the dedicated print or get_asm method, which supports keyword arguments to
107 customize behavior.
108 )";
109
110 static const char kDumpDocstring[] =
111 R"(Dumps a debug representation of the object to stderr.)";
112
113 static const char kAppendBlockDocstring[] =
114 R"(Appends a new block, with argument types as positional args.
115
116 Returns:
117 The created block.
118 )";
119
120 static const char kValueDunderStrDocstring[] =
121 R"(Returns the string form of the value.
122
123 If the value is a block argument, this is the assembly form of its type and the
124 position in the argument list. If the value is an operation result, this is
125 equivalent to printing the operation that produced it.
126 )";
127
128 //------------------------------------------------------------------------------
129 // Utilities.
130 //------------------------------------------------------------------------------
131
132 /// Checks whether the given type is an integer or float type.
mlirTypeIsAIntegerOrFloat(MlirType type)133 static int mlirTypeIsAIntegerOrFloat(MlirType type) {
134 return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
135 mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
136 }
137
138 static py::object
createCustomDialectWrapper(const std::string & dialectNamespace,py::object dialectDescriptor)139 createCustomDialectWrapper(const std::string &dialectNamespace,
140 py::object dialectDescriptor) {
141 auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace);
142 if (!dialectClass) {
143 // Use the base class.
144 return py::cast(PyDialect(std::move(dialectDescriptor)));
145 }
146
147 // Create the custom implementation.
148 return (*dialectClass)(std::move(dialectDescriptor));
149 }
150
toMlirStringRef(const std::string & s)151 static MlirStringRef toMlirStringRef(const std::string &s) {
152 return mlirStringRefCreate(s.data(), s.size());
153 }
154
155 //------------------------------------------------------------------------------
156 // Collections.
157 //------------------------------------------------------------------------------
158
159 namespace {
160
161 class PyRegionIterator {
162 public:
PyRegionIterator(PyOperationRef operation)163 PyRegionIterator(PyOperationRef operation)
164 : operation(std::move(operation)) {}
165
dunderIter()166 PyRegionIterator &dunderIter() { return *this; }
167
dunderNext()168 PyRegion dunderNext() {
169 operation->checkValid();
170 if (nextIndex >= mlirOperationGetNumRegions(operation->get())) {
171 throw py::stop_iteration();
172 }
173 MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++);
174 return PyRegion(operation, region);
175 }
176
bind(py::module & m)177 static void bind(py::module &m) {
178 py::class_<PyRegionIterator>(m, "RegionIterator")
179 .def("__iter__", &PyRegionIterator::dunderIter)
180 .def("__next__", &PyRegionIterator::dunderNext);
181 }
182
183 private:
184 PyOperationRef operation;
185 int nextIndex = 0;
186 };
187
188 /// Regions of an op are fixed length and indexed numerically so are represented
189 /// with a sequence-like container.
190 class PyRegionList {
191 public:
PyRegionList(PyOperationRef operation)192 PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {}
193
dunderLen()194 intptr_t dunderLen() {
195 operation->checkValid();
196 return mlirOperationGetNumRegions(operation->get());
197 }
198
dunderGetItem(intptr_t index)199 PyRegion dunderGetItem(intptr_t index) {
200 // dunderLen checks validity.
201 if (index < 0 || index >= dunderLen()) {
202 throw SetPyError(PyExc_IndexError,
203 "attempt to access out of bounds region");
204 }
205 MlirRegion region = mlirOperationGetRegion(operation->get(), index);
206 return PyRegion(operation, region);
207 }
208
bind(py::module & m)209 static void bind(py::module &m) {
210 py::class_<PyRegionList>(m, "ReqionSequence")
211 .def("__len__", &PyRegionList::dunderLen)
212 .def("__getitem__", &PyRegionList::dunderGetItem);
213 }
214
215 private:
216 PyOperationRef operation;
217 };
218
219 class PyBlockIterator {
220 public:
PyBlockIterator(PyOperationRef operation,MlirBlock next)221 PyBlockIterator(PyOperationRef operation, MlirBlock next)
222 : operation(std::move(operation)), next(next) {}
223
dunderIter()224 PyBlockIterator &dunderIter() { return *this; }
225
dunderNext()226 PyBlock dunderNext() {
227 operation->checkValid();
228 if (mlirBlockIsNull(next)) {
229 throw py::stop_iteration();
230 }
231
232 PyBlock returnBlock(operation, next);
233 next = mlirBlockGetNextInRegion(next);
234 return returnBlock;
235 }
236
bind(py::module & m)237 static void bind(py::module &m) {
238 py::class_<PyBlockIterator>(m, "BlockIterator")
239 .def("__iter__", &PyBlockIterator::dunderIter)
240 .def("__next__", &PyBlockIterator::dunderNext);
241 }
242
243 private:
244 PyOperationRef operation;
245 MlirBlock next;
246 };
247
248 /// Blocks are exposed by the C-API as a forward-only linked list. In Python,
249 /// we present them as a more full-featured list-like container but optimize
250 /// it for forward iteration. Blocks are always owned by a region.
251 class PyBlockList {
252 public:
PyBlockList(PyOperationRef operation,MlirRegion region)253 PyBlockList(PyOperationRef operation, MlirRegion region)
254 : operation(std::move(operation)), region(region) {}
255
dunderIter()256 PyBlockIterator dunderIter() {
257 operation->checkValid();
258 return PyBlockIterator(operation, mlirRegionGetFirstBlock(region));
259 }
260
dunderLen()261 intptr_t dunderLen() {
262 operation->checkValid();
263 intptr_t count = 0;
264 MlirBlock block = mlirRegionGetFirstBlock(region);
265 while (!mlirBlockIsNull(block)) {
266 count += 1;
267 block = mlirBlockGetNextInRegion(block);
268 }
269 return count;
270 }
271
dunderGetItem(intptr_t index)272 PyBlock dunderGetItem(intptr_t index) {
273 operation->checkValid();
274 if (index < 0) {
275 throw SetPyError(PyExc_IndexError,
276 "attempt to access out of bounds block");
277 }
278 MlirBlock block = mlirRegionGetFirstBlock(region);
279 while (!mlirBlockIsNull(block)) {
280 if (index == 0) {
281 return PyBlock(operation, block);
282 }
283 block = mlirBlockGetNextInRegion(block);
284 index -= 1;
285 }
286 throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block");
287 }
288
appendBlock(py::args pyArgTypes)289 PyBlock appendBlock(py::args pyArgTypes) {
290 operation->checkValid();
291 llvm::SmallVector<MlirType, 4> argTypes;
292 argTypes.reserve(pyArgTypes.size());
293 for (auto &pyArg : pyArgTypes) {
294 argTypes.push_back(pyArg.cast<PyType &>());
295 }
296
297 MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data());
298 mlirRegionAppendOwnedBlock(region, block);
299 return PyBlock(operation, block);
300 }
301
bind(py::module & m)302 static void bind(py::module &m) {
303 py::class_<PyBlockList>(m, "BlockList")
304 .def("__getitem__", &PyBlockList::dunderGetItem)
305 .def("__iter__", &PyBlockList::dunderIter)
306 .def("__len__", &PyBlockList::dunderLen)
307 .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring);
308 }
309
310 private:
311 PyOperationRef operation;
312 MlirRegion region;
313 };
314
315 class PyOperationIterator {
316 public:
PyOperationIterator(PyOperationRef parentOperation,MlirOperation next)317 PyOperationIterator(PyOperationRef parentOperation, MlirOperation next)
318 : parentOperation(std::move(parentOperation)), next(next) {}
319
dunderIter()320 PyOperationIterator &dunderIter() { return *this; }
321
dunderNext()322 py::object dunderNext() {
323 parentOperation->checkValid();
324 if (mlirOperationIsNull(next)) {
325 throw py::stop_iteration();
326 }
327
328 PyOperationRef returnOperation =
329 PyOperation::forOperation(parentOperation->getContext(), next);
330 next = mlirOperationGetNextInBlock(next);
331 return returnOperation->createOpView();
332 }
333
bind(py::module & m)334 static void bind(py::module &m) {
335 py::class_<PyOperationIterator>(m, "OperationIterator")
336 .def("__iter__", &PyOperationIterator::dunderIter)
337 .def("__next__", &PyOperationIterator::dunderNext);
338 }
339
340 private:
341 PyOperationRef parentOperation;
342 MlirOperation next;
343 };
344
345 /// Operations are exposed by the C-API as a forward-only linked list. In
346 /// Python, we present them as a more full-featured list-like container but
347 /// optimize it for forward iteration. Iterable operations are always owned
348 /// by a block.
349 class PyOperationList {
350 public:
PyOperationList(PyOperationRef parentOperation,MlirBlock block)351 PyOperationList(PyOperationRef parentOperation, MlirBlock block)
352 : parentOperation(std::move(parentOperation)), block(block) {}
353
dunderIter()354 PyOperationIterator dunderIter() {
355 parentOperation->checkValid();
356 return PyOperationIterator(parentOperation,
357 mlirBlockGetFirstOperation(block));
358 }
359
dunderLen()360 intptr_t dunderLen() {
361 parentOperation->checkValid();
362 intptr_t count = 0;
363 MlirOperation childOp = mlirBlockGetFirstOperation(block);
364 while (!mlirOperationIsNull(childOp)) {
365 count += 1;
366 childOp = mlirOperationGetNextInBlock(childOp);
367 }
368 return count;
369 }
370
dunderGetItem(intptr_t index)371 py::object dunderGetItem(intptr_t index) {
372 parentOperation->checkValid();
373 if (index < 0) {
374 throw SetPyError(PyExc_IndexError,
375 "attempt to access out of bounds operation");
376 }
377 MlirOperation childOp = mlirBlockGetFirstOperation(block);
378 while (!mlirOperationIsNull(childOp)) {
379 if (index == 0) {
380 return PyOperation::forOperation(parentOperation->getContext(), childOp)
381 ->createOpView();
382 }
383 childOp = mlirOperationGetNextInBlock(childOp);
384 index -= 1;
385 }
386 throw SetPyError(PyExc_IndexError,
387 "attempt to access out of bounds operation");
388 }
389
bind(py::module & m)390 static void bind(py::module &m) {
391 py::class_<PyOperationList>(m, "OperationList")
392 .def("__getitem__", &PyOperationList::dunderGetItem)
393 .def("__iter__", &PyOperationList::dunderIter)
394 .def("__len__", &PyOperationList::dunderLen);
395 }
396
397 private:
398 PyOperationRef parentOperation;
399 MlirBlock block;
400 };
401
402 } // namespace
403
404 //------------------------------------------------------------------------------
405 // PyMlirContext
406 //------------------------------------------------------------------------------
407
PyMlirContext(MlirContext context)408 PyMlirContext::PyMlirContext(MlirContext context) : context(context) {
409 py::gil_scoped_acquire acquire;
410 auto &liveContexts = getLiveContexts();
411 liveContexts[context.ptr] = this;
412 }
413
~PyMlirContext()414 PyMlirContext::~PyMlirContext() {
415 // Note that the only public way to construct an instance is via the
416 // forContext method, which always puts the associated handle into
417 // liveContexts.
418 py::gil_scoped_acquire acquire;
419 getLiveContexts().erase(context.ptr);
420 mlirContextDestroy(context);
421 }
422
getCapsule()423 py::object PyMlirContext::getCapsule() {
424 return py::reinterpret_steal<py::object>(mlirPythonContextToCapsule(get()));
425 }
426
createFromCapsule(py::object capsule)427 py::object PyMlirContext::createFromCapsule(py::object capsule) {
428 MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr());
429 if (mlirContextIsNull(rawContext))
430 throw py::error_already_set();
431 return forContext(rawContext).releaseObject();
432 }
433
createNewContextForInit()434 PyMlirContext *PyMlirContext::createNewContextForInit() {
435 MlirContext context = mlirContextCreate();
436 mlirRegisterAllDialects(context);
437 return new PyMlirContext(context);
438 }
439
forContext(MlirContext context)440 PyMlirContextRef PyMlirContext::forContext(MlirContext context) {
441 py::gil_scoped_acquire acquire;
442 auto &liveContexts = getLiveContexts();
443 auto it = liveContexts.find(context.ptr);
444 if (it == liveContexts.end()) {
445 // Create.
446 PyMlirContext *unownedContextWrapper = new PyMlirContext(context);
447 py::object pyRef = py::cast(unownedContextWrapper);
448 assert(pyRef && "cast to py::object failed");
449 liveContexts[context.ptr] = unownedContextWrapper;
450 return PyMlirContextRef(unownedContextWrapper, std::move(pyRef));
451 }
452 // Use existing.
453 py::object pyRef = py::cast(it->second);
454 return PyMlirContextRef(it->second, std::move(pyRef));
455 }
456
getLiveContexts()457 PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() {
458 static LiveContextMap liveContexts;
459 return liveContexts;
460 }
461
getLiveCount()462 size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); }
463
getLiveOperationCount()464 size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); }
465
getLiveModuleCount()466 size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }
467
contextEnter()468 pybind11::object PyMlirContext::contextEnter() {
469 return PyThreadContextEntry::pushContext(*this);
470 }
471
contextExit(pybind11::object excType,pybind11::object excVal,pybind11::object excTb)472 void PyMlirContext::contextExit(pybind11::object excType,
473 pybind11::object excVal,
474 pybind11::object excTb) {
475 PyThreadContextEntry::popContext(*this);
476 }
477
resolve()478 PyMlirContext &DefaultingPyMlirContext::resolve() {
479 PyMlirContext *context = PyThreadContextEntry::getDefaultContext();
480 if (!context) {
481 throw SetPyError(
482 PyExc_RuntimeError,
483 "An MLIR function requires a Context but none was provided in the call "
484 "or from the surrounding environment. Either pass to the function with "
485 "a 'context=' argument or establish a default using 'with Context():'");
486 }
487 return *context;
488 }
489
490 //------------------------------------------------------------------------------
491 // PyThreadContextEntry management
492 //------------------------------------------------------------------------------
493
getStack()494 std::vector<PyThreadContextEntry> &PyThreadContextEntry::getStack() {
495 static thread_local std::vector<PyThreadContextEntry> stack;
496 return stack;
497 }
498
getTopOfStack()499 PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() {
500 auto &stack = getStack();
501 if (stack.empty())
502 return nullptr;
503 return &stack.back();
504 }
505
push(FrameKind frameKind,py::object context,py::object insertionPoint,py::object location)506 void PyThreadContextEntry::push(FrameKind frameKind, py::object context,
507 py::object insertionPoint,
508 py::object location) {
509 auto &stack = getStack();
510 stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint),
511 std::move(location));
512 // If the new stack has more than one entry and the context of the new top
513 // entry matches the previous, copy the insertionPoint and location from the
514 // previous entry if missing from the new top entry.
515 if (stack.size() > 1) {
516 auto &prev = *(stack.rbegin() + 1);
517 auto ¤t = stack.back();
518 if (current.context.is(prev.context)) {
519 // Default non-context objects from the previous entry.
520 if (!current.insertionPoint)
521 current.insertionPoint = prev.insertionPoint;
522 if (!current.location)
523 current.location = prev.location;
524 }
525 }
526 }
527
getContext()528 PyMlirContext *PyThreadContextEntry::getContext() {
529 if (!context)
530 return nullptr;
531 return py::cast<PyMlirContext *>(context);
532 }
533
getInsertionPoint()534 PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() {
535 if (!insertionPoint)
536 return nullptr;
537 return py::cast<PyInsertionPoint *>(insertionPoint);
538 }
539
getLocation()540 PyLocation *PyThreadContextEntry::getLocation() {
541 if (!location)
542 return nullptr;
543 return py::cast<PyLocation *>(location);
544 }
545
getDefaultContext()546 PyMlirContext *PyThreadContextEntry::getDefaultContext() {
547 auto *tos = getTopOfStack();
548 return tos ? tos->getContext() : nullptr;
549 }
550
getDefaultInsertionPoint()551 PyInsertionPoint *PyThreadContextEntry::getDefaultInsertionPoint() {
552 auto *tos = getTopOfStack();
553 return tos ? tos->getInsertionPoint() : nullptr;
554 }
555
getDefaultLocation()556 PyLocation *PyThreadContextEntry::getDefaultLocation() {
557 auto *tos = getTopOfStack();
558 return tos ? tos->getLocation() : nullptr;
559 }
560
pushContext(PyMlirContext & context)561 py::object PyThreadContextEntry::pushContext(PyMlirContext &context) {
562 py::object contextObj = py::cast(context);
563 push(FrameKind::Context, /*context=*/contextObj,
564 /*insertionPoint=*/py::object(),
565 /*location=*/py::object());
566 return contextObj;
567 }
568
popContext(PyMlirContext & context)569 void PyThreadContextEntry::popContext(PyMlirContext &context) {
570 auto &stack = getStack();
571 if (stack.empty())
572 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
573 auto &tos = stack.back();
574 if (tos.frameKind != FrameKind::Context && tos.getContext() != &context)
575 throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit");
576 stack.pop_back();
577 }
578
579 py::object
pushInsertionPoint(PyInsertionPoint & insertionPoint)580 PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) {
581 py::object contextObj =
582 insertionPoint.getBlock().getParentOperation()->getContext().getObject();
583 py::object insertionPointObj = py::cast(insertionPoint);
584 push(FrameKind::InsertionPoint,
585 /*context=*/contextObj,
586 /*insertionPoint=*/insertionPointObj,
587 /*location=*/py::object());
588 return insertionPointObj;
589 }
590
popInsertionPoint(PyInsertionPoint & insertionPoint)591 void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) {
592 auto &stack = getStack();
593 if (stack.empty())
594 throw SetPyError(PyExc_RuntimeError,
595 "Unbalanced InsertionPoint enter/exit");
596 auto &tos = stack.back();
597 if (tos.frameKind != FrameKind::InsertionPoint &&
598 tos.getInsertionPoint() != &insertionPoint)
599 throw SetPyError(PyExc_RuntimeError,
600 "Unbalanced InsertionPoint enter/exit");
601 stack.pop_back();
602 }
603
pushLocation(PyLocation & location)604 py::object PyThreadContextEntry::pushLocation(PyLocation &location) {
605 py::object contextObj = location.getContext().getObject();
606 py::object locationObj = py::cast(location);
607 push(FrameKind::Location, /*context=*/contextObj,
608 /*insertionPoint=*/py::object(),
609 /*location=*/locationObj);
610 return locationObj;
611 }
612
popLocation(PyLocation & location)613 void PyThreadContextEntry::popLocation(PyLocation &location) {
614 auto &stack = getStack();
615 if (stack.empty())
616 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
617 auto &tos = stack.back();
618 if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location)
619 throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit");
620 stack.pop_back();
621 }
622
623 //------------------------------------------------------------------------------
624 // PyDialect, PyDialectDescriptor, PyDialects
625 //------------------------------------------------------------------------------
626
getDialectForKey(const std::string & key,bool attrError)627 MlirDialect PyDialects::getDialectForKey(const std::string &key,
628 bool attrError) {
629 // If the "std" dialect was asked for, substitute the empty namespace :(
630 static const std::string emptyKey;
631 const std::string *canonKey = key == "std" ? &emptyKey : &key;
632 MlirDialect dialect = mlirContextGetOrLoadDialect(
633 getContext()->get(), {canonKey->data(), canonKey->size()});
634 if (mlirDialectIsNull(dialect)) {
635 throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError,
636 Twine("Dialect '") + key + "' not found");
637 }
638 return dialect;
639 }
640
641 //------------------------------------------------------------------------------
642 // PyLocation
643 //------------------------------------------------------------------------------
644
getCapsule()645 py::object PyLocation::getCapsule() {
646 return py::reinterpret_steal<py::object>(mlirPythonLocationToCapsule(*this));
647 }
648
createFromCapsule(py::object capsule)649 PyLocation PyLocation::createFromCapsule(py::object capsule) {
650 MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr());
651 if (mlirLocationIsNull(rawLoc))
652 throw py::error_already_set();
653 return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)),
654 rawLoc);
655 }
656
contextEnter()657 py::object PyLocation::contextEnter() {
658 return PyThreadContextEntry::pushLocation(*this);
659 }
660
contextExit(py::object excType,py::object excVal,py::object excTb)661 void PyLocation::contextExit(py::object excType, py::object excVal,
662 py::object excTb) {
663 PyThreadContextEntry::popLocation(*this);
664 }
665
resolve()666 PyLocation &DefaultingPyLocation::resolve() {
667 auto *location = PyThreadContextEntry::getDefaultLocation();
668 if (!location) {
669 throw SetPyError(
670 PyExc_RuntimeError,
671 "An MLIR function requires a Location but none was provided in the "
672 "call or from the surrounding environment. Either pass to the function "
673 "with a 'loc=' argument or establish a default using 'with loc:'");
674 }
675 return *location;
676 }
677
678 //------------------------------------------------------------------------------
679 // PyModule
680 //------------------------------------------------------------------------------
681
PyModule(PyMlirContextRef contextRef,MlirModule module)682 PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
683 : BaseContextObject(std::move(contextRef)), module(module) {}
684
~PyModule()685 PyModule::~PyModule() {
686 py::gil_scoped_acquire acquire;
687 auto &liveModules = getContext()->liveModules;
688 assert(liveModules.count(module.ptr) == 1 &&
689 "destroying module not in live map");
690 liveModules.erase(module.ptr);
691 mlirModuleDestroy(module);
692 }
693
forModule(MlirModule module)694 PyModuleRef PyModule::forModule(MlirModule module) {
695 MlirContext context = mlirModuleGetContext(module);
696 PyMlirContextRef contextRef = PyMlirContext::forContext(context);
697
698 py::gil_scoped_acquire acquire;
699 auto &liveModules = contextRef->liveModules;
700 auto it = liveModules.find(module.ptr);
701 if (it == liveModules.end()) {
702 // Create.
703 PyModule *unownedModule = new PyModule(std::move(contextRef), module);
704 // Note that the default return value policy on cast is automatic_reference,
705 // which does not take ownership (delete will not be called).
706 // Just be explicit.
707 py::object pyRef =
708 py::cast(unownedModule, py::return_value_policy::take_ownership);
709 unownedModule->handle = pyRef;
710 liveModules[module.ptr] =
711 std::make_pair(unownedModule->handle, unownedModule);
712 return PyModuleRef(unownedModule, std::move(pyRef));
713 }
714 // Use existing.
715 PyModule *existing = it->second.second;
716 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
717 return PyModuleRef(existing, std::move(pyRef));
718 }
719
createFromCapsule(py::object capsule)720 py::object PyModule::createFromCapsule(py::object capsule) {
721 MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr());
722 if (mlirModuleIsNull(rawModule))
723 throw py::error_already_set();
724 return forModule(rawModule).releaseObject();
725 }
726
getCapsule()727 py::object PyModule::getCapsule() {
728 return py::reinterpret_steal<py::object>(mlirPythonModuleToCapsule(get()));
729 }
730
731 //------------------------------------------------------------------------------
732 // PyOperation
733 //------------------------------------------------------------------------------
734
PyOperation(PyMlirContextRef contextRef,MlirOperation operation)735 PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation)
736 : BaseContextObject(std::move(contextRef)), operation(operation) {}
737
~PyOperation()738 PyOperation::~PyOperation() {
739 auto &liveOperations = getContext()->liveOperations;
740 assert(liveOperations.count(operation.ptr) == 1 &&
741 "destroying operation not in live map");
742 liveOperations.erase(operation.ptr);
743 if (!isAttached()) {
744 mlirOperationDestroy(operation);
745 }
746 }
747
createInstance(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)748 PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
749 MlirOperation operation,
750 py::object parentKeepAlive) {
751 auto &liveOperations = contextRef->liveOperations;
752 // Create.
753 PyOperation *unownedOperation =
754 new PyOperation(std::move(contextRef), operation);
755 // Note that the default return value policy on cast is automatic_reference,
756 // which does not take ownership (delete will not be called).
757 // Just be explicit.
758 py::object pyRef =
759 py::cast(unownedOperation, py::return_value_policy::take_ownership);
760 unownedOperation->handle = pyRef;
761 if (parentKeepAlive) {
762 unownedOperation->parentKeepAlive = std::move(parentKeepAlive);
763 }
764 liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation);
765 return PyOperationRef(unownedOperation, std::move(pyRef));
766 }
767
forOperation(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)768 PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
769 MlirOperation operation,
770 py::object parentKeepAlive) {
771 auto &liveOperations = contextRef->liveOperations;
772 auto it = liveOperations.find(operation.ptr);
773 if (it == liveOperations.end()) {
774 // Create.
775 return createInstance(std::move(contextRef), operation,
776 std::move(parentKeepAlive));
777 }
778 // Use existing.
779 PyOperation *existing = it->second.second;
780 py::object pyRef = py::reinterpret_borrow<py::object>(it->second.first);
781 return PyOperationRef(existing, std::move(pyRef));
782 }
783
createDetached(PyMlirContextRef contextRef,MlirOperation operation,py::object parentKeepAlive)784 PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
785 MlirOperation operation,
786 py::object parentKeepAlive) {
787 auto &liveOperations = contextRef->liveOperations;
788 assert(liveOperations.count(operation.ptr) == 0 &&
789 "cannot create detached operation that already exists");
790 (void)liveOperations;
791
792 PyOperationRef created = createInstance(std::move(contextRef), operation,
793 std::move(parentKeepAlive));
794 created->attached = false;
795 return created;
796 }
797
checkValid() const798 void PyOperation::checkValid() const {
799 if (!valid) {
800 throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated");
801 }
802 }
803
print(py::object fileObject,bool binary,llvm::Optional<int64_t> largeElementsLimit,bool enableDebugInfo,bool prettyDebugInfo,bool printGenericOpForm,bool useLocalScope)804 void PyOperationBase::print(py::object fileObject, bool binary,
805 llvm::Optional<int64_t> largeElementsLimit,
806 bool enableDebugInfo, bool prettyDebugInfo,
807 bool printGenericOpForm, bool useLocalScope) {
808 PyOperation &operation = getOperation();
809 operation.checkValid();
810 if (fileObject.is_none())
811 fileObject = py::module::import("sys").attr("stdout");
812
813 if (!printGenericOpForm && !mlirOperationVerify(operation)) {
814 fileObject.attr("write")("// Verification failed, printing generic form\n");
815 printGenericOpForm = true;
816 }
817
818 MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate();
819 if (largeElementsLimit)
820 mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit);
821 if (enableDebugInfo)
822 mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo);
823 if (printGenericOpForm)
824 mlirOpPrintingFlagsPrintGenericOpForm(flags);
825
826 PyFileAccumulator accum(fileObject, binary);
827 py::gil_scoped_release();
828 mlirOperationPrintWithFlags(operation, flags, accum.getCallback(),
829 accum.getUserData());
830 mlirOpPrintingFlagsDestroy(flags);
831 }
832
getAsm(bool binary,llvm::Optional<int64_t> largeElementsLimit,bool enableDebugInfo,bool prettyDebugInfo,bool printGenericOpForm,bool useLocalScope)833 py::object PyOperationBase::getAsm(bool binary,
834 llvm::Optional<int64_t> largeElementsLimit,
835 bool enableDebugInfo, bool prettyDebugInfo,
836 bool printGenericOpForm,
837 bool useLocalScope) {
838 py::object fileObject;
839 if (binary) {
840 fileObject = py::module::import("io").attr("BytesIO")();
841 } else {
842 fileObject = py::module::import("io").attr("StringIO")();
843 }
844 print(fileObject, /*binary=*/binary,
845 /*largeElementsLimit=*/largeElementsLimit,
846 /*enableDebugInfo=*/enableDebugInfo,
847 /*prettyDebugInfo=*/prettyDebugInfo,
848 /*printGenericOpForm=*/printGenericOpForm,
849 /*useLocalScope=*/useLocalScope);
850
851 return fileObject.attr("getvalue")();
852 }
853
getParentOperation()854 PyOperationRef PyOperation::getParentOperation() {
855 if (!isAttached())
856 throw SetPyError(PyExc_ValueError, "Detached operations have no parent");
857 MlirOperation operation = mlirOperationGetParentOperation(get());
858 if (mlirOperationIsNull(operation))
859 throw SetPyError(PyExc_ValueError, "Operation has no parent.");
860 return PyOperation::forOperation(getContext(), operation);
861 }
862
getBlock()863 PyBlock PyOperation::getBlock() {
864 PyOperationRef parentOperation = getParentOperation();
865 MlirBlock block = mlirOperationGetBlock(get());
866 assert(!mlirBlockIsNull(block) && "Attached operation has null parent");
867 return PyBlock{std::move(parentOperation), block};
868 }
869
create(std::string name,llvm::Optional<std::vector<PyValue * >> operands,llvm::Optional<std::vector<PyType * >> results,llvm::Optional<py::dict> attributes,llvm::Optional<std::vector<PyBlock * >> successors,int regions,DefaultingPyLocation location,py::object maybeIp)870 py::object PyOperation::create(
871 std::string name, llvm::Optional<std::vector<PyValue *>> operands,
872 llvm::Optional<std::vector<PyType *>> results,
873 llvm::Optional<py::dict> attributes,
874 llvm::Optional<std::vector<PyBlock *>> successors, int regions,
875 DefaultingPyLocation location, py::object maybeIp) {
876 llvm::SmallVector<MlirValue, 4> mlirOperands;
877 llvm::SmallVector<MlirType, 4> mlirResults;
878 llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
879 llvm::SmallVector<std::pair<std::string, MlirAttribute>, 4> mlirAttributes;
880
881 // General parameter validation.
882 if (regions < 0)
883 throw SetPyError(PyExc_ValueError, "number of regions must be >= 0");
884
885 // Unpack/validate operands.
886 if (operands) {
887 mlirOperands.reserve(operands->size());
888 for (PyValue *operand : *operands) {
889 if (!operand)
890 throw SetPyError(PyExc_ValueError, "operand value cannot be None");
891 mlirOperands.push_back(operand->get());
892 }
893 }
894
895 // Unpack/validate results.
896 if (results) {
897 mlirResults.reserve(results->size());
898 for (PyType *result : *results) {
899 // TODO: Verify result type originate from the same context.
900 if (!result)
901 throw SetPyError(PyExc_ValueError, "result type cannot be None");
902 mlirResults.push_back(*result);
903 }
904 }
905 // Unpack/validate attributes.
906 if (attributes) {
907 mlirAttributes.reserve(attributes->size());
908 for (auto &it : *attributes) {
909 std::string key;
910 try {
911 key = it.first.cast<std::string>();
912 } catch (py::cast_error &err) {
913 std::string msg = "Invalid attribute key (not a string) when "
914 "attempting to create the operation \"" +
915 name + "\" (" + err.what() + ")";
916 throw py::cast_error(msg);
917 }
918 try {
919 auto &attribute = it.second.cast<PyAttribute &>();
920 // TODO: Verify attribute originates from the same context.
921 mlirAttributes.emplace_back(std::move(key), attribute);
922 } catch (py::reference_cast_error &) {
923 // This exception seems thrown when the value is "None".
924 std::string msg =
925 "Found an invalid (`None`?) attribute value for the key \"" + key +
926 "\" when attempting to create the operation \"" + name + "\"";
927 throw py::cast_error(msg);
928 } catch (py::cast_error &err) {
929 std::string msg = "Invalid attribute value for the key \"" + key +
930 "\" when attempting to create the operation \"" +
931 name + "\" (" + err.what() + ")";
932 throw py::cast_error(msg);
933 }
934 }
935 }
936 // Unpack/validate successors.
937 if (successors) {
938 llvm::SmallVector<MlirBlock, 4> mlirSuccessors;
939 mlirSuccessors.reserve(successors->size());
940 for (auto *successor : *successors) {
941 // TODO: Verify successor originate from the same context.
942 if (!successor)
943 throw SetPyError(PyExc_ValueError, "successor block cannot be None");
944 mlirSuccessors.push_back(successor->get());
945 }
946 }
947
948 // Apply unpacked/validated to the operation state. Beyond this
949 // point, exceptions cannot be thrown or else the state will leak.
950 MlirOperationState state =
951 mlirOperationStateGet(toMlirStringRef(name), location);
952 if (!mlirOperands.empty())
953 mlirOperationStateAddOperands(&state, mlirOperands.size(),
954 mlirOperands.data());
955 if (!mlirResults.empty())
956 mlirOperationStateAddResults(&state, mlirResults.size(),
957 mlirResults.data());
958 if (!mlirAttributes.empty()) {
959 // Note that the attribute names directly reference bytes in
960 // mlirAttributes, so that vector must not be changed from here
961 // on.
962 llvm::SmallVector<MlirNamedAttribute, 4> mlirNamedAttributes;
963 mlirNamedAttributes.reserve(mlirAttributes.size());
964 for (auto &it : mlirAttributes)
965 mlirNamedAttributes.push_back(
966 mlirNamedAttributeGet(toMlirStringRef(it.first), it.second));
967 mlirOperationStateAddAttributes(&state, mlirNamedAttributes.size(),
968 mlirNamedAttributes.data());
969 }
970 if (!mlirSuccessors.empty())
971 mlirOperationStateAddSuccessors(&state, mlirSuccessors.size(),
972 mlirSuccessors.data());
973 if (regions) {
974 llvm::SmallVector<MlirRegion, 4> mlirRegions;
975 mlirRegions.resize(regions);
976 for (int i = 0; i < regions; ++i)
977 mlirRegions[i] = mlirRegionCreate();
978 mlirOperationStateAddOwnedRegions(&state, mlirRegions.size(),
979 mlirRegions.data());
980 }
981
982 // Construct the operation.
983 MlirOperation operation = mlirOperationCreate(&state);
984 PyOperationRef created =
985 PyOperation::createDetached(location->getContext(), operation);
986
987 // InsertPoint active?
988 if (!maybeIp.is(py::cast(false))) {
989 PyInsertionPoint *ip;
990 if (maybeIp.is_none()) {
991 ip = PyThreadContextEntry::getDefaultInsertionPoint();
992 } else {
993 ip = py::cast<PyInsertionPoint *>(maybeIp);
994 }
995 if (ip)
996 ip->insert(*created.get());
997 }
998
999 return created->createOpView();
1000 }
1001
createOpView()1002 py::object PyOperation::createOpView() {
1003 MlirIdentifier ident = mlirOperationGetName(get());
1004 MlirStringRef identStr = mlirIdentifierStr(ident);
1005 auto opViewClass = PyGlobals::get().lookupRawOpViewClass(
1006 StringRef(identStr.data, identStr.length));
1007 if (opViewClass)
1008 return (*opViewClass)(getRef().getObject());
1009 return py::cast(PyOpView(getRef().getObject()));
1010 }
1011
PyOpView(py::object operationObject)1012 PyOpView::PyOpView(py::object operationObject)
1013 // Casting through the PyOperationBase base-class and then back to the
1014 // Operation lets us accept any PyOperationBase subclass.
1015 : operation(py::cast<PyOperationBase &>(operationObject).getOperation()),
1016 operationObject(operation.getRef().getObject()) {}
1017
createRawSubclass(py::object userClass)1018 py::object PyOpView::createRawSubclass(py::object userClass) {
1019 // This is... a little gross. The typical pattern is to have a pure python
1020 // class that extends OpView like:
1021 // class AddFOp(_cext.ir.OpView):
1022 // def __init__(self, loc, lhs, rhs):
1023 // operation = loc.context.create_operation(
1024 // "addf", lhs, rhs, results=[lhs.type])
1025 // super().__init__(operation)
1026 //
1027 // I.e. The goal of the user facing type is to provide a nice constructor
1028 // that has complete freedom for the op under construction. This is at odds
1029 // with our other desire to sometimes create this object by just passing an
1030 // operation (to initialize the base class). We could do *arg and **kwargs
1031 // munging to try to make it work, but instead, we synthesize a new class
1032 // on the fly which extends this user class (AddFOp in this example) and
1033 // *give it* the base class's __init__ method, thus bypassing the
1034 // intermediate subclass's __init__ method entirely. While slightly,
1035 // underhanded, this is safe/legal because the type hierarchy has not changed
1036 // (we just added a new leaf) and we aren't mucking around with __new__.
1037 // Typically, this new class will be stored on the original as "_Raw" and will
1038 // be used for casts and other things that need a variant of the class that
1039 // is initialized purely from an operation.
1040 py::object parentMetaclass =
1041 py::reinterpret_borrow<py::object>((PyObject *)&PyType_Type);
1042 py::dict attributes;
1043 // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from
1044 // now.
1045 // auto opViewType = py::type::of<PyOpView>();
1046 auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true);
1047 attributes["__init__"] = opViewType.attr("__init__");
1048 py::str origName = userClass.attr("__name__");
1049 py::str newName = py::str("_") + origName;
1050 return parentMetaclass(newName, py::make_tuple(userClass), attributes);
1051 }
1052
1053 //------------------------------------------------------------------------------
1054 // PyInsertionPoint.
1055 //------------------------------------------------------------------------------
1056
PyInsertionPoint(PyBlock & block)1057 PyInsertionPoint::PyInsertionPoint(PyBlock &block) : block(block) {}
1058
PyInsertionPoint(PyOperationBase & beforeOperationBase)1059 PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase)
1060 : refOperation(beforeOperationBase.getOperation().getRef()),
1061 block((*refOperation)->getBlock()) {}
1062
insert(PyOperationBase & operationBase)1063 void PyInsertionPoint::insert(PyOperationBase &operationBase) {
1064 PyOperation &operation = operationBase.getOperation();
1065 if (operation.isAttached())
1066 throw SetPyError(PyExc_ValueError,
1067 "Attempt to insert operation that is already attached");
1068 block.getParentOperation()->checkValid();
1069 MlirOperation beforeOp = {nullptr};
1070 if (refOperation) {
1071 // Insert before operation.
1072 (*refOperation)->checkValid();
1073 beforeOp = (*refOperation)->get();
1074 }
1075 mlirBlockInsertOwnedOperationBefore(block.get(), beforeOp, operation);
1076 operation.setAttached();
1077 }
1078
atBlockBegin(PyBlock & block)1079 PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) {
1080 MlirOperation firstOp = mlirBlockGetFirstOperation(block.get());
1081 if (mlirOperationIsNull(firstOp)) {
1082 // Just insert at end.
1083 return PyInsertionPoint(block);
1084 }
1085
1086 // Insert before first op.
1087 PyOperationRef firstOpRef = PyOperation::forOperation(
1088 block.getParentOperation()->getContext(), firstOp);
1089 return PyInsertionPoint{block, std::move(firstOpRef)};
1090 }
1091
atBlockTerminator(PyBlock & block)1092 PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) {
1093 MlirOperation terminator = mlirBlockGetTerminator(block.get());
1094 if (mlirOperationIsNull(terminator))
1095 throw SetPyError(PyExc_ValueError, "Block has no terminator");
1096 PyOperationRef terminatorOpRef = PyOperation::forOperation(
1097 block.getParentOperation()->getContext(), terminator);
1098 return PyInsertionPoint{block, std::move(terminatorOpRef)};
1099 }
1100
contextEnter()1101 py::object PyInsertionPoint::contextEnter() {
1102 return PyThreadContextEntry::pushInsertionPoint(*this);
1103 }
1104
contextExit(pybind11::object excType,pybind11::object excVal,pybind11::object excTb)1105 void PyInsertionPoint::contextExit(pybind11::object excType,
1106 pybind11::object excVal,
1107 pybind11::object excTb) {
1108 PyThreadContextEntry::popInsertionPoint(*this);
1109 }
1110
1111 //------------------------------------------------------------------------------
1112 // PyAttribute.
1113 //------------------------------------------------------------------------------
1114
operator ==(const PyAttribute & other)1115 bool PyAttribute::operator==(const PyAttribute &other) {
1116 return mlirAttributeEqual(attr, other.attr);
1117 }
1118
getCapsule()1119 py::object PyAttribute::getCapsule() {
1120 return py::reinterpret_steal<py::object>(mlirPythonAttributeToCapsule(*this));
1121 }
1122
createFromCapsule(py::object capsule)1123 PyAttribute PyAttribute::createFromCapsule(py::object capsule) {
1124 MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr());
1125 if (mlirAttributeIsNull(rawAttr))
1126 throw py::error_already_set();
1127 return PyAttribute(
1128 PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr);
1129 }
1130
1131 //------------------------------------------------------------------------------
1132 // PyNamedAttribute.
1133 //------------------------------------------------------------------------------
1134
PyNamedAttribute(MlirAttribute attr,std::string ownedName)1135 PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName)
1136 : ownedName(new std::string(std::move(ownedName))) {
1137 namedAttr = mlirNamedAttributeGet(toMlirStringRef(*this->ownedName), attr);
1138 }
1139
1140 //------------------------------------------------------------------------------
1141 // PyType.
1142 //------------------------------------------------------------------------------
1143
operator ==(const PyType & other)1144 bool PyType::operator==(const PyType &other) {
1145 return mlirTypeEqual(type, other.type);
1146 }
1147
getCapsule()1148 py::object PyType::getCapsule() {
1149 return py::reinterpret_steal<py::object>(mlirPythonTypeToCapsule(*this));
1150 }
1151
createFromCapsule(py::object capsule)1152 PyType PyType::createFromCapsule(py::object capsule) {
1153 MlirType rawType = mlirPythonCapsuleToType(capsule.ptr());
1154 if (mlirTypeIsNull(rawType))
1155 throw py::error_already_set();
1156 return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)),
1157 rawType);
1158 }
1159
1160 //------------------------------------------------------------------------------
1161 // PyValue and subclases.
1162 //------------------------------------------------------------------------------
1163
1164 namespace {
1165 /// CRTP base class for Python MLIR values that subclass Value and should be
1166 /// castable from it. The value hierarchy is one level deep and is not supposed
1167 /// to accommodate other levels unless core MLIR changes.
1168 template <typename DerivedTy>
1169 class PyConcreteValue : public PyValue {
1170 public:
1171 // Derived classes must define statics for:
1172 // IsAFunctionTy isaFunction
1173 // const char *pyClassName
1174 // and redefine bindDerived.
1175 using ClassTy = py::class_<DerivedTy, PyValue>;
1176 using IsAFunctionTy = bool (*)(MlirValue);
1177
1178 PyConcreteValue() = default;
PyConcreteValue(PyOperationRef operationRef,MlirValue value)1179 PyConcreteValue(PyOperationRef operationRef, MlirValue value)
1180 : PyValue(operationRef, value) {}
PyConcreteValue(PyValue & orig)1181 PyConcreteValue(PyValue &orig)
1182 : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {}
1183
1184 /// Attempts to cast the original value to the derived type and throws on
1185 /// type mismatches.
castFrom(PyValue & orig)1186 static MlirValue castFrom(PyValue &orig) {
1187 if (!DerivedTy::isaFunction(orig.get())) {
1188 auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1189 throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") +
1190 DerivedTy::pyClassName +
1191 " (from " + origRepr + ")");
1192 }
1193 return orig.get();
1194 }
1195
1196 /// Binds the Python module objects to functions of this class.
bind(py::module & m)1197 static void bind(py::module &m) {
1198 auto cls = ClassTy(m, DerivedTy::pyClassName);
1199 cls.def(py::init<PyValue &>(), py::keep_alive<0, 1>());
1200 DerivedTy::bindDerived(cls);
1201 }
1202
1203 /// Implemented by derived classes to add methods to the Python subclass.
bindDerived(ClassTy & m)1204 static void bindDerived(ClassTy &m) {}
1205 };
1206
1207 /// Python wrapper for MlirBlockArgument.
1208 class PyBlockArgument : public PyConcreteValue<PyBlockArgument> {
1209 public:
1210 static constexpr IsAFunctionTy isaFunction = mlirValueIsABlockArgument;
1211 static constexpr const char *pyClassName = "BlockArgument";
1212 using PyConcreteValue::PyConcreteValue;
1213
bindDerived(ClassTy & c)1214 static void bindDerived(ClassTy &c) {
1215 c.def_property_readonly("owner", [](PyBlockArgument &self) {
1216 return PyBlock(self.getParentOperation(),
1217 mlirBlockArgumentGetOwner(self.get()));
1218 });
1219 c.def_property_readonly("arg_number", [](PyBlockArgument &self) {
1220 return mlirBlockArgumentGetArgNumber(self.get());
1221 });
1222 c.def("set_type", [](PyBlockArgument &self, PyType type) {
1223 return mlirBlockArgumentSetType(self.get(), type);
1224 });
1225 }
1226 };
1227
1228 /// Python wrapper for MlirOpResult.
1229 class PyOpResult : public PyConcreteValue<PyOpResult> {
1230 public:
1231 static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult;
1232 static constexpr const char *pyClassName = "OpResult";
1233 using PyConcreteValue::PyConcreteValue;
1234
bindDerived(ClassTy & c)1235 static void bindDerived(ClassTy &c) {
1236 c.def_property_readonly("owner", [](PyOpResult &self) {
1237 assert(
1238 mlirOperationEqual(self.getParentOperation()->get(),
1239 mlirOpResultGetOwner(self.get())) &&
1240 "expected the owner of the value in Python to match that in the IR");
1241 return self.getParentOperation();
1242 });
1243 c.def_property_readonly("result_number", [](PyOpResult &self) {
1244 return mlirOpResultGetResultNumber(self.get());
1245 });
1246 }
1247 };
1248
1249 /// A list of block arguments. Internally, these are stored as consecutive
1250 /// elements, random access is cheap. The argument list is associated with the
1251 /// operation that contains the block (detached blocks are not allowed in
1252 /// Python bindings) and extends its lifetime.
1253 class PyBlockArgumentList {
1254 public:
PyBlockArgumentList(PyOperationRef operation,MlirBlock block)1255 PyBlockArgumentList(PyOperationRef operation, MlirBlock block)
1256 : operation(std::move(operation)), block(block) {}
1257
1258 /// Returns the length of the block argument list.
dunderLen()1259 intptr_t dunderLen() {
1260 operation->checkValid();
1261 return mlirBlockGetNumArguments(block);
1262 }
1263
1264 /// Returns `index`-th element of the block argument list.
dunderGetItem(intptr_t index)1265 PyBlockArgument dunderGetItem(intptr_t index) {
1266 if (index < 0 || index >= dunderLen()) {
1267 throw SetPyError(PyExc_IndexError,
1268 "attempt to access out of bounds region");
1269 }
1270 PyValue value(operation, mlirBlockGetArgument(block, index));
1271 return PyBlockArgument(value);
1272 }
1273
1274 /// Defines a Python class in the bindings.
bind(py::module & m)1275 static void bind(py::module &m) {
1276 py::class_<PyBlockArgumentList>(m, "BlockArgumentList")
1277 .def("__len__", &PyBlockArgumentList::dunderLen)
1278 .def("__getitem__", &PyBlockArgumentList::dunderGetItem);
1279 }
1280
1281 private:
1282 PyOperationRef operation;
1283 MlirBlock block;
1284 };
1285
1286 /// A list of operation operands. Internally, these are stored as consecutive
1287 /// elements, random access is cheap. The result list is associated with the
1288 /// operation whose results these are, and extends the lifetime of this
1289 /// operation.
1290 class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
1291 public:
1292 static constexpr const char *pyClassName = "OpOperandList";
1293
PyOpOperandList(PyOperationRef operation,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)1294 PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
1295 intptr_t length = -1, intptr_t step = 1)
1296 : Sliceable(startIndex,
1297 length == -1 ? mlirOperationGetNumOperands(operation->get())
1298 : length,
1299 step),
1300 operation(operation) {}
1301
getNumElements()1302 intptr_t getNumElements() {
1303 operation->checkValid();
1304 return mlirOperationGetNumOperands(operation->get());
1305 }
1306
getElement(intptr_t pos)1307 PyValue getElement(intptr_t pos) {
1308 return PyValue(operation, mlirOperationGetOperand(operation->get(), pos));
1309 }
1310
slice(intptr_t startIndex,intptr_t length,intptr_t step)1311 PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1312 return PyOpOperandList(operation, startIndex, length, step);
1313 }
1314
1315 private:
1316 PyOperationRef operation;
1317 };
1318
1319 /// A list of operation results. Internally, these are stored as consecutive
1320 /// elements, random access is cheap. The result list is associated with the
1321 /// operation whose results these are, and extends the lifetime of this
1322 /// operation.
1323 class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
1324 public:
1325 static constexpr const char *pyClassName = "OpResultList";
1326
PyOpResultList(PyOperationRef operation,intptr_t startIndex=0,intptr_t length=-1,intptr_t step=1)1327 PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
1328 intptr_t length = -1, intptr_t step = 1)
1329 : Sliceable(startIndex,
1330 length == -1 ? mlirOperationGetNumResults(operation->get())
1331 : length,
1332 step),
1333 operation(operation) {}
1334
getNumElements()1335 intptr_t getNumElements() {
1336 operation->checkValid();
1337 return mlirOperationGetNumResults(operation->get());
1338 }
1339
getElement(intptr_t index)1340 PyOpResult getElement(intptr_t index) {
1341 PyValue value(operation, mlirOperationGetResult(operation->get(), index));
1342 return PyOpResult(value);
1343 }
1344
slice(intptr_t startIndex,intptr_t length,intptr_t step)1345 PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) {
1346 return PyOpResultList(operation, startIndex, length, step);
1347 }
1348
1349 private:
1350 PyOperationRef operation;
1351 };
1352
1353 /// A list of operation attributes. Can be indexed by name, producing
1354 /// attributes, or by index, producing named attributes.
1355 class PyOpAttributeMap {
1356 public:
PyOpAttributeMap(PyOperationRef operation)1357 PyOpAttributeMap(PyOperationRef operation) : operation(operation) {}
1358
dunderGetItemNamed(const std::string & name)1359 PyAttribute dunderGetItemNamed(const std::string &name) {
1360 MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(),
1361 toMlirStringRef(name));
1362 if (mlirAttributeIsNull(attr)) {
1363 throw SetPyError(PyExc_KeyError,
1364 "attempt to access a non-existent attribute");
1365 }
1366 return PyAttribute(operation->getContext(), attr);
1367 }
1368
dunderGetItemIndexed(intptr_t index)1369 PyNamedAttribute dunderGetItemIndexed(intptr_t index) {
1370 if (index < 0 || index >= dunderLen()) {
1371 throw SetPyError(PyExc_IndexError,
1372 "attempt to access out of bounds attribute");
1373 }
1374 MlirNamedAttribute namedAttr =
1375 mlirOperationGetAttribute(operation->get(), index);
1376 return PyNamedAttribute(namedAttr.attribute,
1377 std::string(namedAttr.name.data));
1378 }
1379
dunderSetItem(const std::string & name,PyAttribute attr)1380 void dunderSetItem(const std::string &name, PyAttribute attr) {
1381 mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name),
1382 attr);
1383 }
1384
dunderDelItem(const std::string & name)1385 void dunderDelItem(const std::string &name) {
1386 int removed = mlirOperationRemoveAttributeByName(operation->get(),
1387 toMlirStringRef(name));
1388 if (!removed)
1389 throw SetPyError(PyExc_KeyError,
1390 "attempt to delete a non-existent attribute");
1391 }
1392
dunderLen()1393 intptr_t dunderLen() {
1394 return mlirOperationGetNumAttributes(operation->get());
1395 }
1396
dunderContains(const std::string & name)1397 bool dunderContains(const std::string &name) {
1398 return !mlirAttributeIsNull(mlirOperationGetAttributeByName(
1399 operation->get(), toMlirStringRef(name)));
1400 }
1401
bind(py::module & m)1402 static void bind(py::module &m) {
1403 py::class_<PyOpAttributeMap>(m, "OpAttributeMap")
1404 .def("__contains__", &PyOpAttributeMap::dunderContains)
1405 .def("__len__", &PyOpAttributeMap::dunderLen)
1406 .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed)
1407 .def("__getitem__", &PyOpAttributeMap::dunderGetItemIndexed)
1408 .def("__setitem__", &PyOpAttributeMap::dunderSetItem)
1409 .def("__delitem__", &PyOpAttributeMap::dunderDelItem);
1410 }
1411
1412 private:
1413 PyOperationRef operation;
1414 };
1415
1416 } // end namespace
1417
1418 //------------------------------------------------------------------------------
1419 // Builtin attribute subclasses.
1420 //------------------------------------------------------------------------------
1421
1422 namespace {
1423
1424 /// CRTP base classes for Python attributes that subclass Attribute and should
1425 /// be castable from it (i.e. via something like StringAttr(attr)).
1426 /// By default, attribute class hierarchies are one level deep (i.e. a
1427 /// concrete attribute class extends PyAttribute); however, intermediate
1428 /// python-visible base classes can be modeled by specifying a BaseTy.
1429 template <typename DerivedTy, typename BaseTy = PyAttribute>
1430 class PyConcreteAttribute : public BaseTy {
1431 public:
1432 // Derived classes must define statics for:
1433 // IsAFunctionTy isaFunction
1434 // const char *pyClassName
1435 using ClassTy = py::class_<DerivedTy, BaseTy>;
1436 using IsAFunctionTy = bool (*)(MlirAttribute);
1437
1438 PyConcreteAttribute() = default;
PyConcreteAttribute(PyMlirContextRef contextRef,MlirAttribute attr)1439 PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr)
1440 : BaseTy(std::move(contextRef), attr) {}
PyConcreteAttribute(PyAttribute & orig)1441 PyConcreteAttribute(PyAttribute &orig)
1442 : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {}
1443
castFrom(PyAttribute & orig)1444 static MlirAttribute castFrom(PyAttribute &orig) {
1445 if (!DerivedTy::isaFunction(orig)) {
1446 auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1447 throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") +
1448 DerivedTy::pyClassName +
1449 " (from " + origRepr + ")");
1450 }
1451 return orig;
1452 }
1453
bind(py::module & m)1454 static void bind(py::module &m) {
1455 auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol());
1456 cls.def(py::init<PyAttribute &>(), py::keep_alive<0, 1>());
1457 DerivedTy::bindDerived(cls);
1458 }
1459
1460 /// Implemented by derived classes to add methods to the Python subclass.
bindDerived(ClassTy & m)1461 static void bindDerived(ClassTy &m) {}
1462 };
1463
1464 /// Float Point Attribute subclass - FloatAttr.
1465 class PyFloatAttribute : public PyConcreteAttribute<PyFloatAttribute> {
1466 public:
1467 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat;
1468 static constexpr const char *pyClassName = "FloatAttr";
1469 using PyConcreteAttribute::PyConcreteAttribute;
1470
bindDerived(ClassTy & c)1471 static void bindDerived(ClassTy &c) {
1472 c.def_static(
1473 "get",
1474 [](PyType &type, double value, DefaultingPyLocation loc) {
1475 MlirAttribute attr = mlirFloatAttrDoubleGetChecked(type, value, loc);
1476 // TODO: Rework error reporting once diagnostic engine is exposed
1477 // in C API.
1478 if (mlirAttributeIsNull(attr)) {
1479 throw SetPyError(PyExc_ValueError,
1480 Twine("invalid '") +
1481 py::repr(py::cast(type)).cast<std::string>() +
1482 "' and expected floating point type.");
1483 }
1484 return PyFloatAttribute(type.getContext(), attr);
1485 },
1486 py::arg("type"), py::arg("value"), py::arg("loc") = py::none(),
1487 "Gets an uniqued float point attribute associated to a type");
1488 c.def_static(
1489 "get_f32",
1490 [](double value, DefaultingPyMlirContext context) {
1491 MlirAttribute attr = mlirFloatAttrDoubleGet(
1492 context->get(), mlirF32TypeGet(context->get()), value);
1493 return PyFloatAttribute(context->getRef(), attr);
1494 },
1495 py::arg("value"), py::arg("context") = py::none(),
1496 "Gets an uniqued float point attribute associated to a f32 type");
1497 c.def_static(
1498 "get_f64",
1499 [](double value, DefaultingPyMlirContext context) {
1500 MlirAttribute attr = mlirFloatAttrDoubleGet(
1501 context->get(), mlirF64TypeGet(context->get()), value);
1502 return PyFloatAttribute(context->getRef(), attr);
1503 },
1504 py::arg("value"), py::arg("context") = py::none(),
1505 "Gets an uniqued float point attribute associated to a f64 type");
1506 c.def_property_readonly(
1507 "value",
1508 [](PyFloatAttribute &self) {
1509 return mlirFloatAttrGetValueDouble(self);
1510 },
1511 "Returns the value of the float point attribute");
1512 }
1513 };
1514
1515 /// Integer Attribute subclass - IntegerAttr.
1516 class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
1517 public:
1518 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger;
1519 static constexpr const char *pyClassName = "IntegerAttr";
1520 using PyConcreteAttribute::PyConcreteAttribute;
1521
bindDerived(ClassTy & c)1522 static void bindDerived(ClassTy &c) {
1523 c.def_static(
1524 "get",
1525 [](PyType &type, int64_t value) {
1526 MlirAttribute attr = mlirIntegerAttrGet(type, value);
1527 return PyIntegerAttribute(type.getContext(), attr);
1528 },
1529 py::arg("type"), py::arg("value"),
1530 "Gets an uniqued integer attribute associated to a type");
1531 c.def_property_readonly(
1532 "value",
1533 [](PyIntegerAttribute &self) {
1534 return mlirIntegerAttrGetValueInt(self);
1535 },
1536 "Returns the value of the integer attribute");
1537 }
1538 };
1539
1540 /// Bool Attribute subclass - BoolAttr.
1541 class PyBoolAttribute : public PyConcreteAttribute<PyBoolAttribute> {
1542 public:
1543 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool;
1544 static constexpr const char *pyClassName = "BoolAttr";
1545 using PyConcreteAttribute::PyConcreteAttribute;
1546
bindDerived(ClassTy & c)1547 static void bindDerived(ClassTy &c) {
1548 c.def_static(
1549 "get",
1550 [](bool value, DefaultingPyMlirContext context) {
1551 MlirAttribute attr = mlirBoolAttrGet(context->get(), value);
1552 return PyBoolAttribute(context->getRef(), attr);
1553 },
1554 py::arg("value"), py::arg("context") = py::none(),
1555 "Gets an uniqued bool attribute");
1556 c.def_property_readonly(
1557 "value",
1558 [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); },
1559 "Returns the value of the bool attribute");
1560 }
1561 };
1562
1563 class PyStringAttribute : public PyConcreteAttribute<PyStringAttribute> {
1564 public:
1565 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString;
1566 static constexpr const char *pyClassName = "StringAttr";
1567 using PyConcreteAttribute::PyConcreteAttribute;
1568
bindDerived(ClassTy & c)1569 static void bindDerived(ClassTy &c) {
1570 c.def_static(
1571 "get",
1572 [](std::string value, DefaultingPyMlirContext context) {
1573 MlirAttribute attr =
1574 mlirStringAttrGet(context->get(), toMlirStringRef(value));
1575 return PyStringAttribute(context->getRef(), attr);
1576 },
1577 py::arg("value"), py::arg("context") = py::none(),
1578 "Gets a uniqued string attribute");
1579 c.def_static(
1580 "get_typed",
1581 [](PyType &type, std::string value) {
1582 MlirAttribute attr =
1583 mlirStringAttrTypedGet(type, toMlirStringRef(value));
1584 return PyStringAttribute(type.getContext(), attr);
1585 },
1586
1587 "Gets a uniqued string attribute associated to a type");
1588 c.def_property_readonly(
1589 "value",
1590 [](PyStringAttribute &self) {
1591 MlirStringRef stringRef = mlirStringAttrGetValue(self);
1592 return py::str(stringRef.data, stringRef.length);
1593 },
1594 "Returns the value of the string attribute");
1595 }
1596 };
1597
1598 // TODO: Support construction of bool elements.
1599 // TODO: Support construction of string elements.
1600 class PyDenseElementsAttribute
1601 : public PyConcreteAttribute<PyDenseElementsAttribute> {
1602 public:
1603 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements;
1604 static constexpr const char *pyClassName = "DenseElementsAttr";
1605 using PyConcreteAttribute::PyConcreteAttribute;
1606
1607 static PyDenseElementsAttribute
getFromBuffer(py::buffer array,bool signless,DefaultingPyMlirContext contextWrapper)1608 getFromBuffer(py::buffer array, bool signless,
1609 DefaultingPyMlirContext contextWrapper) {
1610 // Request a contiguous view. In exotic cases, this will cause a copy.
1611 int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT;
1612 Py_buffer *view = new Py_buffer();
1613 if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) {
1614 delete view;
1615 throw py::error_already_set();
1616 }
1617 py::buffer_info arrayInfo(view);
1618
1619 MlirContext context = contextWrapper->get();
1620 // Switch on the types that can be bulk loaded between the Python and
1621 // MLIR-C APIs.
1622 // See: https://docs.python.org/3/library/struct.html#format-characters
1623 if (arrayInfo.format == "f") {
1624 // f32
1625 assert(arrayInfo.itemsize == 4 && "mismatched array itemsize");
1626 return PyDenseElementsAttribute(
1627 contextWrapper->getRef(),
1628 bulkLoad(context, mlirDenseElementsAttrFloatGet,
1629 mlirF32TypeGet(context), arrayInfo));
1630 } else if (arrayInfo.format == "d") {
1631 // f64
1632 assert(arrayInfo.itemsize == 8 && "mismatched array itemsize");
1633 return PyDenseElementsAttribute(
1634 contextWrapper->getRef(),
1635 bulkLoad(context, mlirDenseElementsAttrDoubleGet,
1636 mlirF64TypeGet(context), arrayInfo));
1637 } else if (isSignedIntegerFormat(arrayInfo.format)) {
1638 if (arrayInfo.itemsize == 4) {
1639 // i32
1640 MlirType elementType = signless ? mlirIntegerTypeGet(context, 32)
1641 : mlirIntegerTypeSignedGet(context, 32);
1642 return PyDenseElementsAttribute(contextWrapper->getRef(),
1643 bulkLoad(context,
1644 mlirDenseElementsAttrInt32Get,
1645 elementType, arrayInfo));
1646 } else if (arrayInfo.itemsize == 8) {
1647 // i64
1648 MlirType elementType = signless ? mlirIntegerTypeGet(context, 64)
1649 : mlirIntegerTypeSignedGet(context, 64);
1650 return PyDenseElementsAttribute(contextWrapper->getRef(),
1651 bulkLoad(context,
1652 mlirDenseElementsAttrInt64Get,
1653 elementType, arrayInfo));
1654 }
1655 } else if (isUnsignedIntegerFormat(arrayInfo.format)) {
1656 if (arrayInfo.itemsize == 4) {
1657 // unsigned i32
1658 MlirType elementType = signless
1659 ? mlirIntegerTypeGet(context, 32)
1660 : mlirIntegerTypeUnsignedGet(context, 32);
1661 return PyDenseElementsAttribute(contextWrapper->getRef(),
1662 bulkLoad(context,
1663 mlirDenseElementsAttrUInt32Get,
1664 elementType, arrayInfo));
1665 } else if (arrayInfo.itemsize == 8) {
1666 // unsigned i64
1667 MlirType elementType = signless
1668 ? mlirIntegerTypeGet(context, 64)
1669 : mlirIntegerTypeUnsignedGet(context, 64);
1670 return PyDenseElementsAttribute(contextWrapper->getRef(),
1671 bulkLoad(context,
1672 mlirDenseElementsAttrUInt64Get,
1673 elementType, arrayInfo));
1674 }
1675 }
1676
1677 // TODO: Fall back to string-based get.
1678 std::string message = "unimplemented array format conversion from format: ";
1679 message.append(arrayInfo.format);
1680 throw SetPyError(PyExc_ValueError, message);
1681 }
1682
getSplat(PyType shapedType,PyAttribute & elementAttr)1683 static PyDenseElementsAttribute getSplat(PyType shapedType,
1684 PyAttribute &elementAttr) {
1685 auto contextWrapper =
1686 PyMlirContext::forContext(mlirTypeGetContext(shapedType));
1687 if (!mlirAttributeIsAInteger(elementAttr) &&
1688 !mlirAttributeIsAFloat(elementAttr)) {
1689 std::string message = "Illegal element type for DenseElementsAttr: ";
1690 message.append(py::repr(py::cast(elementAttr)));
1691 throw SetPyError(PyExc_ValueError, message);
1692 }
1693 if (!mlirTypeIsAShaped(shapedType) ||
1694 !mlirShapedTypeHasStaticShape(shapedType)) {
1695 std::string message =
1696 "Expected a static ShapedType for the shaped_type parameter: ";
1697 message.append(py::repr(py::cast(shapedType)));
1698 throw SetPyError(PyExc_ValueError, message);
1699 }
1700 MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType);
1701 MlirType attrType = mlirAttributeGetType(elementAttr);
1702 if (!mlirTypeEqual(shapedElementType, attrType)) {
1703 std::string message =
1704 "Shaped element type and attribute type must be equal: shaped=";
1705 message.append(py::repr(py::cast(shapedType)));
1706 message.append(", element=");
1707 message.append(py::repr(py::cast(elementAttr)));
1708 throw SetPyError(PyExc_ValueError, message);
1709 }
1710
1711 MlirAttribute elements =
1712 mlirDenseElementsAttrSplatGet(shapedType, elementAttr);
1713 return PyDenseElementsAttribute(contextWrapper->getRef(), elements);
1714 }
1715
dunderLen()1716 intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); }
1717
accessBuffer()1718 py::buffer_info accessBuffer() {
1719 MlirType shapedType = mlirAttributeGetType(*this);
1720 MlirType elementType = mlirShapedTypeGetElementType(shapedType);
1721
1722 if (mlirTypeIsAF32(elementType)) {
1723 // f32
1724 return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue);
1725 } else if (mlirTypeIsAF64(elementType)) {
1726 // f64
1727 return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue);
1728 } else if (mlirTypeIsAInteger(elementType) &&
1729 mlirIntegerTypeGetWidth(elementType) == 32) {
1730 if (mlirIntegerTypeIsSignless(elementType) ||
1731 mlirIntegerTypeIsSigned(elementType)) {
1732 // i32
1733 return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value);
1734 } else if (mlirIntegerTypeIsUnsigned(elementType)) {
1735 // unsigned i32
1736 return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value);
1737 }
1738 } else if (mlirTypeIsAInteger(elementType) &&
1739 mlirIntegerTypeGetWidth(elementType) == 64) {
1740 if (mlirIntegerTypeIsSignless(elementType) ||
1741 mlirIntegerTypeIsSigned(elementType)) {
1742 // i64
1743 return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value);
1744 } else if (mlirIntegerTypeIsUnsigned(elementType)) {
1745 // unsigned i64
1746 return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value);
1747 }
1748 }
1749
1750 std::string message = "unimplemented array format.";
1751 throw SetPyError(PyExc_ValueError, message);
1752 }
1753
bindDerived(ClassTy & c)1754 static void bindDerived(ClassTy &c) {
1755 c.def("__len__", &PyDenseElementsAttribute::dunderLen)
1756 .def_static("get", PyDenseElementsAttribute::getFromBuffer,
1757 py::arg("array"), py::arg("signless") = true,
1758 py::arg("context") = py::none(),
1759 "Gets from a buffer or ndarray")
1760 .def_static("get_splat", PyDenseElementsAttribute::getSplat,
1761 py::arg("shaped_type"), py::arg("element_attr"),
1762 "Gets a DenseElementsAttr where all values are the same")
1763 .def_property_readonly("is_splat",
1764 [](PyDenseElementsAttribute &self) -> bool {
1765 return mlirDenseElementsAttrIsSplat(self);
1766 })
1767 .def_buffer(&PyDenseElementsAttribute::accessBuffer);
1768 }
1769
1770 private:
1771 template <typename ElementTy>
1772 static MlirAttribute
bulkLoad(MlirContext context,MlirAttribute (* ctor)(MlirType,intptr_t,ElementTy *),MlirType mlirElementType,py::buffer_info & arrayInfo)1773 bulkLoad(MlirContext context,
1774 MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *),
1775 MlirType mlirElementType, py::buffer_info &arrayInfo) {
1776 SmallVector<int64_t, 4> shape(arrayInfo.shape.begin(),
1777 arrayInfo.shape.begin() + arrayInfo.ndim);
1778 auto shapedType =
1779 mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType);
1780 intptr_t numElements = arrayInfo.size;
1781 const ElementTy *contents = static_cast<const ElementTy *>(arrayInfo.ptr);
1782 return ctor(shapedType, numElements, contents);
1783 }
1784
isUnsignedIntegerFormat(const std::string & format)1785 static bool isUnsignedIntegerFormat(const std::string &format) {
1786 if (format.empty())
1787 return false;
1788 char code = format[0];
1789 return code == 'I' || code == 'B' || code == 'H' || code == 'L' ||
1790 code == 'Q';
1791 }
1792
isSignedIntegerFormat(const std::string & format)1793 static bool isSignedIntegerFormat(const std::string &format) {
1794 if (format.empty())
1795 return false;
1796 char code = format[0];
1797 return code == 'i' || code == 'b' || code == 'h' || code == 'l' ||
1798 code == 'q';
1799 }
1800
1801 template <typename Type>
bufferInfo(MlirType shapedType,Type (* value)(MlirAttribute,intptr_t))1802 py::buffer_info bufferInfo(MlirType shapedType,
1803 Type (*value)(MlirAttribute, intptr_t)) {
1804 intptr_t rank = mlirShapedTypeGetRank(shapedType);
1805 // Prepare the data for the buffer_info.
1806 // Buffer is configured for read-only access below.
1807 Type *data = static_cast<Type *>(
1808 const_cast<void *>(mlirDenseElementsAttrGetRawData(*this)));
1809 // Prepare the shape for the buffer_info.
1810 SmallVector<intptr_t, 4> shape;
1811 for (intptr_t i = 0; i < rank; ++i)
1812 shape.push_back(mlirShapedTypeGetDimSize(shapedType, i));
1813 // Prepare the strides for the buffer_info.
1814 SmallVector<intptr_t, 4> strides;
1815 intptr_t strideFactor = 1;
1816 for (intptr_t i = 1; i < rank; ++i) {
1817 strideFactor = 1;
1818 for (intptr_t j = i; j < rank; ++j) {
1819 strideFactor *= mlirShapedTypeGetDimSize(shapedType, j);
1820 }
1821 strides.push_back(sizeof(Type) * strideFactor);
1822 }
1823 strides.push_back(sizeof(Type));
1824 return py::buffer_info(data, sizeof(Type),
1825 py::format_descriptor<Type>::format(), rank, shape,
1826 strides, /*readonly=*/true);
1827 }
1828 }; // namespace
1829
1830 /// Refinement of the PyDenseElementsAttribute for attributes containing integer
1831 /// (and boolean) values. Supports element access.
1832 class PyDenseIntElementsAttribute
1833 : public PyConcreteAttribute<PyDenseIntElementsAttribute,
1834 PyDenseElementsAttribute> {
1835 public:
1836 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements;
1837 static constexpr const char *pyClassName = "DenseIntElementsAttr";
1838 using PyConcreteAttribute::PyConcreteAttribute;
1839
1840 /// Returns the element at the given linear position. Asserts if the index is
1841 /// out of range.
dunderGetItem(intptr_t pos)1842 py::int_ dunderGetItem(intptr_t pos) {
1843 if (pos < 0 || pos >= dunderLen()) {
1844 throw SetPyError(PyExc_IndexError,
1845 "attempt to access out of bounds element");
1846 }
1847
1848 MlirType type = mlirAttributeGetType(*this);
1849 type = mlirShapedTypeGetElementType(type);
1850 assert(mlirTypeIsAInteger(type) &&
1851 "expected integer element type in dense int elements attribute");
1852 // Dispatch element extraction to an appropriate C function based on the
1853 // elemental type of the attribute. py::int_ is implicitly constructible
1854 // from any C++ integral type and handles bitwidth correctly.
1855 // TODO: consider caching the type properties in the constructor to avoid
1856 // querying them on each element access.
1857 unsigned width = mlirIntegerTypeGetWidth(type);
1858 bool isUnsigned = mlirIntegerTypeIsUnsigned(type);
1859 if (isUnsigned) {
1860 if (width == 1) {
1861 return mlirDenseElementsAttrGetBoolValue(*this, pos);
1862 }
1863 if (width == 32) {
1864 return mlirDenseElementsAttrGetUInt32Value(*this, pos);
1865 }
1866 if (width == 64) {
1867 return mlirDenseElementsAttrGetUInt64Value(*this, pos);
1868 }
1869 } else {
1870 if (width == 1) {
1871 return mlirDenseElementsAttrGetBoolValue(*this, pos);
1872 }
1873 if (width == 32) {
1874 return mlirDenseElementsAttrGetInt32Value(*this, pos);
1875 }
1876 if (width == 64) {
1877 return mlirDenseElementsAttrGetInt64Value(*this, pos);
1878 }
1879 }
1880 throw SetPyError(PyExc_TypeError, "Unsupported integer type");
1881 }
1882
bindDerived(ClassTy & c)1883 static void bindDerived(ClassTy &c) {
1884 c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem);
1885 }
1886 };
1887
1888 /// Refinement of PyDenseElementsAttribute for attributes containing
1889 /// floating-point values. Supports element access.
1890 class PyDenseFPElementsAttribute
1891 : public PyConcreteAttribute<PyDenseFPElementsAttribute,
1892 PyDenseElementsAttribute> {
1893 public:
1894 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements;
1895 static constexpr const char *pyClassName = "DenseFPElementsAttr";
1896 using PyConcreteAttribute::PyConcreteAttribute;
1897
dunderGetItem(intptr_t pos)1898 py::float_ dunderGetItem(intptr_t pos) {
1899 if (pos < 0 || pos >= dunderLen()) {
1900 throw SetPyError(PyExc_IndexError,
1901 "attempt to access out of bounds element");
1902 }
1903
1904 MlirType type = mlirAttributeGetType(*this);
1905 type = mlirShapedTypeGetElementType(type);
1906 // Dispatch element extraction to an appropriate C function based on the
1907 // elemental type of the attribute. py::float_ is implicitly constructible
1908 // from float and double.
1909 // TODO: consider caching the type properties in the constructor to avoid
1910 // querying them on each element access.
1911 if (mlirTypeIsAF32(type)) {
1912 return mlirDenseElementsAttrGetFloatValue(*this, pos);
1913 }
1914 if (mlirTypeIsAF64(type)) {
1915 return mlirDenseElementsAttrGetDoubleValue(*this, pos);
1916 }
1917 throw SetPyError(PyExc_TypeError, "Unsupported floating-point type");
1918 }
1919
bindDerived(ClassTy & c)1920 static void bindDerived(ClassTy &c) {
1921 c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem);
1922 }
1923 };
1924
1925 class PyTypeAttribute : public PyConcreteAttribute<PyTypeAttribute> {
1926 public:
1927 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType;
1928 static constexpr const char *pyClassName = "TypeAttr";
1929 using PyConcreteAttribute::PyConcreteAttribute;
1930
bindDerived(ClassTy & c)1931 static void bindDerived(ClassTy &c) {
1932 c.def_static(
1933 "get",
1934 [](PyType value, DefaultingPyMlirContext context) {
1935 MlirAttribute attr = mlirTypeAttrGet(value.get());
1936 return PyTypeAttribute(context->getRef(), attr);
1937 },
1938 py::arg("value"), py::arg("context") = py::none(),
1939 "Gets a uniqued Type attribute");
1940 c.def_property_readonly("value", [](PyTypeAttribute &self) {
1941 return PyType(self.getContext()->getRef(),
1942 mlirTypeAttrGetValue(self.get()));
1943 });
1944 }
1945 };
1946
1947 /// Unit Attribute subclass. Unit attributes don't have values.
1948 class PyUnitAttribute : public PyConcreteAttribute<PyUnitAttribute> {
1949 public:
1950 static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit;
1951 static constexpr const char *pyClassName = "UnitAttr";
1952 using PyConcreteAttribute::PyConcreteAttribute;
1953
bindDerived(ClassTy & c)1954 static void bindDerived(ClassTy &c) {
1955 c.def_static(
1956 "get",
1957 [](DefaultingPyMlirContext context) {
1958 return PyUnitAttribute(context->getRef(),
1959 mlirUnitAttrGet(context->get()));
1960 },
1961 py::arg("context") = py::none(), "Create a Unit attribute.");
1962 }
1963 };
1964
1965 } // namespace
1966
1967 //------------------------------------------------------------------------------
1968 // Builtin type subclasses.
1969 //------------------------------------------------------------------------------
1970
1971 namespace {
1972
1973 /// CRTP base classes for Python types that subclass Type and should be
1974 /// castable from it (i.e. via something like IntegerType(t)).
1975 /// By default, type class hierarchies are one level deep (i.e. a
1976 /// concrete type class extends PyType); however, intermediate python-visible
1977 /// base classes can be modeled by specifying a BaseTy.
1978 template <typename DerivedTy, typename BaseTy = PyType>
1979 class PyConcreteType : public BaseTy {
1980 public:
1981 // Derived classes must define statics for:
1982 // IsAFunctionTy isaFunction
1983 // const char *pyClassName
1984 using ClassTy = py::class_<DerivedTy, BaseTy>;
1985 using IsAFunctionTy = bool (*)(MlirType);
1986
1987 PyConcreteType() = default;
PyConcreteType(PyMlirContextRef contextRef,MlirType t)1988 PyConcreteType(PyMlirContextRef contextRef, MlirType t)
1989 : BaseTy(std::move(contextRef), t) {}
PyConcreteType(PyType & orig)1990 PyConcreteType(PyType &orig)
1991 : PyConcreteType(orig.getContext(), castFrom(orig)) {}
1992
castFrom(PyType & orig)1993 static MlirType castFrom(PyType &orig) {
1994 if (!DerivedTy::isaFunction(orig)) {
1995 auto origRepr = py::repr(py::cast(orig)).cast<std::string>();
1996 throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") +
1997 DerivedTy::pyClassName +
1998 " (from " + origRepr + ")");
1999 }
2000 return orig;
2001 }
2002
bind(py::module & m)2003 static void bind(py::module &m) {
2004 auto cls = ClassTy(m, DerivedTy::pyClassName);
2005 cls.def(py::init<PyType &>(), py::keep_alive<0, 1>());
2006 DerivedTy::bindDerived(cls);
2007 }
2008
2009 /// Implemented by derived classes to add methods to the Python subclass.
bindDerived(ClassTy & m)2010 static void bindDerived(ClassTy &m) {}
2011 };
2012
2013 class PyIntegerType : public PyConcreteType<PyIntegerType> {
2014 public:
2015 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
2016 static constexpr const char *pyClassName = "IntegerType";
2017 using PyConcreteType::PyConcreteType;
2018
bindDerived(ClassTy & c)2019 static void bindDerived(ClassTy &c) {
2020 c.def_static(
2021 "get_signless",
2022 [](unsigned width, DefaultingPyMlirContext context) {
2023 MlirType t = mlirIntegerTypeGet(context->get(), width);
2024 return PyIntegerType(context->getRef(), t);
2025 },
2026 py::arg("width"), py::arg("context") = py::none(),
2027 "Create a signless integer type");
2028 c.def_static(
2029 "get_signed",
2030 [](unsigned width, DefaultingPyMlirContext context) {
2031 MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
2032 return PyIntegerType(context->getRef(), t);
2033 },
2034 py::arg("width"), py::arg("context") = py::none(),
2035 "Create a signed integer type");
2036 c.def_static(
2037 "get_unsigned",
2038 [](unsigned width, DefaultingPyMlirContext context) {
2039 MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
2040 return PyIntegerType(context->getRef(), t);
2041 },
2042 py::arg("width"), py::arg("context") = py::none(),
2043 "Create an unsigned integer type");
2044 c.def_property_readonly(
2045 "width",
2046 [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
2047 "Returns the width of the integer type");
2048 c.def_property_readonly(
2049 "is_signless",
2050 [](PyIntegerType &self) -> bool {
2051 return mlirIntegerTypeIsSignless(self);
2052 },
2053 "Returns whether this is a signless integer");
2054 c.def_property_readonly(
2055 "is_signed",
2056 [](PyIntegerType &self) -> bool {
2057 return mlirIntegerTypeIsSigned(self);
2058 },
2059 "Returns whether this is a signed integer");
2060 c.def_property_readonly(
2061 "is_unsigned",
2062 [](PyIntegerType &self) -> bool {
2063 return mlirIntegerTypeIsUnsigned(self);
2064 },
2065 "Returns whether this is an unsigned integer");
2066 }
2067 };
2068
2069 /// Index Type subclass - IndexType.
2070 class PyIndexType : public PyConcreteType<PyIndexType> {
2071 public:
2072 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
2073 static constexpr const char *pyClassName = "IndexType";
2074 using PyConcreteType::PyConcreteType;
2075
bindDerived(ClassTy & c)2076 static void bindDerived(ClassTy &c) {
2077 c.def_static(
2078 "get",
2079 [](DefaultingPyMlirContext context) {
2080 MlirType t = mlirIndexTypeGet(context->get());
2081 return PyIndexType(context->getRef(), t);
2082 },
2083 py::arg("context") = py::none(), "Create a index type.");
2084 }
2085 };
2086
2087 /// Floating Point Type subclass - BF16Type.
2088 class PyBF16Type : public PyConcreteType<PyBF16Type> {
2089 public:
2090 static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
2091 static constexpr const char *pyClassName = "BF16Type";
2092 using PyConcreteType::PyConcreteType;
2093
bindDerived(ClassTy & c)2094 static void bindDerived(ClassTy &c) {
2095 c.def_static(
2096 "get",
2097 [](DefaultingPyMlirContext context) {
2098 MlirType t = mlirBF16TypeGet(context->get());
2099 return PyBF16Type(context->getRef(), t);
2100 },
2101 py::arg("context") = py::none(), "Create a bf16 type.");
2102 }
2103 };
2104
2105 /// Floating Point Type subclass - F16Type.
2106 class PyF16Type : public PyConcreteType<PyF16Type> {
2107 public:
2108 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
2109 static constexpr const char *pyClassName = "F16Type";
2110 using PyConcreteType::PyConcreteType;
2111
bindDerived(ClassTy & c)2112 static void bindDerived(ClassTy &c) {
2113 c.def_static(
2114 "get",
2115 [](DefaultingPyMlirContext context) {
2116 MlirType t = mlirF16TypeGet(context->get());
2117 return PyF16Type(context->getRef(), t);
2118 },
2119 py::arg("context") = py::none(), "Create a f16 type.");
2120 }
2121 };
2122
2123 /// Floating Point Type subclass - F32Type.
2124 class PyF32Type : public PyConcreteType<PyF32Type> {
2125 public:
2126 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
2127 static constexpr const char *pyClassName = "F32Type";
2128 using PyConcreteType::PyConcreteType;
2129
bindDerived(ClassTy & c)2130 static void bindDerived(ClassTy &c) {
2131 c.def_static(
2132 "get",
2133 [](DefaultingPyMlirContext context) {
2134 MlirType t = mlirF32TypeGet(context->get());
2135 return PyF32Type(context->getRef(), t);
2136 },
2137 py::arg("context") = py::none(), "Create a f32 type.");
2138 }
2139 };
2140
2141 /// Floating Point Type subclass - F64Type.
2142 class PyF64Type : public PyConcreteType<PyF64Type> {
2143 public:
2144 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
2145 static constexpr const char *pyClassName = "F64Type";
2146 using PyConcreteType::PyConcreteType;
2147
bindDerived(ClassTy & c)2148 static void bindDerived(ClassTy &c) {
2149 c.def_static(
2150 "get",
2151 [](DefaultingPyMlirContext context) {
2152 MlirType t = mlirF64TypeGet(context->get());
2153 return PyF64Type(context->getRef(), t);
2154 },
2155 py::arg("context") = py::none(), "Create a f64 type.");
2156 }
2157 };
2158
2159 /// None Type subclass - NoneType.
2160 class PyNoneType : public PyConcreteType<PyNoneType> {
2161 public:
2162 static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
2163 static constexpr const char *pyClassName = "NoneType";
2164 using PyConcreteType::PyConcreteType;
2165
bindDerived(ClassTy & c)2166 static void bindDerived(ClassTy &c) {
2167 c.def_static(
2168 "get",
2169 [](DefaultingPyMlirContext context) {
2170 MlirType t = mlirNoneTypeGet(context->get());
2171 return PyNoneType(context->getRef(), t);
2172 },
2173 py::arg("context") = py::none(), "Create a none type.");
2174 }
2175 };
2176
2177 /// Complex Type subclass - ComplexType.
2178 class PyComplexType : public PyConcreteType<PyComplexType> {
2179 public:
2180 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
2181 static constexpr const char *pyClassName = "ComplexType";
2182 using PyConcreteType::PyConcreteType;
2183
bindDerived(ClassTy & c)2184 static void bindDerived(ClassTy &c) {
2185 c.def_static(
2186 "get",
2187 [](PyType &elementType) {
2188 // The element must be a floating point or integer scalar type.
2189 if (mlirTypeIsAIntegerOrFloat(elementType)) {
2190 MlirType t = mlirComplexTypeGet(elementType);
2191 return PyComplexType(elementType.getContext(), t);
2192 }
2193 throw SetPyError(
2194 PyExc_ValueError,
2195 Twine("invalid '") +
2196 py::repr(py::cast(elementType)).cast<std::string>() +
2197 "' and expected floating point or integer type.");
2198 },
2199 "Create a complex type");
2200 c.def_property_readonly(
2201 "element_type",
2202 [](PyComplexType &self) -> PyType {
2203 MlirType t = mlirComplexTypeGetElementType(self);
2204 return PyType(self.getContext(), t);
2205 },
2206 "Returns element type.");
2207 }
2208 };
2209
2210 class PyShapedType : public PyConcreteType<PyShapedType> {
2211 public:
2212 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped;
2213 static constexpr const char *pyClassName = "ShapedType";
2214 using PyConcreteType::PyConcreteType;
2215
bindDerived(ClassTy & c)2216 static void bindDerived(ClassTy &c) {
2217 c.def_property_readonly(
2218 "element_type",
2219 [](PyShapedType &self) {
2220 MlirType t = mlirShapedTypeGetElementType(self);
2221 return PyType(self.getContext(), t);
2222 },
2223 "Returns the element type of the shaped type.");
2224 c.def_property_readonly(
2225 "has_rank",
2226 [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
2227 "Returns whether the given shaped type is ranked.");
2228 c.def_property_readonly(
2229 "rank",
2230 [](PyShapedType &self) {
2231 self.requireHasRank();
2232 return mlirShapedTypeGetRank(self);
2233 },
2234 "Returns the rank of the given ranked shaped type.");
2235 c.def_property_readonly(
2236 "has_static_shape",
2237 [](PyShapedType &self) -> bool {
2238 return mlirShapedTypeHasStaticShape(self);
2239 },
2240 "Returns whether the given shaped type has a static shape.");
2241 c.def(
2242 "is_dynamic_dim",
2243 [](PyShapedType &self, intptr_t dim) -> bool {
2244 self.requireHasRank();
2245 return mlirShapedTypeIsDynamicDim(self, dim);
2246 },
2247 "Returns whether the dim-th dimension of the given shaped type is "
2248 "dynamic.");
2249 c.def(
2250 "get_dim_size",
2251 [](PyShapedType &self, intptr_t dim) {
2252 self.requireHasRank();
2253 return mlirShapedTypeGetDimSize(self, dim);
2254 },
2255 "Returns the dim-th dimension of the given ranked shaped type.");
2256 c.def_static(
2257 "is_dynamic_size",
2258 [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
2259 "Returns whether the given dimension size indicates a dynamic "
2260 "dimension.");
2261 c.def(
2262 "is_dynamic_stride_or_offset",
2263 [](PyShapedType &self, int64_t val) -> bool {
2264 self.requireHasRank();
2265 return mlirShapedTypeIsDynamicStrideOrOffset(val);
2266 },
2267 "Returns whether the given value is used as a placeholder for dynamic "
2268 "strides and offsets in shaped types.");
2269 }
2270
2271 private:
requireHasRank()2272 void requireHasRank() {
2273 if (!mlirShapedTypeHasRank(*this)) {
2274 throw SetPyError(
2275 PyExc_ValueError,
2276 "calling this method requires that the type has a rank.");
2277 }
2278 }
2279 };
2280
2281 /// Vector Type subclass - VectorType.
2282 class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
2283 public:
2284 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
2285 static constexpr const char *pyClassName = "VectorType";
2286 using PyConcreteType::PyConcreteType;
2287
bindDerived(ClassTy & c)2288 static void bindDerived(ClassTy &c) {
2289 c.def_static(
2290 "get",
2291 [](std::vector<int64_t> shape, PyType &elementType,
2292 DefaultingPyLocation loc) {
2293 MlirType t = mlirVectorTypeGetChecked(shape.size(), shape.data(),
2294 elementType, loc);
2295 // TODO: Rework error reporting once diagnostic engine is exposed
2296 // in C API.
2297 if (mlirTypeIsNull(t)) {
2298 throw SetPyError(
2299 PyExc_ValueError,
2300 Twine("invalid '") +
2301 py::repr(py::cast(elementType)).cast<std::string>() +
2302 "' and expected floating point or integer type.");
2303 }
2304 return PyVectorType(elementType.getContext(), t);
2305 },
2306 py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(),
2307 "Create a vector type");
2308 }
2309 };
2310
2311 /// Ranked Tensor Type subclass - RankedTensorType.
2312 class PyRankedTensorType
2313 : public PyConcreteType<PyRankedTensorType, PyShapedType> {
2314 public:
2315 static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
2316 static constexpr const char *pyClassName = "RankedTensorType";
2317 using PyConcreteType::PyConcreteType;
2318
bindDerived(ClassTy & c)2319 static void bindDerived(ClassTy &c) {
2320 c.def_static(
2321 "get",
2322 [](std::vector<int64_t> shape, PyType &elementType,
2323 DefaultingPyLocation loc) {
2324 MlirType t = mlirRankedTensorTypeGetChecked(
2325 shape.size(), shape.data(), elementType, loc);
2326 // TODO: Rework error reporting once diagnostic engine is exposed
2327 // in C API.
2328 if (mlirTypeIsNull(t)) {
2329 throw SetPyError(
2330 PyExc_ValueError,
2331 Twine("invalid '") +
2332 py::repr(py::cast(elementType)).cast<std::string>() +
2333 "' and expected floating point, integer, vector or "
2334 "complex "
2335 "type.");
2336 }
2337 return PyRankedTensorType(elementType.getContext(), t);
2338 },
2339 py::arg("shape"), py::arg("element_type"), py::arg("loc") = py::none(),
2340 "Create a ranked tensor type");
2341 }
2342 };
2343
2344 /// Unranked Tensor Type subclass - UnrankedTensorType.
2345 class PyUnrankedTensorType
2346 : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
2347 public:
2348 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
2349 static constexpr const char *pyClassName = "UnrankedTensorType";
2350 using PyConcreteType::PyConcreteType;
2351
bindDerived(ClassTy & c)2352 static void bindDerived(ClassTy &c) {
2353 c.def_static(
2354 "get",
2355 [](PyType &elementType, DefaultingPyLocation loc) {
2356 MlirType t = mlirUnrankedTensorTypeGetChecked(elementType, loc);
2357 // TODO: Rework error reporting once diagnostic engine is exposed
2358 // in C API.
2359 if (mlirTypeIsNull(t)) {
2360 throw SetPyError(
2361 PyExc_ValueError,
2362 Twine("invalid '") +
2363 py::repr(py::cast(elementType)).cast<std::string>() +
2364 "' and expected floating point, integer, vector or "
2365 "complex "
2366 "type.");
2367 }
2368 return PyUnrankedTensorType(elementType.getContext(), t);
2369 },
2370 py::arg("element_type"), py::arg("loc") = py::none(),
2371 "Create a unranked tensor type");
2372 }
2373 };
2374
2375 /// Ranked MemRef Type subclass - MemRefType.
2376 class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
2377 public:
2378 static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
2379 static constexpr const char *pyClassName = "MemRefType";
2380 using PyConcreteType::PyConcreteType;
2381
bindDerived(ClassTy & c)2382 static void bindDerived(ClassTy &c) {
2383 // TODO: Add mlirMemRefTypeGet and mlirMemRefTypeGetAffineMap binding
2384 // once the affine map binding is completed.
2385 c.def_static(
2386 "get_contiguous_memref",
2387 // TODO: Make the location optional and create a default location.
2388 [](PyType &elementType, std::vector<int64_t> shape,
2389 unsigned memorySpace, DefaultingPyLocation loc) {
2390 MlirType t = mlirMemRefTypeContiguousGetChecked(
2391 elementType, shape.size(), shape.data(), memorySpace, loc);
2392 // TODO: Rework error reporting once diagnostic engine is exposed
2393 // in C API.
2394 if (mlirTypeIsNull(t)) {
2395 throw SetPyError(
2396 PyExc_ValueError,
2397 Twine("invalid '") +
2398 py::repr(py::cast(elementType)).cast<std::string>() +
2399 "' and expected floating point, integer, vector or "
2400 "complex "
2401 "type.");
2402 }
2403 return PyMemRefType(elementType.getContext(), t);
2404 },
2405 py::arg("element_type"), py::arg("shape"), py::arg("memory_space"),
2406 py::arg("loc") = py::none(), "Create a memref type")
2407 .def_property_readonly(
2408 "num_affine_maps",
2409 [](PyMemRefType &self) -> intptr_t {
2410 return mlirMemRefTypeGetNumAffineMaps(self);
2411 },
2412 "Returns the number of affine layout maps in the given MemRef "
2413 "type.")
2414 .def_property_readonly(
2415 "memory_space",
2416 [](PyMemRefType &self) -> unsigned {
2417 return mlirMemRefTypeGetMemorySpace(self);
2418 },
2419 "Returns the memory space of the given MemRef type.");
2420 }
2421 };
2422
2423 /// Unranked MemRef Type subclass - UnrankedMemRefType.
2424 class PyUnrankedMemRefType
2425 : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
2426 public:
2427 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
2428 static constexpr const char *pyClassName = "UnrankedMemRefType";
2429 using PyConcreteType::PyConcreteType;
2430
bindDerived(ClassTy & c)2431 static void bindDerived(ClassTy &c) {
2432 c.def_static(
2433 "get",
2434 [](PyType &elementType, unsigned memorySpace,
2435 DefaultingPyLocation loc) {
2436 MlirType t =
2437 mlirUnrankedMemRefTypeGetChecked(elementType, memorySpace, loc);
2438 // TODO: Rework error reporting once diagnostic engine is exposed
2439 // in C API.
2440 if (mlirTypeIsNull(t)) {
2441 throw SetPyError(
2442 PyExc_ValueError,
2443 Twine("invalid '") +
2444 py::repr(py::cast(elementType)).cast<std::string>() +
2445 "' and expected floating point, integer, vector or "
2446 "complex "
2447 "type.");
2448 }
2449 return PyUnrankedMemRefType(elementType.getContext(), t);
2450 },
2451 py::arg("element_type"), py::arg("memory_space"),
2452 py::arg("loc") = py::none(), "Create a unranked memref type")
2453 .def_property_readonly(
2454 "memory_space",
2455 [](PyUnrankedMemRefType &self) -> unsigned {
2456 return mlirUnrankedMemrefGetMemorySpace(self);
2457 },
2458 "Returns the memory space of the given Unranked MemRef type.");
2459 }
2460 };
2461
2462 /// Tuple Type subclass - TupleType.
2463 class PyTupleType : public PyConcreteType<PyTupleType> {
2464 public:
2465 static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
2466 static constexpr const char *pyClassName = "TupleType";
2467 using PyConcreteType::PyConcreteType;
2468
bindDerived(ClassTy & c)2469 static void bindDerived(ClassTy &c) {
2470 c.def_static(
2471 "get_tuple",
2472 [](py::list elementList, DefaultingPyMlirContext context) {
2473 intptr_t num = py::len(elementList);
2474 // Mapping py::list to SmallVector.
2475 SmallVector<MlirType, 4> elements;
2476 for (auto element : elementList)
2477 elements.push_back(element.cast<PyType>());
2478 MlirType t = mlirTupleTypeGet(context->get(), num, elements.data());
2479 return PyTupleType(context->getRef(), t);
2480 },
2481 py::arg("elements"), py::arg("context") = py::none(),
2482 "Create a tuple type");
2483 c.def(
2484 "get_type",
2485 [](PyTupleType &self, intptr_t pos) -> PyType {
2486 MlirType t = mlirTupleTypeGetType(self, pos);
2487 return PyType(self.getContext(), t);
2488 },
2489 "Returns the pos-th type in the tuple type.");
2490 c.def_property_readonly(
2491 "num_types",
2492 [](PyTupleType &self) -> intptr_t {
2493 return mlirTupleTypeGetNumTypes(self);
2494 },
2495 "Returns the number of types contained in a tuple.");
2496 }
2497 };
2498
2499 /// Function type.
2500 class PyFunctionType : public PyConcreteType<PyFunctionType> {
2501 public:
2502 static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
2503 static constexpr const char *pyClassName = "FunctionType";
2504 using PyConcreteType::PyConcreteType;
2505
bindDerived(ClassTy & c)2506 static void bindDerived(ClassTy &c) {
2507 c.def_static(
2508 "get",
2509 [](std::vector<PyType> inputs, std::vector<PyType> results,
2510 DefaultingPyMlirContext context) {
2511 SmallVector<MlirType, 4> inputsRaw(inputs.begin(), inputs.end());
2512 SmallVector<MlirType, 4> resultsRaw(results.begin(), results.end());
2513 MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(),
2514 inputsRaw.data(), resultsRaw.size(),
2515 resultsRaw.data());
2516 return PyFunctionType(context->getRef(), t);
2517 },
2518 py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(),
2519 "Gets a FunctionType from a list of input and result types");
2520 c.def_property_readonly(
2521 "inputs",
2522 [](PyFunctionType &self) {
2523 MlirType t = self;
2524 auto contextRef = self.getContext();
2525 py::list types;
2526 for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
2527 ++i) {
2528 types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i)));
2529 }
2530 return types;
2531 },
2532 "Returns the list of input types in the FunctionType.");
2533 c.def_property_readonly(
2534 "results",
2535 [](PyFunctionType &self) {
2536 auto contextRef = self.getContext();
2537 py::list types;
2538 for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
2539 ++i) {
2540 types.append(
2541 PyType(contextRef, mlirFunctionTypeGetResult(self, i)));
2542 }
2543 return types;
2544 },
2545 "Returns the list of result types in the FunctionType.");
2546 }
2547 };
2548
2549 } // namespace
2550
2551 //------------------------------------------------------------------------------
2552 // Populates the pybind11 IR submodule.
2553 //------------------------------------------------------------------------------
2554
populateIRSubmodule(py::module & m)2555 void mlir::python::populateIRSubmodule(py::module &m) {
2556 //----------------------------------------------------------------------------
2557 // Mapping of MlirContext
2558 //----------------------------------------------------------------------------
2559 py::class_<PyMlirContext>(m, "Context")
2560 .def(py::init<>(&PyMlirContext::createNewContextForInit))
2561 .def_static("_get_live_count", &PyMlirContext::getLiveCount)
2562 .def("_get_context_again",
2563 [](PyMlirContext &self) {
2564 PyMlirContextRef ref = PyMlirContext::forContext(self.get());
2565 return ref.releaseObject();
2566 })
2567 .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
2568 .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
2569 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2570 &PyMlirContext::getCapsule)
2571 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
2572 .def("__enter__", &PyMlirContext::contextEnter)
2573 .def("__exit__", &PyMlirContext::contextExit)
2574 .def_property_readonly_static(
2575 "current",
2576 [](py::object & /*class*/) {
2577 auto *context = PyThreadContextEntry::getDefaultContext();
2578 if (!context)
2579 throw SetPyError(PyExc_ValueError, "No current Context");
2580 return context;
2581 },
2582 "Gets the Context bound to the current thread or raises ValueError")
2583 .def_property_readonly(
2584 "dialects",
2585 [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2586 "Gets a container for accessing dialects by name")
2587 .def_property_readonly(
2588 "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); },
2589 "Alias for 'dialect'")
2590 .def(
2591 "get_dialect_descriptor",
2592 [=](PyMlirContext &self, std::string &name) {
2593 MlirDialect dialect = mlirContextGetOrLoadDialect(
2594 self.get(), {name.data(), name.size()});
2595 if (mlirDialectIsNull(dialect)) {
2596 throw SetPyError(PyExc_ValueError,
2597 Twine("Dialect '") + name + "' not found");
2598 }
2599 return PyDialectDescriptor(self.getRef(), dialect);
2600 },
2601 "Gets or loads a dialect by name, returning its descriptor object")
2602 .def_property(
2603 "allow_unregistered_dialects",
2604 [](PyMlirContext &self) -> bool {
2605 return mlirContextGetAllowUnregisteredDialects(self.get());
2606 },
2607 [](PyMlirContext &self, bool value) {
2608 mlirContextSetAllowUnregisteredDialects(self.get(), value);
2609 });
2610
2611 //----------------------------------------------------------------------------
2612 // Mapping of PyDialectDescriptor
2613 //----------------------------------------------------------------------------
2614 py::class_<PyDialectDescriptor>(m, "DialectDescriptor")
2615 .def_property_readonly("namespace",
2616 [](PyDialectDescriptor &self) {
2617 MlirStringRef ns =
2618 mlirDialectGetNamespace(self.get());
2619 return py::str(ns.data, ns.length);
2620 })
2621 .def("__repr__", [](PyDialectDescriptor &self) {
2622 MlirStringRef ns = mlirDialectGetNamespace(self.get());
2623 std::string repr("<DialectDescriptor ");
2624 repr.append(ns.data, ns.length);
2625 repr.append(">");
2626 return repr;
2627 });
2628
2629 //----------------------------------------------------------------------------
2630 // Mapping of PyDialects
2631 //----------------------------------------------------------------------------
2632 py::class_<PyDialects>(m, "Dialects")
2633 .def("__getitem__",
2634 [=](PyDialects &self, std::string keyName) {
2635 MlirDialect dialect =
2636 self.getDialectForKey(keyName, /*attrError=*/false);
2637 py::object descriptor =
2638 py::cast(PyDialectDescriptor{self.getContext(), dialect});
2639 return createCustomDialectWrapper(keyName, std::move(descriptor));
2640 })
2641 .def("__getattr__", [=](PyDialects &self, std::string attrName) {
2642 MlirDialect dialect =
2643 self.getDialectForKey(attrName, /*attrError=*/true);
2644 py::object descriptor =
2645 py::cast(PyDialectDescriptor{self.getContext(), dialect});
2646 return createCustomDialectWrapper(attrName, std::move(descriptor));
2647 });
2648
2649 //----------------------------------------------------------------------------
2650 // Mapping of PyDialect
2651 //----------------------------------------------------------------------------
2652 py::class_<PyDialect>(m, "Dialect")
2653 .def(py::init<py::object>(), "descriptor")
2654 .def_property_readonly(
2655 "descriptor", [](PyDialect &self) { return self.getDescriptor(); })
2656 .def("__repr__", [](py::object self) {
2657 auto clazz = self.attr("__class__");
2658 return py::str("<Dialect ") +
2659 self.attr("descriptor").attr("namespace") + py::str(" (class ") +
2660 clazz.attr("__module__") + py::str(".") +
2661 clazz.attr("__name__") + py::str(")>");
2662 });
2663
2664 //----------------------------------------------------------------------------
2665 // Mapping of Location
2666 //----------------------------------------------------------------------------
2667 py::class_<PyLocation>(m, "Location")
2668 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule)
2669 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule)
2670 .def("__enter__", &PyLocation::contextEnter)
2671 .def("__exit__", &PyLocation::contextExit)
2672 .def("__eq__",
2673 [](PyLocation &self, PyLocation &other) -> bool {
2674 return mlirLocationEqual(self, other);
2675 })
2676 .def("__eq__", [](PyLocation &self, py::object other) { return false; })
2677 .def_property_readonly_static(
2678 "current",
2679 [](py::object & /*class*/) {
2680 auto *loc = PyThreadContextEntry::getDefaultLocation();
2681 if (!loc)
2682 throw SetPyError(PyExc_ValueError, "No current Location");
2683 return loc;
2684 },
2685 "Gets the Location bound to the current thread or raises ValueError")
2686 .def_static(
2687 "unknown",
2688 [](DefaultingPyMlirContext context) {
2689 return PyLocation(context->getRef(),
2690 mlirLocationUnknownGet(context->get()));
2691 },
2692 py::arg("context") = py::none(),
2693 "Gets a Location representing an unknown location")
2694 .def_static(
2695 "file",
2696 [](std::string filename, int line, int col,
2697 DefaultingPyMlirContext context) {
2698 return PyLocation(
2699 context->getRef(),
2700 mlirLocationFileLineColGet(
2701 context->get(), toMlirStringRef(filename), line, col));
2702 },
2703 py::arg("filename"), py::arg("line"), py::arg("col"),
2704 py::arg("context") = py::none(), kContextGetFileLocationDocstring)
2705 .def_property_readonly(
2706 "context",
2707 [](PyLocation &self) { return self.getContext().getObject(); },
2708 "Context that owns the Location")
2709 .def("__repr__", [](PyLocation &self) {
2710 PyPrintAccumulator printAccum;
2711 mlirLocationPrint(self, printAccum.getCallback(),
2712 printAccum.getUserData());
2713 return printAccum.join();
2714 });
2715
2716 //----------------------------------------------------------------------------
2717 // Mapping of Module
2718 //----------------------------------------------------------------------------
2719 py::class_<PyModule>(m, "Module")
2720 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
2721 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
2722 .def_static(
2723 "parse",
2724 [](const std::string moduleAsm, DefaultingPyMlirContext context) {
2725 MlirModule module = mlirModuleCreateParse(
2726 context->get(), toMlirStringRef(moduleAsm));
2727 // TODO: Rework error reporting once diagnostic engine is exposed
2728 // in C API.
2729 if (mlirModuleIsNull(module)) {
2730 throw SetPyError(
2731 PyExc_ValueError,
2732 "Unable to parse module assembly (see diagnostics)");
2733 }
2734 return PyModule::forModule(module).releaseObject();
2735 },
2736 py::arg("asm"), py::arg("context") = py::none(),
2737 kModuleParseDocstring)
2738 .def_static(
2739 "create",
2740 [](DefaultingPyLocation loc) {
2741 MlirModule module = mlirModuleCreateEmpty(loc);
2742 return PyModule::forModule(module).releaseObject();
2743 },
2744 py::arg("loc") = py::none(), "Creates an empty module")
2745 .def_property_readonly(
2746 "context",
2747 [](PyModule &self) { return self.getContext().getObject(); },
2748 "Context that created the Module")
2749 .def_property_readonly(
2750 "operation",
2751 [](PyModule &self) {
2752 return PyOperation::forOperation(self.getContext(),
2753 mlirModuleGetOperation(self.get()),
2754 self.getRef().releaseObject())
2755 .releaseObject();
2756 },
2757 "Accesses the module as an operation")
2758 .def_property_readonly(
2759 "body",
2760 [](PyModule &self) {
2761 PyOperationRef module_op = PyOperation::forOperation(
2762 self.getContext(), mlirModuleGetOperation(self.get()),
2763 self.getRef().releaseObject());
2764 PyBlock returnBlock(module_op, mlirModuleGetBody(self.get()));
2765 return returnBlock;
2766 },
2767 "Return the block for this module")
2768 .def(
2769 "dump",
2770 [](PyModule &self) {
2771 mlirOperationDump(mlirModuleGetOperation(self.get()));
2772 },
2773 kDumpDocstring)
2774 .def(
2775 "__str__",
2776 [](PyModule &self) {
2777 MlirOperation operation = mlirModuleGetOperation(self.get());
2778 PyPrintAccumulator printAccum;
2779 mlirOperationPrint(operation, printAccum.getCallback(),
2780 printAccum.getUserData());
2781 return printAccum.join();
2782 },
2783 kOperationStrDunderDocstring);
2784
2785 //----------------------------------------------------------------------------
2786 // Mapping of Operation.
2787 //----------------------------------------------------------------------------
2788 py::class_<PyOperationBase>(m, "_OperationBase")
2789 .def("__eq__",
2790 [](PyOperationBase &self, PyOperationBase &other) {
2791 return &self.getOperation() == &other.getOperation();
2792 })
2793 .def("__eq__",
2794 [](PyOperationBase &self, py::object other) { return false; })
2795 .def_property_readonly("attributes",
2796 [](PyOperationBase &self) {
2797 return PyOpAttributeMap(
2798 self.getOperation().getRef());
2799 })
2800 .def_property_readonly("operands",
2801 [](PyOperationBase &self) {
2802 return PyOpOperandList(
2803 self.getOperation().getRef());
2804 })
2805 .def_property_readonly("regions",
2806 [](PyOperationBase &self) {
2807 return PyRegionList(
2808 self.getOperation().getRef());
2809 })
2810 .def_property_readonly(
2811 "results",
2812 [](PyOperationBase &self) {
2813 return PyOpResultList(self.getOperation().getRef());
2814 },
2815 "Returns the list of Operation results.")
2816 .def_property_readonly(
2817 "result",
2818 [](PyOperationBase &self) {
2819 auto &operation = self.getOperation();
2820 auto numResults = mlirOperationGetNumResults(operation);
2821 if (numResults != 1) {
2822 auto name = mlirIdentifierStr(mlirOperationGetName(operation));
2823 throw SetPyError(
2824 PyExc_ValueError,
2825 Twine("Cannot call .result on operation ") +
2826 StringRef(name.data, name.length) + " which has " +
2827 Twine(numResults) +
2828 " results (it is only valid for operations with a "
2829 "single result)");
2830 }
2831 return PyOpResult(operation.getRef(),
2832 mlirOperationGetResult(operation, 0));
2833 },
2834 "Shortcut to get an op result if it has only one (throws an error "
2835 "otherwise).")
2836 .def("__iter__",
2837 [](PyOperationBase &self) {
2838 return PyRegionIterator(self.getOperation().getRef());
2839 })
2840 .def(
2841 "__str__",
2842 [](PyOperationBase &self) {
2843 return self.getAsm(/*binary=*/false,
2844 /*largeElementsLimit=*/llvm::None,
2845 /*enableDebugInfo=*/false,
2846 /*prettyDebugInfo=*/false,
2847 /*printGenericOpForm=*/false,
2848 /*useLocalScope=*/false);
2849 },
2850 "Returns the assembly form of the operation.")
2851 .def("print", &PyOperationBase::print,
2852 // Careful: Lots of arguments must match up with print method.
2853 py::arg("file") = py::none(), py::arg("binary") = false,
2854 py::arg("large_elements_limit") = py::none(),
2855 py::arg("enable_debug_info") = false,
2856 py::arg("pretty_debug_info") = false,
2857 py::arg("print_generic_op_form") = false,
2858 py::arg("use_local_scope") = false, kOperationPrintDocstring)
2859 .def("get_asm", &PyOperationBase::getAsm,
2860 // Careful: Lots of arguments must match up with get_asm method.
2861 py::arg("binary") = false,
2862 py::arg("large_elements_limit") = py::none(),
2863 py::arg("enable_debug_info") = false,
2864 py::arg("pretty_debug_info") = false,
2865 py::arg("print_generic_op_form") = false,
2866 py::arg("use_local_scope") = false, kOperationGetAsmDocstring);
2867
2868 py::class_<PyOperation, PyOperationBase>(m, "Operation")
2869 .def_static("create", &PyOperation::create, py::arg("name"),
2870 py::arg("operands") = py::none(),
2871 py::arg("results") = py::none(),
2872 py::arg("attributes") = py::none(),
2873 py::arg("successors") = py::none(), py::arg("regions") = 0,
2874 py::arg("loc") = py::none(), py::arg("ip") = py::none(),
2875 kOperationCreateDocstring)
2876 .def_property_readonly(
2877 "context",
2878 [](PyOperation &self) { return self.getContext().getObject(); },
2879 "Context that owns the Operation")
2880 .def_property_readonly("opview", &PyOperation::createOpView);
2881
2882 py::class_<PyOpView, PyOperationBase>(m, "OpView")
2883 .def(py::init<py::object>())
2884 .def_property_readonly("operation", &PyOpView::getOperationObject)
2885 .def_property_readonly(
2886 "context",
2887 [](PyOpView &self) {
2888 return self.getOperation().getContext().getObject();
2889 },
2890 "Context that owns the Operation")
2891 .def("__str__",
2892 [](PyOpView &self) { return py::str(self.getOperationObject()); });
2893
2894 //----------------------------------------------------------------------------
2895 // Mapping of PyRegion.
2896 //----------------------------------------------------------------------------
2897 py::class_<PyRegion>(m, "Region")
2898 .def_property_readonly(
2899 "blocks",
2900 [](PyRegion &self) {
2901 return PyBlockList(self.getParentOperation(), self.get());
2902 },
2903 "Returns a forward-optimized sequence of blocks.")
2904 .def(
2905 "__iter__",
2906 [](PyRegion &self) {
2907 self.checkValid();
2908 MlirBlock firstBlock = mlirRegionGetFirstBlock(self.get());
2909 return PyBlockIterator(self.getParentOperation(), firstBlock);
2910 },
2911 "Iterates over blocks in the region.")
2912 .def("__eq__",
2913 [](PyRegion &self, PyRegion &other) {
2914 return self.get().ptr == other.get().ptr;
2915 })
2916 .def("__eq__", [](PyRegion &self, py::object &other) { return false; });
2917
2918 //----------------------------------------------------------------------------
2919 // Mapping of PyBlock.
2920 //----------------------------------------------------------------------------
2921 py::class_<PyBlock>(m, "Block")
2922 .def_property_readonly(
2923 "arguments",
2924 [](PyBlock &self) {
2925 return PyBlockArgumentList(self.getParentOperation(), self.get());
2926 },
2927 "Returns a list of block arguments.")
2928 .def_property_readonly(
2929 "operations",
2930 [](PyBlock &self) {
2931 return PyOperationList(self.getParentOperation(), self.get());
2932 },
2933 "Returns a forward-optimized sequence of operations.")
2934 .def(
2935 "__iter__",
2936 [](PyBlock &self) {
2937 self.checkValid();
2938 MlirOperation firstOperation =
2939 mlirBlockGetFirstOperation(self.get());
2940 return PyOperationIterator(self.getParentOperation(),
2941 firstOperation);
2942 },
2943 "Iterates over operations in the block.")
2944 .def("__eq__",
2945 [](PyBlock &self, PyBlock &other) {
2946 return self.get().ptr == other.get().ptr;
2947 })
2948 .def("__eq__", [](PyBlock &self, py::object &other) { return false; })
2949 .def(
2950 "__str__",
2951 [](PyBlock &self) {
2952 self.checkValid();
2953 PyPrintAccumulator printAccum;
2954 mlirBlockPrint(self.get(), printAccum.getCallback(),
2955 printAccum.getUserData());
2956 return printAccum.join();
2957 },
2958 "Returns the assembly form of the block.");
2959
2960 //----------------------------------------------------------------------------
2961 // Mapping of PyInsertionPoint.
2962 //----------------------------------------------------------------------------
2963
2964 py::class_<PyInsertionPoint>(m, "InsertionPoint")
2965 .def(py::init<PyBlock &>(), py::arg("block"),
2966 "Inserts after the last operation but still inside the block.")
2967 .def("__enter__", &PyInsertionPoint::contextEnter)
2968 .def("__exit__", &PyInsertionPoint::contextExit)
2969 .def_property_readonly_static(
2970 "current",
2971 [](py::object & /*class*/) {
2972 auto *ip = PyThreadContextEntry::getDefaultInsertionPoint();
2973 if (!ip)
2974 throw SetPyError(PyExc_ValueError, "No current InsertionPoint");
2975 return ip;
2976 },
2977 "Gets the InsertionPoint bound to the current thread or raises "
2978 "ValueError if none has been set")
2979 .def(py::init<PyOperationBase &>(), py::arg("beforeOperation"),
2980 "Inserts before a referenced operation.")
2981 .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin,
2982 py::arg("block"), "Inserts at the beginning of the block.")
2983 .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator,
2984 py::arg("block"), "Inserts before the block terminator.")
2985 .def("insert", &PyInsertionPoint::insert, py::arg("operation"),
2986 "Inserts an operation.");
2987
2988 //----------------------------------------------------------------------------
2989 // Mapping of PyAttribute.
2990 //----------------------------------------------------------------------------
2991 py::class_<PyAttribute>(m, "Attribute")
2992 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
2993 &PyAttribute::getCapsule)
2994 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule)
2995 .def_static(
2996 "parse",
2997 [](std::string attrSpec, DefaultingPyMlirContext context) {
2998 MlirAttribute type = mlirAttributeParseGet(
2999 context->get(), toMlirStringRef(attrSpec));
3000 // TODO: Rework error reporting once diagnostic engine is exposed
3001 // in C API.
3002 if (mlirAttributeIsNull(type)) {
3003 throw SetPyError(PyExc_ValueError,
3004 Twine("Unable to parse attribute: '") +
3005 attrSpec + "'");
3006 }
3007 return PyAttribute(context->getRef(), type);
3008 },
3009 py::arg("asm"), py::arg("context") = py::none(),
3010 "Parses an attribute from an assembly form")
3011 .def_property_readonly(
3012 "context",
3013 [](PyAttribute &self) { return self.getContext().getObject(); },
3014 "Context that owns the Attribute")
3015 .def_property_readonly("type",
3016 [](PyAttribute &self) {
3017 return PyType(self.getContext()->getRef(),
3018 mlirAttributeGetType(self));
3019 })
3020 .def(
3021 "get_named",
3022 [](PyAttribute &self, std::string name) {
3023 return PyNamedAttribute(self, std::move(name));
3024 },
3025 py::keep_alive<0, 1>(), "Binds a name to the attribute")
3026 .def("__eq__",
3027 [](PyAttribute &self, PyAttribute &other) { return self == other; })
3028 .def("__eq__", [](PyAttribute &self, py::object &other) { return false; })
3029 .def(
3030 "dump", [](PyAttribute &self) { mlirAttributeDump(self); },
3031 kDumpDocstring)
3032 .def(
3033 "__str__",
3034 [](PyAttribute &self) {
3035 PyPrintAccumulator printAccum;
3036 mlirAttributePrint(self, printAccum.getCallback(),
3037 printAccum.getUserData());
3038 return printAccum.join();
3039 },
3040 "Returns the assembly form of the Attribute.")
3041 .def("__repr__", [](PyAttribute &self) {
3042 // Generally, assembly formats are not printed for __repr__ because
3043 // this can cause exceptionally long debug output and exceptions.
3044 // However, attribute values are generally considered useful and are
3045 // printed. This may need to be re-evaluated if debug dumps end up
3046 // being excessive.
3047 PyPrintAccumulator printAccum;
3048 printAccum.parts.append("Attribute(");
3049 mlirAttributePrint(self, printAccum.getCallback(),
3050 printAccum.getUserData());
3051 printAccum.parts.append(")");
3052 return printAccum.join();
3053 });
3054
3055 //----------------------------------------------------------------------------
3056 // Mapping of PyNamedAttribute
3057 //----------------------------------------------------------------------------
3058 py::class_<PyNamedAttribute>(m, "NamedAttribute")
3059 .def("__repr__",
3060 [](PyNamedAttribute &self) {
3061 PyPrintAccumulator printAccum;
3062 printAccum.parts.append("NamedAttribute(");
3063 printAccum.parts.append(self.namedAttr.name.data);
3064 printAccum.parts.append("=");
3065 mlirAttributePrint(self.namedAttr.attribute,
3066 printAccum.getCallback(),
3067 printAccum.getUserData());
3068 printAccum.parts.append(")");
3069 return printAccum.join();
3070 })
3071 .def_property_readonly(
3072 "name",
3073 [](PyNamedAttribute &self) {
3074 return py::str(self.namedAttr.name.data,
3075 self.namedAttr.name.length);
3076 },
3077 "The name of the NamedAttribute binding")
3078 .def_property_readonly(
3079 "attr",
3080 [](PyNamedAttribute &self) {
3081 // TODO: When named attribute is removed/refactored, also remove
3082 // this constructor (it does an inefficient table lookup).
3083 auto contextRef = PyMlirContext::forContext(
3084 mlirAttributeGetContext(self.namedAttr.attribute));
3085 return PyAttribute(std::move(contextRef), self.namedAttr.attribute);
3086 },
3087 py::keep_alive<0, 1>(),
3088 "The underlying generic attribute of the NamedAttribute binding");
3089
3090 // Builtin attribute bindings.
3091 PyFloatAttribute::bind(m);
3092 PyIntegerAttribute::bind(m);
3093 PyBoolAttribute::bind(m);
3094 PyStringAttribute::bind(m);
3095 PyDenseElementsAttribute::bind(m);
3096 PyDenseIntElementsAttribute::bind(m);
3097 PyDenseFPElementsAttribute::bind(m);
3098 PyTypeAttribute::bind(m);
3099 PyUnitAttribute::bind(m);
3100
3101 //----------------------------------------------------------------------------
3102 // Mapping of PyType.
3103 //----------------------------------------------------------------------------
3104 py::class_<PyType>(m, "Type")
3105 .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule)
3106 .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule)
3107 .def_static(
3108 "parse",
3109 [](std::string typeSpec, DefaultingPyMlirContext context) {
3110 MlirType type =
3111 mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec));
3112 // TODO: Rework error reporting once diagnostic engine is exposed
3113 // in C API.
3114 if (mlirTypeIsNull(type)) {
3115 throw SetPyError(PyExc_ValueError,
3116 Twine("Unable to parse type: '") + typeSpec +
3117 "'");
3118 }
3119 return PyType(context->getRef(), type);
3120 },
3121 py::arg("asm"), py::arg("context") = py::none(),
3122 kContextParseTypeDocstring)
3123 .def_property_readonly(
3124 "context", [](PyType &self) { return self.getContext().getObject(); },
3125 "Context that owns the Type")
3126 .def("__eq__", [](PyType &self, PyType &other) { return self == other; })
3127 .def("__eq__", [](PyType &self, py::object &other) { return false; })
3128 .def(
3129 "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring)
3130 .def(
3131 "__str__",
3132 [](PyType &self) {
3133 PyPrintAccumulator printAccum;
3134 mlirTypePrint(self, printAccum.getCallback(),
3135 printAccum.getUserData());
3136 return printAccum.join();
3137 },
3138 "Returns the assembly form of the type.")
3139 .def("__repr__", [](PyType &self) {
3140 // Generally, assembly formats are not printed for __repr__ because
3141 // this can cause exceptionally long debug output and exceptions.
3142 // However, types are an exception as they typically have compact
3143 // assembly forms and printing them is useful.
3144 PyPrintAccumulator printAccum;
3145 printAccum.parts.append("Type(");
3146 mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData());
3147 printAccum.parts.append(")");
3148 return printAccum.join();
3149 });
3150
3151 // Builtin type bindings.
3152 PyIntegerType::bind(m);
3153 PyIndexType::bind(m);
3154 PyBF16Type::bind(m);
3155 PyF16Type::bind(m);
3156 PyF32Type::bind(m);
3157 PyF64Type::bind(m);
3158 PyNoneType::bind(m);
3159 PyComplexType::bind(m);
3160 PyShapedType::bind(m);
3161 PyVectorType::bind(m);
3162 PyRankedTensorType::bind(m);
3163 PyUnrankedTensorType::bind(m);
3164 PyMemRefType::bind(m);
3165 PyUnrankedMemRefType::bind(m);
3166 PyTupleType::bind(m);
3167 PyFunctionType::bind(m);
3168
3169 //----------------------------------------------------------------------------
3170 // Mapping of Value.
3171 //----------------------------------------------------------------------------
3172 py::class_<PyValue>(m, "Value")
3173 .def_property_readonly(
3174 "context",
3175 [](PyValue &self) { return self.getParentOperation()->getContext(); },
3176 "Context in which the value lives.")
3177 .def(
3178 "dump", [](PyValue &self) { mlirValueDump(self.get()); },
3179 kDumpDocstring)
3180 .def("__eq__",
3181 [](PyValue &self, PyValue &other) {
3182 return self.get().ptr == other.get().ptr;
3183 })
3184 .def("__eq__", [](PyValue &self, py::object other) { return false; })
3185 .def(
3186 "__str__",
3187 [](PyValue &self) {
3188 PyPrintAccumulator printAccum;
3189 printAccum.parts.append("Value(");
3190 mlirValuePrint(self.get(), printAccum.getCallback(),
3191 printAccum.getUserData());
3192 printAccum.parts.append(")");
3193 return printAccum.join();
3194 },
3195 kValueDunderStrDocstring)
3196 .def_property_readonly("type", [](PyValue &self) {
3197 return PyType(self.getParentOperation()->getContext(),
3198 mlirValueGetType(self.get()));
3199 });
3200 PyBlockArgument::bind(m);
3201 PyOpResult::bind(m);
3202
3203 // Container bindings.
3204 PyBlockArgumentList::bind(m);
3205 PyBlockIterator::bind(m);
3206 PyBlockList::bind(m);
3207 PyOperationIterator::bind(m);
3208 PyOperationList::bind(m);
3209 PyOpAttributeMap::bind(m);
3210 PyOpOperandList::bind(m);
3211 PyOpResultList::bind(m);
3212 PyRegionIterator::bind(m);
3213 PyRegionList::bind(m);
3214 }
3215