1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ 17 #define TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ 18 19 #include <string> 20 21 #include "absl/container/inlined_vector.h" 22 #include "absl/strings/str_cat.h" 23 #include "absl/strings/str_join.h" 24 #include "pybind11/pybind11.h" 25 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h" 26 #include "tensorflow/compiler/xla/python/py_client.h" 27 #include "tensorflow/compiler/xla/python/py_values.h" 28 #include "tensorflow/compiler/xla/python/python_ref_manager.h" 29 #include "tensorflow/compiler/xla/python/pytree.h" 30 #include "tensorflow/compiler/xla/types.h" 31 #include "tensorflow/compiler/xla/xla_data.pb.h" 32 33 namespace jax { 34 35 // Flags, such as JIT disable and the x64 mode, are controlled by: 36 // - a global flag value, e.g., associated to --jax_enable_x64 37 // - possibly a thread-local value, which initially is absl::nullopt and 38 // overrides the global value if set. The thread-local state is 39 // used to implement context managers that locally override the global state. 40 // TODO(phawkins): consider changing the global state to optional types to 41 // catch cases where we fail to set it. 42 struct GlobalJitState { 43 bool disable_jit = false; 44 bool enable_x64 = false; 45 46 // Extra context that should be included in the JIT cache key. Must be 47 // hashable and have an equality defined. 48 pybind11::object extra_jit_context = pybind11::none(); 49 50 // A callback that, if present, is called when a JITted function is executed 51 // from cache. 52 absl::optional<pybind11::function> post_hook; 53 }; 54 55 struct ThreadLocalJitState { ~ThreadLocalJitStateThreadLocalJitState56 ~ThreadLocalJitState() { 57 if (extra_jit_context) { 58 // We likely do not hold the GIL, so we hand the Python object to the 59 // global reference manager to destroy. 60 pybind11::object o = std::move(*extra_jit_context); 61 xla::GlobalPyRefManager()->AddGarbage(absl::MakeSpan(&o, 1)); 62 extra_jit_context = absl::nullopt; 63 } 64 } 65 absl::optional<bool> disable_jit; 66 absl::optional<bool> enable_x64; 67 absl::optional<pybind11::object> extra_jit_context; 68 absl::optional<pybind11::function> post_hook; 69 }; 70 71 GlobalJitState& GetGlobalState(); 72 ThreadLocalJitState& GetLocalState(); 73 74 // Returns the value for jax_enable_x64 (defined by a thread-local value if 75 // defined, defaulting to the value of the flag otherwise). 76 bool GetEnableX64(); 77 78 // The signature of Python jitted function call, partitioned into: 79 // - dynamic positional arguments (i.e. positional args which are not static) 80 // - static positional arguments (i.e. the args associated to static_argnums) 81 // - keyword arguments 82 // The CallSignature should unambiguously identify a function call, thus, 83 // equality is based on: 84 // (a) Same PyTree for all dynamic positional arguments and keyword arguments 85 // (a) equality of the arguments and keyword arguments ArgSignature 86 // (a) equality (delegated to Python) of the static arguments. 87 struct CallSignature { 88 // A PyTreeDef for each dynamic argument, positional arguments first 89 // followed by keyword arguments. Keyword arguments are in the order given 90 // by dynamic_arg_names. 91 absl::InlinedVector<xla::PyTreeDef, 2> dynamic_arg_treedefs; 92 // Dynamic keyword argument names. Interned, and sorted by the keyword 93 // name. 94 std::vector<pybind11::object> dynamic_arg_names; 95 // Shape and dtype for both the dynamic positional arguments and the keyword 96 // arguments (sorted by keyword name). 97 absl::InlinedVector<xla::PyArgSignature, 2> dynamic_arg_signatures; 98 99 // Static arguments. Contains the positional arguments sorted in argument 100 // order, followed by static keyword arguments in the order given by 101 // `static_arg_names`. 102 std::vector<pybind11::object> static_args; 103 // Static keyword argument names. Interned, and sorted by keyword name. 104 std::vector<pybind11::object> static_arg_names; 105 106 // For JIT, we need this in the key because computation follows the data, so 107 // we may have multiple executables depending on the devices the data is on. 108 // This is not the case for PMAP, and is set to `nullptr`. 109 xla::PjRtDevice* device = nullptr; 110 bool jax_enable_x64; 111 112 // Opaque additional context that should be included as part of the cache key. 113 pybind11::object global_extra_jit_context; 114 absl::optional<pybind11::object> thread_local_extra_jit_context; 115 116 bool operator==(const CallSignature& other) const; 117 bool operator!=(const CallSignature& other) const { 118 return !(*this == other); 119 } 120 121 std::string DebugString() const; 122 }; 123 124 template <typename H> 125 H AbslHashValue(H h, const CallSignature& s); 126 127 // The resulting information of the parsing and conversion of the arguments. 128 struct ParsedArgumentsAsBuffers { 129 // The call signature will be filled during 2 steps: 130 // - `ParseArguments` will fill the static arguments and the pytree 131 // structures 132 // - the shapes and dtypes are filled later, by `ParseAndTransferArguments`. 133 CallSignature signature; 134 // The concatenation of the dynamic positional arguments and the sorted 135 // keyword arguments. 136 absl::InlinedVector<pybind11::object, 2> flat_dynamic_args; 137 std::vector<pybind11::object> keep_alive_objects; 138 139 // The following is only valid if the parsing succeeds. 140 std::vector<xla::PjRtBuffer*> arg_buffers; 141 // We may need to keep these objects around, because: 142 // (a) we need to extend the lifetime of objects created within 143 // `CopyBuffersToDevice` 144 // (b) `arg_buffers` do not maintain ownership 145 std::vector<std::unique_ptr<xla::PjRtBuffer>> keep_alive; 146 }; 147 148 // Filter out static arguments, flatten and concatenate other arguments (i.e. 149 // dynamic positional and keyword arguments), filling `arguments` in place. 150 xla::Status ParseArguments(pybind11::handle args, 151 const absl::optional<pybind11::kwargs>& py_kwargs, 152 absl::Span<int const> static_argnums, 153 absl::Span<pybind11::str const> static_argnames, 154 ParsedArgumentsAsBuffers& arguments); 155 156 // The function to call in `xla.cc` to add the bindings for this module. 157 void BuildJaxjitSubmodule(pybind11::module& m); 158 159 } // namespace jax 160 161 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_ 162