• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/py_executable.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "tensorflow/core/platform/fingerprint.h"
20 
21 namespace xla {
22 
23 namespace py = pybind11;
24 
PyExecutable(std::shared_ptr<PyClient> client,std::unique_ptr<PjRtExecutable> executable,std::shared_ptr<Traceback> traceback,absl::optional<std::string> fingerprint)25 PyExecutable::PyExecutable(std::shared_ptr<PyClient> client,
26                            std::unique_ptr<PjRtExecutable> executable,
27                            std::shared_ptr<Traceback> traceback,
28                            absl::optional<std::string> fingerprint)
29     : client_(std::move(client)),
30       executable_(std::move(executable)),
31       traceback_(std::move(traceback)),
32       fingerprint_(std::move(fingerprint)) {
33   CHECK(PyGILState_Check());
34   next_ = client_->executables_;
35   client_->executables_ = this;
36   prev_ = nullptr;
37   if (next_) {
38     next_->prev_ = this;
39   }
40   options_.untuple_result = true;
41   if (fingerprint_) {
42     options_.launch_id = tensorflow::Fingerprint32(*fingerprint_);
43     VLOG(1) << "Fingerprint for executable " << executable_->name() << ": "
44             << *fingerprint_;
45   }
46 }
47 
~PyExecutable()48 PyExecutable::~PyExecutable() {
49   CHECK(PyGILState_Check());
50   if (client_->executables_ == this) {
51     client_->executables_ = next_;
52   }
53   if (prev_) {
54     prev_->next_ = next_;
55   }
56   if (next_) {
57     next_->prev_ = prev_;
58   }
59 }
60 
AddressableDevices() const61 std::vector<ClientAndPtr<PjRtDevice>> PyExecutable::AddressableDevices() const {
62   std::vector<ClientAndPtr<PjRtDevice>> devices;
63   devices.reserve(executable_->addressable_devices().size());
64   for (PjRtDevice* device : executable_->addressable_devices()) {
65     devices.push_back(WrapWithClient(client_, device));
66   }
67   return devices;
68 }
69 
70 // Used by JAX JIT which has C++ PjRtBuffers as inputs (Numpy to PjRtBuffer is
71 // faster and simpler than Numpy to PyBuffer to PjRtBuffer) and requires
72 // PyBuffer as outputs as it will return to Python.
PjRtExecute(const std::vector<PjRtBuffer * > & args)73 StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::PjRtExecute(
74     const std::vector<PjRtBuffer*>& args) {
75   std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers;
76   {
77     py::gil_scoped_release gil_release;
78     TF_ASSIGN_OR_RETURN(output_buffers, executable_->Execute({args}, options_));
79   }
80   auto traceback = Traceback::Get();
81   std::vector<std::unique_ptr<PyBuffer>> outputs;
82   outputs.reserve(output_buffers[0].size());
83   for (auto& buffer : output_buffers[0]) {
84     outputs.push_back(
85         std::make_unique<PyBuffer>(client_, std::move(buffer), traceback));
86   }
87   return outputs;
88 }
89 
Execute(absl::Span<PyBuffer * const> args)90 StatusOr<std::vector<std::unique_ptr<PyBuffer>>> PyExecutable::Execute(
91     absl::Span<PyBuffer* const> args) {
92   std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers;
93   {
94     py::gil_scoped_release gil_release;
95     std::vector<PjRtBuffer*> arg_buffers(args.size());
96     absl::c_transform(args, arg_buffers.begin(),
97                       [](PyBuffer* buf) { return buf->buffer(); });
98     TF_ASSIGN_OR_RETURN(output_buffers,
99                         executable_->Execute({arg_buffers}, options_));
100   }
101   auto traceback = Traceback::Get();
102   std::vector<std::unique_ptr<PyBuffer>> outputs;
103   outputs.reserve(output_buffers[0].size());
104   for (auto& buffer : output_buffers[0]) {
105     outputs.push_back(
106         std::make_unique<PyBuffer>(client_, std::move(buffer), traceback));
107   }
108   return outputs;
109 }
110 
111 StatusOr<std::vector<std::vector<std::unique_ptr<PyBuffer>>>>
ExecuteOnLocalDevices(absl::Span<const std::vector<PyBuffer * >> args)112 PyExecutable::ExecuteOnLocalDevices(
113     absl::Span<const std::vector<PyBuffer*>> args) {
114   std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers;
115   {
116     py::gil_scoped_release gil_release;
117     std::vector<std::vector<PjRtBuffer*>> arg_buffers(args.size());
118     for (int computation = 0; computation < args.size(); ++computation) {
119       arg_buffers[computation].resize(args[computation].size());
120       absl::c_transform(args[computation], arg_buffers[computation].begin(),
121                         [](PyBuffer* buf) { return buf->buffer(); });
122     }
123     TF_ASSIGN_OR_RETURN(output_buffers,
124                         executable_->Execute(arg_buffers, options_));
125   }
126   auto traceback = Traceback::Get();
127   std::vector<std::vector<std::unique_ptr<PyBuffer>>> outputs;
128   outputs.resize(output_buffers.size());
129   for (int computation = 0; computation < output_buffers.size();
130        ++computation) {
131     for (auto& buffer : output_buffers[computation]) {
132       outputs[computation].push_back(
133           std::make_unique<PyBuffer>(client_, std::move(buffer), traceback));
134     }
135   }
136   return outputs;
137 }
138 
139 StatusOr<std::vector<std::vector<std::unique_ptr<PyBuffer>>>>
ExecuteShardedOnLocalDevices(absl::Span<const std::vector<PyBuffer * >> args)140 PyExecutable::ExecuteShardedOnLocalDevices(
141     absl::Span<const std::vector<PyBuffer*>> args) {
142   std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers;
143   int num_computations = executable_->addressable_devices().size();
144   {
145     py::gil_scoped_release gil_release;
146     for (const auto& arg : args) {
147       if (arg.size() != num_computations) {
148         return xla::InvalidArgument(
149             "Expected args to execute_sharded_on_local_devices to have %d "
150             "shards, got: [%s]",
151             num_computations,
152             absl::StrJoin(
153                 args, ", ",
154                 [](std::string* out, const std::vector<PyBuffer*>& arg) {
155                   out->append(std::to_string(arg.size()));
156                 }));
157       }
158     }
159     std::vector<std::vector<PjRtBuffer*>> arg_buffers(num_computations);
160     for (int computation = 0; computation < num_computations; ++computation) {
161       arg_buffers[computation].resize(args.size());
162       absl::c_transform(args, arg_buffers[computation].begin(),
163                         [&](const std::vector<PyBuffer*>& arg) {
164                           return arg[computation]->buffer();
165                         });
166     }
167     TF_ASSIGN_OR_RETURN(output_buffers,
168                         executable_->Execute(arg_buffers, options_));
169   }
170   auto traceback = Traceback::Get();
171   int num_output_buffers = output_buffers[0].size();
172   std::vector<std::vector<std::unique_ptr<PyBuffer>>> outputs;
173   outputs.resize(num_output_buffers);
174   for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) {
175     outputs[buffer_id].reserve(num_computations);
176     for (int computation = 0; computation < num_computations; ++computation) {
177       outputs[buffer_id].push_back(std::make_unique<PyBuffer>(
178           client_, std::move(output_buffers[computation][buffer_id]),
179           traceback));
180     }
181   }
182   return outputs;
183 }
184 
HloModules() const185 StatusOr<std::vector<std::shared_ptr<HloModule>>> PyExecutable::HloModules()
186     const {
187   return executable_->GetHloModules();
188 }
189 
190 }  // namespace xla
191