//===- IRModules.h - IR Submodules of pybind module -----------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H #define MLIR_BINDINGS_PYTHON_IRMODULES_H #include #include "PybindUtils.h" #include "mlir-c/IR.h" #include "llvm/ADT/DenseMap.h" namespace mlir { namespace python { class PyBlock; class PyInsertionPoint; class PyLocation; class DefaultingPyLocation; class PyMlirContext; class DefaultingPyMlirContext; class PyModule; class PyOperation; class PyType; class PyValue; /// Template for a reference to a concrete type which captures a python /// reference to its underlying python object. template class PyObjectRef { public: PyObjectRef(T *referrent, pybind11::object object) : referrent(referrent), object(std::move(object)) { assert(this->referrent && "cannot construct PyObjectRef with null referrent"); assert(this->object && "cannot construct PyObjectRef with null object"); } PyObjectRef(PyObjectRef &&other) : referrent(other.referrent), object(std::move(other.object)) { other.referrent = nullptr; assert(!other.object); } PyObjectRef(const PyObjectRef &other) : referrent(other.referrent), object(other.object /* copies */) {} ~PyObjectRef() {} int getRefCount() { if (!object) return 0; return object.ref_count(); } /// Releases the object held by this instance, returning it. /// This is the proper thing to return from a function that wants to return /// the reference. Note that this does not work from initializers. pybind11::object releaseObject() { assert(referrent && object); referrent = nullptr; auto stolen = std::move(object); return stolen; } T *get() { return referrent; } T *operator->() { assert(referrent && object); return referrent; } pybind11::object getObject() { assert(referrent && object); return object; } operator bool() const { return referrent && object; } private: T *referrent; pybind11::object object; }; /// Tracks an entry in the thread context stack. New entries are pushed onto /// here for each with block that activates a new InsertionPoint, Context or /// Location. /// /// Pushing either a Location or InsertionPoint also pushes its associated /// Context. Pushing a Context will not modify the Location or InsertionPoint /// unless if they are from a different context, in which case, they are /// cleared. class PyThreadContextEntry { public: enum class FrameKind { Context, InsertionPoint, Location, }; PyThreadContextEntry(FrameKind frameKind, pybind11::object context, pybind11::object insertionPoint, pybind11::object location) : context(std::move(context)), insertionPoint(std::move(insertionPoint)), location(std::move(location)), frameKind(frameKind) {} /// Gets the top of stack context and return nullptr if not defined. static PyMlirContext *getDefaultContext(); /// Gets the top of stack insertion point and return nullptr if not defined. static PyInsertionPoint *getDefaultInsertionPoint(); /// Gets the top of stack location and returns nullptr if not defined. static PyLocation *getDefaultLocation(); PyMlirContext *getContext(); PyInsertionPoint *getInsertionPoint(); PyLocation *getLocation(); FrameKind getFrameKind() { return frameKind; } /// Stack management. static PyThreadContextEntry *getTopOfStack(); static pybind11::object pushContext(PyMlirContext &context); static void popContext(PyMlirContext &context); static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint); static void popInsertionPoint(PyInsertionPoint &insertionPoint); static pybind11::object pushLocation(PyLocation &location); static void popLocation(PyLocation &location); /// Gets the thread local stack. static std::vector &getStack(); private: static void push(FrameKind frameKind, pybind11::object context, pybind11::object insertionPoint, pybind11::object location); /// An object reference to the PyContext. pybind11::object context; /// An object reference to the current insertion point. pybind11::object insertionPoint; /// An object reference to the current location. pybind11::object location; // The kind of push that was performed. FrameKind frameKind; }; /// Wrapper around MlirContext. using PyMlirContextRef = PyObjectRef; class PyMlirContext { public: PyMlirContext() = delete; PyMlirContext(const PyMlirContext &) = delete; PyMlirContext(PyMlirContext &&) = delete; /// For the case of a python __init__ (py::init) method, pybind11 is quite /// strict about needing to return a pointer that is not yet associated to /// an py::object. Since the forContext() method acts like a pool, possibly /// returning a recycled context, it does not satisfy this need. The usual /// way in python to accomplish such a thing is to override __new__, but /// that is also not supported by pybind11. Instead, we use this entry /// point which always constructs a fresh context (which cannot alias an /// existing one because it is fresh). static PyMlirContext *createNewContextForInit(); /// Returns a context reference for the singleton PyMlirContext wrapper for /// the given context. static PyMlirContextRef forContext(MlirContext context); ~PyMlirContext(); /// Accesses the underlying MlirContext. MlirContext get() { return context; } /// Gets a strong reference to this context, which will ensure it is kept /// alive for the life of the reference. PyMlirContextRef getRef() { return PyMlirContextRef(this, pybind11::cast(this)); } /// Gets a capsule wrapping the void* within the MlirContext. pybind11::object getCapsule(); /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. /// Note that PyMlirContext instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirContext /// is taken by calling this function. static pybind11::object createFromCapsule(pybind11::object capsule); /// Gets the count of live context objects. Used for testing. static size_t getLiveCount(); /// Gets the count of live operations associated with this context. /// Used for testing. size_t getLiveOperationCount(); /// Gets the count of live modules associated with this context. /// Used for testing. size_t getLiveModuleCount(); /// Enter and exit the context manager. pybind11::object contextEnter(); void contextExit(pybind11::object excType, pybind11::object excVal, pybind11::object excTb); private: PyMlirContext(MlirContext context); // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, // preserving the relationship that an MlirContext maps to a single // PyMlirContext wrapper. This could be replaced in the future with an // extension mechanism on the MlirContext for stashing user pointers. // Note that this holds a handle, which does not imply ownership. // Mappings will be removed when the context is destructed. using LiveContextMap = llvm::DenseMap; static LiveContextMap &getLiveContexts(); // Interns all live modules associated with this context. Modules tracked // in this map are valid. When a module is invalidated, it is removed // from this map, and while it still exists as an instance, any // attempt to access it will raise an error. using LiveModuleMap = llvm::DenseMap>; LiveModuleMap liveModules; // Interns all live operations associated with this context. Operations // tracked in this map are valid. When an operation is invalidated, it is // removed from this map, and while it still exists as an instance, any // attempt to access it will raise an error. using LiveOperationMap = llvm::DenseMap>; LiveOperationMap liveOperations; MlirContext context; friend class PyModule; friend class PyOperation; }; /// Used in function arguments when None should resolve to the current context /// manager set instance. class DefaultingPyMlirContext : public Defaulting { public: using Defaulting::Defaulting; static constexpr const char kTypeDescription[] = "[ThreadContextAware] mlir.ir.Context"; static PyMlirContext &resolve(); }; /// Base class for all objects that directly or indirectly depend on an /// MlirContext. The lifetime of the context will extend at least to the /// lifetime of these instances. /// Immutable objects that depend on a context extend this directly. class BaseContextObject { public: BaseContextObject(PyMlirContextRef ref) : contextRef(std::move(ref)) { assert(this->contextRef && "context object constructed with null context ref"); } /// Accesses the context reference. PyMlirContextRef &getContext() { return contextRef; } private: PyMlirContextRef contextRef; }; /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in /// order to differentiate it from the `Dialect` base class which is extended by /// plugins which extend dialect functionality through extension python code. /// This should be seen as the "low-level" object and `Dialect` as the /// high-level, user facing object. class PyDialectDescriptor : public BaseContextObject { public: PyDialectDescriptor(PyMlirContextRef contextRef, MlirDialect dialect) : BaseContextObject(std::move(contextRef)), dialect(dialect) {} MlirDialect get() { return dialect; } private: MlirDialect dialect; }; /// User-level object for accessing dialects with dotted syntax such as: /// ctx.dialect.std class PyDialects : public BaseContextObject { public: PyDialects(PyMlirContextRef contextRef) : BaseContextObject(std::move(contextRef)) {} MlirDialect getDialectForKey(const std::string &key, bool attrError); }; /// User-level dialect object. For dialects that have a registered extension, /// this will be the base class of the extension dialect type. For un-extended, /// objects of this type will be returned directly. class PyDialect { public: PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {} pybind11::object getDescriptor() { return descriptor; } private: pybind11::object descriptor; }; /// Wrapper around an MlirLocation. class PyLocation : public BaseContextObject { public: PyLocation(PyMlirContextRef contextRef, MlirLocation loc) : BaseContextObject(std::move(contextRef)), loc(loc) {} operator MlirLocation() const { return loc; } MlirLocation get() const { return loc; } /// Enter and exit the context manager. pybind11::object contextEnter(); void contextExit(pybind11::object excType, pybind11::object excVal, pybind11::object excTb); /// Gets a capsule wrapping the void* within the MlirContext. pybind11::object getCapsule(); /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. /// Note that PyMlirContext instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirContext /// is taken by calling this function. static PyLocation createFromCapsule(pybind11::object capsule); private: MlirLocation loc; }; /// Used in function arguments when None should resolve to the current context /// manager set instance. class DefaultingPyLocation : public Defaulting { public: using Defaulting::Defaulting; static constexpr const char kTypeDescription[] = "[ThreadContextAware] mlir.ir.Location"; static PyLocation &resolve(); operator MlirLocation() const { return *get(); } }; /// Wrapper around MlirModule. /// This is the top-level, user-owned object that contains regions/ops/blocks. class PyModule; using PyModuleRef = PyObjectRef; class PyModule : public BaseContextObject { public: /// Returns a PyModule reference for the given MlirModule. This may return /// a pre-existing or new object. static PyModuleRef forModule(MlirModule module); PyModule(PyModule &) = delete; PyModule(PyMlirContext &&) = delete; ~PyModule(); /// Gets the backing MlirModule. MlirModule get() { return module; } /// Gets a strong reference to this module. PyModuleRef getRef() { return PyModuleRef(this, pybind11::reinterpret_borrow(handle)); } /// Gets a capsule wrapping the void* within the MlirModule. /// Note that the module does not (yet) provide a corresponding factory for /// constructing from a capsule as that would require uniquing PyModule /// instances, which is not currently done. pybind11::object getCapsule(); /// Creates a PyModule from the MlirModule wrapped by a capsule. /// Note that PyModule instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirModule /// is taken by calling this function. static pybind11::object createFromCapsule(pybind11::object capsule); private: PyModule(PyMlirContextRef contextRef, MlirModule module); MlirModule module; pybind11::handle handle; }; /// Base class for PyOperation and PyOpView which exposes the primary, user /// visible methods for manipulating it. class PyOperationBase { public: virtual ~PyOperationBase() = default; /// Implements the bound 'print' method and helps with others. void print(pybind11::object fileObject, bool binary, llvm::Optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope); pybind11::object getAsm(bool binary, llvm::Optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope); /// Each must provide access to the raw Operation. virtual PyOperation &getOperation() = 0; }; /// Wrapper around PyOperation. /// Operations exist in either an attached (dependent) or detached (top-level) /// state. In the detached state (as on creation), an operation is owned by /// the creator and its lifetime extends either until its reference count /// drops to zero or it is attached to a parent, at which point its lifetime /// is bounded by its top-level parent reference. class PyOperation; using PyOperationRef = PyObjectRef; class PyOperation : public PyOperationBase, public BaseContextObject { public: ~PyOperation(); PyOperation &getOperation() override { return *this; } /// Returns a PyOperation for the given MlirOperation, optionally associating /// it with a parentKeepAlive. static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, pybind11::object parentKeepAlive = pybind11::object()); /// Creates a detached operation. The operation must not be associated with /// any existing live operation. static PyOperationRef createDetached(PyMlirContextRef contextRef, MlirOperation operation, pybind11::object parentKeepAlive = pybind11::object()); /// Gets the backing operation. operator MlirOperation() const { return get(); } MlirOperation get() const { checkValid(); return operation; } PyOperationRef getRef() { return PyOperationRef( this, pybind11::reinterpret_borrow(handle)); } bool isAttached() { return attached; } void setAttached() { assert(!attached && "operation already attached"); attached = true; } void checkValid() const; /// Gets the owning block or raises an exception if the operation has no /// owning block. PyBlock getBlock(); /// Gets the parent operation or raises an exception if the operation has /// no parent. PyOperationRef getParentOperation(); /// Creates an operation. See corresponding python docstring. static pybind11::object create(std::string name, llvm::Optional> operands, llvm::Optional> results, llvm::Optional attributes, llvm::Optional> successors, int regions, DefaultingPyLocation location, pybind11::object ip); /// Creates an OpView suitable for this operation. pybind11::object createOpView(); private: PyOperation(PyMlirContextRef contextRef, MlirOperation operation); static PyOperationRef createInstance(PyMlirContextRef contextRef, MlirOperation operation, pybind11::object parentKeepAlive); MlirOperation operation; pybind11::handle handle; // Keeps the parent alive, regardless of whether it is an Operation or // Module. // TODO: As implemented, this facility is only sufficient for modeling the // trivial module parent back-reference. Generalize this to also account for // transitions from detached to attached and address TODOs in the // ir_operation.py regarding testing corresponding lifetime guarantees. pybind11::object parentKeepAlive; bool attached = true; bool valid = true; }; /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for /// providing more instance-specific accessors and serve as the base class for /// custom ODS-style operation classes. Since this class is subclass on the /// python side, it must present an __init__ method that operates in pure /// python types. class PyOpView : public PyOperationBase { public: PyOpView(pybind11::object operationObject); PyOperation &getOperation() override { return operation; } static pybind11::object createRawSubclass(pybind11::object userClass); pybind11::object getOperationObject() { return operationObject; } private: PyOperation &operation; // For efficient, cast-free access from C++ pybind11::object operationObject; // Holds the reference. }; /// Wrapper around an MlirRegion. /// Regions are managed completely by their containing operation. Unlike the /// C++ API, the python API does not support detached regions. class PyRegion { public: PyRegion(PyOperationRef parentOperation, MlirRegion region) : parentOperation(std::move(parentOperation)), region(region) { assert(!mlirRegionIsNull(region) && "python region cannot be null"); } MlirRegion get() { return region; } PyOperationRef &getParentOperation() { return parentOperation; } void checkValid() { return parentOperation->checkValid(); } private: PyOperationRef parentOperation; MlirRegion region; }; /// Wrapper around an MlirBlock. /// Blocks are managed completely by their containing operation. Unlike the /// C++ API, the python API does not support detached blocks. class PyBlock { public: PyBlock(PyOperationRef parentOperation, MlirBlock block) : parentOperation(std::move(parentOperation)), block(block) { assert(!mlirBlockIsNull(block) && "python block cannot be null"); } MlirBlock get() { return block; } PyOperationRef &getParentOperation() { return parentOperation; } void checkValid() { return parentOperation->checkValid(); } private: PyOperationRef parentOperation; MlirBlock block; }; /// An insertion point maintains a pointer to a Block and a reference operation. /// Calls to insert() will insert a new operation before the /// reference operation. If the reference operation is null, then appends to /// the end of the block. class PyInsertionPoint { public: /// Creates an insertion point positioned after the last operation in the /// block, but still inside the block. PyInsertionPoint(PyBlock &block); /// Creates an insertion point positioned before a reference operation. PyInsertionPoint(PyOperationBase &beforeOperationBase); /// Shortcut to create an insertion point at the beginning of the block. static PyInsertionPoint atBlockBegin(PyBlock &block); /// Shortcut to create an insertion point before the block terminator. static PyInsertionPoint atBlockTerminator(PyBlock &block); /// Inserts an operation. void insert(PyOperationBase &operationBase); /// Enter and exit the context manager. pybind11::object contextEnter(); void contextExit(pybind11::object excType, pybind11::object excVal, pybind11::object excTb); PyBlock &getBlock() { return block; } private: // Trampoline constructor that avoids null initializing members while // looking up parents. PyInsertionPoint(PyBlock block, llvm::Optional refOperation) : refOperation(std::move(refOperation)), block(std::move(block)) {} llvm::Optional refOperation; PyBlock block; }; /// Wrapper around the generic MlirAttribute. /// The lifetime of a type is bound by the PyContext that created it. class PyAttribute : public BaseContextObject { public: PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr) : BaseContextObject(std::move(contextRef)), attr(attr) {} bool operator==(const PyAttribute &other); operator MlirAttribute() const { return attr; } MlirAttribute get() const { return attr; } /// Gets a capsule wrapping the void* within the MlirContext. pybind11::object getCapsule(); /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. /// Note that PyMlirContext instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirContext /// is taken by calling this function. static PyAttribute createFromCapsule(pybind11::object capsule); private: MlirAttribute attr; }; /// Represents a Python MlirNamedAttr, carrying an optional owned name. /// TODO: Refactor this and the C-API to be based on an Identifier owned /// by the context so as to avoid ownership issues here. class PyNamedAttribute { public: /// Constructs a PyNamedAttr that retains an owned name. This should be /// used in any code that originates an MlirNamedAttribute from a python /// string. /// The lifetime of the PyNamedAttr must extend to the lifetime of the /// passed attribute. PyNamedAttribute(MlirAttribute attr, std::string ownedName); MlirNamedAttribute namedAttr; private: // Since the MlirNamedAttr contains an internal pointer to the actual // memory of the owned string, it must be heap allocated to remain valid. // Otherwise, strings that fit within the small object optimization threshold // will have their memory address change as the containing object is moved, // resulting in an invalid aliased pointer. std::unique_ptr ownedName; }; /// Wrapper around the generic MlirType. /// The lifetime of a type is bound by the PyContext that created it. class PyType : public BaseContextObject { public: PyType(PyMlirContextRef contextRef, MlirType type) : BaseContextObject(std::move(contextRef)), type(type) {} bool operator==(const PyType &other); operator MlirType() const { return type; } MlirType get() const { return type; } /// Gets a capsule wrapping the void* within the MlirContext. pybind11::object getCapsule(); /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. /// Note that PyMlirContext instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirContext /// is taken by calling this function. static PyType createFromCapsule(pybind11::object capsule); private: MlirType type; }; /// Wrapper around the generic MlirValue. /// Values are managed completely by the operation that resulted in their /// definition. For op result value, this is the operation that defines the /// value. For block argument values, this is the operation that contains the /// block to which the value is an argument (blocks cannot be detached in Python /// bindings so such operation always exists). class PyValue { public: PyValue(PyOperationRef parentOperation, MlirValue value) : parentOperation(parentOperation), value(value) {} MlirValue get() { return value; } PyOperationRef &getParentOperation() { return parentOperation; } void checkValid() { return parentOperation->checkValid(); } private: PyOperationRef parentOperation; MlirValue value; }; void populateIRSubmodule(pybind11::module &m); } // namespace python } // namespace mlir namespace pybind11 { namespace detail { template <> struct type_caster : MlirDefaultingCaster {}; template <> struct type_caster : MlirDefaultingCaster {}; } // namespace detail } // namespace pybind11 #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H