1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PY_BUFFER_H_ 17 #define TENSORFLOW_COMPILER_XLA_PYTHON_PY_BUFFER_H_ 18 19 #include <memory> 20 #include <stdexcept> 21 #include <vector> 22 23 #include "absl/strings/string_view.h" 24 #include "absl/synchronization/notification.h" 25 #include "absl/types/optional.h" 26 #include "pybind11/numpy.h" 27 #include "pybind11/pybind11.h" 28 #include "tensorflow/compiler/xla/python/py_client.h" 29 #include "tensorflow/compiler/xla/python/traceback.h" 30 #include "tensorflow/compiler/xla/statusor.h" 31 #include "tensorflow/compiler/xla/types.h" 32 33 namespace xla { 34 35 // Python wrapper around PjRtBuffer. We use a wrapper class: 36 // a) to keep the PjRtClient alive via a std::shared_ptr<> 37 // b) to add Python-specific functionality. 38 // 39 // A `PyBuffer` can be used from Python without being wrapped in a Python 40 // `DeviceArray` object, at the condition there is no associated LazyExpr. 41 class PyBuffer { 42 public: 43 // pybind11::object typed subclass for PyBuffer objects. 44 class pyobject : public pybind11::object { 45 public: 46 PYBIND11_OBJECT(pyobject, // NOLINT 47 pybind11::object, PyBuffer::IsPyBuffer); 48 pyobject() = default; buf()49 PyBuffer* buf() const { return PyBuffer::AsPyBufferUnchecked(*this); } 50 }; 51 using object = pyobject; 52 53 static object Make(std::shared_ptr<PyClient> client, 54 std::shared_ptr<PjRtBuffer> buffer, 55 std::shared_ptr<Traceback> traceback); 56 57 // Returns true if `h` is a PyBuffer. 58 static bool IsPyBuffer(pybind11::handle handle); 59 // Converts `handle` to a PyBuffer*. Does not do any checking. 60 static PyBuffer* AsPyBufferUnchecked(pybind11::handle handle); 61 // Converts `handle` to a PyBuffer*. Returns an error status if 62 // !IsPyBuffer(handle) 63 static StatusOr<PyBuffer*> AsPyBuffer(pybind11::handle handle); 64 65 // Gets a Python handle to an existing PyBuffer. Assumes the PyObject was 66 // allocated on the Python heap, which is the case if Make() was used. 67 pybind11::handle AsHandle(); 68 69 ~PyBuffer(); 70 client()71 std::shared_ptr<PyClient> client() const { return client_; } buffer()72 PjRtBuffer* buffer() const { return buffer_.get(); } shared_ptr_buffer()73 std::shared_ptr<PjRtBuffer> shared_ptr_buffer() const { return buffer_; } 74 75 ClientAndPtr<PjRtDevice> device() const; platform_name()76 absl::string_view platform_name() const { 77 return buffer_->client()->platform_name(); 78 } is_deleted()79 bool is_deleted() const { return buffer_->IsDeleted(); } 80 81 StatusOr<pybind11::object> CopyToDevice( 82 const ClientAndPtr<PjRtDevice>& dst_device) const; 83 OnDeviceSizeInBytes()84 StatusOr<size_t> OnDeviceSizeInBytes() { 85 return buffer_->GetOnDeviceSizeInBytes(); 86 } 87 Delete()88 void Delete() { 89 buffer_->Delete(); 90 host_value_ = nullptr; 91 } 92 93 // Makes a copy of this PyBuffer object that shares the underlying PjRtBuffer. 94 // This is useful because we may wish to change JAX metadata (e.g., the sticky 95 // device) without copying the buffer. 96 object Clone() const; 97 98 // Returns xla::InvalidArgument if the buffer has been deleted. 99 Status BlockHostUntilReady(); 100 Status CopyToHostAsync(); 101 shape()102 const Shape& shape() { return buffer_->on_device_shape(); } 103 104 StatusOr<std::uintptr_t> UnsafeBufferPointer() const; 105 106 // Implementation of the CUDA array interface for sharing GPU buffers with 107 // other Python libraries. 108 StatusOr<pybind11::dict> CudaArrayInterface(); 109 traceback()110 const std::shared_ptr<Traceback>& traceback() const { return traceback_; } 111 112 // Returns the size (i.e. number of elements) of the (host) numpy array. 113 StatusOr<int64> size(); 114 115 // Returns the number of dimensions of the (host) numpy array. ndim()116 int ndim() const { return buffer()->on_device_shape().dimensions_size(); } 117 118 pybind11::tuple python_shape() const; 119 pybind11::dtype python_dtype() const; 120 121 // Representing the logical view of the underlying dynamic shapes. 122 StatusOr<const Shape*> xla_dynamic_shape(); 123 set_sticky_device(PjRtDevice * sticky_device)124 Status set_sticky_device(PjRtDevice* sticky_device) { 125 TF_RET_CHECK(sticky_device == nullptr || 126 sticky_device == buffer_->device()); 127 sticky_device_ = sticky_device; 128 return Status::OK(); 129 } sticky_device()130 PjRtDevice* sticky_device() const { return sticky_device_; } 131 set_weak_type(absl::optional<bool> weak_type)132 void set_weak_type(absl::optional<bool> weak_type) { weak_type_ = weak_type; } weak_type()133 absl::optional<bool> weak_type() const { return weak_type_; } 134 135 StatusOr<pybind11::object> AsNumPyArray(pybind11::handle this_obj); 136 SetAval(pybind11::object aval)137 void SetAval(pybind11::object aval) { aval_ = aval; } GetAval()138 pybind11::object GetAval() const { return aval_; } 139 140 static Status RegisterTypes(pybind11::module& m); base_type()141 static PyObject* base_type() { return base_type_; } type()142 static PyObject* type() { return type_; } 143 144 private: 145 // PyBuffer objects must not be allocated directly since they must always live 146 // on the Python heap. Use Make() instead. 147 PyBuffer(std::shared_ptr<PyClient> client, std::shared_ptr<PjRtBuffer> buffer, 148 std::shared_ptr<Traceback> traceback); 149 150 static PyObject* base_type_; 151 static PyObject* type_; 152 153 friend class PyClient; 154 155 struct HostValue { 156 absl::Notification ready; 157 Status status; 158 std::shared_ptr<xla::Literal> value; 159 }; 160 std::shared_ptr<PyClient> client_; 161 std::shared_ptr<PjRtBuffer> buffer_; 162 std::shared_ptr<Traceback> traceback_; 163 std::shared_ptr<HostValue> host_value_; // Protected by the GIL. 164 165 // JAX uses this field to record whether a buffer is committed to a particular 166 // device by the user (https://github.com/google/jax/pull/1916). 167 PjRtDevice* sticky_device_ = nullptr; 168 169 // TODO(phawkins): consider not keeping an explicit aval on C++ buffer 170 // objects. 171 pybind11::object aval_ = pybind11::none(); 172 173 // An optional weak type. If absent, the JAX jit code computes the weak_type 174 // from the aval_.weak_type attribute. This is a backwards compatibility 175 // measure for older Python code that does not set weak_type explicitly. 176 // TODO(phawkins): drop support for older jax Python versions and make 177 // weak_type mandatory. 178 absl::optional<bool> weak_type_ = absl::nullopt; 179 180 absl::optional<Shape> dynamic_shape_ = absl::nullopt; 181 // Doubly-linked list of all PyBuffers known to the client. Protected by the 182 // GIL. Since multiple PyBuffers may share the same PjRtBuffer, there may be 183 // duplicate PjRtBuffers in this list. 184 PyBuffer* next_; 185 PyBuffer* prev_; 186 }; 187 188 } // namespace xla 189 190 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_PY_BUFFER_H_ 191