1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PY_BUFFER_H_ 17 #define TENSORFLOW_COMPILER_XLA_PYTHON_PY_BUFFER_H_ 18 19 #include <memory> 20 #include <stdexcept> 21 #include <vector> 22 23 #include "absl/strings/string_view.h" 24 #include "absl/synchronization/notification.h" 25 #include "absl/types/optional.h" 26 #include "pybind11/numpy.h" 27 #include "pybind11/pybind11.h" 28 #include "tensorflow/compiler/xla/python/py_client.h" 29 #include "tensorflow/compiler/xla/python/traceback.h" 30 #include "tensorflow/compiler/xla/statusor.h" 31 #include "tensorflow/compiler/xla/types.h" 32 33 namespace xla { 34 35 // As we are deploying both a C++ and a Python implementation for DeviceArray, 36 // we use an empty base-class to ensure `isinstance(x, DeviceArray)` works. 37 // DeviceArrayBase == DeviceArray 38 // / \ 39 // / \ 40 // PyBuffer _DeviceArray (Python) 41 // in C++ 42 class DeviceArrayBase { 43 public: 44 DeviceArrayBase() = default; 45 }; 46 47 // Python wrapper around PjRtBuffer. We use a wrapper class: 48 // a) to keep the PjRtClient alive via a std::shared_ptr<> 49 // b) to add Python-specific functionality. 50 // 51 // A `PyBuffer` can be used from Python without being wrapped in a Python 52 // `DeviceArray` object, at the condition there is no associated LazyExpr. 53 class PyBuffer : public DeviceArrayBase { 54 public: 55 PyBuffer(std::shared_ptr<PyClient> client, std::unique_ptr<PjRtBuffer> buffer, 56 std::shared_ptr<Traceback> traceback); 57 ~PyBuffer(); 58 client()59 std::shared_ptr<PyClient> client() const { return client_; } buffer()60 PjRtBuffer* buffer() const { return buffer_.get(); } 61 62 ClientAndPtr<PjRtDevice> device() const; platform_name()63 absl::string_view platform_name() const { 64 return buffer_->client()->platform_name(); 65 } is_deleted()66 bool is_deleted() const { return buffer_->IsDeleted(); } 67 68 StatusOr<std::unique_ptr<PyBuffer>> CopyToDevice( 69 const ClientAndPtr<PjRtDevice>& dst_device) const; 70 OnDeviceSizeInBytes()71 int64 OnDeviceSizeInBytes() { return buffer_->OnDeviceSizeInBytes(); } 72 Delete()73 void Delete() { 74 buffer_->Delete(); 75 host_value_ = nullptr; 76 } 77 78 // Returns xla::InvalidArgument if the buffer has been deleted. 79 Status BlockHostUntilReady(); 80 Status CopyToHostAsync(); 81 shape()82 const Shape& shape() { return buffer_->on_device_shape(); } 83 84 StatusOr<std::uintptr_t> UnsafeBufferPointer() const; 85 86 // Implementation of the CUDA array interface for sharing GPU buffers with 87 // other Python libraries. 88 StatusOr<pybind11::dict> CudaArrayInterface() const; 89 90 // PEP 3118 Python buffer protocol implementation. 91 static PyBufferProcs* BufferProtocol(); 92 traceback()93 Traceback* traceback() { return traceback_.get(); } 94 95 // Returns the size (i.e. number of elements) of the (host) numpy array. size()96 int64 size() { return ShapeUtil::ElementsIn(buffer()->on_device_shape()); } 97 98 // Returns the number of dimensions of the (host) numpy array. ndim()99 int ndim() const { return buffer()->on_device_shape().dimensions_size(); } 100 101 pybind11::tuple python_shape() const; 102 pybind11::dtype python_dtype() const; 103 104 void SetStickyDevice(pybind11::object sticky_device); GetStickyDevice()105 pybind11::object GetStickyDevice() const { return sticky_device_.value(); } 106 107 StatusOr<pybind11::object> AsNumPyArray(pybind11::handle this_obj); 108 109 void SetAval(pybind11::object aval); GetAval()110 pybind11::object GetAval() const { return aval_.value(); } 111 112 private: 113 friend class PyClient; 114 115 struct HostValue { 116 absl::Notification ready; 117 Status status; 118 std::shared_ptr<xla::Literal> value; 119 }; 120 std::shared_ptr<PyClient> client_; 121 std::unique_ptr<PjRtBuffer> buffer_; 122 std::shared_ptr<Traceback> traceback_; 123 std::shared_ptr<HostValue> host_value_; // Protected by the GIL. 124 125 absl::optional<pybind11::object> sticky_device_ = absl::nullopt; 126 // TODO(jblespiau): It's currently there for convenience but maybe we can do 127 // without it (adding `weak_type` instead). 128 absl::optional<pybind11::object> aval_ = absl::nullopt; 129 // Doubly-linked list of all buffers known to the client. Protected by the 130 // GIL. 131 PyBuffer* next_; 132 PyBuffer* prev_; 133 }; 134 135 } // namespace xla 136 137 #endif // TENSORFLOW_COMPILER_XLA_PYTHON_PY_BUFFER_H_ 138