• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_
18 
19 #include <string>
20 
21 #include "absl/container/inlined_vector.h"
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "pybind11/pybind11.h"
25 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
26 #include "tensorflow/compiler/xla/python/py_client.h"
27 #include "tensorflow/compiler/xla/python/py_values.h"
28 #include "tensorflow/compiler/xla/python/python_ref_manager.h"
29 #include "tensorflow/compiler/xla/python/pytree.h"
30 #include "tensorflow/compiler/xla/types.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32 
33 namespace jax {
34 
35 // Flags, such as JIT disable and the x64 mode, are controlled by:
36 // - a global flag value, e.g., associated to --jax_enable_x64
37 // - possibly a thread-local value, which initially is absl::nullopt and
38 //   overrides the global value if set. The thread-local state is
39 //   used to implement context managers that locally override the global state.
40 // TODO(phawkins): consider changing the global state to optional types to
41 // catch cases where we fail to set it.
42 struct GlobalJitState {
43   bool disable_jit = false;
44   bool enable_x64 = false;
45 
46   // Extra context that should be included in the JIT cache key. Must be
47   // hashable and have an equality defined.
48   pybind11::object extra_jit_context = pybind11::none();
49 
50   // A callback that, if present, is called when a JITted function is executed
51   // from cache.
52   absl::optional<pybind11::function> post_hook;
53 };
54 
55 struct ThreadLocalJitState {
~ThreadLocalJitStateThreadLocalJitState56   ~ThreadLocalJitState() {
57     if (extra_jit_context) {
58       // We likely do not hold the GIL, so we hand the Python object to the
59       // global reference manager to destroy.
60       pybind11::object o = std::move(*extra_jit_context);
61       xla::GlobalPyRefManager()->AddGarbage(absl::MakeSpan(&o, 1));
62       extra_jit_context = absl::nullopt;
63     }
64   }
65   absl::optional<bool> disable_jit;
66   absl::optional<bool> enable_x64;
67   absl::optional<pybind11::object> extra_jit_context;
68   absl::optional<pybind11::function> post_hook;
69 };
70 
71 GlobalJitState& GetGlobalState();
72 ThreadLocalJitState& GetLocalState();
73 
74 // Returns the value for jax_enable_x64 (defined by a thread-local value if
75 // defined, defaulting to the value of the flag otherwise).
76 bool GetEnableX64();
77 
78 // The signature of Python jitted function call, partitioned into:
79 // - dynamic positional arguments (i.e. positional args which are not static)
80 // - static positional arguments (i.e. the args associated to static_argnums)
81 // - keyword arguments
82 // The CallSignature should unambiguously identify a function call, thus,
83 // equality is based on:
84 // (a) Same PyTree for all dynamic positional arguments and keyword arguments
85 // (a) equality of the arguments and keyword arguments ArgSignature
86 // (a) equality (delegated to Python) of the static arguments.
87 struct CallSignature {
88   // A PyTreeDef for each dynamic argument, positional arguments first
89   // followed by keyword arguments. Keyword arguments are in the order given
90   // by dynamic_arg_names.
91   absl::InlinedVector<xla::PyTreeDef, 2> dynamic_arg_treedefs;
92   // Dynamic keyword argument names. Interned, and sorted by the keyword
93   // name.
94   std::vector<pybind11::object> dynamic_arg_names;
95   // Shape and dtype for both the dynamic positional arguments and the keyword
96   // arguments (sorted by keyword name).
97   absl::InlinedVector<xla::PyArgSignature, 2> dynamic_arg_signatures;
98 
99   // Static arguments. Contains the positional arguments sorted in argument
100   // order, followed by static keyword arguments in the order given by
101   // `static_arg_names`.
102   std::vector<pybind11::object> static_args;
103   // Static keyword argument names. Interned, and sorted by keyword name.
104   std::vector<pybind11::object> static_arg_names;
105 
106   // For JIT, we need this in the key because computation follows the data, so
107   // we may have multiple executables depending on the devices the data is on.
108   // This is not the case for PMAP, and is set to `nullptr`.
109   xla::PjRtDevice* device = nullptr;
110   bool jax_enable_x64;
111 
112   // Opaque additional context that should be included as part of the cache key.
113   pybind11::object global_extra_jit_context;
114   absl::optional<pybind11::object> thread_local_extra_jit_context;
115 
116   bool operator==(const CallSignature& other) const;
117   bool operator!=(const CallSignature& other) const {
118     return !(*this == other);
119   }
120 
121   std::string DebugString() const;
122 };
123 
124 template <typename H>
125 H AbslHashValue(H h, const CallSignature& s);
126 
127 // The resulting information of the parsing and conversion of the arguments.
128 struct ParsedArgumentsAsBuffers {
129   // The call signature will be filled during 2 steps:
130   // - `ParseArguments` will fill the static arguments and the pytree
131   //    structures
132   // - the shapes and dtypes are filled later, by `ParseAndTransferArguments`.
133   CallSignature signature;
134   // The concatenation of the dynamic positional arguments and the sorted
135   // keyword arguments.
136   absl::InlinedVector<pybind11::object, 2> flat_dynamic_args;
137   std::vector<pybind11::object> keep_alive_objects;
138 
139   // The following is only valid if the parsing succeeds.
140   std::vector<xla::PjRtBuffer*> arg_buffers;
141   // We may need to keep these objects around, because:
142   // (a) we need to extend the lifetime of objects created within
143   //    `CopyBuffersToDevice`
144   // (b) `arg_buffers` do not maintain ownership
145   std::vector<std::unique_ptr<xla::PjRtBuffer>> keep_alive;
146 };
147 
148 // Filter out static arguments, flatten and concatenate other arguments (i.e.
149 // dynamic positional and keyword arguments), filling `arguments` in place.
150 xla::Status ParseArguments(pybind11::handle args,
151                            const absl::optional<pybind11::kwargs>& py_kwargs,
152                            absl::Span<int const> static_argnums,
153                            absl::Span<pybind11::str const> static_argnames,
154                            ParsedArgumentsAsBuffers& arguments);
155 
156 // The function to call in `xla.cc` to add the bindings for this module.
157 void BuildJaxjitSubmodule(pybind11::module& m);
158 
159 }  // namespace jax
160 
161 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_
162