1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/py_executable.h"
17
18 #include "absl/algorithm/container.h"
19 #include "tensorflow/core/platform/fingerprint.h"
20
21 namespace xla {
22
23 namespace py = pybind11;
24
PyExecutable(std::shared_ptr<PyClient> client,std::unique_ptr<PjRtExecutable> executable,std::shared_ptr<Traceback> traceback,absl::optional<std::string> fingerprint)25 PyExecutable::PyExecutable(std::shared_ptr<PyClient> client,
26 std::unique_ptr<PjRtExecutable> executable,
27 std::shared_ptr<Traceback> traceback,
28 absl::optional<std::string> fingerprint)
29 : client_(std::move(client)),
30 executable_(std::move(executable)),
31 traceback_(std::move(traceback)),
32 fingerprint_(std::move(fingerprint)) {
33 CHECK(PyGILState_Check());
34 next_ = client_->executables_;
35 client_->executables_ = this;
36 prev_ = nullptr;
37 if (next_) {
38 next_->prev_ = this;
39 }
40 options_.untuple_result = true;
41 if (fingerprint_) {
42 options_.launch_id = tensorflow::Fingerprint32(*fingerprint_);
43 VLOG(1) << "Fingerprint for executable " << executable_->name() << ": "
44 << *fingerprint_;
45 }
46 }
47
~PyExecutable()48 PyExecutable::~PyExecutable() {
49 CHECK(PyGILState_Check());
50 if (client_->executables_ == this) {
51 client_->executables_ = next_;
52 }
53 if (prev_) {
54 prev_->next_ = next_;
55 }
56 if (next_) {
57 next_->prev_ = prev_;
58 }
59 }
60
AddressableDevices() const61 std::vector<ClientAndPtr<PjRtDevice>> PyExecutable::AddressableDevices() const {
62 std::vector<ClientAndPtr<PjRtDevice>> devices;
63 devices.reserve(executable_->addressable_devices().size());
64 for (PjRtDevice* device : executable_->addressable_devices()) {
65 devices.push_back(WrapWithClient(client_, device));
66 }
67 return devices;
68 }
69
70 // Used by JAX JIT which has C++ PjRtBuffers as inputs (Numpy to PjRtBuffer is
71 // faster and simpler than Numpy to PyBuffer to PjRtBuffer) and requires
72 // PyBuffer as outputs as it will return to Python.
PjRtExecute(const std::vector<PjRtBuffer * > & args)73 StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::PjRtExecute(
74 const std::vector<PjRtBuffer*>& args) {
75 std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers;
76 {
77 py::gil_scoped_release gil_release;
78 TF_ASSIGN_OR_RETURN(output_buffers, executable_->Execute({args}, options_));
79 }
80 auto traceback = Traceback::Get();
81 std::vector<std::unique_ptr<PyBuffer>> outputs;
82 outputs.reserve(output_buffers[0].size());
83 for (auto& buffer : output_buffers[0]) {
84 outputs.push_back(
85 std::make_unique<PyBuffer>(client_, std::move(buffer), traceback));
86 }
87 return outputs;
88 }
89
Execute(absl::Span<PyBuffer * const> args)90 StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::Execute(
91 absl::Span<PyBuffer* const> args) {
92 std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers;
93 {
94 py::gil_scoped_release gil_release;
95 std::vector<PjRtBuffer*> arg_buffers(args.size());
96 absl::c_transform(args, arg_buffers.begin(),
97 [](PyBuffer* buf) { return buf->buffer(); });
98 TF_ASSIGN_OR_RETURN(output_buffers,
99 executable_->Execute({arg_buffers}, options_));
100 }
101 auto traceback = Traceback::Get();
102 std::vector<std::unique_ptr<PyBuffer>> outputs;
103 outputs.reserve(output_buffers[0].size());
104 for (auto& buffer : output_buffers[0]) {
105 outputs.push_back(
106 std::make_unique<PyBuffer>(client_, std::move(buffer), traceback));
107 }
108 return outputs;
109 }
110
111 StatusOr<std::vector<std::vector<std::unique_ptr<PyBuffer>>>>
ExecuteOnLocalDevices(absl::Span<const std::vector<PyBuffer * >> args)112 PyExecutable::ExecuteOnLocalDevices(
113 absl::Span<const std::vector<PyBuffer*>> args) {
114 std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers;
115 {
116 py::gil_scoped_release gil_release;
117 std::vector<std::vector<PjRtBuffer*>> arg_buffers(args.size());
118 for (int computation = 0; computation < args.size(); ++computation) {
119 arg_buffers[computation].resize(args[computation].size());
120 absl::c_transform(args[computation], arg_buffers[computation].begin(),
121 [](PyBuffer* buf) { return buf->buffer(); });
122 }
123 TF_ASSIGN_OR_RETURN(output_buffers,
124 executable_->Execute(arg_buffers, options_));
125 }
126 auto traceback = Traceback::Get();
127 std::vector<std::vector<std::unique_ptr<PyBuffer>>> outputs;
128 outputs.resize(output_buffers.size());
129 for (int computation = 0; computation < output_buffers.size();
130 ++computation) {
131 for (auto& buffer : output_buffers[computation]) {
132 outputs[computation].push_back(
133 std::make_unique<PyBuffer>(client_, std::move(buffer), traceback));
134 }
135 }
136 return outputs;
137 }
138
139 StatusOr<std::vector<std::vector<std::unique_ptr<PyBuffer>>>>
ExecuteShardedOnLocalDevices(absl::Span<const std::vector<PyBuffer * >> args)140 PyExecutable::ExecuteShardedOnLocalDevices(
141 absl::Span<const std::vector<PyBuffer*>> args) {
142 std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers;
143 int num_computations = executable_->addressable_devices().size();
144 {
145 py::gil_scoped_release gil_release;
146 for (const auto& arg : args) {
147 if (arg.size() != num_computations) {
148 return xla::InvalidArgument(
149 "Expected args to execute_sharded_on_local_devices to have %d "
150 "shards, got: [%s]",
151 num_computations,
152 absl::StrJoin(
153 args, ", ",
154 [](std::string* out, const std::vector<PyBuffer*>& arg) {
155 out->append(std::to_string(arg.size()));
156 }));
157 }
158 }
159 std::vector<std::vector<PjRtBuffer*>> arg_buffers(num_computations);
160 for (int computation = 0; computation < num_computations; ++computation) {
161 arg_buffers[computation].resize(args.size());
162 absl::c_transform(args, arg_buffers[computation].begin(),
163 [&](const std::vector<PyBuffer*>& arg) {
164 return arg[computation]->buffer();
165 });
166 }
167 TF_ASSIGN_OR_RETURN(output_buffers,
168 executable_->Execute(arg_buffers, options_));
169 }
170 auto traceback = Traceback::Get();
171 int num_output_buffers = output_buffers[0].size();
172 std::vector<std::vector<std::unique_ptr<PyBuffer>>> outputs;
173 outputs.resize(num_output_buffers);
174 for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) {
175 outputs[buffer_id].reserve(num_computations);
176 for (int computation = 0; computation < num_computations; ++computation) {
177 outputs[buffer_id].push_back(std::make_unique<PyBuffer>(
178 client_, std::move(output_buffers[computation][buffer_id]),
179 traceback));
180 }
181 }
182 return outputs;
183 }
184
HloModules() const185 StatusOr<std::vector<std::shared_ptr<HloModule>>> PyExecutable::HloModules()
186 const {
187 return executable_->GetHloModules();
188 }
189
190 } // namespace xla
191