1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_ 17 #define TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_ 18 19 #include <memory> 20 21 #include "absl/container/flat_hash_set.h" 22 #include "tensorflow/compiler/xla/pjrt/event_pool.h" 23 #include "tensorflow/compiler/xla/pjrt/local_device_state.h" 24 #include "tensorflow/compiler/xla/service/shaped_buffer.h" 25 #include "tensorflow/compiler/xla/service/transfer_manager.h" 26 #include "tensorflow/compiler/xla/shape.h" 27 #include "tensorflow/core/platform/thread_annotations.h" 28 #include "tensorflow/stream_executor/device_memory.h" 29 #include "tensorflow/stream_executor/device_memory_allocator.h" 30 #include "tensorflow/stream_executor/stream.h" 31 32 namespace xla { 33 34 // A BufferSequencingEvent keeps track of dependencies of a buffer on each 35 // stream it has been used on. 36 // 37 // Each logical buffer in an XLA computation may be defined (i.e., written to) 38 // at most once. We call the operation that writes the buffer's value on some 39 // stream (e.g., a transfer or compute kernel) the buffer's definition event. 40 // 41 // After the operation that populates the value of a buffer has been enqueued on 42 // 'stream', RecordOnStream(stream) should also be called to trigger the 43 // definition event after the operation has completed. 44 // 45 // After the buffer is read on 'stream' another event should be added so that 46 // it is possible to sequence buffer donation after all reads have completed. 47 // 48 // Since different streams are not necessarily synchronized with one another, 49 // if we wish to consume the value of the buffer on a different stream, we 50 // should first call WaitForEventOnStream(stream), which add a cross-stream 51 // from 'stream' to the buffer's definition event, causing 'stream' to pause 52 // until the definition event has been triggered, if needed. Operations on 53 // 'stream' may then assume that the buffer is valid and its contents correspond 54 // to the desired buffer. 55 // 56 // The dependency logic caches the set of streams at the tail of which the 57 // definition event is known to have occurred; waiting for the same event on the 58 // same stream causes no additional waiting. 59 class BufferSequencingEvent { 60 public: 61 BufferSequencingEvent() = default; 62 63 // Sets the sequencing event to 'event', which is recorded on 'stream'. Must 64 // be called at most once. Unblocks any other host threads that are blocked in 65 // WaitForEventOnStream. 66 void SetSequencingEvent(EventPool::Handle event, se::Stream* stream); 67 68 // Adds synchronization events to 'stream' that wait for this event to be 69 // defined on 'stream'. Does nothing if the event is already known to have 70 // occurred by the tail of 'stream'. If RecordOnStream has not yet been 71 // called, blocks the calling thread until the event has been recorded. 72 void WaitForEventOnStream(se::Stream* stream); 73 74 // Returns true if the event is known to have occurred by the tail of 75 // 'stream'. If RecordOnStream has not yet been called, blocks the calling 76 // thread until the event has been recorded. 77 bool DefinedOn(se::Stream* stream); 78 79 // Returns true if the event is known by the host to have already occurred. If 80 // RecordOnStream has not yet been called, blocks the calling thread until the 81 // event has been recorded. 82 bool IsComplete(); 83 84 // Compares the sequence numbers of two recorded events. It is illegal to call 85 // the comparison operators unless both events have been recorded. 86 inline bool operator<(const BufferSequencingEvent& rhs) const { 87 return sequence_number() < rhs.sequence_number(); 88 } 89 inline bool operator>(const BufferSequencingEvent& rhs) const { 90 return rhs < *this; 91 } 92 inline bool operator<=(const BufferSequencingEvent& rhs) const { 93 return !(*this > rhs); 94 } 95 inline bool operator>=(const BufferSequencingEvent& rhs) const { 96 return !(*this < rhs); 97 } 98 99 private: 100 bool EventHasBeenRecorded() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 101 uint64 sequence_number() const; 102 103 // An event that is triggered when the content of one or more buffers has been 104 // read or written. If this event is used as a definition event and is 105 // nullptr, it is assumed that the buffer's content is always defined for 106 // example because it uses storage borrowed from elsewhere. 107 EventPool::Handle event_; 108 109 mutable absl::Mutex mu_; 110 // A list of all streams for which the buffer's content is known to be defined 111 // at the tail of the queue, i.e., for any newly enqueued command. 112 absl::InlinedVector<se::Stream*, 2> streams_defined_on_ TF_GUARDED_BY(mu_); 113 }; 114 115 // Class that represents a tuple of device buffers. Like a ScopedShapedBuffer it 116 // owns all of the device memory in the tuple. It also tracks the definition and 117 // usage of the memory on streams, to allow for synchronized usage and deletion 118 // of memory under all of the allocation model semantics. 119 class TrackedDeviceBuffer { 120 public: 121 // Helper object to keep track of usage of the buffer on streams. 122 struct StreamAndEvent { 123 // A stream the buffer has been used on. 124 se::Stream* stream; 125 // An event that is later than the most recent usage of the buffer on 126 // stream. 127 std::shared_ptr<BufferSequencingEvent> event; 128 // True if and only if a reference to the buffer is kept live until after 129 // the host knows that event is complete. 130 bool reference_held; 131 }; 132 133 // Converts a ScopedShapedBuffer into a TrackedDeviceBuffer. Takes ownership 134 // of the buffers of the shaped_buffer. 135 static std::shared_ptr<TrackedDeviceBuffer> FromScopedShapedBuffer( 136 ScopedShapedBuffer* shaped_buffer, 137 absl::Span<const std::shared_ptr<BufferSequencingEvent>> 138 definition_events); 139 140 // Builds a ShapedBuffer view onto the buffers of 'tree'. 141 ShapedBuffer AsShapedBuffer(const Shape& on_device_shape) const; 142 143 // Adds the owned device buffers in order to 'iterator'. Used to add the 144 // buffers to an ExecutionInput. We require but do not verify that 'iterator' 145 // when passed in is pointing to a sub-tuple of the ExecutionInput whose 146 // on_device_shape matches that of the TrackedDeviceBuffer. 'end' is used to 147 // check that 'iterator' doesn't run out of bounds. 148 void AddToInputAsImmutable( 149 ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator, 150 const ShapeTree<MaybeOwningDeviceMemory>::iterator& end) const; 151 152 // Adds the owned device buffers in order to 'iterator', marking them as 153 // available to be donated. If donation succeeds, i.e., execution_input is 154 // subsequently successfully enqueued to a computation, 155 // this->ReleaseDeviceMemory() must be called to avoid freeing the device 156 // memory twice. We require but do not verify that 'iterator' when passed in 157 // is pointing to a sub-tuple of execution_input whose on_device_shape matches 158 // that of the TrackedDeviceBuffer. 'end' is used to check that 'iterator' 159 // doesn't run out of bounds. 160 void AddToInputAsDonated( 161 ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator, 162 const ShapeTree<MaybeOwningDeviceMemory>::iterator& end, 163 ExecutionInput* execution_input, 164 se::DeviceMemoryAllocator* allocator) const; 165 allocator()166 se::DeviceMemoryAllocator* allocator() const { return allocator_; } device_ordinal()167 int device_ordinal() const { return device_ordinal_; } device_memory()168 absl::InlinedVector<se::DeviceMemoryBase, 1>& device_memory() { 169 return device_memory_; 170 } device_memory()171 const absl::InlinedVector<se::DeviceMemoryBase, 1>& device_memory() const { 172 return device_memory_; 173 } definition_events()174 absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events() 175 const { 176 return definition_events_; 177 } usage_events()178 absl::Span<const StreamAndEvent> usage_events() const { 179 return usage_events_; 180 } 181 182 // Relinquishes ownership of the buffer's device memory, e.g., after the 183 // buffer is passed to a computation that aliases its inputs to outputs. ReleaseDeviceMemory()184 void ReleaseDeviceMemory() { device_memory_.clear(); } 185 186 // Indicates that the buffer has been used on a stream. 187 // 188 // usage_stream: a stream that the buffer was used on. 189 // event: an event that has been recorded on usage_stream after the 190 // buffer was used. 191 // reference_held: true if and only if the caller has caused a memory 192 // reference to *this to stay live until after the host 193 // is sure that the usage (transfer or execution) has 194 // completed. 195 void AddUsageEvent(se::Stream* usage_stream, 196 std::shared_ptr<BufferSequencingEvent> event, 197 bool reference_held); 198 199 using StreamAndEventContainer = absl::InlinedVector<StreamAndEvent, 3>; 200 // Returns the set of streams that the buffer was used on, and for each stream 201 // an event later than the last use of the buffer. After 202 // LockUseAndTransferUsageEvents is called it is illegal to use the buffer on 203 // any stream and, e.g. AddUsageHold will CHECK fail. 204 StreamAndEventContainer LockUseAndTransferUsageEvents(); 205 TrackedDeviceBuffer()206 TrackedDeviceBuffer() : in_use_(true) {} 207 TrackedDeviceBuffer(se::DeviceMemoryAllocator* allocator, int device_ordinal, 208 absl::Span<se::DeviceMemoryBase const> device_memory, 209 absl::Span<const std::shared_ptr<BufferSequencingEvent>> 210 definition_events, 211 std::function<void()> on_delete_callback); 212 ~TrackedDeviceBuffer(); 213 214 private: 215 // Are the buffers in device_memory_ owned? If so, which allocator and device 216 // ordinal? May be nullptr, indicating the buffers are not owned. 217 se::DeviceMemoryAllocator* allocator_; 218 int device_ordinal_; 219 220 // Each host-side buffer may have several buffers on-device. 221 absl::InlinedVector<se::DeviceMemoryBase, 1> device_memory_; 222 223 // Events that are triggered when the content of one or more buffers is ready 224 // during multistream execution. May be nullptr, which is used in the 225 // single-stream execution case where events are not necessary for buffer 226 // event sequencing. All events must be triggered before the buffers can be 227 // used. 228 absl::InlinedVector<std::shared_ptr<BufferSequencingEvent>, 2> 229 definition_events_; 230 231 // in_use_ starts out true, and is set to false when the buffer is released 232 // from its owning PjRtBuffer. Once in_use_ is false, the buffer may no 233 // longer be used on any stream. 234 bool in_use_; 235 // Set of streams that the buffer has ever been used on, see comment on 236 // StreamAndEvent. 237 StreamAndEventContainer usage_events_; 238 239 // A callback to call when the TrackedDeviceBuffer is about to be destroyed. 240 std::function<void()> on_delete_callback_; 241 }; 242 243 // Populates 'events' with the set of buffer events for buffer. If 244 // get_usage_events=true populates with the latest usage events, otherwise 245 // populates with the definition events. 246 void GetDeviceBufferEvents(const TrackedDeviceBuffer& buffer, 247 bool get_usage_events, 248 absl::flat_hash_set<BufferSequencingEvent*>* events); 249 250 // Waits for all of the definition events in a buffer on 'stream'. 251 void WaitForBufferDefinitionEventsOnStream(const TrackedDeviceBuffer& buffer, 252 se::Stream* stream); 253 254 } // namespace xla 255 256 #endif // TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_ 257