1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
18
19 #include <memory>
20 #include <string>
21 #include <vector>
22
23 #include "absl/types/optional.h"
24 #include "pybind11/pybind11.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29
30 namespace xla {
31
32 class PyBuffer;
33 class PyClient;
34 class PyExecutable;
35
36 // Custom holder types.
37 //
38 // We must keep the PyClient object alive as long as any of the runtime
39 // objects are alive. Since we don't have a lot of control over Python
40 // destructor ordering, we keep the PyClient object as a std::shared_ptr<>,
41 // and ensure that each Python runtime object holds a reference to the
42 // PyClient. An alternative design would be to keep a single global
43 // singleton PyClient, although this seems less flexible, especially for
44 // writing tests.
45 //
46 // To maintain PyClient references, we define pybind11 holder classes that
47 // are custom smart pointers that also keep a reference to a PyClient.
48 // pybind11 has a `keep_alive` feature that has a similar goal, but it doesn't
49 // seem sufficiently flexible to describe ownership relationships in cases where
50 // the ownership doesn't pertain to a direct argument or return value of a
51 // function. Another alternative to the holder classes would be to create proxy
52 // objects that contain both a reference and a runtime class; holder classes
53 // seem less tedious to define.
54
55 // A pair of a PyClient reference and an unowned pointer to T.
56 template <typename T>
57 struct ClientAndPtr {
58 ClientAndPtr() = default;
59 // pybind11 requires that we define a constructor that takes a raw pointer,
60 // but it should be unreachable.
ClientAndPtrClientAndPtr61 explicit ClientAndPtr(T*) {
62 LOG(FATAL) << "ClientAndPtr should constructed via WrapWithClient.";
63 }
64
65 ClientAndPtr(const ClientAndPtr&) = default;
66 ClientAndPtr(ClientAndPtr&&) = default;
67 ClientAndPtr& operator=(const ClientAndPtr&) = default;
68 ClientAndPtr& operator=(ClientAndPtr&&) = default;
69
70 std::shared_ptr<PyClient> client;
71 T* contents;
72
getClientAndPtr73 T* get() const { return contents; }
74 T* operator->() const { return contents; }
75 T& operator*() const { return *contents; }
76 };
77
78 // By defining a templated helper function, we can use return type deduction
79 // and avoid specifying types at the caller.
80 template <typename T>
WrapWithClient(std::shared_ptr<PyClient> client,T * contents)81 ClientAndPtr<T> WrapWithClient(std::shared_ptr<PyClient> client, T* contents) {
82 ClientAndPtr<T> result;
83 result.client = std::move(client);
84 result.contents = contents;
85 return result;
86 }
87
88 // Python wrapper around PjRtClient.
89 // We use a wrapper class to add Python-specific functionality.
90 class PyClient : public std::enable_shared_from_this<PyClient> {
91 public:
92 explicit PyClient(std::unique_ptr<PjRtClient> pjrt_client);
93 explicit PyClient(std::shared_ptr<PjRtClient> pjrt_client);
94 ~PyClient();
95
pjrt_client()96 PjRtClient* pjrt_client() const { return pjrt_client_.get(); }
shared_pjrt_client()97 std::shared_ptr<PjRtClient> shared_pjrt_client() { return pjrt_client_; }
98
platform_name()99 absl::string_view platform_name() const {
100 return pjrt_client_->platform_name();
101 }
platform_version()102 absl::string_view platform_version() const {
103 return pjrt_client_->platform_version();
104 }
runtime_type()105 absl::string_view runtime_type() const {
106 return PjRtRuntimeTypeString(pjrt_client_->runtime_type());
107 }
addressable_device_count()108 int addressable_device_count() const {
109 return pjrt_client_->addressable_device_count();
110 }
device_count()111 int device_count() const { return pjrt_client_->device_count(); }
process_index()112 int process_index() const { return pjrt_client_->process_index(); }
113
114 std::vector<ClientAndPtr<PjRtDevice>> Devices();
115 std::vector<ClientAndPtr<PjRtDevice>> LocalDevices();
116
117 // Returns a vector of live PyBuffer objects. PyBuffer objects may share
118 // PjRtBuffers, so there may be duplicates of the same underlying device
119 // buffer.
120 std::vector<pybind11::object> LiveBuffers();
121 std::vector<pybind11::object> LiveBuffersOnDevice(PjRtDevice* device);
122
123 // Returns a vector of live PyExecutable objects.
124 // note: must return std::shared_ptr instead of raw ptrs
125 // https://pybind11.readthedocs.io/en/stable/advanced/smart_ptrs.html#std-shared-ptr
126 std::vector<std::shared_ptr<PyExecutable>> LiveExecutables();
127
128 // TODO(zhangqiaorjc): Remove when we have transparent defragmentation.
129 Status Defragment();
130
131 StatusOr<std::vector<std::vector<ClientAndPtr<PjRtDevice>>>>
132 GetDefaultDeviceAssignment(int num_replicas, int num_partitions);
133
134 // TODO(skye): delete after all callers can handle 2D output
135 StatusOr<std::vector<ClientAndPtr<PjRtDevice>>> GetDefaultDeviceAssignment1D(
136 int num_replicas);
137
CreateChannelHandle()138 StatusOr<ChannelHandle> CreateChannelHandle() {
139 return pjrt_client_->CreateChannelHandle();
140 }
CreateDeviceToHostChannelHandle()141 StatusOr<ChannelHandle> CreateDeviceToHostChannelHandle() {
142 return pjrt_client_->CreateDeviceToHostChannelHandle();
143 }
CreateHostToDeviceChannelHandle()144 StatusOr<ChannelHandle> CreateHostToDeviceChannelHandle() {
145 return pjrt_client_->CreateHostToDeviceChannelHandle();
146 }
147
148 StatusOr<pybind11::object> BufferFromPyval(
149 pybind11::handle argument, PjRtDevice* device, bool force_copy,
150 PjRtClient::HostBufferSemantics host_buffer_semantics);
151
152 StatusOr<std::shared_ptr<PyExecutable>> Compile(
153 const XlaComputation& computation, CompileOptions options);
154
155 StatusOr<pybind11::bytes> SerializeExecutable(
156 const PyExecutable& executable) const;
157 StatusOr<std::shared_ptr<PyExecutable>> DeserializeExecutable(
158 const std::string& serialized, CompileOptions options);
159
160 // TODO(skyewm): remove when jax stop providing hlo_module
DeserializeExecutable(const std::string & serialized,std::shared_ptr<HloModule> hlo_module,CompileOptions options)161 StatusOr<std::shared_ptr<PyExecutable>> DeserializeExecutable(
162 const std::string& serialized, std::shared_ptr<HloModule> hlo_module,
163 CompileOptions options) {
164 return DeserializeExecutable(serialized, options);
165 }
166
167 StatusOr<pybind11::bytes> HeapProfile();
168
169 // Adds code to `builder` to call Python host function `callable` with
170 // `operands`, returning a result of `result_shape`. If desired, the operand
171 // layouts can be constrained by `operand_layouts`. Returns a pair of the
172 // output XlaOp, together with an object that must be kept alive as long as
173 // the Python callback may be called. Typically the callback may be kept
174 // alive by attaching it to the executable built from this computation.
175 //
176 // Callable receives as arguments NumPy arrays for arguments with array types,
177 // and None for Token argument. The callable must return a tuple of either
178 // arrays or None values.
179 //
180 // This is a method of PyClient since different platforms may implement this
181 // functionality in different ways.
182 StatusOr<std::pair<XlaOp, pybind11::object>> EmitPythonCallback(
183 pybind11::function callable, XlaBuilder& builder,
184 absl::Span<XlaOp const> operands, absl::Span<Shape const> result_shapes,
185 absl::optional<std::vector<Shape>> operand_layouts, bool has_side_effect);
186
187 private:
188 friend class PyBuffer;
189 friend class PyExecutable;
190
191 std::shared_ptr<PjRtClient> pjrt_client_;
192
193 // Pointers to intrusive doubly-linked lists of buffers and executables, used
194 // to iterate over all known objects when heap profiling. The list structure
195 // is protected by the GIL.
196
197 // buffers_ is a per-device list, indexed by device->id().
198 std::vector<PyBuffer*> buffers_;
199 PyExecutable* executables_ = nullptr;
200 };
201
202 } // namespace xla
203
204 PYBIND11_DECLARE_HOLDER_TYPE(T, xla::ClientAndPtr<T>);
205
206 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_PY_CLIENT_H_
207