1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_TFRT_CPU_DEVICE_BUFFER_H_ 17 #define TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_TFRT_CPU_DEVICE_BUFFER_H_ 18 19 #include <memory> 20 21 #include "absl/container/inlined_vector.h" 22 #include "absl/types/span.h" 23 #include "tensorflow/compiler/xla/cpu_function_runtime.h" 24 #include "tensorflow/compiler/xla/shape_util.h" 25 #include "tensorflow/core/platform/mem.h" 26 #include "tfrt/host_context/async_value_ref.h" // from @tf_runtime 27 28 namespace xla { 29 30 class MaybeOwningCpuMemory { 31 public: 32 MaybeOwningCpuMemory() = default; 33 34 // Non-owning. MaybeOwningCpuMemory(void * buf,size_t size)35 explicit MaybeOwningCpuMemory(void* buf, size_t size) 36 : buf_(buf), size_(size) {} 37 38 // Owning. 39 using OwnedDataPtr = 40 std::unique_ptr<uint8_t[], decltype(tensorflow::port::AlignedFree)*>; MaybeOwningCpuMemory(OwnedDataPtr data,size_t size)41 explicit MaybeOwningCpuMemory(OwnedDataPtr data, size_t size) 42 : buf_(data.get()), data_(std::move(data)), size_(size) {} 43 44 // Move-only. 45 MaybeOwningCpuMemory(MaybeOwningCpuMemory&&) = default; 46 MaybeOwningCpuMemory& operator=(MaybeOwningCpuMemory&&) = default; 47 MaybeOwningCpuMemory(const MaybeOwningCpuMemory&) = delete; 48 MaybeOwningCpuMemory& operator=(const MaybeOwningCpuMemory&) = delete; 49 50 // Owning. AllocateShared(size_t size)51 static std::shared_ptr<MaybeOwningCpuMemory> AllocateShared(size_t size) { 52 return std::make_shared<MaybeOwningCpuMemory>( 53 OwnedDataPtr{static_cast<uint8_t*>(tensorflow::port::AlignedMalloc( 54 size, cpu_function_runtime::kMinAlign)), 55 tensorflow::port::AlignedFree}, 56 size); 57 } 58 data()59 void* data() const { return buf_; } size()60 size_t size() const { return size_; } owns_data()61 bool owns_data() const { return data_ != nullptr; } 62 63 private: 64 void* buf_ = nullptr; // Non-owning data pointer. 65 OwnedDataPtr data_ = {nullptr, free}; // Owning data pointer; 66 size_t size_ = 0; // Size in number of bytes. 67 }; 68 69 // tfrt::AsyncValueRef<CpuEvent> is used to indicate the completion of a CPU 70 // operation, e.g., data transfer or running a program. 71 struct CpuEvent { 72 CpuEvent() = default; 73 }; 74 75 // Class that represents CPU buffers. It optionally owns the buffers. It also 76 // tracks the definition and usage of the memory to allow for synchronized usage 77 // and deletion of CPU memory. 78 class TrackedTfrtCpuDeviceBuffer { 79 public: 80 // For non-tuple, takes a single buffer. 81 // For tuple, takes the leaf buffers. Tuple index table created internally. 82 // Nested tuple is not supported. 83 TrackedTfrtCpuDeviceBuffer( 84 bool is_tuple, 85 absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers, 86 absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> definition_events, 87 std::function<void()> on_delete_callback = nullptr); 88 89 // Move-only. 90 TrackedTfrtCpuDeviceBuffer(TrackedTfrtCpuDeviceBuffer&&) = default; 91 TrackedTfrtCpuDeviceBuffer& operator=(TrackedTfrtCpuDeviceBuffer&&) = default; 92 TrackedTfrtCpuDeviceBuffer(const TrackedTfrtCpuDeviceBuffer&) = delete; 93 TrackedTfrtCpuDeviceBuffer& operator=(const TrackedTfrtCpuDeviceBuffer&) = 94 delete; 95 96 ~TrackedTfrtCpuDeviceBuffer(); 97 Buffers()98 absl::Span<const std::shared_ptr<MaybeOwningCpuMemory>> Buffers() { 99 return buffers_; 100 } 101 102 std::shared_ptr<MaybeOwningCpuMemory> Buffer(const ShapeIndex& shape_index); 103 DefinitionEvents()104 absl::Span<const tfrt::AsyncValueRef<CpuEvent>> DefinitionEvents() const { 105 return definition_events_; 106 } 107 UsageEvents()108 absl::Span<const tfrt::AsyncValueRef<CpuEvent>> UsageEvents() const { 109 return usage_events_; 110 } 111 112 void AddUsageEvents(absl::Span<tfrt::AsyncValueRef<CpuEvent>> events); 113 114 absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> 115 ConsumeBuffers(); 116 117 // Return the usage events for the buffers. After 118 // LockUseAndTransferUsageEvents is called, it is illegal to AddUsageEvent. 119 absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> 120 LockUseAndTransferUsageEvents(); 121 122 // Relinquishes ownership of the buffer's device memory, e.g., after the 123 // buffer is passed to a computation that aliases its inputs to outputs. 124 void ReleaseDeviceMemory(); 125 126 private: 127 bool is_tuple_; 128 // If tuple, tuple index table is created and stored. 129 std::shared_ptr<MaybeOwningCpuMemory> tuple_index_table_; 130 // If non-tuple, `buffers_` contains 1 buffer; otherwise all leaf buffers. 131 absl::InlinedVector<std::shared_ptr<MaybeOwningCpuMemory>, 4> buffers_; 132 // Definition events are associated with CPU operations that write to the 133 // buffers. 134 absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> definition_events_; 135 // Usage events are associated with CPU operations that read from the buffers. 136 absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> usage_events_; 137 // A callback to call when the TrackedTfrtCpuDeviceBuffer is about to be 138 // destroyed. 139 std::function<void()> on_delete_callback_; 140 }; 141 } // namespace xla 142 143 #endif // TENSORFLOW_COMPILER_XLA_PJRT_TRACKED_TFRT_CPU_DEVICE_BUFFER_H_ 144