1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_
18
19 #include "absl/strings/str_cat.h"
20 #include "absl/strings/str_join.h"
21 #include "pybind11/pybind11.h"
22 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
23 #include "tensorflow/compiler/xla/python/py_client.h"
24 #include "tensorflow/compiler/xla/python/pytree.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27
28 namespace jax {
29
30 // Describes the abstract shape and dtype of an argument.
31 struct ArgSignature {
ArgSignatureArgSignature32 ArgSignature(xla::PrimitiveType dtype, absl::Span<const xla::int64> shape,
33 bool weak_type)
34 : dtype(dtype), shape(shape.begin(), shape.end()), weak_type(weak_type) {}
35 // This is the XLA dtype of the object.
36 const xla::PrimitiveType dtype;
37 const absl::InlinedVector<xla::int64, 4> shape;
38 // JAX arguments can be of weak type, if and only if they are Python scalars
39 // or `DeviceArray` values such that `aval.weak_type` is true.
40 const bool weak_type;
41 bool operator==(const ArgSignature& other) const {
42 return std::tie(dtype, weak_type, shape) ==
43 std::tie(other.dtype, other.weak_type, other.shape);
44 }
45 bool operator!=(const ArgSignature& other) const { return !(*this == other); }
46 std::string DebugString() const;
47 };
48
49 template <typename H>
AbslHashValue(H h,const ArgSignature & s)50 H AbslHashValue(H h, const ArgSignature& s) {
51 h = H::combine(std::move(h), s.dtype);
52 h = H::combine_contiguous(std::move(h), s.shape.data(), s.shape.size());
53 return h;
54 }
55
56 // The signature of Python jitted function call, partitioned into:
57 // - dynamic positional arguments (i.e. positional args which are not static)
58 // - static positional arguments (i.e. the args associated to static_argnums)
59 // - keyword arguments
60 // The CallSignature should unambiguously identify a function call, thus,
61 // equality is based on:
62 // (a) Same PyTree for all dynamic positional arguments and keyword arguments
63 // (a) equality of the arguments and keyword arguments ArgSignature
64 // (a) equality (delegated to Python) of the static arguments.
65 struct CallSignature {
66 struct KwargEntry {
67 // To avoid comparing strings, we intern the kwargs strings.
68 // The compilation cache holds a reference to all the keys.
69 pybind11::handle key;
70 xla::PyTreeDef value_treedef;
71 bool operator==(const KwargEntry& other) const {
72 return key.ptr() == other.key.ptr() &&
73 value_treedef == other.value_treedef;
74 }
75 bool operator!=(const KwargEntry& other) const { return !(*this == other); }
76 };
77
78 // Only contains the arguments associated to `static_argnums`, sorted in the
79 // order of their argnum index.
80 std::vector<pybind11::object> static_args;
81 // A PyTreeDef for each positional dynamic (i.e. not static) argument.
82 std::vector<xla::PyTreeDef> dynamic_positional_args_treedef;
83 // Keyword arguments. Sorted by the keyword name.
84 std::vector<KwargEntry> keyword_args;
85 // Shape and dtype for both the dynamic positional arguments and the keyword
86 // arguments (sorted by keyword name).
87 std::vector<ArgSignature> dynamic_args_signatures;
88 xla::PjRtDevice* device;
89
90 bool operator==(const CallSignature& other) const;
91 bool operator!=(const CallSignature& other) const {
92 return !(*this == other);
93 }
94
95 // To be used when we want to keep ownership of Python values referenced by
96 // the `CallSignature` (i.e. when we insert an entry).
97 void IncRef() const;
98 // The destructor of the cache should call this on all entries.
99 void DecRef() const;
100
101 std::string DebugString() const;
102 };
103
104 template <typename H>
AbslHashValue(H h,const CallSignature::KwargEntry & kw)105 H AbslHashValue(H h, const CallSignature::KwargEntry& kw) {
106 h = H::combine(std::move(h), kw.key.ptr(), kw.value_treedef);
107 return h;
108 }
109
110 template <typename H>
111 H AbslHashValue(H h, const CallSignature& s);
112
113 // The resulting information of the parsing and conversion of the arguments.
114 struct ParsedArgumentsAsBuffers {
115 // The call signature will be filled during 2 steps:
116 // - `ParseArguments` will fill the static arguments and the pytree
117 // structures
118 // - the shapes and dtypes are filled later, by `ParseAndTransferArguments`.
119 CallSignature signature;
120 // The concatenation of the dynamic positional arguments and the sorted
121 // keyword arguments.
122 std::vector<pybind11::object> flat_dynamic_args;
123 std::vector<pybind11::object> keep_alive_objects;
124
125 // The following is only valid if the parsing succeeds.
126 std::vector<xla::PjRtBuffer*> arg_buffers;
127 // We may need to keep these objects around, because:
128 // (a) we need to extend the lifetime of objects created within
129 // `ConvertArgsToBuffers`
130 // (b) `arg_buffers` do not maintain ownership
131 std::vector<std::unique_ptr<xla::PjRtBuffer>> keep_alive;
132 };
133
134 // Filter out static arguments, flatten and concatenate other arguments (i.e.
135 // dynamic positional and keyword arguments), filling `arguments` in place.
136 xla::Status ParseArguments(const pybind11::args& args,
137 const pybind11::kwargs& py_kwargs,
138 absl::Span<int const> static_argnums,
139 ParsedArgumentsAsBuffers& arguments);
140
141 struct DevicePutResult {
DevicePutResultDevicePutResult142 explicit DevicePutResult(xla::PjRtBuffer* b, bool weak_type)
143 : buffer(b), weak_type(weak_type), owned_buffer(nullptr) {}
DevicePutResultDevicePutResult144 DevicePutResult(std::unique_ptr<xla::PjRtBuffer> new_buffer, bool weak_type)
145 : buffer(new_buffer.get()),
146 weak_type(weak_type),
147 owned_buffer(std::move(new_buffer)) {}
148
149 xla::PjRtBuffer* buffer;
150 bool weak_type;
151 std::unique_ptr<xla::PjRtBuffer> owned_buffer;
152 };
153
154 // Returns the ArgSignature associated with an argument. Returns an error if
155 // the argument is not supported.
156 xla::StatusOr<ArgSignature> ArgSignatureOfValue(pybind11::handle arg,
157 bool jax_enable_x64);
158
159 // Moves a device-like object to be on device.
160 // - If the object is already on device, `owned_buffer` will be nullptr.
161 // - If it's not, a new buffer will be created and returned using
162 // `owned_buffer`.
163 // In all cases, `buffer` will point to the already existing or newly created
164 // buffer.
165 // If `obj` is not convertible to a `xla::PjRtBuffer` from C++, an error will be
166 // returned; float0 dtype and `_DeviceArray` with non-trivial LazyExpr are not
167 // supported yet.
168 xla::StatusOr<DevicePutResult> DevicePut(pybind11::handle arg,
169 xla::PjRtDevice* to_device,
170 bool jax_enable_x64,
171 xla::PyClient& pyclient);
172
173 // The function to call in `xla.cc` to add the bindings for this module.
174 void BuildJaxjitSubmodule(pybind11::module& m);
175
176 } // namespace jax
177
178 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_JAX_JIT_H_
179