1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This files implements the `jax.jit` dispatch and just-in-time feature.
17 //
18 // In a nutshell, `Jit(f)` returns a callable that will dispatch (i.e. forward
19 // based on passed arguments dtypes/shapes/identity) the execution to a
20 // just-in-time compiled XLA Executable. All of that is done in C++ for
21 // performance reasons.
22 //
23 // This file contains the utilities to:
24 // (a) inspect arguments and describe their structure, dtype/shapes, etc.
25 // (b) keep a mapping from function signatures to compiled XLA Executables.
26
27 #include "tensorflow/compiler/xla/python/jax_jit.h"
28
29 #include <Python.h>
30
31 #include <exception>
32 #include <memory>
33 #include <stdexcept>
34 #include <utility>
35
36 #include "absl/container/flat_hash_map.h"
37 #include "absl/container/inlined_vector.h"
38 #include "absl/strings/str_cat.h"
39 #include "absl/synchronization/notification.h"
40 #include "absl/types/optional.h"
41 #include "pybind11/cast.h"
42 #include "pybind11/numpy.h"
43 #include "pybind11/pybind11.h"
44 #include "pybind11/pytypes.h"
45 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
46 #include "tensorflow/compiler/xla/python/py_buffer.h"
47 #include "tensorflow/compiler/xla/python/py_executable.h"
48 #include "tensorflow/compiler/xla/python/pytree.h"
49 #include "tensorflow/compiler/xla/python/types.h"
50 #include "tensorflow/compiler/xla/shape_util.h"
51 #include "tensorflow/compiler/xla/statusor.h"
52 #include "tensorflow/compiler/xla/types.h"
53 #include "tensorflow/compiler/xla/util.h"
54 #include "tensorflow/compiler/xla/xla_data.pb.h"
55 #include "tensorflow/core/platform/status.h"
56
57 namespace jax {
58
59 namespace py = pybind11;
60
61 // TODO(phawkins): Add support for Tracers.
62 // TODO(jblespiau): Use absl Status.
63 // TODO(jblespiau): Remove the "xla::" prefixes when not needed.
64
DebugString() const65 std::string ArgSignature::DebugString() const {
66 std::string result = "";
67 if (weak_type) {
68 absl::StrAppend(&result, "weak_");
69 }
70 absl::StrAppend(&result, xla::PrimitiveType_Name(dtype));
71 absl::StrAppend(&result, "[", absl::StrJoin(shape, ","), "]");
72 return result;
73 }
74
operator ==(const CallSignature & other) const75 bool CallSignature::operator==(const CallSignature& other) const {
76 return std::tie(dynamic_positional_args_treedef, keyword_args,
77 dynamic_args_signatures, device) ==
78 std::tie(other.dynamic_positional_args_treedef, other.keyword_args,
79 other.dynamic_args_signatures, other.device) &&
80 // `==` on py:objects is the Python `is`. We need equal.
81 std::equal(
82 static_args.begin(), static_args.end(), other.static_args.begin(),
83 other.static_args.end(),
84 [](const py::object& a, const py::object& b) {
85 try {
86 return a.equal(b);
87 } catch (const py::error_already_set& e) {
88 throw std::invalid_argument(absl::StrCat(
89 "static arguments should be comparable using __eq__."
90 "The following error was raised when comparing two "
91 "objects of types ",
92 py::cast<std::string>(py::str(py::type::of(a))), " and ",
93 py::cast<std::string>(py::str(py::type::of(b))),
94 ". The error was:\n", e.what()));
95 }
96 });
97 }
98
IncRef() const99 void CallSignature::IncRef() const {
100 for (const auto& kw : keyword_args) {
101 kw.key.inc_ref();
102 }
103 }
104
DecRef() const105 void CallSignature::DecRef() const {
106 for (const auto& kw : keyword_args) {
107 kw.key.dec_ref();
108 }
109 }
110
111 namespace {
112
113 thread_local bool disable_jit;
SetDisableJit(bool disable_jit_)114 void SetDisableJit(bool disable_jit_) { disable_jit = disable_jit_; }
GetDisableJit()115 bool GetDisableJit() { return disable_jit; }
116
117 } // namespace
118
DebugString() const119 std::string CallSignature::DebugString() const {
120 std::vector<std::string> static_args_str;
121 static_args_str.reserve(static_args.size());
122 for (auto& static_arg : static_args) {
123 static_args_str.emplace_back(py::cast<std::string>(py::str(static_arg)));
124 }
125
126 std::vector<std::string> signature_str;
127 signature_str.reserve(dynamic_args_signatures.size());
128
129 for (auto& arg_signature : dynamic_args_signatures) {
130 signature_str.emplace_back(arg_signature.DebugString());
131 }
132 std::vector<std::string> tree_def_str;
133 signature_str.reserve(dynamic_positional_args_treedef.size());
134 for (auto& tree_def : dynamic_positional_args_treedef) {
135 tree_def_str.emplace_back(tree_def.ToString());
136 }
137 std::vector<std::string> keyword_names;
138 keyword_names.reserve(keyword_args.size());
139 for (auto& kwarg_entry : keyword_args) {
140 keyword_names.emplace_back(py::cast<std::string>(kwarg_entry.key));
141 tree_def_str.emplace_back(kwarg_entry.value_treedef.ToString());
142 }
143 return absl::StrCat(
144 static_args.size(), " static_args: ", absl::StrJoin(static_args_str, ","),
145 "\n", // new line
146 keyword_args.size(), " keyword args:", absl::StrJoin(keyword_names, ","),
147 "\n", // new-line
148 dynamic_positional_args_treedef.size(), " positional args.\n",
149 dynamic_args_signatures.size(),
150 " dynamic args (positional+keyword):\n - ",
151 absl::StrJoin(signature_str, ", "), "\n - ",
152 absl::StrJoin(tree_def_str, " | "));
153 }
154
155 template <typename H>
AbslHashValue(H h,const CallSignature & s)156 H AbslHashValue(H h, const CallSignature& s) {
157 h = H::combine_contiguous(std::move(h),
158 s.dynamic_positional_args_treedef.data(),
159 s.dynamic_positional_args_treedef.size());
160 h = H::combine_contiguous(std::move(h), s.keyword_args.data(),
161 s.keyword_args.size());
162 h = H::combine_contiguous(std::move(h), s.dynamic_args_signatures.data(),
163 s.dynamic_args_signatures.size());
164 h = H::combine(std::move(h), s.device);
165 for (const auto& static_arg : s.static_args) {
166 ssize_t hash;
167 try {
168 hash = py::hash(static_arg);
169 } catch (const py::error_already_set& e) {
170 throw std::invalid_argument(absl::StrCat(
171 "Non-hashable static arguments are not supported. An error occured "
172 "while trying to hash an object of type ",
173 py::cast<std::string>(py::str(py::type::of(static_arg))), ", ",
174 py::cast<std::string>(py::str(static_arg)), ". The error was:\n",
175 e.what(), "\n"));
176 }
177 h = H::combine(std::move(h), hash);
178 }
179 return h;
180 }
181
182 // Filter out static arguments, flatten and concatenate other arguments (i.e.
183 // dynamic positional and keyword arguments), filling `arguments` in place.
ParseArguments(const py::args & args,const py::kwargs & py_kwargs,absl::Span<int const> static_argnums,ParsedArgumentsAsBuffers & arguments)184 xla::Status ParseArguments(const py::args& args, const py::kwargs& py_kwargs,
185 absl::Span<int const> static_argnums,
186 ParsedArgumentsAsBuffers& arguments) {
187 if (static_argnums.size() > args.size()) {
188 return xla::InvalidArgument(
189 "%s", "[jaxjit] Error with static argnums, executing the Python path.");
190 }
191 arguments.flat_dynamic_args.reserve(args.size() + py_kwargs.size() -
192 static_argnums.size());
193 arguments.signature.dynamic_positional_args_treedef.reserve(
194 args.size() - static_argnums.size());
195
196 // Positional arguments.
197 for (size_t i = 0; i < args.size(); ++i) {
198 if (std::find(static_argnums.begin(), static_argnums.end(), i) ==
199 static_argnums.end()) {
200 xla::PyTreeDef pytree_def;
201 pytree_def.FlattenInto(args[i], arguments.flat_dynamic_args);
202 arguments.signature.dynamic_positional_args_treedef.push_back(pytree_def);
203 } else {
204 arguments.signature.static_args.emplace_back(
205 // borrow is mandatory here.
206 py::reinterpret_borrow<py::object>(args[i]));
207 }
208 }
209
210 // Keyword arguments.
211 std::vector<std::pair<py::handle, py::handle>> kwargs(py_kwargs.begin(),
212 py_kwargs.end());
213 // We first intern the keys, then sort them (by name, as in the Python path)
214 // (see also xla::PyTreeDef::Flatten) and then create the signatures.
215 // TODO(jblespiau): We should be able to sort the keys by interned-key
216 // pointers, but this requires the Python compilation to do the same.
217 arguments.signature.keyword_args.resize(kwargs.size());
218 for (size_t i = 0; i < kwargs.size(); ++i) {
219 // Intern the key if not already interned.
220 if (!PyUnicode_CHECK_INTERNED(kwargs[i].first.ptr())) {
221 PyObject* key = kwargs[i].first.ptr();
222 kwargs[i].first.inc_ref();
223 PyUnicode_InternInPlace(&key);
224 arguments.keep_alive_objects.push_back(
225 py::reinterpret_steal<py::object>(key));
226 kwargs[i].first = py::handle(key);
227 }
228 }
229
230 std::sort(kwargs.begin(), kwargs.end(),
231 [](const std::pair<py::handle, py::handle>& a,
232 const std::pair<py::handle, py::handle>& b) {
233 return a.first < b.first;
234 });
235 for (size_t i = 0; i < kwargs.size(); ++i) {
236 arguments.signature.keyword_args[i].key = kwargs[i].first;
237 arguments.signature.keyword_args[i].value_treedef.FlattenInto(
238 kwargs[i].second, arguments.flat_dynamic_args);
239 }
240 return xla::Status::OK();
241 }
242
243 namespace {
244
245 struct NumpyScalarTypes {
246 py::object np_bool;
247 py::object np_int8;
248 py::object np_int16;
249 py::object np_int32;
250 py::object np_int64;
251 py::object np_uint8;
252 py::object np_uint16;
253 py::object np_uint32;
254 py::object np_uint64;
255 py::object np_float16;
256 py::object np_float32;
257 py::object np_float64;
258 py::object np_complex64;
259 py::object np_complex128;
260 py::object np_longlong;
261 py::object np_intc;
262 };
263
GetNumpyScalarTypes()264 const NumpyScalarTypes& GetNumpyScalarTypes() {
265 static const NumpyScalarTypes* singleton = []() {
266 // Use Designated initializers when they are available.
267 const auto numpy = py::module::import("numpy");
268 NumpyScalarTypes* dtypes = new NumpyScalarTypes();
269 dtypes->np_bool = py::object(numpy.attr("bool_"));
270 dtypes->np_int8 = py::object(numpy.attr("int8"));
271 dtypes->np_int16 = py::object(numpy.attr("int16"));
272 dtypes->np_int32 = py::object(numpy.attr("int32"));
273 dtypes->np_int64 = py::object(numpy.attr("int64"));
274 dtypes->np_uint8 = py::object(numpy.attr("uint8"));
275 dtypes->np_uint16 = py::object(numpy.attr("uint16"));
276 dtypes->np_uint32 = py::object(numpy.attr("uint32"));
277 dtypes->np_uint64 = py::object(numpy.attr("uint64"));
278 dtypes->np_float16 = py::object(numpy.attr("float16"));
279 dtypes->np_float32 = py::object(numpy.attr("float32"));
280 dtypes->np_float64 = py::object(numpy.attr("float64"));
281 dtypes->np_complex64 = py::object(numpy.attr("complex64"));
282 dtypes->np_complex128 = py::object(numpy.attr("complex128"));
283 dtypes->np_longlong = py::object(numpy.attr("longlong"));
284 dtypes->np_intc = py::object(numpy.attr("intc"));
285
286 return dtypes;
287 }();
288
289 return *singleton;
290 }
291
DtypeTo32BitDtype(const py::dtype & dtype)292 const py::dtype* DtypeTo32BitDtype(const py::dtype& dtype) {
293 // TODO(jblespiau): Use GetNumpyScalarTypes instead.
294 static const auto* int64_dt = new py::dtype("int64");
295 static const auto* int32_dt = new py::dtype("int32");
296 static const auto* uint64_dt = new py::dtype("uint64");
297 static const auto* uint32_dt = new py::dtype("uint32");
298 static const auto* float64_dt = new py::dtype("float64");
299 static const auto* float32_dt = new py::dtype("float32");
300 static const auto* complex64_dt = new py::dtype("complex64");
301 static const auto* complex128_dt = new py::dtype("complex128");
302
303 if (dtype.equal(*int64_dt)) {
304 return int32_dt;
305 }
306 if (dtype.equal(*float64_dt)) {
307 return float32_dt;
308 }
309 if (dtype.equal(*uint64_dt)) {
310 return uint32_dt;
311 }
312 if (dtype.equal(*complex128_dt)) {
313 return complex64_dt;
314 }
315
316 return nullptr;
317 }
318
319 // The equivalent of the Python jax/lazy.py::is_trivial:
320 // return (type(lexpr.input) is ArrayVar and
321 // lexpr.dims == tuple(range(len(lexpr.shape))))
322 //
323 // Expects *only* `None` or a LazyExpr` object.
IsTrivialLazyExpr(py::handle lexpr)324 bool IsTrivialLazyExpr(py::handle lexpr) {
325 if (lexpr.is_none()) {
326 return true;
327 }
328
329 static const auto* lazy_module =
330 new py::module(py::module::import("jax.lazy"));
331 auto input = py::getattr(lexpr, "input");
332 if (!input.get_type().is(lazy_module->attr("ArrayVar"))) {
333 return false;
334 }
335 py::tuple dims = py::cast<py::tuple>(lexpr.attr("dims"));
336 py::tuple shape = py::cast<py::tuple>(lexpr.attr("shape"));
337
338 for (int i = 0; i < shape.size(); ++i) {
339 if (dims[i].is_none()) {
340 return false;
341 }
342 if (py::cast<int>(dims[i]) != i) {
343 return false;
344 }
345 }
346 return true;
347 }
348
IsFloat0(py::array arg)349 bool IsFloat0(py::array arg) {
350 static const auto* dtypes_module =
351 new py::module(py::module::import("jax.dtypes"));
352 static const auto* float0_dtype =
353 new py::handle(dtypes_module->attr("float0"));
354 return float0_dtype->is(arg.attr("dtype"));
355 }
356
357 template <typename CppType, typename Pybind11Type>
ConvertToScalarBuffer(const py::handle & scalar,xla::PjRtClient * client,xla::PjRtDevice * device)358 std::unique_ptr<xla::PjRtBuffer> ConvertToScalarBuffer(
359 const py::handle& scalar, xla::PjRtClient* client,
360 xla::PjRtDevice* device) {
361 CppType data = py::cast<Pybind11Type>(scalar);
362 // Work around for https://github.com/pybind/pybind11/issues/2786
363 if (PyErr_Occurred()) {
364 throw py::error_already_set();
365 }
366 xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<CppType>({});
367 return ValueOrThrow(client->BufferFromHostBuffer(
368 &data, shape,
369 xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr,
370 device));
371 }
372
373 } // namespace
374
375 namespace {
376
377 using ToArgSignatureHandler =
378 std::function<xla::StatusOr<ArgSignature>(py::handle, bool)>;
379 }
380
ArgSignatureOfValue(pybind11::handle arg,bool jax_enable_x64)381 xla::StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
382 bool jax_enable_x64) {
383 static const absl::flat_hash_map<PyObject*,
384 ToArgSignatureHandler>* const handlers = [] {
385 auto p = new absl::flat_hash_map<PyObject*, ToArgSignatureHandler>();
386
387 const auto xla_module = py::module::import("jax.interpreters.xla");
388 const auto& device_array = xla_module.attr("_DeviceArray");
389
390 const NumpyScalarTypes& dtypes = GetNumpyScalarTypes();
391
392 // The 4 Python native types.
393 ToArgSignatureHandler bool_handler =
394 [](py::handle, bool) -> xla::StatusOr<ArgSignature> {
395 return ArgSignature(xla::PrimitiveType::PRED, {}, true);
396 };
397 ToArgSignatureHandler int_handler =
398 [](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
399 if (jax_enable_x64) {
400 return ArgSignature(xla::PrimitiveType::S64, {}, true);
401 } else {
402 return ArgSignature(xla::PrimitiveType::S32, {}, true);
403 }
404 };
405 ToArgSignatureHandler float_handler =
406 [&dtypes](py::handle h,
407 bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
408 // Only Python native types has a True weak_type.
409 bool weak_type = !py::isinstance(h, dtypes.np_float64);
410 if (jax_enable_x64) {
411 return ArgSignature(xla::PrimitiveType::F64, {}, weak_type);
412 } else {
413 return ArgSignature(xla::PrimitiveType::F32, {}, weak_type);
414 }
415 };
416 ToArgSignatureHandler complex_handler =
417 [&dtypes](py::handle h,
418 bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
419 // Note that this branch is also taken for np.complex128:
420 // isinstance(np.complex128(3), complex) returns True
421 // isinstance(np.complex64(3), complex) returns False
422 bool weak_type = !py::isinstance(h, dtypes.np_complex128);
423 if (jax_enable_x64) {
424 return ArgSignature(xla::PrimitiveType::C128, {}, weak_type);
425 } else {
426 return ArgSignature(xla::PrimitiveType::C64, {}, weak_type);
427 }
428 };
429
430 (*p)[reinterpret_cast<PyObject*>(&PyBool_Type)] = bool_handler;
431 (*p)[reinterpret_cast<PyObject*>(&PyLong_Type)] = int_handler;
432 (*p)[reinterpret_cast<PyObject*>(&PyFloat_Type)] = float_handler;
433 (*p)[reinterpret_cast<PyObject*>(&PyComplex_Type)] = complex_handler;
434
435 // The Buffer types
436 // PyBuffer necessarily has a trivial LazyExpr, no need to check it.
437 ToArgSignatureHandler buffer_handler =
438 [](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
439 xla::PyBuffer* buffer = py::cast<xla::PyBuffer*>(h);
440 bool weak_type = py::cast<py::bool_>(h.attr("aval").attr("weak_type"));
441 return ArgSignature(buffer->buffer()->on_device_shape().element_type(),
442 buffer->buffer()->on_device_shape().dimensions(),
443 weak_type);
444 };
445 (*p)[py::type::handle_of<xla::DeviceArrayBase>().ptr()] = buffer_handler;
446 ToArgSignatureHandler device_array_handler =
447 [](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
448 py::handle aval = h.attr("aval");
449 TF_ASSIGN_OR_RETURN(auto dtype,
450 xla::DtypeToPrimitiveType(aval.attr("dtype")));
451 return ArgSignature(dtype,
452 py::cast<std::vector<xla::int64>>(aval.attr("shape")),
453 py::cast<py::bool_>(aval.attr("weak_type")));
454 };
455 // ShardedDeviceArray is covered by the MRO fallback on _DeviceArray.
456 (*p)[device_array.ptr()] = device_array_handler;
457
458 ToArgSignatureHandler numpy_handler =
459 [](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
460 py::array numpy_array = py::cast<py::array>(h);
461 if (IsFloat0(numpy_array)) {
462 return xla::InvalidArgument(
463 "float0 numpy arrays not supported in C++. "
464 "Falling back to Python.");
465 }
466 if (!jax_enable_x64) {
467 const py::dtype raw_dtype = numpy_array.dtype();
468 const py::dtype* to_dtype = DtypeTo32BitDtype(raw_dtype);
469
470 xla::PrimitiveType dtype;
471 if (to_dtype) {
472 TF_ASSIGN_OR_RETURN(dtype, xla::DtypeToPrimitiveType(*to_dtype));
473 } else {
474 TF_ASSIGN_OR_RETURN(dtype, xla::DtypeToPrimitiveType(raw_dtype));
475 }
476 // We need the reinterpret_cast for the OSS version to build.
477 return ArgSignature(
478 dtype,
479 absl::MakeConstSpan(
480 reinterpret_cast<const xla::int64*>(numpy_array.shape()),
481 numpy_array.ndim()),
482 /*weak_type=*/false);
483 }
484 TF_ASSIGN_OR_RETURN(auto dtype,
485 xla::DtypeToPrimitiveType(numpy_array.dtype()));
486 return ArgSignature(
487 dtype,
488 absl::MakeConstSpan(
489 reinterpret_cast<const xla::int64*>(numpy_array.shape()),
490 numpy_array.ndim()),
491 /*weak_type=*/false);
492 };
493 const auto numpy = py::module::import("numpy");
494 const auto& ndarray = numpy.attr("ndarray");
495 (*p)[ndarray.ptr()] = numpy_handler;
496
497 ToArgSignatureHandler np_uint64_handler =
498 [](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
499 if (jax_enable_x64) {
500 return ArgSignature(xla::PrimitiveType::U64, {}, /*weak_type=*/false);
501 } else {
502 return ArgSignature(xla::PrimitiveType::U32, {}, /*weak_type=*/false);
503 }
504 };
505 ToArgSignatureHandler np_int_handler =
506 [](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
507 if (jax_enable_x64) {
508 return ArgSignature(xla::PrimitiveType::S64, {}, /*weak_type=*/false);
509 } else {
510 return ArgSignature(xla::PrimitiveType::S32, {}, /*weak_type=*/false);
511 }
512 };
513 ToArgSignatureHandler numpy_array_handler =
514 [](py::handle h, bool jax_enable_x64) -> xla::StatusOr<ArgSignature> {
515 // This block deals with all numpy scalar types, except for int64_dt,
516 // float64_dt and complex128_dt which are taken care of in previous if
517 // blocks.
518 TF_ASSIGN_OR_RETURN(auto dtype,
519 xla::DtypeToPrimitiveType(h.attr("dtype")));
520 return ArgSignature(dtype, {}, /*weak_type=*/false);
521 };
522
523 // This block deals with all numpy scalar types, except for int64_dt,
524 // float64_dt and complex128_dt which are taken care of in previous if
525 // blocks.
526 (*p)[dtypes.np_bool.ptr()] = numpy_array_handler;
527 (*p)[dtypes.np_int8.ptr()] = numpy_array_handler;
528 (*p)[dtypes.np_int16.ptr()] = numpy_array_handler;
529 (*p)[dtypes.np_int32.ptr()] = numpy_array_handler;
530 (*p)[dtypes.np_int64.ptr()] = np_int_handler;
531 (*p)[dtypes.np_uint8.ptr()] = numpy_array_handler;
532 (*p)[dtypes.np_uint16.ptr()] = numpy_array_handler;
533 (*p)[dtypes.np_uint32.ptr()] = numpy_array_handler;
534 (*p)[dtypes.np_uint64.ptr()] = np_uint64_handler;
535 (*p)[dtypes.np_float16.ptr()] = numpy_array_handler;
536 (*p)[dtypes.np_float32.ptr()] = numpy_array_handler;
537 (*p)[dtypes.np_float64.ptr()] = float_handler;
538 (*p)[dtypes.np_complex64.ptr()] = numpy_array_handler;
539 (*p)[dtypes.np_complex128.ptr()] = complex_handler;
540 (*p)[dtypes.np_longlong.ptr()] = np_int_handler;
541 (*p)[dtypes.np_intc.ptr()] = numpy_array_handler;
542
543 return p;
544 }();
545
546 auto res = handlers->find(arg.get_type().ptr());
547 if (res == handlers->end()) {
548 // We attempt to look at the MRO classes
549 for (auto base_class : arg.get_type().attr("mro")()) {
550 res = handlers->find(base_class.ptr());
551 if (res != handlers->end()) {
552 return res->second(arg, jax_enable_x64);
553 }
554 }
555 return xla::InvalidArgument(
556 "%s", absl::StrCat("Not supported: The C++ ToArgSignature only accepts "
557 "Buffer/DeviceArray/ShardedDeviceArray, Numpy "
558 "arrays scalars of supported types "
559 "(see implementation), or Python scalars. Got type ",
560 py::cast<std::string>(py::str(arg.get_type()))));
561 } else {
562 return res->second(arg, jax_enable_x64);
563 }
564 }
565
566 namespace {
567 using DevicePutFunc = std::function<xla::StatusOr<DevicePutResult>(
568 py::handle, xla::PjRtDevice*, bool, xla::PyClient&)>;
569
HandleBool(py::handle h,xla::PjRtDevice * to_device,bool jax_enable_x64,xla::PyClient & pyclient)570 DevicePutResult HandleBool(py::handle h, xla::PjRtDevice* to_device,
571 bool jax_enable_x64, xla::PyClient& pyclient) {
572 return DevicePutResult(ConvertToScalarBuffer<bool, py::bool_>(
573 h, pyclient.pjrt_client(), to_device),
574 /*weak_type=*/true);
575 }
576
HandleInt(py::handle obj,xla::PjRtDevice * to_device,bool jax_enable_x64,xla::PyClient & pyclient)577 DevicePutResult HandleInt(py::handle obj, xla::PjRtDevice* to_device,
578 bool jax_enable_x64, xla::PyClient& pyclient) {
579 if (jax_enable_x64) {
580 return DevicePutResult(ConvertToScalarBuffer<xla::int64, py::int_>(
581 obj, pyclient.pjrt_client(), to_device),
582 /*weak_type=*/true);
583 } else {
584 return DevicePutResult(ConvertToScalarBuffer<int, py::int_>(
585 obj, pyclient.pjrt_client(), to_device),
586 /*weak_type=*/true);
587 }
588 }
589
590 template <bool weak_type>
HandleFloat(py::handle h,xla::PjRtDevice * to_device,bool jax_enable_x64,xla::PyClient & pyclient)591 xla::StatusOr<DevicePutResult> HandleFloat(py::handle h,
592 xla::PjRtDevice* to_device,
593 bool jax_enable_x64,
594 xla::PyClient& pyclient) {
595 if (jax_enable_x64) {
596 return DevicePutResult(ConvertToScalarBuffer<double, py::float_>(
597 h, pyclient.pjrt_client(), to_device),
598 /*weak_type=*/weak_type);
599 } else {
600 return DevicePutResult(ConvertToScalarBuffer<float, py::float_>(
601 h, pyclient.pjrt_client(), to_device),
602 /*weak_type=*/weak_type);
603 }
604 }
605
606 template <bool weak_type>
HandleComplex(py::handle h,xla::PjRtDevice * to_device,bool jax_enable_x64,xla::PyClient & pyclient)607 xla::StatusOr<DevicePutResult> HandleComplex(py::handle h,
608 xla::PjRtDevice* to_device,
609 bool jax_enable_x64,
610 xla::PyClient& pyclient) {
611 // This branch is also taken for np.complex128:
612 // isinstance(np.complex128(3), complex) returns True
613 // isinstance(np.complex64(3), complex) returns False
614 Py_complex result = PyComplex_AsCComplex(h.ptr());
615 if (result.real == -1.0 && PyErr_Occurred()) {
616 PyErr_Clear();
617 throw std::runtime_error("Could not convert the complex number");
618 }
619 if (jax_enable_x64) {
620 xla::complex128 data(result.real, result.imag);
621 xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex128>({});
622 return DevicePutResult(
623 ValueOrThrow(pyclient.pjrt_client()->BufferFromHostBuffer(
624 &data, shape,
625 xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
626 nullptr, to_device)),
627 /*weak_type=*/weak_type);
628 } else {
629 xla::complex64 data(result.real, result.imag);
630 xla::Shape shape = xla::ShapeUtil::MakeShapeWithType<xla::complex64>({});
631 return DevicePutResult(
632 ValueOrThrow(pyclient.pjrt_client()->BufferFromHostBuffer(
633 &data, shape,
634 xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall,
635 nullptr, to_device)),
636 /*weak_type=*/weak_type);
637 }
638 }
639
HandleDeviceArray(py::handle obj,xla::PjRtDevice * to_device,bool jax_enable_x64,xla::PyClient & pyclient)640 xla::StatusOr<DevicePutResult> HandleDeviceArray(py::handle obj,
641 xla::PjRtDevice* to_device,
642 bool jax_enable_x64,
643 xla::PyClient& pyclient) {
644 if (!IsTrivialLazyExpr(py::getattr(obj, "_lazy_expr"))) {
645 return xla::InvalidArgument(
646 "Non-trivial lazy expression not supported in C++. "
647 "Falling back to Python.");
648 }
649 xla::PyBuffer* buffer = py::cast<xla::PyBuffer*>(obj.attr("device_buffer"));
650 bool weak_type = py::cast<py::bool_>(obj.attr("aval").attr("weak_type"));
651 // Same block as in the previous `if (is_py_buffer)`.
652 if (buffer->device().contents == to_device) {
653 return DevicePutResult(buffer->buffer(), weak_type);
654 } else {
655 std::unique_ptr<xla::PjRtBuffer> copied_buffer =
656 ValueOrThrow(buffer->buffer()->CopyToDevice(to_device));
657 return DevicePutResult(std::move(copied_buffer), weak_type);
658 }
659 }
660
661 // Do not convert types, and only call PjRtBufferFromPyval, independently
662 // of the value of jax_enable_x64.
HandleBufferFromPyval(py::handle h,xla::PjRtDevice * to_device,bool jax_enable_x64,xla::PyClient & pyclient)663 DevicePutResult HandleBufferFromPyval(py::handle h, xla::PjRtDevice* to_device,
664 bool jax_enable_x64,
665 xla::PyClient& pyclient) {
666 std::unique_ptr<xla::PjRtBuffer> buffer =
667 ValueOrThrow(pyclient.PjRtBufferFromPyval(
668 h, to_device,
669 /*force_copy=*/false, /*host_buffer_semantics=*/
670 xla::PjRtClient::HostBufferSemantics::kZeroCopy));
671 return DevicePutResult(std::move(buffer), /*weak_type=*/false);
672 }
673
HandleNpBool(py::handle h,xla::PjRtDevice * to_device,bool jax_enable_x64,xla::PyClient & pyclient)674 DevicePutResult HandleNpBool(py::handle h, xla::PjRtDevice* to_device,
675 bool jax_enable_x64, xla::PyClient& pyclient) {
676 if (jax_enable_x64) {
677 return DevicePutResult(ConvertToScalarBuffer<xla::int64, py::int_>(
678 h, pyclient.pjrt_client(), to_device),
679 /*weak_type=*/false);
680 } else {
681 return DevicePutResult(ConvertToScalarBuffer<int, py::int_>(
682 h, pyclient.pjrt_client(), to_device),
683 /*weak_type=*/false);
684 }
685 }
686
HandleUint64(py::handle h,xla::PjRtDevice * to_device,bool jax_enable_x64,xla::PyClient & pyclient)687 DevicePutResult HandleUint64(py::handle h, xla::PjRtDevice* to_device,
688 bool jax_enable_x64, xla::PyClient& pyclient) {
689 if (jax_enable_x64) {
690 std::unique_ptr<xla::PjRtBuffer> buffer =
691 ValueOrThrow(pyclient.PjRtBufferFromPyval(
692 h, to_device,
693 /*force_copy=*/false, /*host_buffer_semantics=*/
694 xla::PjRtClient::HostBufferSemantics::kZeroCopy));
695 return DevicePutResult(std::move(buffer), /*weak_type=*/false);
696 } else {
697 static const auto* numpy = new py::module(py::module::import("numpy"));
698 const auto& np_array = numpy->attr("array");
699
700 // Note that this is calling back to Python!
701 std::unique_ptr<xla::PjRtBuffer> buffer =
702 ValueOrThrow(pyclient.PjRtBufferFromPyval(
703 np_array(h, py::dtype("uint32")), to_device,
704 /*force_copy=*/false, /*host_buffer_semantics=*/
705 xla::PjRtClient::HostBufferSemantics::kZeroCopy));
706 return DevicePutResult(std::move(buffer), /*weak_type=*/false);
707 }
708 }
709
HandleNdarray(py::handle h,xla::PjRtDevice * to_device,bool jax_enable_x64,xla::PyClient & pyclient)710 xla::StatusOr<DevicePutResult> HandleNdarray(py::handle h,
711 xla::PjRtDevice* to_device,
712 bool jax_enable_x64,
713 xla::PyClient& pyclient) {
714 py::array numpy_array = py::cast<py::array>(h);
715 if (IsFloat0(numpy_array)) {
716 return xla::InvalidArgument("%s",
717 "float0 numpy arrays not supported in C++. "
718 "Falling back to Python.");
719 }
720 // If jax_enable_x64 is not set, we need to coerce 32 bits types.
721 // Note that this is calling back to Python!
722 if (!jax_enable_x64) {
723 const py::dtype* to_dtype = DtypeTo32BitDtype(numpy_array.dtype());
724 if (to_dtype) {
725 static const auto* numpy = new py::module(py::module::import("numpy"));
726 const auto& np_array = numpy->attr("array");
727 numpy_array = np_array(numpy_array, *to_dtype);
728 }
729 }
730 std::unique_ptr<xla::PjRtBuffer> buffer =
731 ValueOrThrow(pyclient.PjRtBufferFromPyval(
732 numpy_array, to_device,
733 /*force_copy=*/false, /*host_buffer_semantics=*/
734 xla::PjRtClient::HostBufferSemantics::kZeroCopy));
735 return DevicePutResult(std::move(buffer), /*weak_type=*/false);
736 }
737
738 } // namespace
739
DevicePut(pybind11::handle arg,xla::PjRtDevice * to_device,bool jax_enable_x64,xla::PyClient & pyclient)740 xla::StatusOr<DevicePutResult> DevicePut(pybind11::handle arg,
741 xla::PjRtDevice* to_device,
742 bool jax_enable_x64,
743 xla::PyClient& pyclient) {
744 static const absl::flat_hash_map<PyObject*, DevicePutFunc>* const handlers =
745 [] {
746 auto p = new absl::flat_hash_map<PyObject*, DevicePutFunc>();
747
748 const NumpyScalarTypes& dtypes = GetNumpyScalarTypes();
749
750 const auto numpy = py::module::import("numpy");
751 const auto xla_module = py::module::import("jax.interpreters.xla");
752 const auto& device_array = xla_module.attr("_DeviceArray");
753
754 // Python base types.
755 (*p)[reinterpret_cast<PyObject*>(&PyBool_Type)] = HandleBool;
756 (*p)[reinterpret_cast<PyObject*>(&PyLong_Type)] = HandleInt;
757 (*p)[reinterpret_cast<PyObject*>(&PyFloat_Type)] = HandleFloat<true>;
758 (*p)[reinterpret_cast<PyObject*>(&PyComplex_Type)] =
759 HandleComplex<true>;
760
761 // DeviceArray and co.
762 const auto pxla_module = py::module::import("jax.interpreters.pxla");
763 const auto& sda = pxla_module.attr("ShardedDeviceArray");
764 (*p)[device_array.ptr()] = HandleDeviceArray;
765 (*p)[py::type::handle_of<xla::DeviceArrayBase>().ptr()] =
766 HandleDeviceArray;
767 (*p)[sda.ptr()] = HandleBufferFromPyval;
768 // Numpy arrays.
769 (*p)[numpy.attr("ndarray").ptr()] = HandleNdarray;
770
771 // Numpy scalar types. For some of them, we share the handler with
772 // Python types (np_int64, np_float64, np_complex128).
773 (*p)[dtypes.np_bool.ptr()] = HandleBufferFromPyval;
774 (*p)[dtypes.np_int8.ptr()] = HandleBufferFromPyval;
775 (*p)[dtypes.np_int16.ptr()] = HandleBufferFromPyval;
776 (*p)[dtypes.np_int32.ptr()] = HandleBufferFromPyval;
777 (*p)[dtypes.np_int64.ptr()] = HandleNpBool;
778 (*p)[dtypes.np_uint8.ptr()] = HandleBufferFromPyval;
779 (*p)[dtypes.np_uint16.ptr()] = HandleBufferFromPyval;
780 (*p)[dtypes.np_uint32.ptr()] = HandleBufferFromPyval;
781 (*p)[dtypes.np_uint64.ptr()] = HandleUint64;
782 (*p)[dtypes.np_float16.ptr()] = HandleBufferFromPyval;
783 (*p)[dtypes.np_float32.ptr()] = HandleBufferFromPyval;
784 (*p)[dtypes.np_float64.ptr()] = HandleFloat<false>;
785 (*p)[dtypes.np_complex64.ptr()] = HandleBufferFromPyval;
786 (*p)[dtypes.np_complex128.ptr()] = HandleComplex<false>;
787 (*p)[dtypes.np_longlong.ptr()] = HandleNpBool;
788 (*p)[dtypes.np_intc.ptr()] = HandleBufferFromPyval;
789
790 return p;
791 }();
792
793 auto res = handlers->find(arg.get_type().ptr());
794 if (res == handlers->end()) {
795 for (auto base_class : arg.get_type().attr("mro")()) {
796 res = handlers->find(base_class.ptr());
797 if (res != handlers->end()) {
798 return res->second(arg, to_device, jax_enable_x64, pyclient);
799 }
800 }
801 return xla::InvalidArgument(
802 "%s", absl::StrCat(
803 "Not supported: The C++ jax jit execution path, only accepts "
804 "DeviceArray, Numpy arrays scalars of supported types "
805 "(see implementation), or Python scalars. Got type ",
806 py::cast<std::string>(py::str(arg.get_type()))));
807 } else {
808 return res->second(arg, to_device, jax_enable_x64, pyclient);
809 }
810 }
811
812 namespace {
813
814 struct CacheEntry {
815 std::shared_ptr<xla::PyExecutable> executable;
816 xla::PyTreeDef out_pytree_def;
817 // We use Python types within the vector because this is what we will be
818 // returning to Python. No need to convert back and forth.
819 // We need py::object to maintain the objects alive.
820 std::vector<py::object> out_avals;
821 // The processing done in `AddCacheEntry` ensures that LazyExpr are stored as
822 // `py::none()`.
823 std::vector<py::object> out_lazy_exprs;
824 py::object sticky_device;
825
826 // Ensures a single thread performs the compilation for a given executable.
827 //
828 // The first thread (holding the GIL) will create the CacheEntry associated to
829 // a signature and if the object has been insterted already, other threads
830 // will wait for the notification.
831 absl::Notification compilation_complete;
832 absl::optional<xla::Status> compilation_error = absl::nullopt;
833 // Trivial computation will fallback to Python.
834 // Running a jax(pmap) will also fallback to Python.
835 bool fall_back_to_python = false;
836 };
837
838 // A `CompiledFunction` is associated to a `jax.jit(f)` and takes care of the
839 // bookkeeping of the different signatures used and the dispatch of calls to
840 // the correct underlying `PyExecutable`. This class is thread-safe.
841 class CompiledFunction {
842 public:
843 CompiledFunction(py::function fun, py::function cache_miss,
844 py::function get_device, py::function get_jax_enable_x64,
845 py::function get_jax_disable_jit,
846 std::vector<int> static_argnums);
847 ~CompiledFunction();
848
849 // This function will:
850 // (a) flatten the inputs using pytree
851 // (b) get buffer objects from the arguments
852 // (c) call the executable
853 // (d) construct `DeviceArray` objects from the outputs
854 // (e) reconstruct the `PyTree`.
855 py::object Call(py::args args, py::kwargs kwargs);
856
857 // This allows `inspect.signature(cpp_jitted_f)` from Python.
PythonSignature()858 py::object PythonSignature() {
859 static const auto* inspect = new py::module(py::module::import("inspect"));
860 return inspect->attr("signature")(fun_);
861 }
862
cache_size() const863 int cache_size() const { return executables_.size(); }
864
865 private:
866 // Returns nullptr if not present in the cache.
867 CacheEntry* GetCacheEntryIfPresent(const CallSignature& signature);
868 // Should never return nullptr.
869 CacheEntry* AddCacheEntry(const py::args& args, const py::kwargs& kwargs,
870 const CallSignature& signature,
871 py::object out_and_fastpath_data);
JitIsDisabled()872 bool JitIsDisabled() { return GetDisableJit() || jax_disable_jit_.value(); }
873
874 bool always_fallback_to_python_ = false;
875
876 const py::function fun_; // The Python function to jit.
877 // See JAX _cpp_jit in api.py for documentation.
878 const py::function cache_miss_;
879
880 // We need to know the static arguments to remove them from the arguments
881 // passed to the underlying PyExecutable. In sorted order.
882 std::vector<int> static_argnums_;
883 // We need a `unique_ptr` here to ensure value pointer stability.
884 absl::flat_hash_map<CallSignature, std::unique_ptr<CacheEntry>> executables_;
885
886 // As top-level functions are decorated with `jax.jit`, when
887 // `CompiledFunction` is being instantiated from Python, the clients are not
888 // yet available (done after GoogleInit). They will be during the first call
889 // to `Call`.
890 // A function taking no arguments and returning the default device and whether
891 // jax.jit has been committed to it.
892 const py::function get_jax_enable_x64_;
893 const py::function get_jax_disable_jit_;
894 const py::function get_device_;
895
896 // The writing of the following is protected by the mutex.
897 absl::Mutex mu_;
898 // The value of the Python flag. The value will be computed only during the
899 // first object call, because GoogleInit must have been executed.
900 absl::optional<bool> jax_enable_x64_ = absl::nullopt;
901 absl::optional<bool> jax_disable_jit_ = absl::nullopt;
902
903 // The logic if the following:
904 // - if `device` or `backend` are not specified to `jax.jit`, we will use
905 // the input sticky buffer device, or `default_device_` if there is no
906 // such sticky buffer.
907 // - When one of `device` or `backend` is specified, this will determine
908 // the `default_device_` which will be used as the targeted device. In
909 // which case, we will always copy input buffers to this device.
910 std::shared_ptr<xla::PyClient> default_pyclient_ = nullptr;
911 xla::ClientAndPtr<xla::PjRtDevice> default_pydevice_;
912 xla::PjRtDevice* default_device_ = nullptr;
913 bool is_committed_;
914 };
915
CompiledFunction(py::function fun,py::function cache_miss,py::function get_device,py::function get_jax_enable_x64,py::function get_jax_disable_jit,std::vector<int> static_argnums)916 CompiledFunction::CompiledFunction(py::function fun, py::function cache_miss,
917 py::function get_device,
918 py::function get_jax_enable_x64,
919 py::function get_jax_disable_jit,
920 std::vector<int> static_argnums)
921 : fun_(std::move(fun)),
922 cache_miss_(std::move(cache_miss)),
923 static_argnums_(std::move(static_argnums)),
924 get_jax_enable_x64_(get_jax_enable_x64),
925 get_jax_disable_jit_(get_jax_disable_jit),
926 get_device_(std::move(get_device)) {
927 std::sort(static_argnums_.begin(), static_argnums_.end());
928 }
929
~CompiledFunction()930 CompiledFunction::~CompiledFunction() {
931 for (const auto& entry : executables_) {
932 entry.first.DecRef();
933 }
934 }
935
936 // Converts flattened arguments contained in ParsedArgumentsAsBuffers in
937 // place. If arguments are `DeviceArray`, they must all be on the same `Device`.
938 //
939 // Returns `Okxla::Status()` on success. Returning an error should lead to
940 // calling the Python fallback.
ConvertArgsToBuffers(bool jax_enable_x64,xla::PyClient & pyclient,xla::PjRtDevice * default_device,bool is_committed,ParsedArgumentsAsBuffers & arguments)941 xla::Status ConvertArgsToBuffers(bool jax_enable_x64, xla::PyClient& pyclient,
942 xla::PjRtDevice* default_device,
943 bool is_committed,
944 ParsedArgumentsAsBuffers& arguments) {
945 std::vector<xla::PjRtBuffer*>& arg_buffers = arguments.arg_buffers;
946 auto& keep_alive = arguments.keep_alive;
947
948 int num_flat_dynamic_args = arguments.flat_dynamic_args.size();
949 arg_buffers.reserve(num_flat_dynamic_args);
950 arguments.signature.dynamic_args_signatures.reserve(num_flat_dynamic_args);
951
952 static const auto* xla_module =
953 new py::module(py::module::import("jax.interpreters.xla"));
954 const auto& device_array = xla_module->attr("_DeviceArray");
955
956 // When the jitted function is not committed, we first check whether any
957 // sticky `DeviceArray` is present and on which device they live. See also:
958 // https://github.com/google/jax/pull/1884
959 // https://github.com/google/jax/pull/1916 for the rationale why the
960 // computation follows the data locality.
961 // It's also similar to PyTorch's behavior.
962 xla::PjRtDevice* data_device = nullptr;
963 if (is_committed) {
964 data_device = default_device;
965 } else {
966 for (py::handle arg : arguments.flat_dynamic_args) {
967 // We specically only deal with DeviceArray (not ShardedDeviceArray).
968 // (Can happen in jit(pmap), e.g. "test_jit_nested_donate_ignored").
969 if (py::isinstance<xla::PyBuffer>(arg) ||
970 arg.get_type().is(device_array)) {
971 xla::PyBuffer* buffer;
972 if (arg.attr("_device").is_none()) { // Skip non-sticky devices.
973 continue;
974 }
975 try {
976 // This can fail, e.g. when device_buffer is a `DeviceConstant`.
977 buffer = py::cast<xla::PyBuffer*>(arg.attr("device_buffer"));
978 } catch (const py::cast_error& e) {
979 return xla::InvalidArgument(
980 "%s",
981 absl::StrCat("[jaxjit] Unsupported subclass of `DeviceArray`: "
982 "`device_buffer` field is of type ",
983 py::cast<std::string>(
984 arg.attr("device_buffer").get_type().str()),
985 " while a `PyBuffer` was expected."
986
987 ));
988 }
989 xla::PjRtDevice* device = buffer->buffer()->device();
990 if (data_device && (device != data_device)) {
991 throw std::invalid_argument(absl::StrCat(
992 "primitive arguments must be colocated on the same device ("
993 "C++ jax.jit). Arguments are on devices: ",
994 device->DebugString(), " and ", data_device->DebugString()));
995 } else {
996 data_device = device;
997 }
998 }
999 }
1000 }
1001 if (!data_device) {
1002 // No `DeviceArray` were found default to `default_device`.
1003 data_device = default_device;
1004 }
1005 CHECK(data_device);
1006 arguments.signature.device = data_device;
1007
1008 for (py::handle arg : arguments.flat_dynamic_args) {
1009 TF_ASSIGN_OR_RETURN(DevicePutResult on_device,
1010 DevicePut(arg, data_device, jax_enable_x64, pyclient));
1011
1012 xla::PjRtBuffer* buffer = on_device.buffer;
1013 arg_buffers.push_back(buffer);
1014 if (on_device.owned_buffer) {
1015 keep_alive.emplace_back(std::move(on_device.owned_buffer));
1016 }
1017
1018 ArgSignature sig(buffer->on_device_shape().element_type(),
1019 buffer->on_device_shape().dimensions(),
1020 on_device.weak_type);
1021 arguments.signature.dynamic_args_signatures.push_back(std::move(sig));
1022 }
1023 return xla::Status::OK();
1024 }
1025
GetCacheEntryIfPresent(const CallSignature & signature)1026 CacheEntry* CompiledFunction::GetCacheEntryIfPresent(
1027 const CallSignature& signature) {
1028 auto found_iterator = executables_.find(signature);
1029 if (found_iterator != executables_.end()) { // Cache hit!
1030 if (!found_iterator->second->compilation_complete.HasBeenNotified()) {
1031 py::gil_scoped_release gil_release;
1032 found_iterator->second->compilation_complete.WaitForNotification();
1033 }
1034 if (found_iterator->second->compilation_error) {
1035 throw std::invalid_argument(
1036 found_iterator->second->compilation_error.value().error_message());
1037 }
1038 return found_iterator->second.get();
1039 }
1040 return nullptr;
1041 }
1042
AddCacheEntry(const py::args & args,const py::kwargs & kwargs,const CallSignature & signature,py::object out_and_fastpath_data)1043 CacheEntry* CompiledFunction::AddCacheEntry(const py::args& args,
1044 const py::kwargs& kwargs,
1045 const CallSignature& signature,
1046 py::object out_and_fastpath_data) {
1047 // We need to insert the element.
1048 auto result = executables_.emplace(signature, std::make_unique<CacheEntry>());
1049 auto it = result.first;
1050 CacheEntry* cache_entry = it->second.get();
1051 // CallSignatures in the cache own their keyword argument reference.
1052 result.first->first.IncRef();
1053
1054 py::tuple tuple = py::cast<py::tuple>(out_and_fastpath_data);
1055 CHECK_EQ(tuple.size(), 2);
1056 if (tuple[1].is_none()) {
1057 cache_entry->fall_back_to_python = true;
1058 cache_entry->compilation_complete.Notify();
1059 return cache_entry;
1060 }
1061
1062 py::tuple executable_handlers_out_tree = py::cast<py::tuple>(tuple[1]);
1063 if (executable_handlers_out_tree.size() != 5) {
1064 throw std::runtime_error(absl::StrCat(
1065 "The versions of jaxlib and Jax are incompatible (jaxlib is too recent "
1066 "compared to Jax. Upgrade Jax is advised. The C++ code expects "
1067 "5 arguments but ",
1068 executable_handlers_out_tree.size(), " where provided: ",
1069 py::cast<std::string>(
1070 py::str(py::repr(executable_handlers_out_tree)))));
1071 }
1072 // (xla_executable, out_pytree_def, sticky_device, avals, lazy_exprs)
1073 auto executable = py::cast<std::shared_ptr<xla::PyExecutable>>(
1074 executable_handlers_out_tree[0]);
1075 cache_entry->executable = std::move(executable);
1076 int num_devices =
1077 cache_entry->executable->pjrt_executable().addressable_devices().size();
1078 // The presence of jit(pmap) is detected from Python.
1079 CHECK_EQ(num_devices, 1);
1080
1081 auto out_tree = py::cast<xla::PyTreeDef>(executable_handlers_out_tree[1]);
1082 cache_entry->out_pytree_def = std::move(out_tree);
1083
1084 cache_entry->sticky_device =
1085 py::cast<py::object>(executable_handlers_out_tree[2]);
1086 auto avals = py::cast<py::list>(executable_handlers_out_tree[3]);
1087 auto lazy_exprs = py::cast<py::list>(executable_handlers_out_tree[4]);
1088 CHECK_EQ(avals.size(), lazy_exprs.size());
1089
1090 cache_entry->out_avals.reserve(avals.size());
1091 cache_entry->out_lazy_exprs.reserve(avals.size());
1092 for (int i = 0; i < avals.size(); ++i) {
1093 py::object shaped_array = py::reinterpret_borrow<py::object>(avals[i]);
1094 py::object lazy_expr = py::reinterpret_borrow<py::object>(lazy_exprs[i]);
1095
1096 cache_entry->out_avals.push_back(shaped_array);
1097 CHECK(lazy_expr.is_none() || !IsTrivialLazyExpr(lazy_expr));
1098 cache_entry->out_lazy_exprs.push_back(lazy_expr);
1099 }
1100
1101 cache_entry->compilation_complete.Notify();
1102 return cache_entry;
1103 }
1104
Call(py::args args,py::kwargs kwargs)1105 py::object CompiledFunction::Call(py::args args, py::kwargs kwargs) {
1106 if (always_fallback_to_python_) {
1107 return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
1108 }
1109 // Delayed values are retrieved on the first call to `Call`.
1110 if (!default_device_) {
1111 // As we are calling Python code, that may release the GIL, we first hold
1112 // mu_ before holding the GIL.
1113 py::gil_scoped_release gil_release;
1114 {
1115 absl::MutexLock lock1(&mu_);
1116 py::gil_scoped_acquire gil_aquire;
1117
1118 jax_enable_x64_ = py::cast<bool>(get_jax_enable_x64_());
1119 jax_disable_jit_ = py::cast<bool>(get_jax_disable_jit_());
1120 if (!default_device_) {
1121 py::object device_and_is_committed = get_device_();
1122 try {
1123 default_pydevice_ = py::cast<xla::ClientAndPtr<xla::PjRtDevice>>(
1124 device_and_is_committed.attr("default_device"));
1125 } catch (const py::cast_error& e) {
1126 // Pathways and Cloud TPU 2VM runtime.
1127 always_fallback_to_python_ = true;
1128 return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
1129 }
1130 default_pyclient_ = default_pydevice_.client;
1131 default_device_ = default_pydevice_.contents;
1132 if (!default_device_) { // UPTC
1133 always_fallback_to_python_ = true;
1134 return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
1135 }
1136 is_committed_ =
1137 py::cast<bool>(device_and_is_committed.attr("committed_to_device"));
1138 }
1139 }
1140 }
1141 CHECK(default_device_);
1142 if (JitIsDisabled()) {
1143 return fun_(*args, **kwargs);
1144 }
1145 ParsedArgumentsAsBuffers arguments;
1146 if (!ParseArguments(args, kwargs, static_argnums_, arguments).ok()) {
1147 return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
1148 }
1149
1150 // The C++ jit do not support Tracers arguments inputs yet. The Python-based
1151 // jit function will be called if any of the dynamic arguments is unsupported.
1152 if (!ConvertArgsToBuffers(jax_enable_x64_.value(), *default_pyclient_,
1153 default_device_, is_committed_, arguments)
1154 .ok()) {
1155 return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
1156 }
1157
1158 CacheEntry* cache_entry = GetCacheEntryIfPresent(arguments.signature);
1159
1160 if (!cache_entry) {
1161 py::object out_and_fastpath_data = cache_miss_(*args, **kwargs);
1162 cache_entry = GetCacheEntryIfPresent(arguments.signature);
1163 if (!cache_entry) {
1164 cache_entry = AddCacheEntry(args, kwargs, arguments.signature,
1165 out_and_fastpath_data);
1166 }
1167 CHECK(cache_entry);
1168 if (cache_entry->fall_back_to_python) {
1169 return py::cast<py::tuple>(out_and_fastpath_data)[0];
1170 }
1171 // As we have already computed the results, we can return it.
1172 // It's even *required* e.g. if there are donated arguments, because
1173 // otherwise the buffer which has been donated already will be invalid.
1174 return py::cast<py::tuple>(out_and_fastpath_data)[0];
1175 }
1176 CHECK(cache_entry);
1177 if (cache_entry->fall_back_to_python) {
1178 return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
1179 }
1180 std::vector<std::unique_ptr<xla::PyBuffer>> outputs =
1181 ValueOrThrow(cache_entry->executable->PjRtExecute(arguments.arg_buffers));
1182
1183 const std::vector<py::object>& out_avals = cache_entry->out_avals;
1184 const std::vector<py::object>& out_lazy_exprs = cache_entry->out_lazy_exprs;
1185 const py::object& sticky_device = cache_entry->sticky_device;
1186
1187 py::list flat_device_arrays;
1188 for (int i = 0; i < outputs.size(); ++i) {
1189 auto& buffer = outputs[i];
1190 if (out_lazy_exprs[i].is_none()) { // No LazyExpr.
1191 buffer->SetAval(out_avals[i]);
1192 buffer->SetStickyDevice(sticky_device);
1193 flat_device_arrays.append(py::cast(std::move(outputs[i])));
1194 } else {
1195 static const auto* xla_module =
1196 new py::module(py::module::import("jax.interpreters.xla"));
1197 static const auto* device_array =
1198 new py::handle(xla_module->attr("_DeviceArray"));
1199 flat_device_arrays.append(
1200 (*device_array)(out_avals[i], sticky_device, out_lazy_exprs[i],
1201 py::cast(std::move(outputs[i]))));
1202 }
1203 }
1204 return cache_entry->out_pytree_def.Unflatten(flat_device_arrays);
1205 }
1206
1207 } // namespace
1208
BuildJaxjitSubmodule(pybind11::module & m)1209 void BuildJaxjitSubmodule(pybind11::module& m) {
1210 py::module jitlib = m.def_submodule("jax_jit", "Jax C++ jit library");
1211
1212 py::class_<CompiledFunction, std::unique_ptr<CompiledFunction>> cfun(
1213 jitlib, "CompiledFunction");
1214 cfun.def("__call__", &CompiledFunction::Call);
1215 cfun.def_property_readonly("__signature__",
1216 &CompiledFunction::PythonSignature);
1217
1218 jitlib.def("set_disable_jit", &SetDisableJit);
1219 jitlib.def("get_disable_jit", &GetDisableJit);
1220 jitlib.def(
1221 "jit",
1222 [](py::function fun, py::function cache_miss, py::function get_device,
1223 py::function get_jax_enable_x64, py::function get_jax_disable_jit,
1224 std::vector<int> static_argnums) -> std::unique_ptr<CompiledFunction> {
1225 return std::make_unique<CompiledFunction>(
1226 std::move(fun), std::move(cache_miss), std::move(get_device),
1227 std::move(get_jax_enable_x64), std::move(get_jax_disable_jit),
1228 std::move(static_argnums));
1229 });
1230
1231 // This function is yet a full replacement for the Python one, because:
1232 // (a) it does not support abstract types,
1233 // (b) it does not set the device stickiness yet.
1234 // TODO(jblespiau): Finish the replacement of the Python feature.
1235 jitlib.def("device_put", [](py::handle obj, bool jax_enable_x64,
1236 xla::ClientAndPtr<xla::PjRtDevice> to_device) {
1237 std::shared_ptr<xla::PyClient>& pyclient = to_device.client;
1238 xla::StatusOr<DevicePutResult> results =
1239 DevicePut(obj, to_device.contents, jax_enable_x64, *pyclient);
1240 if (!results.ok()) {
1241 throw std::runtime_error(results.status().error_message());
1242 }
1243 if (results->owned_buffer) {
1244 auto buffer = std::make_unique<xla::PyBuffer>(
1245 pyclient, std::move(results->owned_buffer), xla::Traceback::Get());
1246
1247 static const auto* jax_core =
1248 new py::module(py::module::import("jax.core"));
1249 static const auto* shaped_array =
1250 new py::handle(jax_core->attr("ShapedArray"));
1251 buffer->SetAval((*shaped_array)(
1252 buffer->python_shape(), buffer->python_dtype(), results->weak_type));
1253 buffer->SetStickyDevice(py::none());
1254
1255 return py::cast(std::move(buffer));
1256 } else {
1257 return py::cast<py::object>(obj);
1258 }
1259 });
1260
1261 py::class_<ArgSignature> arg_signature(jitlib, "ArgSignature");
1262 arg_signature
1263 .def_property_readonly("dtype",
1264 [](const ArgSignature& sig) {
1265 return PrimitiveTypeToDtype(sig.dtype);
1266 })
1267 .def_property_readonly("shape",
1268 [](const ArgSignature& sig) {
1269 return xla::IntSpanToTuple(sig.shape);
1270 })
1271 .def_readonly("weak_type", &ArgSignature::weak_type);
1272 jitlib.def("_ArgSignatureOfValue", &ArgSignatureOfValue);
1273
1274 // All private members are only for testing purposes
1275 cfun.def("_cache_size", &CompiledFunction::cache_size);
1276 jitlib.def("_DtypeTo32BitDtype", [](const py::object obj) -> py::object {
1277 py::dtype dtype = py::dtype::from_args(obj);
1278 const py::dtype* res = DtypeTo32BitDtype(dtype);
1279 if (res) {
1280 return *res;
1281 } else {
1282 return py::none();
1283 }
1284 });
1285 jitlib.def("_is_float0", &IsFloat0);
1286 jitlib.def("_is_trivial", &IsTrivialLazyExpr);
1287 }
1288
1289 } // namespace jax
1290