• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PYTHON_PY_BUFFER_H_
17 #define TENSORFLOW_COMPILER_XLA_PYTHON_PY_BUFFER_H_
18 
19 #include <memory>
20 #include <stdexcept>
21 #include <vector>
22 
23 #include "absl/strings/string_view.h"
24 #include "absl/synchronization/notification.h"
25 #include "absl/types/optional.h"
26 #include "pybind11/numpy.h"
27 #include "pybind11/pybind11.h"
28 #include "tensorflow/compiler/xla/python/py_client.h"
29 #include "tensorflow/compiler/xla/python/traceback.h"
30 #include "tensorflow/compiler/xla/statusor.h"
31 #include "tensorflow/compiler/xla/types.h"
32 
33 namespace xla {
34 
35 // As we are deploying both a C++ and a Python implementation for DeviceArray,
36 // we use an empty base-class to ensure `isinstance(x, DeviceArray)` works.
37 //         DeviceArrayBase == DeviceArray
38 //              /  \
39 //             /    \
40 //    PyBuffer      _DeviceArray (Python)
41 //      in C++
42 class DeviceArrayBase {
43  public:
44   DeviceArrayBase() = default;
45 };
46 
47 // Python wrapper around PjRtBuffer. We use a wrapper class:
48 // a) to keep the PjRtClient alive via a std::shared_ptr<>
49 // b) to add Python-specific functionality.
50 //
51 // A `PyBuffer` can be used from Python without being wrapped in a Python
52 // `DeviceArray` object, at the condition there is no associated LazyExpr.
53 class PyBuffer : public DeviceArrayBase {
54  public:
55   PyBuffer(std::shared_ptr<PyClient> client, std::unique_ptr<PjRtBuffer> buffer,
56            std::shared_ptr<Traceback> traceback);
57   ~PyBuffer();
58 
client()59   std::shared_ptr<PyClient> client() const { return client_; }
buffer()60   PjRtBuffer* buffer() const { return buffer_.get(); }
61 
62   ClientAndPtr<PjRtDevice> device() const;
platform_name()63   absl::string_view platform_name() const {
64     return buffer_->client()->platform_name();
65   }
is_deleted()66   bool is_deleted() const { return buffer_->IsDeleted(); }
67 
68   StatusOr<std::unique_ptr<PyBuffer>> CopyToDevice(
69       const ClientAndPtr<PjRtDevice>& dst_device) const;
70 
OnDeviceSizeInBytes()71   int64 OnDeviceSizeInBytes() { return buffer_->OnDeviceSizeInBytes(); }
72 
Delete()73   void Delete() {
74     buffer_->Delete();
75     host_value_ = nullptr;
76   }
77 
78   // Returns xla::InvalidArgument if the buffer has been deleted.
79   Status BlockHostUntilReady();
80   Status CopyToHostAsync();
81 
shape()82   const Shape& shape() { return buffer_->on_device_shape(); }
83 
84   StatusOr<std::uintptr_t> UnsafeBufferPointer() const;
85 
86   // Implementation of the CUDA array interface for sharing GPU buffers with
87   // other Python libraries.
88   StatusOr<pybind11::dict> CudaArrayInterface() const;
89 
90   // PEP 3118 Python buffer protocol implementation.
91   static PyBufferProcs* BufferProtocol();
92 
traceback()93   Traceback* traceback() { return traceback_.get(); }
94 
95   // Returns the size (i.e. number of elements) of the (host) numpy array.
size()96   int64 size() { return ShapeUtil::ElementsIn(buffer()->on_device_shape()); }
97 
98   // Returns the number of dimensions of the (host) numpy array.
ndim()99   int ndim() const { return buffer()->on_device_shape().dimensions_size(); }
100 
101   pybind11::tuple python_shape() const;
102   pybind11::dtype python_dtype() const;
103 
104   void SetStickyDevice(pybind11::object sticky_device);
GetStickyDevice()105   pybind11::object GetStickyDevice() const { return sticky_device_.value(); }
106 
107   StatusOr<pybind11::object> AsNumPyArray(pybind11::handle this_obj);
108 
109   void SetAval(pybind11::object aval);
GetAval()110   pybind11::object GetAval() const { return aval_.value(); }
111 
112  private:
113   friend class PyClient;
114 
115   struct HostValue {
116     absl::Notification ready;
117     Status status;
118     std::shared_ptr<xla::Literal> value;
119   };
120   std::shared_ptr<PyClient> client_;
121   std::unique_ptr<PjRtBuffer> buffer_;
122   std::shared_ptr<Traceback> traceback_;
123   std::shared_ptr<HostValue> host_value_;  // Protected by the GIL.
124 
125   absl::optional<pybind11::object> sticky_device_ = absl::nullopt;
126   // TODO(jblespiau): It's currently there for convenience but maybe we can do
127   // without it (adding `weak_type` instead).
128   absl::optional<pybind11::object> aval_ = absl::nullopt;
129   // Doubly-linked list of all buffers known to the client. Protected by the
130   // GIL.
131   PyBuffer* next_;
132   PyBuffer* prev_;
133 };
134 
135 }  // namespace xla
136 
137 #endif  // TENSORFLOW_COMPILER_XLA_PYTHON_PY_BUFFER_H_
138