1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "absl/memory/memory.h"
22 #include "absl/synchronization/mutex.h"
23 #include "tensorflow/compiler/xla/util.h"
24 #include "tensorflow/core/platform/types.h"
25 #include "tensorflow/core/profiler/lib/traceme.h"
26 #include "tensorflow/core/protobuf/error_codes.pb.h"
27 #include "tensorflow/stream_executor/stream.h"
28
29 namespace xla {
30
LocalDeviceState(se::StreamExecutor * executor,LocalClient * client,AllocationModel allocation_model,int max_inflight_computations,bool allow_event_reuse,bool use_callback_stream)31 LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor,
32 LocalClient* client,
33 AllocationModel allocation_model,
34 int max_inflight_computations,
35 bool allow_event_reuse,
36 bool use_callback_stream)
37 : allocation_model_(allocation_model),
38 event_pool_(allow_event_reuse),
39 compute_semaphore_(
40 /*capacity=*/max_inflight_computations),
41 executor_(executor),
42 client_(client),
43 prng_seed_generator_(prng_seed_device_()),
44 prng_seed_distribution_(std::numeric_limits<int>::min(),
45 std::numeric_limits<int>::max()) {
46 compute_stream_ = std::make_unique<se::Stream>(executor);
47 host_to_device_stream_ = std::make_unique<se::Stream>(executor);
48 compute_stream_->Init();
49 host_to_device_stream_->Init();
50 if (use_callback_stream) {
51 callback_stream_map_ =
52 absl::flat_hash_map<se::Stream*, std::unique_ptr<se::Stream>>();
53 }
54 device_to_host_streams_.reserve(kNumDeviceToHostStreams);
55 for (int i = 0; i < kNumDeviceToHostStreams; ++i) {
56 auto stream = std::make_unique<se::Stream>(executor);
57 stream->Init();
58 device_to_host_streams_.push_back(std::move(stream));
59 }
60 device_to_device_streams_.reserve(kNumDeviceToDeviceStreams);
61 for (int i = 0; i < kNumDeviceToDeviceStreams; ++i) {
62 auto stream = std::make_unique<se::Stream>(executor);
63 stream->Init();
64 device_to_device_streams_.push_back(std::move(stream));
65 }
66 execute_thread_ = std::make_unique<WorkerThread>(tensorflow::Env::Default(),
67 "py_xla_execute");
68 callback_thread_ = std::make_unique<WorkerThread>(tensorflow::Env::Default(),
69 "py_xla_callback");
70 }
71
~LocalDeviceState()72 LocalDeviceState::~LocalDeviceState() {
73 Status status = SynchronizeAllActivity();
74 if (!status.ok()) {
75 LOG(ERROR) << "Error when closing device: " << status;
76 }
77 }
78
SynchronizeAllActivity()79 Status LocalDeviceState::SynchronizeAllActivity() {
80 Status status;
81 // TODO(phawkins): in theory the call to SynchronizeAllActivity below should
82 // suffice. However on the Host platform SynchronizeAllActivity is a dummy
83 // implementation that doesn't actually block. To make sure activity has
84 // stopped, also block on the compute stream. If SynchronizeAllActivity is
85 // fixed, we could remove the BlockHostUntilDone call.
86 status.Update(compute_stream_->BlockHostUntilDone());
87 if (callback_stream_map_.has_value()) {
88 for (auto& callback_stream : callback_stream_map_.value()) {
89 status.Update(callback_stream.second->BlockHostUntilDone());
90 }
91 }
92 bool ok = compute_stream_->parent()->SynchronizeAllActivity();
93 if (!ok) {
94 status.Update(Unknown("SynchronizeAllActivity failed."));
95 }
96 return status;
97 }
98
ThenMemcpyDeviceToDevice(se::Stream * transfer_stream,se::Stream * dst_stream,se::DeviceMemoryBase src_buffer,se::DeviceMemoryBase dst_buffer)99 Status LocalDeviceState::ThenMemcpyDeviceToDevice(
100 se::Stream* transfer_stream, se::Stream* dst_stream,
101 se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) {
102 // The default implementation simply calls ThenMemcpyD2D, and assumes that
103 // the buffer addresses identify the devices. This does not work
104 // on all platforms; this method is virtual so it can be overridden.
105 transfer_stream->ThenMemcpyD2D(&dst_buffer, src_buffer, dst_buffer.size());
106 return Status::OK();
107 }
108
ThenExecuteCallback(se::Stream * stream,std::function<void ()> callback)109 void LocalDeviceState::ThenExecuteCallback(se::Stream* stream,
110 std::function<void()> callback) {
111 tensorflow::profiler::TraceMe traceme("ThenExecuteCallback");
112 if (callback_stream_map_.has_value()) {
113 // Prevent concurrent updates to the callback stream map.
114 absl::MutexLock lock(&mu_);
115 auto callback_stream = callback_stream_map_->find(stream);
116 if (callback_stream == callback_stream_map_->end()) {
117 auto new_stream = std::make_unique<se::Stream>(executor_);
118 new_stream->Init();
119 callback_stream =
120 callback_stream_map_->insert({stream, std::move(new_stream)}).first;
121 }
122 callback_stream->second->ThenWaitFor(stream);
123 stream = callback_stream->second.get();
124 }
125 stream->ThenDoHostCallback([this, callback{std::move(callback)}]() mutable {
126 callback_thread_->Schedule(std::move(callback));
127 });
128 }
129
GetDeviceToHostStream()130 se::Stream* LocalDeviceState::GetDeviceToHostStream() {
131 absl::MutexLock lock(&mu_);
132 int i = next_device_to_host_stream_;
133 next_device_to_host_stream_ =
134 (next_device_to_host_stream_ + 1) % device_to_host_streams_.size();
135 return device_to_host_streams_.at(i).get();
136 }
137
GetDeviceToDeviceStream()138 se::Stream* LocalDeviceState::GetDeviceToDeviceStream() {
139 absl::MutexLock lock(&mu_);
140 int i = next_device_to_device_stream_;
141 next_device_to_device_stream_ =
142 (next_device_to_device_stream_ + 1) % device_to_device_streams_.size();
143 return device_to_device_streams_.at(i).get();
144 }
145
BorrowStreamFromPool()146 std::unique_ptr<se::Stream> LocalDeviceState::BorrowStreamFromPool() {
147 absl::MutexLock lock(&mu_);
148 if (usage_stream_pool_.empty()) {
149 auto stream = std::make_unique<se::Stream>(compute_stream_->parent());
150 stream->Init();
151 return stream;
152 } else {
153 std::unique_ptr<se::Stream> stream = std::move(usage_stream_pool_.top());
154 usage_stream_pool_.pop();
155 auto status = stream->RefreshStatus(); // Can return error::Unimplemented
156 // Stream may fail with "ABORTED: Bad connection".
157 if (status.code() != tensorflow::error::ABORTED) {
158 CHECK(stream->ok()) << status;
159 }
160 return stream;
161 }
162 }
163
ReturnStreamToPool(std::unique_ptr<se::Stream> stream)164 void LocalDeviceState::ReturnStreamToPool(std::unique_ptr<se::Stream> stream) {
165 auto status = stream->RefreshStatus(); // Can return error::Unimplemented
166 // Stream may fail with "ABORTED: Bad connection".
167 if (status.code() != tensorflow::error::ABORTED) {
168 CHECK(stream->ok()) << status;
169 }
170 absl::MutexLock lock(&mu_);
171 usage_stream_pool_.push(std::move(stream));
172 }
173
GetNewPrngSeed()174 int LocalDeviceState::GetNewPrngSeed() {
175 absl::MutexLock lock(&mu_);
176 int x = 0;
177 do {
178 x = prng_seed_distribution_(prng_seed_generator_);
179 } while (x == 0);
180 return x;
181 }
182
183 } // namespace xla
184