• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h"
17 
18 #include <algorithm>
19 #include <cmath>
20 
21 #include "absl/base/call_once.h"
22 #include "absl/strings/str_replace.h"
23 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
24 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/status_macros.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/stream_executor/device_memory.h"
30 #include "tensorflow/stream_executor/gpu/asm_compiler.h"
31 #include "tensorflow/stream_executor/kernel.h"
32 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
33 
34 namespace xla {
35 namespace gpu {
36 
37 static constexpr double kTolerance = 0.1f;
38 
39 // Comparison kernel code: compare two buffers of fp16/fp32/fp64/int8 of length
40 // buffer_length where the relative error does not exceed the passed
41 // rel_error_threshold. Write the number of mismatches into out parameter
42 // mismatch_count.
43 //
44 // NaN's are considered equal, and for half's we clamp all numbers to largest
45 // and smallest numbers representable to avoid miscomparisons due to overflows.
46 //
47 // The PTX below is compiled from the following CUDA code:
48 //
49 // #include<cuda_fp16.h>
50 // extern "C" { // avoid name mangling
51 // __device__ float __xla_buffer_comparator_canonicalize(float input) {
52 //   // All fp16 infinities are treated as 65505 or -65505, in order to avoid
53 //   // differences due to overflows.
54 //   return isnan(input) ? input : max(-65505.0f, min(input, 65505.0f));
55 // }
56 
57 // __device__ float __xla_buffer_comparator_extract_int8(int pack) {
58 //   // Extract the lower 8 bits from pack and convert it to float
59 //   const unsigned int bit_mask = 0xff;
60 //   unsigned int bits = pack & bit_mask;
61 //   char* int8_ptr = (char*)&bits;
62 //   return __int2float_rn(*int8_ptr);
63 // }
64 
65 // __global__ void __xla_fp16_comparison(__half* buffer_a, __half* buffer_b,
66 //                                       float rel_error_threshold,
67 //                                       unsigned long long buffer_length,
68 //                                       int* mismatch_count) {
69 //   int idx = threadIdx.x + blockIdx.x * blockDim.x;
70 //   if (idx >= buffer_length) return;
71 //   float elem_a = __half2float(buffer_a[idx]);
72 //   float elem_b = __half2float(buffer_b[idx]);
73 //   elem_a = __xla_buffer_comparator_canonicalize(elem_a);
74 //   elem_b = __xla_buffer_comparator_canonicalize(elem_b);
75 //   if (isnan(elem_a) && isnan(elem_b)) return;
76 //   float rel_error = abs(elem_a - elem_b)
77 //       / (max(abs(elem_a), abs(elem_b)) + 1);
78 //   if (rel_error > rel_error_threshold || isnan(rel_error))
79 //     atomicAdd(mismatch_count, 1);
80 // }
81 
82 // __global__ void __xla_fp32_comparison(float* buffer_a, float* buffer_b,
83 //                                       float rel_error_threshold,
84 //                                       unsigned long long buffer_length,
85 //                                       int* mismatch_count) {
86 //   int idx = threadIdx.x + blockIdx.x * blockDim.x;
87 //   if (idx >= buffer_length) return;
88 //   float elem_a = buffer_a[idx];
89 //   float elem_b = buffer_b[idx];
90 //   if (isnan(elem_a) && isnan(elem_b)) return;
91 //   if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b))
92 //     return;
93 //   float rel_error = abs(elem_a - elem_b)
94 //       / (max(abs(elem_a), abs(elem_b)) + 1);
95 //   if (rel_error > rel_error_threshold || isnan(rel_error))
96 //     atomicAdd(mismatch_count, 1);
97 // }
98 
99 // __global__ void __xla_fp64_comparison(double* buffer_a, double* buffer_b,
100 //                                       float rel_error_threshold,
101 //                                       unsigned long long buffer_length,
102 //                                       int* mismatch_count) {
103 //   int idx = threadIdx.x + blockIdx.x * blockDim.x;
104 //   if (idx >= buffer_length) return;
105 //   double elem_a = buffer_a[idx];
106 //   double elem_b = buffer_b[idx];
107 //   if (isnan(elem_a) && isnan(elem_b)) return;
108 //   if (isinf(elem_a) && isinf(elem_b) && signbit(elem_a) == signbit(elem_b))
109 //     return;
110 //   double rel_error = abs(elem_a - elem_b)
111 //       / (max(abs(elem_a), abs(elem_b)) + 1);
112 //   if (rel_error > rel_error_threshold || isnan(rel_error))
113 //     atomicAdd(mismatch_count, 1);
114 // }
115 
116 // __global__ void __xla_int8_comparison(int* buffer_a, int* buffer_b,
117 //                                       float rel_error_threshold,
118 //                                       unsigned long long buffer_length,
119 //                                       int* mismatch_count) {
120 //   int idx = threadIdx.x + blockIdx.x * blockDim.x;
121 //   if (idx >= buffer_length) return;
122 //   int pack_a = buffer_a[idx];
123 //   int pack_b = buffer_b[idx];
124 //   for(int i = 0; i < 4; ++i) {
125 //     float elem_a = __xla_buffer_comparator_extract_int8(pack_a);
126 //     float elem_b = __xla_buffer_comparator_extract_int8(pack_b);
127 //     float rel_error = abs(elem_a - elem_b)
128 //         / (max(abs(elem_a), abs(elem_b)) + 1);
129 //     if (rel_error > rel_error_threshold || isnan(rel_error))
130 //         atomicAdd(mismatch_count, 1);
131 //     pack_a >>= 8;
132 //     pack_b >>= 8;
133 //   }
134 // }
135 // } // end extern declaration.
136 static const char* buffer_compare_ptx = R"(
137 .version 4.2
138 .target sm_30
139 .address_size 64
140 
141  // .globl __xla_fp16_comparison
142 
143 .visible .entry __xla_fp16_comparison(
144  .param .u64 __xla_fp16_comparison_param_0,
145  .param .u64 __xla_fp16_comparison_param_1,
146  .param .f32 __xla_fp16_comparison_param_2,
147  .param .u64 __xla_fp16_comparison_param_3,
148  .param .u64 __xla_fp16_comparison_param_4
149 )
150 {
151  .reg .pred  %p<9>;
152  .reg .b16  %rs<3>;
153  .reg .f32  %f<28>;
154  .reg .b32  %r<6>;
155  .reg .b64  %rd<12>;
156 
157 
158  ld.param.u64  %rd1, [__xla_fp16_comparison_param_0];
159  ld.param.u64  %rd2, [__xla_fp16_comparison_param_1];
160  ld.param.f32  %f10, [__xla_fp16_comparison_param_2];
161  ld.param.u64  %rd4, [__xla_fp16_comparison_param_3];
162  ld.param.u64  %rd3, [__xla_fp16_comparison_param_4];
163  mov.u32  %r2, %ntid.x;
164  mov.u32  %r3, %ctaid.x;
165  mov.u32  %r4, %tid.x;
166  mad.lo.s32  %r1, %r2, %r3, %r4;
167  cvt.s64.s32 %rd5, %r1;
168  setp.ge.u64 %p1, %rd5, %rd4;
169  @%p1 bra  BB0_9;
170 
171  cvta.to.global.u64  %rd6, %rd1;
172  mul.wide.s32  %rd7, %r1, 2;
173  add.s64  %rd8, %rd6, %rd7;
174  ld.global.u16  %rs1, [%rd8];
175  // inline asm
176  {  cvt.f32.f16 %f26, %rs1;}
177 
178  // inline asm
179  cvta.to.global.u64  %rd9, %rd2;
180  add.s64  %rd10, %rd9, %rd7;
181  ld.global.u16  %rs2, [%rd10];
182  // inline asm
183  {  cvt.f32.f16 %f27, %rs2;}
184 
185  // inline asm
186  abs.f32  %f13, %f26;
187  setp.gtu.f32 %p2, %f13, 0f7F800000;
188  @%p2 bra  BB0_3;
189 
190  mov.f32  %f14, 0f477FE100;
191  min.f32  %f15, %f26, %f14;
192  mov.f32  %f16, 0fC77FE100;
193  max.f32  %f26, %f16, %f15;
194 
195 BB0_3:
196  abs.f32  %f17, %f27;
197  setp.gtu.f32 %p3, %f17, 0f7F800000;
198  @%p3 bra  BB0_5;
199 
200  mov.f32  %f18, 0f477FE100;
201  min.f32  %f19, %f27, %f18;
202  mov.f32  %f20, 0fC77FE100;
203  max.f32  %f27, %f20, %f19;
204 
205 BB0_5:
206  abs.f32  %f7, %f26;
207  setp.gtu.f32 %p4, %f7, 0f7F800000;
208  abs.f32  %f8, %f27;
209  setp.gtu.f32 %p5, %f8, 0f7F800000;
210  and.pred   %p6, %p4, %p5;
211  @%p6 bra  BB0_9;
212 
213  sub.f32  %f21, %f26, %f27;
214  abs.f32  %f22, %f21;
215  max.f32  %f23, %f7, %f8;
216  add.f32  %f24, %f23, 0f3F800000;
217  div.rn.f32  %f9, %f22, %f24;
218  setp.gt.f32 %p7, %f9, %f10;
219  @%p7 bra  BB0_8;
220 
221  abs.f32  %f25, %f9;
222  setp.le.f32 %p8, %f25, 0f7F800000;
223  @%p8 bra  BB0_9;
224 
225 BB0_8:
226  cvta.to.global.u64  %rd11, %rd3;
227  atom.global.add.u32  %r5, [%rd11], 1;
228 
229 BB0_9:
230  ret;
231 }
232 
233  // .globl __xla_fp32_comparison
234 .visible .entry __xla_fp32_comparison(
235  .param .u64 __xla_fp32_comparison_param_0,
236  .param .u64 __xla_fp32_comparison_param_1,
237  .param .f32 __xla_fp32_comparison_param_2,
238  .param .u64 __xla_fp32_comparison_param_3,
239  .param .u64 __xla_fp32_comparison_param_4
240 )
241 {
242  .reg .pred  %p<10>;
243  .reg .b16  %rs<3>;
244  .reg .f32  %f<13>;
245  .reg .b32  %r<10>;
246  .reg .b64  %rd<12>;
247 
248 
249  ld.param.u64  %rd1, [__xla_fp32_comparison_param_0];
250  ld.param.u64  %rd2, [__xla_fp32_comparison_param_1];
251  ld.param.f32  %f6, [__xla_fp32_comparison_param_2];
252  ld.param.u64  %rd4, [__xla_fp32_comparison_param_3];
253  ld.param.u64  %rd3, [__xla_fp32_comparison_param_4];
254  mov.u32  %r2, %ntid.x;
255  mov.u32  %r3, %ctaid.x;
256  mov.u32  %r4, %tid.x;
257  mad.lo.s32  %r1, %r2, %r3, %r4;
258  cvt.s64.s32 %rd5, %r1;
259  setp.ge.u64 %p1, %rd5, %rd4;
260  @%p1 bra  BB1_8;
261 
262  cvta.to.global.u64  %rd6, %rd1;
263  mul.wide.s32  %rd7, %r1, 4;
264  add.s64  %rd8, %rd6, %rd7;
265  cvta.to.global.u64  %rd9, %rd2;
266  add.s64  %rd10, %rd9, %rd7;
267  ld.global.f32  %f1, [%rd10];
268  ld.global.f32  %f2, [%rd8];
269  abs.f32  %f3, %f2;
270  setp.le.f32 %p2, %f3, 0f7F800000;
271  @%p2 bra  BB1_3;
272 
273  abs.f32  %f7, %f1;
274  setp.gtu.f32 %p3, %f7, 0f7F800000;
275  @%p3 bra  BB1_8;
276 
277 BB1_3:
278  setp.neu.f32 %p4, %f3, 0f7F800000;
279  abs.f32  %f4, %f1;
280  setp.neu.f32 %p5, %f4, 0f7F800000;
281  or.pred   %p6, %p4, %p5;
282  @%p6 bra  BB1_5;
283 
284  mov.b32   %r5, %f2;
285  shr.u32  %r6, %r5, 31;
286  cvt.u16.u32 %rs1, %r6;
287  mov.b32   %r7, %f1;
288  shr.u32  %r8, %r7, 31;
289  cvt.u16.u32 %rs2, %r8;
290  setp.eq.s16 %p7, %rs1, %rs2;
291  @%p7 bra  BB1_8;
292 
293 BB1_5:
294  sub.f32  %f8, %f2, %f1;
295  abs.f32  %f9, %f8;
296  max.f32  %f10, %f3, %f4;
297  add.f32  %f11, %f10, 0f3F800000;
298  div.rn.f32  %f5, %f9, %f11;
299  setp.gt.f32 %p8, %f5, %f6;
300  @%p8 bra  BB1_7;
301 
302  abs.f32  %f12, %f5;
303  setp.le.f32 %p9, %f12, 0f7F800000;
304  @%p9 bra  BB1_8;
305 
306 BB1_7:
307  cvta.to.global.u64  %rd11, %rd3;
308  atom.global.add.u32  %r9, [%rd11], 1;
309 
310 BB1_8:
311  ret;
312 }
313 
314  // .globl __xla_fp64_comparison
315 .visible .entry __xla_fp64_comparison(
316  .param .u64 __xla_fp64_comparison_param_0,
317  .param .u64 __xla_fp64_comparison_param_1,
318  .param .f32 __xla_fp64_comparison_param_2,
319  .param .u64 __xla_fp64_comparison_param_3,
320  .param .u64 __xla_fp64_comparison_param_4
321 )
322 {
323  .reg .pred  %p<11>;
324  .reg .b16  %rs<3>;
325  .reg .f32  %f<2>;
326  .reg .b32  %r<14>;
327  .reg .f64  %fd<13>;
328  .reg .b64  %rd<12>;
329 
330 
331  ld.param.u64  %rd1, [__xla_fp64_comparison_param_0];
332  ld.param.u64  %rd2, [__xla_fp64_comparison_param_1];
333  ld.param.f32  %f1, [__xla_fp64_comparison_param_2];
334  ld.param.u64  %rd4, [__xla_fp64_comparison_param_3];
335  ld.param.u64  %rd3, [__xla_fp64_comparison_param_4];
336  mov.u32  %r4, %ntid.x;
337  mov.u32  %r5, %ctaid.x;
338  mov.u32  %r6, %tid.x;
339  mad.lo.s32  %r1, %r4, %r5, %r6;
340  cvt.s64.s32 %rd5, %r1;
341  setp.ge.u64 %p1, %rd5, %rd4;
342  @%p1 bra  BB2_11;
343 
344  cvta.to.global.u64  %rd6, %rd1;
345  mul.wide.s32  %rd7, %r1, 8;
346  add.s64  %rd8, %rd6, %rd7;
347  cvta.to.global.u64  %rd9, %rd2;
348  add.s64  %rd10, %rd9, %rd7;
349  ld.global.f64  %fd1, [%rd10];
350  ld.global.f64  %fd2, [%rd8];
351  abs.f64  %fd3, %fd2;
352  setp.le.f64 %p2, %fd3, 0d7FF0000000000000;
353  @%p2 bra  BB2_3;
354 
355  abs.f64  %fd5, %fd1;
356  setp.gtu.f64 %p3, %fd5, 0d7FF0000000000000;
357  @%p3 bra  BB2_11;
358 
359 BB2_3:
360  {
361  .reg .b32 %temp;
362  mov.b64  {%temp, %r2}, %fd2;
363  }
364  and.b32   %r7, %r2, 2147483647;
365  setp.ne.s32 %p4, %r7, 2146435072;
366  @%p4 bra  BB2_8;
367 
368  {
369  .reg .b32 %temp;
370  mov.b64  {%r8, %temp}, %fd2;
371  }
372  setp.ne.s32 %p5, %r8, 0;
373  @%p5 bra  BB2_8;
374 
375  {
376  .reg .b32 %temp;
377  mov.b64  {%temp, %r3}, %fd1;
378  }
379  and.b32   %r9, %r3, 2147483647;
380  setp.ne.s32 %p6, %r9, 2146435072;
381  @%p6 bra  BB2_8;
382 
383  {
384  .reg .b32 %temp;
385  mov.b64  {%r10, %temp}, %fd1;
386  }
387  setp.ne.s32 %p7, %r10, 0;
388  @%p7 bra  BB2_8;
389 
390  shr.u32  %r11, %r2, 31;
391  cvt.u16.u32 %rs1, %r11;
392  shr.u32  %r12, %r3, 31;
393  cvt.u16.u32 %rs2, %r12;
394  setp.eq.s16 %p8, %rs1, %rs2;
395  @%p8 bra  BB2_11;
396 
397 BB2_8:
398  sub.f64  %fd6, %fd2, %fd1;
399  abs.f64  %fd7, %fd6;
400  abs.f64  %fd8, %fd1;
401  max.f64  %fd9, %fd3, %fd8;
402  add.f64  %fd10, %fd9, 0d3FF0000000000000;
403  div.rn.f64  %fd4, %fd7, %fd10;
404  cvt.f64.f32 %fd11, %f1;
405  setp.gt.f64 %p9, %fd4, %fd11;
406  @%p9 bra  BB2_10;
407 
408  abs.f64  %fd12, %fd4;
409  setp.le.f64 %p10, %fd12, 0d7FF0000000000000;
410  @%p10 bra  BB2_11;
411 
412 BB2_10:
413  cvta.to.global.u64  %rd11, %rd3;
414  atom.global.add.u32  %r13, [%rd11], 1;
415 
416 BB2_11:
417  ret;
418 }
419 
420  // .globl __xla_int8_comparison
421 .visible .entry __xla_int8_comparison(
422  .param .u64 __xla_int8_comparison_param_0,
423  .param .u64 __xla_int8_comparison_param_1,
424  .param .f32 __xla_int8_comparison_param_2,
425  .param .u64 __xla_int8_comparison_param_3,
426  .param .u64 __xla_int8_comparison_param_4
427 )
428 {
429  .reg .pred  %p<10>;
430  .reg .f32  %f<42>;
431  .reg .b32  %r<23>;
432  .reg .b64  %rd<12>;
433 
434 
435  ld.param.u64  %rd2, [__xla_int8_comparison_param_0];
436  ld.param.u64  %rd3, [__xla_int8_comparison_param_1];
437  ld.param.f32  %f5, [__xla_int8_comparison_param_2];
438  ld.param.u64  %rd4, [__xla_int8_comparison_param_3];
439  ld.param.u64  %rd5, [__xla_int8_comparison_param_4];
440  cvta.to.global.u64  %rd1, %rd5;
441  mov.u32  %r4, %ntid.x;
442  mov.u32  %r5, %ctaid.x;
443  mov.u32  %r6, %tid.x;
444  mad.lo.s32  %r1, %r4, %r5, %r6;
445  cvt.s64.s32 %rd6, %r1;
446  setp.ge.u64 %p1, %rd6, %rd4;
447  @%p1 bra  BB3_13;
448 
449  cvta.to.global.u64  %rd7, %rd2;
450  mul.wide.s32  %rd8, %r1, 4;
451  add.s64  %rd9, %rd7, %rd8;
452  cvta.to.global.u64  %rd10, %rd3;
453  add.s64  %rd11, %rd10, %rd8;
454  ld.global.u32  %r2, [%rd9];
455  cvt.s32.s8  %r7, %r2;
456  cvt.rn.f32.s32 %f6, %r7;
457  ld.global.u32  %r3, [%rd11];
458  cvt.s32.s8  %r8, %r3;
459  cvt.rn.f32.s32 %f7, %r8;
460  sub.f32  %f8, %f6, %f7;
461  abs.f32  %f9, %f8;
462  abs.f32  %f10, %f6;
463  abs.f32  %f11, %f7;
464  max.f32  %f12, %f10, %f11;
465  add.f32  %f13, %f12, 0f3F800000;
466  div.rn.f32  %f1, %f9, %f13;
467  setp.gt.f32 %p2, %f1, %f5;
468  @%p2 bra  BB3_3;
469 
470  abs.f32  %f14, %f1;
471  setp.le.f32 %p3, %f14, 0f7F800000;
472  @%p3 bra  BB3_4;
473 
474 BB3_3:
475  atom.global.add.u32  %r9, [%rd1], 1;
476 
477 BB3_4:
478  shr.u32  %r10, %r3, 8;
479  shr.u32  %r11, %r2, 8;
480  cvt.s32.s8  %r12, %r11;
481  cvt.rn.f32.s32 %f15, %r12;
482  cvt.s32.s8  %r13, %r10;
483  cvt.rn.f32.s32 %f16, %r13;
484  sub.f32  %f17, %f15, %f16;
485  abs.f32  %f18, %f17;
486  abs.f32  %f19, %f15;
487  abs.f32  %f20, %f16;
488  max.f32  %f21, %f19, %f20;
489  add.f32  %f22, %f21, 0f3F800000;
490  div.rn.f32  %f2, %f18, %f22;
491  setp.gt.f32 %p4, %f2, %f5;
492  @%p4 bra  BB3_6;
493 
494  abs.f32  %f23, %f2;
495  setp.le.f32 %p5, %f23, 0f7F800000;
496  @%p5 bra  BB3_7;
497 
498 BB3_6:
499  atom.global.add.u32  %r14, [%rd1], 1;
500 
501 BB3_7:
502  shr.u32  %r15, %r3, 16;
503  shr.u32  %r16, %r2, 16;
504  cvt.s32.s8  %r17, %r16;
505  cvt.rn.f32.s32 %f24, %r17;
506  cvt.s32.s8  %r18, %r15;
507  cvt.rn.f32.s32 %f25, %r18;
508  sub.f32  %f26, %f24, %f25;
509  abs.f32  %f27, %f26;
510  abs.f32  %f28, %f24;
511  abs.f32  %f29, %f25;
512  max.f32  %f30, %f28, %f29;
513  add.f32  %f31, %f30, 0f3F800000;
514  div.rn.f32  %f3, %f27, %f31;
515  setp.gt.f32 %p6, %f3, %f5;
516  @%p6 bra  BB3_9;
517 
518  abs.f32  %f32, %f3;
519  setp.le.f32 %p7, %f32, 0f7F800000;
520  @%p7 bra  BB3_10;
521 
522 BB3_9:
523  atom.global.add.u32  %r19, [%rd1], 1;
524 
525 BB3_10:
526  shr.s32  %r20, %r2, 24;
527  cvt.rn.f32.s32 %f33, %r20;
528  shr.s32  %r21, %r3, 24;
529  cvt.rn.f32.s32 %f34, %r21;
530  sub.f32  %f35, %f33, %f34;
531  abs.f32  %f36, %f35;
532  abs.f32  %f37, %f33;
533  abs.f32  %f38, %f34;
534  max.f32  %f39, %f37, %f38;
535  add.f32  %f40, %f39, 0f3F800000;
536  div.rn.f32  %f4, %f36, %f40;
537  setp.gt.f32 %p8, %f4, %f5;
538  @%p8 bra  BB3_12;
539 
540  abs.f32  %f41, %f4;
541  setp.le.f32 %p9, %f41, 0f7F800000;
542  @%p9 bra  BB3_13;
543 
544 BB3_12:
545  atom.global.add.u32  %r22, [%rd1], 1;
546 
547 BB3_13:
548  ret;
549 }
550 )";
551 
552 template <typename ElementT>
553 using ComparisonKernelT =
554     se::TypedKernel<se::DeviceMemory<ElementT>, se::DeviceMemory<ElementT>,
555                     float, uint64, se::DeviceMemory<uint64>>;
556 
557 // Compares two buffers on the GPU.
558 //
559 // Returns `true` if two buffers are equal, `false` otherwise.
560 template <typename ElementT>
DeviceCompare(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs,const Shape & buffer_shape,const HloModuleConfig & config,absl::string_view kernel_name)561 static StatusOr<bool> DeviceCompare(se::Stream* stream,
562                                     se::DeviceMemoryBase lhs,
563                                     se::DeviceMemoryBase rhs,
564                                     const Shape& buffer_shape,
565                                     const HloModuleConfig& config,
566                                     absl::string_view kernel_name) {
567   se::StreamExecutor* executor = stream->parent();
568 
569   se::ScopedDeviceMemory<uint64> out_param =
570       executor->AllocateOwnedScalar<uint64>();
571 
572   stream->ThenMemZero(out_param.ptr(), sizeof(uint64));
573   if (lhs.size() != rhs.size()) {
574     return InternalError("Mismatched buffer size: %d bytes vs. %d bytes",
575                          lhs.size(), rhs.size());
576   }
577 
578   se::DeviceMemory<ElementT> lhs_typed(lhs);
579   se::DeviceMemory<ElementT> rhs_typed(rhs);
580   uint64 buffer_size = lhs_typed.ElementCount();
581 
582   absl::Span<const uint8> compiled_ptx = {};
583   StatusOr<absl::Span<const uint8>> compiled_ptx_or =
584       se::CompileGpuAsmOrGetCached(executor->device_ordinal(),
585                                    buffer_compare_ptx,
586                                    PtxOptsFromConfig(config));
587   if (compiled_ptx_or.ok()) {
588     compiled_ptx = compiled_ptx_or.ConsumeValueOrDie();
589   } else {
590     static absl::once_flag ptxas_not_found_logged;
591     absl::call_once(ptxas_not_found_logged, [&]() {
592       LOG(WARNING)
593           << compiled_ptx_or.status().ToString()
594           << "\nRelying on driver to perform ptx compilation. "
595           << "\nSetting XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda "
596           << " or modifying $PATH can be used to set the location of ptxas"
597           << "\nThis message will only be logged once.";
598     });
599   }
600 
601   TF_ASSIGN_OR_RETURN(
602       std::unique_ptr<ComparisonKernelT<ElementT>> comparison_kernel,
603       (executor->CreateTypedKernel<se::DeviceMemory<ElementT>,
604                                    se::DeviceMemory<ElementT>, float, uint64,
605                                    se::DeviceMemory<uint64>>(
606           kernel_name, buffer_compare_ptx, compiled_ptx)));
607 
608   GpuDeviceInfo gpu_device_info;
609   gpu_device_info.threads_per_block_limit =
610       executor->GetDeviceDescription().threads_per_block_limit();
611   gpu_device_info.threads_per_warp =
612       executor->GetDeviceDescription().threads_per_warp();
613   gpu_device_info.shared_memory_per_block =
614       executor->GetDeviceDescription().shared_memory_per_block();
615   gpu_device_info.threads_per_core_limit =
616       executor->GetDeviceDescription().threads_per_core_limit();
617   gpu_device_info.core_count = executor->GetDeviceDescription().core_count();
618   LaunchDimensions dim =
619       CalculateLaunchDimensions(buffer_shape, gpu_device_info);
620 
621   LaunchDimensions::Dim3D thread_counts = dim.thread_counts_per_block();
622   LaunchDimensions::Dim3D block_counts = dim.block_counts();
623   stream->ThenLaunch(
624       se::ThreadDim(thread_counts.x, thread_counts.y, thread_counts.z),
625       se::BlockDim(block_counts.x, block_counts.y, block_counts.z),
626       *comparison_kernel, lhs_typed, rhs_typed, static_cast<float>(kTolerance),
627       buffer_size, out_param.cref());
628 
629   uint64 result = -1;
630   CHECK_EQ(out_param->size(), sizeof(result));
631   stream->ThenMemcpy(&result, *out_param, sizeof(result));
632   TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
633   return result == 0;
634 }
635 
636 // Host side comparison code that does the same thing, but reports some of the
637 // differences as well. It only print logs for debugging.
638 //
639 // Returns true if no differences were seen, false otherwise.
640 template <typename ElementType, typename ComparisonType>
HostCompare(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs)641 StatusOr<bool> HostCompare(se::Stream* stream, se::DeviceMemoryBase lhs,
642                            se::DeviceMemoryBase rhs) {
643   int64 n = lhs.size() / sizeof(ElementType);
644   std::vector<ElementType> host_lhs(n), host_rhs(n);
645   stream->ThenMemcpy(host_lhs.data(), lhs, lhs.size());
646   stream->ThenMemcpy(host_rhs.data(), rhs, rhs.size());
647   TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
648 
649   const auto canonicalize = [](ComparisonType a) -> ComparisonType {
650     if (std::is_same<ElementType, Eigen::half>::value && a) {
651       constexpr ComparisonType kMaxFp16Value = 65505.;
652       if (std::isnan(a)) {
653         return a;
654       }
655       return std::max(-kMaxFp16Value, std::min(a, kMaxFp16Value));
656     }
657     return a;
658   };
659   int differences_seen = 0;
660   for (int64 i = 0; i < n && differences_seen < 10; i++) {
661     auto original_lhs = static_cast<ComparisonType>(host_lhs[i]);
662     auto original_rhs = static_cast<ComparisonType>(host_rhs[i]);
663     ComparisonType lhs = canonicalize(original_lhs);
664     ComparisonType rhs = canonicalize(original_rhs);
665     if (std::isnan(lhs) && std::isnan(rhs)) {
666       continue;
667     }
668     if (std::isinf(lhs) && std::isinf(rhs) && lhs == rhs) {
669       continue;
670     }
671     if (std::isfinite(lhs) != std::isfinite(rhs) ||
672         !(std::abs(lhs - rhs) / (std::max(std::abs(lhs), std::abs(rhs)) + 1) <
673           kTolerance)) {
674       differences_seen++;
675       LOG(ERROR) << "Difference at " << i << ": " << original_lhs << " vs "
676                  << original_rhs;
677     }
678   }
679   return differences_seen == 0;
680 }
681 
682 template <typename ElementT, typename ComparisonT>
CompareEqualParameterized(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs,const Shape & shape,const HloModuleConfig & config,absl::string_view kernel_name)683 static StatusOr<bool> CompareEqualParameterized(se::Stream* stream,
684                                                 se::DeviceMemoryBase lhs,
685                                                 se::DeviceMemoryBase rhs,
686                                                 const Shape& shape,
687                                                 const HloModuleConfig& config,
688                                                 absl::string_view kernel_name) {
689   XLA_SCOPED_LOGGING_TIMER("BufferComparator::CompareEqual");
690   TF_ASSIGN_OR_RETURN(
691       bool result,
692       DeviceCompare<ElementT>(stream, lhs, rhs, shape, config, kernel_name));
693 
694   if (result) {
695     return true;
696   }
697 
698   TF_ASSIGN_OR_RETURN(bool host_return,
699                       (HostCompare<ElementT, ComparisonT>(stream, lhs, rhs)));
700   CHECK(host_return == result) << "Different comparison result on GPU vs host";
701 
702   return false;
703 }
704 
CompareEqual(se::Stream * stream,se::DeviceMemoryBase lhs,se::DeviceMemoryBase rhs) const705 StatusOr<bool> BufferComparator::CompareEqual(se::Stream* stream,
706                                               se::DeviceMemoryBase lhs,
707                                               se::DeviceMemoryBase rhs) const {
708   switch (shape_.element_type()) {
709     case xla::F16:
710       return CompareEqualParameterized<Eigen::half, float>(
711           stream, lhs, rhs, shape_, config_, "__xla_fp16_comparison");
712     case xla::F32:
713       return CompareEqualParameterized<float, float>(
714           stream, lhs, rhs, shape_, config_, "__xla_fp32_comparison");
715     case xla::F64:
716       return CompareEqualParameterized<double, double>(
717           stream, lhs, rhs, shape_, config_, "__xla_fp64_comparison");
718     case xla::S8:
719       return CompareEqualParameterized<int8, float>(
720           stream, lhs, rhs, shape_, config_, "__xla_int8_comparison");
721     default:
722       return Unimplemented("Unimplemented element type");
723   }
724 }
725 
BufferComparator(const Shape & shape,const HloModuleConfig & config)726 BufferComparator::BufferComparator(const Shape& shape,
727                                    const HloModuleConfig& config)
728     : shape_(shape), config_(config) {
729   // Normalize complex shapes: since we treat the passed array as a contiguous
730   // storage it does not matter which dimension are we doubling.
731   auto double_dim_size = [&]() {
732     int64 prev_zero_dim_size = shape_.dimensions(0);
733     shape_.set_dimensions(0, prev_zero_dim_size * 2);
734   };
735 
736   if (shape_.element_type() == PrimitiveType::C64) {
737     // C64 is just two F32s next to each other.
738     shape_.set_element_type(PrimitiveType::F32);
739     double_dim_size();
740   } else if (shape_.element_type() == PrimitiveType::C128) {
741     // C128 is just two F64s next to each other.
742     shape_.set_element_type(PrimitiveType::F64);
743     double_dim_size();
744   }
745 }
746 
747 }  // namespace gpu
748 }  // namespace xla
749