1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/runtime/memory_mapper.h"
17
18 #include <sys/mman.h>
19 #include <sys/types.h>
20 #include <unistd.h>
21
22 #include <memory>
23 #include <system_error> // NOLINT
24
25 #include "llvm/ExecutionEngine/SectionMemoryManager.h"
26
27 // Support for memfd_create(2) was added in glibc v2.27.
28 #if defined(__linux__) && defined(__GLIBC__) && defined(__GLIBC_PREREQ)
29 #if __GLIBC_PREREQ(2, 27)
30 #define XLA_RUNTIME_ENABLE_MEMORY_MAPPER
31 #endif // __GLIBC_PREREQ(2, 27)
32 #endif // __linux__ and __GLIBC__ and __GLIBC_PREREQ
33
34 namespace xla {
35 namespace runtime {
36
37 //===-----------------------------------------------------------------------===/
38 #ifndef XLA_RUNTIME_ENABLE_MEMORY_MAPPER
39 //===-----------------------------------------------------------------------===/
40
Create(llvm::StringRef name)41 std::unique_ptr<XlaRuntimeMemoryMapper> XlaRuntimeMemoryMapper::Create(
42 llvm::StringRef name) {
43 return nullptr;
44 }
45
allocateMappedMemory(llvm::SectionMemoryManager::AllocationPurpose purpose,size_t len,const llvm::sys::MemoryBlock * const near_block,unsigned prot_flags,std::error_code & error_code)46 llvm::sys::MemoryBlock XlaRuntimeMemoryMapper::allocateMappedMemory(
47 llvm::SectionMemoryManager::AllocationPurpose purpose, size_t len,
48 const llvm::sys::MemoryBlock* const near_block, unsigned prot_flags,
49 std::error_code& error_code) {
50 llvm_unreachable("XlaRuntimeMemoryMapper is not implemented");
51 }
52
protectMappedMemory(const llvm::sys::MemoryBlock & block,unsigned prot_flags)53 std::error_code XlaRuntimeMemoryMapper::protectMappedMemory(
54 const llvm::sys::MemoryBlock& block, unsigned prot_flags) {
55 llvm_unreachable("XlaRuntimeMemoryMapper is not implemented");
56 }
57
releaseMappedMemory(llvm::sys::MemoryBlock & block)58 std::error_code XlaRuntimeMemoryMapper::releaseMappedMemory(
59 llvm::sys::MemoryBlock& block) {
60 llvm_unreachable("XlaRuntimeMemoryMapper is not implemented");
61 }
62
63 //===-----------------------------------------------------------------------===/
64 #else // XLA_RUNTIME_ENABLE_MEMORY_MAPPER
65 //===-----------------------------------------------------------------------===/
66
67 namespace {
68
69 using MemoryMapper = llvm::SectionMemoryManager::MemoryMapper;
70 using AllocationPurpose = llvm::SectionMemoryManager::AllocationPurpose;
71
72 int retrying_close(int fd) {
73 return RetryOnEINTR([&]() { return close(fd); }, -1);
74 }
75
76 int retrying_ftruncate(int fd, off_t length) {
77 return RetryOnEINTR([&]() { return ftruncate(fd, length); }, -1);
78 }
79
80 int retrying_memfd_create(const char* name, unsigned int flags) {
81 return RetryOnEINTR([&]() { return memfd_create(name, flags); }, -1);
82 }
83
84 void* retrying_mmap(void* addr, size_t length, int prot, int flags, int fd,
85 off_t offset) {
86 return RetryOnEINTR(
87 [&]() { return mmap(addr, length, prot, flags, fd, offset); },
88 MAP_FAILED);
89 }
90
91 int retrying_mprotect(void* addr, size_t len, int prot) {
92 return RetryOnEINTR([&]() { return mprotect(addr, len, prot); }, -1);
93 }
94
95 int retrying_munmap(void* addr, size_t length) {
96 return RetryOnEINTR([&]() { return munmap(addr, length); }, -1);
97 }
98
99 int64_t retrying_sysconf(int name) {
100 return RetryOnEINTR([&]() { return sysconf(name); }, -1);
101 }
102
103 int ToPosixProtectionFlags(unsigned flags) {
104 int ret = 0;
105 if (flags & llvm::sys::Memory::MF_READ) {
106 ret |= PROT_READ;
107 }
108 if (flags & llvm::sys::Memory::MF_WRITE) {
109 ret |= PROT_WRITE;
110 }
111 if (flags & llvm::sys::Memory::MF_EXEC) {
112 ret |= PROT_EXEC;
113 }
114 return ret;
115 }
116
117 } // namespace
118
119 std::unique_ptr<XlaRuntimeMemoryMapper> XlaRuntimeMemoryMapper::Create(
120 llvm::StringRef name) {
121 std::unique_ptr<XlaRuntimeMemoryMapper> ret(new XlaRuntimeMemoryMapper(name));
122 return ret;
123 }
124
125 llvm::sys::MemoryBlock XlaRuntimeMemoryMapper::allocateMappedMemory(
126 AllocationPurpose purpose, size_t len,
127 const llvm::sys::MemoryBlock* const near_block, unsigned prot_flags,
128 std::error_code& error_code) {
129 auto round_up = [](size_t size, size_t align) {
130 return (size + align - 1) & ~(align - 1);
131 };
132 int64_t page_size = retrying_sysconf(_SC_PAGESIZE);
133 len = round_up(len, page_size);
134
135 int fd = -1;
136 int mmap_flags = MAP_PRIVATE;
137 if (purpose == llvm::SectionMemoryManager::AllocationPurpose::Code) {
138 // Try to get a truncated memfd. If that fails, use an anonymous mapping.
139 fd = retrying_memfd_create(name_.c_str(), 0);
140 if (fd != -1 && retrying_ftruncate(fd, len) == -1) {
141 retrying_close(fd);
142 fd = -1;
143 }
144 }
145 if (fd == -1) {
146 mmap_flags |= MAP_ANONYMOUS;
147 }
148 prot_flags = ToPosixProtectionFlags(prot_flags);
149 void* map = retrying_mmap(nullptr, len, prot_flags, mmap_flags, fd, 0);
150 // Regardless of the outcome of the mmap, we can close the fd now.
151 if (fd != -1) retrying_close(fd);
152
153 if (map == MAP_FAILED) {
154 error_code = std::error_code(errno, std::generic_category());
155 return llvm::sys::MemoryBlock();
156 }
157 return llvm::sys::MemoryBlock(map, len);
158 }
159
160 std::error_code XlaRuntimeMemoryMapper::protectMappedMemory(
161 const llvm::sys::MemoryBlock& block, unsigned prot_flags) {
162 int64_t page_size = retrying_sysconf(_SC_PAGESIZE);
163 uintptr_t base = reinterpret_cast<uintptr_t>(block.base());
164 uintptr_t rounded_down_base = base & ~(page_size - 1);
165 size_t size = block.allocatedSize();
166 size += base - rounded_down_base;
167
168 prot_flags = ToPosixProtectionFlags(prot_flags);
169 void* addr = reinterpret_cast<void*>(rounded_down_base);
170 if (retrying_mprotect(addr, size, prot_flags) == -1) {
171 return std::error_code(errno, std::generic_category());
172 }
173 return std::error_code();
174 }
175
176 std::error_code XlaRuntimeMemoryMapper::releaseMappedMemory(
177 llvm::sys::MemoryBlock& block) {
178 if (retrying_munmap(block.base(), block.allocatedSize()) == -1) {
179 return std::error_code(errno, std::generic_category());
180 }
181 return std::error_code();
182 }
183
184 #endif // XLA_RUNTIME_ENABLE_MEMORY_MAPPER
185
186 } // namespace runtime
187 } // namespace xla
188