1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "absl/memory/memory.h"
22 #include "absl/synchronization/mutex.h"
23 #include "tensorflow/compiler/xla/util.h"
24 #include "tensorflow/core/platform/types.h"
25 #include "tensorflow/core/protobuf/error_codes.pb.h"
26 #include "tensorflow/stream_executor/stream.h"
27
28 namespace xla {
29
LocalDeviceState(se::StreamExecutor * executor,LocalClient * client,AllocationModel allocation_model,bool asynchronous,bool allow_event_reuse)30 LocalDeviceState::LocalDeviceState(se::StreamExecutor* executor,
31 LocalClient* client,
32 AllocationModel allocation_model,
33 bool asynchronous, bool allow_event_reuse)
34 : allocation_model_(allocation_model),
35 event_pool_(allow_event_reuse),
36 compute_semaphore_(/*capacity=*/asynchronous ? 32 : 1),
37 executor_(executor),
38 client_(client),
39 prng_seed_generator_(prng_seed_device_()),
40 prng_seed_distribution_(std::numeric_limits<int>::min(),
41 std::numeric_limits<int>::max()) {
42 compute_stream_ = absl::make_unique<se::Stream>(executor);
43 host_to_device_stream_ = absl::make_unique<se::Stream>(executor);
44 callback_stream_ = absl::make_unique<se::Stream>(executor);
45 compute_stream_->Init();
46 host_to_device_stream_->Init();
47 callback_stream_->Init();
48 device_to_host_streams_.reserve(kNumDeviceToHostStreams);
49 for (int i = 0; i < kNumDeviceToHostStreams; ++i) {
50 auto stream = absl::make_unique<se::Stream>(executor);
51 stream->Init();
52 device_to_host_streams_.push_back(std::move(stream));
53 }
54 device_to_device_streams_.reserve(kNumDeviceToDeviceStreams);
55 for (int i = 0; i < kNumDeviceToDeviceStreams; ++i) {
56 auto stream = absl::make_unique<se::Stream>(executor);
57 stream->Init();
58 device_to_device_streams_.push_back(std::move(stream));
59 }
60 execute_thread_ = absl::make_unique<WorkerThread>(tensorflow::Env::Default(),
61 "py_xla_execute");
62 callback_thread_ = absl::make_unique<WorkerThread>(tensorflow::Env::Default(),
63 "py_xla_callback");
64 }
65
~LocalDeviceState()66 LocalDeviceState::~LocalDeviceState() {
67 Status status = SynchronizeAllActivity();
68 if (!status.ok()) {
69 LOG(ERROR) << "Error when closing device: " << status;
70 }
71 }
72
SynchronizeAllActivity()73 Status LocalDeviceState::SynchronizeAllActivity() {
74 Status status;
75 // TODO(phawkins): in theory the call to SynchronizeAllActivity below should
76 // suffice. However on the Host platform SynchronizeAllActivity is a dummy
77 // implementation that doesn't actually block. To make sure activity has
78 // stopped, also block on the compute stream. If SynchronizeAllActivity is
79 // fixed, we could remove the BlockHostUntilDone call.
80 status.Update(compute_stream_->BlockHostUntilDone());
81 status.Update(callback_stream_->BlockHostUntilDone());
82 bool ok = compute_stream_->parent()->SynchronizeAllActivity();
83 if (!ok) {
84 status.Update(Unknown("SynchronizeAllActivity failed."));
85 }
86 return status;
87 }
88
ThenMemcpyDeviceToDevice(se::Stream * transfer_stream,se::Stream * dst_stream,se::DeviceMemoryBase src_buffer,se::DeviceMemoryBase dst_buffer)89 Status LocalDeviceState::ThenMemcpyDeviceToDevice(
90 se::Stream* transfer_stream, se::Stream* dst_stream,
91 se::DeviceMemoryBase src_buffer, se::DeviceMemoryBase dst_buffer) {
92 // The default implementation simply calls ThenMemcpyD2D, and assumes that
93 // the buffer addresses identify the devices. This does not work
94 // on all platforms; this method is virtual so it can be overridden.
95 transfer_stream->ThenMemcpyD2D(&dst_buffer, src_buffer, dst_buffer.size());
96 return Status::OK();
97 }
98
ThenExecuteOnCallbackThread(se::Stream * stream,std::function<void ()> callback) const99 void LocalDeviceState::ThenExecuteOnCallbackThread(
100 se::Stream* stream, std::function<void()> callback) const {
101 stream->ThenDoHostCallback([this, callback]() mutable {
102 callback_thread_->Schedule(std::move(callback));
103 });
104 }
105
GetDeviceToHostStream()106 se::Stream* LocalDeviceState::GetDeviceToHostStream() {
107 absl::MutexLock lock(&mu_);
108 int i = next_device_to_host_stream_;
109 next_device_to_host_stream_ =
110 (next_device_to_host_stream_ + 1) % device_to_host_streams_.size();
111 return device_to_host_streams_.at(i).get();
112 }
113
GetDeviceToDeviceStream()114 se::Stream* LocalDeviceState::GetDeviceToDeviceStream() {
115 absl::MutexLock lock(&mu_);
116 int i = next_device_to_device_stream_;
117 next_device_to_device_stream_ =
118 (next_device_to_device_stream_ + 1) % device_to_device_streams_.size();
119 return device_to_device_streams_.at(i).get();
120 }
121
BorrowStreamFromPool()122 std::unique_ptr<se::Stream> LocalDeviceState::BorrowStreamFromPool() {
123 absl::MutexLock lock(&mu_);
124 if (usage_stream_pool_.empty()) {
125 auto stream = absl::make_unique<se::Stream>(compute_stream_->parent());
126 stream->Init();
127 return stream;
128 } else {
129 std::unique_ptr<se::Stream> stream = std::move(usage_stream_pool_.top());
130 usage_stream_pool_.pop();
131 auto status = stream->RefreshStatus(); // Can return error::Unimplemented
132 // Stream may fail with "ABORTED: Bad connection".
133 if (status.code() != tensorflow::error::ABORTED) {
134 CHECK(stream->ok()) << status;
135 }
136 return stream;
137 }
138 }
139
ReturnStreamToPool(std::unique_ptr<se::Stream> stream)140 void LocalDeviceState::ReturnStreamToPool(std::unique_ptr<se::Stream> stream) {
141 auto status = stream->RefreshStatus(); // Can return error::Unimplemented
142 // Stream may fail with "ABORTED: Bad connection".
143 if (status.code() != tensorflow::error::ABORTED) {
144 CHECK(stream->ok()) << status;
145 }
146 absl::MutexLock lock(&mu_);
147 usage_stream_pool_.push(std::move(stream));
148 }
149
GetNewPrngSeed()150 int LocalDeviceState::GetNewPrngSeed() {
151 absl::MutexLock lock(&mu_);
152 int x = 0;
153 do {
154 x = prng_seed_distribution_(prng_seed_generator_);
155 } while (x == 0);
156 return x;
157 }
158
159 } // namespace xla
160