1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
17
18 #include <algorithm>
19 #include <cmath>
20
21 #include "absl/base/call_once.h"
22 #include "absl/strings/str_replace.h"
23 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
24 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/stream_executor/device_memory.h"
30 #include "tensorflow/stream_executor/gpu/asm_compiler.h"
31 #include "tensorflow/stream_executor/kernel.h"
32 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
33
34 namespace xla {
35 namespace gpu {
36
37 static constexpr double kTolerance = 0.1f;
38
39 // Comparison kernel code: compare two buffers of fp16/fp32/fp64/int8 of length
40 // buffer_length where the relative error does not exceed the passed
41 // rel_error_threshold. Write the number of mismatches into out parameter
42 // mismatch_count.
43 //
44 // NaN's are considered equal, and for half's we clamp all numbers to largest
45 // and smallest numbers representable to avoid miscomparisons due to overflows.
46 //
47 // The PTX below is compiled from the following CUDA code:
48 //
49 // #include<cuda_fp16.h>
50 // extern "C" { // avoid name mangling
51 // __device__ float __xla_buffer_comparator_canonicalize(float input) {
52 // // All fp16 infinities are treated as 65505 or -65505, in order to avoid
53 // // differences due to overflows.
54 // return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f));
55 // }
56
57 // __device__ float __xla_buffer_comparator_extract_int8(int pack) {
58 // // Extract the lower 8 bits from pack and convert it to float
59 // const unsigned int bit_mask = 0xff;
60 // unsigned int bits = pack & bit_mask;
61 // char* int8_ptr = (char*)&bits;
62 // return __int2float_rn(*int8_ptr);
63 // }
64
65 // __global__ void __xla_fp16_comparison(__half* buffer_a, __half* buffer_b,
66 // float rel_error_threshold,
67 // unsigned long long buffer_length,
68 // int* mismatch_count) {
69 // int idx = threadIdx.x + blockIdx.x * blockDim.x;
70 // if (idx >= buffer_length) return;
71 // float elem_a = __half2float(buffer_a[idx]);
72 // float elem_b = __half2float(buffer_b[idx]);
73 // elem_a = __xla_buffer_comparator_canonicalize(elem_a);
74 // elem_b = __xla_buffer_comparator_canonicalize(elem_b);
75 // if (isnan(elem_a) && isnan(elem_b)) return;
76 // float rel_error = abs(elem_a - elem_b)
77 // / (max(abs(elem_a), abs(elem_b)) + 1);
78 // if (rel_error > rel_error_threshold || isnan(rel_error))
79 // atomicAdd(mismatch_count, 1);
80 // }
81
82 // __global__ void __xla_fp32_comparison(float* buffer_a, float* buffer_b,
83 // float rel_error_threshold,
84 // unsigned long long buffer_length,
85 // int* mismatch_count) {
86 // int idx = threadIdx.x + blockIdx.x * blockDim.x;
87 // if (idx >= buffer_length) return;
88 // float elem_a = buffer_a[idx];
89 // float elem_b = buffer_b[idx];
90 // if (isnan(elem_a) && isnan(elem_b)) return;
91 // if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b))
92 // return;
93 // float rel_error = abs(elem_a - elem_b)
94 // / (max(abs(elem_a), abs(elem_b)) + 1);
95 // if (rel_error > rel_error_threshold || isnan(rel_error))
96 // atomicAdd(mismatch_count, 1);
97 // }
98
99 // __global__ void __xla_fp64_comparison(double* buffer_a, double* buffer_b,
100 // float rel_error_threshold,
101 // unsigned long long buffer_length,
102 // int* mismatch_count) {
103 // int idx = threadIdx.x + blockIdx.x * blockDim.x;
104 // if (idx >= buffer_length) return;
105 // double elem_a = buffer_a[idx];
106 // double elem_b = buffer_b[idx];
107 // if (isnan(elem_a) && isnan(elem_b)) return;
108 // if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b))
109 // return;
110 // double rel_error = abs(elem_a - elem_b)
111 // / (max(abs(elem_a), abs(elem_b)) + 1);
112 // if (rel_error > rel_error_threshold || isnan(rel_error))
113 // atomicAdd(mismatch_count, 1);
114 // }
115
116 // __global__ void __xla_int8_comparison(int* buffer_a, int* buffer_b,
117 // float rel_error_threshold,
118 // unsigned long long buffer_length,
119 // int* mismatch_count) {
120 // int idx = threadIdx.x + blockIdx.x * blockDim.x;
121 // if (idx >= buffer_length) return;
122 // int pack_a = buffer_a[idx];
123 // int pack_b = buffer_b[idx];
124 // for(int i = 0; i < 4; ++i) {
125 // float elem_a = __xla_buffer_comparator_extract_int8(pack_a);
126 // float elem_b = __xla_buffer_comparator_extract_int8(pack_b);
127 // float rel_error = abs(elem_a - elem_b)
128 // / (max(abs(elem_a), abs(elem_b)) + 1);
129 // if (rel_error > rel_error_threshold || isnan(rel_error))
130 // atomicAdd(mismatch_count, 1);
131 // pack_a >>= 8;
132 // pack_b >>= 8;
133 // }
134 // }
135 // } // end extern declaration.
136 static const char* buffer_compare_ptx = R"(
137 .version 4.2
138 .target sm_30
139 .address_size 64
140
141 // .globl __xla_fp16_comparison
142
143 .visible .entry __xla_fp16_comparison(
144 .param .u64 __xla_fp16_comparison_param_0,
145 .param .u64 __xla_fp16_comparison_param_1,
146 .param .f32 __xla_fp16_comparison_param_2,
147 .param .u64 __xla_fp16_comparison_param_3,
148 .param .u64 __xla_fp16_comparison_param_4
149 )
150 {
151 .reg .pred %p<9>;
152 .reg .b16 %rs<3>;
153 .reg .f32 %f<28>;
154 .reg .b32 %r<6>;
155 .reg .b64 %rd<12>;
156
157
158 ld.param.u64 %rd1, [__xla_fp16_comparison_param_0];
159 ld.param.u64 %rd2, [__xla_fp16_comparison_param_1];
160 ld.param.f32 %f10, [__xla_fp16_comparison_param_2];
161 ld.param.u64 %rd4, [__xla_fp16_comparison_param_3];
162 ld.param.u64 %rd3, [__xla_fp16_comparison_param_4];
163 mov.u32 %r2, %ntid.x;
164 mov.u32 %r3, %ctaid.x;
165 mov.u32 %r4, %tid.x;
166 mad.lo.s32 %r1, %r2, %r3, %r4;
167 cvt.s64.s32 %rd5, %r1;
168 setp.ge.u64 %p1, %rd5, %rd4;
169 @%p1 bra BB0_9;
170
171 cvta.to.global.u64 %rd6, %rd1;
172 mul.wide.s32 %rd7, %r1, 2;
173 add.s64 %rd8, %rd6, %rd7;
174 ld.global.u16 %rs1, [%rd8];
175 // inline asm
176 { cvt.f32.f16 %f26, %rs1;}
177
178 // inline asm
179 cvta.to.global.u64 %rd9, %rd2;
180 add.s64 %rd10, %rd9, %rd7;
181 ld.global.u16 %rs2, [%rd10];
182 // inline asm
183 { cvt.f32.f16 %f27, %rs2;}
184
185 // inline asm
186 abs.f32 %f13, %f26;
187 setp.gtu.f32 %p2, %f13, 0f7F800000;
188 @%p2 bra BB0_3;
189
190 mov.f32 %f14, 0f477FE100;
191 min.f32 %f15, %f26, %f14;
192 mov.f32 %f16, 0fC77FE100;
193 max.f32 %f26, %f16, %f15;
194
195 BB0_3:
196 abs.f32 %f17, %f27;
197 setp.gtu.f32 %p3, %f17, 0f7F800000;
198 @%p3 bra BB0_5;
199
200 mov.f32 %f18, 0f477FE100;
201 min.f32 %f19, %f27, %f18;
202 mov.f32 %f20, 0fC77FE100;
203 max.f32 %f27, %f20, %f19;
204
205 BB0_5:
206 abs.f32 %f7, %f26;
207 setp.gtu.f32 %p4, %f7, 0f7F800000;
208 abs.f32 %f8, %f27;
209 setp.gtu.f32 %p5, %f8, 0f7F800000;
210 and.pred %p6, %p4, %p5;
211 @%p6 bra BB0_9;
212
213 sub.f32 %f21, %f26, %f27;
214 abs.f32 %f22, %f21;
215 max.f32 %f23, %f7, %f8;
216 add.f32 %f24, %f23, 0f3F800000;
217 div.rn.f32 %f9, %f22, %f24;
218 setp.gt.f32 %p7, %f9, %f10;
219 @%p7 bra BB0_8;
220
221 abs.f32 %f25, %f9;
222 setp.le.f32 %p8, %f25, 0f7F800000;
223 @%p8 bra BB0_9;
224
225 BB0_8:
226 cvta.to.global.u64 %rd11, %rd3;
227 atom.global.add.u32 %r5, [%rd11], 1;
228
229 BB0_9:
230 ret;
231 }
232
233 // .globl __xla_fp32_comparison
234 .visible .entry __xla_fp32_comparison(
235 .param .u64 __xla_fp32_comparison_param_0,
236 .param .u64 __xla_fp32_comparison_param_1,
237 .param .f32 __xla_fp32_comparison_param_2,
238 .param .u64 __xla_fp32_comparison_param_3,
239 .param .u64 __xla_fp32_comparison_param_4
240 )
241 {
242 .reg .pred %p<10>;
243 .reg .b16 %rs<3>;
244 .reg .f32 %f<13>;
245 .reg .b32 %r<10>;
246 .reg .b64 %rd<12>;
247
248
249 ld.param.u64 %rd1, [__xla_fp32_comparison_param_0];
250 ld.param.u64 %rd2, [__xla_fp32_comparison_param_1];
251 ld.param.f32 %f6, [__xla_fp32_comparison_param_2];
252 ld.param.u64 %rd4, [__xla_fp32_comparison_param_3];
253 ld.param.u64 %rd3, [__xla_fp32_comparison_param_4];
254 mov.u32 %r2, %ntid.x;
255 mov.u32 %r3, %ctaid.x;
256 mov.u32 %r4, %tid.x;
257 mad.lo.s32 %r1, %r2, %r3, %r4;
258 cvt.s64.s32 %rd5, %r1;
259 setp.ge.u64 %p1, %rd5, %rd4;
260 @%p1 bra BB1_8;
261
262 cvta.to.global.u64 %rd6, %rd1;
263 mul.wide.s32 %rd7, %r1, 4;
264 add.s64 %rd8, %rd6, %rd7;
265 cvta.to.global.u64 %rd9, %rd2;
266 add.s64 %rd10, %rd9, %rd7;
267 ld.global.f32 %f1, [%rd10];
268 ld.global.f32 %f2, [%rd8];
269 abs.f32 %f3, %f2;
270 setp.le.f32 %p2, %f3, 0f7F800000;
271 @%p2 bra BB1_3;
272
273 abs.f32 %f7, %f1;
274 setp.gtu.f32 %p3, %f7, 0f7F800000;
275 @%p3 bra BB1_8;
276
277 BB1_3:
278 setp.neu.f32 %p4, %f3, 0f7F800000;
279 abs.f32 %f4, %f1;
280 setp.neu.f32 %p5, %f4, 0f7F800000;
281 or.pred %p6, %p4, %p5;
282 @%p6 bra BB1_5;
283
284 mov.b32 %r5, %f2;
285 shr.u32 %r6, %r5, 31;
286 cvt.u16.u32 %rs1, %r6;
287 mov.b32 %r7, %f1;
288 shr.u32 %r8, %r7, 31;
289 cvt.u16.u32 %rs2, %r8;
290 setp.eq.s16 %p7, %rs1, %rs2;
291 @%p7 bra BB1_8;
292
293 BB1_5:
294 sub.f32 %f8, %f2, %f1;
295 abs.f32 %f9, %f8;
296 max.f32 %f10, %f3, %f4;
297 add.f32 %f11, %f10, 0f3F800000;
298 div.rn.f32 %f5, %f9, %f11;
299 setp.gt.f32 %p8, %f5, %f6;
300 @%p8 bra BB1_7;
301
302 abs.f32 %f12, %f5;
303 setp.le.f32 %p9, %f12, 0f7F800000;
304 @%p9 bra BB1_8;
305
306 BB1_7:
307 cvta.to.global.u64 %rd11, %rd3;
308 atom.global.add.u32 %r9, [%rd11], 1;
309
310 BB1_8:
311 ret;
312 }
313
314 // .globl __xla_fp64_comparison
315 .visible .entry __xla_fp64_comparison(
316 .param .u64 __xla_fp64_comparison_param_0,
317 .param .u64 __xla_fp64_comparison_param_1,
318 .param .f32 __xla_fp64_comparison_param_2,
319 .param .u64 __xla_fp64_comparison_param_3,
320 .param .u64 __xla_fp64_comparison_param_4
321 )
322 {
323 .reg .pred %p<11>;
324 .reg .b16 %rs<3>;
325 .reg .f32 %f<2>;
326 .reg .b32 %r<14>;
327 .reg .f64 %fd<13>;
328 .reg .b64 %rd<12>;
329
330
331 ld.param.u64 %rd1, [__xla_fp64_comparison_param_0];
332 ld.param.u64 %rd2, [__xla_fp64_comparison_param_1];
333 ld.param.f32 %f1, [__xla_fp64_comparison_param_2];
334 ld.param.u64 %rd4, [__xla_fp64_comparison_param_3];
335 ld.param.u64 %rd3, [__xla_fp64_comparison_param_4];
336 mov.u32 %r4, %ntid.x;
337 mov.u32 %r5, %ctaid.x;
338 mov.u32 %r6, %tid.x;
339 mad.lo.s32 %r1, %r4, %r5, %r6;
340 cvt.s64.s32 %rd5, %r1;
341 setp.ge.u64 %p1, %rd5, %rd4;
342 @%p1 bra BB2_11;
343
344 cvta.to.global.u64 %rd6, %rd1;
345 mul.wide.s32 %rd7, %r1, 8;
346 add.s64 %rd8, %rd6, %rd7;
347 cvta.to.global.u64 %rd9, %rd2;
348 add.s64 %rd10, %rd9, %rd7;
349 ld.global.f64 %fd1, [%rd10];
350 ld.global.f64 %fd2, [%rd8];
351 abs.f64 %fd3, %fd2;
352 setp.le.f64 %p2, %fd3, 0d7FF0000000000000;
353 @%p2 bra BB2_3;
354
355 abs.f64 %fd5, %fd1;
356 setp.gtu.f64 %p3, %fd5, 0d7FF0000000000000;
357 @%p3 bra BB2_11;
358
359 BB2_3:
360 {
361 .reg .b32 %temp;
362 mov.b64 {%temp, %r2}, %fd2;
363 }
364 and.b32 %r7, %r2, 2147483647;
365 setp.ne.s32 %p4, %r7, 2146435072;
366 @%p4 bra BB2_8;
367
368 {
369 .reg .b32 %temp;
370 mov.b64 {%r8, %temp}, %fd2;
371 }
372 setp.ne.s32 %p5, %r8, 0;
373 @%p5 bra BB2_8;
374
375 {
376 .reg .b32 %temp;
377 mov.b64 {%temp, %r3}, %fd1;
378 }
379 and.b32 %r9, %r3, 2147483647;
380 setp.ne.s32 %p6, %r9, 2146435072;
381 @%p6 bra BB2_8;
382
383 {
384 .reg .b32 %temp;
385 mov.b64 {%r10, %temp}, %fd1;
386 }
387 setp.ne.s32 %p7, %r10, 0;
388 @%p7 bra BB2_8;
389
390 shr.u32 %r11, %r2, 31;
391 cvt.u16.u32 %rs1, %r11;
392 shr.u32 %r12, %r3, 31;
393 cvt.u16.u32 %rs2, %r12;
394 setp.eq.s16 %p8, %rs1, %rs2;
395 @%p8 bra BB2_11;
396
397 BB2_8:
398 sub.f64 %fd6, %fd2, %fd1;
399 abs.f64 %fd7, %fd6;
400 abs.f64 %fd8, %fd1;
401 max.f64 %fd9, %fd3, %fd8;
402 add.f64 %fd10, %fd9, 0d3FF0000000000000;
403 div.rn.f64 %fd4, %fd7, %fd10;
404 cvt.f64.f32 %fd11, %f1;
405 setp.gt.f64 %p9, %fd4, %fd11;
406 @%p9 bra BB2_10;
407
408 abs.f64 %fd12, %fd4;
409 setp.le.f64 %p10, %fd12, 0d7FF0000000000000;
410 @%p10 bra BB2_11;
411
412 BB2_10:
413 cvta.to.global.u64 %rd11, %rd3;
414 atom.global.add.u32 %r13, [%rd11], 1;
415
416 BB2_11:
417 ret;
418 }
419
420 // .globl __xla_int8_comparison
421 .visible .entry __xla_int8_comparison(
422 .param .u64 __xla_int8_comparison_param_0,
423 .param .u64 __xla_int8_comparison_param_1,
424 .param .f32 __xla_int8_comparison_param_2,
425 .param .u64 __xla_int8_comparison_param_3,
426 .param .u64 __xla_int8_comparison_param_4
427 )
428 {
429 .reg .pred %p<10>;
430 .reg .f32 %f<42>;
431 .reg .b32 %r<23>;
432 .reg .b64 %rd<12>;
433
434
435 ld.param.u64 %rd2, [__xla_int8_comparison_param_0];
436 ld.param.u64 %rd3, [__xla_int8_comparison_param_1];
437 ld.param.f32 %f5, [__xla_int8_comparison_param_2];
438 ld.param.u64 %rd4, [__xla_int8_comparison_param_3];
439 ld.param.u64 %rd5, [__xla_int8_comparison_param_4];
440 cvta.to.global.u64 %rd1, %rd5;
441 mov.u32 %r4, %ntid.x;
442 mov.u32 %r5, %ctaid.x;
443 mov.u32 %r6, %tid.x;
444 mad.lo.s32 %r1, %r4, %r5, %r6;
445 cvt.s64.s32 %rd6, %r1;
446 setp.ge.u64 %p1, %rd6, %rd4;
447 @%p1 bra BB3_13;
448
449 cvta.to.global.u64 %rd7, %rd2;
450 mul.wide.s32 %rd8, %r1, 4;
451 add.s64 %rd9, %rd7, %rd8;
452 cvta.to.global.u64 %rd10, %rd3;
453 add.s64 %rd11, %rd10, %rd8;
454 ld.global.u32 %r2, [%rd9];
455 cvt.s32.s8 %r7, %r2;
456 cvt.rn.f32.s32 %f6, %r7;
457 ld.global.u32 %r3, [%rd11];
458 cvt.s32.s8 %r8, %r3;
459 cvt.rn.f32.s32 %f7, %r8;
460 sub.f32 %f8, %f6, %f7;
461 abs.f32 %f9, %f8;
462 abs.f32 %f10, %f6;
463 abs.f32 %f11, %f7;
464 max.f32 %f12, %f10, %f11;
465 add.f32 %f13, %f12, 0f3F800000;
466 div.rn.f32 %f1, %f9, %f13;
467 setp.gt.f32 %p2, %f1, %f5;
468 @%p2 bra BB3_3;
469
470 abs.f32 %f14, %f1;
471 setp.le.f32 %p3, %f14, 0f7F800000;
472 @%p3 bra BB3_4;
473
474 BB3_3:
475 atom.global.add.u32 %r9, [%rd1], 1;
476
477 BB3_4:
478 shr.u32 %r10, %r3, 8;
479 shr.u32 %r11, %r2, 8;
480 cvt.s32.s8 %r12, %r11;
481 cvt.rn.f32.s32 %f15, %r12;
482 cvt.s32.s8 %r13, %r10;
483 cvt.rn.f32.s32 %f16, %r13;
484 sub.f32 %f17, %f15, %f16;
485 abs.f32 %f18, %f17;
486 abs.f32 %f19, %f15;
487 abs.f32 %f20, %f16;
488 max.f32 %f21, %f19, %f20;
489 add.f32 %f22, %f21, 0f3F800000;
490 div.rn.f32 %f2, %f18, %f22;
491 setp.gt.f32 %p4, %f2, %f5;
492 @%p4 bra BB3_6;
493
494 abs.f32 %f23, %f2;
495 setp.le.f32 %p5, %f23, 0f7F800000;
496 @%p5 bra BB3_7;
497
498 BB3_6:
499 atom.global.add.u32 %r14, [%rd1], 1;
500
501 BB3_7:
502 shr.u32 %r15, %r3, 16;
503 shr.u32 %r16, %r2, 16;
504 cvt.s32.s8 %r17, %r16;
505 cvt.rn.f32.s32 %f24, %r17;
506 cvt.s32.s8 %r18, %r15;
507 cvt.rn.f32.s32 %f25, %r18;
508 sub.f32 %f26, %f24, %f25;
509 abs.f32 %f27, %f26;
510 abs.f32 %f28, %f24;
511 abs.f32 %f29, %f25;
512 max.f32 %f30, %f28, %f29;
513 add.f32 %f31, %f30, 0f3F800000;
514 div.rn.f32 %f3, %f27, %f31;
515 setp.gt.f32 %p6, %f3, %f5;
516 @%p6 bra BB3_9;
517
518 abs.f32 %f32, %f3;
519 setp.le.f32 %p7, %f32, 0f7F800000;
520 @%p7 bra BB3_10;
521
522 BB3_9:
523 atom.global.add.u32 %r19, [%rd1], 1;
524
525 BB3_10:
526 shr.s32 %r20, %r2, 24;
527 cvt.rn.f32.s32 %f33, %r20;
528 shr.s32 %r21, %r3, 24;
529 cvt.rn.f32.s32 %f34, %r21;
530 sub.f32 %f35, %f33, %f34;
531 abs.f32 %f36, %f35;
532 abs.f32 %f37, %f33;
533 abs.f32 %f38, %f34;
534 max.f32 %f39, %f37, %f38;
535 add.f32 %f40, %f39, 0f3F800000;
536 div.rn.f32 %f4, %f36, %f40;
537 setp.gt.f32 %p8, %f4, %f5;
538 @%p8 bra BB3_12;
539
540 abs.f32 %f41, %f4;
541 setp.le.f32 %p9, %f41, 0f7F800000;
542 @%p9 bra BB3_13;
543
544 BB3_12:
545 atom.global.add.u32 %r22, [%rd1], 1;
546
547 BB3_13:
548 ret;
549 }
550 )";
551
552 template <typename ElementT>
553 using ComparisonKernelT =
554 se::TypedKernel<se::DeviceMemory<ElementT>, se::DeviceMemory<ElementT>,
555 float, uint64, se::DeviceMemory<uint64>>;
556
557 // Compares two buffers on the GPU.
558 //
559 // Returns `true` if two buffers are equal, `false` otherwise.
560 template <typename ElementT>
DeviceCompare(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs,const Shape & buffer_shape,const HloModuleConfig & config,absl::string_view kernel_name)561 static StatusOr<bool> DeviceCompare(se::Stream* stream,
562 se::DeviceMemoryBase lhs,
563 se::DeviceMemoryBase rhs,
564 const Shape& buffer_shape,
565 const HloModuleConfig& config,
566 absl::string_view kernel_name) {
567 se::StreamExecutor* executor = stream->parent();
568
569 se::ScopedDeviceMemory<uint64> out_param =
570 executor->AllocateOwnedScalar<uint64>();
571
572 stream->ThenMemZero(out_param.ptr(), sizeof(uint64));
573 if (lhs.size() != rhs.size()) {
574 return InternalError("Mismatched buffer size: %d bytes vs. %d bytes",
575 lhs.size(), rhs.size());
576 }
577
578 se::DeviceMemory<ElementT> lhs_typed(lhs);
579 se::DeviceMemory<ElementT> rhs_typed(rhs);
580 uint64 buffer_size = lhs_typed.ElementCount();
581
582 absl::Span<const uint8> compiled_ptx = {};
583 StatusOr<absl::Span<const uint8>> compiled_ptx_or =
584 se::CompileGpuAsmOrGetCached(executor->device_ordinal(),
585 buffer_compare_ptx,
586 PtxOptsFromConfig(config));
587 if (compiled_ptx_or.ok()) {
588 compiled_ptx = compiled_ptx_or.ConsumeValueOrDie();
589 } else {
590 static absl::once_flag ptxas_not_found_logged;
591 absl::call_once(ptxas_not_found_logged, [&]() {
592 LOG(WARNING)
593 << compiled_ptx_or.status().ToString()
594 << "\nRelying on driver to perform ptx compilation. "
595 << "\nSetting XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda "
596 << " or modifying $PATH can be used to set the location of ptxas"
597 << "\nThis message will only be logged once.";
598 });
599 }
600
601 TF_ASSIGN_OR_RETURN(
602 std::unique_ptr<ComparisonKernelT<ElementT>> comparison_kernel,
603 (executor->CreateTypedKernel<se::DeviceMemory<ElementT>,
604 se::DeviceMemory<ElementT>, float, uint64,
605 se::DeviceMemory<uint64>>(
606 kernel_name, buffer_compare_ptx, compiled_ptx)));
607
608 GpuDeviceInfo gpu_device_info;
609 gpu_device_info.threads_per_block_limit =
610 executor->GetDeviceDescription().threads_per_block_limit();
611 gpu_device_info.threads_per_warp =
612 executor->GetDeviceDescription().threads_per_warp();
613 gpu_device_info.shared_memory_per_block =
614 executor->GetDeviceDescription().shared_memory_per_block();
615 gpu_device_info.threads_per_core_limit =
616 executor->GetDeviceDescription().threads_per_core_limit();
617 gpu_device_info.core_count = executor->GetDeviceDescription().core_count();
618 LaunchDimensions dim =
619 CalculateLaunchDimensions(buffer_shape, gpu_device_info);
620
621 LaunchDimensions::Dim3D thread_counts = dim.thread_counts_per_block();
622 LaunchDimensions::Dim3D block_counts = dim.block_counts();
623 stream->ThenLaunch(
624 se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z),
625 se::BlockDim(block_counts.x, block_counts.y, block_counts.z),
626 *comparison_kernel, lhs_typed, rhs_typed, static_cast<float>(kTolerance),
627 buffer_size, out_param.cref());
628
629 uint64 result = -1;
630 CHECK_EQ(out_param->size(), sizeof(result));
631 stream->ThenMemcpy(&result, *out_param, sizeof(result));
632 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
633 return result == 0;
634 }
635
636 // Host side comparison code that does the same thing, but reports some of the
637 // differences as well. It only print logs for debugging.
638 //
639 // Returns true if no differences were seen, false otherwise.
640 template <typename ElementType, typename ComparisonType>
HostCompare(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs)641 StatusOr<bool> HostCompare(se::Stream* stream, se::DeviceMemoryBase lhs,
642 se::DeviceMemoryBase rhs) {
643 int64 n = lhs.size() / sizeof(ElementType);
644 std::vector<ElementType> host_lhs(n), host_rhs(n);
645 stream->ThenMemcpy(host_lhs.data(), lhs, lhs.size());
646 stream->ThenMemcpy(host_rhs.data(), rhs, rhs.size());
647 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
648
649 const auto canonicalize = [](ComparisonType a) -> ComparisonType {
650 if (std::is_same<ElementType, Eigen::half>::value && a) {
651 constexpr ComparisonType kMaxFp16Value = 65505.;
652 if (std::isnan(a)) {
653 return a;
654 }
655 return std::max(-kMaxFp16Value, std::min(a, kMaxFp16Value));
656 }
657 return a;
658 };
659 int differences_seen = 0;
660 for (int64 i = 0; i < n && differences_seen < 10; i++) {
661 auto original_lhs = static_cast<ComparisonType>(host_lhs[i]);
662 auto original_rhs = static_cast<ComparisonType>(host_rhs[i]);
663 ComparisonType lhs = canonicalize(original_lhs);
664 ComparisonType rhs = canonicalize(original_rhs);
665 if (std::isnan(lhs) && std::isnan(rhs)) {
666 continue;
667 }
668 if (std::isinf(lhs) && std::isinf(rhs) && lhs == rhs) {
669 continue;
670 }
671 if (std::isfinite(lhs) != std::isfinite(rhs) ||
672 !(std::abs(lhs - rhs) / (std::max(std::abs(lhs), std::abs(rhs)) + 1) <
673 kTolerance)) {
674 differences_seen++;
675 LOG(ERROR) << "Difference at " << i << ": " << original_lhs << " vs "
676 << original_rhs;
677 }
678 }
679 return differences_seen == 0;
680 }
681
682 template <typename ElementT, typename ComparisonT>
CompareEqualParameterized(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs,const Shape & shape,const HloModuleConfig & config,absl::string_view kernel_name)683 static StatusOr<bool> CompareEqualParameterized(se::Stream* stream,
684 se::DeviceMemoryBase lhs,
685 se::DeviceMemoryBase rhs,
686 const Shape& shape,
687 const HloModuleConfig& config,
688 absl::string_view kernel_name) {
689 XLA_SCOPED_LOGGING_TIMER("BufferComparator::CompareEqual");
690 TF_ASSIGN_OR_RETURN(
691 bool result,
692 DeviceCompare<ElementT>(stream, lhs, rhs, shape, config, kernel_name));
693
694 if (result) {
695 return true;
696 }
697
698 TF_ASSIGN_OR_RETURN(bool host_return,
699 (HostCompare<ElementT, ComparisonT>(stream, lhs, rhs)));
700 CHECK(host_return == result) << "Different comparison result on GPU vs host";
701
702 return false;
703 }
704
CompareEqual(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs) const705 StatusOr<bool> BufferComparator::CompareEqual(se::Stream* stream,
706 se::DeviceMemoryBase lhs,
707 se::DeviceMemoryBase rhs) const {
708 switch (shape_.element_type()) {
709 case xla::F16:
710 return CompareEqualParameterized<Eigen::half, float>(
711 stream, lhs, rhs, shape_, config_, "__xla_fp16_comparison");
712 case xla::F32:
713 return CompareEqualParameterized<float, float>(
714 stream, lhs, rhs, shape_, config_, "__xla_fp32_comparison");
715 case xla::F64:
716 return CompareEqualParameterized<double, double>(
717 stream, lhs, rhs, shape_, config_, "__xla_fp64_comparison");
718 case xla::S8:
719 return CompareEqualParameterized<int8, float>(
720 stream, lhs, rhs, shape_, config_, "__xla_int8_comparison");
721 default:
722 return Unimplemented("Unimplemented element type");
723 }
724 }
725
BufferComparator(const Shape & shape,const HloModuleConfig & config)726 BufferComparator::BufferComparator(const Shape& shape,
727 const HloModuleConfig& config)
728 : shape_(shape), config_(config) {
729 // Normalize complex shapes: since we treat the passed array as a contiguous
730 // storage it does not matter which dimension are we doubling.
731 auto double_dim_size = [&]() {
732 int64 prev_zero_dim_size = shape_.dimensions(0);
733 shape_.set_dimensions(0, prev_zero_dim_size * 2);
734 };
735
736 if (shape_.element_type() == PrimitiveType::C64) {
737 // C64 is just two F32s next to each other.
738 shape_.set_element_type(PrimitiveType::F32);
739 double_dim_size();
740 } else if (shape_.element_type() == PrimitiveType::C128) {
741 // C128 is just two F64s next to each other.
742 shape_.set_element_type(PrimitiveType::F64);
743 double_dim_size();
744 }
745 }
746
747 } // namespace gpu
748 } // namespace xla
749