1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
17
18 #include <iterator>
19 #include <memory>
20
21 #include "absl/synchronization/mutex.h"
22 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
23 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
24 #include "tensorflow/compiler/xla/types.h"
25 #include "tensorflow/stream_executor/device_memory.h"
26 #include "tensorflow/stream_executor/device_memory_allocator.h"
27 #include "tensorflow/stream_executor/event.h"
28 #include "tensorflow/stream_executor/stream.h"
29
30 namespace xla {
31
SetSequencingEvent(EventPool::Handle event,se::Stream * stream)32 void BufferSequencingEvent::SetSequencingEvent(EventPool::Handle event,
33 se::Stream* stream) {
34 absl::MutexLock lock(&mu_);
35 CHECK(!event_.event());
36 event_ = std::move(event);
37 CHECK(streams_defined_on_.empty());
38 streams_defined_on_.push_back(stream);
39 }
40
EventHasBeenRecorded() const41 bool BufferSequencingEvent::EventHasBeenRecorded() const {
42 return event_.event() != nullptr;
43 }
44
sequence_number() const45 uint64 BufferSequencingEvent::sequence_number() const {
46 absl::MutexLock lock(&mu_);
47 CHECK(EventHasBeenRecorded());
48 return event_.sequence_number();
49 }
50
WaitForEventOnStream(se::Stream * stream)51 void BufferSequencingEvent::WaitForEventOnStream(se::Stream* stream) {
52 absl::MutexLock lock(&mu_);
53
54 // We cannot wait for an event until ThenRecordEvent has been called; on GPU
55 // newly created events are deemed to have already happened past.
56 mu_.Await(
57 absl::Condition(this, &BufferSequencingEvent::EventHasBeenRecorded));
58
59 // The set of defined streams is expected to be very small indeed (usually
60 // 1-2), so a simple linear scan should be fast enough.
61 if (std::find(streams_defined_on_.begin(), streams_defined_on_.end(),
62 stream) != streams_defined_on_.end()) {
63 // stream is in streams_defined_on_; it doesn't need to be waited on.
64 return;
65 }
66
67 stream->ThenWaitFor(event_.event());
68 streams_defined_on_.push_back(stream);
69 }
70
DefinedOn(se::Stream * stream)71 bool BufferSequencingEvent::DefinedOn(se::Stream* stream) {
72 absl::MutexLock lock(&mu_);
73
74 // We cannot wait for an event until ThenRecordEvent has been called; on GPU
75 // newly created events are deemed to have already happened past.
76 mu_.Await(
77 absl::Condition(this, &BufferSequencingEvent::EventHasBeenRecorded));
78
79 // The set of defined streams is expected to be very small indeed (usually
80 // 1-2), so a simple linear scan should be fast enough.
81 return std::find(streams_defined_on_.begin(), streams_defined_on_.end(),
82 stream) != streams_defined_on_.end();
83 }
84
IsComplete()85 bool BufferSequencingEvent::IsComplete() {
86 absl::MutexLock lock(&mu_);
87
88 // We cannot wait for an event until ThenRecordEvent has been called; on
89 // GPU newly created events are deemed to have already happened past.
90 mu_.Await(
91 absl::Condition(this, &BufferSequencingEvent::EventHasBeenRecorded));
92
93 return event_.event()->PollForStatus() == se::Event::Status::kComplete;
94 }
95
96 /* static */ std::shared_ptr<TrackedDeviceBuffer>
FromScopedShapedBuffer(ScopedShapedBuffer * shaped_buffer,absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events)97 TrackedDeviceBuffer::FromScopedShapedBuffer(
98 ScopedShapedBuffer* shaped_buffer,
99 absl::Span<const std::shared_ptr<BufferSequencingEvent>>
100 definition_events) {
101 ShapeTree<se::DeviceMemoryBase>::iterator iterator =
102 shaped_buffer->buffers().begin();
103 std::vector<se::DeviceMemoryBase> buffers;
104 buffers.reserve(1);
105
106 ShapeUtil::ForEachSubshape(
107 shaped_buffer->on_device_shape(), [&](const Shape&, const ShapeIndex&) {
108 CHECK(iterator != shaped_buffer->buffers().end());
109 buffers.push_back(iterator->second);
110 iterator->second = se::DeviceMemoryBase();
111 ++iterator;
112 });
113 CHECK(iterator == shaped_buffer->buffers().end());
114 return std::make_shared<TrackedDeviceBuffer>(
115 shaped_buffer->memory_allocator(), shaped_buffer->device_ordinal(),
116 absl::Span<se::DeviceMemoryBase>(buffers), definition_events,
117 /*on_delete_callback=*/nullptr);
118 }
119
AsShapedBuffer(const Shape & on_device_shape) const120 ShapedBuffer TrackedDeviceBuffer::AsShapedBuffer(
121 const Shape& on_device_shape) const {
122 ShapedBuffer shaped_buffer(on_device_shape, device_ordinal_);
123 ShapeTree<se::DeviceMemoryBase>::iterator iterator =
124 shaped_buffer.buffers().begin();
125 for (const se::DeviceMemoryBase& buf : device_memory_) {
126 CHECK(iterator != shaped_buffer.buffers().end());
127 iterator->second = buf;
128 ++iterator;
129 }
130 CHECK(iterator == shaped_buffer.buffers().end());
131 return shaped_buffer;
132 }
133
134 // See comment on ExecutionInput in xla/service/executable.h to understand
135 // the meaning of owned/unowned in that class.
136
AddToInputAsImmutable(ShapeTree<MaybeOwningDeviceMemory>::iterator * iterator,const ShapeTree<MaybeOwningDeviceMemory>::iterator & end) const137 void TrackedDeviceBuffer::AddToInputAsImmutable(
138 ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
139 const ShapeTree<MaybeOwningDeviceMemory>::iterator& end) const {
140 for (const se::DeviceMemoryBase& buf : device_memory_) {
141 CHECK(*iterator != end);
142 // Set buffers to be case (1) in the comment on ExecutionInput.
143 (*iterator)->second = MaybeOwningDeviceMemory(buf);
144 ++(*iterator);
145 }
146 }
147
AddToInputAsDonated(ShapeTree<MaybeOwningDeviceMemory>::iterator * iterator,const ShapeTree<MaybeOwningDeviceMemory>::iterator & end,ExecutionInput * execution_input,se::DeviceMemoryAllocator * allocator) const148 void TrackedDeviceBuffer::AddToInputAsDonated(
149 ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
150 const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
151 ExecutionInput* execution_input,
152 se::DeviceMemoryAllocator* allocator) const {
153 for (const se::DeviceMemoryBase& buf : device_memory_) {
154 CHECK(*iterator != end);
155 // Set buffers to be case (2) in the comment on ExecutionInput.
156 (*iterator)->second = MaybeOwningDeviceMemory(
157 se::OwningDeviceMemory(buf, device_ordinal_, allocator));
158 execution_input->SetUnownedIndex((*iterator)->first);
159 ++(*iterator);
160 }
161 }
162
TrackedDeviceBuffer(se::DeviceMemoryAllocator * allocator,int device_ordinal,absl::Span<se::DeviceMemoryBase const> device_memory,absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events,std::function<void ()> on_delete_callback)163 TrackedDeviceBuffer::TrackedDeviceBuffer(
164 se::DeviceMemoryAllocator* allocator, int device_ordinal,
165 absl::Span<se::DeviceMemoryBase const> device_memory,
166 absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events,
167 std::function<void()> on_delete_callback)
168 : allocator_(allocator),
169 device_ordinal_(device_ordinal),
170 device_memory_(device_memory.begin(), device_memory.end()),
171 definition_events_(std::make_move_iterator(definition_events.begin()),
172 std::make_move_iterator(definition_events.end())),
173 in_use_(true),
174 on_delete_callback_(std::move(on_delete_callback)) {}
175
~TrackedDeviceBuffer()176 TrackedDeviceBuffer::~TrackedDeviceBuffer() {
177 if (allocator_) {
178 for (const se::DeviceMemoryBase& buffer : device_memory_) {
179 Status status = allocator_->Deallocate(device_ordinal_, buffer);
180 if (!status.ok()) {
181 LOG(ERROR) << "Buffer deallocation failed: " << status;
182 }
183 }
184 }
185 if (on_delete_callback_) {
186 on_delete_callback_();
187 }
188 }
189
AddUsageEvent(se::Stream * usage_stream,std::shared_ptr<BufferSequencingEvent> event,bool reference_held)190 void TrackedDeviceBuffer::AddUsageEvent(
191 se::Stream* usage_stream, std::shared_ptr<BufferSequencingEvent> event,
192 bool reference_held) {
193 CHECK(in_use_);
194
195 for (auto& existing : usage_events_) {
196 if (existing.stream == usage_stream) {
197 if (*existing.event < *event) {
198 existing.event = event;
199 existing.reference_held = reference_held;
200 }
201 return;
202 }
203 }
204 usage_events_.push_back({usage_stream, event, reference_held});
205 }
206
207 TrackedDeviceBuffer::StreamAndEventContainer
LockUseAndTransferUsageEvents()208 TrackedDeviceBuffer::LockUseAndTransferUsageEvents() {
209 CHECK(in_use_);
210 in_use_ = false;
211 return std::move(usage_events_);
212 }
213
GetDeviceBufferEvents(const TrackedDeviceBuffer & buffer,bool get_usage_events,absl::flat_hash_set<BufferSequencingEvent * > * events)214 void GetDeviceBufferEvents(
215 const TrackedDeviceBuffer& buffer, bool get_usage_events,
216 absl::flat_hash_set<BufferSequencingEvent*>* events) {
217 if (get_usage_events) {
218 for (const auto& e : buffer.usage_events()) {
219 events->insert(e.event.get());
220 }
221 } else {
222 for (const auto& e : buffer.definition_events()) {
223 events->insert(e.get());
224 }
225 }
226 }
227
WaitForBufferDefinitionEventsOnStream(const TrackedDeviceBuffer & buffer,se::Stream * stream)228 void WaitForBufferDefinitionEventsOnStream(const TrackedDeviceBuffer& buffer,
229 se::Stream* stream) {
230 absl::flat_hash_set<BufferSequencingEvent*> events;
231 GetDeviceBufferEvents(buffer, /*get_usage_events=*/false, &events);
232 for (BufferSequencingEvent* event : events) {
233 event->WaitForEventOnStream(stream);
234 }
235 }
236
237 } // namespace xla
238