1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
17
18 #include <algorithm>
19 #include <cmath>
20
21 #include "absl/base/call_once.h"
22 #include "absl/strings/str_replace.h"
23 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
24 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/stream_executor/device_memory.h"
30 #include "tensorflow/stream_executor/kernel.h"
31 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
32
33 namespace xla {
34 namespace gpu {
35
36 static constexpr double kTolerance = 0.1f;
37
38 // Comparison kernel code: compare two buffers of fp16/fp32/fp64/int8 of length
39 // buffer_length where the relative error does not exceed the passed
40 // rel_error_threshold. Write the number of mismatches into out parameter
41 // mismatch_count.
42 //
43 // NaN's are considered equal, and for half's we clamp all numbers to largest
44 // and smallest numbers representable to avoid miscomparisons due to overflows.
45 //
46 // The PTX below is compiled from the following CUDA code:
47 //
48 // #include<cuda_fp16.h>
49 // extern "C" { // avoid name mangling
50 // __device__ float __xla_buffer_comparator_canonicalize(float input) {
51 // // All fp16 infinities are treated as 65505 or -65505, in order to avoid
52 // // differences due to overflows.
53 // return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f));
54 // }
55
56 // __device__ float __xla_buffer_comparator_extract_int8(int pack) {
57 // // Extract the lower 8 bits from pack and convert it to float
58 // const unsigned int bit_mask = 0xff;
59 // unsigned int bits = pack & bit_mask;
60 // char* int8_ptr = (char*)&bits;
61 // return __int2float_rn(*int8_ptr);
62 // }
63
64 // __global__ void __xla_fp16_comparison(__half* buffer_a, __half* buffer_b,
65 // float rel_error_threshold,
66 // unsigned long long buffer_length,
67 // int* mismatch_count) {
68 // int idx = threadIdx.x + blockIdx.x * blockDim.x;
69 // if (idx >= buffer_length) return;
70 // float elem_a = __half2float(buffer_a[idx]);
71 // float elem_b = __half2float(buffer_b[idx]);
72 // elem_a = __xla_buffer_comparator_canonicalize(elem_a);
73 // elem_b = __xla_buffer_comparator_canonicalize(elem_b);
74 // if (isnan(elem_a) && isnan(elem_b)) return;
75 // float rel_error = abs(elem_a - elem_b)
76 // / (max(abs(elem_a), abs(elem_b)) + 1);
77 // if (rel_error > rel_error_threshold || isnan(rel_error))
78 // atomicAdd(mismatch_count, 1);
79 // }
80
81 // __global__ void __xla_fp32_comparison(float* buffer_a, float* buffer_b,
82 // float rel_error_threshold,
83 // unsigned long long buffer_length,
84 // int* mismatch_count) {
85 // int idx = threadIdx.x + blockIdx.x * blockDim.x;
86 // if (idx >= buffer_length) return;
87 // float elem_a = buffer_a[idx];
88 // float elem_b = buffer_b[idx];
89 // if (isnan(elem_a) && isnan(elem_b)) return;
90 // if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b))
91 // return;
92 // float rel_error = abs(elem_a - elem_b)
93 // / (max(abs(elem_a), abs(elem_b)) + 1);
94 // if (rel_error > rel_error_threshold || isnan(rel_error))
95 // atomicAdd(mismatch_count, 1);
96 // }
97
98 // __global__ void __xla_fp64_comparison(double* buffer_a, double* buffer_b,
99 // float rel_error_threshold,
100 // unsigned long long buffer_length,
101 // int* mismatch_count) {
102 // int idx = threadIdx.x + blockIdx.x * blockDim.x;
103 // if (idx >= buffer_length) return;
104 // double elem_a = buffer_a[idx];
105 // double elem_b = buffer_b[idx];
106 // if (isnan(elem_a) && isnan(elem_b)) return;
107 // if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b))
108 // return;
109 // double rel_error = abs(elem_a - elem_b)
110 // / (max(abs(elem_a), abs(elem_b)) + 1);
111 // if (rel_error > rel_error_threshold || isnan(rel_error))
112 // atomicAdd(mismatch_count, 1);
113 // }
114
115 // __global__ void __xla_int8_comparison(int* buffer_a, int* buffer_b,
116 // float rel_error_threshold,
117 // unsigned long long buffer_length,
118 // int* mismatch_count) {
119 // int idx = threadIdx.x + blockIdx.x * blockDim.x;
120 // if (idx >= buffer_length) return;
121 // int pack_a = buffer_a[idx];
122 // int pack_b = buffer_b[idx];
123 // for(int i = 0; i < 4; ++i) {
124 // float elem_a = __xla_buffer_comparator_extract_int8(pack_a);
125 // float elem_b = __xla_buffer_comparator_extract_int8(pack_b);
126 // float rel_error = abs(elem_a - elem_b)
127 // / (max(abs(elem_a), abs(elem_b)) + 1);
128 // if (rel_error > rel_error_threshold || isnan(rel_error))
129 // atomicAdd(mismatch_count, 1);
130 // pack_a >>= 8;
131 // pack_b >>= 8;
132 // }
133 // }
134 // } // end extern declaration.
135 static const char* buffer_compare_ptx = R"(
136 .version 4.2
137 .target sm_30
138 .address_size 64
139
140 // .globl __xla_fp16_comparison
141
142 .visible .entry __xla_fp16_comparison(
143 .param .u64 __xla_fp16_comparison_param_0,
144 .param .u64 __xla_fp16_comparison_param_1,
145 .param .f32 __xla_fp16_comparison_param_2,
146 .param .u64 __xla_fp16_comparison_param_3,
147 .param .u64 __xla_fp16_comparison_param_4
148 )
149 {
150 .reg .pred %p<9>;
151 .reg .b16 %rs<3>;
152 .reg .f32 %f<28>;
153 .reg .b32 %r<6>;
154 .reg .b64 %rd<12>;
155
156
157 ld.param.u64 %rd1, [__xla_fp16_comparison_param_0];
158 ld.param.u64 %rd2, [__xla_fp16_comparison_param_1];
159 ld.param.f32 %f10, [__xla_fp16_comparison_param_2];
160 ld.param.u64 %rd4, [__xla_fp16_comparison_param_3];
161 ld.param.u64 %rd3, [__xla_fp16_comparison_param_4];
162 mov.u32 %r2, %ntid.x;
163 mov.u32 %r3, %ctaid.x;
164 mov.u32 %r4, %tid.x;
165 mad.lo.s32 %r1, %r2, %r3, %r4;
166 cvt.s64.s32 %rd5, %r1;
167 setp.ge.u64 %p1, %rd5, %rd4;
168 @%p1 bra BB0_9;
169
170 cvta.to.global.u64 %rd6, %rd1;
171 mul.wide.s32 %rd7, %r1, 2;
172 add.s64 %rd8, %rd6, %rd7;
173 ld.global.u16 %rs1, [%rd8];
174 // inline asm
175 { cvt.f32.f16 %f26, %rs1;}
176
177 // inline asm
178 cvta.to.global.u64 %rd9, %rd2;
179 add.s64 %rd10, %rd9, %rd7;
180 ld.global.u16 %rs2, [%rd10];
181 // inline asm
182 { cvt.f32.f16 %f27, %rs2;}
183
184 // inline asm
185 abs.f32 %f13, %f26;
186 setp.gtu.f32 %p2, %f13, 0f7F800000;
187 @%p2 bra BB0_3;
188
189 mov.f32 %f14, 0f477FE100;
190 min.f32 %f15, %f26, %f14;
191 mov.f32 %f16, 0fC77FE100;
192 max.f32 %f26, %f16, %f15;
193
194 BB0_3:
195 abs.f32 %f17, %f27;
196 setp.gtu.f32 %p3, %f17, 0f7F800000;
197 @%p3 bra BB0_5;
198
199 mov.f32 %f18, 0f477FE100;
200 min.f32 %f19, %f27, %f18;
201 mov.f32 %f20, 0fC77FE100;
202 max.f32 %f27, %f20, %f19;
203
204 BB0_5:
205 abs.f32 %f7, %f26;
206 setp.gtu.f32 %p4, %f7, 0f7F800000;
207 abs.f32 %f8, %f27;
208 setp.gtu.f32 %p5, %f8, 0f7F800000;
209 and.pred %p6, %p4, %p5;
210 @%p6 bra BB0_9;
211
212 sub.f32 %f21, %f26, %f27;
213 abs.f32 %f22, %f21;
214 max.f32 %f23, %f7, %f8;
215 add.f32 %f24, %f23, 0f3F800000;
216 div.rn.f32 %f9, %f22, %f24;
217 setp.gt.f32 %p7, %f9, %f10;
218 @%p7 bra BB0_8;
219
220 abs.f32 %f25, %f9;
221 setp.le.f32 %p8, %f25, 0f7F800000;
222 @%p8 bra BB0_9;
223
224 BB0_8:
225 cvta.to.global.u64 %rd11, %rd3;
226 atom.global.add.u32 %r5, [%rd11], 1;
227
228 BB0_9:
229 ret;
230 }
231
232 // .globl __xla_fp32_comparison
233 .visible .entry __xla_fp32_comparison(
234 .param .u64 __xla_fp32_comparison_param_0,
235 .param .u64 __xla_fp32_comparison_param_1,
236 .param .f32 __xla_fp32_comparison_param_2,
237 .param .u64 __xla_fp32_comparison_param_3,
238 .param .u64 __xla_fp32_comparison_param_4
239 )
240 {
241 .reg .pred %p<10>;
242 .reg .b16 %rs<3>;
243 .reg .f32 %f<13>;
244 .reg .b32 %r<10>;
245 .reg .b64 %rd<12>;
246
247
248 ld.param.u64 %rd1, [__xla_fp32_comparison_param_0];
249 ld.param.u64 %rd2, [__xla_fp32_comparison_param_1];
250 ld.param.f32 %f6, [__xla_fp32_comparison_param_2];
251 ld.param.u64 %rd4, [__xla_fp32_comparison_param_3];
252 ld.param.u64 %rd3, [__xla_fp32_comparison_param_4];
253 mov.u32 %r2, %ntid.x;
254 mov.u32 %r3, %ctaid.x;
255 mov.u32 %r4, %tid.x;
256 mad.lo.s32 %r1, %r2, %r3, %r4;
257 cvt.s64.s32 %rd5, %r1;
258 setp.ge.u64 %p1, %rd5, %rd4;
259 @%p1 bra BB1_8;
260
261 cvta.to.global.u64 %rd6, %rd1;
262 mul.wide.s32 %rd7, %r1, 4;
263 add.s64 %rd8, %rd6, %rd7;
264 cvta.to.global.u64 %rd9, %rd2;
265 add.s64 %rd10, %rd9, %rd7;
266 ld.global.f32 %f1, [%rd10];
267 ld.global.f32 %f2, [%rd8];
268 abs.f32 %f3, %f2;
269 setp.le.f32 %p2, %f3, 0f7F800000;
270 @%p2 bra BB1_3;
271
272 abs.f32 %f7, %f1;
273 setp.gtu.f32 %p3, %f7, 0f7F800000;
274 @%p3 bra BB1_8;
275
276 BB1_3:
277 setp.neu.f32 %p4, %f3, 0f7F800000;
278 abs.f32 %f4, %f1;
279 setp.neu.f32 %p5, %f4, 0f7F800000;
280 or.pred %p6, %p4, %p5;
281 @%p6 bra BB1_5;
282
283 mov.b32 %r5, %f2;
284 shr.u32 %r6, %r5, 31;
285 cvt.u16.u32 %rs1, %r6;
286 mov.b32 %r7, %f1;
287 shr.u32 %r8, %r7, 31;
288 cvt.u16.u32 %rs2, %r8;
289 setp.eq.s16 %p7, %rs1, %rs2;
290 @%p7 bra BB1_8;
291
292 BB1_5:
293 sub.f32 %f8, %f2, %f1;
294 abs.f32 %f9, %f8;
295 max.f32 %f10, %f3, %f4;
296 add.f32 %f11, %f10, 0f3F800000;
297 div.rn.f32 %f5, %f9, %f11;
298 setp.gt.f32 %p8, %f5, %f6;
299 @%p8 bra BB1_7;
300
301 abs.f32 %f12, %f5;
302 setp.le.f32 %p9, %f12, 0f7F800000;
303 @%p9 bra BB1_8;
304
305 BB1_7:
306 cvta.to.global.u64 %rd11, %rd3;
307 atom.global.add.u32 %r9, [%rd11], 1;
308
309 BB1_8:
310 ret;
311 }
312
313 // .globl __xla_fp64_comparison
314 .visible .entry __xla_fp64_comparison(
315 .param .u64 __xla_fp64_comparison_param_0,
316 .param .u64 __xla_fp64_comparison_param_1,
317 .param .f32 __xla_fp64_comparison_param_2,
318 .param .u64 __xla_fp64_comparison_param_3,
319 .param .u64 __xla_fp64_comparison_param_4
320 )
321 {
322 .reg .pred %p<11>;
323 .reg .b16 %rs<3>;
324 .reg .f32 %f<2>;
325 .reg .b32 %r<14>;
326 .reg .f64 %fd<13>;
327 .reg .b64 %rd<12>;
328
329
330 ld.param.u64 %rd1, [__xla_fp64_comparison_param_0];
331 ld.param.u64 %rd2, [__xla_fp64_comparison_param_1];
332 ld.param.f32 %f1, [__xla_fp64_comparison_param_2];
333 ld.param.u64 %rd4, [__xla_fp64_comparison_param_3];
334 ld.param.u64 %rd3, [__xla_fp64_comparison_param_4];
335 mov.u32 %r4, %ntid.x;
336 mov.u32 %r5, %ctaid.x;
337 mov.u32 %r6, %tid.x;
338 mad.lo.s32 %r1, %r4, %r5, %r6;
339 cvt.s64.s32 %rd5, %r1;
340 setp.ge.u64 %p1, %rd5, %rd4;
341 @%p1 bra BB2_11;
342
343 cvta.to.global.u64 %rd6, %rd1;
344 mul.wide.s32 %rd7, %r1, 8;
345 add.s64 %rd8, %rd6, %rd7;
346 cvta.to.global.u64 %rd9, %rd2;
347 add.s64 %rd10, %rd9, %rd7;
348 ld.global.f64 %fd1, [%rd10];
349 ld.global.f64 %fd2, [%rd8];
350 abs.f64 %fd3, %fd2;
351 setp.le.f64 %p2, %fd3, 0d7FF0000000000000;
352 @%p2 bra BB2_3;
353
354 abs.f64 %fd5, %fd1;
355 setp.gtu.f64 %p3, %fd5, 0d7FF0000000000000;
356 @%p3 bra BB2_11;
357
358 BB2_3:
359 {
360 .reg .b32 %temp;
361 mov.b64 {%temp, %r2}, %fd2;
362 }
363 and.b32 %r7, %r2, 2147483647;
364 setp.ne.s32 %p4, %r7, 2146435072;
365 @%p4 bra BB2_8;
366
367 {
368 .reg .b32 %temp;
369 mov.b64 {%r8, %temp}, %fd2;
370 }
371 setp.ne.s32 %p5, %r8, 0;
372 @%p5 bra BB2_8;
373
374 {
375 .reg .b32 %temp;
376 mov.b64 {%temp, %r3}, %fd1;
377 }
378 and.b32 %r9, %r3, 2147483647;
379 setp.ne.s32 %p6, %r9, 2146435072;
380 @%p6 bra BB2_8;
381
382 {
383 .reg .b32 %temp;
384 mov.b64 {%r10, %temp}, %fd1;
385 }
386 setp.ne.s32 %p7, %r10, 0;
387 @%p7 bra BB2_8;
388
389 shr.u32 %r11, %r2, 31;
390 cvt.u16.u32 %rs1, %r11;
391 shr.u32 %r12, %r3, 31;
392 cvt.u16.u32 %rs2, %r12;
393 setp.eq.s16 %p8, %rs1, %rs2;
394 @%p8 bra BB2_11;
395
396 BB2_8:
397 sub.f64 %fd6, %fd2, %fd1;
398 abs.f64 %fd7, %fd6;
399 abs.f64 %fd8, %fd1;
400 max.f64 %fd9, %fd3, %fd8;
401 add.f64 %fd10, %fd9, 0d3FF0000000000000;
402 div.rn.f64 %fd4, %fd7, %fd10;
403 cvt.f64.f32 %fd11, %f1;
404 setp.gt.f64 %p9, %fd4, %fd11;
405 @%p9 bra BB2_10;
406
407 abs.f64 %fd12, %fd4;
408 setp.le.f64 %p10, %fd12, 0d7FF0000000000000;
409 @%p10 bra BB2_11;
410
411 BB2_10:
412 cvta.to.global.u64 %rd11, %rd3;
413 atom.global.add.u32 %r13, [%rd11], 1;
414
415 BB2_11:
416 ret;
417 }
418
419 // .globl __xla_int8_comparison
420 .visible .entry __xla_int8_comparison(
421 .param .u64 __xla_int8_comparison_param_0,
422 .param .u64 __xla_int8_comparison_param_1,
423 .param .f32 __xla_int8_comparison_param_2,
424 .param .u64 __xla_int8_comparison_param_3,
425 .param .u64 __xla_int8_comparison_param_4
426 )
427 {
428 .reg .pred %p<10>;
429 .reg .f32 %f<42>;
430 .reg .b32 %r<23>;
431 .reg .b64 %rd<12>;
432
433
434 ld.param.u64 %rd2, [__xla_int8_comparison_param_0];
435 ld.param.u64 %rd3, [__xla_int8_comparison_param_1];
436 ld.param.f32 %f5, [__xla_int8_comparison_param_2];
437 ld.param.u64 %rd4, [__xla_int8_comparison_param_3];
438 ld.param.u64 %rd5, [__xla_int8_comparison_param_4];
439 cvta.to.global.u64 %rd1, %rd5;
440 mov.u32 %r4, %ntid.x;
441 mov.u32 %r5, %ctaid.x;
442 mov.u32 %r6, %tid.x;
443 mad.lo.s32 %r1, %r4, %r5, %r6;
444 cvt.s64.s32 %rd6, %r1;
445 setp.ge.u64 %p1, %rd6, %rd4;
446 @%p1 bra BB3_13;
447
448 cvta.to.global.u64 %rd7, %rd2;
449 mul.wide.s32 %rd8, %r1, 4;
450 add.s64 %rd9, %rd7, %rd8;
451 cvta.to.global.u64 %rd10, %rd3;
452 add.s64 %rd11, %rd10, %rd8;
453 ld.global.u32 %r2, [%rd9];
454 cvt.s32.s8 %r7, %r2;
455 cvt.rn.f32.s32 %f6, %r7;
456 ld.global.u32 %r3, [%rd11];
457 cvt.s32.s8 %r8, %r3;
458 cvt.rn.f32.s32 %f7, %r8;
459 sub.f32 %f8, %f6, %f7;
460 abs.f32 %f9, %f8;
461 abs.f32 %f10, %f6;
462 abs.f32 %f11, %f7;
463 max.f32 %f12, %f10, %f11;
464 add.f32 %f13, %f12, 0f3F800000;
465 div.rn.f32 %f1, %f9, %f13;
466 setp.gt.f32 %p2, %f1, %f5;
467 @%p2 bra BB3_3;
468
469 abs.f32 %f14, %f1;
470 setp.le.f32 %p3, %f14, 0f7F800000;
471 @%p3 bra BB3_4;
472
473 BB3_3:
474 atom.global.add.u32 %r9, [%rd1], 1;
475
476 BB3_4:
477 shr.u32 %r10, %r3, 8;
478 shr.u32 %r11, %r2, 8;
479 cvt.s32.s8 %r12, %r11;
480 cvt.rn.f32.s32 %f15, %r12;
481 cvt.s32.s8 %r13, %r10;
482 cvt.rn.f32.s32 %f16, %r13;
483 sub.f32 %f17, %f15, %f16;
484 abs.f32 %f18, %f17;
485 abs.f32 %f19, %f15;
486 abs.f32 %f20, %f16;
487 max.f32 %f21, %f19, %f20;
488 add.f32 %f22, %f21, 0f3F800000;
489 div.rn.f32 %f2, %f18, %f22;
490 setp.gt.f32 %p4, %f2, %f5;
491 @%p4 bra BB3_6;
492
493 abs.f32 %f23, %f2;
494 setp.le.f32 %p5, %f23, 0f7F800000;
495 @%p5 bra BB3_7;
496
497 BB3_6:
498 atom.global.add.u32 %r14, [%rd1], 1;
499
500 BB3_7:
501 shr.u32 %r15, %r3, 16;
502 shr.u32 %r16, %r2, 16;
503 cvt.s32.s8 %r17, %r16;
504 cvt.rn.f32.s32 %f24, %r17;
505 cvt.s32.s8 %r18, %r15;
506 cvt.rn.f32.s32 %f25, %r18;
507 sub.f32 %f26, %f24, %f25;
508 abs.f32 %f27, %f26;
509 abs.f32 %f28, %f24;
510 abs.f32 %f29, %f25;
511 max.f32 %f30, %f28, %f29;
512 add.f32 %f31, %f30, 0f3F800000;
513 div.rn.f32 %f3, %f27, %f31;
514 setp.gt.f32 %p6, %f3, %f5;
515 @%p6 bra BB3_9;
516
517 abs.f32 %f32, %f3;
518 setp.le.f32 %p7, %f32, 0f7F800000;
519 @%p7 bra BB3_10;
520
521 BB3_9:
522 atom.global.add.u32 %r19, [%rd1], 1;
523
524 BB3_10:
525 shr.s32 %r20, %r2, 24;
526 cvt.rn.f32.s32 %f33, %r20;
527 shr.s32 %r21, %r3, 24;
528 cvt.rn.f32.s32 %f34, %r21;
529 sub.f32 %f35, %f33, %f34;
530 abs.f32 %f36, %f35;
531 abs.f32 %f37, %f33;
532 abs.f32 %f38, %f34;
533 max.f32 %f39, %f37, %f38;
534 add.f32 %f40, %f39, 0f3F800000;
535 div.rn.f32 %f4, %f36, %f40;
536 setp.gt.f32 %p8, %f4, %f5;
537 @%p8 bra BB3_12;
538
539 abs.f32 %f41, %f4;
540 setp.le.f32 %p9, %f41, 0f7F800000;
541 @%p9 bra BB3_13;
542
543 BB3_12:
544 atom.global.add.u32 %r22, [%rd1], 1;
545
546 BB3_13:
547 ret;
548 }
549 )";
550
551 template <typename ElementT>
552 using ComparisonKernelT =
553 se::TypedKernel<se::DeviceMemory<ElementT>, se::DeviceMemory<ElementT>,
554 float, uint64, se::DeviceMemory<uint64>>;
555
556 // Compares two buffers on the GPU.
557 //
558 // Returns `true` if two buffers are equal, `false` otherwise.
559 template <typename ElementT>
DeviceCompare(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs,const Shape & buffer_shape,const HloModuleConfig & config,absl::string_view kernel_name)560 static StatusOr<bool> DeviceCompare(se::Stream* stream,
561 se::DeviceMemoryBase lhs,
562 se::DeviceMemoryBase rhs,
563 const Shape& buffer_shape,
564 const HloModuleConfig& config,
565 absl::string_view kernel_name) {
566 se::StreamExecutor* executor = stream->parent();
567
568 se::ScopedDeviceMemory<uint64> out_param =
569 executor->AllocateOwnedScalar<uint64>();
570
571 stream->ThenMemZero(out_param.ptr(), sizeof(uint64));
572 if (lhs.size() != rhs.size()) {
573 return InternalError("Mismatched buffer size: %d bytes vs. %d bytes",
574 lhs.size(), rhs.size());
575 }
576
577 se::DeviceMemory<ElementT> lhs_typed(lhs);
578 se::DeviceMemory<ElementT> rhs_typed(rhs);
579 uint64 buffer_size = lhs_typed.ElementCount();
580
581 absl::Span<const uint8> compiled_ptx = {};
582 StatusOr<absl::Span<const uint8>> compiled_ptx_or =
583 se::CompileGpuAsmOrGetCached(executor->device_ordinal(),
584 buffer_compare_ptx,
585 PtxOptsFromConfig(config));
586 if (compiled_ptx_or.ok()) {
587 compiled_ptx = compiled_ptx_or.ConsumeValueOrDie();
588 } else {
589 static absl::once_flag ptxas_not_found_logged;
590 absl::call_once(ptxas_not_found_logged, [&]() {
591 LOG(WARNING)
592 << compiled_ptx_or.status().ToString()
593 << "\nRelying on driver to perform ptx compilation. "
594 << "\nSetting XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda "
595 << " or modifying $PATH can be used to set the location of ptxas"
596 << "\nThis message will only be logged once.";
597 });
598 }
599
600 TF_ASSIGN_OR_RETURN(
601 std::unique_ptr<ComparisonKernelT<ElementT>> comparison_kernel,
602 (executor->CreateTypedKernel<se::DeviceMemory<ElementT>,
603 se::DeviceMemory<ElementT>, float, uint64,
604 se::DeviceMemory<uint64>>(
605 kernel_name, buffer_compare_ptx, compiled_ptx)));
606
607 LaunchDimensions dim =
608 CalculateLaunchDimensions(buffer_shape, executor->GetDeviceDescription());
609
610 stream->ThenLaunch(se::ThreadDim(dim.threads_per_block()),
611 se::BlockDim(dim.block_count()), *comparison_kernel,
612 lhs_typed, rhs_typed, static_cast<float>(kTolerance),
613 buffer_size, out_param.cref());
614
615 uint64 result = -1;
616 CHECK_EQ(out_param->size(), sizeof(result));
617 stream->ThenMemcpy(&result, *out_param, sizeof(result));
618 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
619 return result == 0;
620 }
621
622 // Host side comparison code that does the same thing, but reports some of the
623 // differences as well. It only print logs for debugging.
624 //
625 // Returns true if no differences were seen, false otherwise.
626 template <typename ElementType, typename ComparisonType>
HostCompare(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs)627 StatusOr<bool> HostCompare(se::Stream* stream, se::DeviceMemoryBase lhs,
628 se::DeviceMemoryBase rhs) {
629 int64 n = lhs.size() / sizeof(ElementType);
630 std::vector<ElementType> host_lhs(n), host_rhs(n);
631 stream->ThenMemcpy(host_lhs.data(), lhs, lhs.size());
632 stream->ThenMemcpy(host_rhs.data(), rhs, rhs.size());
633 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
634
635 const auto canonicalize = [](ComparisonType a) -> ComparisonType {
636 if (std::is_same<ElementType, Eigen::half>::value && a) {
637 constexpr ComparisonType kMaxFp16Value = 65505.;
638 if (std::isnan(a)) {
639 return a;
640 }
641 return std::max(-kMaxFp16Value, std::min(a, kMaxFp16Value));
642 }
643 return a;
644 };
645 int differences_seen = 0;
646 for (int64 i = 0; i < n && differences_seen < 10; i++) {
647 auto original_lhs = static_cast<ComparisonType>(host_lhs[i]);
648 auto original_rhs = static_cast<ComparisonType>(host_rhs[i]);
649 ComparisonType lhs = canonicalize(original_lhs);
650 ComparisonType rhs = canonicalize(original_rhs);
651 if (std::isnan(lhs) && std::isnan(rhs)) {
652 continue;
653 }
654 if (std::isinf(lhs) && std::isinf(rhs) && lhs == rhs) {
655 continue;
656 }
657 if (std::isfinite(lhs) != std::isfinite(rhs) ||
658 !(std::abs(lhs - rhs) / (std::max(std::abs(lhs), std::abs(rhs)) + 1) <
659 kTolerance)) {
660 differences_seen++;
661 LOG(ERROR) << "Difference at " << i << ": " << original_lhs << " vs "
662 << original_rhs;
663 }
664 }
665 return differences_seen == 0;
666 }
667
668 template <typename ElementT, typename ComparisonT>
CompareEqualParameterized(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs,const Shape & shape,const HloModuleConfig & config,absl::string_view kernel_name)669 static StatusOr<bool> CompareEqualParameterized(se::Stream* stream,
670 se::DeviceMemoryBase lhs,
671 se::DeviceMemoryBase rhs,
672 const Shape& shape,
673 const HloModuleConfig& config,
674 absl::string_view kernel_name) {
675 XLA_SCOPED_LOGGING_TIMER("BufferComparator::CompareEqual");
676 TF_ASSIGN_OR_RETURN(
677 bool result,
678 DeviceCompare<ElementT>(stream, lhs, rhs, shape, config, kernel_name));
679
680 if (result) {
681 return true;
682 }
683
684 TF_ASSIGN_OR_RETURN(bool host_return,
685 (HostCompare<ElementT, ComparisonT>(stream, lhs, rhs)));
686 CHECK(host_return == result) << "Different comparison result on GPU vs host";
687
688 return false;
689 }
690
CompareEqual(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs) const691 StatusOr<bool> BufferComparator::CompareEqual(se::Stream* stream,
692 se::DeviceMemoryBase lhs,
693 se::DeviceMemoryBase rhs) const {
694 switch (shape_.element_type()) {
695 case xla::F16:
696 return CompareEqualParameterized<Eigen::half, float>(
697 stream, lhs, rhs, shape_, config_, "__xla_fp16_comparison");
698 case xla::F32:
699 return CompareEqualParameterized<float, float>(
700 stream, lhs, rhs, shape_, config_, "__xla_fp32_comparison");
701 case xla::F64:
702 return CompareEqualParameterized<double, double>(
703 stream, lhs, rhs, shape_, config_, "__xla_fp64_comparison");
704 case xla::S8:
705 return CompareEqualParameterized<int8, float>(
706 stream, lhs, rhs, shape_, config_, "__xla_int8_comparison");
707 default:
708 return Unimplemented("Unimplemented element type");
709 }
710 }
711
BufferComparator(const Shape & shape,const HloModuleConfig & config)712 BufferComparator::BufferComparator(const Shape& shape,
713 const HloModuleConfig& config)
714 : shape_(shape), config_(config) {
715 // Normalize complex shapes: since we treat the passed array as a contiguous
716 // storage it does not matter which dimension are we doubling.
717 auto double_dim_size = [&]() {
718 int64 prev_zero_dim_size = shape_.dimensions(0);
719 shape_.set_dimensions(0, prev_zero_dim_size * 2);
720 };
721
722 if (shape_.element_type() == PrimitiveType::C64) {
723 // C64 is just two F32s next to each other.
724 shape_.set_element_type(PrimitiveType::F32);
725 double_dim_size();
726 } else if (shape_.element_type() == PrimitiveType::C128) {
727 // C128 is just two F64s next to each other.
728 shape_.set_element_type(PrimitiveType::F64);
729 double_dim_size();
730 }
731 }
732
733 } // namespace gpu
734 } // namespace xla
735