#include #include #include #include #include #include #include #include #include #include #include #include using namespace torch::autograd; using namespace torch::jit; using namespace torch::jit::tracer; namespace torch::jit::tracer { // Python interpreter retrieval routine adapted from // https://stackoverflow.com/a/8706144 std::vector _pythonCallstack() { pybind11::gil_scoped_acquire gil; PyFrameObject* frame = PyEval_GetFrame(); Py_XINCREF(frame); std::vector entries; while (nullptr != frame) { auto code = THPCodeObjectPtr(PyFrame_GetCode(frame)); size_t line = PyCode_Addr2Line(code.get(), PyFrame_GetLasti(frame)); std::string filename = THPUtils_unpackString(code->co_filename); std::string funcname = THPUtils_unpackString(code->co_name); auto source = std::make_shared(funcname, filename, line); entries.emplace_back( StackEntry{funcname, SourceRange(source, 0, funcname.size())}); auto new_frame = PyFrame_GetBack(frame); Py_DECREF(frame); frame = new_frame; } return entries; } SourceRange getPythonInterpreterSourceRange() { auto cs = pythonCallstack(); std::optional source_filename; size_t source_line = 0; std::stringstream stack_trace; for (const auto& entry : cs) { auto& range = entry.range; if (range.source()) { auto& src = range.source(); if (src && src->filename()) { auto line = src->starting_line_no() + src->lineno_for_offset(range.start()); stack_trace << *(src->filename()) << "(" << line << "): " << entry.filename << "\n"; if (!source_filename) { source_filename = *(src->filename()); source_line = line; } } } } auto stack_trace_text = stack_trace.str(); auto source = std::make_shared(stack_trace_text, source_filename, source_line); return SourceRange(source, 0, stack_trace_text.size()); } std::pair, Stack> createGraphByTracingWithDict( const py::function& func, const py::dict& inputs_dict, const Stack& trace_inputs, const py::function& var_name_lookup_fn, bool strict, bool force_outplace, Module* self, const std::vector& argument_names) { C10_LOG_API_USAGE_ONCE("torch.tracer"); auto lookup_fn_adapter = [var_name_lookup_fn](const Variable& var) -> std::string { pybind11::gil_scoped_acquire ag; return py::cast(var_name_lookup_fn(var)); }; // The argument_names parameter is parsed in python and its order // is the same as the arguments' decalaration order in forward() method. // These name shall be added to the graph as debug name and the order // should align with the traceable stack we generated by the python dict. std::vector compact_argument_names; Stack compact_trace_inputs; for (const auto& argument_name : argument_names) { if (inputs_dict.contains(argument_name)) { compact_argument_names.push_back(argument_name); } } for (const auto& compact_argument_name : compact_argument_names) { for (auto it = inputs_dict.begin(); it != inputs_dict.end(); it++) { if (py::cast(it->first) == compact_argument_name) { compact_trace_inputs.push_back( toIValue(it->second, tryToInferType(it->second).type())); } } } auto outs = tracer::trace( std::move(compact_trace_inputs), [&](const Stack& inputs) -> Stack { // We just leave the inputs_dict as it was and pass it to forward // method. auto out = func(**inputs_dict); if (out.ptr() == Py_None) { AT_ERROR( "The traced function didn't return any values! Side-effects are not " "captured in traces, so it would be a no-op."); } return {toTypeInferredIValue(out)}; }, lookup_fn_adapter, strict, force_outplace, self, compact_argument_names); return std::make_pair(std::get<0>(outs)->graph, std::get<1>(outs)); } std::pair, Stack> createGraphByTracing( const py::function& func, Stack trace_inputs, const py::function& var_name_lookup_fn, bool strict, bool force_outplace, Module* self, const std::vector& argument_names) { C10_LOG_API_USAGE_ONCE("torch.tracer"); auto lookup_fn_adapter = [var_name_lookup_fn](const Variable& var) -> std::string { pybind11::gil_scoped_acquire ag; return py::cast(var_name_lookup_fn(var)); }; auto outs = tracer::trace( std::move(trace_inputs), [&func](Stack inputs) -> Stack { size_t num_func_inputs = inputs.size(); py::tuple py_inputs(num_func_inputs); for (const auto i : c10::irange(num_func_inputs)) { py_inputs[i] = py::cast(inputs[i]); } auto out = func(*py_inputs); if (out.ptr() == Py_None) { AT_ERROR( "The traced function didn't return any values! Side-effects are not " "captured in traces, so it would be a no-op."); } return {toTypeInferredIValue(out)}; }, lookup_fn_adapter, strict, force_outplace, self, argument_names); return std::make_pair(std::get<0>(outs)->graph, std::get<1>(outs)); } Node* preRecordPythonTrace( THPObjectPtr pyobj, const std::string& arg_types, at::ArrayRef inputs, pyobj_list scalar_args) { THPObjectPtr apply(PyObject_GetAttrString(pyobj.get(), "apply")); if (!apply) { throw python_error(); } auto& graph = getTracingState()->graph; Node* n = graph->createPythonOp( std::move(apply), arg_types, std::move(scalar_args)); recordSourceLocation(n); for (const Variable& input : inputs) { n->addInput(getValueTrace(input)); } graph->insertNode(n); return n; } void pythonRecordSourceLocation(Node* n) { n->setSourceRange(getPythonInterpreterSourceRange()); } void pythonWarn(const std::string& reason) { pybind11::gil_scoped_acquire gil; auto warn_class = py::module::import("torch.jit").attr("TracerWarning"); PyErr_WarnEx(warn_class.ptr(), reason.c_str(), 1); } void initPythonTracerBindings(PyObject* module) { setPythonCallstack(_pythonCallstack); setRecordSourceLocation(pythonRecordSourceLocation); auto m = py::handle(module).cast(); py::class_>( m, "TracingState", py::dynamic_attr()) // NB: no constructor; you have to get it from C++ code .def( "__repr__", [](const TracingState& s) { std::ostringstream ss; ss << ""; return ss.str(); }) .def( "__str__", [](const TracingState& s) -> std::string { std::ostringstream ss; ss << *s.graph; return ss.str(); }) .def( "push_scope", [](TracingState& s, const std::string& scope_name) { s.graph->push_scope(scope_name); }) .def("pop_scope", [](TracingState& s) { s.graph->pop_scope(); }) .def( "current_scope", [](TracingState& s) { return s.graph->current_scope()->name().toUnqualString(); }) .def( "set_graph", [](TracingState& s, std::shared_ptr g) { s.graph = std::move(g); }) .def("graph", [](TracingState& s) { return s.graph; }); m.def("_tracer_warn_use_python", []() { tracer::setWarn(pythonWarn); }); m.def( "_create_graph_by_tracing", createGraphByTracing, py::arg("func"), py::arg("inputs"), py::arg("var_name_lookup_fn"), py::arg("strict"), py::arg("force_outplace"), py::arg("self") = nullptr, py::arg("argument_names") = std::vector()); m.def("_get_tracing_state", []() { return getTracingState(); }); m.def("_set_tracing_state", [](std::shared_ptr state) { return setTracingState(std::move(state)); }); m.def("_get_value_trace", [](const Variable& var) { return getValueTrace(var); }); m.def("_set_value_trace", [](const Variable& var, Value* value) { return setValueTrace(var, value); }); m.def("_tracer_set_get_unique_name_fn", [](const py::function& func) { const auto& tracing_state = getTracingState(); AT_ASSERT(tracing_state); tracing_state->lookup_var_name_fn = [func](const Variable& var) -> std::string { pybind11::gil_scoped_acquire ag; return py::cast(func(var)); }; }); m.def("_tracer_set_force_outplace", [](bool force_outplace) { const auto& tracing_state = getTracingState(); AT_ASSERT(tracing_state); tracing_state->force_outplace = force_outplace; }); } } // namespace torch::jit::tracer