• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/python/py_executable.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "tensorflow/core/platform/fingerprint.h"
20 
21 namespace xla {
22 
23 namespace py = pybind11;
24 
PyExecutable(std::shared_ptr<PyClient> client,std::unique_ptr<PjRtExecutable> executable,std::shared_ptr<Traceback> traceback,absl::optional<std::string> fingerprint)25 PyExecutable::PyExecutable(std::shared_ptr<PyClient> client,
26                            std::unique_ptr<PjRtExecutable> executable,
27                            std::shared_ptr<Traceback> traceback,
28                            absl::optional<std::string> fingerprint)
29     : client_(std::move(client)),
30       executable_(std::move(executable)),
31       traceback_(std::move(traceback)),
32       fingerprint_(std::move(fingerprint)) {
33   CHECK(PyGILState_Check());
34   next_ = client_->executables_;
35   client_->executables_ = this;
36   prev_ = nullptr;
37   if (next_) {
38     next_->prev_ = this;
39   }
40   options_.untuple_result = true;
41   if (fingerprint_) {
42     options_.launch_id = tensorflow::Fingerprint32(*fingerprint_);
43     VLOG(1) << "Fingerprint for executable " << executable_->name() << ": "
44             << *fingerprint_;
45   }
46 }
47 
~PyExecutable()48 PyExecutable::~PyExecutable() {
49   CHECK(PyGILState_Check());
50   if (client_->executables_ == this) {
51     client_->executables_ = next_;
52   }
53   if (prev_) {
54     prev_->next_ = next_;
55   }
56   if (next_) {
57     next_->prev_ = prev_;
58   }
59 }
60 
AddressableDevices() const61 std::vector<ClientAndPtr<PjRtDevice>> PyExecutable::AddressableDevices() const {
62   std::vector<ClientAndPtr<PjRtDevice>> devices;
63   devices.reserve(executable_->addressable_devices().size());
64   for (PjRtDevice* device : executable_->addressable_devices()) {
65     devices.push_back(WrapWithClient(client_, device));
66   }
67   return devices;
68 }
69 
Execute(absl::Span<PyBuffer::object const> args)70 StatusOr<std::vector<PyBuffer::object>> PyExecutable::Execute(
71     absl::Span<PyBuffer::object const> args) {
72   std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers;
73   {
74     py::gil_scoped_release gil_release;
75     std::vector<PjRtBuffer*> arg_buffers(args.size());
76     absl::c_transform(
77         args, arg_buffers.begin(),
78         [](const PyBuffer::object& buf) { return buf.buf()->buffer(); });
79     TF_ASSIGN_OR_RETURN(output_buffers,
80                         executable_->Execute({arg_buffers}, options_));
81   }
82   auto traceback = Traceback::Get();
83   std::vector<PyBuffer::object> outputs;
84   outputs.reserve(output_buffers[0].size());
85   for (auto& buffer : output_buffers[0]) {
86     outputs.push_back(PyBuffer::Make(client_, std::move(buffer), traceback));
87   }
88   return outputs;
89 }
90 
91 StatusOr<std::vector<std::vector<PyBuffer::object>>>
ExecuteShardedOnLocalDevices(absl::Span<const std::vector<PyBuffer::object>> args)92 PyExecutable::ExecuteShardedOnLocalDevices(
93     absl::Span<const std::vector<PyBuffer::object>> args) {
94   std::vector<std::vector<std::unique_ptr<PjRtBuffer>>> output_buffers;
95   int num_computations = executable_->addressable_devices().size();
96   {
97     py::gil_scoped_release gil_release;
98     for (const auto& arg : args) {
99       if (arg.size() != num_computations) {
100         return xla::InvalidArgument(
101             "Expected args to execute_sharded_on_local_devices to have %d "
102             "shards, got: [%s]",
103             num_computations,
104             absl::StrJoin(
105                 args, ", ",
106                 [](std::string* out, const std::vector<PyBuffer::object>& arg) {
107                   out->append(std::to_string(arg.size()));
108                 }));
109       }
110     }
111     std::vector<std::vector<PjRtBuffer*>> arg_buffers(num_computations);
112     const int num_args = args.size();
113     for (int computation = 0; computation < num_computations; ++computation) {
114       arg_buffers[computation].resize(num_args);
115       absl::c_transform(args, arg_buffers[computation].begin(),
116                         [&](const std::vector<PyBuffer::object>& arg) {
117                           return arg[computation].buf()->buffer();
118                         });
119     }
120     TF_ASSIGN_OR_RETURN(output_buffers,
121                         executable_->Execute(arg_buffers, options_));
122   }
123   auto traceback = Traceback::Get();
124   int num_output_buffers = output_buffers[0].size();
125   std::vector<std::vector<PyBuffer::object>> outputs;
126   outputs.resize(num_output_buffers);
127   for (int buffer_id = 0; buffer_id < num_output_buffers; ++buffer_id) {
128     outputs[buffer_id].reserve(num_computations);
129     for (int computation = 0; computation < num_computations; ++computation) {
130       outputs[buffer_id].push_back(PyBuffer::Make(
131           client_, std::move(output_buffers[computation][buffer_id]),
132           traceback));
133     }
134   }
135   return outputs;
136 }
137 
HloModules() const138 StatusOr<std::vector<std::shared_ptr<HloModule>>> PyExecutable::HloModules()
139     const {
140   return executable_->GetHloModules();
141 }
142 
KeepAlive(py::object obj)143 void PyExecutable::KeepAlive(py::object obj) {
144   keepalives_.push_back(std::move(obj));
145 }
146 
147 }  // namespace xla
148