#include #include #include #include #include #include #include namespace py = pybind11; namespace pybind11 { namespace detail { #define ITEM_TYPE_CASTER(T, Name) \ template <> \ struct type_caster::Item> { \ public: \ using Item = typename torch::OrderedDict::Item; \ using PairCaster = make_caster>; \ PYBIND11_TYPE_CASTER(Item, _("Ordered" #Name "DictItem")); \ bool load(handle src, bool convert) { \ return PairCaster().load(src, convert); \ } \ static handle cast(Item src, return_value_policy policy, handle parent) { \ return PairCaster::cast( \ src.pair(), std::move(policy), std::move(parent)); \ } \ } // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) ITEM_TYPE_CASTER(torch::Tensor, Tensor); // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) ITEM_TYPE_CASTER(std::shared_ptr, Module); } // namespace detail } // namespace pybind11 namespace torch { namespace python { namespace { template void bind_ordered_dict(py::module module, const char* dict_name) { using ODict = OrderedDict; // clang-format off py::class_(module, dict_name) .def("items", &ODict::items) .def("keys", &ODict::keys) .def("values", &ODict::values) .def("__iter__", [](const ODict& dict) { return py::make_iterator(dict.begin(), dict.end()); }, py::keep_alive<0, 1>()) .def("__len__", &ODict::size) .def("__contains__", &ODict::contains) .def("__getitem__", [](const ODict& dict, const std::string& key) { return dict[key]; }) .def("__getitem__", [](const ODict& dict, size_t index) { return dict[index]; }); // clang-format on } } // namespace void init_bindings(PyObject* module) { py::module m = py::handle(module).cast(); py::module cpp = m.def_submodule("cpp"); bind_ordered_dict(cpp, "OrderedTensorDict"); bind_ordered_dict>(cpp, "OrderedModuleDict"); py::module nn = cpp.def_submodule("nn"); add_module_bindings( py::class_>(nn, "Module")); } } // namespace python } // namespace torch