1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/py_client.h"
17
18 #include <memory>
19
20 #include "absl/container/flat_hash_map.h"
21 #include "tensorflow/compiler/xla/python/py_buffer.h"
22 #include "tensorflow/compiler/xla/python/py_executable.h"
23 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
24 #include "tensorflow/compiler/xla/python/traceback.h"
25 #include "tensorflow/compiler/xla/python/types.h"
26 #include "tensorflow/core/profiler/profile.pb.h"
27
28 namespace xla {
29
30 namespace py = pybind11;
31 namespace pprof = tensorflow::tfprof::pprof;
32
PyClient(std::unique_ptr<PjRtClient> pjrt_client)33 PyClient::PyClient(std::unique_ptr<PjRtClient> pjrt_client)
34 : pjrt_client_(std::move(pjrt_client)) {}
PyClient(std::shared_ptr<PjRtClient> pjrt_client)35 PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
36 : pjrt_client_(std::move(pjrt_client)) {}
37
Devices()38 std::vector<ClientAndPtr<PjRtDevice>> PyClient::Devices() {
39 std::vector<ClientAndPtr<PjRtDevice>> devices;
40 auto span = pjrt_client_->devices();
41 devices.reserve(span.size());
42 for (PjRtDevice* device : span) {
43 devices.push_back(WrapWithClient(shared_from_this(), device));
44 }
45 return devices;
46 }
47
LocalDevices()48 std::vector<ClientAndPtr<PjRtDevice>> PyClient::LocalDevices() {
49 std::vector<ClientAndPtr<PjRtDevice>> devices;
50 devices.reserve(pjrt_client_->addressable_devices().size());
51 for (PjRtDevice* device : pjrt_client_->addressable_devices()) {
52 devices.push_back(WrapWithClient(shared_from_this(), device));
53 }
54 return devices;
55 }
56
LiveBuffers()57 std::vector<ClientAndPtr<PyBuffer>> PyClient::LiveBuffers() {
58 CHECK(PyGILState_Check());
59 std::vector<ClientAndPtr<PyBuffer>> buffers;
60 for (PyBuffer* buffer = buffers_; buffer; buffer = buffer->next_) {
61 if (!buffer->is_deleted()) {
62 buffers.push_back(WrapWithClient(shared_from_this(), buffer));
63 }
64 }
65 return buffers;
66 }
67
68 StatusOr<std::vector<std::vector<ClientAndPtr<PjRtDevice>>>>
GetDefaultDeviceAssignment(int num_replicas,int num_partitions)69 PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) {
70 TF_ASSIGN_OR_RETURN(
71 DeviceAssignment device_assignment,
72 pjrt_client_->GetDefaultDeviceAssignment(num_replicas, num_partitions));
73 std::vector<std::vector<ClientAndPtr<PjRtDevice>>> result;
74 result.resize(num_replicas);
75 for (int r = 0; r < num_replicas; ++r) {
76 result[r].resize(num_partitions);
77 for (int p = 0; p < num_partitions; ++p) {
78 int device_id = device_assignment(r, p);
79 TF_ASSIGN_OR_RETURN(PjRtDevice * device,
80 pjrt_client_->LookupDevice(device_id));
81 result[r][p] = WrapWithClient(shared_from_this(), device);
82 }
83 }
84 return result;
85 }
86
87 StatusOr<std::vector<ClientAndPtr<PjRtDevice>>>
GetDefaultDeviceAssignment1D(int num_replicas)88 PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
89 TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
90 pjrt_client_->GetDefaultDeviceAssignment(
91 num_replicas, /*num_partitions=*/1));
92 std::vector<ClientAndPtr<PjRtDevice>> result;
93 for (int i = 0; i < num_replicas; ++i) {
94 int device_id = device_assignment(i, 0);
95 TF_ASSIGN_OR_RETURN(PjRtDevice * device,
96 pjrt_client_->LookupDevice(device_id));
97 result.push_back(WrapWithClient(shared_from_this(), device));
98 }
99 return result;
100 }
101
PjRtBufferFromPyval(pybind11::handle argument,PjRtDevice * device,bool force_copy,PjRtClient::HostBufferSemantics host_buffer_semantics)102 StatusOr<std::unique_ptr<PjRtBuffer>> PyClient::PjRtBufferFromPyval(
103 pybind11::handle argument, PjRtDevice* device, bool force_copy,
104 PjRtClient::HostBufferSemantics host_buffer_semantics) {
105 if (device == nullptr) {
106 TF_RET_CHECK(!pjrt_client_->addressable_devices().empty());
107 device = pjrt_client_->addressable_devices().front();
108 }
109 CHECK(device != nullptr);
110 TF_ASSIGN_OR_RETURN(PjRtDevice * found_device,
111 pjrt_client_->LookupDevice(device->id()));
112 if (found_device != device) {
113 return InvalidArgument("Cannot copy value to device '%s' with '%s' backend",
114 device->DebugString(),
115 pjrt_client_->platform_name());
116 }
117 GlobalPyRefManager()->CollectGarbage();
118
119 absl::optional<CastToArrayResult> c = CastToArray(argument);
120 if (!c) {
121 return InvalidArgument(
122 "from_python argument must be an array, got value %s",
123 py::cast<std::string>(py::repr(argument)));
124 }
125
126 std::function<void()> on_done_with_host_buffer;
127 if (host_buffer_semantics !=
128 PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall) {
129 std::shared_ptr<PythonRefManager::ManagedPyObjects> py_buffer_ref =
130 GlobalPyRefManager()->ManageReference(std::move(c->array));
131 on_done_with_host_buffer =
132 [py_buffer_ref{
133 std::move(py_buffer_ref)}]() { /* keeps py_buffer_ref alive */ };
134 }
135
136 std::unique_ptr<PjRtBuffer> buffer;
137 {
138 py::gil_scoped_release gil_release;
139 TF_ASSIGN_OR_RETURN(buffer,
140 pjrt_client_->BufferFromHostBuffer(
141 c->buf_ptr, c->shape, host_buffer_semantics,
142 std::move(on_done_with_host_buffer), device));
143 }
144 return buffer;
145 }
BufferFromPyval(pybind11::handle argument,PjRtDevice * device,bool force_copy,PjRtClient::HostBufferSemantics host_buffer_semantics)146 StatusOr<std::unique_ptr<PyBuffer>> PyClient::BufferFromPyval(
147 pybind11::handle argument, PjRtDevice* device, bool force_copy,
148 PjRtClient::HostBufferSemantics host_buffer_semantics) {
149 TF_ASSIGN_OR_RETURN(
150 std::unique_ptr<PjRtBuffer> buffer,
151 PjRtBufferFromPyval(argument, device, force_copy, host_buffer_semantics));
152
153 auto traceback = Traceback::Get();
154 return std::make_unique<PyBuffer>(shared_from_this(), std::move(buffer),
155 std::move(traceback));
156 }
157
Compile(const XlaComputation & computation,CompileOptions options)158 StatusOr<std::shared_ptr<PyExecutable>> PyClient::Compile(
159 const XlaComputation& computation, CompileOptions options) {
160 std::unique_ptr<PjRtExecutable> executable;
161 absl::optional<std::string> fingerprint;
162 {
163 py::gil_scoped_release gil_release;
164 TF_ASSIGN_OR_RETURN(executable,
165 pjrt_client_->Compile(computation, std::move(options)));
166 TF_ASSIGN_OR_RETURN(fingerprint,
167 pjrt_client_->ExecutableFingerprint(*executable));
168 }
169 auto traceback = Traceback::Get();
170 return std::make_shared<PyExecutable>(
171 shared_from_this(), std::move(executable), std::move(traceback),
172 std::move(fingerprint));
173 }
174
175 class ProfileBuilder {
176 public:
177 ProfileBuilder();
profile()178 pprof::Profile& profile() { return profile_; }
179
180 // Adds or returns the ID of `s` in the table.
181 int StringId(const std::string& s);
182
183 // Adds or returns the ID of a function.
184 int FunctionId(PyCodeObject* code);
185
186 // Adds or returns the ID of a code location.
187 int LocationId(PyCodeObject* code, int instruction);
188
189 private:
190 pprof::Profile profile_;
191
192 absl::flat_hash_map<std::string, int> strings_;
193 absl::flat_hash_map<PyCodeObject*, int> functions_;
194 absl::flat_hash_map<std::pair<PyCodeObject*, int>, int> locations_;
195 };
196
ProfileBuilder()197 ProfileBuilder::ProfileBuilder() { CHECK_EQ(0, StringId("")); }
198
StringId(const std::string & s)199 int ProfileBuilder::StringId(const std::string& s) {
200 auto ret = strings_.emplace(s, profile_.string_table_size());
201 if (ret.second) {
202 profile_.add_string_table(s);
203 }
204 return ret.first->second;
205 }
206
FunctionId(PyCodeObject * code)207 int ProfileBuilder::FunctionId(PyCodeObject* code) {
208 // +1 because id 0 is reserved.
209 auto ret = functions_.emplace(code, profile_.function_size() + 1);
210 if (ret.second) {
211 auto* function = profile_.add_function();
212 function->set_id(ret.first->second);
213 int name = StringId(py::str(code->co_name));
214 function->set_name(name);
215 function->set_system_name(name);
216 function->set_filename(StringId(py::str(code->co_filename)));
217 function->set_start_line(code->co_firstlineno);
218 }
219 return ret.first->second;
220 }
221
LocationId(PyCodeObject * code,int instruction)222 int ProfileBuilder::LocationId(PyCodeObject* code, int instruction) {
223 // +1 because id 0 is reserved.
224 auto ret = locations_.emplace(std::make_pair(code, instruction),
225 profile_.location_size() + 1);
226 if (ret.second) {
227 auto* location = profile_.add_location();
228 location->set_id(ret.first->second);
229 auto* line = location->add_line();
230 line->set_function_id(FunctionId(code));
231 line->set_line(PyCode_Addr2Line(code, instruction));
232 }
233 return ret.first->second;
234 }
235
236 namespace {
237
238 struct HeapProfileKey {
239 Traceback* traceback;
240 int64 size;
241 PjRtDevice* device;
242 bool operator==(const HeapProfileKey& other) const;
243 };
244
operator ==(const HeapProfileKey & other) const245 bool HeapProfileKey::operator==(const HeapProfileKey& other) const {
246 if (size != other.size || device != other.device) {
247 return false;
248 }
249 if ((traceback == nullptr) != (other.traceback == nullptr)) {
250 return false;
251 }
252 if (traceback && traceback->raw_frames() != other.traceback->raw_frames()) {
253 return false;
254 }
255 return true;
256 }
257
258 template <typename H>
AbslHashValue(H h,const HeapProfileKey & key)259 H AbslHashValue(H h, const HeapProfileKey& key) {
260 if (key.traceback) {
261 h = H::combine_contiguous(std::move(h), key.traceback->raw_frames().begin(),
262 key.traceback->raw_frames().size());
263 }
264 h = H::combine(std::move(h), key.size, key.device);
265 return h;
266 }
267
268 } // namespace
269
HeapProfile()270 py::bytes PyClient::HeapProfile() {
271 CHECK(PyGILState_Check());
272 absl::flat_hash_map<HeapProfileKey, int64> entries;
273 for (PyBuffer* buffer = buffers_; buffer; buffer = buffer->next_) {
274 HeapProfileKey key{buffer->traceback(),
275 buffer->buffer()->OnDeviceSizeInBytes(),
276 buffer->buffer()->device()};
277 ++entries[key];
278 }
279
280 for (PyExecutable* executable = executables_; executable;
281 executable = executable->next_) {
282 HeapProfileKey key{executable->traceback(),
283 executable->SizeOfGeneratedCodeInBytes(), nullptr};
284 ++entries[key];
285 }
286
287 ProfileBuilder builder;
288 auto* allocations = builder.profile().add_sample_type();
289 allocations->set_type(builder.StringId("allocations"));
290 allocations->set_unit(builder.StringId("count"));
291 auto* space = builder.profile().add_sample_type();
292 space->set_type(builder.StringId("space"));
293 space->set_unit(builder.StringId("bytes"));
294
295 const int kind_string_id = builder.StringId("kind");
296 const int buffer_string_id = builder.StringId("buffer");
297 const int executable_string_id = builder.StringId("executable");
298 const int device_string_id = builder.StringId("device");
299 for (const auto& entry : entries) {
300 auto* sample = builder.profile().add_sample();
301 if (entry.first.traceback) {
302 for (const auto& frame : entry.first.traceback->raw_frames()) {
303 sample->add_location_id(builder.LocationId(frame.first, frame.second));
304 }
305 }
306 sample->add_value(entry.second);
307 sample->add_value(entry.first.size * entry.second);
308
309 auto* kind_label = sample->add_label();
310 kind_label->set_key(kind_string_id);
311 if (entry.first.device) {
312 kind_label->set_str(buffer_string_id);
313 auto* device_label = sample->add_label();
314 device_label->set_key(device_string_id);
315 device_label->set_str(
316 builder.StringId(entry.first.device->DebugString()));
317 } else {
318 kind_label->set_str(executable_string_id);
319 }
320 }
321 return builder.profile().SerializeAsString();
322 }
323
324 } // namespace xla
325