1 //===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H 10 #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H 11 12 #include "mlir-c/Support.h" 13 #include "llvm/ADT/Optional.h" 14 #include "llvm/ADT/Twine.h" 15 16 #include <pybind11/pybind11.h> 17 #include <pybind11/stl.h> 18 19 namespace mlir { 20 namespace python { 21 22 // Sets a python error, ready to be thrown to return control back to the 23 // python runtime. 24 // Correct usage: 25 // throw SetPyError(PyExc_ValueError, "Foobar'd"); 26 pybind11::error_already_set SetPyError(PyObject *excClass, 27 const llvm::Twine &message); 28 29 /// CRTP template for special wrapper types that are allowed to be passed in as 30 /// 'None' function arguments and can be resolved by some global mechanic if 31 /// so. Such types will raise an error if this global resolution fails, and 32 /// it is actually illegal for them to ever be unresolved. From a user 33 /// perspective, they behave like a smart ptr to the underlying type (i.e. 34 /// 'get' method and operator-> overloaded). 35 /// 36 /// Derived types must provide a method, which is called when an environmental 37 /// resolution is required. It must raise an exception if resolution fails: 38 /// static ReferrentTy &resolve() 39 /// 40 /// They must also provide a parameter description that will be used in 41 /// error messages about mismatched types: 42 /// static constexpr const char kTypeDescription[] = "<Description>"; 43 44 template <typename DerivedTy, typename T> 45 class Defaulting { 46 public: 47 using ReferrentTy = T; 48 /// Type casters require the type to be default constructible, but using 49 /// such an instance is illegal. 50 Defaulting() = default; Defaulting(ReferrentTy & referrent)51 Defaulting(ReferrentTy &referrent) : referrent(&referrent) {} 52 get()53 ReferrentTy *get() const { return referrent; } 54 ReferrentTy *operator->() { return referrent; } 55 56 private: 57 ReferrentTy *referrent = nullptr; 58 }; 59 60 } // namespace python 61 } // namespace mlir 62 63 namespace pybind11 { 64 namespace detail { 65 66 template <typename DefaultingTy> 67 struct MlirDefaultingCaster { 68 PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription)); 69 loadMlirDefaultingCaster70 bool load(pybind11::handle src, bool) { 71 if (src.is_none()) { 72 // Note that we do want an exception to propagate from here as it will be 73 // the most informative. 74 value = DefaultingTy{DefaultingTy::resolve()}; 75 return true; 76 } 77 78 // Unlike many casters that chain, these casters are expected to always 79 // succeed, so instead of doing an isinstance check followed by a cast, 80 // just cast in one step and handle the exception. Returning false (vs 81 // letting the exception propagate) causes higher level signature parsing 82 // code to produce nice error messages (other than "Cannot cast..."). 83 try { 84 value = DefaultingTy{ 85 pybind11::cast<typename DefaultingTy::ReferrentTy &>(src)}; 86 return true; 87 } catch (std::exception &) { 88 return false; 89 } 90 } 91 castMlirDefaultingCaster92 static handle cast(DefaultingTy src, return_value_policy policy, 93 handle parent) { 94 return pybind11::cast(src, policy); 95 } 96 }; 97 98 template <typename T> 99 struct type_caster<llvm::Optional<T>> : optional_caster<llvm::Optional<T>> {}; 100 } // namespace detail 101 } // namespace pybind11 102 103 //------------------------------------------------------------------------------ 104 // Conversion utilities. 105 //------------------------------------------------------------------------------ 106 107 namespace mlir { 108 109 /// Accumulates into a python string from a method that accepts an 110 /// MlirStringCallback. 111 struct PyPrintAccumulator { 112 pybind11::list parts; 113 114 void *getUserData() { return this; } 115 116 MlirStringCallback getCallback() { 117 return [](MlirStringRef part, void *userData) { 118 PyPrintAccumulator *printAccum = 119 static_cast<PyPrintAccumulator *>(userData); 120 pybind11::str pyPart(part.data, 121 part.length); // Decodes as UTF-8 by default. 122 printAccum->parts.append(std::move(pyPart)); 123 }; 124 } 125 126 pybind11::str join() { 127 pybind11::str delim("", 0); 128 return delim.attr("join")(parts); 129 } 130 }; 131 132 /// Accumulates int a python file-like object, either writing text (default) 133 /// or binary. 134 class PyFileAccumulator { 135 public: 136 PyFileAccumulator(pybind11::object fileObject, bool binary) 137 : pyWriteFunction(fileObject.attr("write")), binary(binary) {} 138 139 void *getUserData() { return this; } 140 141 MlirStringCallback getCallback() { 142 return [](MlirStringRef part, void *userData) { 143 pybind11::gil_scoped_acquire(); 144 PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData); 145 if (accum->binary) { 146 // Note: Still has to copy and not avoidable with this API. 147 pybind11::bytes pyBytes(part.data, part.length); 148 accum->pyWriteFunction(pyBytes); 149 } else { 150 pybind11::str pyStr(part.data, 151 part.length); // Decodes as UTF-8 by default. 152 accum->pyWriteFunction(pyStr); 153 } 154 }; 155 } 156 157 private: 158 pybind11::object pyWriteFunction; 159 bool binary; 160 }; 161 162 /// Accumulates into a python string from a method that is expected to make 163 /// one (no more, no less) call to the callback (asserts internally on 164 /// violation). 165 struct PySinglePartStringAccumulator { 166 void *getUserData() { return this; } 167 168 MlirStringCallback getCallback() { 169 return [](MlirStringRef part, void *userData) { 170 PySinglePartStringAccumulator *accum = 171 static_cast<PySinglePartStringAccumulator *>(userData); 172 assert(!accum->invoked && 173 "PySinglePartStringAccumulator called back multiple times"); 174 accum->invoked = true; 175 accum->value = pybind11::str(part.data, part.length); 176 }; 177 } 178 179 pybind11::str takeValue() { 180 assert(invoked && "PySinglePartStringAccumulator not called back"); 181 return std::move(value); 182 } 183 184 private: 185 pybind11::str value; 186 bool invoked = false; 187 }; 188 189 /// A CRTP base class for pseudo-containers willing to support Python-type 190 /// slicing access on top of indexed access. Calling ::bind on this class 191 /// will define `__len__` as well as `__getitem__` with integer and slice 192 /// arguments. 193 /// 194 /// This is intended for pseudo-containers that can refer to arbitrary slices of 195 /// underlying storage indexed by a single integer. Indexing those with an 196 /// integer produces an instance of ElementTy. Indexing those with a slice 197 /// produces a new instance of Derived, which can be sliced further. 198 /// 199 /// A derived class must provide the following: 200 /// - a `static const char *pyClassName ` field containing the name of the 201 /// Python class to bind; 202 /// - an instance method `intptr_t getNumElements()` that returns the number 203 /// of elements in the backing container (NOT that of the slice); 204 /// - an instance method `ElementTy getElement(intptr_t)` that returns a 205 /// single element at the given index. 206 /// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that 207 /// constructs a new instance of the derived pseudo-container with the 208 /// given slice parameters (to be forwarded to the Sliceable constructor). 209 /// 210 /// A derived class may additionally define: 211 /// - a `static void bindDerived(ClassTy &)` method to bind additional methods 212 /// the python class. 213 template <typename Derived, typename ElementTy> 214 class Sliceable { 215 protected: 216 using ClassTy = pybind11::class_<Derived>; 217 218 public: 219 explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step) 220 : startIndex(startIndex), length(length), step(step) { 221 assert(length >= 0 && "expected non-negative slice length"); 222 } 223 224 /// Returns the length of the slice. 225 intptr_t dunderLen() const { return length; } 226 227 /// Returns the element at the given slice index. Supports negative indices 228 /// by taking elements in inverse order. Throws if the index is out of bounds. 229 ElementTy dunderGetItem(intptr_t index) { 230 // Negative indices mean we count from the end. 231 if (index < 0) 232 index = length + index; 233 if (index < 0 || index >= length) { 234 throw python::SetPyError(PyExc_IndexError, 235 "attempt to access out of bounds"); 236 } 237 238 // Compute the linear index given the current slice properties. 239 int linearIndex = index * step + startIndex; 240 assert(linearIndex >= 0 && 241 linearIndex < static_cast<Derived *>(this)->getNumElements() && 242 "linear index out of bounds, the slice is ill-formed"); 243 return static_cast<Derived *>(this)->getElement(linearIndex); 244 } 245 246 /// Returns a new instance of the pseudo-container restricted to the given 247 /// slice. 248 Derived dunderGetItemSlice(pybind11::slice slice) { 249 ssize_t start, stop, extraStep, sliceLength; 250 if (!slice.compute(dunderLen(), &start, &stop, &extraStep, &sliceLength)) { 251 throw python::SetPyError(PyExc_IndexError, 252 "attempt to access out of bounds"); 253 } 254 return static_cast<Derived *>(this)->slice(startIndex + start * step, 255 sliceLength, step * extraStep); 256 } 257 258 /// Binds the indexing and length methods in the Python class. 259 static void bind(pybind11::module &m) { 260 auto clazz = pybind11::class_<Derived>(m, Derived::pyClassName) 261 .def("__len__", &Sliceable::dunderLen) 262 .def("__getitem__", &Sliceable::dunderGetItem) 263 .def("__getitem__", &Sliceable::dunderGetItemSlice); 264 Derived::bindDerived(clazz); 265 } 266 267 /// Hook for derived classes willing to bind more methods. 268 static void bindDerived(ClassTy &) {} 269 270 private: 271 intptr_t startIndex; 272 intptr_t length; 273 intptr_t step; 274 }; 275 276 } // namespace mlir 277 278 #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H 279