• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/traceback.h"
17 
18 #include <stdexcept>
19 
20 #include "absl/strings/str_format.h"
21 #include "absl/strings/str_join.h"
22 #include "pybind11/pytypes.h"
23 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
24 #include "tensorflow/core/platform/logging.h"
25 
26 namespace xla {
27 
28 namespace py = pybind11;
29 
30 bool Traceback::enabled_ = true;
31 
~Traceback()32 Traceback::~Traceback() {
33   // We want Traceback objects to be safe to destroy without holding the
34   // GIL, so we defer destruction of the strings.
35   GlobalPyRefManager()->AddGarbage(frames_);
36 }
37 
ToString() const38 std::string Traceback::Frame::ToString() const {
39   return absl::StrFormat("%s:%d (%s)", file_name, line_num, function_name);
40 }
41 
ToString() const42 std::string Traceback::ToString() const {
43   std::vector<std::string> frame_strs;
44   frame_strs.reserve(frames_.size());
45   for (const Frame& frame : Frames()) {
46     frame_strs.push_back(frame.ToString());
47   }
48   return absl::StrJoin(frame_strs, "\n");
49 }
50 
Frames() const51 std::vector<Traceback::Frame> Traceback::Frames() const {
52   // We require the GIL because we manipulate Python strings.
53   CHECK(PyGILState_Check());
54   std::vector<Traceback::Frame> frames;
55   frames.reserve(frames_.size());
56   for (const auto& frame : frames_) {
57     frames.push_back(Frame{
58         std::string(py::reinterpret_borrow<py::str>(frame.first->co_filename)),
59         std::string(py::reinterpret_borrow<py::str>(frame.first->co_name)),
60         frame.first->co_firstlineno,
61         PyCode_Addr2Line(frame.first, frame.second)});
62   }
63   return frames;
64 }
65 
Get()66 std::shared_ptr<Traceback> Traceback::Get() {
67   DCHECK(PyGILState_Check());
68   if (!enabled_) {
69     return nullptr;
70   }
71   auto tb = std::make_shared<Traceback>();
72   const PyThreadState* thread_state = PyThreadState_GET();
73   for (PyFrameObject* py_frame = thread_state->frame; py_frame != nullptr;
74        py_frame = py_frame->f_back) {
75     Py_INCREF(py_frame->f_code);
76     tb->frames_.emplace_back(py_frame->f_code, py_frame->f_lasti);
77   }
78   return tb;
79 }
80 
SetEnabled(bool enabled)81 void Traceback::SetEnabled(bool enabled) { enabled_ = enabled; }
82 
AsPythonTraceback() const83 py::object Traceback::AsPythonTraceback() const {
84   py::object traceback = py::none();
85   py::dict globals;
86   py::handle traceback_type(reinterpret_cast<PyObject*>(&PyTraceBack_Type));
87   for (const std::pair<PyCodeObject*, int>& frame : frames_) {
88     PyFrameObject* py_frame = PyFrame_New(PyThreadState_Get(), frame.first,
89                                           globals.ptr(), /*locals=*/nullptr);
90 
91     traceback = traceback_type(
92         /*tb_next=*/std::move(traceback),
93         /*tb_frame=*/
94         py::reinterpret_steal<py::object>(
95             reinterpret_cast<PyObject*>(py_frame)),
96         /*tb_lasti=*/frame.second,
97         /*tb_lineno=*/PyCode_Addr2Line(frame.first, frame.second));
98   }
99   return traceback;
100 }
101 
BuildTracebackSubmodule(py::module & m)102 void BuildTracebackSubmodule(py::module& m) {
103   py::class_<Traceback::Frame>(m, "Frame")
104       .def_readonly("file_name", &Traceback::Frame::file_name)
105       .def_readonly("function_name", &Traceback::Frame::function_name)
106       .def_readonly("function_start_line",
107                     &Traceback::Frame::function_start_line)
108       .def_readonly("line_num", &Traceback::Frame::line_num)
109       .def("__repr__", [](const Traceback::Frame& frame) {
110         return absl::StrFormat("%s;%s:%d", frame.function_name, frame.file_name,
111                                frame.line_num);
112       });
113 
114   py::class_<Traceback, std::shared_ptr<Traceback>> traceback(
115       m, "Traceback", "Represents a Python stack trace.");
116   traceback.def_property_static(
117       "enabled", [](py::object /* cls */) { return Traceback::enabled(); },
118       [](py::object /* cls */, bool enabled) {
119         return Traceback::SetEnabled(enabled);
120       });
121   traceback.def_static(
122       "get_traceback", []() { return Traceback::Get(); },
123       R"doc(
124     Returns a :class:`Traceback` for the current thread.
125 
126     If ``Traceback.enabled`` is ``True``, returns a :class:`Traceback` object
127     that describes the Python stack of the calling thread. Stack trace
128     collection has a small overhead, so it is disabled by default. If traceback
129     collection is disabled, returns ``None``.
130     )doc");
131   traceback.def_property_readonly("frames", &Traceback::Frames);
132   traceback.def("__str__", &Traceback::ToString);
133   traceback.def("as_python_traceback", &Traceback::AsPythonTraceback);
134 
135   // This function replaces the exception traceback associated with the current
136   // Python thread.
137   m.def(
138       "replace_thread_exc_traceback",
139       [](py::object tb) {
140         if (!PyTraceBack_Check(tb.ptr())) {
141           throw std::runtime_error("argument must be a traceback object");
142         }
143         PyThreadState* thread_state = PyThreadState_Get();
144         if (!thread_state->exc_info->exc_traceback) {
145           throw std::runtime_error(
146               "Current thread does not have an active "
147               "exception traceback");
148         }
149         PyObject* old_exc_traceback = thread_state->exc_info->exc_traceback;
150         thread_state->exc_info->exc_traceback = tb.release().ptr();
151         Py_XDECREF(old_exc_traceback);
152       },
153       py::arg("traceback"));
154 }
155 
156 }  // namespace xla
157