• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/pmap_lib.h"
17 
18 #include <stdexcept>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/strings/str_cat.h"
23 #include "absl/strings/str_join.h"
24 #include "absl/synchronization/notification.h"
25 #include "absl/types/span.h"
26 #include "absl/types/variant.h"
27 #include "pybind11/cast.h"
28 #include "pybind11/pybind11.h"
29 #include "pybind11/pytypes.h"
30 #include "tensorflow/compiler/xla/python/absl_casters.h"
31 #include "tensorflow/compiler/xla/python/jax_jit.h"
32 #include "tensorflow/compiler/xla/python/py_executable.h"
33 #include "tensorflow/compiler/xla/python/types.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/platform/logging.h"
36 
37 namespace jax {
38 
39 namespace py = pybind11;
40 
41 // TODO(jblespiau): Using `NoSharding` instead of `None` would allow us to
42 // simplify the conversion logic.
PyShardingToCpp(pybind11::tuple py_sharding)43 std::vector<AvalDimSharding> PyShardingToCpp(pybind11::tuple py_sharding) {
44   std::vector<AvalDimSharding> cpp_sharding;
45   cpp_sharding.reserve(py_sharding.size());
46   for (py::handle value : py_sharding) {
47     if (value.is_none()) {
48       cpp_sharding.push_back(NoSharding());
49     } else if (py::isinstance<Chunked>(value)) {
50       cpp_sharding.push_back(py::cast<Chunked>(value));
51     } else if (py::isinstance<Unstacked>(value)) {
52       cpp_sharding.push_back(py::cast<Unstacked>(value));
53     } else {
54       throw std::runtime_error(
55           absl::StrCat("Not supported Python object in PyShardingToCpp in "
56                        "pmap_lib.cc. The object was of type ",
57                        py::cast<std::string>(py::str(value.get_type())),
58                        "\n:", py::cast<std::string>(py::str(value))));
59     }
60   }
61   return cpp_sharding;
62 }
63 
CppShardingToPy(std::vector<AvalDimSharding> sharding)64 pybind11::tuple CppShardingToPy(std::vector<AvalDimSharding> sharding) {
65   py::tuple result(sharding.size());
66   int counter = 0;
67   for (auto value : sharding) {
68     if (absl::holds_alternative<NoSharding>(value)) {
69       result[counter++] = py::none();
70     } else if (absl::holds_alternative<Chunked>(value)) {
71       py::handle handle = py::cast(absl::get<Chunked>(value));
72       result[counter++] = py::cast<py::object>(handle);
73     } else if (absl::holds_alternative<Unstacked>(value)) {
74       py::handle handle = py::cast(absl::get<Unstacked>(value));
75       result[counter++] = py::cast<py::object>(handle);
76     } else {
77       LOG(FATAL) << "Unhandled CPP type in CppShardingToPy.";
78     }
79   }
80   return result;
81 }
82 
PyMeshShardingToCpp(pybind11::tuple py_mesh_mapping)83 std::vector<MeshDimAssignment> PyMeshShardingToCpp(
84     pybind11::tuple py_mesh_mapping) {
85   return py::cast<std::vector<MeshDimAssignment>>(py_mesh_mapping);
86 }
87 
CppMeshMappingToPy(std::vector<MeshDimAssignment> mesh_mapping)88 pybind11::tuple CppMeshMappingToPy(
89     std::vector<MeshDimAssignment> mesh_mapping) {
90   py::tuple result(mesh_mapping.size());
91   int counter = 0;
92   for (auto& value : mesh_mapping) {
93     result[counter] = py::cast(value);
94     ++counter;
95   }
96   return result;
97 }
98 
99 namespace {
100 
101 struct PmapCacheEntry {
102   // To get a first version running, we use extensively Python here for the
103   // handling of the arguments and outputs.
104   // TODO(jblespiau): Move more to C++.
105   std::shared_ptr<xla::PyExecutable> executable;
106   // See _cpp_pmap in api.py.
107   py::object backend;
108   // A function taking as argument a list of arguments and returns a list of
109   // list of buffers `[num_devices x num_args]`.
110   py::function handle_args;
111   // A function taking as argument the output of `ExecuteOnLocalDevices` and
112   // returning a list of ShardedDeviceArray objects.
113   py::function out_handler;
114   xla::PyTreeDef out_pytree_def;
115 
116   // Ensures a single thread performs the compilation for a given executable.
117   //
118   // The first thread (holding the GIL) will create the CacheEntry associated to
119   // a signature and if the object has been insterted already, other threads
120   // will wait for the notification.
121   absl::Notification compilation_complete;
122   absl::optional<xla::Status> compilation_error = absl::nullopt;
123 
124   bool fall_back_to_python = false;
125 };
126 
127 }  // namespace
128 
129 // A `PmapFunction` is associated to a `jax.pmap(f)` and takes care of the
130 // bookkeeping of the different signatures used and the dispatch of calls to
131 // the correct underlying `PyExecutable`. This class is thread-safe.
132 class PmapFunction {
133  public:
PmapFunction(py::function fun,py::function cache_miss,py::function get_jax_enable_x64,std::vector<int> static_argnums)134   PmapFunction(py::function fun, py::function cache_miss,
135                py::function get_jax_enable_x64, std::vector<int> static_argnums)
136       : fun_(std::move(fun)),
137         cache_miss_(std::move(cache_miss)),
138         static_argnums_(std::move(static_argnums)),
139         get_jax_enable_x64_(get_jax_enable_x64) {
140     std::sort(static_argnums_.begin(), static_argnums_.end());
141   }
142 
~PmapFunction()143   ~PmapFunction() {
144     for (const auto& entry : executables_) {
145       entry.first.DecRef();
146     }
147   }
148 
149   // This function will:
150   // (a) flatten the inputs using pytree
151   // (b) get buffer objects from the arguments
152   // (c) call the executable
153   // (d) construct `ShardedDeviceArray` objects from the outputs
154   // (e) reconstruct the `PyTree`.
155   py::object Call(py::args args, py::kwargs kwargs);
156 
PythonSignature()157   py::object PythonSignature() {
158     static const auto* inspect = new py::module(py::module::import("inspect"));
159     return inspect->attr("signature")(fun_);
160   }
161 
cache_size() const162   int cache_size() const { return executables_.size(); }
163 
164  private:
165   // Returns nullptr if not present in the cache.
166   PmapCacheEntry* GetCacheEntryIfPresent(const CallSignature& signature);
167   // Should never return nullptr.
168   PmapCacheEntry* AddCacheEntry(const py::args& args, const py::kwargs& kwargs,
169                                 const CallSignature& signature,
170                                 py::object out_and_fastpath_data);
171 
172   bool always_fallback_to_python_ = false;
173 
174   const py::function fun_;  // The Python function to pmap.
175   // See JAX _cpp_pmap in api.py for documentation.
176   const py::function cache_miss_;
177 
178   // We need to know the static arguments to remove them from the arguments
179   // passed to the underlying PyExecutable. In sorted order.
180   std::vector<int> static_argnums_;
181   // We need a `unique_ptr` here to ensure value pointer stability.
182   absl::flat_hash_map<CallSignature, std::unique_ptr<PmapCacheEntry>>
183       executables_;
184 
185   const py::function get_jax_enable_x64_;
186   absl::optional<bool> jax_enable_x64_ = absl::nullopt;
187 
188   // A vector of size `num_outputs`, specifying the sharding of each output
189   std::vector<ShardingSpec> sharding_specs_;
190 };
191 
GetCacheEntryIfPresent(const CallSignature & signature)192 PmapCacheEntry* PmapFunction::GetCacheEntryIfPresent(
193     const CallSignature& signature) {
194   auto found_iterator = executables_.find(signature);
195   if (found_iterator != executables_.end()) {  // Cache hit!
196     if (!found_iterator->second->compilation_complete.HasBeenNotified()) {
197       py::gil_scoped_release gil_release;
198       found_iterator->second->compilation_complete.WaitForNotification();
199     }
200     if (found_iterator->second->compilation_error) {
201       throw std::invalid_argument(
202           found_iterator->second->compilation_error.value().error_message());
203     }
204     return found_iterator->second.get();
205   }
206   return nullptr;
207 }
208 
AddCacheEntry(const py::args & args,const py::kwargs & kwargs,const CallSignature & signature,py::object out_and_fastpath_data)209 PmapCacheEntry* PmapFunction::AddCacheEntry(const py::args& args,
210                                             const py::kwargs& kwargs,
211                                             const CallSignature& signature,
212                                             py::object out_and_fastpath_data) {
213   // We need to insert the element.
214   auto result =
215       executables_.emplace(signature, std::make_unique<PmapCacheEntry>());
216   auto it = result.first;
217   PmapCacheEntry* cache_entry = it->second.get();
218   // CallSignatures in the cache own their keyword argument reference.
219   result.first->first.IncRef();
220 
221   py::tuple tuple = py::cast<py::tuple>(out_and_fastpath_data);
222   CHECK_EQ(tuple.size(), 2);
223   if (tuple[1].is_none()) {
224     cache_entry->fall_back_to_python = true;
225     cache_entry->compilation_complete.Notify();
226     return cache_entry;
227   }
228 
229   py::dict pmap_data = py::cast<py::dict>(tuple[1]);
230   if (py::cast<int>(pmap_data["version"]) != 1) {
231     throw std::runtime_error(absl::StrCat(
232         "The versions of jaxlib and Jax are incompatible (pmap cpp version 1 "
233         "expected, but got ",
234         py::cast<int>(pmap_data["version"]),
235         "Upgrade jaxlib and jax. Provided data was:",
236         py::cast<std::string>(py::str(py::repr(pmap_data)))));
237   }
238   // { "version": 1,
239   //   "xla_executable": xla_executable,
240   //   "in_handler": in_handler,
241   //   "out_handler": out_handler,
242   //   "out_pytree_def": out_pytree_def }
243   auto executable =
244       py::cast<std::shared_ptr<xla::PyExecutable>>(pmap_data["xla_executable"]);
245   cache_entry->executable = std::move(executable);
246   cache_entry->handle_args = py::cast<py::function>(pmap_data["in_handler"]);
247   cache_entry->out_handler = py::cast<py::function>(pmap_data["out_handler"]);
248   auto out_tree = py::cast<xla::PyTreeDef>(pmap_data["out_pytree_def"]);
249   cache_entry->out_pytree_def = std::move(out_tree);
250 
251   cache_entry->compilation_complete.Notify();
252   return cache_entry;
253 }
254 
Call(py::args args,py::kwargs kwargs)255 py::object PmapFunction::Call(py::args args, py::kwargs kwargs) {
256   if (always_fallback_to_python_) {
257     return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
258   }
259   // Delayed values are retrieved on the first call to `Call`.
260   if (jax_enable_x64_ == absl::nullopt) {
261     jax_enable_x64_ = py::cast<bool>(get_jax_enable_x64_());
262   }
263 
264   ParsedArgumentsAsBuffers arguments;
265   if (!ParseArguments(args, kwargs, static_argnums_, arguments).ok()) {
266     return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
267   }
268 
269   // Get dynamic argument signatures.
270   for (py::handle arg : arguments.flat_dynamic_args) {
271     auto signature_or_error = ArgSignatureOfValue(arg, jax_enable_x64_.value());
272     if (!signature_or_error.ok()) {
273       return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
274     }
275     arguments.signature.dynamic_args_signatures.push_back(
276         std::move(signature_or_error).ValueOrDie());
277   }
278 
279   // Retrieve/Maybe add the executable to the cache.
280   PmapCacheEntry* cache_entry = GetCacheEntryIfPresent(arguments.signature);
281   if (!cache_entry) {
282     py::object out_and_fastpath_data = cache_miss_(*args, **kwargs);
283     cache_entry = GetCacheEntryIfPresent(arguments.signature);
284     if (!cache_entry) {
285       cache_entry = AddCacheEntry(args, kwargs, arguments.signature,
286                                   out_and_fastpath_data);
287     }
288     CHECK(cache_entry);
289     if (cache_entry->fall_back_to_python) {
290       return py::cast<py::tuple>(out_and_fastpath_data)[0];
291     }
292     // As we have already computed the results, we can return it.
293     // It's even *required* e.g. if there are donated arguments, because
294     // otherwise the buffer which has been donated already will be invalid.
295     return py::cast<py::tuple>(out_and_fastpath_data)[0];
296   }
297 
298   CHECK(cache_entry);
299   if (cache_entry->fall_back_to_python) {
300     return py::cast<py::tuple>(cache_miss_(*args, **kwargs))[0];
301   }
302 
303   // TODO(jblespiau): Use C++ only for this.
304   py::list arg_list;
305   for (auto& v : arguments.flat_dynamic_args) {
306     arg_list.append(v);
307   }
308 
309   py::object handled_args = cache_entry->handle_args(arg_list);
310   py::list list_of_list_of_buffers = py::cast<py::list>(handled_args);
311 
312   arguments.keep_alive_objects.push_back(
313       py::cast<py::object>(list_of_list_of_buffers));
314   // Should be `[num_devices x num_args]`.
315   std::vector<std::vector<xla::PyBuffer*>> arg_buffers;
316   arg_buffers.reserve(list_of_list_of_buffers.size());
317   for (int i = 0; i < list_of_list_of_buffers.size(); ++i) {
318     std::vector<xla::PyBuffer*> buffers;
319     buffers.reserve(py::cast<py::list>(list_of_list_of_buffers[i]).size());
320     for (auto& buf : list_of_list_of_buffers[i]) {
321       buffers.push_back(py::cast<xla::PyBuffer*>(buf));
322     }
323     arg_buffers.push_back(std::move(buffers));
324   }
325 
326   std::vector<std::vector<std::unique_ptr<xla::PyBuffer>>> outputs =
327       ValueOrThrow(cache_entry->executable->ExecuteOnLocalDevices(arg_buffers));
328 
329   // TODO(jblespiau): Move this to C++.
330   py::list outputs_as_python_objects;
331   for (int i = 0; i < outputs.size(); ++i) {
332     outputs_as_python_objects.append(py::cast(std::move(outputs[i])));
333   }
334   py::list flat_sharded_device_arrays =
335       cache_entry->out_handler(outputs_as_python_objects);
336   return cache_entry->out_pytree_def.Unflatten(flat_sharded_device_arrays);
337 }
338 
BuildPmapSubmodule(pybind11::module & m)339 void BuildPmapSubmodule(pybind11::module& m) {
340   py::module pmap_lib = m.def_submodule("pmap_lib", "Jax C++ pmap library");
341 
342   py::class_<NoSharding> no_sharding(pmap_lib, "NoSharding");
343   no_sharding.def(py::init<>())
344       .def("__repr__",
345            [](const NoSharding& chuncked) { return "NoSharding()"; })
346       .def("__eq__", [](const NoSharding& self, py::object obj) {
347         return py::isinstance<NoSharding>(obj);
348       });
349 
350   py::class_<Chunked> chunked(pmap_lib, "Chunked");
351   chunked.def(py::init<std::vector<int>>())
352       .def_readonly("chunks", &Chunked::chunks)
353       .def("__repr__",
354            [](const Chunked& chuncked) {
355              return absl::StrCat("Chunked(",
356                                  absl::StrJoin(chuncked.chunks, ","), ")");
357            })
358       .def("__eq__", [](const Chunked& self, py::object other) {
359         if (!py::isinstance<Chunked>(other)) {
360           return false;
361         }
362         return self == py::cast<const Chunked&>(other);
363       });
364 
365   py::class_<Unstacked> unstacked(pmap_lib, "Unstacked");
366   unstacked.def(py::init<int>())
367       .def_readonly("size", &Unstacked::size)
368       .def("__repr__",
369            [](const Unstacked& x) {
370              return absl::StrCat("Unstacked(", x.size, ")");
371            })
372       .def("__eq__", [](const Unstacked& self, py::object other) {
373         if (!py::isinstance<Unstacked>(other)) {
374           return false;
375         }
376         return self == py::cast<const Unstacked&>(other);
377       });
378 
379   py::class_<ShardedAxis> sharded_axis(pmap_lib, "ShardedAxis");
380   sharded_axis.def(py::init<int>()).def_readonly("axis", &ShardedAxis::axis);
381   sharded_axis
382       .def("__repr__",
383            [](const ShardedAxis& x) {
384              return absl::StrCat("ShardedAxis(axis=", x.axis, ")");
385            })
386       .def("__eq__", [](const ShardedAxis& self, const ShardedAxis& other) {
387         return self == other;
388       });
389 
390   py::class_<Replicated> replicated(pmap_lib, "Replicated");
391   replicated.def(py::init<int>())
392       .def_readonly("replicas", &Replicated::replicas)
393       .def("__repr__",
394            [](const Replicated& x) {
395              return absl::StrCat("Replicated(replicas=", x.replicas, ")");
396            })
397       .def("__eq__", [](const Replicated& self, const Replicated& other) {
398         return self == other;
399       });
400 
401   py::class_<ShardingSpec> sharding_spec(pmap_lib, "ShardingSpec");
402   sharding_spec
403       .def(py::init<std::vector<AvalDimSharding>,
404                     std::vector<MeshDimAssignment>>(),
405            py::arg("sharding"), py::arg("mesh_mapping"))
406       .def_property_readonly("sharding", &ShardingSpec::GetSharding)
407       .def_property_readonly("mesh_mapping", &ShardingSpec::GetMeshMapping);
408 
409   py::class_<ShardedDeviceArray> sda(pmap_lib, "ShardedDeviceArray");
410   sda.def(py::init<pybind11::handle, ShardingSpec, pybind11::list>())
411       .def_property_readonly("aval", &ShardedDeviceArray::GetAval)
412       .def_property_readonly("sharding_spec",
413                              &ShardedDeviceArray::GetShardingSpec)
414       .def_property_readonly("device_buffers",
415                              &ShardedDeviceArray::GetDeviceBuffers);
416 
417   py::class_<PmapFunction, std::unique_ptr<PmapFunction>> cfun(pmap_lib,
418                                                                "PmapFunction");
419   cfun.def("__call__", &PmapFunction::Call);
420   cfun.def_property_readonly("__signature__", &PmapFunction::PythonSignature);
421 
422   pmap_lib.def(
423       "pmap",
424       [](py::function fun, py::function cache_miss,
425          py::function get_jax_enable_x64,
426          std::vector<int> static_argnums) -> std::unique_ptr<PmapFunction> {
427         return std::make_unique<PmapFunction>(
428             std::move(fun), std::move(cache_miss),
429             std::move(get_jax_enable_x64), std::move(static_argnums));
430       });
431 }
432 
433 }  // namespace jax
434