1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
17
18 #include <algorithm>
19 #include <cmath>
20 #include <utility>
21
22 #include "absl/base/call_once.h"
23 #include "absl/strings/str_replace.h"
24 #include "tensorflow/compiler/xla/service/gpu/gpu_asm_opts_util.h"
25 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
26 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/status_macros.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/stream_executor/device_memory.h"
31 #include "tensorflow/stream_executor/gpu/asm_compiler.h"
32 #include "tensorflow/stream_executor/kernel.h"
33 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
34
35 namespace xla {
36 namespace gpu {
37
38 static constexpr double kTolerance = 0.1f;
39
40 // Comparison kernel code: compare two buffers of
41 // bf16/fp16/fp32/fp64/int8_t/int32_t of length buffer_length where the relative
42 // error does not exceed the passed rel_error_threshold. Write the number of
43 // mismatches into out parameter mismatch_count.
44 //
45 // NaN's are considered equal, and for half's we clamp all numbers to largest
46 // and smallest numbers representable to avoid miscomparisons due to overflows.
47 //
48 // The PTX below is compiled from the following CUDA code:
49 //
50 // #include <cuda_fp16.h>
51 // #include <cuda_bf16.h>
52 //
53 // namespace {
54 //
55 // __device__ __inline__ float __xla_buffer_comparator_canonicalize(float input)
56 // {
57 // // All fp16 infinities are treated as 65505 or -65505, in order to avoid
58 // // differences due to overflows.
59 // return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f));
60 // }
61 //
62 // } // end anonymous namespace
63 //
64 // extern "C" { // avoid name mangling
65 //
66 //
67 // __global__ void __xla_fp16_comparison(__half* buffer_a, __half* buffer_b,
68 // float rel_error_threshold,
69 // unsigned long long buffer_length,
70 // int* mismatch_count) {
71 // int idx = threadIdx.x + blockIdx.x * blockDim.x;
72 // if (idx >= buffer_length) return;
73 // float elem_a = __half2float(buffer_a[idx]);
74 // float elem_b = __half2float(buffer_b[idx]);
75 // elem_a = __xla_buffer_comparator_canonicalize(elem_a);
76 // elem_b = __xla_buffer_comparator_canonicalize(elem_b);
77 // if (isnan(elem_a) && isnan(elem_b)) return;
78 //
79 // float rel_error = abs(elem_a - elem_b)
80 // / (max(abs(elem_a), abs(elem_b)) + 1);
81 //
82 // if (rel_error > rel_error_threshold || isnan(rel_error))
83 // atomicAdd(mismatch_count, 1);
84 // }
85 //
86 // __global__ void __xla_fp32_comparison(float* buffer_a, float* buffer_b,
87 // float rel_error_threshold,
88 // unsigned long long buffer_length,
89 // int* mismatch_count) {
90 // int idx = threadIdx.x + blockIdx.x * blockDim.x;
91 // if (idx >= buffer_length) return;
92 // float elem_a = buffer_a[idx];
93 // float elem_b = buffer_b[idx];
94 // if (isnan(elem_a) && isnan(elem_b)) return;
95 // if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b))
96 // return;
97 //
98 // float rel_error = abs(elem_a - elem_b)
99 // / (max(abs(elem_a), abs(elem_b)) + 1);
100 // if (rel_error > rel_error_threshold || isnan(rel_error))
101 // atomicAdd(mismatch_count, 1);
102 // }
103 //
104 // __global__ void __xla_fp64_comparison(double* buffer_a, double* buffer_b,
105 // float rel_error_threshold,
106 // unsigned long long buffer_length,
107 // int* mismatch_count) {
108 // int idx = threadIdx.x + blockIdx.x * blockDim.x;
109 // if (idx >= buffer_length) return;
110 //
111 // double elem_a = buffer_a[idx];
112 // double elem_b = buffer_b[idx];
113 // if (isnan(elem_a) && isnan(elem_b)) return;
114 // if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b))
115 // return;
116 // double rel_error = abs(elem_a - elem_b)
117 // / (max(abs(elem_a), abs(elem_b)) + 1);
118 // if (rel_error > rel_error_threshold || isnan(rel_error))
119 // atomicAdd(mismatch_count, 1);
120 // }
121 //
122 // __global__ void __xla_bf16_comparison(__nv_bfloat16* buffer_a,
123 // __nv_bfloat16* buffer_b,
124 // float rel_error_threshold,
125 // unsigned long long buffer_length,
126 // int* mismatch_count) {
127 // int idx = threadIdx.x + blockIdx.x * blockDim.x;
128 // if (idx >= buffer_length) return;
129 // float elem_a = __bfloat162float(buffer_a[idx]);
130 // float elem_b = __bfloat162float(buffer_b[idx]);
131 // elem_a = __xla_buffer_comparator_canonicalize(elem_a);
132 // elem_b = __xla_buffer_comparator_canonicalize(elem_b);
133 // if (isnan(elem_a) && isnan(elem_b)) return;
134 //
135 // float rel_error = abs(elem_a - elem_b)
136 // / (max(abs(elem_a), abs(elem_b)) + 1);
137 //
138 // if (rel_error > rel_error_threshold || isnan(rel_error))
139 // atomicAdd(mismatch_count, 1);
140 // }
141 //
142 // // TODO(b/191520348): The comparison below requires exact equality.
143 // __global__ void __xla_int8_comparison(int8_t* buffer_a, int8_t* buffer_b,
144 // float rel_error_threshold,
145 // unsigned long long buffer_length,
146 // int* mismatch_count) {
147 // int idx = threadIdx.x + blockIdx.x * blockDim.x;
148 // if (idx >= buffer_length) return;
149 // float a = buffer_a[idx];
150 // float b = buffer_b[idx];
151 // float rel_error = abs(a - b) / (max(abs(a), abs(b)) + 1);
152 // if (rel_error > rel_error_threshold || isnan(rel_error))
153 // atomicAdd(mismatch_count, 1);
154 // }
155 //
156 // __global__ void __xla_int32_comparison(int* buffer_a, int* buffer_b,
157 // float rel_error_threshold,
158 // unsigned long long buffer_length,
159 // int* mismatch_count) {
160 // int idx = threadIdx.x + blockIdx.x * blockDim.x;
161 // if (idx >= buffer_length) return;
162 // float elem_a = static_cast<float>(buffer_a[idx]);
163 // float elem_b = static_cast<float>(buffer_b[idx]);
164 // float rel_error = abs(elem_a - elem_b)
165 // / (max(abs(elem_a), abs(elem_b)) + 1);
166 // if (rel_error > rel_error_threshold || isnan(rel_error))
167 // atomicAdd(mismatch_count, 1);
168 // }
169 // } // end extern declaration
170 static const char* buffer_compare_ptx = R"(
171 //
172 // Generated by LLVM NVPTX Back-End
173 //
174
175 .version 4.2
176 .target sm_30
177 .address_size 64
178
179 // .globl__xla_fp16_comparison
180
181 .visible .entry __xla_fp16_comparison(
182 .param .u64 __xla_fp16_comparison_param_0,
183 .param .u64 __xla_fp16_comparison_param_1,
184 .param .f32 __xla_fp16_comparison_param_2,
185 .param .u64 __xla_fp16_comparison_param_3,
186 .param .u64 __xla_fp16_comparison_param_4
187 )
188 {
189 .reg .pred %p<10>;
190 .reg .b16 %rs<3>;
191 .reg .f32 %f<20>;
192 .reg .b32 %r<6>;
193 .reg .b64 %rd<12>;
194
195 ld.param.u64 %rd8, [__xla_fp16_comparison_param_3];
196 mov.u32 %r1, %tid.x;
197 mov.u32 %r2, %ctaid.x;
198 mov.u32 %r3, %ntid.x;
199 mad.lo.s32 %r4, %r3, %r2, %r1;
200 cvt.s64.s32 %rd4, %r4;
201 setp.ge.u64 %p1, %rd4, %rd8;
202 @%p1 bra LBB0_4;
203 ld.param.u64 %rd5, [__xla_fp16_comparison_param_0];
204 ld.param.u64 %rd7, [__xla_fp16_comparison_param_1];
205 cvta.to.global.u64 %rd2, %rd7;
206 cvta.to.global.u64 %rd3, %rd5;
207 shl.b64 %rd9, %rd4, 1;
208 add.s64 %rd10, %rd3, %rd9;
209 ld.global.u16 %rs1, [%rd10];
210 // begin inline asm
211 { cvt.f32.f16 %f6, %rs1;}
212
213 // end inline asm
214 add.s64 %rd11, %rd2, %rd9;
215 ld.global.u16 %rs2, [%rd11];
216 // begin inline asm
217 { cvt.f32.f16 %f7, %rs2;}
218
219 // end inline asm
220 abs.f32 %f8, %f6;
221 setp.gtu.f32 %p2, %f8, 0f7F800000;
222 min.f32 %f9, %f6, 0f477FE100;
223 max.f32 %f10, %f9, 0fC77FE100;
224 selp.f32 %f1, %f6, %f10, %p2;
225 abs.f32 %f11, %f7;
226 setp.gtu.f32 %p3, %f11, 0f7F800000;
227 min.f32 %f12, %f7, 0f477FE100;
228 max.f32 %f13, %f12, 0fC77FE100;
229 selp.f32 %f2, %f7, %f13, %p3;
230 abs.f32 %f3, %f1;
231 setp.gtu.f32 %p4, %f3, 0f7F800000;
232 abs.f32 %f4, %f2;
233 setp.gtu.f32 %p5, %f4, 0f7F800000;
234 and.pred %p6, %p4, %p5;
235 @%p6 bra LBB0_4;
236 ld.param.f32 %f5, [__xla_fp16_comparison_param_2];
237 sub.f32 %f14, %f1, %f2;
238 abs.f32 %f15, %f14;
239 max.f32 %f16, %f3, %f4;
240 add.f32 %f17, %f16, 0f3F800000;
241 div.rn.f32 %f18, %f15, %f17;
242 setp.gt.f32 %p7, %f18, %f5;
243 abs.f32 %f19, %f18;
244 setp.gtu.f32 %p8, %f19, 0f7F800000;
245 or.pred %p9, %p7, %p8;
246 @!%p9 bra LBB0_4;
247 bra.uni LBB0_3;
248 LBB0_3:
249 ld.param.u64 %rd6, [__xla_fp16_comparison_param_4];
250 cvta.to.global.u64 %rd1, %rd6;
251 atom.global.add.u32 %r5, [%rd1], 1;
252 LBB0_4:
253 ret;
254
255 }
256 // .globl__xla_fp32_comparison
257 .visible .entry __xla_fp32_comparison(
258 .param .u64 __xla_fp32_comparison_param_0,
259 .param .u64 __xla_fp32_comparison_param_1,
260 .param .f32 __xla_fp32_comparison_param_2,
261 .param .u64 __xla_fp32_comparison_param_3,
262 .param .u64 __xla_fp32_comparison_param_4
263 )
264 {
265 .reg .pred %p<12>;
266 .reg .f32 %f<12>;
267 .reg .b32 %r<9>;
268 .reg .b64 %rd<12>;
269
270 ld.param.u64 %rd8, [__xla_fp32_comparison_param_3];
271 mov.u32 %r1, %tid.x;
272 mov.u32 %r2, %ctaid.x;
273 mov.u32 %r3, %ntid.x;
274 mad.lo.s32 %r4, %r3, %r2, %r1;
275 cvt.s64.s32 %rd4, %r4;
276 setp.ge.u64 %p1, %rd4, %rd8;
277 @%p1 bra LBB1_6;
278 ld.param.u64 %rd5, [__xla_fp32_comparison_param_0];
279 ld.param.u64 %rd7, [__xla_fp32_comparison_param_1];
280 cvta.to.global.u64 %rd2, %rd7;
281 cvta.to.global.u64 %rd3, %rd5;
282 shl.b64 %rd9, %rd4, 2;
283 add.s64 %rd10, %rd3, %rd9;
284 ld.global.f32 %f1, [%rd10];
285 add.s64 %rd11, %rd2, %rd9;
286 ld.global.f32 %f2, [%rd11];
287 abs.f32 %f3, %f1;
288 setp.gtu.f32 %p2, %f3, 0f7F800000;
289 abs.f32 %f4, %f2;
290 setp.gtu.f32 %p3, %f4, 0f7F800000;
291 and.pred %p4, %p2, %p3;
292 @%p4 bra LBB1_6;
293 setp.eq.f32 %p5, %f3, 0f7F800000;
294 setp.eq.f32 %p6, %f4, 0f7F800000;
295 and.pred %p7, %p5, %p6;
296 @!%p7 bra LBB1_4;
297 bra.uni LBB1_3;
298 LBB1_3:
299 mov.b32 %r5, %f1;
300 mov.b32 %r6, %f2;
301 xor.b32 %r7, %r6, %r5;
302 setp.gt.s32 %p8, %r7, -1;
303 @%p8 bra LBB1_6;
304 LBB1_4:
305 ld.param.f32 %f5, [__xla_fp32_comparison_param_2];
306 sub.f32 %f6, %f1, %f2;
307 abs.f32 %f7, %f6;
308 max.f32 %f8, %f3, %f4;
309 add.f32 %f9, %f8, 0f3F800000;
310 div.rn.f32 %f10, %f7, %f9;
311 setp.gt.f32 %p9, %f10, %f5;
312 abs.f32 %f11, %f10;
313 setp.gtu.f32 %p10, %f11, 0f7F800000;
314 or.pred %p11, %p9, %p10;
315 @!%p11 bra LBB1_6;
316 bra.uni LBB1_5;
317 LBB1_5:
318 ld.param.u64 %rd6, [__xla_fp32_comparison_param_4];
319 cvta.to.global.u64 %rd1, %rd6;
320 atom.global.add.u32 %r8, [%rd1], 1;
321 LBB1_6:
322 ret;
323
324 }
325 // .globl__xla_fp64_comparison
326 .visible .entry __xla_fp64_comparison(
327 .param .u64 __xla_fp64_comparison_param_0,
328 .param .u64 __xla_fp64_comparison_param_1,
329 .param .f32 __xla_fp64_comparison_param_2,
330 .param .u64 __xla_fp64_comparison_param_3,
331 .param .u64 __xla_fp64_comparison_param_4
332 )
333 {
334 .reg .pred %p<16>;
335 .reg .f32 %f<2>;
336 .reg .b32 %r<13>;
337 .reg .f64 %fd<12>;
338 .reg .b64 %rd<12>;
339
340 ld.param.u64 %rd8, [__xla_fp64_comparison_param_3];
341 mov.u32 %r2, %tid.x;
342 mov.u32 %r3, %ctaid.x;
343 mov.u32 %r4, %ntid.x;
344 mad.lo.s32 %r5, %r4, %r3, %r2;
345 cvt.s64.s32 %rd4, %r5;
346 setp.ge.u64 %p1, %rd4, %rd8;
347 @%p1 bra LBB2_6;
348 ld.param.u64 %rd5, [__xla_fp64_comparison_param_0];
349 ld.param.u64 %rd7, [__xla_fp64_comparison_param_1];
350 cvta.to.global.u64 %rd2, %rd7;
351 cvta.to.global.u64 %rd3, %rd5;
352 shl.b64 %rd9, %rd4, 3;
353 add.s64 %rd10, %rd3, %rd9;
354 ld.global.f64 %fd1, [%rd10];
355 add.s64 %rd11, %rd2, %rd9;
356 ld.global.f64 %fd2, [%rd11];
357 abs.f64 %fd3, %fd1;
358 setp.gtu.f64 %p2, %fd3, 0d7FF0000000000000;
359 abs.f64 %fd4, %fd2;
360 setp.gtu.f64 %p3, %fd4, 0d7FF0000000000000;
361 and.pred %p4, %p2, %p3;
362 @%p4 bra LBB2_6;
363 {
364 .reg .b32 %temp;
365 mov.b64 {%r6, %temp}, %fd1;
366 }
367 {
368 .reg .b32 %temp;
369 mov.b64 {%temp, %r1}, %fd1;
370 }
371 and.b32 %r7, %r1, 2147483647;
372 setp.eq.s32 %p5, %r7, 2146435072;
373 setp.eq.s32 %p6, %r6, 0;
374 and.pred %p7, %p5, %p6;
375 @!%p7 bra LBB2_4;
376 bra.uni LBB2_3;
377 LBB2_3:
378 {
379 .reg .b32 %temp;
380 mov.b64 {%r8, %temp}, %fd2;
381 }
382 {
383 .reg .b32 %temp;
384 mov.b64 {%temp, %r9}, %fd2;
385 }
386 and.b32 %r10, %r9, 2147483647;
387 setp.eq.s32 %p8, %r10, 2146435072;
388 setp.eq.s32 %p9, %r8, 0;
389 and.pred %p10, %p8, %p9;
390 xor.b32 %r11, %r9, %r1;
391 setp.gt.s32 %p11, %r11, -1;
392 and.pred %p12, %p10, %p11;
393 @%p12 bra LBB2_6;
394 LBB2_4:
395 ld.param.f32 %f1, [__xla_fp64_comparison_param_2];
396 sub.f64 %fd5, %fd1, %fd2;
397 abs.f64 %fd6, %fd5;
398 max.f64 %fd7, %fd3, %fd4;
399 add.f64 %fd8, %fd7, 0d3FF0000000000000;
400 div.rn.f64 %fd9, %fd6, %fd8;
401 cvt.f64.f32 %fd10, %f1;
402 setp.gt.f64 %p13, %fd9, %fd10;
403 abs.f64 %fd11, %fd9;
404 setp.gtu.f64 %p14, %fd11, 0d7FF0000000000000;
405 or.pred %p15, %p13, %p14;
406 @!%p15 bra LBB2_6;
407 bra.uni LBB2_5;
408 LBB2_5:
409 ld.param.u64 %rd6, [__xla_fp64_comparison_param_4];
410 cvta.to.global.u64 %rd1, %rd6;
411 atom.global.add.u32 %r12, [%rd1], 1;
412 LBB2_6:
413 ret;
414
415 }
416 // .globl__xla_bf16_comparison
417 .visible .entry __xla_bf16_comparison(
418 .param .u64 __xla_bf16_comparison_param_0,
419 .param .u64 __xla_bf16_comparison_param_1,
420 .param .f32 __xla_bf16_comparison_param_2,
421 .param .u64 __xla_bf16_comparison_param_3,
422 .param .u64 __xla_bf16_comparison_param_4
423 )
424 {
425 .reg .pred %p<10>;
426 .reg .b16 %rs<3>;
427 .reg .f32 %f<20>;
428 .reg .b32 %r<6>;
429 .reg .b64 %rd<12>;
430
431 ld.param.u64 %rd8, [__xla_bf16_comparison_param_3];
432 mov.u32 %r1, %tid.x;
433 mov.u32 %r2, %ctaid.x;
434 mov.u32 %r3, %ntid.x;
435 mad.lo.s32 %r4, %r3, %r2, %r1;
436 cvt.s64.s32 %rd4, %r4;
437 setp.ge.u64 %p1, %rd4, %rd8;
438 @%p1 bra LBB3_4;
439 ld.param.u64 %rd5, [__xla_bf16_comparison_param_0];
440 ld.param.u64 %rd7, [__xla_bf16_comparison_param_1];
441 cvta.to.global.u64 %rd2, %rd7;
442 cvta.to.global.u64 %rd3, %rd5;
443 shl.b64 %rd9, %rd4, 1;
444 add.s64 %rd10, %rd3, %rd9;
445 ld.global.u16 %rs1, [%rd10];
446 // begin inline asm
447 { mov.b32 %f6, {0,%rs1};}
448
449 // end inline asm
450 add.s64 %rd11, %rd2, %rd9;
451 ld.global.u16 %rs2, [%rd11];
452 // begin inline asm
453 { mov.b32 %f7, {0,%rs2};}
454
455 // end inline asm
456 abs.f32 %f8, %f6;
457 setp.gtu.f32 %p2, %f8, 0f7F800000;
458 min.f32 %f9, %f6, 0f477FE100;
459 max.f32 %f10, %f9, 0fC77FE100;
460 selp.f32 %f1, %f6, %f10, %p2;
461 abs.f32 %f11, %f7;
462 setp.gtu.f32 %p3, %f11, 0f7F800000;
463 min.f32 %f12, %f7, 0f477FE100;
464 max.f32 %f13, %f12, 0fC77FE100;
465 selp.f32 %f2, %f7, %f13, %p3;
466 abs.f32 %f3, %f1;
467 setp.gtu.f32 %p4, %f3, 0f7F800000;
468 abs.f32 %f4, %f2;
469 setp.gtu.f32 %p5, %f4, 0f7F800000;
470 and.pred %p6, %p4, %p5;
471 @%p6 bra LBB3_4;
472 ld.param.f32 %f5, [__xla_bf16_comparison_param_2];
473 sub.f32 %f14, %f1, %f2;
474 abs.f32 %f15, %f14;
475 max.f32 %f16, %f3, %f4;
476 add.f32 %f17, %f16, 0f3F800000;
477 div.rn.f32 %f18, %f15, %f17;
478 setp.gt.f32 %p7, %f18, %f5;
479 abs.f32 %f19, %f18;
480 setp.gtu.f32 %p8, %f19, 0f7F800000;
481 or.pred %p9, %p7, %p8;
482 @!%p9 bra LBB3_4;
483 bra.uni LBB3_3;
484 LBB3_3:
485 ld.param.u64 %rd6, [__xla_bf16_comparison_param_4];
486 cvta.to.global.u64 %rd1, %rd6;
487 atom.global.add.u32 %r5, [%rd1], 1;
488 LBB3_4:
489 ret;
490
491 }
492 // .globl__xla_int8_comparison
493 .visible .entry __xla_int8_comparison(
494 .param .u64 __xla_int8_comparison_param_0,
495 .param .u64 __xla_int8_comparison_param_1,
496 .param .f32 __xla_int8_comparison_param_2,
497 .param .u64 __xla_int8_comparison_param_3,
498 .param .u64 __xla_int8_comparison_param_4
499 )
500 {
501 .reg .pred %p<5>;
502 .reg .f32 %f<12>;
503 .reg .b32 %r<8>;
504 .reg .b64 %rd<11>;
505
506 ld.param.u64 %rd8, [__xla_int8_comparison_param_3];
507 mov.u32 %r1, %tid.x;
508 mov.u32 %r2, %ctaid.x;
509 mov.u32 %r3, %ntid.x;
510 mad.lo.s32 %r4, %r3, %r2, %r1;
511 cvt.s64.s32 %rd4, %r4;
512 setp.ge.u64 %p1, %rd4, %rd8;
513 @%p1 bra LBB7_3;
514 ld.param.f32 %f1, [__xla_int8_comparison_param_2];
515 ld.param.u64 %rd5, [__xla_int8_comparison_param_0];
516 ld.param.u64 %rd7, [__xla_int8_comparison_param_1];
517 cvta.to.global.u64 %rd2, %rd7;
518 cvta.to.global.u64 %rd3, %rd5;
519 add.s64 %rd9, %rd3, %rd4;
520 ld.global.s8 %r5, [%rd9];
521 add.s64 %rd10, %rd2, %rd4;
522 ld.global.s8 %r6, [%rd10];
523 cvt.rn.f32.s32 %f2, %r5;
524 cvt.rn.f32.s32 %f3, %r6;
525 sub.f32 %f4, %f2, %f3;
526 abs.f32 %f5, %f4;
527 abs.f32 %f6, %f2;
528 abs.f32 %f7, %f3;
529 max.f32 %f8, %f6, %f7;
530 add.f32 %f9, %f8, 0f3F800000;
531 div.rn.f32 %f10, %f5, %f9;
532 setp.leu.f32 %p2, %f10, %f1;
533 abs.f32 %f11, %f10;
534 setp.le.f32 %p3, %f11, 0f7F800000;
535 and.pred %p4, %p2, %p3;
536 @%p4 bra LBB7_3;
537 ld.param.u64 %rd6, [__xla_int8_comparison_param_4];
538 cvta.to.global.u64 %rd1, %rd6;
539 atom.global.add.u32 %r7, [%rd1], 1;
540 LBB7_3:
541 ret;
542 }
543
544 // .globl__xla_int32_comparison
545 .visible .entry __xla_int32_comparison(
546 .param .u64 __xla_int32_comparison_param_0,
547 .param .u64 __xla_int32_comparison_param_1,
548 .param .f32 __xla_int32_comparison_param_2,
549 .param .u64 __xla_int32_comparison_param_3,
550 .param .u64 __xla_int32_comparison_param_4
551 )
552 {
553 .reg .pred %p<5>;
554 .reg .f32 %f<12>;
555 .reg .b32 %r<8>;
556 .reg .b64 %rd<12>;
557
558 ld.param.u64 %rd8, [__xla_int32_comparison_param_3];
559 mov.u32 %r1, %tid.x;
560 mov.u32 %r2, %ctaid.x;
561 mov.u32 %r3, %ntid.x;
562 mad.lo.s32 %r4, %r3, %r2, %r1;
563 cvt.s64.s32 %rd4, %r4;
564 setp.ge.u64 %p1, %rd4, %rd8;
565 @%p1 bra LBB5_3;
566 ld.param.f32 %f1, [__xla_int32_comparison_param_2];
567 ld.param.u64 %rd5, [__xla_int32_comparison_param_0];
568 ld.param.u64 %rd7, [__xla_int32_comparison_param_1];
569 cvta.to.global.u64 %rd2, %rd7;
570 cvta.to.global.u64 %rd3, %rd5;
571 shl.b64 %rd9, %rd4, 2;
572 add.s64 %rd10, %rd3, %rd9;
573 ld.global.u32 %r5, [%rd10];
574 cvt.rn.f32.s32 %f2, %r5;
575 add.s64 %rd11, %rd2, %rd9;
576 ld.global.u32 %r6, [%rd11];
577 cvt.rn.f32.s32 %f3, %r6;
578 sub.f32 %f4, %f2, %f3;
579 abs.f32 %f5, %f4;
580 abs.f32 %f6, %f2;
581 abs.f32 %f7, %f3;
582 max.f32 %f8, %f6, %f7;
583 add.f32 %f9, %f8, 0f3F800000;
584 div.rn.f32 %f10, %f5, %f9;
585 setp.gt.f32 %p2, %f10, %f1;
586 abs.f32 %f11, %f10;
587 setp.gtu.f32 %p3, %f11, 0f7F800000;
588 or.pred %p4, %p2, %p3;
589 @!%p4 bra LBB5_3;
590 bra.uni LBB5_2;
591 LBB5_2:
592 ld.param.u64 %rd6, [__xla_int32_comparison_param_4];
593 cvta.to.global.u64 %rd1, %rd6;
594 atom.global.add.u32 %r7, [%rd1], 1;
595 LBB5_3:
596 ret;
597
598 }
599 )";
600
601 template <typename ElementT>
602 using ComparisonKernelT =
603 se::TypedKernel<se::DeviceMemory<ElementT>, se::DeviceMemory<ElementT>,
604 float, uint64_t, se::DeviceMemory<uint64_t>>;
605
606 // Compares two buffers on the GPU.
607 //
608 // Returns `true` if two buffers are equal, `false` otherwise.
609 template <typename ElementT>
DeviceCompare(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs,const Shape & buffer_shape,const HloModuleConfig & config,absl::string_view kernel_name)610 static StatusOr<bool> DeviceCompare(se::Stream* stream,
611 se::DeviceMemoryBase lhs,
612 se::DeviceMemoryBase rhs,
613 const Shape& buffer_shape,
614 const HloModuleConfig& config,
615 absl::string_view kernel_name) {
616 se::StreamExecutor* executor = stream->parent();
617
618 se::ScopedDeviceMemory<uint64_t> out_param =
619 executor->AllocateOwnedScalar<uint64_t>();
620
621 stream->ThenMemZero(out_param.ptr(), sizeof(uint64_t));
622 if (lhs.size() != rhs.size()) {
623 return InternalError("Mismatched buffer size: %d bytes vs. %d bytes",
624 lhs.size(), rhs.size());
625 }
626
627 se::DeviceMemory<ElementT> lhs_typed(lhs);
628 se::DeviceMemory<ElementT> rhs_typed(rhs);
629 uint64_t buffer_size = lhs_typed.ElementCount();
630
631 absl::Span<const uint8_t> compiled_ptx = {};
632 StatusOr<absl::Span<const uint8_t>> compiled_ptx_or =
633 se::CompileGpuAsmOrGetCached(
634 executor->device_ordinal(), buffer_compare_ptx,
635 PtxOptsFromDebugOptions(config.debug_options()));
636 if (compiled_ptx_or.ok()) {
637 compiled_ptx = std::move(compiled_ptx_or).value();
638 } else {
639 static absl::once_flag ptxas_not_found_logged;
640 absl::call_once(ptxas_not_found_logged, [&]() {
641 LOG(WARNING)
642 << compiled_ptx_or.status().ToString()
643 << "\nRelying on driver to perform ptx compilation. "
644 << "\nSetting XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda "
645 << " or modifying $PATH can be used to set the location of ptxas"
646 << "\nThis message will only be logged once.";
647 });
648 }
649
650 TF_ASSIGN_OR_RETURN(
651 std::unique_ptr<ComparisonKernelT<ElementT>> comparison_kernel,
652 (executor->CreateTypedKernel<se::DeviceMemory<ElementT>,
653 se::DeviceMemory<ElementT>, float, uint64_t,
654 se::DeviceMemory<uint64_t>>(
655 kernel_name, buffer_compare_ptx, compiled_ptx)));
656
657 GpuDeviceInfo gpu_device_info;
658 gpu_device_info.threads_per_block_limit =
659 executor->GetDeviceDescription().threads_per_block_limit();
660 gpu_device_info.threads_per_warp =
661 executor->GetDeviceDescription().threads_per_warp();
662 gpu_device_info.shared_memory_per_block =
663 executor->GetDeviceDescription().shared_memory_per_block();
664 gpu_device_info.threads_per_core_limit =
665 executor->GetDeviceDescription().threads_per_core_limit();
666 gpu_device_info.core_count = executor->GetDeviceDescription().core_count();
667 gpu_device_info.block_dim_limit_x =
668 executor->GetDeviceDescription().block_dim_limit().x;
669 gpu_device_info.block_dim_limit_y =
670 executor->GetDeviceDescription().block_dim_limit().y;
671 gpu_device_info.block_dim_limit_z =
672 executor->GetDeviceDescription().block_dim_limit().z;
673
674 TF_ASSIGN_OR_RETURN(LaunchDimensions dim,
675 CalculateLaunchDimensions(buffer_shape, gpu_device_info));
676
677 LaunchDimensions::Dim3D thread_counts = dim.thread_counts_per_block();
678 LaunchDimensions::Dim3D block_counts = dim.block_counts();
679 TF_RETURN_IF_ERROR(stream->ThenLaunch(
680 se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z),
681 se::BlockDim(block_counts.x, block_counts.y, block_counts.z),
682 *comparison_kernel, lhs_typed, rhs_typed, static_cast<float>(kTolerance),
683 buffer_size, out_param.cref()));
684
685 uint64_t result = -1;
686 CHECK_EQ(out_param->size(), sizeof(result));
687 stream->ThenMemcpy(&result, *out_param, sizeof(result));
688 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
689 return result == 0;
690 }
691
692 // Host side comparison code that does the same thing, but reports some of the
693 // differences as well. It only print logs for debugging.
694 //
695 // Returns true if no differences were seen, false otherwise.
696 template <typename ElementType, typename ComparisonType>
HostCompare(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs)697 StatusOr<bool> HostCompare(se::Stream* stream, se::DeviceMemoryBase lhs,
698 se::DeviceMemoryBase rhs) {
699 int64_t n = lhs.size() / sizeof(ElementType);
700 std::vector<ElementType> host_lhs(n), host_rhs(n);
701 stream->ThenMemcpy(host_lhs.data(), lhs, lhs.size());
702 stream->ThenMemcpy(host_rhs.data(), rhs, rhs.size());
703 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
704
705 const auto canonicalize = [](ComparisonType a) -> ComparisonType {
706 if (std::is_same<ElementType, Eigen::half>::value && a) {
707 constexpr ComparisonType kMaxFp16Value = 65505;
708 if (std::isnan(a)) {
709 return a;
710 }
711 return std::max(-kMaxFp16Value, std::min(a, kMaxFp16Value));
712 }
713 return a;
714 };
715 int differences_seen = 0;
716 for (int64_t i = 0; i < n && differences_seen < 10; i++) {
717 auto original_lhs = static_cast<ComparisonType>(host_lhs[i]);
718 auto original_rhs = static_cast<ComparisonType>(host_rhs[i]);
719 ComparisonType lhs = canonicalize(original_lhs);
720 ComparisonType rhs = canonicalize(original_rhs);
721 if (std::isnan(lhs) && std::isnan(rhs)) {
722 continue;
723 }
724 if (std::isinf(lhs) && std::isinf(rhs) && lhs == rhs) {
725 continue;
726 }
727 if (std::isfinite(lhs) != std::isfinite(rhs) ||
728 !(std::abs(lhs - rhs) / (std::max(std::abs(lhs), std::abs(rhs)) + 1) <
729 kTolerance)) {
730 differences_seen++;
731 LOG(ERROR) << "Difference at " << i << ": " << original_lhs << " vs "
732 << original_rhs;
733 }
734 }
735 return differences_seen == 0;
736 }
737
738 template <typename ElementT, typename ComparisonT>
CompareEqualParameterized(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs,const Shape & shape,const HloModuleConfig & config,absl::string_view kernel_name)739 static StatusOr<bool> CompareEqualParameterized(se::Stream* stream,
740 se::DeviceMemoryBase lhs,
741 se::DeviceMemoryBase rhs,
742 const Shape& shape,
743 const HloModuleConfig& config,
744 absl::string_view kernel_name) {
745 XLA_SCOPED_LOGGING_TIMER("BufferComparator::CompareEqual");
746 TF_ASSIGN_OR_RETURN(
747 bool result,
748 DeviceCompare<ElementT>(stream, lhs, rhs, shape, config, kernel_name));
749
750 if (result) {
751 return true;
752 }
753
754 TF_ASSIGN_OR_RETURN(bool host_return,
755 (HostCompare<ElementT, ComparisonT>(stream, lhs, rhs)));
756 CHECK_EQ(host_return, result)
757 << "Host comparison succeeded even though GPU comparison failed.";
758
759 return false;
760 }
761
CompareEqual(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs) const762 StatusOr<bool> BufferComparator::CompareEqual(se::Stream* stream,
763 se::DeviceMemoryBase lhs,
764 se::DeviceMemoryBase rhs) const {
765 switch (shape_.element_type()) {
766 case xla::F16:
767 return CompareEqualParameterized<Eigen::half, float>(
768 stream, lhs, rhs, shape_, config_, "__xla_fp16_comparison");
769 case xla::BF16:
770 return CompareEqualParameterized<Eigen::bfloat16, float>(
771 stream, lhs, rhs, shape_, config_, "__xla_bf16_comparison");
772 case xla::F32:
773 return CompareEqualParameterized<float, float>(
774 stream, lhs, rhs, shape_, config_, "__xla_fp32_comparison");
775 case xla::F64:
776 return CompareEqualParameterized<double, double>(
777 stream, lhs, rhs, shape_, config_, "__xla_fp64_comparison");
778 case xla::S8:
779 return CompareEqualParameterized<int8_t, float>(
780 stream, lhs, rhs, shape_, config_, "__xla_int8_comparison");
781 case xla::S32:
782 return CompareEqualParameterized<int32_t, float>(
783 stream, lhs, rhs, shape_, config_, "__xla_int32_comparison");
784 default:
785 return Unimplemented("Unimplemented element type");
786 }
787 }
788
BufferComparator(const Shape & shape,const HloModuleConfig & config)789 BufferComparator::BufferComparator(const Shape& shape,
790 const HloModuleConfig& config)
791 : shape_(shape), config_(config) {
792 // Normalize complex shapes: since we treat the passed array as a contiguous
793 // storage it does not matter which dimension are we doubling.
794 auto double_dim_size = [&]() {
795 int64_t prev_zero_dim_size = shape_.dimensions(0);
796 shape_.set_dimensions(0, prev_zero_dim_size * 2);
797 };
798
799 if (shape_.element_type() == PrimitiveType::C64) {
800 // C64 is just two F32s next to each other.
801 shape_.set_element_type(PrimitiveType::F32);
802 double_dim_size();
803 } else if (shape_.element_type() == PrimitiveType::C128) {
804 // C128 is just two F64s next to each other.
805 shape_.set_element_type(PrimitiveType::F64);
806 double_dim_size();
807 }
808 }
809
810 } // namespace gpu
811 } // namespace xla
812