• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_STREAM_EXECUTOR_TPU_NONCOPYABLE_BUFFER_H_
17 #define TENSORFLOW_STREAM_EXECUTOR_TPU_NONCOPYABLE_BUFFER_H_
18 
19 #include <memory>
20 
21 #include "absl/base/casts.h"
22 #include "absl/types/optional.h"
23 #include "absl/types/span.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/mem.h"
26 
27 namespace tensorflow {
28 namespace tpu {
29 
30 // Uncopyable buffer type with optional ownership of the underlying data. If
31 // data is not owned then ensuring lifetime of the data exceeds the lifetime of
32 // the buffer is the responsibility of the user.
33 class NoncopyableBuffer {
34  public:
35   NoncopyableBuffer() = default;
36 
37   // Allocate an owning buffer without initializing the data. Useful when it
38   // will be filled by a subsequent function and want to avoid initialization
39   // cost. Size is specified in number of bytes.
NoncopyableBuffer(size_t size)40   explicit NoncopyableBuffer(size_t size)
41       : data_(static_cast<uint8_t*>(malloc(size)), free),
42         buf_(data_.get()),
43         size_(size) {}
44 
45   // Allocates an owning buffer and initializes it with the specified data. Size
46   // is specified in number of uint32's.
NoncopyableBuffer(size_t size_in_u32s,absl::optional<uint32_t> value)47   NoncopyableBuffer(size_t size_in_u32s, absl::optional<uint32_t> value)
48       : NoncopyableBuffer(size_in_u32s * sizeof(uint32_t)) {
49 #ifndef MEMORY_SANITIZER
50     if (!value.has_value()) {
51       return;
52     }
53 #endif
54     uint32_t* data_u32 = reinterpret_cast<uint32_t*>(data_.get());
55     uint32_t v = value.value_or(0);
56     for (uint32_t *p = data_u32, *e = data_u32 + size_in_u32s; p < e; ++p) {
57       *p = v;
58     }
59   }
60 
61   // Directly use buf pointer without copying it to owning data_. This delays
62   // the memcpy until mutable access is requested. "buf" is not owned by this
63   // data structure, so it is the user's duty to ensure the live range of "buf"
64   // is longer than this data structure.
NoncopyableBuffer(const uint8_t * buf,size_t size)65   NoncopyableBuffer(const uint8_t* buf, size_t size)  // Size is in uint8's.
66       : buf_(buf), size_(size) {}
NoncopyableBuffer(const uint32_t * buf,size_t size_in_u32s)67   NoncopyableBuffer(const uint32_t* buf,
68                     size_t size_in_u32s)  // Size is in uint32_t's.
69       : buf_(buf), size_(size_in_u32s * sizeof(uint32_t)) {}
70 
71   NoncopyableBuffer(const NoncopyableBuffer&) = delete;
72   NoncopyableBuffer(NoncopyableBuffer&&) = default;
73 
74   NoncopyableBuffer& operator=(const NoncopyableBuffer&) = delete;
75   NoncopyableBuffer& operator=(NoncopyableBuffer&&) = default;
76 
77   // Ensure that the buffer owns the data and returns a mutable view into the
78   // owned data for modification.
79   template <typename T>
mutable_data()80   absl::Span<T> mutable_data() {
81     static_assert(std::is_arithmetic<T>::value, "Must be arithmetic type.");
82     EnsureDataOwned();
83     DCHECK_EQ(size_ % sizeof(T), 0);
84     return absl::Span<T>(reinterpret_cast<T*>(data_.get()), size_ / sizeof(T));
85   }
86 
87   template <typename T>
const_data()88   absl::Span<const T> const_data() const {
89     static_assert(std::is_arithmetic<T>::value, "Must be arithmetic type.");
90     DCHECK_EQ(size_ % sizeof(T), 0);
91     return absl::Span<const T>(static_cast<const T*>(buf_), size_ / sizeof(T));
92   }
93   // Clone the content to a given buffer.
CloneTo(void * buf)94   void CloneTo(void* buf) { memcpy(buf, buf_, size_); }
95 
96   // Return true if data is owned by this buffer (have been copied to `data_`).
owns_data()97   bool owns_data() const { return data_ != nullptr; }
98 
99   // Returns a copy of the object that owns its buffer.
100   NoncopyableBuffer Clone(size_t alignment = 1) const {
101     auto clone = alignment <= 1
102                      ? NoncopyableBuffer(size_)
103                      : NoncopyableBuffer(AlignedAlloc(size_, alignment), size_);
104     memcpy(clone.data_.get(), buf_, size_);
105     return clone;
106   }
107 
108   // Ensure that the buffer owns the data.
EnsureDataOwned()109   void EnsureDataOwned() {
110     if (data_ == nullptr) {
111       data_ = OwnedDataPtr(static_cast<uint8_t*>(malloc(size_)), free);
112       memcpy(data_.get(), buf_, size_);
113       buf_ = data_.get();
114     }
115   }
116 
117  private:
118   using OwnedDataPtr = std::unique_ptr<uint8_t[], decltype(port::AlignedFree)*>;
NoncopyableBuffer(OwnedDataPtr data,size_t size)119   NoncopyableBuffer(OwnedDataPtr data, size_t size)
120       : data_(std::move(data)), buf_(data_.get()), size_(size) {}
121 
AlignedAlloc(size_t size,size_t alignment)122   static OwnedDataPtr AlignedAlloc(size_t size, size_t alignment) {
123     return OwnedDataPtr(
124         static_cast<uint8_t*>(port::AlignedMalloc(size, alignment)),
125         port::AlignedFree);
126   }
127   // If data_ != nullptr then buf_ == data_.get()
128   OwnedDataPtr data_ = {nullptr, free};  // Owning data pointer.
129   const void* buf_;                      // Non-owning data pointer.
130   size_t size_;                          // Size in number of bytes.
131 };
132 
133 }  // namespace tpu
134 }  // namespace tensorflow
135 
136 #endif  // TENSORFLOW_STREAM_EXECUTOR_TPU_NONCOPYABLE_BUFFER_H_
137