1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/python/py_executable.h"
17
18 #include "absl/algorithm/container.h"
19 #include "tensorflow/core/platform/fingerprint.h"
20
21 namespace xla {
22
23 namespace py = pybind11;
24
PyExecutable(std::shared_ptr<PyClient> client,std::unique_ptr<PjRtExecutable> executable,std::shared_ptr<Traceback> traceback,absl::optional<std::string> fingerprint)25 PyExecutable::PyExecutable(std::shared_ptr<PyClient> client,
26 std::unique_ptr<PjRtExecutable> executable,
27 std::shared_ptr<Traceback> traceback,
28 absl::optional<std::string> fingerprint)
29 : client_(std::move(client)),
30 executable_(std::move(executable)),
31 traceback_(std::move(traceback)),
32 fingerprint_(std::move(fingerprint)) {
33 CHECK(PyGILState_Check());
34 next_ = client_->executables_;
35 client_->executables_ = this;
36 prev_ = nullptr;
37 if (next_) {
38 next_->prev_ = this;
39 }
40 options_.untuple_result = true;
41 if (fingerprint_) {
42 options_.launch_id = tensorflow::Fingerprint32(*fingerprint_);
43 VLOG(1) << "Fingerprint for executable " << executable_->name() << ": "
44 << *fingerprint_;
45 }
46 }
47
~PyExecutable()48 PyExecutable::~PyExecutable() {
49 CHECK(PyGILState_Check());
50 if (client_->executables_ == this) {
51 client_->executables_ = next_;
52 }
53 if (prev_) {
54 prev_->next_ = next_;
55 }
56 if (next_) {
57 next_->prev_ = prev_;
58 }
59 }
60
AddressableDevices() const61 std::vector<ClientAndPtr<PjRtDevice>> PyExecutable::AddressableDevices() const {
62 std::vector<ClientAndPtr<PjRtDevice>> devices;
63 devices.reserve(executable_->addressable_devices().size());
64 for (PjRtDevice* device : executable_->addressable_devices()) {
65 devices.push_back(WrapWithClient(client_, device));
66 }
67 return devices;
68 }
69
Execute(absl::Span<PyBuffer::object const> args)70 StatusOr<std::vector<PyBuffer::object>> PyExecutable::Execute(
71 absl::Span<PyBuffer::object const> args) {
72 std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers;
73 {
74 py::gil_scoped_release gil_release;
75 std::vector<PjRtBuffer*> arg_buffers(args.size());
76 absl::c_transform(
77 args, arg_buffers.begin(),
78 [](const PyBuffer::object& buf) { return buf.buf()->buffer(); });
79 TF_ASSIGN_OR_RETURN(output_buffers,
80 executable_->Execute({arg_buffers}, options_));
81 }
82 auto traceback = Traceback::Get();
83 std::vector<PyBuffer::object> outputs;
84 outputs.reserve(output_buffers[0].size());
85 for (auto& buffer : output_buffers[0]) {
86 outputs.push_back(PyBuffer::Make(client_, std::move(buffer), traceback));
87 }
88 return outputs;
89 }
90
91 StatusOr<std::vector<std::vector<PyBuffer::object>>>
ExecuteShardedOnLocalDevices(absl::Span<const std::vector<PyBuffer::object>> args)92 PyExecutable::ExecuteShardedOnLocalDevices(
93 absl::Span<const std::vector<PyBuffer::object>> args) {
94 std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers;
95 int num_computations = executable_->addressable_devices().size();
96 {
97 py::gil_scoped_release gil_release;
98 for (const auto& arg : args) {
99 if (arg.size() != num_computations) {
100 return xla::InvalidArgument(
101 "Expected args to execute_sharded_on_local_devices to have %d "
102 "shards, got: [%s]",
103 num_computations,
104 absl::StrJoin(
105 args, ", ",
106 [](std::string* out, const std::vector<PyBuffer::object>& arg) {
107 out->append(std::to_string(arg.size()));
108 }));
109 }
110 }
111 std::vector<std::vector<PjRtBuffer*>> arg_buffers(num_computations);
112 const int num_args = args.size();
113 for (int computation = 0; computation < num_computations; ++computation) {
114 arg_buffers[computation].resize(num_args);
115 absl::c_transform(args, arg_buffers[computation].begin(),
116 [&](const std::vector<PyBuffer::object>& arg) {
117 return arg[computation].buf()->buffer();
118 });
119 }
120 TF_ASSIGN_OR_RETURN(output_buffers,
121 executable_->Execute(arg_buffers, options_));
122 }
123 auto traceback = Traceback::Get();
124 int num_output_buffers = output_buffers[0].size();
125 std::vector<std::vector<PyBuffer::object>> outputs;
126 outputs.resize(num_output_buffers);
127 for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) {
128 outputs[buffer_id].reserve(num_computations);
129 for (int computation = 0; computation < num_computations; ++computation) {
130 outputs[buffer_id].push_back(PyBuffer::Make(
131 client_, std::move(output_buffers[computation][buffer_id]),
132 traceback));
133 }
134 }
135 return outputs;
136 }
137
HloModules() const138 StatusOr<std::vector<std::shared_ptr<HloModule>>> PyExecutable::HloModules()
139 const {
140 return executable_->GetHloModules();
141 }
142
KeepAlive(py::object obj)143 void PyExecutable::KeepAlive(py::object obj) {
144 keepalives_.push_back(std::move(obj));
145 }
146
147 } // namespace xla
148