1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_NONCOPYABLE_BUFFER_H_ 17 #define TENSORFLOW_STREAM_EXECUTOR_TPU_NONCOPYABLE_BUFFER_H_ 18 19 #include <memory> 20 21 #include "absl/base/casts.h" 22 #include "absl/types/optional.h" 23 #include "absl/types/span.h" 24 #include "tensorflow/core/platform/logging.h" 25 #include "tensorflow/core/platform/mem.h" 26 27 namespace tensorflow { 28 namespace tpu { 29 30 // Uncopyable buffer type with optional ownership of the underlying data. If 31 // data is not owned then ensuring lifetime of the data exceeds the lifetime of 32 // the buffer is the responsibility of the user. 33 class NoncopyableBuffer { 34 public: 35 NoncopyableBuffer() = default; 36 37 // Allocate an owning buffer without initializing the data. Useful when it 38 // will be filled by a subsequent function and want to avoid initialization 39 // cost. Size is specified in number of bytes. NoncopyableBuffer(size_t size)40 explicit NoncopyableBuffer(size_t size) 41 : data_(static_cast<uint8_t*>(malloc(size)), free), 42 buf_(data_.get()), 43 size_(size) {} 44 45 // Allocates an owning buffer and initializes it with the specified data. Size 46 // is specified in number of uint32's. NoncopyableBuffer(size_t size_in_u32s,absl::optional<uint32_t> value)47 NoncopyableBuffer(size_t size_in_u32s, absl::optional<uint32_t> value) 48 : NoncopyableBuffer(size_in_u32s * sizeof(uint32_t)) { 49 #ifndef MEMORY_SANITIZER 50 if (!value.has_value()) { 51 return; 52 } 53 #endif 54 uint32_t* data_u32 = reinterpret_cast<uint32_t*>(data_.get()); 55 uint32_t v = value.value_or(0); 56 for (uint32_t *p = data_u32, *e = data_u32 + size_in_u32s; p < e; ++p) { 57 *p = v; 58 } 59 } 60 61 // Directly use buf pointer without copying it to owning data_. This delays 62 // the memcpy until mutable access is requested. "buf" is not owned by this 63 // data structure, so it is the user's duty to ensure the live range of "buf" 64 // is longer than this data structure. NoncopyableBuffer(const uint8_t * buf,size_t size)65 NoncopyableBuffer(const uint8_t* buf, size_t size) // Size is in uint8's. 66 : buf_(buf), size_(size) {} NoncopyableBuffer(const uint32_t * buf,size_t size_in_u32s)67 NoncopyableBuffer(const uint32_t* buf, 68 size_t size_in_u32s) // Size is in uint32_t's. 69 : buf_(buf), size_(size_in_u32s * sizeof(uint32_t)) {} 70 71 NoncopyableBuffer(const NoncopyableBuffer&) = delete; 72 NoncopyableBuffer(NoncopyableBuffer&&) = default; 73 74 NoncopyableBuffer& operator=(const NoncopyableBuffer&) = delete; 75 NoncopyableBuffer& operator=(NoncopyableBuffer&&) = default; 76 77 // Ensure that the buffer owns the data and returns a mutable view into the 78 // owned data for modification. 79 template <typename T> mutable_data()80 absl::Span<T> mutable_data() { 81 static_assert(std::is_arithmetic<T>::value, "Must be arithmetic type."); 82 EnsureDataOwned(); 83 DCHECK_EQ(size_ % sizeof(T), 0); 84 return absl::Span<T>(reinterpret_cast<T*>(data_.get()), size_ / sizeof(T)); 85 } 86 87 template <typename T> const_data()88 absl::Span<const T> const_data() const { 89 static_assert(std::is_arithmetic<T>::value, "Must be arithmetic type."); 90 DCHECK_EQ(size_ % sizeof(T), 0); 91 return absl::Span<const T>(static_cast<const T*>(buf_), size_ / sizeof(T)); 92 } 93 // Clone the content to a given buffer. CloneTo(void * buf)94 void CloneTo(void* buf) { memcpy(buf, buf_, size_); } 95 96 // Return true if data is owned by this buffer (have been copied to `data_`). owns_data()97 bool owns_data() const { return data_ != nullptr; } 98 99 // Returns a copy of the object that owns its buffer. 100 NoncopyableBuffer Clone(size_t alignment = 1) const { 101 auto clone = alignment <= 1 102 ? NoncopyableBuffer(size_) 103 : NoncopyableBuffer(AlignedAlloc(size_, alignment), size_); 104 memcpy(clone.data_.get(), buf_, size_); 105 return clone; 106 } 107 108 // Ensure that the buffer owns the data. EnsureDataOwned()109 void EnsureDataOwned() { 110 if (data_ == nullptr) { 111 data_ = OwnedDataPtr(static_cast<uint8_t*>(malloc(size_)), free); 112 memcpy(data_.get(), buf_, size_); 113 buf_ = data_.get(); 114 } 115 } 116 117 private: 118 using OwnedDataPtr = std::unique_ptr<uint8_t[], decltype(port::AlignedFree)*>; NoncopyableBuffer(OwnedDataPtr data,size_t size)119 NoncopyableBuffer(OwnedDataPtr data, size_t size) 120 : data_(std::move(data)), buf_(data_.get()), size_(size) {} 121 AlignedAlloc(size_t size,size_t alignment)122 static OwnedDataPtr AlignedAlloc(size_t size, size_t alignment) { 123 return OwnedDataPtr( 124 static_cast<uint8_t*>(port::AlignedMalloc(size, alignment)), 125 port::AlignedFree); 126 } 127 // If data_ != nullptr then buf_ == data_.get() 128 OwnedDataPtr data_ = {nullptr, free}; // Owning data pointer. 129 const void* buf_; // Non-owning data pointer. 130 size_t size_; // Size in number of bytes. 131 }; 132 133 } // namespace tpu 134 } // namespace tensorflow 135 136 #endif // TENSORFLOW_STREAM_EXECUTOR_TPU_NONCOPYABLE_BUFFER_H_ 137