• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/stream_executor/gpu/redzone_allocator.h"
17 
18 #include "absl/base/call_once.h"
19 #include "absl/container/fixed_array.h"
20 #include "absl/strings/str_format.h"
21 #include "absl/types/optional.h"
22 #include "tensorflow/core/framework/allocator.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/stream_executor/device_memory.h"
26 #include "tensorflow/stream_executor/gpu/asm_compiler.h"
27 #include "tensorflow/stream_executor/gpu/gpu_asm_opts.h"
28 #include "tensorflow/stream_executor/kernel.h"
29 #include "tensorflow/stream_executor/kernel_spec.h"
30 #include "tensorflow/stream_executor/stream.h"
31 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
32 
33 namespace stream_executor {
34 
35 // Rounds the value up to a multiple of the divisor by first calling CeilOfRatio
36 // then multiplying by the divisor. For example: RoundUpToNearest(13, 8) => 16
37 template <typename T>
RoundUpToNearest(T value,T divisor)38 static T RoundUpToNearest(T value, T divisor) {
39   return tensorflow::MathUtil::CeilOfRatio(value, divisor) * divisor;
40 }
41 
42 // The size of the redzone at the end of the user buffer is rounded up to a
43 // multiple of kRhsRedzoneAlign.  This simplifies the implementation a bit.
44 constexpr int64 kRhsRedzoneAlign = 4;
45 
46 using RedzoneCheckStatus = RedzoneAllocator::RedzoneCheckStatus;
47 
RedzoneAllocator(Stream * stream,DeviceMemoryAllocator * memory_allocator,GpuAsmOpts ptx_compilation_opts,int64 memory_limit,int64 redzone_size,uint8 redzone_pattern)48 RedzoneAllocator::RedzoneAllocator(Stream* stream,
49                                    DeviceMemoryAllocator* memory_allocator,
50                                    GpuAsmOpts ptx_compilation_opts,
51                                    int64 memory_limit, int64 redzone_size,
52                                    uint8 redzone_pattern)
53     : device_ordinal_(stream->parent()->device_ordinal()),
54       stream_(stream),
55       memory_limit_(memory_limit),
56       redzone_size_(RoundUpToNearest(
57           redzone_size,
58           static_cast<int64>(tensorflow::Allocator::kAllocatorAlignment))),
59       redzone_pattern_(redzone_pattern),
60       memory_allocator_(memory_allocator),
61       gpu_compilation_opts_(ptx_compilation_opts) {}
62 
AllocateBytes(int64 byte_size)63 port::StatusOr<DeviceMemory<uint8>> RedzoneAllocator::AllocateBytes(
64     int64 byte_size) {
65   CHECK_GE(byte_size, 0) << "byte_size must be positive.";
66   if (byte_size > GetMemoryLimitInBytes()) {
67     return port::Status(
68         port::error::RESOURCE_EXHAUSTED,
69         absl::StrFormat(
70             "Allocating %d bytes exceeds the memory limit of %d bytes.",
71             byte_size, GetMemoryLimitInBytes()));
72   }
73 
74   int64 rhs_slop = RoundUpToNearest(byte_size, kRhsRedzoneAlign) - byte_size;
75   TF_ASSIGN_OR_RETURN(
76       OwningDeviceMemory allocated_buffer,
77       memory_allocator_->Allocate(device_ordinal_,
78                                   byte_size + 2 * redzone_size_ + rhs_slop,
79                                   /*retry_on_failure=*/false));
80   allocated_bytes_excluding_redzones_ += byte_size;
81 
82   static_assert(sizeof(uint8) == 1, "Unexpected size");
83   DeviceMemory<uint8> allocated_buffer_memory(*allocated_buffer);
84 
85   DeviceMemory<uint8> lhs_redzone = stream_->parent()->GetSubBuffer(
86       &allocated_buffer_memory, 0, redzone_size_);
87 
88   DeviceMemory<uint8> data_chunk = stream_->parent()->GetSubBuffer(
89       &allocated_buffer_memory, redzone_size_, byte_size);
90 
91   // Split up the RHS redzone into two pieces:
92   //  - 0 to kRhsRedzoneAlign bytes adjacent to the user buffer, followed by
93   //  - redzone_size_ bytes.
94   // We do this because Stream::ThenMemset32 requires the buffer address and
95   // size to be aligned to 4 bytes.
96   DeviceMemory<uint8> rhs_redzone_slop = stream_->parent()->GetSubBuffer(
97       &allocated_buffer_memory, redzone_size_ + byte_size, rhs_slop);
98 
99   DeviceMemory<uint8> rhs_redzone_nonslop = stream_->parent()->GetSubBuffer(
100       &allocated_buffer_memory, redzone_size_ + byte_size + rhs_slop,
101       redzone_size_);
102 
103   uint8 pattern_arr[] = {redzone_pattern_, redzone_pattern_, redzone_pattern_,
104                          redzone_pattern_};
105   uint32 pattern32;
106   std::memcpy(&pattern32, pattern_arr, sizeof(pattern32));
107   stream_->ThenMemset32(&lhs_redzone, pattern32, redzone_size_);
108   if (rhs_slop != 0) {
109     stream_->ThenMemcpy(&rhs_redzone_slop, &pattern32, rhs_slop);
110   }
111   stream_->ThenMemset32(&rhs_redzone_nonslop, pattern32, redzone_size_);
112 
113   allocated_buffers_.emplace_back(std::move(allocated_buffer), byte_size);
114   return data_chunk;
115 }
116 
117 // PTX blob for the function which checks that every byte in
118 // input_buffer (length is buffer_length) is equal to redzone_pattern.
119 //
120 // On mismatch, increment the counter pointed to by out_mismatch_cnt_ptr.
121 //
122 // Generated from:
123 // __global__ void redzone_checker(unsigned char* input_buffer,
124 //                                 unsigned char redzone_pattern,
125 //                                 unsigned long long buffer_length,
126 //                                 int* out_mismatched_ptr) {
127 //   unsigned long long idx = threadIdx.x + blockIdx.x * blockDim.x;
128 //   if (idx >= buffer_length) return;
129 //   if (input_buffer[idx] != redzone_pattern) atomicAdd(out_mismatched_ptr, 1);
130 // }
131 //
132 // Code must compile for the oldest GPU XLA may be compiled for.
133 static const char* redzone_checker_ptx = R"(
134 .version 4.2
135 .target sm_30
136 .address_size 64
137 
138 .visible .entry redzone_checker(
139   .param .u64 input_buffer,
140   .param .u8 redzone_pattern,
141   .param .u64 buffer_length,
142   .param .u64 out_mismatch_cnt_ptr
143 )
144 {
145   .reg .pred   %p<3>;
146   .reg .b16   %rs<3>;
147   .reg .b32   %r<6>;
148   .reg .b64   %rd<8>;
149 
150   ld.param.u64   %rd6, [buffer_length];
151   mov.u32   %r1, %tid.x;
152   mov.u32   %r2, %ctaid.x;
153   mov.u32   %r3, %ntid.x;
154   mad.lo.s32   %r4, %r3, %r2, %r1;
155   cvt.u64.u32   %rd3, %r4;
156   setp.ge.u64   %p1, %rd3, %rd6;
157   @%p1 bra   LBB6_3;
158   ld.param.u8   %rs1, [redzone_pattern];
159   ld.param.u64   %rd4, [input_buffer];
160   cvta.to.global.u64   %rd2, %rd4;
161   add.s64   %rd7, %rd2, %rd3;
162   ld.global.u8   %rs2, [%rd7];
163   setp.eq.s16   %p2, %rs2, %rs1;
164   @%p2 bra   LBB6_3;
165   ld.param.u64   %rd5, [out_mismatch_cnt_ptr];
166   cvta.to.global.u64   %rd1, %rd5;
167   atom.global.add.u32   %r5, [%rd1], 1;
168 LBB6_3:
169   ret;
170 }
171 )";
172 
173 // The PTX in redzone_checker_ptx has to be launched with specified types
174 // in the specified order.
175 using ComparisonKernelT =
176     TypedKernel<DeviceMemory<uint8>, uint8, uint64, DeviceMemory<uint64>>;
177 
178 // Check that redzones weren't overwritten on a host.
179 //
180 // Slower, but gives a more useful error message.
CheckRedzoneHost(DeviceMemoryBase redzone,DeviceMemoryBase user_allocation,absl::string_view name,Stream * stream,uint8 redzone_pattern)181 static port::StatusOr<RedzoneCheckStatus> CheckRedzoneHost(
182     DeviceMemoryBase redzone, DeviceMemoryBase user_allocation,
183     absl::string_view name, Stream* stream, uint8 redzone_pattern) {
184   uint64 size = redzone.size();
185   auto redzone_data = absl::make_unique<uint8[]>(size);
186   TF_RETURN_IF_ERROR(stream->ThenMemcpy(redzone_data.get(), redzone, size)
187                          .BlockHostUntilDone());
188 
189   std::array<uint8, sizeof(uint64)> pattern_arr;
190   pattern_arr.fill(redzone_pattern);
191   uint64 pattern64;
192   std::memcpy(&pattern64, pattern_arr.data(), sizeof(uint64));
193 
194   int64 i;
195   for (i = 0; i + 7 < size; i += sizeof(uint64)) {
196     uint64 rz_value = *reinterpret_cast<uint64*>(&redzone_data[i]);
197     if (rz_value != pattern64) {
198       return RedzoneCheckStatus(name, user_allocation.opaque(), i, pattern64,
199                                 rz_value);
200     }
201   }
202   for (; i < size; ++i) {
203     uint8 rz_value = redzone_data[i];
204     if (rz_value != redzone_pattern) {
205       return RedzoneCheckStatus(name, user_allocation.opaque(), i,
206                                 redzone_pattern, rz_value);
207     }
208   }
209   return RedzoneCheckStatus::OK();
210 }
211 
212 // Run the redzone checker on the provided buffer redzone.
213 //
214 // Increment out_param if mismatch occurs.
RunRedzoneChecker(Stream * stream,const DeviceMemory<uint8> & redzone,uint8 redzone_pattern,const DeviceMemory<uint64> & out_param,const ComparisonKernelT & comparison_kernel)215 static void RunRedzoneChecker(Stream* stream,
216                               const DeviceMemory<uint8>& redzone,
217                               uint8 redzone_pattern,
218                               const DeviceMemory<uint64>& out_param,
219                               const ComparisonKernelT& comparison_kernel) {
220   StreamExecutor* executor = stream->parent();
221 
222   int64 num_elements = redzone.size();
223   int64 threads_per_block = std::min(
224       executor->GetDeviceDescription().threads_per_block_limit(), num_elements);
225   int64 block_count =
226       tensorflow::MathUtil::CeilOfRatio(num_elements, threads_per_block);
227 
228   stream->ThenLaunch(ThreadDim(threads_per_block), BlockDim(block_count),
229                      comparison_kernel, redzone, redzone_pattern,
230                      redzone.size(), out_param);
231 }
232 
233 // Since we reuse the same buffer for multiple checks, we re-initialize redzone
234 // with a NaN pattern after a failed check.
235 //
236 // This function is blocking, since redzone failing is a rare event.
ReinitializeRedzone(Stream * stream,DeviceMemoryBase redzone,uint8 redzone_pattern)237 static port::Status ReinitializeRedzone(Stream* stream,
238                                         DeviceMemoryBase redzone,
239                                         uint8 redzone_pattern) {
240   absl::FixedArray<uint8> redzone_array(redzone.size());
241   redzone_array.fill(redzone_pattern);
242   stream->ThenMemcpy(&redzone, redzone_array.data(), redzone.size());
243   TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
244   return port::Status::OK();
245 }
246 
247 // Check redzones around the user allocation.
248 //
249 // Precondition: the memory pointed out by out_param is zeroed.
CheckRedzonesForBuffer(Stream * stream,DeviceMemoryBase memory,const DeviceMemory<uint64> & out_param,const ComparisonKernelT & comparison_kernel,int64 user_allocation_size,uint64 redzone_size,uint8 redzone_pattern)250 static port::StatusOr<RedzoneCheckStatus> CheckRedzonesForBuffer(
251     Stream* stream, DeviceMemoryBase memory,
252     const DeviceMemory<uint64>& out_param,
253     const ComparisonKernelT& comparison_kernel, int64 user_allocation_size,
254     uint64 redzone_size, uint8 redzone_pattern) {
255   StreamExecutor* executor = stream->parent();
256   int64 rhs_slop =
257       RoundUpToNearest<int64>(user_allocation_size, kRhsRedzoneAlign) -
258       user_allocation_size;
259   CHECK_EQ(memory.size(), user_allocation_size + rhs_slop + 2 * redzone_size);
260 
261   DeviceMemory<uint8> buffer_uint8(memory);
262   DeviceMemory<uint8> lhs_redzone =
263       executor->GetSubBuffer(&buffer_uint8, 0,
264                              /*element_count=*/redzone_size);
265   DeviceMemory<uint8> user_allocation =
266       executor->GetSubBuffer(&buffer_uint8, redzone_size,
267                              /*element_count=*/user_allocation_size);
268   DeviceMemory<uint8> rhs_redzone =
269       executor->GetSubBuffer(&buffer_uint8, redzone_size + user_allocation_size,
270                              /*element_count=*/redzone_size + rhs_slop);
271 
272   RunRedzoneChecker(stream, lhs_redzone, redzone_pattern, out_param,
273                     comparison_kernel);
274   RunRedzoneChecker(stream, rhs_redzone, redzone_pattern, out_param,
275                     comparison_kernel);
276   int64 result;
277   CHECK_EQ(out_param.size(), sizeof(result));
278   stream->ThenMemcpy(&result, out_param, sizeof(result));
279   TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
280 
281   if (result != 0) {
282     TF_ASSIGN_OR_RETURN(RedzoneCheckStatus lhs_check,
283                         CheckRedzoneHost(lhs_redzone, user_allocation, "LHS",
284                                          stream, redzone_pattern));
285     TF_ASSIGN_OR_RETURN(RedzoneCheckStatus rhs_check,
286                         CheckRedzoneHost(rhs_redzone, user_allocation, "RHS",
287                                          stream, redzone_pattern));
288 
289     CHECK(!lhs_check.ok() || !rhs_check.ok())
290         << "Mismatched results with host and device comparison";
291 
292     TF_RETURN_IF_ERROR(
293         ReinitializeRedzone(stream, lhs_redzone, redzone_pattern));
294     TF_RETURN_IF_ERROR(
295         ReinitializeRedzone(stream, rhs_redzone, redzone_pattern));
296     return !lhs_check.ok() ? lhs_check : rhs_check;
297   }
298 
299   return RedzoneCheckStatus::OK();
300 }
301 
CheckRedzones() const302 port::StatusOr<RedzoneCheckStatus> RedzoneAllocator::CheckRedzones() const {
303   StreamExecutor* executor = stream_->parent();
304 
305   absl::Span<const uint8> compiled_ptx = {};
306   port::StatusOr<absl::Span<const uint8>> compiled_ptx_or =
307       CompileGpuAsmOrGetCached(executor->device_ordinal(), redzone_checker_ptx,
308                                gpu_compilation_opts_);
309   if (compiled_ptx_or.ok()) {
310     compiled_ptx = compiled_ptx_or.ValueOrDie();
311   } else {
312     static absl::once_flag ptxas_not_found_logged;
313     absl::call_once(ptxas_not_found_logged, [&]() {
314       LOG(WARNING) << compiled_ptx_or.status().ToString()
315                    << "\nRelying on driver to perform ptx compilation. "
316                    << "\nModify $PATH to customize ptxas location."
317                    << "\nThis message will be only logged once.";
318     });
319   }
320 
321   ScopedDeviceMemory<uint64> out_param =
322       executor->AllocateOwnedScalar<uint64>();
323   stream_->ThenMemZero(out_param.ptr(), sizeof(uint64));
324 
325   TF_ASSIGN_OR_RETURN(
326       std::unique_ptr<ComparisonKernelT> comparison_kernel,
327       (executor->CreateTypedKernel<DeviceMemory<uint8>, uint8, uint64,
328                                    DeviceMemory<uint64>>(
329           "redzone_checker", redzone_checker_ptx, compiled_ptx)));
330 
331   for (const auto& buf_and_size : allocated_buffers_) {
332     TF_ASSIGN_OR_RETURN(
333         RedzoneCheckStatus redzone_status,
334         CheckRedzonesForBuffer(stream_, *buf_and_size.first, out_param.cref(),
335                                *comparison_kernel, buf_and_size.second,
336                                redzone_size_, redzone_pattern_));
337     if (!redzone_status.ok()) {
338       return redzone_status;
339     }
340   }
341 
342   return RedzoneCheckStatus::OK();
343 }
344 
RedzoneFailureMsg() const345 std::string RedzoneCheckStatus::RedzoneFailureMsg() const {
346   return absl::StrFormat(
347       "Redzone mismatch in %s redzone of buffer %p at offset %d; "
348       "expected %08x but was %08x.",
349       buffer_name, user_buffer_address, offset, expected_value, actual_value);
350 }
351 
352 }  // namespace stream_executor
353