• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_
18 
19 #include <memory>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "tensorflow/compiler/xla/pjrt/event_pool.h"
23 #include "tensorflow/compiler/xla/pjrt/local_device_state.h"
24 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
25 #include "tensorflow/compiler/xla/service/transfer_manager.h"
26 #include "tensorflow/compiler/xla/shape.h"
27 #include "tensorflow/core/platform/thread_annotations.h"
28 #include "tensorflow/stream_executor/device_memory.h"
29 #include "tensorflow/stream_executor/device_memory_allocator.h"
30 #include "tensorflow/stream_executor/stream.h"
31 
32 namespace xla {
33 
34 // A BufferSequencingEvent keeps track of dependencies of a buffer on each
35 // stream it has been used on.
36 //
37 // Each logical buffer in an XLA computation may be defined (i.e., written to)
38 // at most once. We call the operation that writes the buffer's value on some
39 // stream (e.g., a transfer or compute kernel) the buffer's definition event.
40 //
41 // After the operation that populates the value of a buffer has been enqueued on
42 // 'stream', RecordOnStream(stream) should also be called to trigger the
43 // definition event after the operation has completed.
44 //
45 // After the buffer is read on 'stream' another event should be added so that
46 // it is possible to sequence buffer donation after all reads have completed.
47 //
48 // Since different streams are not necessarily synchronized with one another,
49 // if we wish to consume the value of the buffer on a different stream, we
50 // should first call WaitForEventOnStream(stream), which add a cross-stream
51 // from 'stream' to the buffer's definition event, causing 'stream' to pause
52 // until the definition event has been triggered, if needed. Operations on
53 // 'stream' may then assume that the buffer is valid and its contents correspond
54 // to the desired buffer.
55 //
56 // The dependency logic caches the set of streams at the tail of which the
57 // definition event is known to have occurred; waiting for the same event on the
58 // same stream causes no additional waiting.
59 class BufferSequencingEvent {
60  public:
61   BufferSequencingEvent() = default;
62 
63   // Sets the sequencing event to 'event', which is recorded on 'stream'. Must
64   // be called at most once. Unblocks any other host threads that are blocked in
65   // WaitForEventOnStream.
66   void SetSequencingEvent(EventPool::Handle event, se::Stream* stream);
67 
68   // Adds synchronization events to 'stream' that wait for this event to be
69   // defined on 'stream'. Does nothing if the event is already known to have
70   // occurred by the tail of 'stream'. If RecordOnStream has not yet been
71   // called, blocks the calling thread until the event has been recorded.
72   void WaitForEventOnStream(se::Stream* stream);
73 
74   // Returns true if the event is known to have occurred by the tail of
75   // 'stream'. If RecordOnStream has not yet been called, blocks the calling
76   // thread until the event has been recorded.
77   bool DefinedOn(se::Stream* stream);
78 
79   // Returns true if the event is known by the host to have already occurred. If
80   // RecordOnStream has not yet been called, blocks the calling thread until the
81   // event has been recorded.
82   bool IsComplete();
83 
84   // Compares the sequence numbers of two recorded events. It is illegal to call
85   // the comparison operators unless both events have been recorded.
86   inline bool operator<(const BufferSequencingEvent& rhs) const {
87     return sequence_number() < rhs.sequence_number();
88   }
89   inline bool operator>(const BufferSequencingEvent& rhs) const {
90     return rhs < *this;
91   }
92   inline bool operator<=(const BufferSequencingEvent& rhs) const {
93     return !(*this > rhs);
94   }
95   inline bool operator>=(const BufferSequencingEvent& rhs) const {
96     return !(*this < rhs);
97   }
98 
99  private:
100   bool EventHasBeenRecorded() const TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
101   uint64 sequence_number() const;
102 
103   // An event that is triggered when the content of one or more buffers has been
104   // read or written. If this event is used as a definition event and is
105   // nullptr, it is assumed that the buffer's content is always defined for
106   // example because it uses storage borrowed from elsewhere.
107   EventPool::Handle event_;
108 
109   // Cache of event_->sequence_number that avoids synchronization overhead.
110   // TODO(phawkins): In fact, event_->sequence_number is unused beyond the
111   // initial population of sequence_number_, and we could remove it if we
112   // refactored the EventPool API.
113   std::atomic<uint64_t> sequence_number_{0};
114 
115   mutable absl::Mutex mu_;
116   // A list of all streams for which the buffer's content is known to be defined
117   // at the tail of the queue, i.e., for any newly enqueued command.
118   absl::InlinedVector<se::Stream*, 2> streams_defined_on_ TF_GUARDED_BY(mu_);
119 };
120 
121 // Class that represents a tuple of device buffers. Like a ScopedShapedBuffer it
122 // owns all of the device memory in the tuple. It also tracks the definition and
123 // usage of the memory on streams, to allow for synchronized usage and deletion
124 // of memory under all of the allocation model semantics.
125 class TrackedDeviceBuffer {
126  public:
127   // Helper object to keep track of usage of the buffer on streams.
128   struct StreamAndEvent {
129     // A stream the buffer has been used on.
130     se::Stream* stream;
131     // An event that is later than the most recent usage of the buffer on
132     // stream.
133     std::shared_ptr<BufferSequencingEvent> event;
134     // True if and only if a reference to the buffer is kept live until after
135     // the host knows that event is complete.
136     bool reference_held;
137   };
138 
139   // Converts a ScopedShapedBuffer into a TrackedDeviceBuffer. Takes ownership
140   // of the buffers of the shaped_buffer.
141   static std::shared_ptr<TrackedDeviceBuffer> FromScopedShapedBuffer(
142       ScopedShapedBuffer* shaped_buffer,
143       absl::Span<const std::shared_ptr<BufferSequencingEvent>>
144           definition_events);
145 
146   // Builds a ShapedBuffer view onto the buffers of 'tree'.
147   ShapedBuffer AsShapedBuffer(const Shape& on_device_shape) const;
148 
149   // Adds the owned device buffers in order to 'iterator'. Used to add the
150   // buffers to an ExecutionInput. We require but do not verify that 'iterator'
151   // when passed in is pointing to a sub-tuple of the ExecutionInput whose
152   // on_device_shape matches that of the TrackedDeviceBuffer. 'end' is used to
153   // check that 'iterator' doesn't run out of bounds.
154   void AddToInputAsImmutable(
155       ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
156       const ShapeTree<MaybeOwningDeviceMemory>::iterator& end) const;
157 
158   // Adds the owned device buffers in order to 'iterator', marking them as
159   // available to be donated. If donation succeeds, i.e., execution_input is
160   // subsequently successfully enqueued to a computation,
161   // this->ReleaseDeviceMemory() must be called to avoid freeing the device
162   // memory twice. We require but do not verify that 'iterator' when passed in
163   // is pointing to a sub-tuple of execution_input whose on_device_shape matches
164   // that of the TrackedDeviceBuffer. 'end' is used to check that 'iterator'
165   // doesn't run out of bounds.
166   void AddToInputAsDonated(
167       ShapeTree<MaybeOwningDeviceMemory>::iterator* iterator,
168       const ShapeTree<MaybeOwningDeviceMemory>::iterator& end,
169       ExecutionInput* execution_input,
170       se::DeviceMemoryAllocator* allocator) const;
171 
allocator()172   se::DeviceMemoryAllocator* allocator() const { return allocator_; }
device_ordinal()173   int device_ordinal() const { return device_ordinal_; }
device_memory()174   absl::InlinedVector<se::DeviceMemoryBase, 1>& device_memory() {
175     return device_memory_;
176   }
device_memory()177   const absl::InlinedVector<se::DeviceMemoryBase, 1>& device_memory() const {
178     return device_memory_;
179   }
definition_events()180   absl::Span<const std::shared_ptr<BufferSequencingEvent>> definition_events()
181       const {
182     return definition_events_;
183   }
usage_events()184   absl::Span<const StreamAndEvent> usage_events() const {
185     return usage_events_;
186   }
187 
188   // Relinquishes ownership of the buffer's device memory, e.g., after the
189   // buffer is passed to a computation that aliases its inputs to outputs.
ReleaseDeviceMemory()190   void ReleaseDeviceMemory() { device_memory_.clear(); }
191 
192   // Indicates that the buffer has been used on a stream.
193   //
194   //   usage_stream:   a stream that the buffer was used on.
195   //   event:          an event that has been recorded on usage_stream after the
196   //                   buffer was used.
197   //   reference_held: true if and only if the caller has caused a memory
198   //                   reference to *this to stay live until after the host
199   //                   is sure that the usage (transfer or execution) has
200   //                   completed.
201   void AddUsageEvent(se::Stream* usage_stream,
202                      std::shared_ptr<BufferSequencingEvent> event,
203                      bool reference_held);
204 
205   using StreamAndEventContainer = absl::InlinedVector<StreamAndEvent, 3>;
206   // Returns the set of streams that the buffer was used on, and for each stream
207   // an event later than the last use of the buffer. After
208   // LockUseAndTransferUsageEvents is called it is illegal to use the buffer on
209   // any stream and, e.g. AddUsageHold will CHECK fail.
210   StreamAndEventContainer LockUseAndTransferUsageEvents();
211 
TrackedDeviceBuffer()212   TrackedDeviceBuffer() : in_use_(true) {}
213   TrackedDeviceBuffer(se::DeviceMemoryAllocator* allocator, int device_ordinal,
214                       absl::Span<se::DeviceMemoryBase const> device_memory,
215                       absl::Span<const std::shared_ptr<BufferSequencingEvent>>
216                           definition_events,
217                       std::function<void()> on_delete_callback);
218   ~TrackedDeviceBuffer();
219 
220  private:
221   // Are the buffers in device_memory_ owned? If so, which allocator and device
222   // ordinal? May be nullptr, indicating the buffers are not owned.
223   se::DeviceMemoryAllocator* allocator_;
224   int device_ordinal_;
225 
226   // Each host-side buffer may have several buffers on-device.
227   absl::InlinedVector<se::DeviceMemoryBase, 1> device_memory_;
228 
229   // Events that are triggered when the content of one or more buffers is ready
230   // during multistream execution. May be nullptr, which is used in the
231   // single-stream execution case where events are not necessary for buffer
232   // event sequencing. All events must be triggered before the buffers can be
233   // used.
234   absl::InlinedVector<std::shared_ptr<BufferSequencingEvent>, 2>
235       definition_events_;
236 
237   // in_use_ starts out true, and is set to false when the buffer is released
238   // from its owning PjRtBuffer. Once in_use_ is false, the buffer may no
239   // longer be used on any stream.
240   bool in_use_;
241   // Set of streams that the buffer has ever been used on, see comment on
242   // StreamAndEvent.
243   StreamAndEventContainer usage_events_;
244 
245   // A callback to call when the TrackedDeviceBuffer is about to be destroyed.
246   std::function<void()> on_delete_callback_;
247 };
248 
249 // Populates 'events' with the set of buffer events for buffer. If
250 // get_usage_events=true populates with the latest usage events, otherwise
251 // populates with the definition events.
252 void GetDeviceBufferEvents(const TrackedDeviceBuffer& buffer,
253                            bool get_usage_events,
254                            absl::flat_hash_set<BufferSequencingEvent*>* events);
255 
256 // Waits for all of the definition events in a buffer on 'stream'.
257 void WaitForBufferDefinitionEventsOnStream(const TrackedDeviceBuffer& buffer,
258                                            se::Stream* stream);
259 
260 }  // namespace xla
261 
262 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_DEVICE_BUFFER_H_
263