• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This files implements the `jax.jit` dispatch and just-in-time feature.
17 //
18 // In a nutshell, `Jit(f)` returns a callable that will dispatch (i.e. forward
19 // based on passed arguments dtypes/shapes/identity) the execution to a
20 // just-in-time compiled XLA Executable. All of that is done in C++ for
21 // performance reasons.
22 //
23 // This file contains the utilities to:
24 // (a) inspect arguments and describe their structure, dtype/shapes, etc.
25 // (b) keep a mapping from function signatures to compiled XLA Executables.
26 
27 #include "tensorflow/compiler/xla/python/jax_jit.h"
28 
29 #include <Python.h>
30 
31 #include <algorithm>
32 #include <exception>
33 #include <memory>
34 #include <stdexcept>
35 #include <string>
36 #include <utility>
37 
38 #include "absl/container/flat_hash_map.h"
39 #include "absl/container/inlined_vector.h"
40 #include "absl/strings/str_cat.h"
41 #include "absl/strings/str_format.h"
42 #include "absl/synchronization/notification.h"
43 #include "absl/types/optional.h"
44 #include "absl/types/span.h"
45 #include "pybind11/cast.h"
46 #include "pybind11/numpy.h"
47 #include "pybind11/pybind11.h"
48 #include "pybind11/pytypes.h"
49 #include "tensorflow/compiler/xla/pjrt/lru_cache.h"
50 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
51 #include "tensorflow/compiler/xla/python/py_buffer.h"
52 #include "tensorflow/compiler/xla/python/py_executable.h"
53 #include "tensorflow/compiler/xla/python/py_values.h"
54 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
55 #include "tensorflow/compiler/xla/python/pytree.h"
56 #include "tensorflow/compiler/xla/python/types.h"
57 #include "tensorflow/compiler/xla/shape_util.h"
58 #include "tensorflow/compiler/xla/statusor.h"
59 #include "tensorflow/compiler/xla/types.h"
60 #include "tensorflow/compiler/xla/util.h"
61 #include "tensorflow/compiler/xla/xla_data.pb.h"
62 #include "tensorflow/core/platform/status.h"
63 #include "tensorflow/core/profiler/lib/traceme.h"
64 
65 namespace jax {
66 
67 namespace py = pybind11;
68 
69 // TODO(phawkins): Add support for Tracers.
70 // TODO(jblespiau): Use absl Status.
71 
72 namespace {
73 
74 // Protected by the GIL.
75 GlobalJitState& global_state = *new GlobalJitState();
76 
77 // TODO(phawkins): Google style guide forbids thread-local values with
78 // non-trivial destructors.
79 ABSL_CONST_INIT thread_local ThreadLocalJitState thread_local_state;  // NOLINT
80 
JitIsDisabled()81 bool JitIsDisabled() {
82   return thread_local_state.disable_jit.value_or(global_state.disable_jit);
83 }
84 
85 }  // namespace
86 
GetGlobalState()87 GlobalJitState& GetGlobalState() { return global_state; }
GetLocalState()88 ThreadLocalJitState& GetLocalState() { return thread_local_state; }
89 
GetEnableX64()90 bool GetEnableX64() {
91   return thread_local_state.enable_x64.value_or(global_state.enable_x64);
92 }
93 
DebugString() const94 std::string CallSignature::DebugString() const {
95   auto py_object_formatter = [](std::string* out, const py::object& o) {
96     out->append(py::cast<std::string>(py::str(o)));
97   };
98   auto treedef_formatter = [](std::string* out, const xla::PyTreeDef& d) {
99     out->append(d.ToString());
100   };
101   auto signature_formatter = [](std::string* out,
102                                 const xla::PyArgSignature& s) {
103     out->append(s.DebugString());
104   };
105   std::string thread_local_extra_jit_context_str;
106   if (thread_local_extra_jit_context.has_value()) {
107     thread_local_extra_jit_context_str =
108         py::cast<std::string>(py::str(thread_local_extra_jit_context.value()));
109   } else {
110     thread_local_extra_jit_context_str = "None";
111   }
112   return absl::StrFormat(
113       "static args (positional + keyword): %s\nstatic arg keyword names: %s\n"
114       "dynamic arg signatures (positional + keyword): %s\n"
115       "dynamic arg keyword names: %s\ndynamic arg treedefs: %s\n"
116       "device: %p\n"
117       "jax_enable_x64: %d\n"
118       "global_extra_jit_context: %s\n"
119       "thread_local_extra_jit_context: %s\n",
120       absl::StrJoin(static_args, ",", py_object_formatter),
121       absl::StrJoin(static_arg_names, ",", py_object_formatter),
122       absl::StrJoin(dynamic_arg_signatures, ", ", signature_formatter),
123       absl::StrJoin(dynamic_arg_names, ",", py_object_formatter),
124       absl::StrJoin(dynamic_arg_treedefs, "| ", treedef_formatter),  // new line
125       device, jax_enable_x64,
126       py::cast<std::string>(py::str(global_extra_jit_context)),
127       thread_local_extra_jit_context_str);
128 }
129 
operator ==(const CallSignature & other) const130 bool CallSignature::operator==(const CallSignature& other) const {
131   return std::tie(dynamic_arg_treedefs, dynamic_arg_names,
132                   dynamic_arg_signatures, device, jax_enable_x64,
133                   static_arg_names) ==
134              std::tie(other.dynamic_arg_treedefs, other.dynamic_arg_names,
135                       other.dynamic_arg_signatures, other.device,
136                       other.jax_enable_x64, static_arg_names) &&
137          // `==` on py:objects is the Python `is`. We need equal.
138          std::equal(
139              static_args.begin(), static_args.end(), other.static_args.begin(),
140              other.static_args.end(),
141              [](const py::object& a, const py::object& b) {
142                try {
143                  return a.equal(b);
144                } catch (const py::error_already_set& e) {
145                  throw std::invalid_argument(absl::StrCat(
146                      "static arguments should be comparable using __eq__."
147                      "The following error was raised when comparing two "
148                      "objects of types ",
149                      py::cast<std::string>(py::str(py::type::of(a))), " and ",
150                      py::cast<std::string>(py::str(py::type::of(b))),
151                      ". The error was:\n", e.what()));
152                }
153              }) &&
154          global_extra_jit_context.equal(other.global_extra_jit_context) &&
155          (thread_local_extra_jit_context.has_value() ==
156           other.thread_local_extra_jit_context.has_value()) &&
157          (!thread_local_extra_jit_context.has_value() ||
158           thread_local_extra_jit_context->equal(
159               *other.thread_local_extra_jit_context));
160 }
161 
162 template <typename H>
AbslHashValue(H h,const CallSignature & s)163 H AbslHashValue(H h, const CallSignature& s) {
164   h = H::combine_contiguous(std::move(h), s.dynamic_arg_treedefs.data(),
165                             s.dynamic_arg_treedefs.size());
166   for (const auto& name : s.dynamic_arg_names) {
167     h = H::combine(std::move(h), name.ptr());
168   }
169   h = H::combine_contiguous(std::move(h), s.dynamic_arg_signatures.data(),
170                             s.dynamic_arg_signatures.size());
171   for (const auto& static_arg : s.static_args) {
172     ssize_t hash;
173     try {
174       hash = py::hash(static_arg);
175     } catch (const py::error_already_set& e) {
176       throw std::invalid_argument(absl::StrCat(
177           "Non-hashable static arguments are not supported. An error occured "
178           "while trying to hash an object of type ",
179           py::cast<std::string>(py::str(py::type::of(static_arg))), ", ",
180           py::cast<std::string>(py::str(static_arg)), ". The error was:\n",
181           e.what(), "\n"));
182     }
183     h = H::combine(std::move(h), hash);
184   }
185   for (const auto& name : s.static_arg_names) {
186     h = H::combine(std::move(h), name.ptr());
187   }
188   h = H::combine(std::move(h), s.device, s.jax_enable_x64);
189 
190   // We do not hash the extra_jit_context fields since calling Python hash
191   // functions is expensive (~300ns) and we don't expect a large number of
192   // different contexts.
193   return h;
194 }
195 
196 // Filter out static arguments, flatten and concatenate other arguments (i.e.
197 // dynamic positional and keyword arguments), filling `arguments` in place.
ParseArguments(py::handle args,const absl::optional<py::kwargs> & py_kwargs,absl::Span<int const> static_argnums,absl::Span<py::str const> static_argnames,ParsedArgumentsAsBuffers & arguments)198 xla::Status ParseArguments(py::handle args,
199                            const absl::optional<py::kwargs>& py_kwargs,
200                            absl::Span<int const> static_argnums,
201                            absl::Span<py::str const> static_argnames,
202                            ParsedArgumentsAsBuffers& arguments) {
203   tensorflow::profiler::TraceMe traceme("ParseArguments");
204   int num_args = PyTuple_GET_SIZE(args.ptr());
205   int num_kwargs = py_kwargs ? py_kwargs->size() : 0;
206 
207   arguments.flat_dynamic_args.reserve(num_args + num_kwargs);
208   if (static_argnums.empty()) {
209     arguments.signature.dynamic_arg_treedefs.resize(num_args);
210 
211     // Positional arguments.
212     for (int i = 0; i < num_args; ++i) {
213       xla::PyTreeDef& pytree_def = arguments.signature.dynamic_arg_treedefs[i];
214       pytree_def.FlattenInto(PyTuple_GET_ITEM(args.ptr(), i),
215                              arguments.flat_dynamic_args);
216     }
217   } else {
218     arguments.signature.dynamic_arg_treedefs.reserve(num_args);
219 
220     // Positional arguments.
221     for (int i = 0; i < num_args; ++i) {
222       if (std::find(static_argnums.begin(), static_argnums.end(), i) ==
223           static_argnums.end()) {
224         arguments.signature.dynamic_arg_treedefs.emplace_back();
225         xla::PyTreeDef& pytree_def =
226             arguments.signature.dynamic_arg_treedefs.back();
227         pytree_def.FlattenInto(PyTuple_GET_ITEM(args.ptr(), i),
228                                arguments.flat_dynamic_args);
229       } else {
230         arguments.signature.static_args.emplace_back(
231             py::reinterpret_borrow<py::object>(
232                 PyTuple_GET_ITEM(args.ptr(), i)));
233       }
234     }
235   }
236 
237   // Keyword arguments.
238   if (py_kwargs) {
239     std::vector<std::pair<py::handle, py::handle>> kwargs(py_kwargs->begin(),
240                                                           py_kwargs->end());
241     // We first intern the keys, then sort them (by name, as in the Python path)
242     // (see also xla::PyTreeDef::Flatten) and then create the signatures.
243     // TODO(jblespiau): We should be able to sort the keys by interned-key
244     // pointers, but this requires the Python compilation to do the same.
245     for (int i = 0; i < num_kwargs; ++i) {
246       // Intern the key if not already interned.
247       kwargs[i].first.inc_ref();
248       if (!PyUnicode_CHECK_INTERNED(kwargs[i].first.ptr())) {
249         PyUnicode_InternInPlace(&kwargs[i].first.ptr());
250       }
251     }
252 
253     std::sort(kwargs.begin(), kwargs.end(),
254               [](const std::pair<py::handle, py::handle>& a,
255                  const std::pair<py::handle, py::handle>& b) {
256                 return a.first < b.first;
257               });
258     auto kwarg_is_static = [&](py::handle name) {
259       for (const auto& kw : static_argnames) {
260         if (kw.ptr() == name.ptr()) return true;
261       }
262       return false;
263     };
264 
265     arguments.signature.dynamic_arg_names.reserve(num_kwargs);
266     for (int i = 0; i < num_kwargs; ++i) {
267       if (kwarg_is_static(kwargs[i].first)) {
268         arguments.signature.static_arg_names.push_back(
269             py::reinterpret_steal<py::object>(kwargs[i].first));
270         arguments.signature.static_args.push_back(
271             py::reinterpret_borrow<py::object>(kwargs[i].second));
272       } else {
273         arguments.signature.dynamic_arg_names.push_back(
274             py::reinterpret_steal<py::object>(kwargs[i].first));
275         arguments.signature.dynamic_arg_treedefs.emplace_back();
276         xla::PyTreeDef& pytree_def =
277             arguments.signature.dynamic_arg_treedefs.back();
278         pytree_def.FlattenInto(kwargs[i].second, arguments.flat_dynamic_args);
279       }
280     }
281   }
282   return xla::Status::OK();
283 }
284 
285 namespace {
286 
287 // Elements of CacheEntry are protected by the GIL.
288 struct CacheEntry {
289   // Ensures a single thread performs the compilation for a given executable.
290   //
291   // The first thread (holding the GIL) will create the CacheEntry associated to
292   // a signature and fill it. Other threads will wait for the notification.
293   // If an error occured during the compilation, `fall_back_to_python` is set
294   // to `true`, and other threads will fail with the same error.
295   absl::Notification compilation_complete;
296 
297   std::shared_ptr<xla::PyExecutable> executable;
298   xla::PyTreeDef out_pytree_def;
299   // We use Python types within the vector because this is what we will be
300   // returning to Python. No need to convert back and forth.
301   // We need py::object to maintain the objects alive.
302   std::vector<py::object> out_avals;
303   std::vector<bool> out_weak_types;
304 
305   // The processing done in `AddCacheEntry` ensures that LazyExpr are stored as
306   // `py::none()`.
307   std::vector<py::object> out_lazy_exprs;
308 
309   // Bitvector of kept arguments from Jaxpr DCE pass. Used to drop some `args`
310   // in CompiledFunction::Call before calling into compiled computation.
311   absl::optional<std::vector<bool>> kept_var_bitvec;
312   xla::PjRtDevice* sticky_device;
313 
314   // Fallback to Python happens:
315   // - for trivial computations
316   // - when running a jax(pmap)
317   // - after a compilation error, for threads that did not compile it the first
318   //   time
319   bool fall_back_to_python = false;
320 
321   // Python objects (notably in the cache key) that must remain alive as long
322   // as the cache entry does. Currently this is the `key` values in the kwarg
323   // entries in the cache key.
324   std::vector<py::object> keepalive;
325 };
326 
327 // A CompiledFunctionCache represents a cache of compiled functions that can be
328 // shared between one or more CompiledFunction objects. It serves two goals:
329 // - reduce the number of lru caches (hash map) across multiple JITs.
330 // - make the cache global to increase cache hits (e.g. calling jit(f)(3) twice)
331 //   keeping entries alive as long as the underlying function f is alive.
332 // Assume the cache is protected by the GIL.
333 class CompiledFunctionCache {
334  public:
335   static constexpr int kDefaultCapacity = 4096;
336   explicit CompiledFunctionCache(int capacity);
337 
338   // Cache entries are shared_ptr<>s because it's possible the cache entry
339   // might be evicted before we finish tracing/compiling.
340   typedef xla::LRUCache<CallSignature, std::shared_ptr<CacheEntry>> Cache;
341 
342   // The lifetime of the returned cache lasts at least as long as the lifetime
343   // of `function`, so if the caller holds a strong reference to `function`, the
344   // `Cache` remains valid.
345   // We include as part of the cache key `donate_argnums` (and any other fields
346   // that aren't subsumed by the CallSignature we compute for each call).
347   Cache* Lookup(py::handle function, absl::Span<const int> donate_argnums);
348 
Size() const349   int Size() const { return lru_list_.Size(); }
Capacity() const350   int Capacity() const { return lru_list_.Capacity(); }
Clear()351   void Clear() { lru_list_.Clear(); }
352 
353  private:
354   struct Key {
355     py::handle function;  // Does not hold a reference.
356 
357     // Other fields that are part of the arguments to `jit`, but are not
358     // otherwise part of CallSignature.
359     std::vector<int> donate_argnums;
360 
operator ==jax::__anon7dabdf6d0811::CompiledFunctionCache::Key361     bool operator==(const Key& other) const {
362       return std::tie(function, donate_argnums) ==
363              std::tie(other.function, other.donate_argnums);
364     }
365   };
366   template <typename H>
AbslHashValue(H h,const Key & key)367   friend H AbslHashValue(H h, const Key& key) {
368     h = H::combine(std::move(h), key.function.ptr());
369     h = H::combine_contiguous(std::move(h), key.donate_argnums.data(),
370                               key.donate_argnums.size());
371     return h;
372   }
373 
374   struct Value {
Valuejax::__anon7dabdf6d0811::CompiledFunctionCache::Value375     explicit Value(Cache::LRUList* lru_list) : cache(lru_list) {}
376     Cache cache;
377 
378     // A weak reference to the key function. We use the weak reference to
379     // register a callback that is triggered when the key function is destroyed.
380     // We use a weak pointer because we want to allow caching across multiple
381     // calls to `jax.jit(f)` if `f` remains alive, but we do not want the cache
382     // to keep `f` alive if all other references are dropped.
383     py::weakref weakref;
384   };
385 
386   Cache::LRUList lru_list_;
387   absl::flat_hash_map<Key, std::unique_ptr<Value>> functions_;
388 };
389 
CompiledFunctionCache(int capacity)390 CompiledFunctionCache::CompiledFunctionCache(int capacity)
391     : lru_list_(capacity) {}
392 
Lookup(py::handle function,absl::Span<const int> donate_argnums)393 CompiledFunctionCache::Cache* CompiledFunctionCache::Lookup(
394     py::handle function, absl::Span<const int> donate_argnums) {
395   Key key;
396   key.function = function;
397   key.donate_argnums =
398       std::vector<int>(donate_argnums.begin(), donate_argnums.end());
399   auto insert = functions_.emplace(key, nullptr);
400   std::unique_ptr<Value>& entry = insert.first->second;
401   if (insert.second) {
402     entry = std::make_unique<Value>(&lru_list_);
403     entry->weakref = py::weakref(
404         function,
405         py::cpp_function([this, key{std::move(key)}](py::handle weakref) {
406           functions_.erase(key);
407         }));
408   }
409   return &entry->cache;
410 }
411 
412 // A `CompiledFunction` is associated to a `jax.jit(f)` and takes care of the
413 // bookkeeping of the different signatures used and the dispatch of calls to
414 // the correct underlying `PyExecutable`. This class is thread-safe.
415 class CompiledFunction {
416  public:
417   CompiledFunction(py::function fun, py::function cache_miss,
418                    py::function get_device, std::vector<int> static_argnums,
419                    std::vector<py::str> static_argnames,
420                    std::vector<int> donate_argnums,
421                    std::shared_ptr<CompiledFunctionCache> cache);
422   ~CompiledFunction();
423 
424   // pybind11::object typed subclass for CompiledFunction objects.
425   class pyobject : public pybind11::object {
426    public:
427     PYBIND11_OBJECT(pyobject,  // NOLINT
428                     pybind11::object, CompiledFunction::IsCompiledFunction);
429     pyobject() = default;
func() const430     CompiledFunction* func() const {
431       return CompiledFunction::AsCompiledFunctionUnchecked(*this);
432     }
433   };
434   // Alias as ::object; outside the scope above we won't confuse pybind11's
435   // macros.
436   using object = pyobject;
437 
438   // Returns true if `h` is a CompiledFunction.
439   static bool IsCompiledFunction(pybind11::handle handle);
440   // Converts `handle` to a CompiledFunction*. Does not do any checking.
441   static CompiledFunction* AsCompiledFunctionUnchecked(pybind11::handle handle);
442 
443   // This function will:
444   // (a) flatten the inputs using pytree
445   // (b) get buffer objects from the arguments
446   // (c) call the executable
447   // (d) construct `DeviceArray` objects from the outputs
448   // (e) reconstruct the `PyTree`.
449   xla::StatusOr<py::object> Call(py::handle args,
450                                  absl::optional<py::kwargs> kwargs);
451 
452   // This allows `inspect.signature(cpp_jitted_f)` from Python.
PythonSignature()453   py::object PythonSignature() {
454     static const auto* inspect = new py::module(py::module::import("inspect"));
455     return inspect->attr("signature")(fun_);
456   }
457 
cache_size() const458   int cache_size() const { return executables_->Size(); }
ClearCache()459   void ClearCache() { executables_->Clear(); }
460 
fun() const461   const py::function& fun() const { return fun_; }
cache_miss() const462   const py::function& cache_miss() const { return cache_miss_; }
get_device() const463   const py::function& get_device() const { return get_device_; }
static_argnums() const464   const std::vector<int>& static_argnums() const { return static_argnums_; }
static_argnames() const465   const std::vector<py::str>& static_argnames() const {
466     return static_argnames_;
467   }
donate_argnums() const468   const std::vector<int>& donate_argnums() const { return donate_argnums_; }
cache() const469   const std::shared_ptr<CompiledFunctionCache>& cache() const { return cache_; }
470 
471   // Helper function used by the tp_clear GC method.
ClearPythonReferences()472   void ClearPythonReferences() {
473     py::function fun, cache_miss, get_device;
474     // Swap values for nulls before they are destroyed. See the Python
475     // Py_CLEAR() documentation for a discussion of this topic.
476     std::swap(fun_, fun);
477     std::swap(cache_miss_, cache_miss);
478     std::swap(get_device_, get_device);
479   }
480 
481   py::handle AsPyHandle();
function_name() const482   const std::string& function_name() const { return function_name_; }
483 
484  private:
485   // Attempts to populate default_device_. May release the GIL; is
486   // reentrant-safe.
487   void TryToPopulateDefaultDevice();
488 
489   void PopulateCacheEntry(CacheEntry* entry, const CallSignature& signature,
490                           const py::tuple& out_and_fastpath_data);
491   bool always_fallback_to_python_ = false;
492 
493   py::function fun_;  // The Python function to jit.
494   std::string function_name_;
495 
496   // See JAX _cpp_jit in api.py for documentation.
497   py::function cache_miss_;
498 
499   // We need to know the static arguments to remove them from the arguments
500   // passed to the underlying PyExecutable. In sorted order.
501   std::vector<int> static_argnums_;
502   // Keyword arguments, interned.
503   std::vector<py::str> static_argnames_;
504   std::vector<int> donate_argnums_;
505 
506   // A function taking no arguments and returning the default device and whether
507   // jax.jit has been committed to it.
508   py::function get_device_;
509 
510   // Keeps the shared LRU cache alive as long as the CompiledFunction is alive.
511   std::shared_ptr<CompiledFunctionCache> cache_;
512 
513   // The part of cache_ specific to this CompiledFunction.
514   CompiledFunctionCache::Cache* executables_;
515 
516   // The logic if the following:
517   // - if `device` or `backend` are not specified to `jax.jit`, we will use
518   //   the input sticky buffer device, or `default_device_` if there is no
519   //   such sticky buffer.
520   // - When one of `device` or `backend` is specified, this will determine
521   //   the `default_device_` which will be used as the targeted device. In
522   //   which case, we will always copy input buffers to this device.
523   // These fields are protected by the GIL.
524   std::shared_ptr<xla::PyClient> default_pyclient_ = nullptr;
525   xla::PjRtDevice* default_device_ = nullptr;
526   bool is_committed_;
527 };
528 
CompiledFunction(py::function fun,py::function cache_miss,py::function get_device,std::vector<int> static_argnums,std::vector<py::str> static_argnames,std::vector<int> donate_argnums,std::shared_ptr<CompiledFunctionCache> cache)529 CompiledFunction::CompiledFunction(py::function fun, py::function cache_miss,
530                                    py::function get_device,
531                                    std::vector<int> static_argnums,
532                                    std::vector<py::str> static_argnames,
533                                    std::vector<int> donate_argnums,
534                                    std::shared_ptr<CompiledFunctionCache> cache)
535     : fun_(std::move(fun)),
536       cache_miss_(std::move(cache_miss)),
537       static_argnums_(std::move(static_argnums)),
538       static_argnames_(std::move(static_argnames)),
539       donate_argnums_(donate_argnums),
540       get_device_(std::move(get_device)),
541       cache_(std::move(cache)) {
542   std::sort(static_argnums_.begin(), static_argnums_.end());
543   for (py::str& s : static_argnames) {
544     PyUnicode_InternInPlace(&s.ptr());
545   }
546   executables_ = cache_->Lookup(fun_, donate_argnums);
547   function_name_ = py::str(py::getattr(fun_, "__name__", fun));
548 }
549 
550 CompiledFunction::~CompiledFunction() = default;
551 
552 // Compute signature for arguments.
553 //
554 // Returns `Status::OK()` on success. Returning an error should lead to
555 // calling the Python fallback.
ComputeSignature(bool jax_enable_x64,xla::PyClient & pyclient,xla::PjRtDevice * default_device,bool is_committed,ParsedArgumentsAsBuffers & arguments)556 xla::Status ComputeSignature(bool jax_enable_x64, xla::PyClient& pyclient,
557                              xla::PjRtDevice* default_device, bool is_committed,
558                              ParsedArgumentsAsBuffers& arguments) {
559   tensorflow::profiler::TraceMe traceme("ComputeSignature");
560 
561   int num_flat_dynamic_args = arguments.flat_dynamic_args.size();
562   struct PythonTypes {
563     py::object device_array;
564   };
565   static const auto& types = *[]() -> PythonTypes* {
566     py::module xla_module(py::module::import("jax.interpreters.xla"));
567     py::object device_array(xla_module.attr("_DeviceArray"));
568     return new PythonTypes{device_array};
569   }();
570   // When the jitted function is not committed, we first check whether any
571   // sticky `DeviceArray` is present and on which device they live. See also:
572   // https://github.com/google/jax/pull/1884
573   // https://github.com/google/jax/pull/1916 for the rationale why the
574   // computation follows the data locality.
575   // It's also similar to PyTorch's behavior.
576   xla::PjRtDevice* data_device = nullptr;
577   if (is_committed) {
578     data_device = default_device;
579   } else {
580     for (int i = 0; i < num_flat_dynamic_args; ++i) {
581       py::handle arg = arguments.flat_dynamic_args[i];
582       // We specically only deal with DeviceArray (not ShardedDeviceArray).
583       // (Can happen in jit(pmap), e.g. "test_jit_nested_donate_ignored").
584       xla::PjRtDevice* device = nullptr;
585       if (arg.get_type().ptr() == xla::PyBuffer::type()) {
586         xla::PyBuffer* buffer = xla::PyBuffer::AsPyBufferUnchecked(arg);
587         if (!buffer->sticky_device()) {
588           continue;
589         }
590         device = buffer->sticky_device();
591       } else if (arg.get_type().ptr() == types.device_array.ptr()) {
592         if (arg.attr("_device").is_none()) {  // Skip non-sticky devices.
593           continue;
594         }
595         try {
596           // This can fail, e.g. for cloud TPU 2VM buffers.
597           TF_ASSIGN_OR_RETURN(
598               xla::PyBuffer * buffer,
599               xla::PyBuffer::AsPyBuffer(arg.attr("device_buffer")));
600           device = buffer->buffer()->device();
601         } catch (const py::cast_error& e) {
602           return xla::InvalidArgument(
603               "%s",
604               absl::StrCat("[jaxjit] Unsupported subclass of `DeviceArray`: "
605                            "`device_buffer` field is of type ",
606                            py::cast<std::string>(
607                                arg.attr("device_buffer").get_type().str()),
608                            " while a `PyBuffer` was expected."
609 
610                            ));
611         }
612       }
613       if (device) {
614         if (data_device && (device != data_device)) {
615           throw std::invalid_argument(absl::StrCat(
616               "primitive arguments must be colocated on the same device ("
617               "C++ jax.jit). Arguments are on devices: ",
618               device->DebugString(), " and ", data_device->DebugString()));
619         } else {
620           data_device = device;
621         }
622       }
623     }
624   }
625   if (!data_device) {
626     // No `DeviceArray` were found default to `default_device`.
627     data_device = default_device;
628   }
629   CHECK(data_device);
630   arguments.signature.device = data_device;
631 
632   arguments.signature.dynamic_arg_signatures.reserve(num_flat_dynamic_args);
633   for (int i = 0; i < num_flat_dynamic_args; ++i) {
634     py::handle arg = arguments.flat_dynamic_args[i];
635     TF_ASSIGN_OR_RETURN(auto sig,
636                         xla::PyArgSignatureOfValue(arg, jax_enable_x64));
637     arguments.signature.dynamic_arg_signatures.push_back(std::move(sig));
638   }
639   return xla::Status::OK();
640 }
641 
642 // Copy buffers to device, skipping pruned arguments.
643 // Returns `Status::OK()` on success. Returning an error should lead to
644 // calling the Python fallback.
CopyBuffersToDevice(bool jax_enable_x64,const absl::optional<std::vector<bool>> & kept_args,ParsedArgumentsAsBuffers & arguments)645 xla::Status CopyBuffersToDevice(
646     bool jax_enable_x64, const absl::optional<std::vector<bool>>& kept_args,
647     ParsedArgumentsAsBuffers& arguments) {
648   std::vector<xla::PjRtBuffer*>& arg_buffers = arguments.arg_buffers;
649   xla::PjRtDevice* data_device = arguments.signature.device;
650 
651   int num_flat_dynamic_args = arguments.flat_dynamic_args.size();
652   xla::DevicePutOptions options;
653   options.squash_64bit_types = !jax_enable_x64;
654   // TODO(phawkins): consider allowing forces here.
655   options.force_lazy_arrays = false;
656   options.allow_zero_copy = true;
657   arg_buffers.reserve(num_flat_dynamic_args);
658   bool input_pruning_enabled = kept_args.has_value();
659   for (int i = 0; i < num_flat_dynamic_args; ++i) {
660     if (input_pruning_enabled && !kept_args.value()[i]) {
661       continue;
662     }
663 
664     py::handle arg = arguments.flat_dynamic_args[i];
665     TF_ASSIGN_OR_RETURN(xla::DevicePutResult on_device,
666                         DevicePut(arg, data_device, options));
667 
668     xla::PjRtBuffer* buffer = on_device.buffer;
669     arg_buffers.push_back(buffer);
670     if (on_device.owned_buffer) {
671       arguments.keep_alive.push_back(std::move(on_device.owned_buffer));
672     } else if (on_device.owning_pybuffer) {
673       arguments.keep_alive_objects.push_back(
674           std::move(on_device.owning_pybuffer));
675     }
676   }
677   return xla::Status::OK();
678 }
679 
PopulateCacheEntry(CacheEntry * cache_entry,const CallSignature & signature,const py::tuple & out_and_fastpath_data)680 void CompiledFunction::PopulateCacheEntry(
681     CacheEntry* cache_entry, const CallSignature& signature,
682     const py::tuple& out_and_fastpath_data) {
683   CHECK_EQ(out_and_fastpath_data.size(), 2);
684   if (out_and_fastpath_data[1].is_none()) {
685     cache_entry->fall_back_to_python = true;
686     return;
687   }
688 
689   py::tuple executable_handlers_out_tree =
690       py::cast<py::tuple>(out_and_fastpath_data[1]);
691   // TODO(zhangqiaorjc): Lookup NamedTuple by name after min jax version bump.
692   size_t arity = executable_handlers_out_tree.size();
693   if (arity != 5 && !py::hasattr(executable_handlers_out_tree, "_fields")) {
694     throw std::runtime_error(absl::StrCat(
695         "The versions of jaxlib and Jax are incompatible (jaxlib is too recent "
696         "compared to Jax. Upgrade Jax is advised. The C++ code expects "
697         "5 or 6 arguments but ",
698         arity, " were provided: ",
699         py::cast<std::string>(
700             py::str(py::repr(executable_handlers_out_tree)))));
701   }
702   // (xla_executable, out_pytree_def, sticky_device, avals, lazy_exprs)
703   auto executable = py::cast<std::shared_ptr<xla::PyExecutable>>(
704       executable_handlers_out_tree[0]);
705   cache_entry->executable = std::move(executable);
706   int num_devices =
707       cache_entry->executable->pjrt_executable().addressable_devices().size();
708   // The presence of jit(pmap) is detected from Python.
709   CHECK_EQ(num_devices, 1);
710 
711   auto out_tree = py::cast<xla::PyTreeDef>(executable_handlers_out_tree[1]);
712   cache_entry->out_pytree_def = std::move(out_tree);
713 
714   cache_entry->sticky_device =
715       py::cast<xla::PjRtDevice*>(executable_handlers_out_tree[2]);
716   auto avals = py::cast<py::list>(executable_handlers_out_tree[3]);
717   auto lazy_exprs = py::cast<py::list>(executable_handlers_out_tree[4]);
718   CHECK_EQ(avals.size(), lazy_exprs.size());
719 
720   cache_entry->out_avals.reserve(avals.size());
721   cache_entry->out_weak_types.reserve(avals.size());
722   cache_entry->out_lazy_exprs.reserve(avals.size());
723   for (int i = 0; i < avals.size(); ++i) {
724     py::object shaped_array = py::reinterpret_borrow<py::object>(avals[i]);
725     py::object lazy_expr = py::reinterpret_borrow<py::object>(lazy_exprs[i]);
726 
727     cache_entry->out_avals.push_back(shaped_array);
728     cache_entry->out_weak_types.push_back(
729         py::cast<bool>(shaped_array.attr("weak_type")));
730     cache_entry->out_lazy_exprs.push_back(lazy_expr);
731   }
732   auto kept_var_bitvec_attr =
733       py::getattr(executable_handlers_out_tree, "kept_var_bitvec", py::none());
734   if (!kept_var_bitvec_attr.is_none()) {
735     auto kept_var_bitvec = py::cast<py::list>(kept_var_bitvec_attr);
736     cache_entry->kept_var_bitvec =
737         absl::make_optional<std::vector<bool>>(kept_var_bitvec.size(), false);
738     for (int i = 0; i < kept_var_bitvec.size(); ++i) {
739       cache_entry->kept_var_bitvec.value()[i] =
740           py::cast<bool>(kept_var_bitvec[i]);
741     }
742   }
743 }
744 
TryToPopulateDefaultDevice()745 void CompiledFunction::TryToPopulateDefaultDevice() {
746   // The following line calls Python and may release the GIL.
747   py::object device_and_is_committed;
748   try {
749     device_and_is_committed = get_device_();
750   } catch (py::error_already_set& e) {
751     // Backend or device initialization failed. Handle this in Python.
752     always_fallback_to_python_ = true;
753     return;
754   }
755   // If the GIL was released by the call to get_device_, another thread may
756   // have filled in default_device_.
757   if (!default_device_) {
758     try {
759       auto default_pydevice = py::cast<xla::ClientAndPtr<xla::PjRtDevice>>(
760           device_and_is_committed.attr("default_device"));
761       is_committed_ =
762           py::cast<bool>(device_and_is_committed.attr("committed_to_device"));
763       default_pyclient_ = default_pydevice.client;
764       default_device_ = default_pydevice.contents;
765     } catch (const py::cast_error& e) {
766       // Pathways, Cloud TPU 2VM, and UPTC runtime.
767       always_fallback_to_python_ = true;
768     }
769   }
770 }
771 
Call(py::handle args,absl::optional<py::kwargs> kwargs)772 xla::StatusOr<py::object> CompiledFunction::Call(
773     py::handle args, absl::optional<py::kwargs> kwargs) {
774   // Make sure we trigger a garbage collection on JIT function calls. Otherwise
775   // code like
776   // f = jit(...)
777   // while True:
778   //   f(x)
779   // may never free temporary buffers for copies of arguments.
780   xla::GlobalPyRefManager()->MaybeCollectGarbage();
781 
782   auto& tls = thread_local_state;
783   if (tls.disable_jit.value_or(global_state.disable_jit)) {
784     return fun_(*py::reinterpret_borrow<py::args>(args),
785                 **kwargs.value_or(py::kwargs()));
786   }
787   if (always_fallback_to_python_) {
788     return py::object(
789         py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
790                                         **kwargs.value_or(py::kwargs())))[0]);
791   }
792 
793   // On the first call to `Call`, compute a default device. We need to wait
794   // until after platform initialization is complete before doing so, but @jit
795   // may be used as a decorator.
796   if (!default_device_) {
797     TryToPopulateDefaultDevice();
798     if (!default_device_) {
799       return py::object(py::cast<py::tuple>(
800           cache_miss_(*py::reinterpret_borrow<py::args>(args),
801                       **kwargs.value_or(py::kwargs())))[0]);
802     }
803   }
804 
805   ParsedArgumentsAsBuffers arguments;
806   xla::Status status = ParseArguments(args, kwargs, static_argnums_,
807                                       static_argnames_, arguments);
808   if (!status.ok()) {
809     VLOG(2) << "ParseArguments failed: " << status;
810     return py::object(
811         py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
812                                         **kwargs.value_or(py::kwargs())))[0]);
813   }
814 
815   bool jax_enable_x64 = tls.enable_x64.value_or(global_state.enable_x64);
816   arguments.signature.jax_enable_x64 = jax_enable_x64;
817   // The C++ jit do not support Tracers arguments inputs yet. The Python-based
818   // jit function will be called if any of the dynamic arguments is unsupported.
819   status = ComputeSignature(jax_enable_x64, *default_pyclient_, default_device_,
820                             is_committed_, arguments);
821   if (!status.ok()) {
822     VLOG(2) << "ComputeSignature failed: " << status;
823     return py::object(
824         py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
825                                         **kwargs.value_or(py::kwargs())))[0]);
826   }
827   arguments.signature.global_extra_jit_context = global_state.extra_jit_context;
828   arguments.signature.thread_local_extra_jit_context = tls.extra_jit_context;
829 
830   bool inserted = false;
831   std::shared_ptr<CacheEntry> cache_entry = executables_->GetOrCreateIfAbsent(
832       arguments.signature, [&inserted](const CallSignature& key) {
833         inserted = true;
834         return std::make_shared<CacheEntry>();
835       });
836 
837   if (!cache_entry->compilation_complete.HasBeenNotified()) {
838     // In case of several threads attempting to compile the executable, only
839     // the one that inserted the item will perform the compilation.
840     if (inserted) {
841       py::object out_and_fastpath_data;
842       py::tuple out_tuple;
843       VLOG(2) << "Cache miss for " << arguments.signature.DebugString();
844       try {
845         // Calls Python and may release the GIL. May also throw if
846         // compilation/tracing fails.
847         out_and_fastpath_data = out_and_fastpath_data =
848             cache_miss_(*py::reinterpret_borrow<py::args>(args),
849                         **kwargs.value_or(py::kwargs()));
850         out_tuple = py::cast<py::tuple>(out_and_fastpath_data);
851         PopulateCacheEntry(cache_entry.get(), arguments.signature, out_tuple);
852       } catch (const std::exception& e) {
853         cache_entry->fall_back_to_python = true;
854         cache_entry->compilation_complete.Notify();
855         throw;
856       }
857       cache_entry->compilation_complete.Notify();
858 
859       // We have already computed the result in the miss path so we can return
860       // it. We are even *required* to do so if there are donated arguments,
861       // because any donated buffers will now be invalid.
862       return py::object(out_tuple[0]);
863     } else {
864       // Release the GIL while we wait, making sure the compile thread can
865       // lock it.
866       py::gil_scoped_release release;
867       cache_entry->compilation_complete.WaitForNotification();
868     }
869   }
870   // It's hard to reraise the exact same kind of errors when a compilation error
871   // occured. If the first compilation failed, other threads will also execute
872   // the Python path.
873   if (cache_entry->fall_back_to_python) {
874     return py::object(
875         py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
876                                         **kwargs.value_or(py::kwargs())))[0]);
877   }
878 
879   status = CopyBuffersToDevice(jax_enable_x64, cache_entry->kept_var_bitvec,
880                                arguments);
881   if (!status.ok()) {
882     VLOG(2) << "CopyBuffersToDevice failed: " << status;
883     return py::object(
884         py::cast<py::tuple>(cache_miss_(*py::reinterpret_borrow<py::args>(args),
885                                         **kwargs.value_or(py::kwargs())))[0]);
886   }
887 
888   // Executes the computation.
889   std::vector<std::vector<std::unique_ptr<xla::PjRtBuffer>>> output_buffers;
890   {
891     py::gil_scoped_release gil_release;
892     TF_ASSIGN_OR_RETURN(
893         output_buffers,
894         cache_entry->executable->mutable_pjrt_executable()->Execute(
895             {arguments.arg_buffers}, cache_entry->executable->options()));
896   }
897   auto traceback = xla::Traceback::Get();
898 
899   int num_outputs = output_buffers[0].size();
900   absl::InlinedVector<py::object, 1> flat_device_arrays;
901   flat_device_arrays.reserve(num_outputs);
902   for (int i = 0; i < output_buffers[0].size(); ++i) {
903     bool last = (i == (num_outputs - 1));
904     xla::PyBuffer::object buffer = xla::PyBuffer::Make(
905         cache_entry->executable->client(), std::move(output_buffers[0][i]),
906         last ? std::move(traceback) : traceback);
907     if (cache_entry->out_lazy_exprs[i].is_none()) {  // No LazyExpr.
908       buffer.buf()->SetAval(cache_entry->out_avals[i]);
909       buffer.buf()->set_weak_type(cache_entry->out_weak_types[i]);
910       TF_RETURN_IF_ERROR(
911           buffer.buf()->set_sticky_device(cache_entry->sticky_device));
912       flat_device_arrays.push_back(std::move(buffer));
913     } else {
914       static const auto* xla_module =
915           new py::module(py::module::import("jax.interpreters.xla"));
916       static const auto* device_array =
917           new py::handle(xla_module->attr("_DeviceArray"));
918       flat_device_arrays.push_back(
919           (*device_array)(cache_entry->out_avals[i],
920                           py::cast(WrapWithClient(default_pyclient_,
921                                                   cache_entry->sticky_device)),
922                           cache_entry->out_lazy_exprs[i], std::move(buffer)));
923     }
924   }
925   py::object out = cache_entry->out_pytree_def.Unflatten(flat_device_arrays);
926 
927   // If there is a post-hook function, call it with the inputs and the outputs.
928   absl::optional<py::object> post_hook =
929       tls.post_hook.has_value() ? tls.post_hook : global_state.post_hook;
930   if (post_hook) {
931     (*post_hook)(AsPyHandle(), args,
932                  py::cast<py::dict>(kwargs.value_or(py::kwargs())), out);
933   }
934   return std::move(out);
935 }
936 
937 struct JaxCompiledFunctionObject {
938   PyObject_HEAD;
939   PyObject* dict;      // Dictionary for __dict__
940   PyObject* weakrefs;  // Weak references; for use by the Python interpreter.
941   CompiledFunction fun;
942 };
943 
944 PyObject* JaxCompiledFunction_Type = nullptr;
945 
IsCompiledFunction(py::handle handle)946 bool CompiledFunction::IsCompiledFunction(py::handle handle) {
947   return handle.get_type() == JaxCompiledFunction_Type;
948 }
949 
AsCompiledFunctionUnchecked(py::handle handle)950 CompiledFunction* CompiledFunction::AsCompiledFunctionUnchecked(
951     py::handle handle) {
952   return &(reinterpret_cast<JaxCompiledFunctionObject*>(handle.ptr())->fun);
953 }
954 
AsCompiledFunction(py::handle handle)955 xla::StatusOr<CompiledFunction*> AsCompiledFunction(py::handle handle) {
956   if (!CompiledFunction::IsCompiledFunction(handle)) {
957     return xla::InvalidArgument("Expected a CompiledFunction");
958   }
959   return CompiledFunction::AsCompiledFunctionUnchecked(handle);
960 }
961 
AsPyHandle()962 py::handle CompiledFunction::AsPyHandle() {
963   return reinterpret_cast<PyObject*>(reinterpret_cast<char*>(this) -
964                                      offsetof(JaxCompiledFunctionObject, fun));
965 }
966 
967 extern "C" {
968 
JaxCompiledFunction_tp_new(PyTypeObject * subtype,PyObject * args,PyObject * kwds)969 PyObject* JaxCompiledFunction_tp_new(PyTypeObject* subtype, PyObject* args,
970                                      PyObject* kwds) {
971   JaxCompiledFunctionObject* self =
972       reinterpret_cast<JaxCompiledFunctionObject*>(
973           subtype->tp_alloc(subtype, 0));
974   if (!self) return nullptr;
975   self->dict = nullptr;
976   self->weakrefs = nullptr;
977   return reinterpret_cast<PyObject*>(self);
978 }
979 
JaxCompiledFunction_tp_dealloc(PyObject * self)980 void JaxCompiledFunction_tp_dealloc(PyObject* self) {
981   PyTypeObject* tp = Py_TYPE(self);
982   JaxCompiledFunctionObject* o =
983       reinterpret_cast<JaxCompiledFunctionObject*>(self);
984   if (o->weakrefs) {
985     PyObject_ClearWeakRefs(self);
986   }
987   Py_CLEAR(o->dict);
988   o->fun.~CompiledFunction();
989   tp->tp_free(self);
990   Py_DECREF(tp);
991 }
992 
JaxCompiledFunction_tp_traverse(PyObject * self,visitproc visit,void * arg)993 int JaxCompiledFunction_tp_traverse(PyObject* self, visitproc visit,
994                                     void* arg) {
995   JaxCompiledFunctionObject* o =
996       reinterpret_cast<JaxCompiledFunctionObject*>(self);
997   Py_VISIT(o->dict);
998   Py_VISIT(o->fun.fun().ptr());
999   Py_VISIT(o->fun.cache_miss().ptr());
1000   Py_VISIT(o->fun.get_device().ptr());
1001   return 0;
1002 }
1003 
JaxCompiledFunction_tp_clear(PyObject * self)1004 int JaxCompiledFunction_tp_clear(PyObject* self) {
1005   JaxCompiledFunctionObject* o =
1006       reinterpret_cast<JaxCompiledFunctionObject*>(self);
1007   Py_CLEAR(o->dict);
1008   o->fun.ClearPythonReferences();
1009   return 0;
1010 }
1011 
1012 // Implements the Python descriptor protocol so JIT-compiled functions can be
1013 // used as bound methods. See:
1014 // https://docs.python.org/3/howto/descriptor.html#functions-and-methods
JaxCompiledFunction_tp_descr_get(PyObject * self,PyObject * obj,PyObject * type)1015 PyObject* JaxCompiledFunction_tp_descr_get(PyObject* self, PyObject* obj,
1016                                            PyObject* type) {
1017   if (obj == nullptr || obj == Py_None) {
1018     Py_INCREF(self);
1019     return self;
1020   }
1021   return PyMethod_New(self, obj);
1022 }
1023 
1024 // Support d = instance.__dict__.
JaxCompiledFunction_get_dict(PyObject * self,void *)1025 PyObject* JaxCompiledFunction_get_dict(PyObject* self, void*) {
1026   JaxCompiledFunctionObject* o =
1027       reinterpret_cast<JaxCompiledFunctionObject*>(self);
1028   if (!o->dict) {
1029     o->dict = PyDict_New();
1030   }
1031   Py_XINCREF(o->dict);
1032   return o->dict;
1033 }
1034 
JaxCompiledFunction_set_dict(PyObject * self,PyObject * new_dict,void *)1035 int JaxCompiledFunction_set_dict(PyObject* self, PyObject* new_dict, void*) {
1036   JaxCompiledFunctionObject* o =
1037       reinterpret_cast<JaxCompiledFunctionObject*>(self);
1038   if (!PyDict_Check(new_dict)) {
1039     PyErr_Format(PyExc_TypeError,
1040                  "__dict__ must be set to a dictionary, not a '%s'",
1041                  Py_TYPE(new_dict)->tp_name);
1042     return -1;
1043   }
1044   Py_INCREF(new_dict);
1045   Py_CLEAR(o->dict);
1046   o->dict = new_dict;
1047   return 0;
1048 }
1049 
1050 static PyGetSetDef JaxCompiledFunction_tp_getset[] = {
1051     // Having a __dict__ seems necessary to allow !functool.wraps to override
1052     // __doc__.
1053     {const_cast<char*>("__dict__"), JaxCompiledFunction_get_dict,
1054      JaxCompiledFunction_set_dict, nullptr, nullptr},
1055     {nullptr, nullptr, nullptr, nullptr, nullptr}};
1056 
JaxCompiledFunction_tp_call(PyObject * self,PyObject * args,PyObject * kwargs)1057 PyObject* JaxCompiledFunction_tp_call(PyObject* self, PyObject* args,
1058                                       PyObject* kwargs) {
1059   JaxCompiledFunctionObject* o =
1060       reinterpret_cast<JaxCompiledFunctionObject*>(self);
1061   tensorflow::profiler::TraceMe traceme([&] {
1062     return absl::StrCat("JaxCompiledFunction(", o->fun.function_name(), ")");
1063   });
1064   absl::optional<py::kwargs> py_kwargs;
1065   if (kwargs) {
1066     py_kwargs = py::reinterpret_borrow<py::kwargs>(kwargs);
1067   }
1068   try {
1069     xla::StatusOr<py::object> out = o->fun.Call(args, std::move(py_kwargs));
1070     if (!out.ok()) {
1071       PyErr_SetString(PyExc_ValueError, out.status().ToString().c_str());
1072       return nullptr;
1073     }
1074     return out.ValueOrDie().release().ptr();
1075   } catch (py::error_already_set& e) {
1076     e.restore();
1077     return nullptr;
1078   } catch (py::cast_error& e) {
1079     PyErr_SetString(PyExc_ValueError, e.what());
1080     return nullptr;
1081   } catch (std::invalid_argument& e) {
1082     PyErr_SetString(PyExc_ValueError, e.what());
1083     return nullptr;
1084   }
1085 }
1086 
InitializeCompiledFunction(JaxCompiledFunctionObject * cfun,py::function fun,py::function cache_miss,py::function get_device,std::vector<int> static_argnums,std::vector<py::str> static_argnames,std::vector<int> donate_argnums,std::shared_ptr<CompiledFunctionCache> cache)1087 void InitializeCompiledFunction(JaxCompiledFunctionObject* cfun,
1088                                 py::function fun, py::function cache_miss,
1089                                 py::function get_device,
1090                                 std::vector<int> static_argnums,
1091                                 std::vector<py::str> static_argnames,
1092                                 std::vector<int> donate_argnums,
1093                                 std::shared_ptr<CompiledFunctionCache> cache) {
1094   new (&cfun->fun) CompiledFunction(
1095       std::move(fun), std::move(cache_miss), std::move(get_device),
1096       std::move(static_argnums), std::move(static_argnames),
1097       std::move(donate_argnums), std::move(cache));
1098 }
1099 
1100 }  // extern "C"
1101 
MakeCompiledFunction(py::function fun,py::function cache_miss,py::function get_device,std::vector<int> static_argnums,std::vector<py::str> static_argnames,std::vector<int> donate_argnums,std::shared_ptr<CompiledFunctionCache> cache)1102 py::object MakeCompiledFunction(py::function fun, py::function cache_miss,
1103                                 py::function get_device,
1104                                 std::vector<int> static_argnums,
1105                                 std::vector<py::str> static_argnames,
1106                                 std::vector<int> donate_argnums,
1107                                 std::shared_ptr<CompiledFunctionCache> cache) {
1108   py::object obj = py::reinterpret_steal<py::object>(JaxCompiledFunction_tp_new(
1109       reinterpret_cast<PyTypeObject*>(JaxCompiledFunction_Type), nullptr,
1110       nullptr));
1111   JaxCompiledFunctionObject* buf =
1112       reinterpret_cast<JaxCompiledFunctionObject*>(obj.ptr());
1113   if (!cache) {
1114     cache = std::make_shared<CompiledFunctionCache>(
1115         CompiledFunctionCache::kDefaultCapacity);
1116   }
1117   InitializeCompiledFunction(buf, std::move(fun), std::move(cache_miss),
1118                              std::move(get_device), std::move(static_argnums),
1119                              std::move(static_argnames),
1120                              std::move(donate_argnums), std::move(cache));
1121   return obj;
1122 }
1123 
1124 // Helpers for building Python properties
1125 template <typename Func>
property_readonly(Func && get)1126 py::object property_readonly(Func&& get) {
1127   py::handle property(reinterpret_cast<PyObject*>(&PyProperty_Type));
1128   return property(py::cpp_function(std::forward<Func>(get)), py::none(),
1129                   py::none(), "");
1130 }
1131 
1132 // Version numbers for the pickled representations of
1133 // CompiledFunction/CompiledFunctionCache. Increment these if changing them.
1134 const int kCompiledFunctionCachePickleVersion = 1;
1135 const int kCompiledFunctionPickleVersion = 1;
1136 
1137 }  // namespace
1138 
BuildJaxjitSubmodule(py::module & m)1139 void BuildJaxjitSubmodule(py::module& m) {
1140   py::module jitlib = m.def_submodule("jax_jit", "Jax C++ jit library");
1141 
1142   // We define CompiledFunctionCache in the xla_extension module to work
1143   // around https://github.com/cloudpipe/cloudpickle/issues/354, which
1144   // means classes defined in submodules aren't pickleable.
1145   // TODO(phawkins): remove this workaround once Google is using a newer
1146   // cloudpickle. Opensource users can just pip install a newer version already.
1147   py::class_<CompiledFunctionCache, std::shared_ptr<CompiledFunctionCache>>
1148       cache(m, "CompiledFunctionCache");
1149   cache.def(py::init<int>(),
1150             py::arg("capacity") = CompiledFunctionCache::kDefaultCapacity);
1151   cache.def("size", &CompiledFunctionCache::Size);
1152   cache.def("capacity", &CompiledFunctionCache::Capacity);
1153   cache.def("clear", &CompiledFunctionCache::Clear);
1154   cache.def(py::pickle(
1155       // __getstate__
1156       // Pickles as an empty cache; the client can repopulate as needed.
1157       [](const CompiledFunctionCache& cache) {
1158         py::dict pickle;
1159         pickle["version"] = kCompiledFunctionCachePickleVersion;
1160         pickle["capacity"] = cache.Capacity();
1161         return pickle;
1162       },
1163       // __setstate__
1164       [](const py::dict& pickle) {
1165         int version = py::cast<int>(pickle["version"]);
1166         if (version != kCompiledFunctionCachePickleVersion) {
1167           throw std::invalid_argument(absl::StrFormat(
1168               "Invalid CompiledFunction pickle version, got %d, expected %d",
1169               version, kCompiledFunctionCachePickleVersion));
1170         }
1171         int capacity = py::cast<int>(pickle["capacity"]);
1172         return std::make_shared<CompiledFunctionCache>(capacity);
1173       }));
1174 
1175   // Alias CompiledFunctionCache in the submodule where we actually wanted to
1176   // define it.
1177   jitlib.attr("CompiledFunctionCache") = cache;
1178 
1179   // We need to use heap-allocated type objects because we want to add
1180   // additional methods dynamically.
1181   py::object cfun;
1182   {
1183     py::str name = py::str("CompiledFunction");
1184     py::str qualname = py::str("CompiledFunction");
1185     PyHeapTypeObject* heap_type = reinterpret_cast<PyHeapTypeObject*>(
1186         PyType_Type.tp_alloc(&PyType_Type, 0));
1187     // Caution: we must not call any functions that might invoke the GC until
1188     // PyType_Ready() is called. Otherwise the GC might see a half-constructed
1189     // type object.
1190     CHECK(heap_type) << "Unable to create heap type object";
1191     heap_type->ht_name = name.release().ptr();
1192     heap_type->ht_qualname = qualname.release().ptr();
1193     PyTypeObject* type = &heap_type->ht_type;
1194     type->tp_name = "CompiledFunction";
1195     type->tp_basicsize = sizeof(JaxCompiledFunctionObject);
1196     type->tp_flags =
1197         Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE | Py_TPFLAGS_HAVE_GC;
1198     type->tp_new = JaxCompiledFunction_tp_new;
1199     type->tp_dealloc = JaxCompiledFunction_tp_dealloc;
1200     type->tp_dictoffset = offsetof(JaxCompiledFunctionObject, dict);
1201     type->tp_traverse = JaxCompiledFunction_tp_traverse;
1202     type->tp_clear = JaxCompiledFunction_tp_clear;
1203     type->tp_weaklistoffset = offsetof(JaxCompiledFunctionObject, weakrefs);
1204     type->tp_getset = JaxCompiledFunction_tp_getset;
1205     type->tp_descr_get = JaxCompiledFunction_tp_descr_get;
1206     type->tp_call = JaxCompiledFunction_tp_call;
1207     CHECK_EQ(PyType_Ready(type), 0);
1208     JaxCompiledFunction_Type = reinterpret_cast<PyObject*>(type);
1209     cfun = py::reinterpret_borrow<py::object>(JaxCompiledFunction_Type);
1210   }
1211   py::object cfun_type =
1212       py::reinterpret_borrow<py::object>(JaxCompiledFunction_Type);
1213 
1214   // Add CompiledFunction to the xla_extension module so it can be pickled.
1215   m.attr("CompiledFunction") = cfun_type;
1216 
1217   cfun.attr("__signature__") =
1218       property_readonly([](py::handle self) -> xla::StatusOr<py::object> {
1219         TF_ASSIGN_OR_RETURN(CompiledFunction * fun, AsCompiledFunction(self));
1220         return fun->PythonSignature();
1221       });
1222   cfun.attr("_cache_miss") =
1223       property_readonly([](py::handle self) -> xla::StatusOr<py::object> {
1224         TF_ASSIGN_OR_RETURN(CompiledFunction * fun, AsCompiledFunction(self));
1225         return fun->cache_miss();
1226       });
1227   cfun.attr("__getstate__") = py::cpp_function(
1228       [](const CompiledFunction::object& self) {
1229         CompiledFunction* fn = self.func();
1230         py::dict pickle;
1231         pickle["version"] = kCompiledFunctionPickleVersion;
1232         pickle["fun"] = fn->fun();
1233         pickle["cache_miss"] = fn->cache_miss();
1234         pickle["get_device"] = fn->get_device();
1235         pickle["static_argnums"] = fn->static_argnums();
1236         pickle["static_argnames"] = fn->static_argnames();
1237         pickle["donate_argnums"] = fn->donate_argnums();
1238         pickle["cache"] = fn->cache();
1239         return pickle;
1240       },
1241       py::is_method(cfun_type));
1242   cfun.attr("__setstate__") = py::cpp_function(
1243       [](CompiledFunction::object& self, const py::dict& pickle) {
1244         int version = py::cast<int>(pickle["version"]);
1245         if (version != kCompiledFunctionPickleVersion) {
1246           throw std::invalid_argument(absl::StrFormat(
1247               "Invalid CompiledFunction pickle version, got %d, expected %d",
1248               version, kCompiledFunctionPickleVersion));
1249         }
1250         py::function fun = py::cast<py::function>(pickle["fun"]);
1251         py::function cache_miss = py::cast<py::function>(pickle["cache_miss"]);
1252         py::function get_device = py::cast<py::function>(pickle["get_device"]);
1253         std::vector<int> static_argnums =
1254             py::cast<std::vector<int>>(pickle["static_argnums"]);
1255         std::vector<py::str> static_argnames =
1256             py::cast<std::vector<py::str>>(pickle["static_argnames"]);
1257         std::vector<int> donate_argnums =
1258             py::cast<std::vector<int>>(pickle["donate_argnums"]);
1259         std::shared_ptr<CompiledFunctionCache> cache =
1260             py::cast<std::shared_ptr<CompiledFunctionCache>>(pickle["cache"]);
1261         InitializeCompiledFunction(
1262             reinterpret_cast<JaxCompiledFunctionObject*>(self.ptr()),
1263             std::move(fun), std::move(cache_miss), std::move(get_device),
1264             std::move(static_argnums), std::move(static_argnames),
1265             std::move(donate_argnums), std::move(cache));
1266       },
1267       py::is_method(cfun_type));
1268   py::class_<GlobalJitState> global_state_(jitlib, "GlobalJitState");
1269   global_state_.def_readwrite("disable_jit", &GlobalJitState::disable_jit);
1270   global_state_.def_readwrite("enable_x64", &GlobalJitState::enable_x64);
1271   global_state_.def_readwrite("extra_jit_context",
1272                               &GlobalJitState::extra_jit_context);
1273   global_state_.def_readwrite("post_hook", &GlobalJitState::post_hook);
1274 
1275   py::class_<ThreadLocalJitState> thread_local_state_(jitlib,
1276                                                       "ThreadLocalJitState");
1277   thread_local_state_.def_readwrite("disable_jit",
1278                                     &ThreadLocalJitState::disable_jit);
1279   thread_local_state_.def_readwrite("enable_x64",
1280                                     &ThreadLocalJitState::enable_x64);
1281   thread_local_state_.def_readwrite("extra_jit_context",
1282                                     &ThreadLocalJitState::extra_jit_context);
1283   thread_local_state_.def_readwrite("post_hook",
1284                                     &ThreadLocalJitState::post_hook);
1285 
1286   jitlib.def(
1287       "global_state", [&]() { return &global_state; },
1288       py::return_value_policy::reference);
1289   jitlib.def(
1290       "thread_local_state", [&]() { return &thread_local_state; },
1291       py::return_value_policy::reference);
1292 
1293   jitlib.def("jit_is_disabled", &JitIsDisabled);
1294   jitlib.def("get_enable_x64", &GetEnableX64);
1295 
1296   // TODO(phawkins): delete the following methods after dropping compatibility
1297   // with JAX python versions older than 0.2.10.
1298   jitlib.def("set_disable_jit_cpp_flag",
1299              [&](bool disable_jit) { global_state.disable_jit = disable_jit; });
1300   jitlib.def("get_disable_jit_cpp_flag",
1301              [&]() { return global_state.disable_jit; });
1302   jitlib.def("set_disable_jit_thread_local",
1303              [&](absl::optional<bool> disable_jit) {
1304                thread_local_state.disable_jit = disable_jit;
1305              });
1306   jitlib.def("get_disable_jit_thread_local",
1307              [&]() { return thread_local_state.disable_jit; });
1308   // TODO(jblespiau): Remove from the Python code and remove this
1309   jitlib.def("set_disable_jit", [&](bool disable_jit) {
1310     thread_local_state.disable_jit = disable_jit;
1311   });
1312   jitlib.def("get_disable_jit",
1313              [&]() { return thread_local_state.disable_jit; });
1314 
1315   jitlib.def("set_enable_x64_cpp_flag",
1316              [&](bool enable_x64) { global_state.enable_x64 = enable_x64; });
1317   jitlib.def("get_enable_x64_cpp_flag",
1318              [&]() { return global_state.enable_x64; });
1319   jitlib.def("set_enable_x64_thread_local",
1320              [&](absl::optional<bool> enable_x64) {
1321                thread_local_state.enable_x64 = enable_x64;
1322              });
1323   jitlib.def("get_enable_x64_thread_local", [&](bool enable_x64) {
1324     thread_local_state.enable_x64 = enable_x64;
1325   });
1326   // TODO(phawkins): delete up to here.
1327 
1328   jitlib.def(
1329       "jit",
1330       [](py::function fun, py::function cache_miss, py::function get_device,
1331          std::vector<int> static_argnums, std::vector<py::str> static_argnames,
1332          std::vector<int> donate_argnums,
1333          std::shared_ptr<CompiledFunctionCache> cache) -> py::object {
1334         return MakeCompiledFunction(
1335             std::move(fun), std::move(cache_miss), std::move(get_device),
1336             std::move(static_argnums), std::move(static_argnames),
1337             std::move(donate_argnums), std::move(cache));
1338       },
1339       py::arg("fun"), py::arg("cache_miss"), py::arg("get_device"),
1340       py::arg("static_argnums"),
1341       py::arg("static_argnames") = std::vector<py::str>(),
1342       py::arg("donate_argnums") = std::vector<int>(),
1343       py::arg("cache") = nullptr);
1344 
1345   // This function is not yet a full replacement for the Python one, because:
1346   // (a) it does not support abstract types,
1347   // (b) it does not set the device stickiness yet.
1348   // TODO(jblespiau): Finish the replacement of the Python feature.
1349   jitlib.def("device_put",
1350              [](py::handle obj, bool jax_enable_x64,
1351                 xla::ClientAndPtr<xla::PjRtDevice> to_device)
1352                  -> xla::StatusOr<py::object> {
1353                std::shared_ptr<xla::PyClient>& pyclient = to_device.client;
1354                xla::DevicePutOptions options;
1355                options.squash_64bit_types = !jax_enable_x64;
1356                options.force_lazy_arrays = true;
1357                options.allow_zero_copy = true;
1358                xla::StatusOr<xla::DevicePutResult> results =
1359                    DevicePut(obj, to_device.contents, options);
1360                if (!results.ok()) {
1361                  throw std::runtime_error(results.status().error_message());
1362                }
1363                if (results->owned_buffer) {
1364                  auto buffer = xla::PyBuffer::Make(
1365                      pyclient, std::move(results->owned_buffer),
1366                      xla::Traceback::Get());
1367 
1368                  static const auto* jax_core =
1369                      new py::module(py::module::import("jax.core"));
1370                  static const auto* shaped_array =
1371                      new py::handle(jax_core->attr("ShapedArray"));
1372                  buffer.buf()->SetAval((*shaped_array)(
1373                      buffer.buf()->python_shape(), buffer.buf()->python_dtype(),
1374                      results->weak_type));
1375                  TF_RETURN_IF_ERROR(buffer.buf()->set_sticky_device(nullptr));
1376 
1377                  return std::move(buffer);
1378                } else {
1379                  return py::cast<py::object>(obj);
1380                }
1381              });
1382 
1383   py::class_<xla::PyArgSignature> arg_signature(jitlib, "PyArgSignature");
1384   arg_signature
1385       .def_property_readonly("dtype",
1386                              [](const xla::PyArgSignature& sig) {
1387                                return PrimitiveTypeToDtype(sig.dtype);
1388                              })
1389       .def_property_readonly(
1390           "shape",
1391           [](const xla::PyArgSignature& sig) {
1392             return xla::SpanToTuple(absl::MakeConstSpan(sig.shape));
1393           })
1394       .def_readonly("weak_type", &xla::PyArgSignature::weak_type);
1395   jitlib.def("_ArgSignatureOfValue", &xla::PyArgSignatureOfValue);
1396 
1397   // All private members are only for testing/debugging purposes
1398   cfun.attr("_cache_size") = py::cpp_function(
1399       [](py::handle self) -> xla::StatusOr<int> {
1400         TF_ASSIGN_OR_RETURN(CompiledFunction * fun, AsCompiledFunction(self));
1401         return fun->cache_size();
1402       },
1403       py::is_method(cfun));
1404   cfun.attr("_clear_cache") = py::cpp_function(
1405       [](py::handle self) -> xla::Status {
1406         TF_ASSIGN_OR_RETURN(CompiledFunction * fun, AsCompiledFunction(self));
1407         fun->ClearCache();
1408         return xla::Status::OK();
1409       },
1410       py::is_method(cfun));
1411   jitlib.def("_is_float0", &xla::IsFloat0);
1412 }
1413 
1414 }  // namespace jax
1415