• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/py_client.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/base/casts.h"
23 #include "absl/container/flat_hash_map.h"
24 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
25 #include "tensorflow/compiler/xla/pjrt/transpose.h"
26 #include "tensorflow/compiler/xla/primitive_util.h"
27 #include "tensorflow/compiler/xla/python/py_buffer.h"
28 #include "tensorflow/compiler/xla/python/py_executable.h"
29 #include "tensorflow/compiler/xla/python/py_values.h"
30 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
31 #include "tensorflow/compiler/xla/python/traceback.h"
32 #include "tensorflow/compiler/xla/python/types.h"
33 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
34 #include "tensorflow/core/profiler/profile.pb.h"
35 
36 namespace xla {
37 
38 namespace py = pybind11;
39 namespace pprof = tensorflow::tfprof::pprof;
40 
PyClient(std::unique_ptr<PjRtClient> pjrt_client)41 PyClient::PyClient(std::unique_ptr<PjRtClient> pjrt_client)
42     : PyClient(std::shared_ptr<PjRtClient>(std::move(pjrt_client))) {}
43 
PyClient(std::shared_ptr<PjRtClient> pjrt_client)44 PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client)
45     : pjrt_client_(std::move(pjrt_client)) {
46   CHECK(pjrt_client_ != nullptr);
47   buffers_.resize(pjrt_client_->device_count());
48   for (PjRtDevice* device : pjrt_client_->addressable_devices()) {
49     if (device->id() >= buffers_.size()) {
50       buffers_.resize(device->id() + 1);
51     }
52   }
53 }
54 
~PyClient()55 PyClient::~PyClient() {
56   py::gil_scoped_release gil;
57   pjrt_client_ = nullptr;
58 }
59 
Devices()60 std::vector<ClientAndPtr<PjRtDevice>> PyClient::Devices() {
61   std::vector<ClientAndPtr<PjRtDevice>> devices;
62   auto span = pjrt_client_->devices();
63   devices.reserve(span.size());
64   for (PjRtDevice* device : span) {
65     devices.push_back(WrapWithClient(shared_from_this(), device));
66   }
67   return devices;
68 }
69 
LocalDevices()70 std::vector<ClientAndPtr<PjRtDevice>> PyClient::LocalDevices() {
71   std::vector<ClientAndPtr<PjRtDevice>> devices;
72   devices.reserve(pjrt_client_->addressable_devices().size());
73   for (PjRtDevice* device : pjrt_client_->addressable_devices()) {
74     devices.push_back(WrapWithClient(shared_from_this(), device));
75   }
76   return devices;
77 }
78 
LiveBuffers()79 std::vector<py::object> PyClient::LiveBuffers() {
80   CHECK(PyGILState_Check());
81   std::vector<py::object> buffers;
82   for (PyBuffer* device_buffers : buffers_) {
83     for (PyBuffer* buffer = device_buffers; buffer; buffer = buffer->next_) {
84       if (!buffer->is_deleted()) {
85         buffers.push_back(
86             py::reinterpret_borrow<py::object>(buffer->AsHandle()));
87       }
88     }
89   }
90   return buffers;
91 }
92 
LiveBuffersOnDevice(PjRtDevice * device)93 std::vector<py::object> PyClient::LiveBuffersOnDevice(PjRtDevice* device) {
94   CHECK_EQ(device->client(), pjrt_client());
95   CHECK(PyGILState_Check());
96   std::vector<py::object> buffers;
97   for (PyBuffer* buffer = buffers_[device->id()]; buffer;
98        buffer = buffer->next_) {
99     if (!buffer->is_deleted()) {
100       buffers.push_back(py::reinterpret_borrow<py::object>(buffer->AsHandle()));
101     }
102   }
103   return buffers;
104 }
105 
LiveExecutables()106 std::vector<std::shared_ptr<PyExecutable>> PyClient::LiveExecutables() {
107   CHECK(PyGILState_Check());
108   std::vector<std::shared_ptr<PyExecutable>> executables;
109   for (PyExecutable* exec = executables_; exec; exec = exec->next_) {
110     if (!exec->is_deleted()) {
111       executables.push_back(exec->shared_from_this());
112     }
113   }
114   return executables;
115 }
116 
Defragment()117 Status PyClient::Defragment() {
118   CHECK(PyGILState_Check());
119   return pjrt_client_->Defragment();
120 }
121 
122 StatusOr<std::vector<std::vector<ClientAndPtr<PjRtDevice>>>>
GetDefaultDeviceAssignment(int num_replicas,int num_partitions)123 PyClient::GetDefaultDeviceAssignment(int num_replicas, int num_partitions) {
124   TF_ASSIGN_OR_RETURN(
125       DeviceAssignment device_assignment,
126       pjrt_client_->GetDefaultDeviceAssignment(num_replicas, num_partitions));
127   std::vector<std::vector<ClientAndPtr<PjRtDevice>>> result;
128   result.resize(num_replicas);
129   for (int r = 0; r < num_replicas; ++r) {
130     result[r].resize(num_partitions);
131     for (int p = 0; p < num_partitions; ++p) {
132       int device_id = device_assignment(r, p);
133       TF_ASSIGN_OR_RETURN(PjRtDevice * device,
134                           pjrt_client_->LookupDevice(device_id));
135       result[r][p] = WrapWithClient(shared_from_this(), device);
136     }
137   }
138   return result;
139 }
140 
141 StatusOr<std::vector<ClientAndPtr<PjRtDevice>>>
GetDefaultDeviceAssignment1D(int num_replicas)142 PyClient::GetDefaultDeviceAssignment1D(int num_replicas) {
143   TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment,
144                       pjrt_client_->GetDefaultDeviceAssignment(
145                           num_replicas, /*num_partitions=*/1));
146   std::vector<ClientAndPtr<PjRtDevice>> result;
147   for (int i = 0; i < num_replicas; ++i) {
148     int device_id = device_assignment(i, 0);
149     TF_ASSIGN_OR_RETURN(PjRtDevice * device,
150                         pjrt_client_->LookupDevice(device_id));
151     result.push_back(WrapWithClient(shared_from_this(), device));
152   }
153   return result;
154 }
155 
BufferFromPyval(pybind11::handle argument,PjRtDevice * device,bool force_copy,PjRtClient::HostBufferSemantics host_buffer_semantics)156 StatusOr<py::object> PyClient::BufferFromPyval(
157     pybind11::handle argument, PjRtDevice* device, bool force_copy,
158     PjRtClient::HostBufferSemantics host_buffer_semantics) {
159   if (device == nullptr) {
160     TF_RET_CHECK(!pjrt_client_->addressable_devices().empty());
161     device = pjrt_client_->addressable_devices().front();
162   }
163   CHECK(device != nullptr);
164   TF_ASSIGN_OR_RETURN(PjRtDevice * found_device,
165                       pjrt_client_->LookupDevice(device->id()));
166   if (found_device != device) {
167     return InvalidArgument("Cannot copy value to device '%s' with '%s' backend",
168                            device->DebugString(),
169                            pjrt_client_->platform_name());
170   }
171   GlobalPyRefManager()->CollectGarbage();
172 
173   DevicePutOptions options;
174   options.squash_64bit_types = false;
175   options.allow_zero_copy =
176       (!force_copy &&
177        (host_buffer_semantics == PjRtClient::HostBufferSemantics::kZeroCopy));
178   options.force_lazy_arrays = true;
179   TF_ASSIGN_OR_RETURN(DevicePutResult put,
180                       DevicePut(argument, device, options));
181 
182   if (put.owned_buffer) {
183     auto traceback = Traceback::Get();
184     return PyBuffer::Make(shared_from_this(), std::move(put.owned_buffer),
185                           std::move(traceback));
186   } else {
187     return py::reinterpret_borrow<py::object>(put.owning_pybuffer);
188   }
189 }
190 
Compile(const XlaComputation & computation,CompileOptions options)191 StatusOr<std::shared_ptr<PyExecutable>> PyClient::Compile(
192     const XlaComputation& computation, CompileOptions options) {
193   std::unique_ptr<PjRtExecutable> executable;
194   absl::optional<std::string> fingerprint;
195   {
196     py::gil_scoped_release gil_release;
197     TF_ASSIGN_OR_RETURN(executable,
198                         pjrt_client_->Compile(computation, std::move(options)));
199     TF_ASSIGN_OR_RETURN(fingerprint,
200                         pjrt_client_->ExecutableFingerprint(*executable));
201   }
202   auto traceback = Traceback::Get();
203   return std::make_shared<PyExecutable>(
204       shared_from_this(), std::move(executable), std::move(traceback),
205       std::move(fingerprint));
206 }
207 
SerializeExecutable(const PyExecutable & executable) const208 StatusOr<py::bytes> PyClient::SerializeExecutable(
209     const PyExecutable& executable) const {
210   return pjrt_client_->SerializeExecutable(executable.pjrt_executable());
211 }
212 
DeserializeExecutable(const std::string & serialized,CompileOptions options)213 StatusOr<std::shared_ptr<PyExecutable>> PyClient::DeserializeExecutable(
214     const std::string& serialized, CompileOptions options) {
215   std::unique_ptr<PjRtExecutable> executable;
216   absl::optional<std::string> fingerprint;
217   {
218     py::gil_scoped_release gil_release;
219     TF_ASSIGN_OR_RETURN(executable, pjrt_client_->DeserializeExecutable(
220                                         serialized, std::move(options)));
221     TF_ASSIGN_OR_RETURN(fingerprint,
222                         pjrt_client_->ExecutableFingerprint(*executable));
223   }
224   auto traceback = Traceback::Get();
225   return std::make_shared<PyExecutable>(
226       shared_from_this(), std::move(executable), std::move(traceback),
227       std::move(fingerprint));
228 }
229 
230 class ProfileBuilder {
231  public:
232   ProfileBuilder();
profile()233   pprof::Profile& profile() { return profile_; }
234 
235   // Adds or returns the ID of `s` in the table.
236   int StringId(const std::string& s);
237 
238   // Adds or returns the ID of a function.
239   int FunctionId(PyCodeObject* code);
240 
241   // Adds or returns the ID of a code location.
242   int LocationId(PyCodeObject* code, int instruction);
243 
244  private:
245   pprof::Profile profile_;
246 
247   absl::flat_hash_map<std::string, int> strings_;
248   absl::flat_hash_map<PyCodeObject*, int> functions_;
249   absl::flat_hash_map<std::pair<PyCodeObject*, int>, int> locations_;
250 };
251 
ProfileBuilder()252 ProfileBuilder::ProfileBuilder() { CHECK_EQ(0, StringId("")); }
253 
StringId(const std::string & s)254 int ProfileBuilder::StringId(const std::string& s) {
255   auto ret = strings_.emplace(s, profile_.string_table_size());
256   if (ret.second) {
257     profile_.add_string_table(s);
258   }
259   return ret.first->second;
260 }
261 
FunctionId(PyCodeObject * code)262 int ProfileBuilder::FunctionId(PyCodeObject* code) {
263   // +1 because id 0 is reserved.
264   auto ret = functions_.emplace(code, profile_.function_size() + 1);
265   if (ret.second) {
266     auto* function = profile_.add_function();
267     function->set_id(ret.first->second);
268     int name = StringId(py::str(code->co_name));
269     function->set_name(name);
270     function->set_system_name(name);
271     function->set_filename(StringId(py::str(code->co_filename)));
272     function->set_start_line(code->co_firstlineno);
273   }
274   return ret.first->second;
275 }
276 
LocationId(PyCodeObject * code,int instruction)277 int ProfileBuilder::LocationId(PyCodeObject* code, int instruction) {
278   // +1 because id 0 is reserved.
279   auto ret = locations_.emplace(std::make_pair(code, instruction),
280                                 profile_.location_size() + 1);
281   if (ret.second) {
282     auto* location = profile_.add_location();
283     location->set_id(ret.first->second);
284     auto* line = location->add_line();
285     line->set_function_id(FunctionId(code));
286     line->set_line(PyCode_Addr2Line(code, instruction));
287   }
288   return ret.first->second;
289 }
290 
291 namespace {
292 
293 struct HeapProfileKey {
294   Traceback* traceback;
295   int64 size;
296   PjRtDevice* device;
297   bool operator==(const HeapProfileKey& other) const;
298 };
299 
operator ==(const HeapProfileKey & other) const300 bool HeapProfileKey::operator==(const HeapProfileKey& other) const {
301   if (size != other.size || device != other.device) {
302     return false;
303   }
304   if ((traceback == nullptr) != (other.traceback == nullptr)) {
305     return false;
306   }
307   if (traceback && traceback->raw_frames() != other.traceback->raw_frames()) {
308     return false;
309   }
310   return true;
311 }
312 
313 template <typename H>
AbslHashValue(H h,const HeapProfileKey & key)314 H AbslHashValue(H h, const HeapProfileKey& key) {
315   if (key.traceback) {
316     h = H::combine_contiguous(std::move(h), key.traceback->raw_frames().begin(),
317                               key.traceback->raw_frames().size());
318   }
319   h = H::combine(std::move(h), key.size, key.device);
320   return h;
321 }
322 
323 }  // namespace
324 
HeapProfile()325 StatusOr<py::bytes> PyClient::HeapProfile() {
326   CHECK(PyGILState_Check());
327   absl::flat_hash_set<PjRtBuffer*> buffer_set;
328   absl::flat_hash_map<HeapProfileKey, int64> entries;
329   for (PyBuffer* device_buffers : buffers_) {
330     for (PyBuffer* buffer = device_buffers; buffer; buffer = buffer->next_) {
331       // We only wish to count each PjRtBuffer once, even though they may be
332       // shared by multiple PyBuffers.
333       if (!buffer->is_deleted() && buffer_set.insert(buffer->buffer()).second) {
334         TF_ASSIGN_OR_RETURN(size_t size,
335                             buffer->buffer()->GetOnDeviceSizeInBytes());
336         HeapProfileKey key{buffer->traceback().get(),
337                            static_cast<int64_t>(size),
338                            buffer->buffer()->device()};
339         ++entries[key];
340       }
341     }
342   }
343 
344   for (PyExecutable* executable = executables_; executable;
345        executable = executable->next_) {
346     if (!executable->is_deleted()) {
347       HeapProfileKey key{executable->traceback(),
348                          executable->SizeOfGeneratedCodeInBytes(), nullptr};
349       ++entries[key];
350     }
351   }
352 
353   ProfileBuilder builder;
354   auto* allocations = builder.profile().add_sample_type();
355   allocations->set_type(builder.StringId("allocations"));
356   allocations->set_unit(builder.StringId("count"));
357   auto* space = builder.profile().add_sample_type();
358   space->set_type(builder.StringId("space"));
359   space->set_unit(builder.StringId("bytes"));
360 
361   const int kind_string_id = builder.StringId("kind");
362   const int buffer_string_id = builder.StringId("buffer");
363   const int executable_string_id = builder.StringId("executable");
364   const int device_string_id = builder.StringId("device");
365   for (const auto& entry : entries) {
366     auto* sample = builder.profile().add_sample();
367     if (entry.first.traceback) {
368       for (const auto& frame : entry.first.traceback->raw_frames()) {
369         sample->add_location_id(builder.LocationId(frame.first, frame.second));
370       }
371     }
372     sample->add_value(entry.second);
373     sample->add_value(entry.first.size * entry.second);
374 
375     auto* kind_label = sample->add_label();
376     kind_label->set_key(kind_string_id);
377     if (entry.first.device) {
378       kind_label->set_str(buffer_string_id);
379       auto* device_label = sample->add_label();
380       device_label->set_key(device_string_id);
381       device_label->set_str(
382           builder.StringId(entry.first.device->DebugString()));
383     } else {
384       kind_label->set_str(executable_string_id);
385     }
386   }
387   return py::bytes(builder.profile().SerializeAsString());
388 }
389 
390 namespace {
391 
392 class CpuCallback {
393  public:
394   struct Arg {
395     PrimitiveType type;                    // XLA type
396     py::dtype dtype;                       // NumPy type, for array types.
397     absl::InlinedVector<int64_t, 4> dims;  // Dimensions, for array types.
398     std::vector<ssize_t> strides;          // Byte strides, for array types.
399   };
400   struct Result {
401     PrimitiveType type;  // XLA type
402     // Expected output shape, for array types
403     absl::InlinedVector<int64_t, 4> expected_dims;
404     // Expected output byte strides, for array types. If the strides do not
405     // match the output will be transposed into the expected layout.
406     std::vector<int64_t> expected_strides;
407     // The desired order of output dimensions in major-to-minor order.
408     absl::InlinedVector<int64_t, 4> reversed_layout;
409     // Size of the array in bytes.
410     size_t size_in_bytes;
411   };
412 
CpuCallback(py::function callable,std::vector<Arg> args,std::vector<Result> results)413   explicit CpuCallback(py::function callable, std::vector<Arg> args,
414                        std::vector<Result> results)
415       : callable_(std::move(callable)),
416         args_(std::move(args)),
417         results_(std::move(results)),
418         transpose_cache_(/*capacity=*/16) {}
419 
420   void Call(void* result, void** arg_ptrs);
421 
422  private:
423   py::function callable_;
424   std::vector<Arg> const args_;
425   std::vector<Result> const results_;
426   TransposePlanCache transpose_cache_;
427 };
428 
Call(void * result,void ** arg_ptrs)429 void CpuCallback::Call(void* result, void** arg_ptrs) {
430   absl::Span<void* const> inputs(arg_ptrs, args_.size());
431   absl::Span<void* const> outputs(reinterpret_cast<void**>(result),
432                                   results_.size());
433 
434   py::gil_scoped_acquire gil;
435   py::tuple args(inputs.size());
436   for (size_t i = 0; i < inputs.size(); ++i) {
437     if (args_[i].type == TOKEN) {
438       args[i] = py::none();
439     } else {
440       args[i] = py::array(args_[i].dtype, args_[i].dims, args_[i].strides,
441                           const_cast<void*>(inputs[i]));
442       args[i].attr("flags").attr("writeable") = Py_False;
443     }
444   }
445   py::object result_tuple = callable_(*py::reinterpret_borrow<py::args>(args));
446   if (!PyTuple_Check(result_tuple.ptr())) {
447     throw std::runtime_error(
448         absl::StrFormat("CPU callback expected a tuple result, got %s",
449                         static_cast<std::string>(py::repr(result_tuple))));
450   }
451   if (PyTuple_Size(result_tuple.ptr()) != results_.size()) {
452     throw std::runtime_error(
453         absl::StrFormat("CPU callback expected a tuple with %d results, got %d",
454                         results_.size(), PyTuple_Size(result_tuple.ptr())));
455   }
456   for (size_t i = 0; i < results_.size(); ++i) {
457     py::object output = py::reinterpret_borrow<py::object>(
458         PyTuple_GetItem(result_tuple.ptr(), i));
459     if (results_[i].type == TOKEN) {
460       if (!output.is_none()) {
461         throw std::runtime_error(absl::StrFormat(
462             "Token output from Python callback should be None, got %s",
463             static_cast<std::string>(py::repr(output))));
464       }
465       continue;
466     }
467     py::array array = py::cast<py::array>(std::move(output));
468     static_assert(sizeof(ssize_t) == sizeof(int64_t),
469                   "Expected ssize_t to be of equal size to int64_t");
470     absl::Span<int64_t const> dims(
471         reinterpret_cast<const int64_t*>(array.shape()), array.ndim());
472     if (dims != results_[i].expected_dims) {
473       throw std::runtime_error(absl::StrFormat(
474           "Mismatched result shape for %d-th return value from CPU callback; "
475           "expected array with dimensions %s, got %s",
476           i, absl::StrJoin(results_[i].expected_dims, ","),
477           absl::StrJoin(dims, ",")));
478     }
479     absl::Span<int64_t const> strides(
480         reinterpret_cast<const int64_t*>(array.strides()), array.ndim());
481     if (strides == results_[i].expected_strides) {
482       std::memcpy(outputs[i], array.data(), results_[i].size_in_bytes);
483     } else {
484       StatusOr<std::shared_ptr<TransposePlan>> plan =
485           transpose_cache_.GetOrCreate(
486               primitive_util::ByteWidth(results_[i].type), dims,
487               results_[i].reversed_layout,
488               /*input_layout=*/TransposePlan::Striding{strides});
489       if (!plan.ok()) {
490         throw std::runtime_error(plan.status().ToString());
491       }
492       plan.ValueOrDie()->Execute(array.data(), outputs[i]);
493     }
494   }
495 }
496 
XlaPythonCpuCallback(void * output,void ** inputs)497 extern "C" void XlaPythonCpuCallback(void* output, void** inputs) {
498   CpuCallback* callback =
499       absl::bit_cast<CpuCallback*>(*static_cast<uintptr_t*>(inputs[0]));
500   callback->Call(output, inputs + 1);
501 }
502 
503 XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("xla_python_cpu_callback",
504                                              &XlaPythonCpuCallback);
505 
506 }  // namespace
507 
EmitPythonCallback(pybind11::function callable,XlaBuilder & builder,absl::Span<XlaOp const> operands,absl::Span<Shape const> result_shapes,absl::optional<std::vector<Shape>> operand_layouts,bool has_side_effect)508 StatusOr<std::pair<XlaOp, pybind11::object>> PyClient::EmitPythonCallback(
509     pybind11::function callable, XlaBuilder& builder,
510     absl::Span<XlaOp const> operands, absl::Span<Shape const> result_shapes,
511     absl::optional<std::vector<Shape>> operand_layouts, bool has_side_effect) {
512   if (pjrt_client_->platform_id() != kCpuId) {
513     return Unimplemented("EmitPythonCallback is only implemented on CPU");
514   }
515 
516   std::vector<CpuCallback::Arg> callback_args(operands.size());
517   std::vector<XlaOp> custom_call_args(operands.size() + 1);
518   absl::c_copy(operands, custom_call_args.begin() + 1);
519 
520   if (operand_layouts && operand_layouts->size() != operands.size()) {
521     return InvalidArgument(
522         "Mismatched number of operands (%d) and operand_layouts (%d)",
523         operands.size(), operand_layouts->size());
524   }
525 
526   std::vector<Shape> custom_call_arg_layouts(operands.size() + 1);
527   static_assert(sizeof(uintptr_t) == sizeof(uint64_t),
528                 "Expected 64-bit pointers");
529   custom_call_arg_layouts[0] =
530       ShapeUtil::MakeShapeWithDescendingLayout(U64, {});
531   for (int i = 0; i < operands.size(); ++i) {
532     TF_ASSIGN_OR_RETURN(Shape shape, builder.GetShape(operands[i]));
533     xla::Shape& layout = custom_call_arg_layouts[i + 1];
534     if (operand_layouts) {
535       if (!(*operand_layouts)[i].has_layout()) {
536         return InvalidArgument(
537             "operand_layout shapes for callback must have "
538             "layouts, got %s",
539             (*operand_layouts)[i].ToString(/*print_layout=*/true));
540       }
541       if (!ShapeUtil::Compatible(shape, (*operand_layouts)[i])) {
542         return InvalidArgument(
543             "Incompatible shapes for Python callback argument %d: %s vs %s", i,
544             shape.ToString(),
545             (*operand_layouts)[i].ToString(/*print_layout=*/true));
546       }
547       layout = (*operand_layouts)[i];
548     } else {
549       layout = LayoutUtil::GetWithDefaultLayout(shape);
550     }
551 
552     if (shape.IsArray()) {
553       callback_args[i].dims.resize(shape.dimensions_size());
554       absl::c_copy(shape.dimensions(), callback_args[i].dims.begin());
555       callback_args[i].strides = ByteStridesForShape(layout);
556       callback_args[i].type = shape.element_type();
557       TF_ASSIGN_OR_RETURN(callback_args[i].dtype,
558                           PrimitiveTypeToDtype(shape.element_type()));
559     } else if (shape.IsToken()) {
560       callback_args[i].type = TOKEN;
561     } else {
562       return InvalidArgument(
563           "Only array and token arguments to Python callbacks are supported, "
564           "got %s",
565           shape.ToString());
566     }
567   }
568 
569   std::vector<Shape> result_shapes_with_layout(result_shapes.size());
570   std::vector<CpuCallback::Result> callback_results(result_shapes.size());
571   for (int i = 0; i < result_shapes.size(); ++i) {
572     if (result_shapes[i].IsArray()) {
573       result_shapes_with_layout[i] =
574           result_shapes[i].has_layout()
575               ? result_shapes[i]
576               : LayoutUtil::GetWithDefaultLayout(result_shapes[i]);
577       const Shape& shape = result_shapes_with_layout[i];
578       callback_results[i].expected_dims.resize(shape.dimensions_size());
579       absl::c_copy(shape.dimensions(),
580                    callback_results[i].expected_dims.begin());
581       callback_results[i].expected_strides = ByteStridesForShapeInt64(shape);
582       callback_results[i].type = shape.element_type();
583       callback_results[i].size_in_bytes = ShapeUtil::ByteSizeOf(shape);
584       callback_results[i].reversed_layout.resize(shape.dimensions_size());
585       absl::c_reverse_copy(shape.layout().minor_to_major(),
586                            callback_results[i].reversed_layout.begin());
587     } else if (result_shapes[i].IsToken()) {
588       callback_results[i].type = TOKEN;
589       result_shapes_with_layout[i] = result_shapes[i];
590     } else {
591       return InvalidArgument(
592           "Only array and token return values from Python callbacks are "
593           "supported, got %s",
594           result_shapes[i].ToString());
595     }
596   }
597 
598   auto callback = std::make_unique<CpuCallback>(
599       std::move(callable), callback_args, callback_results);
600   custom_call_args[0] = ConstantR0<std::uint64_t>(
601       &builder, absl::bit_cast<std::uint64_t>(callback.get()));
602 
603   Shape result_shape = ShapeUtil::MakeTupleShape(result_shapes_with_layout);
604   XlaOp result = CustomCallWithLayout(&builder, "xla_python_cpu_callback",
605                                       custom_call_args, result_shape,
606                                       custom_call_arg_layouts,
607                                       /*opaque=*/"", has_side_effect);
608 
609   py::capsule callback_capsule(callback.release(), [](void* ptr) {
610     delete reinterpret_cast<CpuCallback*>(ptr);
611   });
612   return std::make_pair(result, py::object(std::move(callback_capsule)));
613 }
614 
615 }  // namespace xla
616