1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/core/profiler/utils/buffer_pool.h" 17 18 #include "tensorflow/core/platform/logging.h" 19 #include "tensorflow/core/platform/mem.h" 20 #include "tensorflow/core/platform/mutex.h" 21 22 namespace tensorflow { 23 namespace profiler { 24 BufferPool(size_t buffer_size_in_bytes)25BufferPool::BufferPool(size_t buffer_size_in_bytes) 26 : buffer_size_in_bytes_(buffer_size_in_bytes) {} 27 ~BufferPool()28BufferPool::~BufferPool() { DestroyAllBuffers(); } 29 GetOrCreateBuffer()30uint8_t* BufferPool::GetOrCreateBuffer() { 31 // Get a relinquished buffer if it exists. 32 { 33 mutex_lock lock(buffers_mutex_); 34 if (!buffers_.empty()) { 35 uint8_t* buffer = buffers_.back(); 36 buffers_.pop_back(); 37 if (!buffer) { 38 LOG(ERROR) << "A reused buffer must not be null!"; 39 return nullptr; 40 } 41 VLOG(3) << "Reused Buffer, buffer=" << std::hex 42 << reinterpret_cast<uintptr_t>(buffer) << std::dec; 43 return buffer; 44 } 45 } 46 47 // Allocate and return a new buffer. 48 constexpr size_t kBufferAlignSize = 8; 49 uint8_t* buffer = reinterpret_cast<uint8_t*>( 50 port::AlignedMalloc(buffer_size_in_bytes_, kBufferAlignSize)); 51 if (buffer == nullptr) { 52 LOG(WARNING) << "Buffer not allocated."; 53 return nullptr; 54 } 55 VLOG(3) << "Allocated Buffer, buffer=" << std::hex 56 << reinterpret_cast<uintptr_t>(buffer) << std::dec 57 << " size=" << buffer_size_in_bytes_; 58 return buffer; 59 } 60 ReclaimBuffer(uint8_t * buffer)61void BufferPool::ReclaimBuffer(uint8_t* buffer) { 62 mutex_lock lock(buffers_mutex_); 63 64 buffers_.push_back(buffer); 65 VLOG(3) << "Reclaimed Buffer, buffer=" << std::hex 66 << reinterpret_cast<uintptr_t>(buffer) << std::dec; 67 } 68 DestroyAllBuffers()69void BufferPool::DestroyAllBuffers() { 70 mutex_lock lock(buffers_mutex_); 71 for (uint8_t* buffer : buffers_) { 72 VLOG(3) << "Freeing Buffer, buffer:" << std::hex 73 << reinterpret_cast<uintptr_t>(buffer) << std::dec; 74 port::AlignedFree(buffer); 75 } 76 buffers_.clear(); 77 } 78 GetBufferSizeInBytes() const79size_t BufferPool::GetBufferSizeInBytes() const { 80 return buffer_size_in_bytes_; 81 } 82 83 } // namespace profiler 84 } // namespace tensorflow 85