• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2     pybind11/std_bind.h: Binding generators for STL data types
3 
4     Copyright (c) 2016 Sergey Lyskov and Wenzel Jakob
5 
6     All rights reserved. Use of this source code is governed by a
7     BSD-style license that can be found in the LICENSE file.
8 */
9 
10 #pragma once
11 
12 #include "detail/common.h"
13 #include "operators.h"
14 
15 #include <algorithm>
16 #include <sstream>
17 
18 PYBIND11_NAMESPACE_BEGIN(PYBIND11_NAMESPACE)
19 PYBIND11_NAMESPACE_BEGIN(detail)
20 
21 /* SFINAE helper class used by 'is_comparable */
22 template <typename T>  struct container_traits {
23     template <typename T2> static std::true_type test_comparable(decltype(std::declval<const T2 &>() == std::declval<const T2 &>())*);
24     template <typename T2> static std::false_type test_comparable(...);
25     template <typename T2> static std::true_type test_value(typename T2::value_type *);
26     template <typename T2> static std::false_type test_value(...);
27     template <typename T2> static std::true_type test_pair(typename T2::first_type *, typename T2::second_type *);
28     template <typename T2> static std::false_type test_pair(...);
29 
30     static constexpr const bool is_comparable = std::is_same<std::true_type, decltype(test_comparable<T>(nullptr))>::value;
31     static constexpr const bool is_pair = std::is_same<std::true_type, decltype(test_pair<T>(nullptr, nullptr))>::value;
32     static constexpr const bool is_vector = std::is_same<std::true_type, decltype(test_value<T>(nullptr))>::value;
33     static constexpr const bool is_element = !is_pair && !is_vector;
34 };
35 
36 /* Default: is_comparable -> std::false_type */
37 template <typename T, typename SFINAE = void>
38 struct is_comparable : std::false_type { };
39 
40 /* For non-map data structures, check whether operator== can be instantiated */
41 template <typename T>
42 struct is_comparable<
43     T, enable_if_t<container_traits<T>::is_element &&
44                    container_traits<T>::is_comparable>>
45     : std::true_type { };
46 
47 /* For a vector/map data structure, recursively check the value type (which is std::pair for maps) */
48 template <typename T>
49 struct is_comparable<T, enable_if_t<container_traits<T>::is_vector>> {
50     static constexpr const bool value =
51         is_comparable<typename T::value_type>::value;
52 };
53 
54 /* For pairs, recursively check the two data types */
55 template <typename T>
56 struct is_comparable<T, enable_if_t<container_traits<T>::is_pair>> {
57     static constexpr const bool value =
58         is_comparable<typename T::first_type>::value &&
59         is_comparable<typename T::second_type>::value;
60 };
61 
62 /* Fallback functions */
63 template <typename, typename, typename... Args> void vector_if_copy_constructible(const Args &...) { }
64 template <typename, typename, typename... Args> void vector_if_equal_operator(const Args &...) { }
65 template <typename, typename, typename... Args> void vector_if_insertion_operator(const Args &...) { }
66 template <typename, typename, typename... Args> void vector_modifiers(const Args &...) { }
67 
68 template<typename Vector, typename Class_>
69 void vector_if_copy_constructible(enable_if_t<is_copy_constructible<Vector>::value, Class_> &cl) {
70     cl.def(init<const Vector &>(), "Copy constructor");
71 }
72 
73 template<typename Vector, typename Class_>
74 void vector_if_equal_operator(enable_if_t<is_comparable<Vector>::value, Class_> &cl) {
75     using T = typename Vector::value_type;
76 
77     cl.def(self == self);
78     cl.def(self != self);
79 
80     cl.def("count",
81         [](const Vector &v, const T &x) {
82             return std::count(v.begin(), v.end(), x);
83         },
84         arg("x"),
85         "Return the number of times ``x`` appears in the list"
86     );
87 
88     cl.def("remove", [](Vector &v, const T &x) {
89             auto p = std::find(v.begin(), v.end(), x);
90             if (p != v.end())
91                 v.erase(p);
92             else
93                 throw value_error();
94         },
95         arg("x"),
96         "Remove the first item from the list whose value is x. "
97         "It is an error if there is no such item."
98     );
99 
100     cl.def("__contains__",
101         [](const Vector &v, const T &x) {
102             return std::find(v.begin(), v.end(), x) != v.end();
103         },
104         arg("x"),
105         "Return true the container contains ``x``"
106     );
107 }
108 
109 // Vector modifiers -- requires a copyable vector_type:
110 // (Technically, some of these (pop and __delitem__) don't actually require copyability, but it seems
111 // silly to allow deletion but not insertion, so include them here too.)
112 template <typename Vector, typename Class_>
113 void vector_modifiers(enable_if_t<is_copy_constructible<typename Vector::value_type>::value, Class_> &cl) {
114     using T = typename Vector::value_type;
115     using SizeType = typename Vector::size_type;
116     using DiffType = typename Vector::difference_type;
117 
118     auto wrap_i = [](DiffType i, SizeType n) {
119         if (i < 0)
120             i += n;
121         if (i < 0 || (SizeType)i >= n)
122             throw index_error();
123         return i;
124     };
125 
126     cl.def("append",
127            [](Vector &v, const T &value) { v.push_back(value); },
128            arg("x"),
129            "Add an item to the end of the list");
130 
131     cl.def(init([](iterable it) {
132         auto v = std::unique_ptr<Vector>(new Vector());
133         v->reserve(len_hint(it));
134         for (handle h : it)
135            v->push_back(h.cast<T>());
136         return v.release();
137     }));
138 
139     cl.def("clear",
140         [](Vector &v) {
141             v.clear();
142         },
143         "Clear the contents"
144     );
145 
146     cl.def("extend",
147        [](Vector &v, const Vector &src) {
148            v.insert(v.end(), src.begin(), src.end());
149        },
150        arg("L"),
151        "Extend the list by appending all the items in the given list"
152     );
153 
154     cl.def("extend",
155        [](Vector &v, iterable it) {
156            const size_t old_size = v.size();
157            v.reserve(old_size + len_hint(it));
158            try {
159                for (handle h : it) {
160                    v.push_back(h.cast<T>());
161                }
162            } catch (const cast_error &) {
163                v.erase(v.begin() + static_cast<typename Vector::difference_type>(old_size), v.end());
164                try {
165                    v.shrink_to_fit();
166                } catch (const std::exception &) {
167                    // Do nothing
168                }
169                throw;
170            }
171        },
172        arg("L"),
173        "Extend the list by appending all the items in the given list"
174     );
175 
176     cl.def("insert",
177         [](Vector &v, DiffType i, const T &x) {
178             // Can't use wrap_i; i == v.size() is OK
179             if (i < 0)
180                 i += v.size();
181             if (i < 0 || (SizeType)i > v.size())
182                 throw index_error();
183             v.insert(v.begin() + i, x);
184         },
185         arg("i") , arg("x"),
186         "Insert an item at a given position."
187     );
188 
189     cl.def("pop",
190         [](Vector &v) {
191             if (v.empty())
192                 throw index_error();
193             T t = v.back();
194             v.pop_back();
195             return t;
196         },
197         "Remove and return the last item"
198     );
199 
200     cl.def("pop",
201         [wrap_i](Vector &v, DiffType i) {
202             i = wrap_i(i, v.size());
203             T t = v[(SizeType) i];
204             v.erase(v.begin() + i);
205             return t;
206         },
207         arg("i"),
208         "Remove and return the item at index ``i``"
209     );
210 
211     cl.def("__setitem__",
212         [wrap_i](Vector &v, DiffType i, const T &t) {
213             i = wrap_i(i, v.size());
214             v[(SizeType)i] = t;
215         }
216     );
217 
218     /// Slicing protocol
219     cl.def("__getitem__",
220         [](const Vector &v, slice slice) -> Vector * {
221             size_t start, stop, step, slicelength;
222 
223             if (!slice.compute(v.size(), &start, &stop, &step, &slicelength))
224                 throw error_already_set();
225 
226             auto *seq = new Vector();
227             seq->reserve((size_t) slicelength);
228 
229             for (size_t i=0; i<slicelength; ++i) {
230                 seq->push_back(v[start]);
231                 start += step;
232             }
233             return seq;
234         },
235         arg("s"),
236         "Retrieve list elements using a slice object"
237     );
238 
239     cl.def("__setitem__",
240         [](Vector &v, slice slice,  const Vector &value) {
241             size_t start, stop, step, slicelength;
242             if (!slice.compute(v.size(), &start, &stop, &step, &slicelength))
243                 throw error_already_set();
244 
245             if (slicelength != value.size())
246                 throw std::runtime_error("Left and right hand size of slice assignment have different sizes!");
247 
248             for (size_t i=0; i<slicelength; ++i) {
249                 v[start] = value[i];
250                 start += step;
251             }
252         },
253         "Assign list elements using a slice object"
254     );
255 
256     cl.def("__delitem__",
257         [wrap_i](Vector &v, DiffType i) {
258             i = wrap_i(i, v.size());
259             v.erase(v.begin() + i);
260         },
261         "Delete the list elements at index ``i``"
262     );
263 
264     cl.def("__delitem__",
265         [](Vector &v, slice slice) {
266             size_t start, stop, step, slicelength;
267 
268             if (!slice.compute(v.size(), &start, &stop, &step, &slicelength))
269                 throw error_already_set();
270 
271             if (step == 1 && false) {
272                 v.erase(v.begin() + (DiffType) start, v.begin() + DiffType(start + slicelength));
273             } else {
274                 for (size_t i = 0; i < slicelength; ++i) {
275                     v.erase(v.begin() + DiffType(start));
276                     start += step - 1;
277                 }
278             }
279         },
280         "Delete list elements using a slice object"
281     );
282 
283 }
284 
285 // If the type has an operator[] that doesn't return a reference (most notably std::vector<bool>),
286 // we have to access by copying; otherwise we return by reference.
287 template <typename Vector> using vector_needs_copy = negation<
288     std::is_same<decltype(std::declval<Vector>()[typename Vector::size_type()]), typename Vector::value_type &>>;
289 
290 // The usual case: access and iterate by reference
291 template <typename Vector, typename Class_>
292 void vector_accessor(enable_if_t<!vector_needs_copy<Vector>::value, Class_> &cl) {
293     using T = typename Vector::value_type;
294     using SizeType = typename Vector::size_type;
295     using DiffType = typename Vector::difference_type;
296     using ItType   = typename Vector::iterator;
297 
298     auto wrap_i = [](DiffType i, SizeType n) {
299         if (i < 0)
300             i += n;
301         if (i < 0 || (SizeType)i >= n)
302             throw index_error();
303         return i;
304     };
305 
306     cl.def("__getitem__",
307         [wrap_i](Vector &v, DiffType i) -> T & {
308             i = wrap_i(i, v.size());
309             return v[(SizeType)i];
310         },
311         return_value_policy::reference_internal // ref + keepalive
312     );
313 
314     cl.def("__iter__",
315            [](Vector &v) {
316                return make_iterator<
317                    return_value_policy::reference_internal, ItType, ItType, T&>(
318                    v.begin(), v.end());
319            },
320            keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
321     );
322 }
323 
324 // The case for special objects, like std::vector<bool>, that have to be returned-by-copy:
325 template <typename Vector, typename Class_>
326 void vector_accessor(enable_if_t<vector_needs_copy<Vector>::value, Class_> &cl) {
327     using T = typename Vector::value_type;
328     using SizeType = typename Vector::size_type;
329     using DiffType = typename Vector::difference_type;
330     using ItType   = typename Vector::iterator;
331     cl.def("__getitem__",
332         [](const Vector &v, DiffType i) -> T {
333             if (i < 0 && (i += v.size()) < 0)
334                 throw index_error();
335             if ((SizeType)i >= v.size())
336                 throw index_error();
337             return v[(SizeType)i];
338         }
339     );
340 
341     cl.def("__iter__",
342            [](Vector &v) {
343                return make_iterator<
344                    return_value_policy::copy, ItType, ItType, T>(
345                    v.begin(), v.end());
346            },
347            keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
348     );
349 }
350 
351 template <typename Vector, typename Class_> auto vector_if_insertion_operator(Class_ &cl, std::string const &name)
352     -> decltype(std::declval<std::ostream&>() << std::declval<typename Vector::value_type>(), void()) {
353     using size_type = typename Vector::size_type;
354 
355     cl.def("__repr__",
356            [name](Vector &v) {
357             std::ostringstream s;
358             s << name << '[';
359             for (size_type i=0; i < v.size(); ++i) {
360                 s << v[i];
361                 if (i != v.size() - 1)
362                     s << ", ";
363             }
364             s << ']';
365             return s.str();
366         },
367         "Return the canonical string representation of this list."
368     );
369 }
370 
371 // Provide the buffer interface for vectors if we have data() and we have a format for it
372 // GCC seems to have "void std::vector<bool>::data()" - doing SFINAE on the existence of data() is insufficient, we need to check it returns an appropriate pointer
373 template <typename Vector, typename = void>
374 struct vector_has_data_and_format : std::false_type {};
375 template <typename Vector>
376 struct vector_has_data_and_format<Vector, enable_if_t<std::is_same<decltype(format_descriptor<typename Vector::value_type>::format(), std::declval<Vector>().data()), typename Vector::value_type*>::value>> : std::true_type {};
377 
378 // [workaround(intel)] Separate function required here
379 // Workaround as the Intel compiler does not compile the enable_if_t part below
380 // (tested with icc (ICC) 2021.1 Beta 20200827)
381 template <typename... Args>
382 constexpr bool args_any_are_buffer() {
383     return detail::any_of<std::is_same<Args, buffer_protocol>...>::value;
384 }
385 
386 // [workaround(intel)] Separate function required here
387 // [workaround(msvc)] Can't use constexpr bool in return type
388 
389 // Add the buffer interface to a vector
390 template <typename Vector, typename Class_, typename... Args>
391 void vector_buffer_impl(Class_& cl, std::true_type) {
392     using T = typename Vector::value_type;
393 
394     static_assert(vector_has_data_and_format<Vector>::value, "There is not an appropriate format descriptor for this vector");
395 
396     // numpy.h declares this for arbitrary types, but it may raise an exception and crash hard at runtime if PYBIND11_NUMPY_DTYPE hasn't been called, so check here
397     format_descriptor<T>::format();
398 
399     cl.def_buffer([](Vector& v) -> buffer_info {
400         return buffer_info(v.data(), static_cast<ssize_t>(sizeof(T)), format_descriptor<T>::format(), 1, {v.size()}, {sizeof(T)});
401     });
402 
403     cl.def(init([](buffer buf) {
404         auto info = buf.request();
405         if (info.ndim != 1 || info.strides[0] % static_cast<ssize_t>(sizeof(T)))
406             throw type_error("Only valid 1D buffers can be copied to a vector");
407         if (!detail::compare_buffer_info<T>::compare(info) || (ssize_t) sizeof(T) != info.itemsize)
408             throw type_error("Format mismatch (Python: " + info.format + " C++: " + format_descriptor<T>::format() + ")");
409 
410         T *p = static_cast<T*>(info.ptr);
411         ssize_t step = info.strides[0] / static_cast<ssize_t>(sizeof(T));
412         T *end = p + info.shape[0] * step;
413         if (step == 1) {
414             return Vector(p, end);
415         }
416         else {
417             Vector vec;
418             vec.reserve((size_t) info.shape[0]);
419             for (; p != end; p += step)
420                 vec.push_back(*p);
421             return vec;
422         }
423     }));
424 
425     return;
426 }
427 
428 template <typename Vector, typename Class_, typename... Args>
429 void vector_buffer_impl(Class_&, std::false_type) {}
430 
431 template <typename Vector, typename Class_, typename... Args>
432 void vector_buffer(Class_& cl) {
433     vector_buffer_impl<Vector, Class_, Args...>(cl, detail::any_of<std::is_same<Args, buffer_protocol>...>{});
434 }
435 
436 PYBIND11_NAMESPACE_END(detail)
437 
438 //
439 // std::vector
440 //
441 template <typename Vector, typename holder_type = std::unique_ptr<Vector>, typename... Args>
442 class_<Vector, holder_type> bind_vector(handle scope, std::string const &name, Args&&... args) {
443     using Class_ = class_<Vector, holder_type>;
444 
445     // If the value_type is unregistered (e.g. a converting type) or is itself registered
446     // module-local then make the vector binding module-local as well:
447     using vtype = typename Vector::value_type;
448     auto vtype_info = detail::get_type_info(typeid(vtype));
449     bool local = !vtype_info || vtype_info->module_local;
450 
451     Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward<Args>(args)...);
452 
453     // Declare the buffer interface if a buffer_protocol() is passed in
454     detail::vector_buffer<Vector, Class_, Args...>(cl);
455 
456     cl.def(init<>());
457 
458     // Register copy constructor (if possible)
459     detail::vector_if_copy_constructible<Vector, Class_>(cl);
460 
461     // Register comparison-related operators and functions (if possible)
462     detail::vector_if_equal_operator<Vector, Class_>(cl);
463 
464     // Register stream insertion operator (if possible)
465     detail::vector_if_insertion_operator<Vector, Class_>(cl, name);
466 
467     // Modifiers require copyable vector value type
468     detail::vector_modifiers<Vector, Class_>(cl);
469 
470     // Accessor and iterator; return by value if copyable, otherwise we return by ref + keep-alive
471     detail::vector_accessor<Vector, Class_>(cl);
472 
473     cl.def("__bool__",
474         [](const Vector &v) -> bool {
475             return !v.empty();
476         },
477         "Check whether the list is nonempty"
478     );
479 
480     cl.def("__len__", &Vector::size);
481 
482 
483 
484 
485 #if 0
486     // C++ style functions deprecated, leaving it here as an example
487     cl.def(init<size_type>());
488 
489     cl.def("resize",
490          (void (Vector::*) (size_type count)) & Vector::resize,
491          "changes the number of elements stored");
492 
493     cl.def("erase",
494         [](Vector &v, SizeType i) {
495         if (i >= v.size())
496             throw index_error();
497         v.erase(v.begin() + i);
498     }, "erases element at index ``i``");
499 
500     cl.def("empty",         &Vector::empty,         "checks whether the container is empty");
501     cl.def("size",          &Vector::size,          "returns the number of elements");
502     cl.def("push_back", (void (Vector::*)(const T&)) &Vector::push_back, "adds an element to the end");
503     cl.def("pop_back",                               &Vector::pop_back, "removes the last element");
504 
505     cl.def("max_size",      &Vector::max_size,      "returns the maximum possible number of elements");
506     cl.def("reserve",       &Vector::reserve,       "reserves storage");
507     cl.def("capacity",      &Vector::capacity,      "returns the number of elements that can be held in currently allocated storage");
508     cl.def("shrink_to_fit", &Vector::shrink_to_fit, "reduces memory usage by freeing unused memory");
509 
510     cl.def("clear", &Vector::clear, "clears the contents");
511     cl.def("swap",   &Vector::swap, "swaps the contents");
512 
513     cl.def("front", [](Vector &v) {
514         if (v.size()) return v.front();
515         else throw index_error();
516     }, "access the first element");
517 
518     cl.def("back", [](Vector &v) {
519         if (v.size()) return v.back();
520         else throw index_error();
521     }, "access the last element ");
522 
523 #endif
524 
525     return cl;
526 }
527 
528 
529 
530 //
531 // std::map, std::unordered_map
532 //
533 
534 PYBIND11_NAMESPACE_BEGIN(detail)
535 
536 /* Fallback functions */
537 template <typename, typename, typename... Args> void map_if_insertion_operator(const Args &...) { }
538 template <typename, typename, typename... Args> void map_assignment(const Args &...) { }
539 
540 // Map assignment when copy-assignable: just copy the value
541 template <typename Map, typename Class_>
542 void map_assignment(enable_if_t<is_copy_assignable<typename Map::mapped_type>::value, Class_> &cl) {
543     using KeyType = typename Map::key_type;
544     using MappedType = typename Map::mapped_type;
545 
546     cl.def("__setitem__",
547            [](Map &m, const KeyType &k, const MappedType &v) {
548                auto it = m.find(k);
549                if (it != m.end()) it->second = v;
550                else m.emplace(k, v);
551            }
552     );
553 }
554 
555 // Not copy-assignable, but still copy-constructible: we can update the value by erasing and reinserting
556 template<typename Map, typename Class_>
557 void map_assignment(enable_if_t<
558         !is_copy_assignable<typename Map::mapped_type>::value &&
559         is_copy_constructible<typename Map::mapped_type>::value,
560         Class_> &cl) {
561     using KeyType = typename Map::key_type;
562     using MappedType = typename Map::mapped_type;
563 
564     cl.def("__setitem__",
565            [](Map &m, const KeyType &k, const MappedType &v) {
566                // We can't use m[k] = v; because value type might not be default constructable
567                auto r = m.emplace(k, v);
568                if (!r.second) {
569                    // value type is not copy assignable so the only way to insert it is to erase it first...
570                    m.erase(r.first);
571                    m.emplace(k, v);
572                }
573            }
574     );
575 }
576 
577 
578 template <typename Map, typename Class_> auto map_if_insertion_operator(Class_ &cl, std::string const &name)
579 -> decltype(std::declval<std::ostream&>() << std::declval<typename Map::key_type>() << std::declval<typename Map::mapped_type>(), void()) {
580 
581     cl.def("__repr__",
582            [name](Map &m) {
583             std::ostringstream s;
584             s << name << '{';
585             bool f = false;
586             for (auto const &kv : m) {
587                 if (f)
588                     s << ", ";
589                 s << kv.first << ": " << kv.second;
590                 f = true;
591             }
592             s << '}';
593             return s.str();
594         },
595         "Return the canonical string representation of this map."
596     );
597 }
598 
599 
600 PYBIND11_NAMESPACE_END(detail)
601 
602 template <typename Map, typename holder_type = std::unique_ptr<Map>, typename... Args>
603 class_<Map, holder_type> bind_map(handle scope, const std::string &name, Args&&... args) {
604     using KeyType = typename Map::key_type;
605     using MappedType = typename Map::mapped_type;
606     using Class_ = class_<Map, holder_type>;
607 
608     // If either type is a non-module-local bound type then make the map binding non-local as well;
609     // otherwise (e.g. both types are either module-local or converting) the map will be
610     // module-local.
611     auto tinfo = detail::get_type_info(typeid(MappedType));
612     bool local = !tinfo || tinfo->module_local;
613     if (local) {
614         tinfo = detail::get_type_info(typeid(KeyType));
615         local = !tinfo || tinfo->module_local;
616     }
617 
618     Class_ cl(scope, name.c_str(), pybind11::module_local(local), std::forward<Args>(args)...);
619 
620     cl.def(init<>());
621 
622     // Register stream insertion operator (if possible)
623     detail::map_if_insertion_operator<Map, Class_>(cl, name);
624 
625     cl.def("__bool__",
626         [](const Map &m) -> bool { return !m.empty(); },
627         "Check whether the map is nonempty"
628     );
629 
630     cl.def("__iter__",
631            [](Map &m) { return make_key_iterator(m.begin(), m.end()); },
632            keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
633     );
634 
635     cl.def("items",
636            [](Map &m) { return make_iterator(m.begin(), m.end()); },
637            keep_alive<0, 1>() /* Essential: keep list alive while iterator exists */
638     );
639 
640     cl.def("__getitem__",
641         [](Map &m, const KeyType &k) -> MappedType & {
642             auto it = m.find(k);
643             if (it == m.end())
644               throw key_error();
645            return it->second;
646         },
647         return_value_policy::reference_internal // ref + keepalive
648     );
649 
650     cl.def("__contains__",
651         [](Map &m, const KeyType &k) -> bool {
652             auto it = m.find(k);
653             if (it == m.end())
654               return false;
655            return true;
656         }
657     );
658 
659     // Assignment provided only if the type is copyable
660     detail::map_assignment<Map, Class_>(cl);
661 
662     cl.def("__delitem__",
663            [](Map &m, const KeyType &k) {
664                auto it = m.find(k);
665                if (it == m.end())
666                    throw key_error();
667                m.erase(it);
668            }
669     );
670 
671     cl.def("__len__", &Map::size);
672 
673     return cl;
674 }
675 
676 PYBIND11_NAMESPACE_END(PYBIND11_NAMESPACE)
677