• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
17 
18 #include <algorithm>
19 #include <cmath>
20 #include <utility>
21 
22 #include "absl/base/call_once.h"
23 #include "absl/strings/str_replace.h"
24 #include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h"
25 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
26 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/stream_executor/device_memory.h"
31 #include "tensorflow/stream_executor/gpu/asm_compiler.h"
32 #include "tensorflow/stream_executor/kernel.h"
33 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
34 
35 namespace xla {
36 namespace gpu {
37 
38 static constexpr double kTolerance = 0.1f;
39 
40 // Comparison kernel code: compare two buffers of
41 // bf16/fp16/fp32/fp64/int8_t/int32_t of length buffer_length where the relative
42 // error does not exceed the passed rel_error_threshold. Write the number of
43 // mismatches into out parameter mismatch_count.
44 //
45 // NaN's are considered equal, and for half's we clamp all numbers to largest
46 // and smallest numbers representable to avoid miscomparisons due to overflows.
47 //
48 // The PTX below is compiled from the following CUDA code:
49 //
50 // #include <cuda_fp16.h>
51 // #include <cuda_bf16.h>
52 //
53 // namespace {
54 //
55 // __device__ __inline__ float __xla_buffer_comparator_canonicalize(float input)
56 // {
57 //   // All fp16 infinities are treated as 65505 or -65505, in order to avoid
58 //   // differences due to overflows.
59 //   return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f));
60 // }
61 //
62 // } // end anonymous namespace
63 //
64 // extern "C" {  // avoid name mangling
65 //
66 //
67 // __global__ void __xla_fp16_comparison(__half* buffer_a, __half* buffer_b,
68 //                                       float rel_error_threshold,
69 //                                       unsigned long long buffer_length,
70 //                                       int* mismatch_count) {
71 //   int idx = threadIdx.x + blockIdx.x * blockDim.x;
72 //   if (idx >= buffer_length) return;
73 //   float elem_a = __half2float(buffer_a[idx]);
74 //   float elem_b = __half2float(buffer_b[idx]);
75 //   elem_a = __xla_buffer_comparator_canonicalize(elem_a);
76 //   elem_b = __xla_buffer_comparator_canonicalize(elem_b);
77 //   if (isnan(elem_a) && isnan(elem_b)) return;
78 //
79 //   float rel_error = abs(elem_a - elem_b)
80 //       / (max(abs(elem_a), abs(elem_b)) + 1);
81 //
82 //   if (rel_error > rel_error_threshold || isnan(rel_error))
83 //     atomicAdd(mismatch_count, 1);
84 // }
85 //
86 // __global__ void __xla_fp32_comparison(float* buffer_a, float* buffer_b,
87 //                                       float rel_error_threshold,
88 //                                       unsigned long long buffer_length,
89 //                                       int* mismatch_count) {
90 //   int idx = threadIdx.x + blockIdx.x * blockDim.x;
91 //   if (idx >= buffer_length) return;
92 //   float elem_a = buffer_a[idx];
93 //   float elem_b = buffer_b[idx];
94 //   if (isnan(elem_a) && isnan(elem_b)) return;
95 //   if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b))
96 //     return;
97 //
98 //   float rel_error = abs(elem_a - elem_b)
99 //       / (max(abs(elem_a), abs(elem_b)) + 1);
100 //   if (rel_error > rel_error_threshold || isnan(rel_error))
101 //     atomicAdd(mismatch_count, 1);
102 // }
103 //
104 // __global__ void __xla_fp64_comparison(double* buffer_a, double* buffer_b,
105 //                                       float rel_error_threshold,
106 //                                       unsigned long long buffer_length,
107 //                                       int* mismatch_count) {
108 //   int idx = threadIdx.x + blockIdx.x * blockDim.x;
109 //   if (idx >= buffer_length) return;
110 //
111 //   double elem_a = buffer_a[idx];
112 //   double elem_b = buffer_b[idx];
113 //   if (isnan(elem_a) && isnan(elem_b)) return;
114 //   if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b))
115 //     return;
116 //   double rel_error = abs(elem_a - elem_b)
117 //       / (max(abs(elem_a), abs(elem_b)) + 1);
118 //   if (rel_error > rel_error_threshold || isnan(rel_error))
119 //     atomicAdd(mismatch_count, 1);
120 // }
121 //
122 // __global__ void __xla_bf16_comparison(__nv_bfloat16* buffer_a,
123 //                                       __nv_bfloat16* buffer_b,
124 //                                       float rel_error_threshold,
125 //                                       unsigned long long buffer_length,
126 //                                       int* mismatch_count) {
127 //   int idx = threadIdx.x + blockIdx.x * blockDim.x;
128 //   if (idx >= buffer_length) return;
129 //   float elem_a = __bfloat162float(buffer_a[idx]);
130 //   float elem_b = __bfloat162float(buffer_b[idx]);
131 //   elem_a = __xla_buffer_comparator_canonicalize(elem_a);
132 //   elem_b = __xla_buffer_comparator_canonicalize(elem_b);
133 //   if (isnan(elem_a) && isnan(elem_b)) return;
134 //
135 //   float rel_error = abs(elem_a - elem_b)
136 //       / (max(abs(elem_a), abs(elem_b)) + 1);
137 //
138 //   if (rel_error > rel_error_threshold || isnan(rel_error))
139 //     atomicAdd(mismatch_count, 1);
140 // }
141 //
142 // // TODO(b/191520348): The comparison below requires exact equality.
143 // __global__ void __xla_int8_comparison(int8_t* buffer_a, int8_t* buffer_b,
144 //                                       float rel_error_threshold,
145 //                                       unsigned long long buffer_length,
146 //                                       int* mismatch_count) {
147 //   int idx = threadIdx.x + blockIdx.x * blockDim.x;
148 //   if (idx >= buffer_length) return;
149 //   float a = buffer_a[idx];
150 //   float b = buffer_b[idx];
151 //   float rel_error = abs(a - b) / (max(abs(a), abs(b)) + 1);
152 //   if (rel_error > rel_error_threshold || isnan(rel_error))
153 //       atomicAdd(mismatch_count, 1);
154 // }
155 //
156 // __global__ void __xla_int32_comparison(int* buffer_a, int* buffer_b,
157 //                                        float rel_error_threshold,
158 //                                        unsigned long long buffer_length,
159 //                                        int* mismatch_count) {
160 //   int idx = threadIdx.x + blockIdx.x * blockDim.x;
161 //   if (idx >= buffer_length) return;
162 //   float elem_a = static_cast<float>(buffer_a[idx]);
163 //   float elem_b = static_cast<float>(buffer_b[idx]);
164 //   float rel_error = abs(elem_a - elem_b)
165 //       / (max(abs(elem_a), abs(elem_b)) + 1);
166 //   if (rel_error > rel_error_threshold || isnan(rel_error))
167 //     atomicAdd(mismatch_count, 1);
168 // }
169 // } // end extern declaration
170 static const char* buffer_compare_ptx = R"(
171 //
172 // Generated by LLVM NVPTX Back-End
173 //
174 
175 .version 4.2
176 .target sm_30
177 .address_size 64
178 
179 // .globl__xla_fp16_comparison
180 
181 .visible .entry __xla_fp16_comparison(
182 .param .u64 __xla_fp16_comparison_param_0,
183 .param .u64 __xla_fp16_comparison_param_1,
184 .param .f32 __xla_fp16_comparison_param_2,
185 .param .u64 __xla_fp16_comparison_param_3,
186 .param .u64 __xla_fp16_comparison_param_4
187 )
188 {
189 .reg .pred %p<10>;
190 .reg .b16 %rs<3>;
191 .reg .f32 %f<20>;
192 .reg .b32 %r<6>;
193 .reg .b64 %rd<12>;
194 
195 ld.param.u64 %rd8, [__xla_fp16_comparison_param_3];
196 mov.u32 %r1, %tid.x;
197 mov.u32 %r2, %ctaid.x;
198 mov.u32 %r3, %ntid.x;
199 mad.lo.s32 %r4, %r3, %r2, %r1;
200 cvt.s64.s32 %rd4, %r4;
201 setp.ge.u64 %p1, %rd4, %rd8;
202 @%p1 bra LBB0_4;
203 ld.param.u64 %rd5, [__xla_fp16_comparison_param_0];
204 ld.param.u64 %rd7, [__xla_fp16_comparison_param_1];
205 cvta.to.global.u64 %rd2, %rd7;
206 cvta.to.global.u64 %rd3, %rd5;
207 shl.b64 %rd9, %rd4, 1;
208 add.s64 %rd10, %rd3, %rd9;
209 ld.global.u16 %rs1, [%rd10];
210 // begin inline asm
211 {  cvt.f32.f16 %f6, %rs1;}
212 
213 // end inline asm
214 add.s64 %rd11, %rd2, %rd9;
215 ld.global.u16 %rs2, [%rd11];
216 // begin inline asm
217 {  cvt.f32.f16 %f7, %rs2;}
218 
219 // end inline asm
220 abs.f32 %f8, %f6;
221 setp.gtu.f32 %p2, %f8, 0f7F800000;
222 min.f32 %f9, %f6, 0f477FE100;
223 max.f32 %f10, %f9, 0fC77FE100;
224 selp.f32 %f1, %f6, %f10, %p2;
225 abs.f32 %f11, %f7;
226 setp.gtu.f32 %p3, %f11, 0f7F800000;
227 min.f32 %f12, %f7, 0f477FE100;
228 max.f32 %f13, %f12, 0fC77FE100;
229 selp.f32 %f2, %f7, %f13, %p3;
230 abs.f32 %f3, %f1;
231 setp.gtu.f32 %p4, %f3, 0f7F800000;
232 abs.f32 %f4, %f2;
233 setp.gtu.f32 %p5, %f4, 0f7F800000;
234 and.pred  %p6, %p4, %p5;
235 @%p6 bra LBB0_4;
236 ld.param.f32 %f5, [__xla_fp16_comparison_param_2];
237 sub.f32 %f14, %f1, %f2;
238 abs.f32 %f15, %f14;
239 max.f32 %f16, %f3, %f4;
240 add.f32 %f17, %f16, 0f3F800000;
241 div.rn.f32 %f18, %f15, %f17;
242 setp.gt.f32 %p7, %f18, %f5;
243 abs.f32 %f19, %f18;
244 setp.gtu.f32 %p8, %f19, 0f7F800000;
245 or.pred  %p9, %p7, %p8;
246 @!%p9 bra LBB0_4;
247 bra.uni LBB0_3;
248 LBB0_3:
249 ld.param.u64 %rd6, [__xla_fp16_comparison_param_4];
250 cvta.to.global.u64 %rd1, %rd6;
251 atom.global.add.u32 %r5, [%rd1], 1;
252 LBB0_4:
253 ret;
254 
255 }
256 // .globl__xla_fp32_comparison
257 .visible .entry __xla_fp32_comparison(
258 .param .u64 __xla_fp32_comparison_param_0,
259 .param .u64 __xla_fp32_comparison_param_1,
260 .param .f32 __xla_fp32_comparison_param_2,
261 .param .u64 __xla_fp32_comparison_param_3,
262 .param .u64 __xla_fp32_comparison_param_4
263 )
264 {
265 .reg .pred %p<12>;
266 .reg .f32 %f<12>;
267 .reg .b32 %r<9>;
268 .reg .b64 %rd<12>;
269 
270 ld.param.u64 %rd8, [__xla_fp32_comparison_param_3];
271 mov.u32 %r1, %tid.x;
272 mov.u32 %r2, %ctaid.x;
273 mov.u32 %r3, %ntid.x;
274 mad.lo.s32 %r4, %r3, %r2, %r1;
275 cvt.s64.s32 %rd4, %r4;
276 setp.ge.u64 %p1, %rd4, %rd8;
277 @%p1 bra LBB1_6;
278 ld.param.u64 %rd5, [__xla_fp32_comparison_param_0];
279 ld.param.u64 %rd7, [__xla_fp32_comparison_param_1];
280 cvta.to.global.u64 %rd2, %rd7;
281 cvta.to.global.u64 %rd3, %rd5;
282 shl.b64 %rd9, %rd4, 2;
283 add.s64 %rd10, %rd3, %rd9;
284 ld.global.f32 %f1, [%rd10];
285 add.s64 %rd11, %rd2, %rd9;
286 ld.global.f32 %f2, [%rd11];
287 abs.f32 %f3, %f1;
288 setp.gtu.f32 %p2, %f3, 0f7F800000;
289 abs.f32 %f4, %f2;
290 setp.gtu.f32 %p3, %f4, 0f7F800000;
291 and.pred  %p4, %p2, %p3;
292 @%p4 bra LBB1_6;
293 setp.eq.f32 %p5, %f3, 0f7F800000;
294 setp.eq.f32 %p6, %f4, 0f7F800000;
295 and.pred  %p7, %p5, %p6;
296 @!%p7 bra LBB1_4;
297 bra.uni LBB1_3;
298 LBB1_3:
299 mov.b32 %r5, %f1;
300 mov.b32 %r6, %f2;
301 xor.b32  %r7, %r6, %r5;
302 setp.gt.s32 %p8, %r7, -1;
303 @%p8 bra LBB1_6;
304 LBB1_4:
305 ld.param.f32 %f5, [__xla_fp32_comparison_param_2];
306 sub.f32 %f6, %f1, %f2;
307 abs.f32 %f7, %f6;
308 max.f32 %f8, %f3, %f4;
309 add.f32 %f9, %f8, 0f3F800000;
310 div.rn.f32 %f10, %f7, %f9;
311 setp.gt.f32 %p9, %f10, %f5;
312 abs.f32 %f11, %f10;
313 setp.gtu.f32 %p10, %f11, 0f7F800000;
314 or.pred  %p11, %p9, %p10;
315 @!%p11 bra LBB1_6;
316 bra.uni LBB1_5;
317 LBB1_5:
318 ld.param.u64 %rd6, [__xla_fp32_comparison_param_4];
319 cvta.to.global.u64 %rd1, %rd6;
320 atom.global.add.u32 %r8, [%rd1], 1;
321 LBB1_6:
322 ret;
323 
324 }
325 // .globl__xla_fp64_comparison
326 .visible .entry __xla_fp64_comparison(
327 .param .u64 __xla_fp64_comparison_param_0,
328 .param .u64 __xla_fp64_comparison_param_1,
329 .param .f32 __xla_fp64_comparison_param_2,
330 .param .u64 __xla_fp64_comparison_param_3,
331 .param .u64 __xla_fp64_comparison_param_4
332 )
333 {
334 .reg .pred %p<16>;
335 .reg .f32 %f<2>;
336 .reg .b32 %r<13>;
337 .reg .f64 %fd<12>;
338 .reg .b64 %rd<12>;
339 
340 ld.param.u64 %rd8, [__xla_fp64_comparison_param_3];
341 mov.u32 %r2, %tid.x;
342 mov.u32 %r3, %ctaid.x;
343 mov.u32 %r4, %ntid.x;
344 mad.lo.s32 %r5, %r4, %r3, %r2;
345 cvt.s64.s32 %rd4, %r5;
346 setp.ge.u64 %p1, %rd4, %rd8;
347 @%p1 bra LBB2_6;
348 ld.param.u64 %rd5, [__xla_fp64_comparison_param_0];
349 ld.param.u64 %rd7, [__xla_fp64_comparison_param_1];
350 cvta.to.global.u64 %rd2, %rd7;
351 cvta.to.global.u64 %rd3, %rd5;
352 shl.b64 %rd9, %rd4, 3;
353 add.s64 %rd10, %rd3, %rd9;
354 ld.global.f64 %fd1, [%rd10];
355 add.s64 %rd11, %rd2, %rd9;
356 ld.global.f64 %fd2, [%rd11];
357 abs.f64 %fd3, %fd1;
358 setp.gtu.f64 %p2, %fd3, 0d7FF0000000000000;
359 abs.f64 %fd4, %fd2;
360 setp.gtu.f64 %p3, %fd4, 0d7FF0000000000000;
361 and.pred  %p4, %p2, %p3;
362 @%p4 bra LBB2_6;
363 {
364 .reg .b32 %temp;
365 mov.b64 {%r6, %temp}, %fd1;
366 }
367 {
368 .reg .b32 %temp;
369 mov.b64 {%temp, %r1}, %fd1;
370 }
371 and.b32  %r7, %r1, 2147483647;
372 setp.eq.s32 %p5, %r7, 2146435072;
373 setp.eq.s32 %p6, %r6, 0;
374 and.pred  %p7, %p5, %p6;
375 @!%p7 bra LBB2_4;
376 bra.uni LBB2_3;
377 LBB2_3:
378 {
379 .reg .b32 %temp;
380 mov.b64 {%r8, %temp}, %fd2;
381 }
382 {
383 .reg .b32 %temp;
384 mov.b64 {%temp, %r9}, %fd2;
385 }
386 and.b32  %r10, %r9, 2147483647;
387 setp.eq.s32 %p8, %r10, 2146435072;
388 setp.eq.s32 %p9, %r8, 0;
389 and.pred  %p10, %p8, %p9;
390 xor.b32  %r11, %r9, %r1;
391 setp.gt.s32 %p11, %r11, -1;
392 and.pred  %p12, %p10, %p11;
393 @%p12 bra LBB2_6;
394 LBB2_4:
395 ld.param.f32 %f1, [__xla_fp64_comparison_param_2];
396 sub.f64 %fd5, %fd1, %fd2;
397 abs.f64 %fd6, %fd5;
398 max.f64 %fd7, %fd3, %fd4;
399 add.f64 %fd8, %fd7, 0d3FF0000000000000;
400 div.rn.f64 %fd9, %fd6, %fd8;
401 cvt.f64.f32 %fd10, %f1;
402 setp.gt.f64 %p13, %fd9, %fd10;
403 abs.f64 %fd11, %fd9;
404 setp.gtu.f64 %p14, %fd11, 0d7FF0000000000000;
405 or.pred  %p15, %p13, %p14;
406 @!%p15 bra LBB2_6;
407 bra.uni LBB2_5;
408 LBB2_5:
409 ld.param.u64 %rd6, [__xla_fp64_comparison_param_4];
410 cvta.to.global.u64 %rd1, %rd6;
411 atom.global.add.u32 %r12, [%rd1], 1;
412 LBB2_6:
413 ret;
414 
415 }
416 // .globl__xla_bf16_comparison
417 .visible .entry __xla_bf16_comparison(
418 .param .u64 __xla_bf16_comparison_param_0,
419 .param .u64 __xla_bf16_comparison_param_1,
420 .param .f32 __xla_bf16_comparison_param_2,
421 .param .u64 __xla_bf16_comparison_param_3,
422 .param .u64 __xla_bf16_comparison_param_4
423 )
424 {
425 .reg .pred %p<10>;
426 .reg .b16 %rs<3>;
427 .reg .f32 %f<20>;
428 .reg .b32 %r<6>;
429 .reg .b64 %rd<12>;
430 
431 ld.param.u64 %rd8, [__xla_bf16_comparison_param_3];
432 mov.u32 %r1, %tid.x;
433 mov.u32 %r2, %ctaid.x;
434 mov.u32 %r3, %ntid.x;
435 mad.lo.s32 %r4, %r3, %r2, %r1;
436 cvt.s64.s32 %rd4, %r4;
437 setp.ge.u64 %p1, %rd4, %rd8;
438 @%p1 bra LBB3_4;
439 ld.param.u64 %rd5, [__xla_bf16_comparison_param_0];
440 ld.param.u64 %rd7, [__xla_bf16_comparison_param_1];
441 cvta.to.global.u64 %rd2, %rd7;
442 cvta.to.global.u64 %rd3, %rd5;
443 shl.b64 %rd9, %rd4, 1;
444 add.s64 %rd10, %rd3, %rd9;
445 ld.global.u16 %rs1, [%rd10];
446 // begin inline asm
447 { mov.b32 %f6, {0,%rs1};}
448 
449 // end inline asm
450 add.s64 %rd11, %rd2, %rd9;
451 ld.global.u16 %rs2, [%rd11];
452 // begin inline asm
453 { mov.b32 %f7, {0,%rs2};}
454 
455 // end inline asm
456 abs.f32 %f8, %f6;
457 setp.gtu.f32 %p2, %f8, 0f7F800000;
458 min.f32 %f9, %f6, 0f477FE100;
459 max.f32 %f10, %f9, 0fC77FE100;
460 selp.f32 %f1, %f6, %f10, %p2;
461 abs.f32 %f11, %f7;
462 setp.gtu.f32 %p3, %f11, 0f7F800000;
463 min.f32 %f12, %f7, 0f477FE100;
464 max.f32 %f13, %f12, 0fC77FE100;
465 selp.f32 %f2, %f7, %f13, %p3;
466 abs.f32 %f3, %f1;
467 setp.gtu.f32 %p4, %f3, 0f7F800000;
468 abs.f32 %f4, %f2;
469 setp.gtu.f32 %p5, %f4, 0f7F800000;
470 and.pred  %p6, %p4, %p5;
471 @%p6 bra LBB3_4;
472 ld.param.f32 %f5, [__xla_bf16_comparison_param_2];
473 sub.f32 %f14, %f1, %f2;
474 abs.f32 %f15, %f14;
475 max.f32 %f16, %f3, %f4;
476 add.f32 %f17, %f16, 0f3F800000;
477 div.rn.f32 %f18, %f15, %f17;
478 setp.gt.f32 %p7, %f18, %f5;
479 abs.f32 %f19, %f18;
480 setp.gtu.f32 %p8, %f19, 0f7F800000;
481 or.pred  %p9, %p7, %p8;
482 @!%p9 bra LBB3_4;
483 bra.uni LBB3_3;
484 LBB3_3:
485 ld.param.u64 %rd6, [__xla_bf16_comparison_param_4];
486 cvta.to.global.u64 %rd1, %rd6;
487 atom.global.add.u32 %r5, [%rd1], 1;
488 LBB3_4:
489 ret;
490 
491 }
492 // .globl__xla_int8_comparison
493 .visible .entry __xla_int8_comparison(
494 .param .u64 __xla_int8_comparison_param_0,
495 .param .u64 __xla_int8_comparison_param_1,
496 .param .f32 __xla_int8_comparison_param_2,
497 .param .u64 __xla_int8_comparison_param_3,
498 .param .u64 __xla_int8_comparison_param_4
499 )
500 {
501   .reg .pred %p<5>;
502   .reg .f32 %f<12>;
503   .reg .b32 %r<8>;
504   .reg .b64 %rd<11>;
505 
506   ld.param.u64 %rd8, [__xla_int8_comparison_param_3];
507   mov.u32 %r1, %tid.x;
508   mov.u32 %r2, %ctaid.x;
509   mov.u32 %r3, %ntid.x;
510   mad.lo.s32 %r4, %r3, %r2, %r1;
511   cvt.s64.s32 %rd4, %r4;
512   setp.ge.u64 %p1, %rd4, %rd8;
513   @%p1 bra LBB7_3;
514   ld.param.f32 %f1, [__xla_int8_comparison_param_2];
515   ld.param.u64 %rd5, [__xla_int8_comparison_param_0];
516   ld.param.u64 %rd7, [__xla_int8_comparison_param_1];
517   cvta.to.global.u64 %rd2, %rd7;
518   cvta.to.global.u64 %rd3, %rd5;
519   add.s64 %rd9, %rd3, %rd4;
520   ld.global.s8 %r5, [%rd9];
521   add.s64 %rd10, %rd2, %rd4;
522   ld.global.s8 %r6, [%rd10];
523   cvt.rn.f32.s32 %f2, %r5;
524   cvt.rn.f32.s32 %f3, %r6;
525   sub.f32 %f4, %f2, %f3;
526   abs.f32 %f5, %f4;
527   abs.f32 %f6, %f2;
528   abs.f32 %f7, %f3;
529   max.f32 %f8, %f6, %f7;
530   add.f32 %f9, %f8, 0f3F800000;
531   div.rn.f32 %f10, %f5, %f9;
532   setp.leu.f32 %p2, %f10, %f1;
533   abs.f32 %f11, %f10;
534   setp.le.f32 %p3, %f11, 0f7F800000;
535   and.pred %p4, %p2, %p3;
536   @%p4 bra LBB7_3;
537   ld.param.u64 %rd6, [__xla_int8_comparison_param_4];
538   cvta.to.global.u64 %rd1, %rd6;
539   atom.global.add.u32 %r7, [%rd1], 1;
540 LBB7_3:
541   ret;
542 }
543 
544 // .globl__xla_int32_comparison
545 .visible .entry __xla_int32_comparison(
546 .param .u64 __xla_int32_comparison_param_0,
547 .param .u64 __xla_int32_comparison_param_1,
548 .param .f32 __xla_int32_comparison_param_2,
549 .param .u64 __xla_int32_comparison_param_3,
550 .param .u64 __xla_int32_comparison_param_4
551 )
552 {
553 .reg .pred %p<5>;
554 .reg .f32 %f<12>;
555 .reg .b32 %r<8>;
556 .reg .b64 %rd<12>;
557 
558 ld.param.u64 %rd8, [__xla_int32_comparison_param_3];
559 mov.u32 %r1, %tid.x;
560 mov.u32 %r2, %ctaid.x;
561 mov.u32 %r3, %ntid.x;
562 mad.lo.s32 %r4, %r3, %r2, %r1;
563 cvt.s64.s32 %rd4, %r4;
564 setp.ge.u64 %p1, %rd4, %rd8;
565 @%p1 bra LBB5_3;
566 ld.param.f32 %f1, [__xla_int32_comparison_param_2];
567 ld.param.u64 %rd5, [__xla_int32_comparison_param_0];
568 ld.param.u64 %rd7, [__xla_int32_comparison_param_1];
569 cvta.to.global.u64 %rd2, %rd7;
570 cvta.to.global.u64 %rd3, %rd5;
571 shl.b64 %rd9, %rd4, 2;
572 add.s64 %rd10, %rd3, %rd9;
573 ld.global.u32 %r5, [%rd10];
574 cvt.rn.f32.s32 %f2, %r5;
575 add.s64 %rd11, %rd2, %rd9;
576 ld.global.u32 %r6, [%rd11];
577 cvt.rn.f32.s32 %f3, %r6;
578 sub.f32 %f4, %f2, %f3;
579 abs.f32 %f5, %f4;
580 abs.f32 %f6, %f2;
581 abs.f32 %f7, %f3;
582 max.f32 %f8, %f6, %f7;
583 add.f32 %f9, %f8, 0f3F800000;
584 div.rn.f32 %f10, %f5, %f9;
585 setp.gt.f32 %p2, %f10, %f1;
586 abs.f32 %f11, %f10;
587 setp.gtu.f32 %p3, %f11, 0f7F800000;
588 or.pred  %p4, %p2, %p3;
589 @!%p4 bra LBB5_3;
590 bra.uni LBB5_2;
591 LBB5_2:
592 ld.param.u64 %rd6, [__xla_int32_comparison_param_4];
593 cvta.to.global.u64 %rd1, %rd6;
594 atom.global.add.u32 %r7, [%rd1], 1;
595 LBB5_3:
596 ret;
597 
598 }
599 )";
600 
601 template <typename ElementT>
602 using ComparisonKernelT =
603     se::TypedKernel<se::DeviceMemory<ElementT>, se::DeviceMemory<ElementT>,
604                     float, uint64_t, se::DeviceMemory<uint64_t>>;
605 
606 // Compares two buffers on the GPU.
607 //
608 // Returns `true` if two buffers are equal, `false` otherwise.
609 template <typename ElementT>
DeviceCompare(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs,const Shape & buffer_shape,const HloModuleConfig & config,absl::string_view kernel_name)610 static StatusOr<bool> DeviceCompare(se::Stream* stream,
611                                     se::DeviceMemoryBase lhs,
612                                     se::DeviceMemoryBase rhs,
613                                     const Shape& buffer_shape,
614                                     const HloModuleConfig& config,
615                                     absl::string_view kernel_name) {
616   se::StreamExecutor* executor = stream->parent();
617 
618   se::ScopedDeviceMemory<uint64_t> out_param =
619       executor->AllocateOwnedScalar<uint64_t>();
620 
621   stream->ThenMemZero(out_param.ptr(), sizeof(uint64_t));
622   if (lhs.size() != rhs.size()) {
623     return InternalError("Mismatched buffer size: %d bytes vs. %d bytes",
624                          lhs.size(), rhs.size());
625   }
626 
627   se::DeviceMemory<ElementT> lhs_typed(lhs);
628   se::DeviceMemory<ElementT> rhs_typed(rhs);
629   uint64_t buffer_size = lhs_typed.ElementCount();
630 
631   absl::Span<const uint8_t> compiled_ptx = {};
632   StatusOr<absl::Span<const uint8_t>> compiled_ptx_or =
633       se::CompileGpuAsmOrGetCached(
634           executor->device_ordinal(), buffer_compare_ptx,
635           PtxOptsFromDebugOptions(config.debug_options()));
636   if (compiled_ptx_or.ok()) {
637     compiled_ptx = std::move(compiled_ptx_or).value();
638   } else {
639     static absl::once_flag ptxas_not_found_logged;
640     absl::call_once(ptxas_not_found_logged, [&]() {
641       LOG(WARNING)
642           << compiled_ptx_or.status().ToString()
643           << "\nRelying on driver to perform ptx compilation. "
644           << "\nSetting XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda "
645           << " or modifying $PATH can be used to set the location of ptxas"
646           << "\nThis message will only be logged once.";
647     });
648   }
649 
650   TF_ASSIGN_OR_RETURN(
651       std::unique_ptr<ComparisonKernelT<ElementT>> comparison_kernel,
652       (executor->CreateTypedKernel<se::DeviceMemory<ElementT>,
653                                    se::DeviceMemory<ElementT>, float, uint64_t,
654                                    se::DeviceMemory<uint64_t>>(
655           kernel_name, buffer_compare_ptx, compiled_ptx)));
656 
657   GpuDeviceInfo gpu_device_info;
658   gpu_device_info.threads_per_block_limit =
659       executor->GetDeviceDescription().threads_per_block_limit();
660   gpu_device_info.threads_per_warp =
661       executor->GetDeviceDescription().threads_per_warp();
662   gpu_device_info.shared_memory_per_block =
663       executor->GetDeviceDescription().shared_memory_per_block();
664   gpu_device_info.threads_per_core_limit =
665       executor->GetDeviceDescription().threads_per_core_limit();
666   gpu_device_info.core_count = executor->GetDeviceDescription().core_count();
667   gpu_device_info.block_dim_limit_x =
668       executor->GetDeviceDescription().block_dim_limit().x;
669   gpu_device_info.block_dim_limit_y =
670       executor->GetDeviceDescription().block_dim_limit().y;
671   gpu_device_info.block_dim_limit_z =
672       executor->GetDeviceDescription().block_dim_limit().z;
673 
674   TF_ASSIGN_OR_RETURN(LaunchDimensions dim,
675                       CalculateLaunchDimensions(buffer_shape, gpu_device_info));
676 
677   LaunchDimensions::Dim3D thread_counts = dim.thread_counts_per_block();
678   LaunchDimensions::Dim3D block_counts = dim.block_counts();
679   TF_RETURN_IF_ERROR(stream->ThenLaunch(
680       se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z),
681       se::BlockDim(block_counts.x, block_counts.y, block_counts.z),
682       *comparison_kernel, lhs_typed, rhs_typed, static_cast<float>(kTolerance),
683       buffer_size, out_param.cref()));
684 
685   uint64_t result = -1;
686   CHECK_EQ(out_param->size(), sizeof(result));
687   stream->ThenMemcpy(&result, *out_param, sizeof(result));
688   TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
689   return result == 0;
690 }
691 
692 // Host side comparison code that does the same thing, but reports some of the
693 // differences as well. It only print logs for debugging.
694 //
695 // Returns true if no differences were seen, false otherwise.
696 template <typename ElementType, typename ComparisonType>
HostCompare(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs)697 StatusOr<bool> HostCompare(se::Stream* stream, se::DeviceMemoryBase lhs,
698                            se::DeviceMemoryBase rhs) {
699   int64_t n = lhs.size() / sizeof(ElementType);
700   std::vector<ElementType> host_lhs(n), host_rhs(n);
701   stream->ThenMemcpy(host_lhs.data(), lhs, lhs.size());
702   stream->ThenMemcpy(host_rhs.data(), rhs, rhs.size());
703   TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
704 
705   const auto canonicalize = [](ComparisonType a) -> ComparisonType {
706     if (std::is_same<ElementType, Eigen::half>::value && a) {
707       constexpr ComparisonType kMaxFp16Value = 65505;
708       if (std::isnan(a)) {
709         return a;
710       }
711       return std::max(-kMaxFp16Value, std::min(a, kMaxFp16Value));
712     }
713     return a;
714   };
715   int differences_seen = 0;
716   for (int64_t i = 0; i < n && differences_seen < 10; i++) {
717     auto original_lhs = static_cast<ComparisonType>(host_lhs[i]);
718     auto original_rhs = static_cast<ComparisonType>(host_rhs[i]);
719     ComparisonType lhs = canonicalize(original_lhs);
720     ComparisonType rhs = canonicalize(original_rhs);
721     if (std::isnan(lhs) && std::isnan(rhs)) {
722       continue;
723     }
724     if (std::isinf(lhs) && std::isinf(rhs) && lhs == rhs) {
725       continue;
726     }
727     if (std::isfinite(lhs) != std::isfinite(rhs) ||
728         !(std::abs(lhs - rhs) / (std::max(std::abs(lhs), std::abs(rhs)) + 1) <
729           kTolerance)) {
730       differences_seen++;
731       LOG(ERROR) << "Difference at " << i << ": " << original_lhs << " vs "
732                  << original_rhs;
733     }
734   }
735   return differences_seen == 0;
736 }
737 
738 template <typename ElementT, typename ComparisonT>
CompareEqualParameterized(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs,const Shape & shape,const HloModuleConfig & config,absl::string_view kernel_name)739 static StatusOr<bool> CompareEqualParameterized(se::Stream* stream,
740                                                 se::DeviceMemoryBase lhs,
741                                                 se::DeviceMemoryBase rhs,
742                                                 const Shape& shape,
743                                                 const HloModuleConfig& config,
744                                                 absl::string_view kernel_name) {
745   XLA_SCOPED_LOGGING_TIMER("BufferComparator::CompareEqual");
746   TF_ASSIGN_OR_RETURN(
747       bool result,
748       DeviceCompare<ElementT>(stream, lhs, rhs, shape, config, kernel_name));
749 
750   if (result) {
751     return true;
752   }
753 
754   TF_ASSIGN_OR_RETURN(bool host_return,
755                       (HostCompare<ElementT, ComparisonT>(stream, lhs, rhs)));
756   CHECK_EQ(host_return, result)
757       << "Host comparison succeeded even though GPU comparison failed.";
758 
759   return false;
760 }
761 
CompareEqual(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs) const762 StatusOr<bool> BufferComparator::CompareEqual(se::Stream* stream,
763                                               se::DeviceMemoryBase lhs,
764                                               se::DeviceMemoryBase rhs) const {
765   switch (shape_.element_type()) {
766     case xla::F16:
767       return CompareEqualParameterized<Eigen::half, float>(
768           stream, lhs, rhs, shape_, config_, "__xla_fp16_comparison");
769     case xla::BF16:
770       return CompareEqualParameterized<Eigen::bfloat16, float>(
771           stream, lhs, rhs, shape_, config_, "__xla_bf16_comparison");
772     case xla::F32:
773       return CompareEqualParameterized<float, float>(
774           stream, lhs, rhs, shape_, config_, "__xla_fp32_comparison");
775     case xla::F64:
776       return CompareEqualParameterized<double, double>(
777           stream, lhs, rhs, shape_, config_, "__xla_fp64_comparison");
778     case xla::S8:
779       return CompareEqualParameterized<int8_t, float>(
780           stream, lhs, rhs, shape_, config_, "__xla_int8_comparison");
781     case xla::S32:
782       return CompareEqualParameterized<int32_t, float>(
783           stream, lhs, rhs, shape_, config_, "__xla_int32_comparison");
784     default:
785       return Unimplemented("Unimplemented element type");
786   }
787 }
788 
BufferComparator(const Shape & shape,const HloModuleConfig & config)789 BufferComparator::BufferComparator(const Shape& shape,
790                                    const HloModuleConfig& config)
791     : shape_(shape), config_(config) {
792   // Normalize complex shapes: since we treat the passed array as a contiguous
793   // storage it does not matter which dimension are we doubling.
794   auto double_dim_size = [&]() {
795     int64_t prev_zero_dim_size = shape_.dimensions(0);
796     shape_.set_dimensions(0, prev_zero_dim_size * 2);
797   };
798 
799   if (shape_.element_type() == PrimitiveType::C64) {
800     // C64 is just two F32s next to each other.
801     shape_.set_element_type(PrimitiveType::F32);
802     double_dim_size();
803   } else if (shape_.element_type() == PrimitiveType::C128) {
804     // C128 is just two F64s next to each other.
805     shape_.set_element_type(PrimitiveType::F64);
806     double_dim_size();
807   }
808 }
809 
810 }  // namespace gpu
811 }  // namespace xla
812