1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17 #include <string>
18 #include <vector>
19
20 #include "absl/strings/str_format.h"
21 #include "absl/strings/str_join.h"
22 #include "pybind11/attr.h"
23 #include "pybind11/cast.h"
24 #include "pybind11/numpy.h"
25 #include "pybind11/pybind11.h"
26 #include "pybind11/pytypes.h"
27 #include "pybind11/stl_bind.h"
28 #include "tensorflow/compiler/xla/layout_util.h"
29 #include "tensorflow/compiler/xla/pjrt/cpu_device.h"
30 #include "tensorflow/compiler/xla/pjrt/distributed/client.h"
31 #include "tensorflow/compiler/xla/pjrt/distributed/distributed.h"
32 #include "tensorflow/compiler/xla/pjrt/distributed/service.h"
33 #include "tensorflow/compiler/xla/pjrt/gpu_device.h"
34 #include "tensorflow/compiler/xla/pjrt/interpreter_device.h"
35 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
36 #include "tensorflow/compiler/xla/pjrt/tpu_client.h"
37 #include "tensorflow/compiler/xla/python/dlpack.h"
38 #include "tensorflow/compiler/xla/python/jax_jit.h"
39 #include "tensorflow/compiler/xla/python/ops.h"
40 #include "tensorflow/compiler/xla/python/outfeed_receiver_py.h"
41 #include "tensorflow/compiler/xla/python/pmap_lib.h"
42 #include "tensorflow/compiler/xla/python/profiler.h"
43 #include "tensorflow/compiler/xla/python/py_buffer.h"
44 #include "tensorflow/compiler/xla/python/py_executable.h"
45 #include "tensorflow/compiler/xla/python/py_traceback.h"
46 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
47 #include "tensorflow/compiler/xla/python/pytree.h"
48 #include "tensorflow/compiler/xla/python/types.h"
49 #include "tensorflow/compiler/xla/python/xla_compiler.h"
50 #include "tensorflow/compiler/xla/shape.h"
51 #include "tensorflow/compiler/xla/shape_util.h"
52 #include "tensorflow/compiler/xla/statusor.h"
53 #include "tensorflow/compiler/xla/util.h"
54 #include "tensorflow/core/platform/errors.h"
55 #include "tensorflow/python/lib/core/bfloat16.h"
56
57 // TODO(phawkins): remove host_id properties after JAX is update to avoid them.
58
59 namespace xla {
60 namespace {
61
62 namespace py = pybind11;
63
IsOptimizedBuild()64 bool IsOptimizedBuild() {
65 #if NDEBUG
66 return true;
67 #else
68 return false;
69 #endif // NDEBUG
70 }
71
72 } // namespace
73
PYBIND11_MODULE(xla_extension,m)74 PYBIND11_MODULE(xla_extension, m) {
75 CHECK(tensorflow::RegisterNumpyBfloat16());
76
77 // Types
78 py::enum_<PrimitiveType>(m, "PrimitiveType")
79 .value("PRIMITIVE_TYPE_INVALID", PRIMITIVE_TYPE_INVALID)
80 .value("PRED", PRED)
81 .value("S8", S8)
82 .value("S16", S16)
83 .value("S32", S32)
84 .value("S64", S64)
85 .value("U8", U8)
86 .value("U16", U16)
87 .value("U32", U32)
88 .value("U64", U64)
89 .value("F16", F16)
90 .value("BF16", BF16)
91 .value("F32", F32)
92 .value("F64", F64)
93 .value("C64", C64)
94 .value("C128", C128)
95 .value("TUPLE", TUPLE)
96 .value("OPAQUE_TYPE", OPAQUE_TYPE)
97 .value("TOKEN", TOKEN);
98
99 m.def("bfloat16_dtype",
100 []() { return py::handle(tensorflow::Bfloat16Dtype()); });
101
102 // Must be before PyClient.compile.
103 BuildXlaCompilerSubmodule(m);
104
105 py::class_<PjRtDevice, ClientAndPtr<PjRtDevice>>(
106 m, "Device",
107 "A descriptor of an available device.\n\nSubclasses are used to "
108 "represent specific types of devices, e.g. CPUs, GPUs. Subclasses may "
109 "have additional properties specific to that device type.")
110 .def_property_readonly(
111 "id", &PjRtDevice::id,
112 "Integer ID of this device.\n\nUnique across all available devices "
113 "of this type, including remote devices on multi-host platforms.")
114 .def_property_readonly("host_id", &PjRtDevice::task_id,
115 "Integer ID of this device's task.\n\n"
116 "This is always 0 except on multi-task platforms.")
117 .def_property_readonly("task_id", &PjRtDevice::task_id,
118 "Integer ID of this device's task.\n\n"
119 "This is always 0 except on multi-task platforms.")
120 .def_property_readonly("platform",
121 [](const PjRtDevice& device) {
122 return device.client()->platform_name();
123 })
124 .def_property_readonly("device_kind", &PjRtDevice::device_kind)
125 .def_property_readonly(
126 "client",
127 [](const ClientAndPtr<PjRtDevice>& device) { return device.client; })
128 .def("__str__", &PjRtDevice::DebugString)
129 .def("transfer_to_infeed",
130 [](PjRtDevice& device, const LiteralSlice& literal) {
131 GlobalPyRefManager()->CollectGarbage();
132 py::gil_scoped_release gil_release;
133 return device.TransferToInfeed(literal);
134 })
135 .def("transfer_from_outfeed",
136 [](PjRtDevice& device, const Shape& shape) -> StatusOr<py::object> {
137 GlobalPyRefManager()->CollectGarbage();
138 std::shared_ptr<Literal> literal;
139 {
140 py::gil_scoped_release gil_release;
141 Shape shape_with_layout = shape;
142 ShapeUtil::ForEachMutableSubshape(
143 &shape_with_layout, [](Shape* subshape, const ShapeIndex&) {
144 if (!subshape->has_layout()) {
145 LayoutUtil::SetToDefaultLayout(subshape);
146 }
147 });
148 literal = std::make_shared<Literal>(shape_with_layout);
149 TF_RETURN_IF_ERROR(device.TransferFromOutfeed(literal.get()));
150 }
151 return LiteralToPython(std::move(literal));
152 });
153
154 py::class_<CpuDevice, PjRtDevice, ClientAndPtr<CpuDevice>>(m, "CpuDevice")
155 .def("__repr__", [](const CpuDevice& device) {
156 return absl::StrFormat("CpuDevice(id=%i)", device.id());
157 });
158
159 py::class_<GpuDevice, PjRtDevice, ClientAndPtr<GpuDevice>>(m, "GpuDevice")
160 .def("__repr__", [](const GpuDevice& device) {
161 return absl::StrFormat("GpuDevice(id=%i)", device.id());
162 });
163
164 py::class_<PjRtTpuDevice, PjRtDevice, ClientAndPtr<PjRtTpuDevice>>(
165 m, "TpuDevice")
166 .def_property_readonly("host_id", &PjRtTpuDevice::task_id)
167 .def_property_readonly("task_id", &PjRtTpuDevice::task_id)
168 .def_property_readonly(
169 "coords",
170 [](const PjRtTpuDevice& device) -> pybind11::tuple {
171 return IntSpanToTuple(device.coords());
172 },
173 "The coordinates of this TpuDevice's chip in the TPU mesh network.")
174 .def_property_readonly(
175 "core_on_chip", &PjRtTpuDevice::core_on_chip,
176 "The index of this TpuDevice's core on the TPU chip.")
177 .def("__repr__", [](const PjRtTpuDevice& device) {
178 return absl::StrFormat(
179 "TpuDevice(id=%i, host=%i, coords=(%s), core_on_chip=%i)",
180 device.id(), device.task_id(), absl::StrJoin(device.coords(), ","),
181 device.core_on_chip());
182 });
183
184 // Local XLA client methods.
185
186 py::class_<GpuAllocatorConfig> alloc_config(m, "GpuAllocatorConfig");
187 alloc_config.def(py::init<>())
188 .def_readwrite("kind", &GpuAllocatorConfig::kind)
189 .def_readwrite("memory_fraction", &GpuAllocatorConfig::memory_fraction)
190 .def_readwrite("preallocate", &GpuAllocatorConfig::preallocate);
191 py::enum_<GpuAllocatorConfig::Kind>(alloc_config, "Kind")
192 .value("DEFAULT", GpuAllocatorConfig::Kind::kDefault)
193 .value("PLATFORM", GpuAllocatorConfig::Kind::kPlatform)
194 .value("BFC", GpuAllocatorConfig::Kind::kBFC);
195
196 py::enum_<PjRtClient::HostBufferSemantics>(m, "HostBufferSemantics")
197 .value("IMMUTABLE_ONLY_DURING_CALL",
198 PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall)
199 .value("IMMUTABLE_UNTIL_TRANSFER_COMPLETES",
200 PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes)
201 .value("ZERO_COPY", PjRtClient::HostBufferSemantics::kZeroCopy);
202
203 py::class_<PyClient, std::shared_ptr<PyClient>> py_local_client(m, "Client");
204 py_local_client.def_property_readonly("platform", &PyClient::platform_name)
205 .def("device_count", &PyClient::device_count)
206 .def("local_device_count", &PyClient::addressable_device_count)
207 .def("devices", &PyClient::Devices)
208 .def("local_devices", &PyClient::LocalDevices)
209 .def("live_buffers", &PyClient::LiveBuffers)
210 .def("host_id", &PyClient::task_id)
211 .def("task_id", &PyClient::task_id)
212 .def("get_default_device_assignment",
213 &PyClient::GetDefaultDeviceAssignment)
214 // TODO(skye): delete after all callers can handle 2D output
215 .def("get_default_device_assignment",
216 &PyClient::GetDefaultDeviceAssignment1D)
217 .def("create_channel_handle", &PyClient::CreateChannelHandle)
218 .def("create_device_to_host_channel_handle",
219 &PyClient::CreateDeviceToHostChannelHandle)
220 .def("create_host_to_device_channel_handle",
221 &PyClient::CreateHostToDeviceChannelHandle)
222 .def("buffer_from_pyval", &PyClient::BufferFromPyval, py::arg("argument"),
223 py::arg("device") = nullptr, py::arg("force_copy") = false,
224 py::arg("host_buffer_semantics") =
225 PjRtClient::HostBufferSemantics::kZeroCopy)
226 .def("compile", &PyClient::Compile, py::arg("computation"),
227 py::arg("compile_options") = CompileOptions())
228 .def("heap_profile", &PyClient::HeapProfile);
229
230 m.def(
231 "get_cpu_client",
232 [](bool asynchronous) -> StatusOr<std::shared_ptr<PyClient>> {
233 TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtClient> client,
234 GetCpuClient(asynchronous));
235 return std::make_shared<PyClient>(std::move(client));
236 },
237 py::arg("asynchronous") = true);
238 m.def("get_interpreter_client", []() -> StatusOr<std::shared_ptr<PyClient>> {
239 TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtClient> client,
240 GetInterpreterClient());
241 return std::make_shared<PyClient>(std::move(client));
242 });
243 m.def(
244 "get_gpu_client",
245 [](bool asynchronous, const GpuAllocatorConfig& allocator_config,
246 std::shared_ptr<DistributedRuntimeClient> distributed_client,
247 int node_id) -> StatusOr<std::shared_ptr<PyClient>> {
248 TF_ASSIGN_OR_RETURN(
249 std::unique_ptr<PjRtClient> client,
250 GetGpuClient(asynchronous, allocator_config,
251 std::move(distributed_client), node_id));
252 return std::make_shared<PyClient>(std::move(client));
253 },
254 py::arg("asynchronous") = true,
255 py::arg("allocator_config") = GpuAllocatorConfig(),
256 py::arg("distributed_client") = nullptr, py::arg("node_id") = 0);
257 m.def(
258 "get_tpu_client",
259 [](bool asynchronous) -> StatusOr<std::shared_ptr<PyClient>> {
260 TF_ASSIGN_OR_RETURN(std::shared_ptr<PjRtClient> client,
261 GetTpuClient(asynchronous));
262 return std::make_shared<PyClient>(std::move(client));
263 },
264 py::arg("asynchronous") = true);
265
266 py::class_<DeviceArrayBase> device_array_base(m, "DeviceArrayBase");
267 device_array_base.def(py::init<>());
268
269 py::class_<PyBuffer, DeviceArrayBase, std::unique_ptr<PyBuffer>> buffer(
270 m, "Buffer");
271 // TODO(phawkins): alias for backward compatibility. Remove after JAX no
272 // longer uses this name.
273 m.add_object("PyLocalBuffer", buffer);
274 buffer
275 .def_property_readonly("__array_priority__",
276 [](py::object) { return 100; })
277 .def_property("_device", &PyBuffer::GetStickyDevice,
278 &PyBuffer::SetStickyDevice)
279 .def_property("aval", &PyBuffer::GetAval, &PyBuffer::SetAval)
280 .def_property_readonly("_lazy_expr",
281 [](py::object buffer) { return py::none(); })
282 .def_property_readonly("device_buffer",
283 [](py::object buffer) { return buffer; })
284 .def_property_readonly(
285 "shape",
286 [](const PyBuffer& pybuffer) -> pybind11::tuple {
287 return IntSpanToTuple(
288 pybuffer.buffer()->on_device_shape().dimensions());
289 })
290 .def_property_readonly(
291 "dtype",
292 [](const PyBuffer& buffer) {
293 PrimitiveType primitive =
294 buffer.buffer()->on_device_shape().element_type();
295 return PrimitiveTypeToDtype(primitive).ValueOrDie();
296 })
297 .def_property_readonly("size", &PyBuffer::size)
298 .def_property_readonly("ndim", &PyBuffer::ndim)
299 .def_property_readonly(
300 "_value",
301 [](py::handle buffer_obj) -> StatusOr<pybind11::object> {
302 GlobalPyRefManager()->CollectGarbage();
303 PyBuffer* buffer = buffer_obj.cast<PyBuffer*>();
304 return buffer->AsNumPyArray(buffer_obj);
305 })
306 .def("copy_to_device", &PyBuffer::CopyToDevice)
307 .def("on_device_size_in_bytes", &PyBuffer::OnDeviceSizeInBytes)
308 .def("delete", &PyBuffer::Delete)
309 // The GIL is released within BlockHostUntilReady.
310 .def("block_until_ready",
311 [](py::object buffer_obj) -> xla::StatusOr<py::object> {
312 PyBuffer* buffer = buffer_obj.cast<PyBuffer*>();
313 TF_RETURN_IF_ERROR(buffer->BlockHostUntilReady());
314 return buffer_obj;
315 })
316 .def("block_host_until_ready", &PyBuffer::BlockHostUntilReady)
317 .def("copy_to_host_async", &PyBuffer::CopyToHostAsync)
318 .def("to_py",
319 [](py::handle buffer_obj) {
320 PyBuffer* buffer = buffer_obj.cast<PyBuffer*>();
321 return buffer->AsNumPyArray(buffer_obj);
322 })
323 .def("xla_shape", &PyBuffer::shape)
324 .def_property_readonly("client", &PyBuffer::client)
325 .def("device", &PyBuffer::device)
326 .def("platform", &PyBuffer::platform_name)
327 .def("is_deleted", &PyBuffer::is_deleted)
328 .def("unsafe_buffer_pointer", &PyBuffer::UnsafeBufferPointer)
329 .def_property_readonly("__cuda_array_interface__",
330 &PyBuffer::CudaArrayInterface)
331 .def_property_readonly("traceback", &PyBuffer::traceback);
332
333 // pybind11's implementation of the buffer protocol doesn't allow for correct
334 // error handling. We bypass it and implement the buffer protocol ourselves.
335 PyTypeObject* buffer_type = reinterpret_cast<PyTypeObject*>(buffer.ptr());
336 buffer_type->tp_as_buffer = PyBuffer::BufferProtocol();
337
338 py::class_<PyExecutable, std::shared_ptr<PyExecutable>> executable(
339 m, "Executable");
340 executable.def_property_readonly("client", &PyExecutable::client)
341 .def("local_logical_device_ids",
342 [](PyExecutable* exec) {
343 auto span = exec->addressable_device_logical_ids();
344 // Not on dispatch critical path, so ok to have heap allocation.
345 std::vector<std::pair<int, int>> addressable_device_logic_ids;
346 addressable_device_logic_ids.reserve(span.size());
347 for (const auto& logical_device_id : span) {
348 addressable_device_logic_ids.push_back(std::make_pair(
349 logical_device_id.replica, logical_device_id.partition));
350 }
351 })
352 .def("local_devices", &PyExecutable::AddressableDevices)
353 .def("size_of_generated_code_in_bytes",
354 &PyExecutable::SizeOfGeneratedCodeInBytes)
355 .def("delete", &PyExecutable::Delete)
356 .def("execute", &PyExecutable::Execute, py::arg("arguments"))
357 .def("execute_on_local_devices", &PyExecutable::ExecuteOnLocalDevices,
358 py::arg("arguments"))
359 .def("execute_sharded_on_local_devices",
360 &PyExecutable::ExecuteShardedOnLocalDevices, py::arg("arguments"))
361 .def("hlo_modules", &PyExecutable::HloModules)
362 .def_property_readonly("traceback", &PyExecutable::traceback);
363
364 m.def("buffer_to_dlpack_managed_tensor", BufferToDLPackManagedTensor,
365 py::arg("buffer"), py::arg("take_ownership") = true);
366 m.def("dlpack_managed_tensor_to_buffer", DLPackManagedTensorToBuffer);
367
368 BuildProfilerSubmodule(&m);
369 BuildOpsSubmodule(&m);
370 BuildOutfeedReceiverSubmodule(&m);
371 BuildPytreeSubmodule(m);
372 jax::BuildJaxjitSubmodule(m);
373 jax::BuildPmapSubmodule(m);
374 BuildTracebackSubmodule(m);
375
376 py::class_<DistributedRuntimeService,
377 std::unique_ptr<DistributedRuntimeService>>
378 distributed_runtime_service(m, "DistributedRuntimeService");
379 py::class_<DistributedRuntimeClient,
380 std::shared_ptr<DistributedRuntimeClient>>
381 distributed_runtime_client(m, "DistributedRuntimeClient");
382 distributed_runtime_client.def("connect", &DistributedRuntimeClient::Connect)
383 .def("shutdown", &DistributedRuntimeClient::Shutdown);
384
385 m.def("get_distributed_runtime_service", &GetDistributedRuntimeService);
386 m.def("get_distributed_runtime_client", &GetDistributedRuntimeClient);
387
388 m.def("collect_garbage", []() { GlobalPyRefManager()->CollectGarbage(); });
389
390 m.def("is_optimized_build", &IsOptimizedBuild);
391 } // NOLINT(readability/fn_size)
392
393 } // namespace xla
394