1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/py_client.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21
22 #include "absl/base/casts.h"
23 #include "absl/container/flat_hash_map.h"
24 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
25 #include "tensorflow/compiler/xla/pjrt/transpose.h"
26 #include "tensorflow/compiler/xla/primitive_util.h"
27 #include "tensorflow/compiler/xla/python/py_buffer.h"
28 #include "tensorflow/compiler/xla/python/py_executable.h"
29 #include "tensorflow/compiler/xla/python/py_values.h"
30 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
31 #include "tensorflow/compiler/xla/python/traceback.h"
32 #include "tensorflow/compiler/xla/python/types.h"
33 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
34 #include "tensorflow/core/profiler/profile.pb.h"
35
36 namespace xla {
37
38 namespace py = pybind11;
39 namespace pprof = tensorflow::tfprof::pprof;
40
PyClient(std::unique_ptr<PjRtClient> pjrt_client)41 PyClient::PyClient(std::unique_ptr<PjRtClient> pjrt_client)
42 : PyClient(std::shared_ptr<PjRtClient>(std::move(pjrt_client))) {}
43
PyClient(std::shared_ptr<PjRtClient> pjrt_client)44 PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
45 : pjrt_client_(std::move(pjrt_client)) {
46 CHECK(pjrt_client_ != nullptr);
47 buffers_.resize(pjrt_client_->device_count());
48 for (PjRtDevice* device : pjrt_client_->addressable_devices()) {
49 if (device->id() >= buffers_.size()) {
50 buffers_.resize(device->id() + 1);
51 }
52 }
53 }
54
~PyClient()55 PyClient::~PyClient() {
56 py::gil_scoped_release gil;
57 pjrt_client_ = nullptr;
58 }
59
Devices()60 std::vector<ClientAndPtr<PjRtDevice>> PyClient::Devices() {
61 std::vector<ClientAndPtr<PjRtDevice>> devices;
62 auto span = pjrt_client_->devices();
63 devices.reserve(span.size());
64 for (PjRtDevice* device : span) {
65 devices.push_back(WrapWithClient(shared_from_this(), device));
66 }
67 return devices;
68 }
69
LocalDevices()70 std::vector<ClientAndPtr<PjRtDevice>> PyClient::LocalDevices() {
71 std::vector<ClientAndPtr<PjRtDevice>> devices;
72 devices.reserve(pjrt_client_->addressable_devices().size());
73 for (PjRtDevice* device : pjrt_client_->addressable_devices()) {
74 devices.push_back(WrapWithClient(shared_from_this(), device));
75 }
76 return devices;
77 }
78
LiveBuffers()79 std::vector<py::object> PyClient::LiveBuffers() {
80 CHECK(PyGILState_Check());
81 std::vector<py::object> buffers;
82 for (PyBuffer* device_buffers : buffers_) {
83 for (PyBuffer* buffer = device_buffers; buffer; buffer = buffer->next_) {
84 if (!buffer->is_deleted()) {
85 buffers.push_back(
86 py::reinterpret_borrow<py::object>(buffer->AsHandle()));
87 }
88 }
89 }
90 return buffers;
91 }
92
LiveBuffersOnDevice(PjRtDevice * device)93 std::vector<py::object> PyClient::LiveBuffersOnDevice(PjRtDevice* device) {
94 CHECK_EQ(device->client(), pjrt_client());
95 CHECK(PyGILState_Check());
96 std::vector<py::object> buffers;
97 for (PyBuffer* buffer = buffers_[device->id()]; buffer;
98 buffer = buffer->next_) {
99 if (!buffer->is_deleted()) {
100 buffers.push_back(py::reinterpret_borrow<py::object>(buffer->AsHandle()));
101 }
102 }
103 return buffers;
104 }
105
LiveExecutables()106 std::vector<std::shared_ptr<PyExecutable>> PyClient::LiveExecutables() {
107 CHECK(PyGILState_Check());
108 std::vector<std::shared_ptr<PyExecutable>> executables;
109 for (PyExecutable* exec = executables_; exec; exec = exec->next_) {
110 if (!exec->is_deleted()) {
111 executables.push_back(exec->shared_from_this());
112 }
113 }
114 return executables;
115 }
116
Defragment()117 Status PyClient::Defragment() {
118 CHECK(PyGILState_Check());
119 return pjrt_client_->Defragment();
120 }
121
122 StatusOr<std::vector<std::vector<ClientAndPtr<PjRtDevice>>>>
GetDefaultDeviceAssignment(int num_replicas,int num_partitions)123 PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) {
124 TF_ASSIGN_OR_RETURN(
125 DeviceAssignment device_assignment,
126 pjrt_client_->GetDefaultDeviceAssignment(num_replicas, num_partitions));
127 std::vector<std::vector<ClientAndPtr<PjRtDevice>>> result;
128 result.resize(num_replicas);
129 for (int r = 0; r < num_replicas; ++r) {
130 result[r].resize(num_partitions);
131 for (int p = 0; p < num_partitions; ++p) {
132 int device_id = device_assignment(r, p);
133 TF_ASSIGN_OR_RETURN(PjRtDevice * device,
134 pjrt_client_->LookupDevice(device_id));
135 result[r][p] = WrapWithClient(shared_from_this(), device);
136 }
137 }
138 return result;
139 }
140
141 StatusOr<std::vector<ClientAndPtr<PjRtDevice>>>
GetDefaultDeviceAssignment1D(int num_replicas)142 PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
143 TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
144 pjrt_client_->GetDefaultDeviceAssignment(
145 num_replicas, /*num_partitions=*/1));
146 std::vector<ClientAndPtr<PjRtDevice>> result;
147 for (int i = 0; i < num_replicas; ++i) {
148 int device_id = device_assignment(i, 0);
149 TF_ASSIGN_OR_RETURN(PjRtDevice * device,
150 pjrt_client_->LookupDevice(device_id));
151 result.push_back(WrapWithClient(shared_from_this(), device));
152 }
153 return result;
154 }
155
BufferFromPyval(pybind11::handle argument,PjRtDevice * device,bool force_copy,PjRtClient::HostBufferSemantics host_buffer_semantics)156 StatusOr<py::object> PyClient::BufferFromPyval(
157 pybind11::handle argument, PjRtDevice* device, bool force_copy,
158 PjRtClient::HostBufferSemantics host_buffer_semantics) {
159 if (device == nullptr) {
160 TF_RET_CHECK(!pjrt_client_->addressable_devices().empty());
161 device = pjrt_client_->addressable_devices().front();
162 }
163 CHECK(device != nullptr);
164 TF_ASSIGN_OR_RETURN(PjRtDevice * found_device,
165 pjrt_client_->LookupDevice(device->id()));
166 if (found_device != device) {
167 return InvalidArgument("Cannot copy value to device '%s' with '%s' backend",
168 device->DebugString(),
169 pjrt_client_->platform_name());
170 }
171 GlobalPyRefManager()->CollectGarbage();
172
173 DevicePutOptions options;
174 options.squash_64bit_types = false;
175 options.allow_zero_copy =
176 (!force_copy &&
177 (host_buffer_semantics == PjRtClient::HostBufferSemantics::kZeroCopy));
178 options.force_lazy_arrays = true;
179 TF_ASSIGN_OR_RETURN(DevicePutResult put,
180 DevicePut(argument, device, options));
181
182 if (put.owned_buffer) {
183 auto traceback = Traceback::Get();
184 return PyBuffer::Make(shared_from_this(), std::move(put.owned_buffer),
185 std::move(traceback));
186 } else {
187 return py::reinterpret_borrow<py::object>(put.owning_pybuffer);
188 }
189 }
190
Compile(const XlaComputation & computation,CompileOptions options)191 StatusOr<std::shared_ptr<PyExecutable>> PyClient::Compile(
192 const XlaComputation& computation, CompileOptions options) {
193 std::unique_ptr<PjRtExecutable> executable;
194 absl::optional<std::string> fingerprint;
195 {
196 py::gil_scoped_release gil_release;
197 TF_ASSIGN_OR_RETURN(executable,
198 pjrt_client_->Compile(computation, std::move(options)));
199 TF_ASSIGN_OR_RETURN(fingerprint,
200 pjrt_client_->ExecutableFingerprint(*executable));
201 }
202 auto traceback = Traceback::Get();
203 return std::make_shared<PyExecutable>(
204 shared_from_this(), std::move(executable), std::move(traceback),
205 std::move(fingerprint));
206 }
207
SerializeExecutable(const PyExecutable & executable) const208 StatusOr<py::bytes> PyClient::SerializeExecutable(
209 const PyExecutable& executable) const {
210 return pjrt_client_->SerializeExecutable(executable.pjrt_executable());
211 }
212
DeserializeExecutable(const std::string & serialized,CompileOptions options)213 StatusOr<std::shared_ptr<PyExecutable>> PyClient::DeserializeExecutable(
214 const std::string& serialized, CompileOptions options) {
215 std::unique_ptr<PjRtExecutable> executable;
216 absl::optional<std::string> fingerprint;
217 {
218 py::gil_scoped_release gil_release;
219 TF_ASSIGN_OR_RETURN(executable, pjrt_client_->DeserializeExecutable(
220 serialized, std::move(options)));
221 TF_ASSIGN_OR_RETURN(fingerprint,
222 pjrt_client_->ExecutableFingerprint(*executable));
223 }
224 auto traceback = Traceback::Get();
225 return std::make_shared<PyExecutable>(
226 shared_from_this(), std::move(executable), std::move(traceback),
227 std::move(fingerprint));
228 }
229
230 class ProfileBuilder {
231 public:
232 ProfileBuilder();
profile()233 pprof::Profile& profile() { return profile_; }
234
235 // Adds or returns the ID of `s` in the table.
236 int StringId(const std::string& s);
237
238 // Adds or returns the ID of a function.
239 int FunctionId(PyCodeObject* code);
240
241 // Adds or returns the ID of a code location.
242 int LocationId(PyCodeObject* code, int instruction);
243
244 private:
245 pprof::Profile profile_;
246
247 absl::flat_hash_map<std::string, int> strings_;
248 absl::flat_hash_map<PyCodeObject*, int> functions_;
249 absl::flat_hash_map<std::pair<PyCodeObject*, int>, int> locations_;
250 };
251
ProfileBuilder()252 ProfileBuilder::ProfileBuilder() { CHECK_EQ(0, StringId("")); }
253
StringId(const std::string & s)254 int ProfileBuilder::StringId(const std::string& s) {
255 auto ret = strings_.emplace(s, profile_.string_table_size());
256 if (ret.second) {
257 profile_.add_string_table(s);
258 }
259 return ret.first->second;
260 }
261
FunctionId(PyCodeObject * code)262 int ProfileBuilder::FunctionId(PyCodeObject* code) {
263 // +1 because id 0 is reserved.
264 auto ret = functions_.emplace(code, profile_.function_size() + 1);
265 if (ret.second) {
266 auto* function = profile_.add_function();
267 function->set_id(ret.first->second);
268 int name = StringId(py::str(code->co_name));
269 function->set_name(name);
270 function->set_system_name(name);
271 function->set_filename(StringId(py::str(code->co_filename)));
272 function->set_start_line(code->co_firstlineno);
273 }
274 return ret.first->second;
275 }
276
LocationId(PyCodeObject * code,int instruction)277 int ProfileBuilder::LocationId(PyCodeObject* code, int instruction) {
278 // +1 because id 0 is reserved.
279 auto ret = locations_.emplace(std::make_pair(code, instruction),
280 profile_.location_size() + 1);
281 if (ret.second) {
282 auto* location = profile_.add_location();
283 location->set_id(ret.first->second);
284 auto* line = location->add_line();
285 line->set_function_id(FunctionId(code));
286 line->set_line(PyCode_Addr2Line(code, instruction));
287 }
288 return ret.first->second;
289 }
290
291 namespace {
292
293 struct HeapProfileKey {
294 Traceback* traceback;
295 int64 size;
296 PjRtDevice* device;
297 bool operator==(const HeapProfileKey& other) const;
298 };
299
operator ==(const HeapProfileKey & other) const300 bool HeapProfileKey::operator==(const HeapProfileKey& other) const {
301 if (size != other.size || device != other.device) {
302 return false;
303 }
304 if ((traceback == nullptr) != (other.traceback == nullptr)) {
305 return false;
306 }
307 if (traceback && traceback->raw_frames() != other.traceback->raw_frames()) {
308 return false;
309 }
310 return true;
311 }
312
313 template <typename H>
AbslHashValue(H h,const HeapProfileKey & key)314 H AbslHashValue(H h, const HeapProfileKey& key) {
315 if (key.traceback) {
316 h = H::combine_contiguous(std::move(h), key.traceback->raw_frames().begin(),
317 key.traceback->raw_frames().size());
318 }
319 h = H::combine(std::move(h), key.size, key.device);
320 return h;
321 }
322
323 } // namespace
324
HeapProfile()325 StatusOr<py::bytes> PyClient::HeapProfile() {
326 CHECK(PyGILState_Check());
327 absl::flat_hash_set<PjRtBuffer*> buffer_set;
328 absl::flat_hash_map<HeapProfileKey, int64> entries;
329 for (PyBuffer* device_buffers : buffers_) {
330 for (PyBuffer* buffer = device_buffers; buffer; buffer = buffer->next_) {
331 // We only wish to count each PjRtBuffer once, even though they may be
332 // shared by multiple PyBuffers.
333 if (!buffer->is_deleted() && buffer_set.insert(buffer->buffer()).second) {
334 TF_ASSIGN_OR_RETURN(size_t size,
335 buffer->buffer()->GetOnDeviceSizeInBytes());
336 HeapProfileKey key{buffer->traceback().get(),
337 static_cast<int64_t>(size),
338 buffer->buffer()->device()};
339 ++entries[key];
340 }
341 }
342 }
343
344 for (PyExecutable* executable = executables_; executable;
345 executable = executable->next_) {
346 if (!executable->is_deleted()) {
347 HeapProfileKey key{executable->traceback(),
348 executable->SizeOfGeneratedCodeInBytes(), nullptr};
349 ++entries[key];
350 }
351 }
352
353 ProfileBuilder builder;
354 auto* allocations = builder.profile().add_sample_type();
355 allocations->set_type(builder.StringId("allocations"));
356 allocations->set_unit(builder.StringId("count"));
357 auto* space = builder.profile().add_sample_type();
358 space->set_type(builder.StringId("space"));
359 space->set_unit(builder.StringId("bytes"));
360
361 const int kind_string_id = builder.StringId("kind");
362 const int buffer_string_id = builder.StringId("buffer");
363 const int executable_string_id = builder.StringId("executable");
364 const int device_string_id = builder.StringId("device");
365 for (const auto& entry : entries) {
366 auto* sample = builder.profile().add_sample();
367 if (entry.first.traceback) {
368 for (const auto& frame : entry.first.traceback->raw_frames()) {
369 sample->add_location_id(builder.LocationId(frame.first, frame.second));
370 }
371 }
372 sample->add_value(entry.second);
373 sample->add_value(entry.first.size * entry.second);
374
375 auto* kind_label = sample->add_label();
376 kind_label->set_key(kind_string_id);
377 if (entry.first.device) {
378 kind_label->set_str(buffer_string_id);
379 auto* device_label = sample->add_label();
380 device_label->set_key(device_string_id);
381 device_label->set_str(
382 builder.StringId(entry.first.device->DebugString()));
383 } else {
384 kind_label->set_str(executable_string_id);
385 }
386 }
387 return py::bytes(builder.profile().SerializeAsString());
388 }
389
390 namespace {
391
392 class CpuCallback {
393 public:
394 struct Arg {
395 PrimitiveType type; // XLA type
396 py::dtype dtype; // NumPy type, for array types.
397 absl::InlinedVector<int64_t, 4> dims; // Dimensions, for array types.
398 std::vector<ssize_t> strides; // Byte strides, for array types.
399 };
400 struct Result {
401 PrimitiveType type; // XLA type
402 // Expected output shape, for array types
403 absl::InlinedVector<int64_t, 4> expected_dims;
404 // Expected output byte strides, for array types. If the strides do not
405 // match the output will be transposed into the expected layout.
406 std::vector<int64_t> expected_strides;
407 // The desired order of output dimensions in major-to-minor order.
408 absl::InlinedVector<int64_t, 4> reversed_layout;
409 // Size of the array in bytes.
410 size_t size_in_bytes;
411 };
412
CpuCallback(py::function callable,std::vector<Arg> args,std::vector<Result> results)413 explicit CpuCallback(py::function callable, std::vector<Arg> args,
414 std::vector<Result> results)
415 : callable_(std::move(callable)),
416 args_(std::move(args)),
417 results_(std::move(results)),
418 transpose_cache_(/*capacity=*/16) {}
419
420 void Call(void* result, void** arg_ptrs);
421
422 private:
423 py::function callable_;
424 std::vector<Arg> const args_;
425 std::vector<Result> const results_;
426 TransposePlanCache transpose_cache_;
427 };
428
Call(void * result,void ** arg_ptrs)429 void CpuCallback::Call(void* result, void** arg_ptrs) {
430 absl::Span<void* const> inputs(arg_ptrs, args_.size());
431 absl::Span<void* const> outputs(reinterpret_cast<void**>(result),
432 results_.size());
433
434 py::gil_scoped_acquire gil;
435 py::tuple args(inputs.size());
436 for (size_t i = 0; i < inputs.size(); ++i) {
437 if (args_[i].type == TOKEN) {
438 args[i] = py::none();
439 } else {
440 args[i] = py::array(args_[i].dtype, args_[i].dims, args_[i].strides,
441 const_cast<void*>(inputs[i]));
442 args[i].attr("flags").attr("writeable") = Py_False;
443 }
444 }
445 py::object result_tuple = callable_(*py::reinterpret_borrow<py::args>(args));
446 if (!PyTuple_Check(result_tuple.ptr())) {
447 throw std::runtime_error(
448 absl::StrFormat("CPU callback expected a tuple result, got %s",
449 static_cast<std::string>(py::repr(result_tuple))));
450 }
451 if (PyTuple_Size(result_tuple.ptr()) != results_.size()) {
452 throw std::runtime_error(
453 absl::StrFormat("CPU callback expected a tuple with %d results, got %d",
454 results_.size(), PyTuple_Size(result_tuple.ptr())));
455 }
456 for (size_t i = 0; i < results_.size(); ++i) {
457 py::object output = py::reinterpret_borrow<py::object>(
458 PyTuple_GetItem(result_tuple.ptr(), i));
459 if (results_[i].type == TOKEN) {
460 if (!output.is_none()) {
461 throw std::runtime_error(absl::StrFormat(
462 "Token output from Python callback should be None, got %s",
463 static_cast<std::string>(py::repr(output))));
464 }
465 continue;
466 }
467 py::array array = py::cast<py::array>(std::move(output));
468 static_assert(sizeof(ssize_t) == sizeof(int64_t),
469 "Expected ssize_t to be of equal size to int64_t");
470 absl::Span<int64_t const> dims(
471 reinterpret_cast<const int64_t*>(array.shape()), array.ndim());
472 if (dims != results_[i].expected_dims) {
473 throw std::runtime_error(absl::StrFormat(
474 "Mismatched result shape for %d-th return value from CPU callback; "
475 "expected array with dimensions %s, got %s",
476 i, absl::StrJoin(results_[i].expected_dims, ","),
477 absl::StrJoin(dims, ",")));
478 }
479 absl::Span<int64_t const> strides(
480 reinterpret_cast<const int64_t*>(array.strides()), array.ndim());
481 if (strides == results_[i].expected_strides) {
482 std::memcpy(outputs[i], array.data(), results_[i].size_in_bytes);
483 } else {
484 StatusOr<std::shared_ptr<TransposePlan>> plan =
485 transpose_cache_.GetOrCreate(
486 primitive_util::ByteWidth(results_[i].type), dims,
487 results_[i].reversed_layout,
488 /*input_layout=*/TransposePlan::Striding{strides});
489 if (!plan.ok()) {
490 throw std::runtime_error(plan.status().ToString());
491 }
492 plan.ValueOrDie()->Execute(array.data(), outputs[i]);
493 }
494 }
495 }
496
XlaPythonCpuCallback(void * output,void ** inputs)497 extern "C" void XlaPythonCpuCallback(void* output, void** inputs) {
498 CpuCallback* callback =
499 absl::bit_cast<CpuCallback*>(*static_cast<uintptr_t*>(inputs[0]));
500 callback->Call(output, inputs + 1);
501 }
502
503 XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("xla_python_cpu_callback",
504 &XlaPythonCpuCallback);
505
506 } // namespace
507
EmitPythonCallback(pybind11::function callable,XlaBuilder & builder,absl::Span<XlaOp const> operands,absl::Span<Shape const> result_shapes,absl::optional<std::vector<Shape>> operand_layouts,bool has_side_effect)508 StatusOr<std::pair<XlaOp, pybind11::object>> PyClient::EmitPythonCallback(
509 pybind11::function callable, XlaBuilder& builder,
510 absl::Span<XlaOp const> operands, absl::Span<Shape const> result_shapes,
511 absl::optional<std::vector<Shape>> operand_layouts, bool has_side_effect) {
512 if (pjrt_client_->platform_id() != kCpuId) {
513 return Unimplemented("EmitPythonCallback is only implemented on CPU");
514 }
515
516 std::vector<CpuCallback::Arg> callback_args(operands.size());
517 std::vector<XlaOp> custom_call_args(operands.size() + 1);
518 absl::c_copy(operands, custom_call_args.begin() + 1);
519
520 if (operand_layouts && operand_layouts->size() != operands.size()) {
521 return InvalidArgument(
522 "Mismatched number of operands (%d) and operand_layouts (%d)",
523 operands.size(), operand_layouts->size());
524 }
525
526 std::vector<Shape> custom_call_arg_layouts(operands.size() + 1);
527 static_assert(sizeof(uintptr_t) == sizeof(uint64_t),
528 "Expected 64-bit pointers");
529 custom_call_arg_layouts[0] =
530 ShapeUtil::MakeShapeWithDescendingLayout(U64, {});
531 for (int i = 0; i < operands.size(); ++i) {
532 TF_ASSIGN_OR_RETURN(Shape shape, builder.GetShape(operands[i]));
533 xla::Shape& layout = custom_call_arg_layouts[i + 1];
534 if (operand_layouts) {
535 if (!(*operand_layouts)[i].has_layout()) {
536 return InvalidArgument(
537 "operand_layout shapes for callback must have "
538 "layouts, got %s",
539 (*operand_layouts)[i].ToString(/*print_layout=*/true));
540 }
541 if (!ShapeUtil::Compatible(shape, (*operand_layouts)[i])) {
542 return InvalidArgument(
543 "Incompatible shapes for Python callback argument %d: %s vs %s", i,
544 shape.ToString(),
545 (*operand_layouts)[i].ToString(/*print_layout=*/true));
546 }
547 layout = (*operand_layouts)[i];
548 } else {
549 layout = LayoutUtil::GetWithDefaultLayout(shape);
550 }
551
552 if (shape.IsArray()) {
553 callback_args[i].dims.resize(shape.dimensions_size());
554 absl::c_copy(shape.dimensions(), callback_args[i].dims.begin());
555 callback_args[i].strides = ByteStridesForShape(layout);
556 callback_args[i].type = shape.element_type();
557 TF_ASSIGN_OR_RETURN(callback_args[i].dtype,
558 PrimitiveTypeToDtype(shape.element_type()));
559 } else if (shape.IsToken()) {
560 callback_args[i].type = TOKEN;
561 } else {
562 return InvalidArgument(
563 "Only array and token arguments to Python callbacks are supported, "
564 "got %s",
565 shape.ToString());
566 }
567 }
568
569 std::vector<Shape> result_shapes_with_layout(result_shapes.size());
570 std::vector<CpuCallback::Result> callback_results(result_shapes.size());
571 for (int i = 0; i < result_shapes.size(); ++i) {
572 if (result_shapes[i].IsArray()) {
573 result_shapes_with_layout[i] =
574 result_shapes[i].has_layout()
575 ? result_shapes[i]
576 : LayoutUtil::GetWithDefaultLayout(result_shapes[i]);
577 const Shape& shape = result_shapes_with_layout[i];
578 callback_results[i].expected_dims.resize(shape.dimensions_size());
579 absl::c_copy(shape.dimensions(),
580 callback_results[i].expected_dims.begin());
581 callback_results[i].expected_strides = ByteStridesForShapeInt64(shape);
582 callback_results[i].type = shape.element_type();
583 callback_results[i].size_in_bytes = ShapeUtil::ByteSizeOf(shape);
584 callback_results[i].reversed_layout.resize(shape.dimensions_size());
585 absl::c_reverse_copy(shape.layout().minor_to_major(),
586 callback_results[i].reversed_layout.begin());
587 } else if (result_shapes[i].IsToken()) {
588 callback_results[i].type = TOKEN;
589 result_shapes_with_layout[i] = result_shapes[i];
590 } else {
591 return InvalidArgument(
592 "Only array and token return values from Python callbacks are "
593 "supported, got %s",
594 result_shapes[i].ToString());
595 }
596 }
597
598 auto callback = std::make_unique<CpuCallback>(
599 std::move(callable), callback_args, callback_results);
600 custom_call_args[0] = ConstantR0<std::uint64_t>(
601 &builder, absl::bit_cast<std::uint64_t>(callback.get()));
602
603 Shape result_shape = ShapeUtil::MakeTupleShape(result_shapes_with_layout);
604 XlaOp result = CustomCallWithLayout(&builder, "xla_python_cpu_callback",
605 custom_call_args, result_shape,
606 custom_call_arg_layouts,
607 /*opaque=*/"", has_side_effect);
608
609 py::capsule callback_capsule(callback.release(), [](void* ptr) {
610 delete reinterpret_cast<CpuCallback*>(ptr);
611 });
612 return std::make_pair(result, py::object(std::move(callback_capsule)));
613 }
614
615 } // namespace xla
616