• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_TFRT_CPU_DEVICE_BUFFER_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_TFRT_CPU_DEVICE_BUFFER_H_
18 
19 #include <memory>
20 
21 #include "absl/container/inlined_vector.h"
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/core/platform/mem.h"
26 #include "tfrt/host_context/async_value_ref.h"  // from @tf_runtime
27 
28 namespace xla {
29 
30 class MaybeOwningCpuMemory {
31  public:
32   MaybeOwningCpuMemory() = default;
33 
34   // Non-owning.
MaybeOwningCpuMemory(void * buf,size_t size)35   explicit MaybeOwningCpuMemory(void* buf, size_t size)
36       : buf_(buf), size_(size) {}
37 
38   // Owning.
39   using OwnedDataPtr =
40       std::unique_ptr<uint8_t[], decltype(tensorflow::port::AlignedFree)*>;
MaybeOwningCpuMemory(OwnedDataPtr data,size_t size)41   explicit MaybeOwningCpuMemory(OwnedDataPtr data, size_t size)
42       : buf_(data.get()), data_(std::move(data)), size_(size) {}
43 
44   // Move-only.
45   MaybeOwningCpuMemory(MaybeOwningCpuMemory&&) = default;
46   MaybeOwningCpuMemory& operator=(MaybeOwningCpuMemory&&) = default;
47   MaybeOwningCpuMemory(const MaybeOwningCpuMemory&) = delete;
48   MaybeOwningCpuMemory& operator=(const MaybeOwningCpuMemory&) = delete;
49 
50   // Owning.
AllocateShared(size_t size)51   static std::shared_ptr<MaybeOwningCpuMemory> AllocateShared(size_t size) {
52     return std::make_shared<MaybeOwningCpuMemory>(
53         OwnedDataPtr{static_cast<uint8_t*>(tensorflow::port::AlignedMalloc(
54                          size, cpu_function_runtime::kMinAlign)),
55                      tensorflow::port::AlignedFree},
56         size);
57   }
58 
data()59   void* data() const { return buf_; }
size()60   size_t size() const { return size_; }
owns_data()61   bool owns_data() const { return data_ != nullptr; }
62 
63  private:
64   void* buf_ = nullptr;                  // Non-owning data pointer.
65   OwnedDataPtr data_ = {nullptr, free};  // Owning data pointer;
66   size_t size_ = 0;                      // Size in number of bytes.
67 };
68 
69 // tfrt::AsyncValueRef<CpuEvent> is used to indicate the completion of a CPU
70 // operation, e.g., data transfer or running a program.
71 struct CpuEvent {
72   CpuEvent() = default;
73 };
74 
75 // Class that represents CPU buffers. It optionally owns the buffers. It also
76 // tracks the definition and usage of the memory to allow for synchronized usage
77 // and deletion of CPU memory.
78 class TrackedTfrtCpuDeviceBuffer {
79  public:
80   // For non-tuple, takes a single buffer.
81   // For tuple, takes the leaf buffers. Tuple index table created internally.
82   // Nested tuple is not supported.
83   TrackedTfrtCpuDeviceBuffer(
84       bool is_tuple,
85       absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers,
86       absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> definition_events,
87       std::function<void()> on_delete_callback = nullptr);
88 
89   // Move-only.
90   TrackedTfrtCpuDeviceBuffer(TrackedTfrtCpuDeviceBuffer&&) = default;
91   TrackedTfrtCpuDeviceBuffer& operator=(TrackedTfrtCpuDeviceBuffer&&) = default;
92   TrackedTfrtCpuDeviceBuffer(const TrackedTfrtCpuDeviceBuffer&) = delete;
93   TrackedTfrtCpuDeviceBuffer& operator=(const TrackedTfrtCpuDeviceBuffer&) =
94       delete;
95 
96   ~TrackedTfrtCpuDeviceBuffer();
97 
Buffers()98   absl::Span<const std::shared_ptr<MaybeOwningCpuMemory>> Buffers() {
99     return buffers_;
100   }
101 
102   std::shared_ptr<MaybeOwningCpuMemory> Buffer(const ShapeIndex& shape_index);
103 
DefinitionEvents()104   absl::Span<const tfrt::AsyncValueRef<CpuEvent>> DefinitionEvents() const {
105     return definition_events_;
106   }
107 
UsageEvents()108   absl::Span<const tfrt::AsyncValueRef<CpuEvent>> UsageEvents() const {
109     return usage_events_;
110   }
111 
112   void AddUsageEvents(absl::Span<tfrt::AsyncValueRef<CpuEvent>> events);
113 
114   absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4>
115   ConsumeBuffers();
116 
117   // Return the usage events for the buffers. After
118   // LockUseAndTransferUsageEvents is called, it is illegal to AddUsageEvent.
119   absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4>
120   LockUseAndTransferUsageEvents();
121 
122   // Relinquishes ownership of the buffer's device memory, e.g., after the
123   // buffer is passed to a computation that aliases its inputs to outputs.
124   void ReleaseDeviceMemory();
125 
126  private:
127   bool is_tuple_;
128   // If tuple, tuple index table is created and stored.
129   std::shared_ptr<MaybeOwningCpuMemory> tuple_index_table_;
130   // If non-tuple, `buffers_` contains 1 buffer; otherwise all leaf buffers.
131   absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers_;
132   // Definition events are associated with CPU operations that write to the
133   // buffers.
134   absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> definition_events_;
135   // Usage events are associated with CPU operations that read from the buffers.
136   absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> usage_events_;
137   // A callback to call when the TrackedTfrtCpuDeviceBuffer is about to be
138   // destroyed.
139   std::function<void()> on_delete_callback_;
140 };
141 }  // namespace xla
142 
143 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_TFRT_CPU_DEVICE_BUFFER_H_
144