1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/pmap_lib.h"
17
18 #include <stdexcept>
19 #include <string>
20 #include <vector>
21
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "absl/synchronization/notification.h"
25 #include "absl/types/span.h"
26 #include "absl/types/variant.h"
27 #include "pybind11/cast.h"
28 #include "pybind11/pybind11.h"
29 #include "pybind11/pytypes.h"
30 #include "tensorflow/compiler/xla/python/absl_casters.h"
31 #include "tensorflow/compiler/xla/python/jax_jit.h"
32 #include "tensorflow/compiler/xla/python/py_executable.h"
33 #include "tensorflow/compiler/xla/python/types.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/platform/logging.h"
36
37 namespace jax {
38
39 namespace py = pybind11;
40
41 // TODO(jblespiau): Using `NoSharding` instead of `None` would allow us to
42 // simplify the conversion logic.
PyShardingToCpp(pybind11::tuple py_sharding)43 std::vector<AvalDimSharding> PyShardingToCpp(pybind11::tuple py_sharding) {
44 std::vector<AvalDimSharding> cpp_sharding;
45 cpp_sharding.reserve(py_sharding.size());
46 for (py::handle value : py_sharding) {
47 if (value.is_none()) {
48 cpp_sharding.push_back(NoSharding());
49 } else if (py::isinstance<Chunked>(value)) {
50 cpp_sharding.push_back(py::cast<Chunked>(value));
51 } else if (py::isinstance<Unstacked>(value)) {
52 cpp_sharding.push_back(py::cast<Unstacked>(value));
53 } else {
54 throw std::runtime_error(
55 absl::StrCat("Not supported Python object in PyShardingToCpp in "
56 "pmap_lib.cc. The object was of type ",
57 py::cast<std::string>(py::str(value.get_type())),
58 "\n:", py::cast<std::string>(py::str(value))));
59 }
60 }
61 return cpp_sharding;
62 }
63
CppShardingToPy(std::vector<AvalDimSharding> sharding)64 pybind11::tuple CppShardingToPy(std::vector<AvalDimSharding> sharding) {
65 py::tuple result(sharding.size());
66 int counter = 0;
67 for (auto value : sharding) {
68 if (absl::holds_alternative<NoSharding>(value)) {
69 result[counter++] = py::none();
70 } else if (absl::holds_alternative<Chunked>(value)) {
71 py::handle handle = py::cast(absl::get<Chunked>(value));
72 result[counter++] = py::cast<py::object>(handle);
73 } else if (absl::holds_alternative<Unstacked>(value)) {
74 py::handle handle = py::cast(absl::get<Unstacked>(value));
75 result[counter++] = py::cast<py::object>(handle);
76 } else {
77 LOG(FATAL) << "Unhandled CPP type in CppShardingToPy.";
78 }
79 }
80 return result;
81 }
82
PyMeshShardingToCpp(pybind11::tuple py_mesh_mapping)83 std::vector<MeshDimAssignment> PyMeshShardingToCpp(
84 pybind11::tuple py_mesh_mapping) {
85 return py::cast<std::vector<MeshDimAssignment>>(py_mesh_mapping);
86 }
87
CppMeshMappingToPy(std::vector<MeshDimAssignment> mesh_mapping)88 pybind11::tuple CppMeshMappingToPy(
89 std::vector<MeshDimAssignment> mesh_mapping) {
90 py::tuple result(mesh_mapping.size());
91 int counter = 0;
92 for (auto& value : mesh_mapping) {
93 result[counter] = py::cast(value);
94 ++counter;
95 }
96 return result;
97 }
98
99 namespace {
100
101 struct PmapCacheEntry {
102 // To get a first version running, we use extensively Python here for the
103 // handling of the arguments and outputs.
104 // TODO(jblespiau): Move more to C++.
105 std::shared_ptr<xla::PyExecutable> executable;
106 // See _cpp_pmap in api.py.
107 py::object backend;
108 // A function taking as argument a list of arguments and returns a list of
109 // list of buffers `[num_devices x num_args]`.
110 py::function handle_args;
111 // A function taking as argument the output of `ExecuteOnLocalDevices` and
112 // returning a list of ShardedDeviceArray objects.
113 py::function out_handler;
114 xla::PyTreeDef out_pytree_def;
115
116 // Ensures a single thread performs the compilation for a given executable.
117 //
118 // The first thread (holding the GIL) will create the CacheEntry associated to
119 // a signature and if the object has been insterted already, other threads
120 // will wait for the notification.
121 absl::Notification compilation_complete;
122 absl::optional<xla::Status> compilation_error = absl::nullopt;
123
124 bool fall_back_to_python = false;
125 };
126
127 } // namespace
128
129 // A `PmapFunction` is associated to a `jax.pmap(f)` and takes care of the
130 // bookkeeping of the different signatures used and the dispatch of calls to
131 // the correct underlying `PyExecutable`. This class is thread-safe.
132 class PmapFunction {
133 public:
PmapFunction(py::function fun,py::function cache_miss,py::function get_jax_enable_x64,std::vector<int> static_argnums)134 PmapFunction(py::function fun, py::function cache_miss,
135 py::function get_jax_enable_x64, std::vector<int> static_argnums)
136 : fun_(std::move(fun)),
137 cache_miss_(std::move(cache_miss)),
138 static_argnums_(std::move(static_argnums)),
139 get_jax_enable_x64_(get_jax_enable_x64) {
140 std::sort(static_argnums_.begin(), static_argnums_.end());
141 }
142
~PmapFunction()143 ~PmapFunction() {
144 for (const auto& entry : executables_) {
145 entry.first.DecRef();
146 }
147 }
148
149 // This function will:
150 // (a) flatten the inputs using pytree
151 // (b) get buffer objects from the arguments
152 // (c) call the executable
153 // (d) construct `ShardedDeviceArray` objects from the outputs
154 // (e) reconstruct the `PyTree`.
155 py::object Call(py::args args, py::kwargs kwargs);
156
PythonSignature()157 py::object PythonSignature() {
158 static const auto* inspect = new py::module(py::module::import("inspect"));
159 return inspect->attr("signature")(fun_);
160 }
161
cache_size() const162 int cache_size() const { return executables_.size(); }
163
164 private:
165 // Returns nullptr if not present in the cache.
166 PmapCacheEntry* GetCacheEntryIfPresent(const CallSignature& signature);
167 // Should never return nullptr.
168 PmapCacheEntry* AddCacheEntry(const py::args& args, const py::kwargs& kwargs,
169 const CallSignature& signature,
170 py::object out_and_fastpath_data);
171
172 bool always_fallback_to_python_ = false;
173
174 const py::function fun_; // The Python function to pmap.
175 // See JAX _cpp_pmap in api.py for documentation.
176 const py::function cache_miss_;
177
178 // We need to know the static arguments to remove them from the arguments
179 // passed to the underlying PyExecutable. In sorted order.
180 std::vector<int> static_argnums_;
181 // We need a `unique_ptr` here to ensure value pointer stability.
182 absl::flat_hash_map<CallSignature, std::unique_ptr<PmapCacheEntry>>
183 executables_;
184
185 const py::function get_jax_enable_x64_;
186 absl::optional<bool> jax_enable_x64_ = absl::nullopt;
187
188 // A vector of size `num_outputs`, specifying the sharding of each output
189 std::vector<ShardingSpec> sharding_specs_;
190 };
191
GetCacheEntryIfPresent(const CallSignature & signature)192 PmapCacheEntry* PmapFunction::GetCacheEntryIfPresent(
193 const CallSignature& signature) {
194 auto found_iterator = executables_.find(signature);
195 if (found_iterator != executables_.end()) { // Cache hit!
196 if (!found_iterator->second->compilation_complete.HasBeenNotified()) {
197 py::gil_scoped_release gil_release;
198 found_iterator->second->compilation_complete.WaitForNotification();
199 }
200 if (found_iterator->second->compilation_error) {
201 throw std::invalid_argument(
202 found_iterator->second->compilation_error.value().error_message());
203 }
204 return found_iterator->second.get();
205 }
206 return nullptr;
207 }
208
AddCacheEntry(const py::args & args,const py::kwargs & kwargs,const CallSignature & signature,py::object out_and_fastpath_data)209 PmapCacheEntry* PmapFunction::AddCacheEntry(const py::args& args,
210 const py::kwargs& kwargs,
211 const CallSignature& signature,
212 py::object out_and_fastpath_data) {
213 // We need to insert the element.
214 auto result =
215 executables_.emplace(signature, std::make_unique<PmapCacheEntry>());
216 auto it = result.first;
217 PmapCacheEntry* cache_entry = it->second.get();
218 // CallSignatures in the cache own their keyword argument reference.
219 result.first->first.IncRef();
220
221 py::tuple tuple = py::cast<py::tuple>(out_and_fastpath_data);
222 CHECK_EQ(tuple.size(), 2);
223 if (tuple[1].is_none()) {
224 cache_entry->fall_back_to_python = true;
225 cache_entry->compilation_complete.Notify();
226 return cache_entry;
227 }
228
229 py::dict pmap_data = py::cast<py::dict>(tuple[1]);
230 if (py::cast<int>(pmap_data["version"]) != 1) {
231 throw std::runtime_error(absl::StrCat(
232 "The versions of jaxlib and Jax are incompatible (pmap cpp version 1 "
233 "expected, but got ",
234 py::cast<int>(pmap_data["version"]),
235 "Upgrade jaxlib and jax. Provided data was:",
236 py::cast<std::string>(py::str(py::repr(pmap_data)))));
237 }
238 // { "version": 1,
239 // "xla_executable": xla_executable,
240 // "in_handler": in_handler,
241 // "out_handler": out_handler,
242 // "out_pytree_def": out_pytree_def }
243 auto executable =
244 py::cast<std::shared_ptr<xla::PyExecutable>>(pmap_data["xla_executable"]);
245 cache_entry->executable = std::move(executable);
246 cache_entry->handle_args = py::cast<py::function>(pmap_data["in_handler"]);
247 cache_entry->out_handler = py::cast<py::function>(pmap_data["out_handler"]);
248 auto out_tree = py::cast<xla::PyTreeDef>(pmap_data["out_pytree_def"]);
249 cache_entry->out_pytree_def = std::move(out_tree);
250
251 cache_entry->compilation_complete.Notify();
252 return cache_entry;
253 }
254
Call(py::args args,py::kwargs kwargs)255 py::object PmapFunction::Call(py::args args, py::kwargs kwargs) {
256 if (always_fallback_to_python_) {
257 return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
258 }
259 // Delayed values are retrieved on the first call to `Call`.
260 if (jax_enable_x64_ == absl::nullopt) {
261 jax_enable_x64_ = py::cast<bool>(get_jax_enable_x64_());
262 }
263
264 ParsedArgumentsAsBuffers arguments;
265 if (!ParseArguments(args, kwargs, static_argnums_, arguments).ok()) {
266 return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
267 }
268
269 // Get dynamic argument signatures.
270 for (py::handle arg : arguments.flat_dynamic_args) {
271 auto signature_or_error = ArgSignatureOfValue(arg, jax_enable_x64_.value());
272 if (!signature_or_error.ok()) {
273 return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
274 }
275 arguments.signature.dynamic_args_signatures.push_back(
276 std::move(signature_or_error).ValueOrDie());
277 }
278
279 // Retrieve/Maybe add the executable to the cache.
280 PmapCacheEntry* cache_entry = GetCacheEntryIfPresent(arguments.signature);
281 if (!cache_entry) {
282 py::object out_and_fastpath_data = cache_miss_(*args, **kwargs);
283 cache_entry = GetCacheEntryIfPresent(arguments.signature);
284 if (!cache_entry) {
285 cache_entry = AddCacheEntry(args, kwargs, arguments.signature,
286 out_and_fastpath_data);
287 }
288 CHECK(cache_entry);
289 if (cache_entry->fall_back_to_python) {
290 return py::cast<py::tuple>(out_and_fastpath_data)[0];
291 }
292 // As we have already computed the results, we can return it.
293 // It's even *required* e.g. if there are donated arguments, because
294 // otherwise the buffer which has been donated already will be invalid.
295 return py::cast<py::tuple>(out_and_fastpath_data)[0];
296 }
297
298 CHECK(cache_entry);
299 if (cache_entry->fall_back_to_python) {
300 return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
301 }
302
303 // TODO(jblespiau): Use C++ only for this.
304 py::list arg_list;
305 for (auto& v : arguments.flat_dynamic_args) {
306 arg_list.append(v);
307 }
308
309 py::object handled_args = cache_entry->handle_args(arg_list);
310 py::list list_of_list_of_buffers = py::cast<py::list>(handled_args);
311
312 arguments.keep_alive_objects.push_back(
313 py::cast<py::object>(list_of_list_of_buffers));
314 // Should be `[num_devices x num_args]`.
315 std::vector<std::vector<xla::PyBuffer*>> arg_buffers;
316 arg_buffers.reserve(list_of_list_of_buffers.size());
317 for (int i = 0; i < list_of_list_of_buffers.size(); ++i) {
318 std::vector<xla::PyBuffer*> buffers;
319 buffers.reserve(py::cast<py::list>(list_of_list_of_buffers[i]).size());
320 for (auto& buf : list_of_list_of_buffers[i]) {
321 buffers.push_back(py::cast<xla::PyBuffer*>(buf));
322 }
323 arg_buffers.push_back(std::move(buffers));
324 }
325
326 std::vector<std::vector<std::unique_ptr<xla::PyBuffer>>> outputs =
327 ValueOrThrow(cache_entry->executable->ExecuteOnLocalDevices(arg_buffers));
328
329 // TODO(jblespiau): Move this to C++.
330 py::list outputs_as_python_objects;
331 for (int i = 0; i < outputs.size(); ++i) {
332 outputs_as_python_objects.append(py::cast(std::move(outputs[i])));
333 }
334 py::list flat_sharded_device_arrays =
335 cache_entry->out_handler(outputs_as_python_objects);
336 return cache_entry->out_pytree_def.Unflatten(flat_sharded_device_arrays);
337 }
338
BuildPmapSubmodule(pybind11::module & m)339 void BuildPmapSubmodule(pybind11::module& m) {
340 py::module pmap_lib = m.def_submodule("pmap_lib", "Jax C++ pmap library");
341
342 py::class_<NoSharding> no_sharding(pmap_lib, "NoSharding");
343 no_sharding.def(py::init<>())
344 .def("__repr__",
345 [](const NoSharding& chuncked) { return "NoSharding()"; })
346 .def("__eq__", [](const NoSharding& self, py::object obj) {
347 return py::isinstance<NoSharding>(obj);
348 });
349
350 py::class_<Chunked> chunked(pmap_lib, "Chunked");
351 chunked.def(py::init<std::vector<int>>())
352 .def_readonly("chunks", &Chunked::chunks)
353 .def("__repr__",
354 [](const Chunked& chuncked) {
355 return absl::StrCat("Chunked(",
356 absl::StrJoin(chuncked.chunks, ","), ")");
357 })
358 .def("__eq__", [](const Chunked& self, py::object other) {
359 if (!py::isinstance<Chunked>(other)) {
360 return false;
361 }
362 return self == py::cast<const Chunked&>(other);
363 });
364
365 py::class_<Unstacked> unstacked(pmap_lib, "Unstacked");
366 unstacked.def(py::init<int>())
367 .def_readonly("size", &Unstacked::size)
368 .def("__repr__",
369 [](const Unstacked& x) {
370 return absl::StrCat("Unstacked(", x.size, ")");
371 })
372 .def("__eq__", [](const Unstacked& self, py::object other) {
373 if (!py::isinstance<Unstacked>(other)) {
374 return false;
375 }
376 return self == py::cast<const Unstacked&>(other);
377 });
378
379 py::class_<ShardedAxis> sharded_axis(pmap_lib, "ShardedAxis");
380 sharded_axis.def(py::init<int>()).def_readonly("axis", &ShardedAxis::axis);
381 sharded_axis
382 .def("__repr__",
383 [](const ShardedAxis& x) {
384 return absl::StrCat("ShardedAxis(axis=", x.axis, ")");
385 })
386 .def("__eq__", [](const ShardedAxis& self, const ShardedAxis& other) {
387 return self == other;
388 });
389
390 py::class_<Replicated> replicated(pmap_lib, "Replicated");
391 replicated.def(py::init<int>())
392 .def_readonly("replicas", &Replicated::replicas)
393 .def("__repr__",
394 [](const Replicated& x) {
395 return absl::StrCat("Replicated(replicas=", x.replicas, ")");
396 })
397 .def("__eq__", [](const Replicated& self, const Replicated& other) {
398 return self == other;
399 });
400
401 py::class_<ShardingSpec> sharding_spec(pmap_lib, "ShardingSpec");
402 sharding_spec
403 .def(py::init<std::vector<AvalDimSharding>,
404 std::vector<MeshDimAssignment>>(),
405 py::arg("sharding"), py::arg("mesh_mapping"))
406 .def_property_readonly("sharding", &ShardingSpec::GetSharding)
407 .def_property_readonly("mesh_mapping", &ShardingSpec::GetMeshMapping);
408
409 py::class_<ShardedDeviceArray> sda(pmap_lib, "ShardedDeviceArray");
410 sda.def(py::init<pybind11::handle, ShardingSpec, pybind11::list>())
411 .def_property_readonly("aval", &ShardedDeviceArray::GetAval)
412 .def_property_readonly("sharding_spec",
413 &ShardedDeviceArray::GetShardingSpec)
414 .def_property_readonly("device_buffers",
415 &ShardedDeviceArray::GetDeviceBuffers);
416
417 py::class_<PmapFunction, std::unique_ptr<PmapFunction>> cfun(pmap_lib,
418 "PmapFunction");
419 cfun.def("__call__", &PmapFunction::Call);
420 cfun.def_property_readonly("__signature__", &PmapFunction::PythonSignature);
421
422 pmap_lib.def(
423 "pmap",
424 [](py::function fun, py::function cache_miss,
425 py::function get_jax_enable_x64,
426 std::vector<int> static_argnums) -> std::unique_ptr<PmapFunction> {
427 return std::make_unique<PmapFunction>(
428 std::move(fun), std::move(cache_miss),
429 std::move(get_jax_enable_x64), std::move(static_argnums));
430 });
431 }
432
433 } // namespace jax
434