1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This files implements the `jax.jit` dispatch and just-in-time feature.
17 //
18 // In a nutshell, `Jit(f)` returns a callable that will dispatch (i.e. forward
19 // based on passed arguments dtypes/shapes/identity) the execution to a
20 // just-in-time compiled XLA Executable. All of that is done in C++ for
21 // performance reasons.
22 //
23 // This file contains the utilities to:
24 // (a) inspect arguments and describe their structure, dtype/shapes, etc.
25 // (b) keep a mapping from function signatures to compiled XLA Executables.
26
27 #include "tensorflow/compiler/xla/python/jax_jit.h"
28
29 #include <Python.h>
30
31 #include <algorithm>
32 #include <exception>
33 #include <memory>
34 #include <stdexcept>
35 #include <string>
36 #include <utility>
37
38 #include "absl/container/flat_hash_map.h"
39 #include "absl/container/inlined_vector.h"
40 #include "absl/strings/str_cat.h"
41 #include "absl/strings/str_format.h"
42 #include "absl/synchronization/notification.h"
43 #include "absl/types/optional.h"
44 #include "absl/types/span.h"
45 #include "pybind11/cast.h"
46 #include "pybind11/numpy.h"
47 #include "pybind11/pybind11.h"
48 #include "pybind11/pytypes.h"
49 #include "tensorflow/compiler/xla/pjrt/lru_cache.h"
50 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
51 #include "tensorflow/compiler/xla/python/py_buffer.h"
52 #include "tensorflow/compiler/xla/python/py_executable.h"
53 #include "tensorflow/compiler/xla/python/py_values.h"
54 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
55 #include "tensorflow/compiler/xla/python/pytree.h"
56 #include "tensorflow/compiler/xla/python/types.h"
57 #include "tensorflow/compiler/xla/shape_util.h"
58 #include "tensorflow/compiler/xla/statusor.h"
59 #include "tensorflow/compiler/xla/types.h"
60 #include "tensorflow/compiler/xla/util.h"
61 #include "tensorflow/compiler/xla/xla_data.pb.h"
62 #include "tensorflow/core/platform/status.h"
63 #include "tensorflow/core/profiler/lib/traceme.h"
64
65 namespace jax {
66
67 namespace py = pybind11;
68
69 // TODO(phawkins): Add support for Tracers.
70 // TODO(jblespiau): Use absl Status.
71
72 namespace {
73
74 // Protected by the GIL.
75 GlobalJitState& global_state = *new GlobalJitState();
76
77 // TODO(phawkins): Google style guide forbids thread-local values with
78 // non-trivial destructors.
79 ABSL_CONST_INIT thread_local ThreadLocalJitState thread_local_state; // NOLINT
80
JitIsDisabled()81 bool JitIsDisabled() {
82 return thread_local_state.disable_jit.value_or(global_state.disable_jit);
83 }
84
85 } // namespace
86
GetGlobalState()87 GlobalJitState& GetGlobalState() { return global_state; }
GetLocalState()88 ThreadLocalJitState& GetLocalState() { return thread_local_state; }
89
GetEnableX64()90 bool GetEnableX64() {
91 return thread_local_state.enable_x64.value_or(global_state.enable_x64);
92 }
93
DebugString() const94 std::string CallSignature::DebugString() const {
95 auto py_object_formatter = [](std::string* out, const py::object& o) {
96 out->append(py::cast<std::string>(py::str(o)));
97 };
98 auto treedef_formatter = [](std::string* out, const xla::PyTreeDef& d) {
99 out->append(d.ToString());
100 };
101 auto signature_formatter = [](std::string* out,
102 const xla::PyArgSignature& s) {
103 out->append(s.DebugString());
104 };
105 std::string thread_local_extra_jit_context_str;
106 if (thread_local_extra_jit_context.has_value()) {
107 thread_local_extra_jit_context_str =
108 py::cast<std::string>(py::str(thread_local_extra_jit_context.value()));
109 } else {
110 thread_local_extra_jit_context_str = "None";
111 }
112 return absl::StrFormat(
113 "static args (positional + keyword): %s\nstatic arg keyword names: %s\n"
114 "dynamic arg signatures (positional + keyword): %s\n"
115 "dynamic arg keyword names: %s\ndynamic arg treedefs: %s\n"
116 "device: %p\n"
117 "jax_enable_x64: %d\n"
118 "global_extra_jit_context: %s\n"
119 "thread_local_extra_jit_context: %s\n",
120 absl::StrJoin(static_args, ",", py_object_formatter),
121 absl::StrJoin(static_arg_names, ",", py_object_formatter),
122 absl::StrJoin(dynamic_arg_signatures, ", ", signature_formatter),
123 absl::StrJoin(dynamic_arg_names, ",", py_object_formatter),
124 absl::StrJoin(dynamic_arg_treedefs, "| ", treedef_formatter), // new line
125 device, jax_enable_x64,
126 py::cast<std::string>(py::str(global_extra_jit_context)),
127 thread_local_extra_jit_context_str);
128 }
129
operator ==(const CallSignature & other) const130 bool CallSignature::operator==(const CallSignature& other) const {
131 return std::tie(dynamic_arg_treedefs, dynamic_arg_names,
132 dynamic_arg_signatures, device, jax_enable_x64,
133 static_arg_names) ==
134 std::tie(other.dynamic_arg_treedefs, other.dynamic_arg_names,
135 other.dynamic_arg_signatures, other.device,
136 other.jax_enable_x64, static_arg_names) &&
137 // `==` on py:objects is the Python `is`. We need equal.
138 std::equal(
139 static_args.begin(), static_args.end(), other.static_args.begin(),
140 other.static_args.end(),
141 [](const py::object& a, const py::object& b) {
142 try {
143 return a.equal(b);
144 } catch (const py::error_already_set& e) {
145 throw std::invalid_argument(absl::StrCat(
146 "static arguments should be comparable using __eq__."
147 "The following error was raised when comparing two "
148 "objects of types ",
149 py::cast<std::string>(py::str(py::type::of(a))), " and ",
150 py::cast<std::string>(py::str(py::type::of(b))),
151 ". The error was:\n", e.what()));
152 }
153 }) &&
154 global_extra_jit_context.equal(other.global_extra_jit_context) &&
155 (thread_local_extra_jit_context.has_value() ==
156 other.thread_local_extra_jit_context.has_value()) &&
157 (!thread_local_extra_jit_context.has_value() ||
158 thread_local_extra_jit_context->equal(
159 *other.thread_local_extra_jit_context));
160 }
161
162 template <typename H>
AbslHashValue(H h,const CallSignature & s)163 H AbslHashValue(H h, const CallSignature& s) {
164 h = H::combine_contiguous(std::move(h), s.dynamic_arg_treedefs.data(),
165 s.dynamic_arg_treedefs.size());
166 for (const auto& name : s.dynamic_arg_names) {
167 h = H::combine(std::move(h), name.ptr());
168 }
169 h = H::combine_contiguous(std::move(h), s.dynamic_arg_signatures.data(),
170 s.dynamic_arg_signatures.size());
171 for (const auto& static_arg : s.static_args) {
172 ssize_t hash;
173 try {
174 hash = py::hash(static_arg);
175 } catch (const py::error_already_set& e) {
176 throw std::invalid_argument(absl::StrCat(
177 "Non-hashable static arguments are not supported. An error occured "
178 "while trying to hash an object of type ",
179 py::cast<std::string>(py::str(py::type::of(static_arg))), ", ",
180 py::cast<std::string>(py::str(static_arg)), ". The error was:\n",
181 e.what(), "\n"));
182 }
183 h = H::combine(std::move(h), hash);
184 }
185 for (const auto& name : s.static_arg_names) {
186 h = H::combine(std::move(h), name.ptr());
187 }
188 h = H::combine(std::move(h), s.device, s.jax_enable_x64);
189
190 // We do not hash the extra_jit_context fields since calling Python hash
191 // functions is expensive (~300ns) and we don't expect a large number of
192 // different contexts.
193 return h;
194 }
195
196 // Filter out static arguments, flatten and concatenate other arguments (i.e.
197 // dynamic positional and keyword arguments), filling `arguments` in place.
ParseArguments(py::handle args,const absl::optional<py::kwargs> & py_kwargs,absl::Span<int const> static_argnums,absl::Span<py::str const> static_argnames,ParsedArgumentsAsBuffers & arguments)198 xla::Status ParseArguments(py::handle args,
199 const absl::optional<py::kwargs>& py_kwargs,
200 absl::Span<int const> static_argnums,
201 absl::Span<py::str const> static_argnames,
202 ParsedArgumentsAsBuffers& arguments) {
203 tensorflow::profiler::TraceMe traceme("ParseArguments");
204 int num_args = PyTuple_GET_SIZE(args.ptr());
205 int num_kwargs = py_kwargs ? py_kwargs->size() : 0;
206
207 arguments.flat_dynamic_args.reserve(num_args + num_kwargs);
208 if (static_argnums.empty()) {
209 arguments.signature.dynamic_arg_treedefs.resize(num_args);
210
211 // Positional arguments.
212 for (int i = 0; i < num_args; ++i) {
213 xla::PyTreeDef& pytree_def = arguments.signature.dynamic_arg_treedefs[i];
214 pytree_def.FlattenInto(PyTuple_GET_ITEM(args.ptr(), i),
215 arguments.flat_dynamic_args);
216 }
217 } else {
218 arguments.signature.dynamic_arg_treedefs.reserve(num_args);
219
220 // Positional arguments.
221 for (int i = 0; i < num_args; ++i) {
222 if (std::find(static_argnums.begin(), static_argnums.end(), i) ==
223 static_argnums.end()) {
224 arguments.signature.dynamic_arg_treedefs.emplace_back();
225 xla::PyTreeDef& pytree_def =
226 arguments.signature.dynamic_arg_treedefs.back();
227 pytree_def.FlattenInto(PyTuple_GET_ITEM(args.ptr(), i),
228 arguments.flat_dynamic_args);
229 } else {
230 arguments.signature.static_args.emplace_back(
231 py::reinterpret_borrow<py::object>(
232 PyTuple_GET_ITEM(args.ptr(), i)));
233 }
234 }
235 }
236
237 // Keyword arguments.
238 if (py_kwargs) {
239 std::vector<std::pair<py::handle, py::handle>> kwargs(py_kwargs->begin(),
240 py_kwargs->end());
241 // We first intern the keys, then sort them (by name, as in the Python path)
242 // (see also xla::PyTreeDef::Flatten) and then create the signatures.
243 // TODO(jblespiau): We should be able to sort the keys by interned-key
244 // pointers, but this requires the Python compilation to do the same.
245 for (int i = 0; i < num_kwargs; ++i) {
246 // Intern the key if not already interned.
247 kwargs[i].first.inc_ref();
248 if (!PyUnicode_CHECK_INTERNED(kwargs[i].first.ptr())) {
249 PyUnicode_InternInPlace(&kwargs[i].first.ptr());
250 }
251 }
252
253 std::sort(kwargs.begin(), kwargs.end(),
254 [](const std::pair<py::handle, py::handle>& a,
255 const std::pair<py::handle, py::handle>& b) {
256 return a.first < b.first;
257 });
258 auto kwarg_is_static = [&](py::handle name) {
259 for (const auto& kw : static_argnames) {
260 if (kw.ptr() == name.ptr()) return true;
261 }
262 return false;
263 };
264
265 arguments.signature.dynamic_arg_names.reserve(num_kwargs);
266 for (int i = 0; i < num_kwargs; ++i) {
267 if (kwarg_is_static(kwargs[i].first)) {
268 arguments.signature.static_arg_names.push_back(
269 py::reinterpret_steal<py::object>(kwargs[i].first));
270 arguments.signature.static_args.push_back(
271 py::reinterpret_borrow<py::object>(kwargs[i].second));
272 } else {
273 arguments.signature.dynamic_arg_names.push_back(
274 py::reinterpret_steal<py::object>(kwargs[i].first));
275 arguments.signature.dynamic_arg_treedefs.emplace_back();
276 xla::PyTreeDef& pytree_def =
277 arguments.signature.dynamic_arg_treedefs.back();
278 pytree_def.FlattenInto(kwargs[i].second, arguments.flat_dynamic_args);
279 }
280 }
281 }
282 return xla::Status::OK();
283 }
284
285 namespace {
286
287 // Elements of CacheEntry are protected by the GIL.
288 struct CacheEntry {
289 // Ensures a single thread performs the compilation for a given executable.
290 //
291 // The first thread (holding the GIL) will create the CacheEntry associated to
292 // a signature and fill it. Other threads will wait for the notification.
293 // If an error occured during the compilation, `fall_back_to_python` is set
294 // to `true`, and other threads will fail with the same error.
295 absl::Notification compilation_complete;
296
297 std::shared_ptr<xla::PyExecutable> executable;
298 xla::PyTreeDef out_pytree_def;
299 // We use Python types within the vector because this is what we will be
300 // returning to Python. No need to convert back and forth.
301 // We need py::object to maintain the objects alive.
302 std::vector<py::object> out_avals;
303 std::vector<bool> out_weak_types;
304
305 // The processing done in `AddCacheEntry` ensures that LazyExpr are stored as
306 // `py::none()`.
307 std::vector<py::object> out_lazy_exprs;
308
309 // Bitvector of kept arguments from Jaxpr DCE pass. Used to drop some `args`
310 // in CompiledFunction::Call before calling into compiled computation.
311 absl::optional<std::vector<bool>> kept_var_bitvec;
312 xla::PjRtDevice* sticky_device;
313
314 // Fallback to Python happens:
315 // - for trivial computations
316 // - when running a jax(pmap)
317 // - after a compilation error, for threads that did not compile it the first
318 // time
319 bool fall_back_to_python = false;
320
321 // Python objects (notably in the cache key) that must remain alive as long
322 // as the cache entry does. Currently this is the `key` values in the kwarg
323 // entries in the cache key.
324 std::vector<py::object> keepalive;
325 };
326
327 // A CompiledFunctionCache represents a cache of compiled functions that can be
328 // shared between one or more CompiledFunction objects. It serves two goals:
329 // - reduce the number of lru caches (hash map) across multiple JITs.
330 // - make the cache global to increase cache hits (e.g. calling jit(f)(3) twice)
331 // keeping entries alive as long as the underlying function f is alive.
332 // Assume the cache is protected by the GIL.
333 class CompiledFunctionCache {
334 public:
335 static constexpr int kDefaultCapacity = 4096;
336 explicit CompiledFunctionCache(int capacity);
337
338 // Cache entries are shared_ptr<>s because it's possible the cache entry
339 // might be evicted before we finish tracing/compiling.
340 typedef xla::LRUCache<CallSignature, std::shared_ptr<CacheEntry>> Cache;
341
342 // The lifetime of the returned cache lasts at least as long as the lifetime
343 // of `function`, so if the caller holds a strong reference to `function`, the
344 // `Cache` remains valid.
345 // We include as part of the cache key `donate_argnums` (and any other fields
346 // that aren't subsumed by the CallSignature we compute for each call).
347 Cache* Lookup(py::handle function, absl::Span<const int> donate_argnums);
348
Size() const349 int Size() const { return lru_list_.Size(); }
Capacity() const350 int Capacity() const { return lru_list_.Capacity(); }
Clear()351 void Clear() { lru_list_.Clear(); }
352
353 private:
354 struct Key {
355 py::handle function; // Does not hold a reference.
356
357 // Other fields that are part of the arguments to `jit`, but are not
358 // otherwise part of CallSignature.
359 std::vector<int> donate_argnums;
360
operator ==jax::__anon7dabdf6d0811::CompiledFunctionCache::Key361 bool operator==(const Key& other) const {
362 return std::tie(function, donate_argnums) ==
363 std::tie(other.function, other.donate_argnums);
364 }
365 };
366 template <typename H>
AbslHashValue(H h,const Key & key)367 friend H AbslHashValue(H h, const Key& key) {
368 h = H::combine(std::move(h), key.function.ptr());
369 h = H::combine_contiguous(std::move(h), key.donate_argnums.data(),
370 key.donate_argnums.size());
371 return h;
372 }
373
374 struct Value {
Valuejax::__anon7dabdf6d0811::CompiledFunctionCache::Value375 explicit Value(Cache::LRUList* lru_list) : cache(lru_list) {}
376 Cache cache;
377
378 // A weak reference to the key function. We use the weak reference to
379 // register a callback that is triggered when the key function is destroyed.
380 // We use a weak pointer because we want to allow caching across multiple
381 // calls to `jax.jit(f)` if `f` remains alive, but we do not want the cache
382 // to keep `f` alive if all other references are dropped.
383 py::weakref weakref;
384 };
385
386 Cache::LRUList lru_list_;
387 absl::flat_hash_map<Key, std::unique_ptr<Value>> functions_;
388 };
389
CompiledFunctionCache(int capacity)390 CompiledFunctionCache::CompiledFunctionCache(int capacity)
391 : lru_list_(capacity) {}
392
Lookup(py::handle function,absl::Span<const int> donate_argnums)393 CompiledFunctionCache::Cache* CompiledFunctionCache::Lookup(
394 py::handle function, absl::Span<const int> donate_argnums) {
395 Key key;
396 key.function = function;
397 key.donate_argnums =
398 std::vector<int>(donate_argnums.begin(), donate_argnums.end());
399 auto insert = functions_.emplace(key, nullptr);
400 std::unique_ptr<Value>& entry = insert.first->second;
401 if (insert.second) {
402 entry = std::make_unique<Value>(&lru_list_);
403 entry->weakref = py::weakref(
404 function,
405 py::cpp_function([this, key{std::move(key)}](py::handle weakref) {
406 functions_.erase(key);
407 }));
408 }
409 return &entry->cache;
410 }
411
412 // A `CompiledFunction` is associated to a `jax.jit(f)` and takes care of the
413 // bookkeeping of the different signatures used and the dispatch of calls to
414 // the correct underlying `PyExecutable`. This class is thread-safe.
415 class CompiledFunction {
416 public:
417 CompiledFunction(py::function fun, py::function cache_miss,
418 py::function get_device, std::vector<int> static_argnums,
419 std::vector<py::str> static_argnames,
420 std::vector<int> donate_argnums,
421 std::shared_ptr<CompiledFunctionCache> cache);
422 ~CompiledFunction();
423
424 // pybind11::object typed subclass for CompiledFunction objects.
425 class pyobject : public pybind11::object {
426 public:
427 PYBIND11_OBJECT(pyobject, // NOLINT
428 pybind11::object, CompiledFunction::IsCompiledFunction);
429 pyobject() = default;
func() const430 CompiledFunction* func() const {
431 return CompiledFunction::AsCompiledFunctionUnchecked(*this);
432 }
433 };
434 // Alias as ::object; outside the scope above we won't confuse pybind11's
435 // macros.
436 using object = pyobject;
437
438 // Returns true if `h` is a CompiledFunction.
439 static bool IsCompiledFunction(pybind11::handle handle);
440 // Converts `handle` to a CompiledFunction*. Does not do any checking.
441 static CompiledFunction* AsCompiledFunctionUnchecked(pybind11::handle handle);
442
443 // This function will:
444 // (a) flatten the inputs using pytree
445 // (b) get buffer objects from the arguments
446 // (c) call the executable
447 // (d) construct `DeviceArray` objects from the outputs
448 // (e) reconstruct the `PyTree`.
449 xla::StatusOr<py::object> Call(py::handle args,
450 absl::optional<py::kwargs> kwargs);
451
452 // This allows `inspect.signature(cpp_jitted_f)` from Python.
PythonSignature()453 py::object PythonSignature() {
454 static const auto* inspect = new py::module(py::module::import("inspect"));
455 return inspect->attr("signature")(fun_);
456 }
457
cache_size() const458 int cache_size() const { return executables_->Size(); }
ClearCache()459 void ClearCache() { executables_->Clear(); }
460
fun() const461 const py::function& fun() const { return fun_; }
cache_miss() const462 const py::function& cache_miss() const { return cache_miss_; }
get_device() const463 const py::function& get_device() const { return get_device_; }
static_argnums() const464 const std::vector<int>& static_argnums() const { return static_argnums_; }
static_argnames() const465 const std::vector<py::str>& static_argnames() const {
466 return static_argnames_;
467 }
donate_argnums() const468 const std::vector<int>& donate_argnums() const { return donate_argnums_; }
cache() const469 const std::shared_ptr<CompiledFunctionCache>& cache() const { return cache_; }
470
471 // Helper function used by the tp_clear GC method.
ClearPythonReferences()472 void ClearPythonReferences() {
473 py::function fun, cache_miss, get_device;
474 // Swap values for nulls before they are destroyed. See the Python
475 // Py_CLEAR() documentation for a discussion of this topic.
476 std::swap(fun_, fun);
477 std::swap(cache_miss_, cache_miss);
478 std::swap(get_device_, get_device);
479 }
480
481 py::handle AsPyHandle();
function_name() const482 const std::string& function_name() const { return function_name_; }
483
484 private:
485 // Attempts to populate default_device_. May release the GIL; is
486 // reentrant-safe.
487 void TryToPopulateDefaultDevice();
488
489 void PopulateCacheEntry(CacheEntry* entry, const CallSignature& signature,
490 const py::tuple& out_and_fastpath_data);
491 bool always_fallback_to_python_ = false;
492
493 py::function fun_; // The Python function to jit.
494 std::string function_name_;
495
496 // See JAX _cpp_jit in api.py for documentation.
497 py::function cache_miss_;
498
499 // We need to know the static arguments to remove them from the arguments
500 // passed to the underlying PyExecutable. In sorted order.
501 std::vector<int> static_argnums_;
502 // Keyword arguments, interned.
503 std::vector<py::str> static_argnames_;
504 std::vector<int> donate_argnums_;
505
506 // A function taking no arguments and returning the default device and whether
507 // jax.jit has been committed to it.
508 py::function get_device_;
509
510 // Keeps the shared LRU cache alive as long as the CompiledFunction is alive.
511 std::shared_ptr<CompiledFunctionCache> cache_;
512
513 // The part of cache_ specific to this CompiledFunction.
514 CompiledFunctionCache::Cache* executables_;
515
516 // The logic if the following:
517 // - if `device` or `backend` are not specified to `jax.jit`, we will use
518 // the input sticky buffer device, or `default_device_` if there is no
519 // such sticky buffer.
520 // - When one of `device` or `backend` is specified, this will determine
521 // the `default_device_` which will be used as the targeted device. In
522 // which case, we will always copy input buffers to this device.
523 // These fields are protected by the GIL.
524 std::shared_ptr<xla::PyClient> default_pyclient_ = nullptr;
525 xla::PjRtDevice* default_device_ = nullptr;
526 bool is_committed_;
527 };
528
CompiledFunction(py::function fun,py::function cache_miss,py::function get_device,std::vector<int> static_argnums,std::vector<py::str> static_argnames,std::vector<int> donate_argnums,std::shared_ptr<CompiledFunctionCache> cache)529 CompiledFunction::CompiledFunction(py::function fun, py::function cache_miss,
530 py::function get_device,
531 std::vector<int> static_argnums,
532 std::vector<py::str> static_argnames,
533 std::vector<int> donate_argnums,
534 std::shared_ptr<CompiledFunctionCache> cache)
535 : fun_(std::move(fun)),
536 cache_miss_(std::move(cache_miss)),
537 static_argnums_(std::move(static_argnums)),
538 static_argnames_(std::move(static_argnames)),
539 donate_argnums_(donate_argnums),
540 get_device_(std::move(get_device)),
541 cache_(std::move(cache)) {
542 std::sort(static_argnums_.begin(), static_argnums_.end());
543 for (py::str& s : static_argnames) {
544 PyUnicode_InternInPlace(&s.ptr());
545 }
546 executables_ = cache_->Lookup(fun_, donate_argnums);
547 function_name_ = py::str(py::getattr(fun_, "__name__", fun));
548 }
549
550 CompiledFunction::~CompiledFunction() = default;
551
552 // Compute signature for arguments.
553 //
554 // Returns `Status::OK()` on success. Returning an error should lead to
555 // calling the Python fallback.
ComputeSignature(bool jax_enable_x64,xla::PyClient & pyclient,xla::PjRtDevice * default_device,bool is_committed,ParsedArgumentsAsBuffers & arguments)556 xla::Status ComputeSignature(bool jax_enable_x64, xla::PyClient& pyclient,
557 xla::PjRtDevice* default_device, bool is_committed,
558 ParsedArgumentsAsBuffers& arguments) {
559 tensorflow::profiler::TraceMe traceme("ComputeSignature");
560
561 int num_flat_dynamic_args = arguments.flat_dynamic_args.size();
562 struct PythonTypes {
563 py::object device_array;
564 };
565 static const auto& types = *[]() -> PythonTypes* {
566 py::module xla_module(py::module::import("jax.interpreters.xla"));
567 py::object device_array(xla_module.attr("_DeviceArray"));
568 return new PythonTypes{device_array};
569 }();
570 // When the jitted function is not committed, we first check whether any
571 // sticky `DeviceArray` is present and on which device they live. See also:
572 // https://github.com/google/jax/pull/1884
573 // https://github.com/google/jax/pull/1916 for the rationale why the
574 // computation follows the data locality.
575 // It's also similar to PyTorch's behavior.
576 xla::PjRtDevice* data_device = nullptr;
577 if (is_committed) {
578 data_device = default_device;
579 } else {
580 for (int i = 0; i < num_flat_dynamic_args; ++i) {
581 py::handle arg = arguments.flat_dynamic_args[i];
582 // We specically only deal with DeviceArray (not ShardedDeviceArray).
583 // (Can happen in jit(pmap), e.g. "test_jit_nested_donate_ignored").
584 xla::PjRtDevice* device = nullptr;
585 if (arg.get_type().ptr() == xla::PyBuffer::type()) {
586 xla::PyBuffer* buffer = xla::PyBuffer::AsPyBufferUnchecked(arg);
587 if (!buffer->sticky_device()) {
588 continue;
589 }
590 device = buffer->sticky_device();
591 } else if (arg.get_type().ptr() == types.device_array.ptr()) {
592 if (arg.attr("_device").is_none()) { // Skip non-sticky devices.
593 continue;
594 }
595 try {
596 // This can fail, e.g. for cloud TPU 2VM buffers.
597 TF_ASSIGN_OR_RETURN(
598 xla::PyBuffer * buffer,
599 xla::PyBuffer::AsPyBuffer(arg.attr("device_buffer")));
600 device = buffer->buffer()->device();
601 } catch (const py::cast_error& e) {
602 return xla::InvalidArgument(
603 "%s",
604 absl::StrCat("[jaxjit] Unsupported subclass of `DeviceArray`: "
605 "`device_buffer` field is of type ",
606 py::cast<std::string>(
607 arg.attr("device_buffer").get_type().str()),
608 " while a `PyBuffer` was expected."
609
610 ));
611 }
612 }
613 if (device) {
614 if (data_device && (device != data_device)) {
615 throw std::invalid_argument(absl::StrCat(
616 "primitive arguments must be colocated on the same device ("
617 "C++ jax.jit). Arguments are on devices: ",
618 device->DebugString(), " and ", data_device->DebugString()));
619 } else {
620 data_device = device;
621 }
622 }
623 }
624 }
625 if (!data_device) {
626 // No `DeviceArray` were found default to `default_device`.
627 data_device = default_device;
628 }
629 CHECK(data_device);
630 arguments.signature.device = data_device;
631
632 arguments.signature.dynamic_arg_signatures.reserve(num_flat_dynamic_args);
633 for (int i = 0; i < num_flat_dynamic_args; ++i) {
634 py::handle arg = arguments.flat_dynamic_args[i];
635 TF_ASSIGN_OR_RETURN(auto sig,
636 xla::PyArgSignatureOfValue(arg, jax_enable_x64));
637 arguments.signature.dynamic_arg_signatures.push_back(std::move(sig));
638 }
639 return xla::Status::OK();
640 }
641
642 // Copy buffers to device, skipping pruned arguments.
643 // Returns `Status::OK()` on success. Returning an error should lead to
644 // calling the Python fallback.
CopyBuffersToDevice(bool jax_enable_x64,const absl::optional<std::vector<bool>> & kept_args,ParsedArgumentsAsBuffers & arguments)645 xla::Status CopyBuffersToDevice(
646 bool jax_enable_x64, const absl::optional<std::vector<bool>>& kept_args,
647 ParsedArgumentsAsBuffers& arguments) {
648 std::vector<xla::PjRtBuffer*>& arg_buffers = arguments.arg_buffers;
649 xla::PjRtDevice* data_device = arguments.signature.device;
650
651 int num_flat_dynamic_args = arguments.flat_dynamic_args.size();
652 xla::DevicePutOptions options;
653 options.squash_64bit_types = !jax_enable_x64;
654 // TODO(phawkins): consider allowing forces here.
655 options.force_lazy_arrays = false;
656 options.allow_zero_copy = true;
657 arg_buffers.reserve(num_flat_dynamic_args);
658 bool input_pruning_enabled = kept_args.has_value();
659 for (int i = 0; i < num_flat_dynamic_args; ++i) {
660 if (input_pruning_enabled && !kept_args.value()[i]) {
661 continue;
662 }
663
664 py::handle arg = arguments.flat_dynamic_args[i];
665 TF_ASSIGN_OR_RETURN(xla::DevicePutResult on_device,
666 DevicePut(arg, data_device, options));
667
668 xla::PjRtBuffer* buffer = on_device.buffer;
669 arg_buffers.push_back(buffer);
670 if (on_device.owned_buffer) {
671 arguments.keep_alive.push_back(std::move(on_device.owned_buffer));
672 } else if (on_device.owning_pybuffer) {
673 arguments.keep_alive_objects.push_back(
674 std::move(on_device.owning_pybuffer));
675 }
676 }
677 return xla::Status::OK();
678 }
679
PopulateCacheEntry(CacheEntry * cache_entry,const CallSignature & signature,const py::tuple & out_and_fastpath_data)680 void CompiledFunction::PopulateCacheEntry(
681 CacheEntry* cache_entry, const CallSignature& signature,
682 const py::tuple& out_and_fastpath_data) {
683 CHECK_EQ(out_and_fastpath_data.size(), 2);
684 if (out_and_fastpath_data[1].is_none()) {
685 cache_entry->fall_back_to_python = true;
686 return;
687 }
688
689 py::tuple executable_handlers_out_tree =
690 py::cast<py::tuple>(out_and_fastpath_data[1]);
691 // TODO(zhangqiaorjc): Lookup NamedTuple by name after min jax version bump.
692 size_t arity = executable_handlers_out_tree.size();
693 if (arity != 5 && !py::hasattr(executable_handlers_out_tree, "_fields")) {
694 throw std::runtime_error(absl::StrCat(
695 "The versions of jaxlib and Jax are incompatible (jaxlib is too recent "
696 "compared to Jax. Upgrade Jax is advised. The C++ code expects "
697 "5 or 6 arguments but ",
698 arity, " were provided: ",
699 py::cast<std::string>(
700 py::str(py::repr(executable_handlers_out_tree)))));
701 }
702 // (xla_executable, out_pytree_def, sticky_device, avals, lazy_exprs)
703 auto executable = py::cast<std::shared_ptr<xla::PyExecutable>>(
704 executable_handlers_out_tree[0]);
705 cache_entry->executable = std::move(executable);
706 int num_devices =
707 cache_entry->executable->pjrt_executable().addressable_devices().size();
708 // The presence of jit(pmap) is detected from Python.
709 CHECK_EQ(num_devices, 1);
710
711 auto out_tree = py::cast<xla::PyTreeDef>(executable_handlers_out_tree[1]);
712 cache_entry->out_pytree_def = std::move(out_tree);
713
714 cache_entry->sticky_device =
715 py::cast<xla::PjRtDevice*>(executable_handlers_out_tree[2]);
716 auto avals = py::cast<py::list>(executable_handlers_out_tree[3]);
717 auto lazy_exprs = py::cast<py::list>(executable_handlers_out_tree[4]);
718 CHECK_EQ(avals.size(), lazy_exprs.size());
719
720 cache_entry->out_avals.reserve(avals.size());
721 cache_entry->out_weak_types.reserve(avals.size());
722 cache_entry->out_lazy_exprs.reserve(avals.size());
723 for (int i = 0; i < avals.size(); ++i) {
724 py::object shaped_array = py::reinterpret_borrow<py::object>(avals[i]);
725 py::object lazy_expr = py::reinterpret_borrow<py::object>(lazy_exprs[i]);
726
727 cache_entry->out_avals.push_back(shaped_array);
728 cache_entry->out_weak_types.push_back(
729 py::cast<bool>(shaped_array.attr("weak_type")));
730 cache_entry->out_lazy_exprs.push_back(lazy_expr);
731 }
732 auto kept_var_bitvec_attr =
733 py::getattr(executable_handlers_out_tree, "kept_var_bitvec", py::none());
734 if (!kept_var_bitvec_attr.is_none()) {
735 auto kept_var_bitvec = py::cast<py::list>(kept_var_bitvec_attr);
736 cache_entry->kept_var_bitvec =
737 absl::make_optional<std::vector<bool>>(kept_var_bitvec.size(), false);
738 for (int i = 0; i < kept_var_bitvec.size(); ++i) {
739 cache_entry->kept_var_bitvec.value()[i] =
740 py::cast<bool>(kept_var_bitvec[i]);
741 }
742 }
743 }
744
TryToPopulateDefaultDevice()745 void CompiledFunction::TryToPopulateDefaultDevice() {
746 // The following line calls Python and may release the GIL.
747 py::object device_and_is_committed;
748 try {
749 device_and_is_committed = get_device_();
750 } catch (py::error_already_set& e) {
751 // Backend or device initialization failed. Handle this in Python.
752 always_fallback_to_python_ = true;
753 return;
754 }
755 // If the GIL was released by the call to get_device_, another thread may
756 // have filled in default_device_.
757 if (!default_device_) {
758 try {
759 auto default_pydevice = py::cast<xla::ClientAndPtr<xla::PjRtDevice>>(
760 device_and_is_committed.attr("default_device"));
761 is_committed_ =
762 py::cast<bool>(device_and_is_committed.attr("committed_to_device"));
763 default_pyclient_ = default_pydevice.client;
764 default_device_ = default_pydevice.contents;
765 } catch (const py::cast_error& e) {
766 // Pathways, Cloud TPU 2VM, and UPTC runtime.
767 always_fallback_to_python_ = true;
768 }
769 }
770 }
771
Call(py::handle args,absl::optional<py::kwargs> kwargs)772 xla::StatusOr<py::object> CompiledFunction::Call(
773 py::handle args, absl::optional<py::kwargs> kwargs) {
774 // Make sure we trigger a garbage collection on JIT function calls. Otherwise
775 // code like
776 // f = jit(...)
777 // while True:
778 // f(x)
779 // may never free temporary buffers for copies of arguments.
780 xla::GlobalPyRefManager()->MaybeCollectGarbage();
781
782 auto& tls = thread_local_state;
783 if (tls.disable_jit.value_or(global_state.disable_jit)) {
784 return fun_(*py::reinterpret_borrow<py::args>(args),
785 **kwargs.value_or(py::kwargs()));
786 }
787 if (always_fallback_to_python_) {
788 return py::object(
789 py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
790 **kwargs.value_or(py::kwargs())))[0]);
791 }
792
793 // On the first call to `Call`, compute a default device. We need to wait
794 // until after platform initialization is complete before doing so, but @jit
795 // may be used as a decorator.
796 if (!default_device_) {
797 TryToPopulateDefaultDevice();
798 if (!default_device_) {
799 return py::object(py::cast<py::tuple>(
800 cache_miss_(*py::reinterpret_borrow<py::args>(args),
801 **kwargs.value_or(py::kwargs())))[0]);
802 }
803 }
804
805 ParsedArgumentsAsBuffers arguments;
806 xla::Status status = ParseArguments(args, kwargs, static_argnums_,
807 static_argnames_, arguments);
808 if (!status.ok()) {
809 VLOG(2) << "ParseArguments failed: " << status;
810 return py::object(
811 py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
812 **kwargs.value_or(py::kwargs())))[0]);
813 }
814
815 bool jax_enable_x64 = tls.enable_x64.value_or(global_state.enable_x64);
816 arguments.signature.jax_enable_x64 = jax_enable_x64;
817 // The C++ jit do not support Tracers arguments inputs yet. The Python-based
818 // jit function will be called if any of the dynamic arguments is unsupported.
819 status = ComputeSignature(jax_enable_x64, *default_pyclient_, default_device_,
820 is_committed_, arguments);
821 if (!status.ok()) {
822 VLOG(2) << "ComputeSignature failed: " << status;
823 return py::object(
824 py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
825 **kwargs.value_or(py::kwargs())))[0]);
826 }
827 arguments.signature.global_extra_jit_context = global_state.extra_jit_context;
828 arguments.signature.thread_local_extra_jit_context = tls.extra_jit_context;
829
830 bool inserted = false;
831 std::shared_ptr<CacheEntry> cache_entry = executables_->GetOrCreateIfAbsent(
832 arguments.signature, [&inserted](const CallSignature& key) {
833 inserted = true;
834 return std::make_shared<CacheEntry>();
835 });
836
837 if (!cache_entry->compilation_complete.HasBeenNotified()) {
838 // In case of several threads attempting to compile the executable, only
839 // the one that inserted the item will perform the compilation.
840 if (inserted) {
841 py::object out_and_fastpath_data;
842 py::tuple out_tuple;
843 VLOG(2) << "Cache miss for " << arguments.signature.DebugString();
844 try {
845 // Calls Python and may release the GIL. May also throw if
846 // compilation/tracing fails.
847 out_and_fastpath_data = out_and_fastpath_data =
848 cache_miss_(*py::reinterpret_borrow<py::args>(args),
849 **kwargs.value_or(py::kwargs()));
850 out_tuple = py::cast<py::tuple>(out_and_fastpath_data);
851 PopulateCacheEntry(cache_entry.get(), arguments.signature, out_tuple);
852 } catch (const std::exception& e) {
853 cache_entry->fall_back_to_python = true;
854 cache_entry->compilation_complete.Notify();
855 throw;
856 }
857 cache_entry->compilation_complete.Notify();
858
859 // We have already computed the result in the miss path so we can return
860 // it. We are even *required* to do so if there are donated arguments,
861 // because any donated buffers will now be invalid.
862 return py::object(out_tuple[0]);
863 } else {
864 // Release the GIL while we wait, making sure the compile thread can
865 // lock it.
866 py::gil_scoped_release release;
867 cache_entry->compilation_complete.WaitForNotification();
868 }
869 }
870 // It's hard to reraise the exact same kind of errors when a compilation error
871 // occured. If the first compilation failed, other threads will also execute
872 // the Python path.
873 if (cache_entry->fall_back_to_python) {
874 return py::object(
875 py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
876 **kwargs.value_or(py::kwargs())))[0]);
877 }
878
879 status = CopyBuffersToDevice(jax_enable_x64, cache_entry->kept_var_bitvec,
880 arguments);
881 if (!status.ok()) {
882 VLOG(2) << "CopyBuffersToDevice failed: " << status;
883 return py::object(
884 py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
885 **kwargs.value_or(py::kwargs())))[0]);
886 }
887
888 // Executes the computation.
889 std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> output_buffers;
890 {
891 py::gil_scoped_release gil_release;
892 TF_ASSIGN_OR_RETURN(
893 output_buffers,
894 cache_entry->executable->mutable_pjrt_executable()->Execute(
895 {arguments.arg_buffers}, cache_entry->executable->options()));
896 }
897 auto traceback = xla::Traceback::Get();
898
899 int num_outputs = output_buffers[0].size();
900 absl::InlinedVector<py::object, 1> flat_device_arrays;
901 flat_device_arrays.reserve(num_outputs);
902 for (int i = 0; i < output_buffers[0].size(); ++i) {
903 bool last = (i == (num_outputs - 1));
904 xla::PyBuffer::object buffer = xla::PyBuffer::Make(
905 cache_entry->executable->client(), std::move(output_buffers[0][i]),
906 last ? std::move(traceback) : traceback);
907 if (cache_entry->out_lazy_exprs[i].is_none()) { // No LazyExpr.
908 buffer.buf()->SetAval(cache_entry->out_avals[i]);
909 buffer.buf()->set_weak_type(cache_entry->out_weak_types[i]);
910 TF_RETURN_IF_ERROR(
911 buffer.buf()->set_sticky_device(cache_entry->sticky_device));
912 flat_device_arrays.push_back(std::move(buffer));
913 } else {
914 static const auto* xla_module =
915 new py::module(py::module::import("jax.interpreters.xla"));
916 static const auto* device_array =
917 new py::handle(xla_module->attr("_DeviceArray"));
918 flat_device_arrays.push_back(
919 (*device_array)(cache_entry->out_avals[i],
920 py::cast(WrapWithClient(default_pyclient_,
921 cache_entry->sticky_device)),
922 cache_entry->out_lazy_exprs[i], std::move(buffer)));
923 }
924 }
925 py::object out = cache_entry->out_pytree_def.Unflatten(flat_device_arrays);
926
927 // If there is a post-hook function, call it with the inputs and the outputs.
928 absl::optional<py::object> post_hook =
929 tls.post_hook.has_value() ? tls.post_hook : global_state.post_hook;
930 if (post_hook) {
931 (*post_hook)(AsPyHandle(), args,
932 py::cast<py::dict>(kwargs.value_or(py::kwargs())), out);
933 }
934 return std::move(out);
935 }
936
937 struct JaxCompiledFunctionObject {
938 PyObject_HEAD;
939 PyObject* dict; // Dictionary for __dict__
940 PyObject* weakrefs; // Weak references; for use by the Python interpreter.
941 CompiledFunction fun;
942 };
943
944 PyObject* JaxCompiledFunction_Type = nullptr;
945
IsCompiledFunction(py::handle handle)946 bool CompiledFunction::IsCompiledFunction(py::handle handle) {
947 return handle.get_type() == JaxCompiledFunction_Type;
948 }
949
AsCompiledFunctionUnchecked(py::handle handle)950 CompiledFunction* CompiledFunction::AsCompiledFunctionUnchecked(
951 py::handle handle) {
952 return &(reinterpret_cast<JaxCompiledFunctionObject*>(handle.ptr())->fun);
953 }
954
AsCompiledFunction(py::handle handle)955 xla::StatusOr<CompiledFunction*> AsCompiledFunction(py::handle handle) {
956 if (!CompiledFunction::IsCompiledFunction(handle)) {
957 return xla::InvalidArgument("Expected a CompiledFunction");
958 }
959 return CompiledFunction::AsCompiledFunctionUnchecked(handle);
960 }
961
AsPyHandle()962 py::handle CompiledFunction::AsPyHandle() {
963 return reinterpret_cast<PyObject*>(reinterpret_cast<char*>(this) -
964 offsetof(JaxCompiledFunctionObject, fun));
965 }
966
967 extern "C" {
968
JaxCompiledFunction_tp_new(PyTypeObject * subtype,PyObject * args,PyObject * kwds)969 PyObject* JaxCompiledFunction_tp_new(PyTypeObject* subtype, PyObject* args,
970 PyObject* kwds) {
971 JaxCompiledFunctionObject* self =
972 reinterpret_cast<JaxCompiledFunctionObject*>(
973 subtype->tp_alloc(subtype, 0));
974 if (!self) return nullptr;
975 self->dict = nullptr;
976 self->weakrefs = nullptr;
977 return reinterpret_cast<PyObject*>(self);
978 }
979
JaxCompiledFunction_tp_dealloc(PyObject * self)980 void JaxCompiledFunction_tp_dealloc(PyObject* self) {
981 PyTypeObject* tp = Py_TYPE(self);
982 JaxCompiledFunctionObject* o =
983 reinterpret_cast<JaxCompiledFunctionObject*>(self);
984 if (o->weakrefs) {
985 PyObject_ClearWeakRefs(self);
986 }
987 Py_CLEAR(o->dict);
988 o->fun.~CompiledFunction();
989 tp->tp_free(self);
990 Py_DECREF(tp);
991 }
992
JaxCompiledFunction_tp_traverse(PyObject * self,visitproc visit,void * arg)993 int JaxCompiledFunction_tp_traverse(PyObject* self, visitproc visit,
994 void* arg) {
995 JaxCompiledFunctionObject* o =
996 reinterpret_cast<JaxCompiledFunctionObject*>(self);
997 Py_VISIT(o->dict);
998 Py_VISIT(o->fun.fun().ptr());
999 Py_VISIT(o->fun.cache_miss().ptr());
1000 Py_VISIT(o->fun.get_device().ptr());
1001 return 0;
1002 }
1003
JaxCompiledFunction_tp_clear(PyObject * self)1004 int JaxCompiledFunction_tp_clear(PyObject* self) {
1005 JaxCompiledFunctionObject* o =
1006 reinterpret_cast<JaxCompiledFunctionObject*>(self);
1007 Py_CLEAR(o->dict);
1008 o->fun.ClearPythonReferences();
1009 return 0;
1010 }
1011
1012 // Implements the Python descriptor protocol so JIT-compiled functions can be
1013 // used as bound methods. See:
1014 // https://docs.python.org/3/howto/descriptor.html#functions-and-methods
JaxCompiledFunction_tp_descr_get(PyObject * self,PyObject * obj,PyObject * type)1015 PyObject* JaxCompiledFunction_tp_descr_get(PyObject* self, PyObject* obj,
1016 PyObject* type) {
1017 if (obj == nullptr || obj == Py_None) {
1018 Py_INCREF(self);
1019 return self;
1020 }
1021 return PyMethod_New(self, obj);
1022 }
1023
1024 // Support d = instance.__dict__.
JaxCompiledFunction_get_dict(PyObject * self,void *)1025 PyObject* JaxCompiledFunction_get_dict(PyObject* self, void*) {
1026 JaxCompiledFunctionObject* o =
1027 reinterpret_cast<JaxCompiledFunctionObject*>(self);
1028 if (!o->dict) {
1029 o->dict = PyDict_New();
1030 }
1031 Py_XINCREF(o->dict);
1032 return o->dict;
1033 }
1034
JaxCompiledFunction_set_dict(PyObject * self,PyObject * new_dict,void *)1035 int JaxCompiledFunction_set_dict(PyObject* self, PyObject* new_dict, void*) {
1036 JaxCompiledFunctionObject* o =
1037 reinterpret_cast<JaxCompiledFunctionObject*>(self);
1038 if (!PyDict_Check(new_dict)) {
1039 PyErr_Format(PyExc_TypeError,
1040 "__dict__ must be set to a dictionary, not a '%s'",
1041 Py_TYPE(new_dict)->tp_name);
1042 return -1;
1043 }
1044 Py_INCREF(new_dict);
1045 Py_CLEAR(o->dict);
1046 o->dict = new_dict;
1047 return 0;
1048 }
1049
1050 static PyGetSetDef JaxCompiledFunction_tp_getset[] = {
1051 // Having a __dict__ seems necessary to allow !functool.wraps to override
1052 // __doc__.
1053 {const_cast<char*>("__dict__"), JaxCompiledFunction_get_dict,
1054 JaxCompiledFunction_set_dict, nullptr, nullptr},
1055 {nullptr, nullptr, nullptr, nullptr, nullptr}};
1056
JaxCompiledFunction_tp_call(PyObject * self,PyObject * args,PyObject * kwargs)1057 PyObject* JaxCompiledFunction_tp_call(PyObject* self, PyObject* args,
1058 PyObject* kwargs) {
1059 JaxCompiledFunctionObject* o =
1060 reinterpret_cast<JaxCompiledFunctionObject*>(self);
1061 tensorflow::profiler::TraceMe traceme([&] {
1062 return absl::StrCat("JaxCompiledFunction(", o->fun.function_name(), ")");
1063 });
1064 absl::optional<py::kwargs> py_kwargs;
1065 if (kwargs) {
1066 py_kwargs = py::reinterpret_borrow<py::kwargs>(kwargs);
1067 }
1068 try {
1069 xla::StatusOr<py::object> out = o->fun.Call(args, std::move(py_kwargs));
1070 if (!out.ok()) {
1071 PyErr_SetString(PyExc_ValueError, out.status().ToString().c_str());
1072 return nullptr;
1073 }
1074 return out.ValueOrDie().release().ptr();
1075 } catch (py::error_already_set& e) {
1076 e.restore();
1077 return nullptr;
1078 } catch (py::cast_error& e) {
1079 PyErr_SetString(PyExc_ValueError, e.what());
1080 return nullptr;
1081 } catch (std::invalid_argument& e) {
1082 PyErr_SetString(PyExc_ValueError, e.what());
1083 return nullptr;
1084 }
1085 }
1086
InitializeCompiledFunction(JaxCompiledFunctionObject * cfun,py::function fun,py::function cache_miss,py::function get_device,std::vector<int> static_argnums,std::vector<py::str> static_argnames,std::vector<int> donate_argnums,std::shared_ptr<CompiledFunctionCache> cache)1087 void InitializeCompiledFunction(JaxCompiledFunctionObject* cfun,
1088 py::function fun, py::function cache_miss,
1089 py::function get_device,
1090 std::vector<int> static_argnums,
1091 std::vector<py::str> static_argnames,
1092 std::vector<int> donate_argnums,
1093 std::shared_ptr<CompiledFunctionCache> cache) {
1094 new (&cfun->fun) CompiledFunction(
1095 std::move(fun), std::move(cache_miss), std::move(get_device),
1096 std::move(static_argnums), std::move(static_argnames),
1097 std::move(donate_argnums), std::move(cache));
1098 }
1099
1100 } // extern "C"
1101
MakeCompiledFunction(py::function fun,py::function cache_miss,py::function get_device,std::vector<int> static_argnums,std::vector<py::str> static_argnames,std::vector<int> donate_argnums,std::shared_ptr<CompiledFunctionCache> cache)1102 py::object MakeCompiledFunction(py::function fun, py::function cache_miss,
1103 py::function get_device,
1104 std::vector<int> static_argnums,
1105 std::vector<py::str> static_argnames,
1106 std::vector<int> donate_argnums,
1107 std::shared_ptr<CompiledFunctionCache> cache) {
1108 py::object obj = py::reinterpret_steal<py::object>(JaxCompiledFunction_tp_new(
1109 reinterpret_cast<PyTypeObject*>(JaxCompiledFunction_Type), nullptr,
1110 nullptr));
1111 JaxCompiledFunctionObject* buf =
1112 reinterpret_cast<JaxCompiledFunctionObject*>(obj.ptr());
1113 if (!cache) {
1114 cache = std::make_shared<CompiledFunctionCache>(
1115 CompiledFunctionCache::kDefaultCapacity);
1116 }
1117 InitializeCompiledFunction(buf, std::move(fun), std::move(cache_miss),
1118 std::move(get_device), std::move(static_argnums),
1119 std::move(static_argnames),
1120 std::move(donate_argnums), std::move(cache));
1121 return obj;
1122 }
1123
1124 // Helpers for building Python properties
1125 template <typename Func>
property_readonly(Func && get)1126 py::object property_readonly(Func&& get) {
1127 py::handle property(reinterpret_cast<PyObject*>(&PyProperty_Type));
1128 return property(py::cpp_function(std::forward<Func>(get)), py::none(),
1129 py::none(), "");
1130 }
1131
1132 // Version numbers for the pickled representations of
1133 // CompiledFunction/CompiledFunctionCache. Increment these if changing them.
1134 const int kCompiledFunctionCachePickleVersion = 1;
1135 const int kCompiledFunctionPickleVersion = 1;
1136
1137 } // namespace
1138
BuildJaxjitSubmodule(py::module & m)1139 void BuildJaxjitSubmodule(py::module& m) {
1140 py::module jitlib = m.def_submodule("jax_jit", "Jax C++ jit library");
1141
1142 // We define CompiledFunctionCache in the xla_extension module to work
1143 // around https://github.com/cloudpipe/cloudpickle/issues/354, which
1144 // means classes defined in submodules aren't pickleable.
1145 // TODO(phawkins): remove this workaround once Google is using a newer
1146 // cloudpickle. Opensource users can just pip install a newer version already.
1147 py::class_<CompiledFunctionCache, std::shared_ptr<CompiledFunctionCache>>
1148 cache(m, "CompiledFunctionCache");
1149 cache.def(py::init<int>(),
1150 py::arg("capacity") = CompiledFunctionCache::kDefaultCapacity);
1151 cache.def("size", &CompiledFunctionCache::Size);
1152 cache.def("capacity", &CompiledFunctionCache::Capacity);
1153 cache.def("clear", &CompiledFunctionCache::Clear);
1154 cache.def(py::pickle(
1155 // __getstate__
1156 // Pickles as an empty cache; the client can repopulate as needed.
1157 [](const CompiledFunctionCache& cache) {
1158 py::dict pickle;
1159 pickle["version"] = kCompiledFunctionCachePickleVersion;
1160 pickle["capacity"] = cache.Capacity();
1161 return pickle;
1162 },
1163 // __setstate__
1164 [](const py::dict& pickle) {
1165 int version = py::cast<int>(pickle["version"]);
1166 if (version != kCompiledFunctionCachePickleVersion) {
1167 throw std::invalid_argument(absl::StrFormat(
1168 "Invalid CompiledFunction pickle version, got %d, expected %d",
1169 version, kCompiledFunctionCachePickleVersion));
1170 }
1171 int capacity = py::cast<int>(pickle["capacity"]);
1172 return std::make_shared<CompiledFunctionCache>(capacity);
1173 }));
1174
1175 // Alias CompiledFunctionCache in the submodule where we actually wanted to
1176 // define it.
1177 jitlib.attr("CompiledFunctionCache") = cache;
1178
1179 // We need to use heap-allocated type objects because we want to add
1180 // additional methods dynamically.
1181 py::object cfun;
1182 {
1183 py::str name = py::str("CompiledFunction");
1184 py::str qualname = py::str("CompiledFunction");
1185 PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>(
1186 PyType_Type.tp_alloc(&PyType_Type, 0));
1187 // Caution: we must not call any functions that might invoke the GC until
1188 // PyType_Ready() is called. Otherwise the GC might see a half-constructed
1189 // type object.
1190 CHECK(heap_type) << "Unable to create heap type object";
1191 heap_type->ht_name = name.release().ptr();
1192 heap_type->ht_qualname = qualname.release().ptr();
1193 PyTypeObject* type = &heap_type->ht_type;
1194 type->tp_name = "CompiledFunction";
1195 type->tp_basicsize = sizeof(JaxCompiledFunctionObject);
1196 type->tp_flags =
1197 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE | Py_TPFLAGS_HAVE_GC;
1198 type->tp_new = JaxCompiledFunction_tp_new;
1199 type->tp_dealloc = JaxCompiledFunction_tp_dealloc;
1200 type->tp_dictoffset = offsetof(JaxCompiledFunctionObject, dict);
1201 type->tp_traverse = JaxCompiledFunction_tp_traverse;
1202 type->tp_clear = JaxCompiledFunction_tp_clear;
1203 type->tp_weaklistoffset = offsetof(JaxCompiledFunctionObject, weakrefs);
1204 type->tp_getset = JaxCompiledFunction_tp_getset;
1205 type->tp_descr_get = JaxCompiledFunction_tp_descr_get;
1206 type->tp_call = JaxCompiledFunction_tp_call;
1207 CHECK_EQ(PyType_Ready(type), 0);
1208 JaxCompiledFunction_Type = reinterpret_cast<PyObject*>(type);
1209 cfun = py::reinterpret_borrow<py::object>(JaxCompiledFunction_Type);
1210 }
1211 py::object cfun_type =
1212 py::reinterpret_borrow<py::object>(JaxCompiledFunction_Type);
1213
1214 // Add CompiledFunction to the xla_extension module so it can be pickled.
1215 m.attr("CompiledFunction") = cfun_type;
1216
1217 cfun.attr("__signature__") =
1218 property_readonly([](py::handle self) -> xla::StatusOr<py::object> {
1219 TF_ASSIGN_OR_RETURN(CompiledFunction * fun, AsCompiledFunction(self));
1220 return fun->PythonSignature();
1221 });
1222 cfun.attr("_cache_miss") =
1223 property_readonly([](py::handle self) -> xla::StatusOr<py::object> {
1224 TF_ASSIGN_OR_RETURN(CompiledFunction * fun, AsCompiledFunction(self));
1225 return fun->cache_miss();
1226 });
1227 cfun.attr("__getstate__") = py::cpp_function(
1228 [](const CompiledFunction::object& self) {
1229 CompiledFunction* fn = self.func();
1230 py::dict pickle;
1231 pickle["version"] = kCompiledFunctionPickleVersion;
1232 pickle["fun"] = fn->fun();
1233 pickle["cache_miss"] = fn->cache_miss();
1234 pickle["get_device"] = fn->get_device();
1235 pickle["static_argnums"] = fn->static_argnums();
1236 pickle["static_argnames"] = fn->static_argnames();
1237 pickle["donate_argnums"] = fn->donate_argnums();
1238 pickle["cache"] = fn->cache();
1239 return pickle;
1240 },
1241 py::is_method(cfun_type));
1242 cfun.attr("__setstate__") = py::cpp_function(
1243 [](CompiledFunction::object& self, const py::dict& pickle) {
1244 int version = py::cast<int>(pickle["version"]);
1245 if (version != kCompiledFunctionPickleVersion) {
1246 throw std::invalid_argument(absl::StrFormat(
1247 "Invalid CompiledFunction pickle version, got %d, expected %d",
1248 version, kCompiledFunctionPickleVersion));
1249 }
1250 py::function fun = py::cast<py::function>(pickle["fun"]);
1251 py::function cache_miss = py::cast<py::function>(pickle["cache_miss"]);
1252 py::function get_device = py::cast<py::function>(pickle["get_device"]);
1253 std::vector<int> static_argnums =
1254 py::cast<std::vector<int>>(pickle["static_argnums"]);
1255 std::vector<py::str> static_argnames =
1256 py::cast<std::vector<py::str>>(pickle["static_argnames"]);
1257 std::vector<int> donate_argnums =
1258 py::cast<std::vector<int>>(pickle["donate_argnums"]);
1259 std::shared_ptr<CompiledFunctionCache> cache =
1260 py::cast<std::shared_ptr<CompiledFunctionCache>>(pickle["cache"]);
1261 InitializeCompiledFunction(
1262 reinterpret_cast<JaxCompiledFunctionObject*>(self.ptr()),
1263 std::move(fun), std::move(cache_miss), std::move(get_device),
1264 std::move(static_argnums), std::move(static_argnames),
1265 std::move(donate_argnums), std::move(cache));
1266 },
1267 py::is_method(cfun_type));
1268 py::class_<GlobalJitState> global_state_(jitlib, "GlobalJitState");
1269 global_state_.def_readwrite("disable_jit", &GlobalJitState::disable_jit);
1270 global_state_.def_readwrite("enable_x64", &GlobalJitState::enable_x64);
1271 global_state_.def_readwrite("extra_jit_context",
1272 &GlobalJitState::extra_jit_context);
1273 global_state_.def_readwrite("post_hook", &GlobalJitState::post_hook);
1274
1275 py::class_<ThreadLocalJitState> thread_local_state_(jitlib,
1276 "ThreadLocalJitState");
1277 thread_local_state_.def_readwrite("disable_jit",
1278 &ThreadLocalJitState::disable_jit);
1279 thread_local_state_.def_readwrite("enable_x64",
1280 &ThreadLocalJitState::enable_x64);
1281 thread_local_state_.def_readwrite("extra_jit_context",
1282 &ThreadLocalJitState::extra_jit_context);
1283 thread_local_state_.def_readwrite("post_hook",
1284 &ThreadLocalJitState::post_hook);
1285
1286 jitlib.def(
1287 "global_state", [&]() { return &global_state; },
1288 py::return_value_policy::reference);
1289 jitlib.def(
1290 "thread_local_state", [&]() { return &thread_local_state; },
1291 py::return_value_policy::reference);
1292
1293 jitlib.def("jit_is_disabled", &JitIsDisabled);
1294 jitlib.def("get_enable_x64", &GetEnableX64);
1295
1296 // TODO(phawkins): delete the following methods after dropping compatibility
1297 // with JAX python versions older than 0.2.10.
1298 jitlib.def("set_disable_jit_cpp_flag",
1299 [&](bool disable_jit) { global_state.disable_jit = disable_jit; });
1300 jitlib.def("get_disable_jit_cpp_flag",
1301 [&]() { return global_state.disable_jit; });
1302 jitlib.def("set_disable_jit_thread_local",
1303 [&](absl::optional<bool> disable_jit) {
1304 thread_local_state.disable_jit = disable_jit;
1305 });
1306 jitlib.def("get_disable_jit_thread_local",
1307 [&]() { return thread_local_state.disable_jit; });
1308 // TODO(jblespiau): Remove from the Python code and remove this
1309 jitlib.def("set_disable_jit", [&](bool disable_jit) {
1310 thread_local_state.disable_jit = disable_jit;
1311 });
1312 jitlib.def("get_disable_jit",
1313 [&]() { return thread_local_state.disable_jit; });
1314
1315 jitlib.def("set_enable_x64_cpp_flag",
1316 [&](bool enable_x64) { global_state.enable_x64 = enable_x64; });
1317 jitlib.def("get_enable_x64_cpp_flag",
1318 [&]() { return global_state.enable_x64; });
1319 jitlib.def("set_enable_x64_thread_local",
1320 [&](absl::optional<bool> enable_x64) {
1321 thread_local_state.enable_x64 = enable_x64;
1322 });
1323 jitlib.def("get_enable_x64_thread_local", [&](bool enable_x64) {
1324 thread_local_state.enable_x64 = enable_x64;
1325 });
1326 // TODO(phawkins): delete up to here.
1327
1328 jitlib.def(
1329 "jit",
1330 [](py::function fun, py::function cache_miss, py::function get_device,
1331 std::vector<int> static_argnums, std::vector<py::str> static_argnames,
1332 std::vector<int> donate_argnums,
1333 std::shared_ptr<CompiledFunctionCache> cache) -> py::object {
1334 return MakeCompiledFunction(
1335 std::move(fun), std::move(cache_miss), std::move(get_device),
1336 std::move(static_argnums), std::move(static_argnames),
1337 std::move(donate_argnums), std::move(cache));
1338 },
1339 py::arg("fun"), py::arg("cache_miss"), py::arg("get_device"),
1340 py::arg("static_argnums"),
1341 py::arg("static_argnames") = std::vector<py::str>(),
1342 py::arg("donate_argnums") = std::vector<int>(),
1343 py::arg("cache") = nullptr);
1344
1345 // This function is not yet a full replacement for the Python one, because:
1346 // (a) it does not support abstract types,
1347 // (b) it does not set the device stickiness yet.
1348 // TODO(jblespiau): Finish the replacement of the Python feature.
1349 jitlib.def("device_put",
1350 [](py::handle obj, bool jax_enable_x64,
1351 xla::ClientAndPtr<xla::PjRtDevice> to_device)
1352 -> xla::StatusOr<py::object> {
1353 std::shared_ptr<xla::PyClient>& pyclient = to_device.client;
1354 xla::DevicePutOptions options;
1355 options.squash_64bit_types = !jax_enable_x64;
1356 options.force_lazy_arrays = true;
1357 options.allow_zero_copy = true;
1358 xla::StatusOr<xla::DevicePutResult> results =
1359 DevicePut(obj, to_device.contents, options);
1360 if (!results.ok()) {
1361 throw std::runtime_error(results.status().error_message());
1362 }
1363 if (results->owned_buffer) {
1364 auto buffer = xla::PyBuffer::Make(
1365 pyclient, std::move(results->owned_buffer),
1366 xla::Traceback::Get());
1367
1368 static const auto* jax_core =
1369 new py::module(py::module::import("jax.core"));
1370 static const auto* shaped_array =
1371 new py::handle(jax_core->attr("ShapedArray"));
1372 buffer.buf()->SetAval((*shaped_array)(
1373 buffer.buf()->python_shape(), buffer.buf()->python_dtype(),
1374 results->weak_type));
1375 TF_RETURN_IF_ERROR(buffer.buf()->set_sticky_device(nullptr));
1376
1377 return std::move(buffer);
1378 } else {
1379 return py::cast<py::object>(obj);
1380 }
1381 });
1382
1383 py::class_<xla::PyArgSignature> arg_signature(jitlib, "PyArgSignature");
1384 arg_signature
1385 .def_property_readonly("dtype",
1386 [](const xla::PyArgSignature& sig) {
1387 return PrimitiveTypeToDtype(sig.dtype);
1388 })
1389 .def_property_readonly(
1390 "shape",
1391 [](const xla::PyArgSignature& sig) {
1392 return xla::SpanToTuple(absl::MakeConstSpan(sig.shape));
1393 })
1394 .def_readonly("weak_type", &xla::PyArgSignature::weak_type);
1395 jitlib.def("_ArgSignatureOfValue", &xla::PyArgSignatureOfValue);
1396
1397 // All private members are only for testing/debugging purposes
1398 cfun.attr("_cache_size") = py::cpp_function(
1399 [](py::handle self) -> xla::StatusOr<int> {
1400 TF_ASSIGN_OR_RETURN(CompiledFunction * fun, AsCompiledFunction(self));
1401 return fun->cache_size();
1402 },
1403 py::is_method(cfun));
1404 cfun.attr("_clear_cache") = py::cpp_function(
1405 [](py::handle self) -> xla::Status {
1406 TF_ASSIGN_OR_RETURN(CompiledFunction * fun, AsCompiledFunction(self));
1407 fun->ClearCache();
1408 return xla::Status::OK();
1409 },
1410 py::is_method(cfun));
1411 jitlib.def("_is_float0", &xla::IsFloat0);
1412 }
1413
1414 } // namespace jax
1415