1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "absl/types/optional.h"
24 #include "pybind11/pybind11.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29 
30 namespace xla {
31 
32 class PyBuffer;
33 class PyClient;
34 class PyExecutable;
35 
36 // Custom holder types.
37 //
38 // We must keep the PyClient object alive as long as any of the runtime
39 // objects are alive. Since we don't have a lot of control over Python
40 // destructor ordering, we keep the PyClient object as a std::shared_ptr<>,
41 // and ensure that each Python runtime object holds a reference to the
42 // PyClient. An alternative design would be to keep a single global
43 // singleton PyClient, although this seems less flexible, especially for
44 // writing tests.
45 //
46 // To maintain PyClient references, we define pybind11 holder classes that
47 // are custom smart pointers that also keep a reference to a PyClient.
48 // pybind11 has a `keep_alive` feature that has a similar goal, but it doesn't
49 // seem sufficiently flexible to describe ownership relationships in cases where
50 // the ownership doesn't pertain to a direct argument or return value of a
51 // function. Another alternative to the holder classes would be to create proxy
52 // objects that contain both a reference and a runtime class; holder classes
53 // seem less tedious to define.
54 
55 // A pair of a PyClient reference and an unowned pointer to T.
56 template <typename T>
57 struct ClientAndPtr {
58   ClientAndPtr() = default;
59   // pybind11 requires that we define a constructor that takes a raw pointer,
60   // but it should be unreachable.
ClientAndPtrClientAndPtr61   explicit ClientAndPtr(T*) {
62     LOG(FATAL) << "ClientAndPtr should constructed via WrapWithClient.";
63   }
64 
65   ClientAndPtr(const ClientAndPtr&) = default;
66   ClientAndPtr(ClientAndPtr&&) = default;
67   ClientAndPtr& operator=(const ClientAndPtr&) = default;
68   ClientAndPtr& operator=(ClientAndPtr&&) = default;
69 
70   std::shared_ptr<PyClient> client;
71   T* contents;
72 
getClientAndPtr73   T* get() const { return contents; }
74   T* operator->() const { return contents; }
75   T& operator*() const { return *contents; }
76 };
77 
78 // By defining a templated helper function, we can use return type deduction
79 // and avoid specifying types at the caller.
80 template <typename T>
WrapWithClient(std::shared_ptr<PyClient> client,T * contents)81 ClientAndPtr<T> WrapWithClient(std::shared_ptr<PyClient> client, T* contents) {
82   ClientAndPtr<T> result;
83   result.client = std::move(client);
84   result.contents = contents;
85   return result;
86 }
87 
88 // Python wrapper around PjRtClient.
89 // We use a wrapper class to add Python-specific functionality.
90 class PyClient : public std::enable_shared_from_this<PyClient> {
91  public:
92   explicit PyClient(std::unique_ptr<PjRtClient> pjrt_client);
93   explicit PyClient(std::shared_ptr<PjRtClient> pjrt_client);
94   ~PyClient();
95 
pjrt_client()96   PjRtClient* pjrt_client() const { return pjrt_client_.get(); }
shared_pjrt_client()97   std::shared_ptr<PjRtClient> shared_pjrt_client() { return pjrt_client_; }
98 
platform_name()99   absl::string_view platform_name() const {
100     return pjrt_client_->platform_name();
101   }
platform_version()102   absl::string_view platform_version() const {
103     return pjrt_client_->platform_version();
104   }
runtime_type()105   absl::string_view runtime_type() const {
106     return PjRtRuntimeTypeString(pjrt_client_->runtime_type());
107   }
addressable_device_count()108   int addressable_device_count() const {
109     return pjrt_client_->addressable_device_count();
110   }
device_count()111   int device_count() const { return pjrt_client_->device_count(); }
process_index()112   int process_index() const { return pjrt_client_->process_index(); }
113 
114   std::vector<ClientAndPtr<PjRtDevice>> Devices();
115   std::vector<ClientAndPtr<PjRtDevice>> LocalDevices();
116 
117   // Returns a vector of live PyBuffer objects. PyBuffer objects may share
118   // PjRtBuffers, so there may be duplicates of the same underlying device
119   // buffer.
120   std::vector<pybind11::object> LiveBuffers();
121   std::vector<pybind11::object> LiveBuffersOnDevice(PjRtDevice* device);
122 
123   // Returns a vector of live PyExecutable objects.
124   // note: must return std::shared_ptr instead of raw ptrs
125   // https://pybind11.readthedocs.io/en/stable/advanced/smart_ptrs.html#std-shared-ptr
126   std::vector<std::shared_ptr<PyExecutable>> LiveExecutables();
127 
128   // TODO(zhangqiaorjc): Remove when we have transparent defragmentation.
129   Status Defragment();
130 
131   StatusOr<std::vector<std::vector<ClientAndPtr<PjRtDevice>>>>
132   GetDefaultDeviceAssignment(int num_replicas, int num_partitions);
133 
134   // TODO(skye): delete after all callers can handle 2D output
135   StatusOr<std::vector<ClientAndPtr<PjRtDevice>>> GetDefaultDeviceAssignment1D(
136       int num_replicas);
137 
CreateChannelHandle()138   StatusOr<ChannelHandle> CreateChannelHandle() {
139     return pjrt_client_->CreateChannelHandle();
140   }
CreateDeviceToHostChannelHandle()141   StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() {
142     return pjrt_client_->CreateDeviceToHostChannelHandle();
143   }
CreateHostToDeviceChannelHandle()144   StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() {
145     return pjrt_client_->CreateHostToDeviceChannelHandle();
146   }
147 
148   StatusOr<pybind11::object> BufferFromPyval(
149       pybind11::handle argument, PjRtDevice* device, bool force_copy,
150       PjRtClient::HostBufferSemantics host_buffer_semantics);
151 
152   StatusOr<std::shared_ptr<PyExecutable>> Compile(
153       const XlaComputation& computation, CompileOptions options);
154 
155   StatusOr<pybind11::bytes> SerializeExecutable(
156       const PyExecutable& executable) const;
157   StatusOr<std::shared_ptr<PyExecutable>> DeserializeExecutable(
158       const std::string& serialized, CompileOptions options);
159 
160   // TODO(skyewm): remove when jax stop providing hlo_module
DeserializeExecutable(const std::string & serialized,std::shared_ptr<HloModule> hlo_module,CompileOptions options)161   StatusOr<std::shared_ptr<PyExecutable>> DeserializeExecutable(
162       const std::string& serialized, std::shared_ptr<HloModule> hlo_module,
163       CompileOptions options) {
164     return DeserializeExecutable(serialized, options);
165   }
166 
167   StatusOr<pybind11::bytes> HeapProfile();
168 
169   // Adds code to `builder` to call Python host function `callable` with
170   // `operands`, returning a result of `result_shape`. If desired, the operand
171   // layouts can be constrained by `operand_layouts`. Returns a pair of the
172   // output XlaOp, together with an object that must be kept alive as long as
173   // the Python callback may be called. Typically the callback may be kept
174   // alive by attaching it to the executable built from this computation.
175   //
176   // Callable receives as arguments NumPy arrays for arguments with array types,
177   // and None for Token argument. The callable must return a tuple of either
178   // arrays or None values.
179   //
180   // This is a method of PyClient since different platforms may implement this
181   // functionality in different ways.
182   StatusOr<std::pair<XlaOp, pybind11::object>> EmitPythonCallback(
183       pybind11::function callable, XlaBuilder& builder,
184       absl::Span<XlaOp const> operands, absl::Span<Shape const> result_shapes,
185       absl::optional<std::vector<Shape>> operand_layouts, bool has_side_effect);
186 
187  private:
188   friend class PyBuffer;
189   friend class PyExecutable;
190 
191   std::shared_ptr<PjRtClient> pjrt_client_;
192 
193   // Pointers to intrusive doubly-linked lists of buffers and executables, used
194   // to iterate over all known objects when heap profiling. The list structure
195   // is protected by the GIL.
196 
197   // buffers_ is a per-device list, indexed by device->id().
198   std::vector<PyBuffer*> buffers_;
199   PyExecutable* executables_ = nullptr;
200 };
201 
202 }  // namespace xla
203 
204 PYBIND11_DECLARE_HOLDER_TYPE(T, xla::ClientAndPtr<T>);
205 
206 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
207