• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
17 
18 #include <algorithm>
19 #include <cmath>
20 
21 #include "absl/base/call_once.h"
22 #include "absl/strings/str_replace.h"
23 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
24 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/stream_executor/device_memory.h"
30 #include "tensorflow/stream_executor/kernel.h"
31 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
32 
33 namespace xla {
34 namespace gpu {
35 
36 static constexpr double kTolerance = 0.1f;
37 
38 // Comparison kernel code: compare two buffers of fp16/fp32/fp64/int8 of length
39 // buffer_length where the relative error does not exceed the passed
40 // rel_error_threshold. Write the number of mismatches into out parameter
41 // mismatch_count.
42 //
43 // NaN's are considered equal, and for half's we clamp all numbers to largest
44 // and smallest numbers representable to avoid miscomparisons due to overflows.
45 //
46 // The PTX below is compiled from the following CUDA code:
47 //
48 // #include<cuda_fp16.h>
49 // extern "C" { // avoid name mangling
50 // __device__ float __xla_buffer_comparator_canonicalize(float input) {
51 //   // All fp16 infinities are treated as 65505 or -65505, in order to avoid
52 //   // differences due to overflows.
53 //   return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f));
54 // }
55 
56 // __device__ float __xla_buffer_comparator_extract_int8(int pack) {
57 //   // Extract the lower 8 bits from pack and convert it to float
58 //   const unsigned int bit_mask = 0xff;
59 //   unsigned int bits = pack & bit_mask;
60 //   char* int8_ptr = (char*)&bits;
61 //   return __int2float_rn(*int8_ptr);
62 // }
63 
64 // __global__ void __xla_fp16_comparison(__half* buffer_a, __half* buffer_b,
65 //                                       float rel_error_threshold,
66 //                                       unsigned long long buffer_length,
67 //                                       int* mismatch_count) {
68 //   int idx = threadIdx.x + blockIdx.x * blockDim.x;
69 //   if (idx >= buffer_length) return;
70 //   float elem_a = __half2float(buffer_a[idx]);
71 //   float elem_b = __half2float(buffer_b[idx]);
72 //   elem_a = __xla_buffer_comparator_canonicalize(elem_a);
73 //   elem_b = __xla_buffer_comparator_canonicalize(elem_b);
74 //   if (isnan(elem_a) && isnan(elem_b)) return;
75 //   float rel_error = abs(elem_a - elem_b)
76 //       / (max(abs(elem_a), abs(elem_b)) + 1);
77 //   if (rel_error > rel_error_threshold || isnan(rel_error))
78 //     atomicAdd(mismatch_count, 1);
79 // }
80 
81 // __global__ void __xla_fp32_comparison(float* buffer_a, float* buffer_b,
82 //                                       float rel_error_threshold,
83 //                                       unsigned long long buffer_length,
84 //                                       int* mismatch_count) {
85 //   int idx = threadIdx.x + blockIdx.x * blockDim.x;
86 //   if (idx >= buffer_length) return;
87 //   float elem_a = buffer_a[idx];
88 //   float elem_b = buffer_b[idx];
89 //   if (isnan(elem_a) && isnan(elem_b)) return;
90 //   if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b))
91 //     return;
92 //   float rel_error = abs(elem_a - elem_b)
93 //       / (max(abs(elem_a), abs(elem_b)) + 1);
94 //   if (rel_error > rel_error_threshold || isnan(rel_error))
95 //     atomicAdd(mismatch_count, 1);
96 // }
97 
98 // __global__ void __xla_fp64_comparison(double* buffer_a, double* buffer_b,
99 //                                       float rel_error_threshold,
100 //                                       unsigned long long buffer_length,
101 //                                       int* mismatch_count) {
102 //   int idx = threadIdx.x + blockIdx.x * blockDim.x;
103 //   if (idx >= buffer_length) return;
104 //   double elem_a = buffer_a[idx];
105 //   double elem_b = buffer_b[idx];
106 //   if (isnan(elem_a) && isnan(elem_b)) return;
107 //   if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b))
108 //     return;
109 //   double rel_error = abs(elem_a - elem_b)
110 //       / (max(abs(elem_a), abs(elem_b)) + 1);
111 //   if (rel_error > rel_error_threshold || isnan(rel_error))
112 //     atomicAdd(mismatch_count, 1);
113 // }
114 
115 // __global__ void __xla_int8_comparison(int* buffer_a, int* buffer_b,
116 //                                       float rel_error_threshold,
117 //                                       unsigned long long buffer_length,
118 //                                       int* mismatch_count) {
119 //   int idx = threadIdx.x + blockIdx.x * blockDim.x;
120 //   if (idx >= buffer_length) return;
121 //   int pack_a = buffer_a[idx];
122 //   int pack_b = buffer_b[idx];
123 //   for(int i = 0; i < 4; ++i) {
124 //     float elem_a = __xla_buffer_comparator_extract_int8(pack_a);
125 //     float elem_b = __xla_buffer_comparator_extract_int8(pack_b);
126 //     float rel_error = abs(elem_a - elem_b)
127 //         / (max(abs(elem_a), abs(elem_b)) + 1);
128 //     if (rel_error > rel_error_threshold || isnan(rel_error))
129 //         atomicAdd(mismatch_count, 1);
130 //     pack_a >>= 8;
131 //     pack_b >>= 8;
132 //   }
133 // }
134 // } // end extern declaration.
135 static const char* buffer_compare_ptx = R"(
136 .version 4.2
137 .target sm_30
138 .address_size 64
139 
140  // .globl __xla_fp16_comparison
141 
142 .visible .entry __xla_fp16_comparison(
143  .param .u64 __xla_fp16_comparison_param_0,
144  .param .u64 __xla_fp16_comparison_param_1,
145  .param .f32 __xla_fp16_comparison_param_2,
146  .param .u64 __xla_fp16_comparison_param_3,
147  .param .u64 __xla_fp16_comparison_param_4
148 )
149 {
150  .reg .pred  %p<9>;
151  .reg .b16  %rs<3>;
152  .reg .f32  %f<28>;
153  .reg .b32  %r<6>;
154  .reg .b64  %rd<12>;
155 
156 
157  ld.param.u64  %rd1, [__xla_fp16_comparison_param_0];
158  ld.param.u64  %rd2, [__xla_fp16_comparison_param_1];
159  ld.param.f32  %f10, [__xla_fp16_comparison_param_2];
160  ld.param.u64  %rd4, [__xla_fp16_comparison_param_3];
161  ld.param.u64  %rd3, [__xla_fp16_comparison_param_4];
162  mov.u32  %r2, %ntid.x;
163  mov.u32  %r3, %ctaid.x;
164  mov.u32  %r4, %tid.x;
165  mad.lo.s32  %r1, %r2, %r3, %r4;
166  cvt.s64.s32 %rd5, %r1;
167  setp.ge.u64 %p1, %rd5, %rd4;
168  @%p1 bra  BB0_9;
169 
170  cvta.to.global.u64  %rd6, %rd1;
171  mul.wide.s32  %rd7, %r1, 2;
172  add.s64  %rd8, %rd6, %rd7;
173  ld.global.u16  %rs1, [%rd8];
174  // inline asm
175  {  cvt.f32.f16 %f26, %rs1;}
176 
177  // inline asm
178  cvta.to.global.u64  %rd9, %rd2;
179  add.s64  %rd10, %rd9, %rd7;
180  ld.global.u16  %rs2, [%rd10];
181  // inline asm
182  {  cvt.f32.f16 %f27, %rs2;}
183 
184  // inline asm
185  abs.f32  %f13, %f26;
186  setp.gtu.f32 %p2, %f13, 0f7F800000;
187  @%p2 bra  BB0_3;
188 
189  mov.f32  %f14, 0f477FE100;
190  min.f32  %f15, %f26, %f14;
191  mov.f32  %f16, 0fC77FE100;
192  max.f32  %f26, %f16, %f15;
193 
194 BB0_3:
195  abs.f32  %f17, %f27;
196  setp.gtu.f32 %p3, %f17, 0f7F800000;
197  @%p3 bra  BB0_5;
198 
199  mov.f32  %f18, 0f477FE100;
200  min.f32  %f19, %f27, %f18;
201  mov.f32  %f20, 0fC77FE100;
202  max.f32  %f27, %f20, %f19;
203 
204 BB0_5:
205  abs.f32  %f7, %f26;
206  setp.gtu.f32 %p4, %f7, 0f7F800000;
207  abs.f32  %f8, %f27;
208  setp.gtu.f32 %p5, %f8, 0f7F800000;
209  and.pred   %p6, %p4, %p5;
210  @%p6 bra  BB0_9;
211 
212  sub.f32  %f21, %f26, %f27;
213  abs.f32  %f22, %f21;
214  max.f32  %f23, %f7, %f8;
215  add.f32  %f24, %f23, 0f3F800000;
216  div.rn.f32  %f9, %f22, %f24;
217  setp.gt.f32 %p7, %f9, %f10;
218  @%p7 bra  BB0_8;
219 
220  abs.f32  %f25, %f9;
221  setp.le.f32 %p8, %f25, 0f7F800000;
222  @%p8 bra  BB0_9;
223 
224 BB0_8:
225  cvta.to.global.u64  %rd11, %rd3;
226  atom.global.add.u32  %r5, [%rd11], 1;
227 
228 BB0_9:
229  ret;
230 }
231 
232  // .globl __xla_fp32_comparison
233 .visible .entry __xla_fp32_comparison(
234  .param .u64 __xla_fp32_comparison_param_0,
235  .param .u64 __xla_fp32_comparison_param_1,
236  .param .f32 __xla_fp32_comparison_param_2,
237  .param .u64 __xla_fp32_comparison_param_3,
238  .param .u64 __xla_fp32_comparison_param_4
239 )
240 {
241  .reg .pred  %p<10>;
242  .reg .b16  %rs<3>;
243  .reg .f32  %f<13>;
244  .reg .b32  %r<10>;
245  .reg .b64  %rd<12>;
246 
247 
248  ld.param.u64  %rd1, [__xla_fp32_comparison_param_0];
249  ld.param.u64  %rd2, [__xla_fp32_comparison_param_1];
250  ld.param.f32  %f6, [__xla_fp32_comparison_param_2];
251  ld.param.u64  %rd4, [__xla_fp32_comparison_param_3];
252  ld.param.u64  %rd3, [__xla_fp32_comparison_param_4];
253  mov.u32  %r2, %ntid.x;
254  mov.u32  %r3, %ctaid.x;
255  mov.u32  %r4, %tid.x;
256  mad.lo.s32  %r1, %r2, %r3, %r4;
257  cvt.s64.s32 %rd5, %r1;
258  setp.ge.u64 %p1, %rd5, %rd4;
259  @%p1 bra  BB1_8;
260 
261  cvta.to.global.u64  %rd6, %rd1;
262  mul.wide.s32  %rd7, %r1, 4;
263  add.s64  %rd8, %rd6, %rd7;
264  cvta.to.global.u64  %rd9, %rd2;
265  add.s64  %rd10, %rd9, %rd7;
266  ld.global.f32  %f1, [%rd10];
267  ld.global.f32  %f2, [%rd8];
268  abs.f32  %f3, %f2;
269  setp.le.f32 %p2, %f3, 0f7F800000;
270  @%p2 bra  BB1_3;
271 
272  abs.f32  %f7, %f1;
273  setp.gtu.f32 %p3, %f7, 0f7F800000;
274  @%p3 bra  BB1_8;
275 
276 BB1_3:
277  setp.neu.f32 %p4, %f3, 0f7F800000;
278  abs.f32  %f4, %f1;
279  setp.neu.f32 %p5, %f4, 0f7F800000;
280  or.pred   %p6, %p4, %p5;
281  @%p6 bra  BB1_5;
282 
283  mov.b32   %r5, %f2;
284  shr.u32  %r6, %r5, 31;
285  cvt.u16.u32 %rs1, %r6;
286  mov.b32   %r7, %f1;
287  shr.u32  %r8, %r7, 31;
288  cvt.u16.u32 %rs2, %r8;
289  setp.eq.s16 %p7, %rs1, %rs2;
290  @%p7 bra  BB1_8;
291 
292 BB1_5:
293  sub.f32  %f8, %f2, %f1;
294  abs.f32  %f9, %f8;
295  max.f32  %f10, %f3, %f4;
296  add.f32  %f11, %f10, 0f3F800000;
297  div.rn.f32  %f5, %f9, %f11;
298  setp.gt.f32 %p8, %f5, %f6;
299  @%p8 bra  BB1_7;
300 
301  abs.f32  %f12, %f5;
302  setp.le.f32 %p9, %f12, 0f7F800000;
303  @%p9 bra  BB1_8;
304 
305 BB1_7:
306  cvta.to.global.u64  %rd11, %rd3;
307  atom.global.add.u32  %r9, [%rd11], 1;
308 
309 BB1_8:
310  ret;
311 }
312 
313  // .globl __xla_fp64_comparison
314 .visible .entry __xla_fp64_comparison(
315  .param .u64 __xla_fp64_comparison_param_0,
316  .param .u64 __xla_fp64_comparison_param_1,
317  .param .f32 __xla_fp64_comparison_param_2,
318  .param .u64 __xla_fp64_comparison_param_3,
319  .param .u64 __xla_fp64_comparison_param_4
320 )
321 {
322  .reg .pred  %p<11>;
323  .reg .b16  %rs<3>;
324  .reg .f32  %f<2>;
325  .reg .b32  %r<14>;
326  .reg .f64  %fd<13>;
327  .reg .b64  %rd<12>;
328 
329 
330  ld.param.u64  %rd1, [__xla_fp64_comparison_param_0];
331  ld.param.u64  %rd2, [__xla_fp64_comparison_param_1];
332  ld.param.f32  %f1, [__xla_fp64_comparison_param_2];
333  ld.param.u64  %rd4, [__xla_fp64_comparison_param_3];
334  ld.param.u64  %rd3, [__xla_fp64_comparison_param_4];
335  mov.u32  %r4, %ntid.x;
336  mov.u32  %r5, %ctaid.x;
337  mov.u32  %r6, %tid.x;
338  mad.lo.s32  %r1, %r4, %r5, %r6;
339  cvt.s64.s32 %rd5, %r1;
340  setp.ge.u64 %p1, %rd5, %rd4;
341  @%p1 bra  BB2_11;
342 
343  cvta.to.global.u64  %rd6, %rd1;
344  mul.wide.s32  %rd7, %r1, 8;
345  add.s64  %rd8, %rd6, %rd7;
346  cvta.to.global.u64  %rd9, %rd2;
347  add.s64  %rd10, %rd9, %rd7;
348  ld.global.f64  %fd1, [%rd10];
349  ld.global.f64  %fd2, [%rd8];
350  abs.f64  %fd3, %fd2;
351  setp.le.f64 %p2, %fd3, 0d7FF0000000000000;
352  @%p2 bra  BB2_3;
353 
354  abs.f64  %fd5, %fd1;
355  setp.gtu.f64 %p3, %fd5, 0d7FF0000000000000;
356  @%p3 bra  BB2_11;
357 
358 BB2_3:
359  {
360  .reg .b32 %temp;
361  mov.b64  {%temp, %r2}, %fd2;
362  }
363  and.b32   %r7, %r2, 2147483647;
364  setp.ne.s32 %p4, %r7, 2146435072;
365  @%p4 bra  BB2_8;
366 
367  {
368  .reg .b32 %temp;
369  mov.b64  {%r8, %temp}, %fd2;
370  }
371  setp.ne.s32 %p5, %r8, 0;
372  @%p5 bra  BB2_8;
373 
374  {
375  .reg .b32 %temp;
376  mov.b64  {%temp, %r3}, %fd1;
377  }
378  and.b32   %r9, %r3, 2147483647;
379  setp.ne.s32 %p6, %r9, 2146435072;
380  @%p6 bra  BB2_8;
381 
382  {
383  .reg .b32 %temp;
384  mov.b64  {%r10, %temp}, %fd1;
385  }
386  setp.ne.s32 %p7, %r10, 0;
387  @%p7 bra  BB2_8;
388 
389  shr.u32  %r11, %r2, 31;
390  cvt.u16.u32 %rs1, %r11;
391  shr.u32  %r12, %r3, 31;
392  cvt.u16.u32 %rs2, %r12;
393  setp.eq.s16 %p8, %rs1, %rs2;
394  @%p8 bra  BB2_11;
395 
396 BB2_8:
397  sub.f64  %fd6, %fd2, %fd1;
398  abs.f64  %fd7, %fd6;
399  abs.f64  %fd8, %fd1;
400  max.f64  %fd9, %fd3, %fd8;
401  add.f64  %fd10, %fd9, 0d3FF0000000000000;
402  div.rn.f64  %fd4, %fd7, %fd10;
403  cvt.f64.f32 %fd11, %f1;
404  setp.gt.f64 %p9, %fd4, %fd11;
405  @%p9 bra  BB2_10;
406 
407  abs.f64  %fd12, %fd4;
408  setp.le.f64 %p10, %fd12, 0d7FF0000000000000;
409  @%p10 bra  BB2_11;
410 
411 BB2_10:
412  cvta.to.global.u64  %rd11, %rd3;
413  atom.global.add.u32  %r13, [%rd11], 1;
414 
415 BB2_11:
416  ret;
417 }
418 
419  // .globl __xla_int8_comparison
420 .visible .entry __xla_int8_comparison(
421  .param .u64 __xla_int8_comparison_param_0,
422  .param .u64 __xla_int8_comparison_param_1,
423  .param .f32 __xla_int8_comparison_param_2,
424  .param .u64 __xla_int8_comparison_param_3,
425  .param .u64 __xla_int8_comparison_param_4
426 )
427 {
428  .reg .pred  %p<10>;
429  .reg .f32  %f<42>;
430  .reg .b32  %r<23>;
431  .reg .b64  %rd<12>;
432 
433 
434  ld.param.u64  %rd2, [__xla_int8_comparison_param_0];
435  ld.param.u64  %rd3, [__xla_int8_comparison_param_1];
436  ld.param.f32  %f5, [__xla_int8_comparison_param_2];
437  ld.param.u64  %rd4, [__xla_int8_comparison_param_3];
438  ld.param.u64  %rd5, [__xla_int8_comparison_param_4];
439  cvta.to.global.u64  %rd1, %rd5;
440  mov.u32  %r4, %ntid.x;
441  mov.u32  %r5, %ctaid.x;
442  mov.u32  %r6, %tid.x;
443  mad.lo.s32  %r1, %r4, %r5, %r6;
444  cvt.s64.s32 %rd6, %r1;
445  setp.ge.u64 %p1, %rd6, %rd4;
446  @%p1 bra  BB3_13;
447 
448  cvta.to.global.u64  %rd7, %rd2;
449  mul.wide.s32  %rd8, %r1, 4;
450  add.s64  %rd9, %rd7, %rd8;
451  cvta.to.global.u64  %rd10, %rd3;
452  add.s64  %rd11, %rd10, %rd8;
453  ld.global.u32  %r2, [%rd9];
454  cvt.s32.s8  %r7, %r2;
455  cvt.rn.f32.s32 %f6, %r7;
456  ld.global.u32  %r3, [%rd11];
457  cvt.s32.s8  %r8, %r3;
458  cvt.rn.f32.s32 %f7, %r8;
459  sub.f32  %f8, %f6, %f7;
460  abs.f32  %f9, %f8;
461  abs.f32  %f10, %f6;
462  abs.f32  %f11, %f7;
463  max.f32  %f12, %f10, %f11;
464  add.f32  %f13, %f12, 0f3F800000;
465  div.rn.f32  %f1, %f9, %f13;
466  setp.gt.f32 %p2, %f1, %f5;
467  @%p2 bra  BB3_3;
468 
469  abs.f32  %f14, %f1;
470  setp.le.f32 %p3, %f14, 0f7F800000;
471  @%p3 bra  BB3_4;
472 
473 BB3_3:
474  atom.global.add.u32  %r9, [%rd1], 1;
475 
476 BB3_4:
477  shr.u32  %r10, %r3, 8;
478  shr.u32  %r11, %r2, 8;
479  cvt.s32.s8  %r12, %r11;
480  cvt.rn.f32.s32 %f15, %r12;
481  cvt.s32.s8  %r13, %r10;
482  cvt.rn.f32.s32 %f16, %r13;
483  sub.f32  %f17, %f15, %f16;
484  abs.f32  %f18, %f17;
485  abs.f32  %f19, %f15;
486  abs.f32  %f20, %f16;
487  max.f32  %f21, %f19, %f20;
488  add.f32  %f22, %f21, 0f3F800000;
489  div.rn.f32  %f2, %f18, %f22;
490  setp.gt.f32 %p4, %f2, %f5;
491  @%p4 bra  BB3_6;
492 
493  abs.f32  %f23, %f2;
494  setp.le.f32 %p5, %f23, 0f7F800000;
495  @%p5 bra  BB3_7;
496 
497 BB3_6:
498  atom.global.add.u32  %r14, [%rd1], 1;
499 
500 BB3_7:
501  shr.u32  %r15, %r3, 16;
502  shr.u32  %r16, %r2, 16;
503  cvt.s32.s8  %r17, %r16;
504  cvt.rn.f32.s32 %f24, %r17;
505  cvt.s32.s8  %r18, %r15;
506  cvt.rn.f32.s32 %f25, %r18;
507  sub.f32  %f26, %f24, %f25;
508  abs.f32  %f27, %f26;
509  abs.f32  %f28, %f24;
510  abs.f32  %f29, %f25;
511  max.f32  %f30, %f28, %f29;
512  add.f32  %f31, %f30, 0f3F800000;
513  div.rn.f32  %f3, %f27, %f31;
514  setp.gt.f32 %p6, %f3, %f5;
515  @%p6 bra  BB3_9;
516 
517  abs.f32  %f32, %f3;
518  setp.le.f32 %p7, %f32, 0f7F800000;
519  @%p7 bra  BB3_10;
520 
521 BB3_9:
522  atom.global.add.u32  %r19, [%rd1], 1;
523 
524 BB3_10:
525  shr.s32  %r20, %r2, 24;
526  cvt.rn.f32.s32 %f33, %r20;
527  shr.s32  %r21, %r3, 24;
528  cvt.rn.f32.s32 %f34, %r21;
529  sub.f32  %f35, %f33, %f34;
530  abs.f32  %f36, %f35;
531  abs.f32  %f37, %f33;
532  abs.f32  %f38, %f34;
533  max.f32  %f39, %f37, %f38;
534  add.f32  %f40, %f39, 0f3F800000;
535  div.rn.f32  %f4, %f36, %f40;
536  setp.gt.f32 %p8, %f4, %f5;
537  @%p8 bra  BB3_12;
538 
539  abs.f32  %f41, %f4;
540  setp.le.f32 %p9, %f41, 0f7F800000;
541  @%p9 bra  BB3_13;
542 
543 BB3_12:
544  atom.global.add.u32  %r22, [%rd1], 1;
545 
546 BB3_13:
547  ret;
548 }
549 )";
550 
551 template <typename ElementT>
552 using ComparisonKernelT =
553     se::TypedKernel<se::DeviceMemory<ElementT>, se::DeviceMemory<ElementT>,
554                     float, uint64, se::DeviceMemory<uint64>>;
555 
556 // Compares two buffers on the GPU.
557 //
558 // Returns `true` if two buffers are equal, `false` otherwise.
559 template <typename ElementT>
DeviceCompare(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs,const Shape & buffer_shape,const HloModuleConfig & config,absl::string_view kernel_name)560 static StatusOr<bool> DeviceCompare(se::Stream* stream,
561                                     se::DeviceMemoryBase lhs,
562                                     se::DeviceMemoryBase rhs,
563                                     const Shape& buffer_shape,
564                                     const HloModuleConfig& config,
565                                     absl::string_view kernel_name) {
566   se::StreamExecutor* executor = stream->parent();
567 
568   se::ScopedDeviceMemory<uint64> out_param =
569       executor->AllocateOwnedScalar<uint64>();
570 
571   stream->ThenMemZero(out_param.ptr(), sizeof(uint64));
572   if (lhs.size() != rhs.size()) {
573     return InternalError("Mismatched buffer size: %d bytes vs. %d bytes",
574                          lhs.size(), rhs.size());
575   }
576 
577   se::DeviceMemory<ElementT> lhs_typed(lhs);
578   se::DeviceMemory<ElementT> rhs_typed(rhs);
579   uint64 buffer_size = lhs_typed.ElementCount();
580 
581   absl::Span<const uint8> compiled_ptx = {};
582   StatusOr<absl::Span<const uint8>> compiled_ptx_or =
583       se::CompileGpuAsmOrGetCached(executor->device_ordinal(),
584                                    buffer_compare_ptx,
585                                    PtxOptsFromConfig(config));
586   if (compiled_ptx_or.ok()) {
587     compiled_ptx = compiled_ptx_or.ConsumeValueOrDie();
588   } else {
589     static absl::once_flag ptxas_not_found_logged;
590     absl::call_once(ptxas_not_found_logged, [&]() {
591       LOG(WARNING)
592           << compiled_ptx_or.status().ToString()
593           << "\nRelying on driver to perform ptx compilation. "
594           << "\nSetting XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda "
595           << " or modifying $PATH can be used to set the location of ptxas"
596           << "\nThis message will only be logged once.";
597     });
598   }
599 
600   TF_ASSIGN_OR_RETURN(
601       std::unique_ptr<ComparisonKernelT<ElementT>> comparison_kernel,
602       (executor->CreateTypedKernel<se::DeviceMemory<ElementT>,
603                                    se::DeviceMemory<ElementT>, float, uint64,
604                                    se::DeviceMemory<uint64>>(
605           kernel_name, buffer_compare_ptx, compiled_ptx)));
606 
607   LaunchDimensions dim =
608       CalculateLaunchDimensions(buffer_shape, executor->GetDeviceDescription());
609 
610   stream->ThenLaunch(se::ThreadDim(dim.threads_per_block()),
611                      se::BlockDim(dim.block_count()), *comparison_kernel,
612                      lhs_typed, rhs_typed, static_cast<float>(kTolerance),
613                      buffer_size, out_param.cref());
614 
615   uint64 result = -1;
616   CHECK_EQ(out_param->size(), sizeof(result));
617   stream->ThenMemcpy(&result, *out_param, sizeof(result));
618   TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
619   return result == 0;
620 }
621 
622 // Host side comparison code that does the same thing, but reports some of the
623 // differences as well. It only print logs for debugging.
624 //
625 // Returns true if no differences were seen, false otherwise.
626 template <typename ElementType, typename ComparisonType>
HostCompare(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs)627 StatusOr<bool> HostCompare(se::Stream* stream, se::DeviceMemoryBase lhs,
628                            se::DeviceMemoryBase rhs) {
629   int64 n = lhs.size() / sizeof(ElementType);
630   std::vector<ElementType> host_lhs(n), host_rhs(n);
631   stream->ThenMemcpy(host_lhs.data(), lhs, lhs.size());
632   stream->ThenMemcpy(host_rhs.data(), rhs, rhs.size());
633   TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
634 
635   const auto canonicalize = [](ComparisonType a) -> ComparisonType {
636     if (std::is_same<ElementType, Eigen::half>::value && a) {
637       constexpr ComparisonType kMaxFp16Value = 65505.;
638       if (std::isnan(a)) {
639         return a;
640       }
641       return std::max(-kMaxFp16Value, std::min(a, kMaxFp16Value));
642     }
643     return a;
644   };
645   int differences_seen = 0;
646   for (int64 i = 0; i < n && differences_seen < 10; i++) {
647     auto original_lhs = static_cast<ComparisonType>(host_lhs[i]);
648     auto original_rhs = static_cast<ComparisonType>(host_rhs[i]);
649     ComparisonType lhs = canonicalize(original_lhs);
650     ComparisonType rhs = canonicalize(original_rhs);
651     if (std::isnan(lhs) && std::isnan(rhs)) {
652       continue;
653     }
654     if (std::isinf(lhs) && std::isinf(rhs) && lhs == rhs) {
655       continue;
656     }
657     if (std::isfinite(lhs) != std::isfinite(rhs) ||
658         !(std::abs(lhs - rhs) / (std::max(std::abs(lhs), std::abs(rhs)) + 1) <
659           kTolerance)) {
660       differences_seen++;
661       LOG(ERROR) << "Difference at " << i << ": " << original_lhs << " vs "
662                  << original_rhs;
663     }
664   }
665   return differences_seen == 0;
666 }
667 
668 template <typename ElementT, typename ComparisonT>
CompareEqualParameterized(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs,const Shape & shape,const HloModuleConfig & config,absl::string_view kernel_name)669 static StatusOr<bool> CompareEqualParameterized(se::Stream* stream,
670                                                 se::DeviceMemoryBase lhs,
671                                                 se::DeviceMemoryBase rhs,
672                                                 const Shape& shape,
673                                                 const HloModuleConfig& config,
674                                                 absl::string_view kernel_name) {
675   XLA_SCOPED_LOGGING_TIMER("BufferComparator::CompareEqual");
676   TF_ASSIGN_OR_RETURN(
677       bool result,
678       DeviceCompare<ElementT>(stream, lhs, rhs, shape, config, kernel_name));
679 
680   if (result) {
681     return true;
682   }
683 
684   TF_ASSIGN_OR_RETURN(bool host_return,
685                       (HostCompare<ElementT, ComparisonT>(stream, lhs, rhs)));
686   CHECK(host_return == result) << "Different comparison result on GPU vs host";
687 
688   return false;
689 }
690 
CompareEqual(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs) const691 StatusOr<bool> BufferComparator::CompareEqual(se::Stream* stream,
692                                               se::DeviceMemoryBase lhs,
693                                               se::DeviceMemoryBase rhs) const {
694   switch (shape_.element_type()) {
695     case xla::F16:
696       return CompareEqualParameterized<Eigen::half, float>(
697           stream, lhs, rhs, shape_, config_, "__xla_fp16_comparison");
698     case xla::F32:
699       return CompareEqualParameterized<float, float>(
700           stream, lhs, rhs, shape_, config_, "__xla_fp32_comparison");
701     case xla::F64:
702       return CompareEqualParameterized<double, double>(
703           stream, lhs, rhs, shape_, config_, "__xla_fp64_comparison");
704     case xla::S8:
705       return CompareEqualParameterized<int8, float>(
706           stream, lhs, rhs, shape_, config_, "__xla_int8_comparison");
707     default:
708       return Unimplemented("Unimplemented element type");
709   }
710 }
711 
BufferComparator(const Shape & shape,const HloModuleConfig & config)712 BufferComparator::BufferComparator(const Shape& shape,
713                                    const HloModuleConfig& config)
714     : shape_(shape), config_(config) {
715   // Normalize complex shapes: since we treat the passed array as a contiguous
716   // storage it does not matter which dimension are we doubling.
717   auto double_dim_size = [&]() {
718     int64 prev_zero_dim_size = shape_.dimensions(0);
719     shape_.set_dimensions(0, prev_zero_dim_size * 2);
720   };
721 
722   if (shape_.element_type() == PrimitiveType::C64) {
723     // C64 is just two F32s next to each other.
724     shape_.set_element_type(PrimitiveType::F32);
725     double_dim_size();
726   } else if (shape_.element_type() == PrimitiveType::C128) {
727     // C128 is just two F64s next to each other.
728     shape_.set_element_type(PrimitiveType::F64);
729     double_dim_size();
730   }
731 }
732 
733 }  // namespace gpu
734 }  // namespace xla
735