#include #include #include #include #include #include namespace py = pybind11; namespace torch::jit { std::optional maybeConvertToString(const py::object& obj) { if (obj.is_none()) { return std::nullopt; } std::stringstream ss; ss << py::str(obj); return ss.str(); } struct SourceRangeFactory { SourceRangeFactory( const std::string& text, const py::object& filename, size_t file_lineno, size_t leading_whitespace_chars) : source_(std::make_shared( text, maybeConvertToString(filename), file_lineno)), leading_whitespace_chars_(leading_whitespace_chars) {} SourceRange create(int line, int start_col, int end_col) { auto [start_byte_offset, end_byte_offset] = line_col_to_byte_offs( line, start_col + leading_whitespace_chars_, end_col + leading_whitespace_chars_); return SourceRange(source_, start_byte_offset, end_byte_offset); } std::tuple line_col_to_byte_offs( int line, size_t start_col, size_t end_col) { // lines are counted from 1. line--; auto line_start = source_->offset_for_line(line); return std::make_tuple( line_start + start_col, line_start + end_col); } std::shared_ptr source_; std::vector line_len_prefix_sum_; size_t leading_whitespace_chars_; }; template List wrap_list(const SourceRange& fallback_pos, std::vector&& vec) { if (vec.empty()) return List::create(fallback_pos, std::move(vec)); return List::create(vec.front().range(), std::move(vec)); } template Maybe wrap_maybe(const SourceRange& fallback_pos, T* val) { return val ? Maybe::create(val->range(), *val) : Maybe::create(fallback_pos); } void initTreeViewBindings(PyObject* module) { auto _C = py::handle(module).cast(); auto m = _C.def_submodule("_jit_tree_views"); py::class_(m, "SourceRange") .def( "highlight", [](const SourceRange& self) { std::ostringstream stream; self.highlight(stream); return stream.str(); }) .def("__repr__", [](const SourceRange& self) { return self.str(); }) .def( "__str__", [](const SourceRange& self) { return "SourceRange at:\n" + self.str(); }) .def_property_readonly("start", &SourceRange::start) .def_property_readonly("end", &SourceRange::end); py::class_(m, "SourceRangeFactory") .def(py::init()) .def("make_range", &SourceRangeFactory::create) .def( "make_raw_range", [](const SourceRangeFactory& self, size_t start, size_t end) { return SourceRange(self.source_, start, end); }) .def_property_readonly("source", [](const SourceRangeFactory& self) { auto text_view = self.source_->text_str().str(); return text_view; }); py::class_(m, "TreeView") .def("range", &TreeView::range) .def( "__str__", [](const TreeView& tree) { std::ostringstream stream; stream << tree.get(); return stream.str(); }) .def("dump", [](const TreeView& tree) { tree.dump(); }); py::class_(m, "Ident") .def(py::init(&Ident::create)) .def_property_readonly( "name", [](const Ident& self) { return self.name(); }); py::class_(m, "Param") .def(py::init([](const Expr& type, const Ident& name, bool kwarg_only) { return Param::create( name.range(), name, Maybe::create(type.range(), type), Maybe::create(name.range()), kwarg_only); })) .def(py::init( [](const Maybe& type, const Ident& name, bool kwarg_only) { return Param::create( name.range(), name, type, Maybe::create(name.range()), kwarg_only); })); py::class_(m, "Attribute") .def(py::init([](const Ident& name, const Expr& value) { return Attribute::create(name.range(), name, value); })); m.def("TrueLiteral", [](const SourceRange& range) { return Expr(Compound::create(TK_TRUE, range, {})); }); m.def("FalseLiteral", [](const SourceRange& range) { return Expr(Compound::create(TK_FALSE, range, {})); }); m.def("NoneLiteral", [](const SourceRange& range) { return Expr(Compound::create(TK_NONE, range, {})); }); py::class_(m, "Stmt") // NOLINT(bugprone-unused-raii) .def(py::init([](const TreeView& thing) { return Stmt(thing.get()); })); py::class_(m, "Expr"); // NOLINT(bugprone-unused-raii) py::class_(m, "Def") .def(py::init( [](const Ident& name, const Decl& decl, std::vector body) { const auto& r = name.range(); return Def::create(r, name, decl, wrap_list(r, std::move(body))); })) .def("decl", [](const Def& def) { return def.decl(); }) .def("name", [](const Def& def) { return def.name(); }); py::class_(m, "Property") .def(py::init([](const SourceRange& r, const Ident& name, const Def& getter, Def* setter) { return Property::create(r, name, getter, wrap_maybe(r, setter)); })) .def("name", [](const Property& property) { return property.name(); }) .def( "getter_name", [](const Property& property) { return property.getter().name(); }) .def("setter_name", [](const Property& property) { if (property.setter().present()) { return std::optional(property.setter().get().name()); } return std::optional(std::nullopt); }); py::class_(m, "ClassDef") .def(py::init([](const Ident& name, std::vector body, std::vector props, std::vector assigns) { const auto& r = name.range(); return ClassDef::create( r, name, Maybe::create(r), wrap_list(r, std::move(body)), wrap_list(r, std::move(props)), wrap_list(r, std::move(assigns))); })); py::class_(m, "Decl").def(py::init( [](const SourceRange& r, std::vector params, Expr* return_type) { return Decl::create( r, wrap_list(r, std::move(params)), wrap_maybe(r, return_type)); })); py::class_(m, "Delete") .def(py::init([](const SourceRange& range, std::vector targets) { return Delete::create(range, wrap_list(range, std::move(targets))); })); py::class_(m, "WithItem") .def(py::init([](const SourceRange& range, const Expr& target, Var* var) { return WithItem::create(range, target, wrap_maybe(range, var)); })); py::class_(m, "Assign") .def(py::init([](std::vector lhs, const Expr& rhs) { auto li = wrap_list(rhs.range(), std::move(lhs)); return Assign::create( li.range(), li, Maybe::create(rhs.range(), rhs), Maybe::create(li.range())); })) .def(py::init([](std::vector lhs, const Expr& rhs, Expr* type) { auto li = wrap_list(rhs.range(), std::move(lhs)); return Assign::create( li.range(), li, Maybe::create(rhs.range(), rhs), wrap_maybe(li.range(), type)); })); py::class_(m, "AugAssign") .def(py::init( [](const Expr& lhs, const std::string& kind_str, const Expr& rhs) { const auto& r = lhs.range(); auto kind = AugAssignKind(Compound::create(stringToKind(kind_str), r, {})); return AugAssign::create(r, lhs, kind, rhs); })); py::class_(m, "Return") .def(py::init([](const SourceRange& range, Expr* value) { return Return::create( range, value ? *value : Expr(Compound::create(TK_NONE, range, {}))); })); py::class_(m, "Raise") .def(py::init([](const SourceRange& range, const Expr& expr) { return Raise::create(range, expr); })); py::class_(m, "Assert") .def(py::init([](const SourceRange& range, const Expr& test, Expr* msg) { return Assert::create(range, test, wrap_maybe(range, msg)); })); py::class_(m, "Pass").def( py::init([](const SourceRange& range) { return Pass::create(range); })); py::class_(m, "Break") .def(py::init( [](const SourceRange& range) { return Break::create(range); })); py::class_(m, "Continue") .def(py::init( [](const SourceRange& range) { return Continue::create(range); })); py::class_(m, "Dots").def( py::init([](const SourceRange& range) { return Dots::create(range); })); py::class_(m, "If").def( py::init([](const SourceRange& range, const Expr& cond, std::vector true_branch, std::vector false_branch) { return If::create( range, cond, wrap_list(range, std::move(true_branch)), wrap_list(range, std::move(false_branch))); })); py::class_(m, "While") .def(py::init([](const SourceRange& range, const Expr& cond, std::vector body) { return While::create(range, cond, wrap_list(range, std::move(body))); })); py::class_(m, "With").def( py::init([](const SourceRange& range, std::vector targets, std::vector body) { return With::create( range, wrap_list(range, std::move(targets)), wrap_list(range, std::move(body))); })); py::class_(m, "For").def(py::init([](const SourceRange& range, std::vector& targets, std::vector& itrs, std::vector body) { return For::create( range, wrap_list(range, std::move(targets)), wrap_list(range, std::move(itrs)), wrap_list(range, std::move(body))); })); py::class_(m, "ExprStmt").def(py::init([](const Expr& expr) { return ExprStmt::create(expr.range(), expr); })); py::class_(m, "Var") .def(py::init( [](const Ident& name) { return Var::create(name.range(), name); })) .def_property_readonly("name", [](const Var& var) { return var.name(); }); py::class_(m, "BinOp") .def(py::init( [](const std::string& kind, const Expr& lhs, const Expr& rhs) { return BinOp::create(lhs.range(), stringToKind(kind), lhs, rhs); })); // NB: we take range here, because unary ops precede their exprs, so we need // to include them py::class_(m, "UnaryOp") .def(py::init([](const SourceRange& range, const std::string& kind, const Expr& expr) { auto resolved_kind = stringToKind(kind); resolved_kind = resolved_kind == '-' ? TK_UNARY_MINUS : resolved_kind; return UnaryOp::create(range, resolved_kind, expr); })); py::class_(m, "Const") .def(py::init([](const SourceRange& range, const std::string& value) { return Const::create(range, value); })); py::class_(m, "StringLiteral") .def(py::init([](const SourceRange& range, const std::string& value) { return StringLiteral::create(range, value); })); py::class_(m, "Apply") .def(py::init([](const Expr& expr, std::vector args, std::vector kwargs) { const auto& r = expr.range(); return Apply::create( expr.range(), expr, wrap_list(r, std::move(args)), wrap_list(r, std::move(kwargs))); })); py::class_(m, "Select") .def(py::init([](const Expr& expr, const Ident& field) { return Select::create(expr.range(), expr, field); })); py::class_(m, "TernaryIf") .def(py::init( [](const Expr& cond, const Expr& true_expr, const Expr& false_expr) { return TernaryIf::create(cond.range(), cond, true_expr, false_expr); })); py::class_(m, "ListComp") .def(py::init([](const SourceRange& range, const Expr& elt, const Expr& target, const Expr& iter) { return ListComp::create(range, elt, target, iter); })); py::class_(m, "DictComp") .def(py::init([](const SourceRange& range, const Expr& key, const Expr& value, const Expr& target, const Expr& iter) { return DictComp::create(range, key, value, target, iter); })); py::class_(m, "ListLiteral") .def(py::init([](const SourceRange& range, std::vector args) { return ListLiteral::create(range, wrap_list(range, std::move(args))); })); py::class_(m, "TupleLiteral") .def(py::init([](const SourceRange& range, std::vector args) { return TupleLiteral::create(range, wrap_list(range, std::move(args))); })); py::class_(m, "DictLiteral") .def(py::init([](const SourceRange& range, std::vector keys, std::vector values) { return DictLiteral::create( range, wrap_list(range, std::move(keys)), wrap_list(range, std::move(values))); })); py::class_(m, "Subscript") .def(py::init([](const Expr& base, std::vector subscript_exprs) { return Subscript::create( base.range(), base, wrap_list(base.range(), std::move(subscript_exprs))); })); py::class_(m, "SliceExpr") .def(py::init( [](const SourceRange& range, Expr* lower, Expr* upper, Expr* step) { return SliceExpr::create( range, wrap_maybe(range, lower), wrap_maybe(range, upper), wrap_maybe(range, step)); })); py::class_(m, "Starred") .def(py::init([](const SourceRange& range, const Expr& expr) { return Starred::create(range, expr); })); py::class_, TreeView>(m, "EmptyTypeAnnotation") .def(py::init( [](const SourceRange& range) { return Maybe::create(range); })); } } // namespace torch::jit