• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PY_BUFFER_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_PY_BUFFER_H_
18 
19 #include <memory>
20 #include <stdexcept>
21 #include <vector>
22 
23 #include "absl/strings/string_view.h"
24 #include "absl/synchronization/notification.h"
25 #include "absl/types/optional.h"
26 #include "pybind11/numpy.h"
27 #include "pybind11/pybind11.h"
28 #include "tensorflow/compiler/xla/python/py_client.h"
29 #include "tensorflow/compiler/xla/python/traceback.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/compiler/xla/types.h"
32 
33 namespace xla {
34 
35 // Python wrapper around PjRtBuffer. We use a wrapper class:
36 // a) to keep the PjRtClient alive via a std::shared_ptr<>
37 // b) to add Python-specific functionality.
38 //
39 // A `PyBuffer` can be used from Python without being wrapped in a Python
40 // `DeviceArray` object, at the condition there is no associated LazyExpr.
41 class PyBuffer {
42  public:
43   // pybind11::object typed subclass for PyBuffer objects.
44   class pyobject : public pybind11::object {
45    public:
46     PYBIND11_OBJECT(pyobject,  // NOLINT
47                     pybind11::object, PyBuffer::IsPyBuffer);
48     pyobject() = default;
buf()49     PyBuffer* buf() const { return PyBuffer::AsPyBufferUnchecked(*this); }
50   };
51   using object = pyobject;
52 
53   static object Make(std::shared_ptr<PyClient> client,
54                      std::shared_ptr<PjRtBuffer> buffer,
55                      std::shared_ptr<Traceback> traceback);
56 
57   // Returns true if `h` is a PyBuffer.
58   static bool IsPyBuffer(pybind11::handle handle);
59   // Converts `handle` to a PyBuffer*. Does not do any checking.
60   static PyBuffer* AsPyBufferUnchecked(pybind11::handle handle);
61   // Converts `handle` to a PyBuffer*. Returns an error status if
62   // !IsPyBuffer(handle)
63   static StatusOr<PyBuffer*> AsPyBuffer(pybind11::handle handle);
64 
65   // Gets a Python handle to an existing PyBuffer. Assumes the PyObject was
66   // allocated on the Python heap, which is the case if Make() was used.
67   pybind11::handle AsHandle();
68 
69   ~PyBuffer();
70 
client()71   std::shared_ptr<PyClient> client() const { return client_; }
buffer()72   PjRtBuffer* buffer() const { return buffer_.get(); }
shared_ptr_buffer()73   std::shared_ptr<PjRtBuffer> shared_ptr_buffer() const { return buffer_; }
74 
75   ClientAndPtr<PjRtDevice> device() const;
platform_name()76   absl::string_view platform_name() const {
77     return buffer_->client()->platform_name();
78   }
is_deleted()79   bool is_deleted() const { return buffer_->IsDeleted(); }
80 
81   StatusOr<pybind11::object> CopyToDevice(
82       const ClientAndPtr<PjRtDevice>& dst_device) const;
83 
OnDeviceSizeInBytes()84   StatusOr<size_t> OnDeviceSizeInBytes() {
85     return buffer_->GetOnDeviceSizeInBytes();
86   }
87 
Delete()88   void Delete() {
89     buffer_->Delete();
90     host_value_ = nullptr;
91   }
92 
93   // Makes a copy of this PyBuffer object that shares the underlying PjRtBuffer.
94   // This is useful because we may wish to change JAX metadata (e.g., the sticky
95   // device) without copying the buffer.
96   object Clone() const;
97 
98   // Returns xla::InvalidArgument if the buffer has been deleted.
99   Status BlockHostUntilReady();
100   Status CopyToHostAsync();
101 
shape()102   const Shape& shape() { return buffer_->on_device_shape(); }
103 
104   StatusOr<std::uintptr_t> UnsafeBufferPointer() const;
105 
106   // Implementation of the CUDA array interface for sharing GPU buffers with
107   // other Python libraries.
108   StatusOr<pybind11::dict> CudaArrayInterface();
109 
traceback()110   const std::shared_ptr<Traceback>& traceback() const { return traceback_; }
111 
112   // Returns the size (i.e. number of elements) of the (host) numpy array.
113   StatusOr<int64> size();
114 
115   // Returns the number of dimensions of the (host) numpy array.
ndim()116   int ndim() const { return buffer()->on_device_shape().dimensions_size(); }
117 
118   pybind11::tuple python_shape() const;
119   pybind11::dtype python_dtype() const;
120 
121   // Representing the logical view of the underlying dynamic shapes.
122   StatusOr<const Shape*> xla_dynamic_shape();
123 
set_sticky_device(PjRtDevice * sticky_device)124   Status set_sticky_device(PjRtDevice* sticky_device) {
125     TF_RET_CHECK(sticky_device == nullptr ||
126                  sticky_device == buffer_->device());
127     sticky_device_ = sticky_device;
128     return Status::OK();
129   }
sticky_device()130   PjRtDevice* sticky_device() const { return sticky_device_; }
131 
set_weak_type(absl::optional<bool> weak_type)132   void set_weak_type(absl::optional<bool> weak_type) { weak_type_ = weak_type; }
weak_type()133   absl::optional<bool> weak_type() const { return weak_type_; }
134 
135   StatusOr<pybind11::object> AsNumPyArray(pybind11::handle this_obj);
136 
SetAval(pybind11::object aval)137   void SetAval(pybind11::object aval) { aval_ = aval; }
GetAval()138   pybind11::object GetAval() const { return aval_; }
139 
140   static Status RegisterTypes(pybind11::module& m);
base_type()141   static PyObject* base_type() { return base_type_; }
type()142   static PyObject* type() { return type_; }
143 
144  private:
145   // PyBuffer objects must not be allocated directly since they must always live
146   // on the Python heap. Use Make() instead.
147   PyBuffer(std::shared_ptr<PyClient> client, std::shared_ptr<PjRtBuffer> buffer,
148            std::shared_ptr<Traceback> traceback);
149 
150   static PyObject* base_type_;
151   static PyObject* type_;
152 
153   friend class PyClient;
154 
155   struct HostValue {
156     absl::Notification ready;
157     Status status;
158     std::shared_ptr<xla::Literal> value;
159   };
160   std::shared_ptr<PyClient> client_;
161   std::shared_ptr<PjRtBuffer> buffer_;
162   std::shared_ptr<Traceback> traceback_;
163   std::shared_ptr<HostValue> host_value_;  // Protected by the GIL.
164 
165   // JAX uses this field to record whether a buffer is committed to a particular
166   // device by the user (https://github.com/google/jax/pull/1916).
167   PjRtDevice* sticky_device_ = nullptr;
168 
169   // TODO(phawkins): consider not keeping an explicit aval on C++ buffer
170   // objects.
171   pybind11::object aval_ = pybind11::none();
172 
173   // An optional weak type. If absent, the JAX jit code computes the weak_type
174   // from the aval_.weak_type attribute. This is a backwards compatibility
175   // measure for older Python code that does not set weak_type explicitly.
176   // TODO(phawkins): drop support for older jax Python versions and make
177   // weak_type mandatory.
178   absl::optional<bool> weak_type_ = absl::nullopt;
179 
180   absl::optional<Shape> dynamic_shape_ = absl::nullopt;
181   // Doubly-linked list of all PyBuffers known to the client. Protected by the
182   // GIL. Since multiple PyBuffers may share the same PjRtBuffer, there may be
183   // duplicate PjRtBuffers in this list.
184   PyBuffer* next_;
185   PyBuffer* prev_;
186 };
187 
188 }  // namespace xla
189 
190 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_PY_BUFFER_H_
191