• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2016 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir.h"
25 #include "nir_builder.h"
26 
27 #define COND_LOWER_OP(b, name, ...)                    \
28    (b->shader->options->lower_int64_options &          \
29     nir_lower_int64_op_to_options_mask(nir_op_##name)) \
30       ? lower_##name##64(b, __VA_ARGS__)               \
31       : nir_##name(b, __VA_ARGS__)
32 
33 #define COND_LOWER_CMP(b, name, ...)                       \
34    (b->shader->options->lower_int64_options &              \
35     nir_lower_int64_op_to_options_mask(nir_op_##name))     \
36       ? lower_int64_compare(b, nir_op_##name, __VA_ARGS__) \
37       : nir_##name(b, __VA_ARGS__)
38 
39 #define COND_LOWER_CAST(b, name, ...)                  \
40    (b->shader->options->lower_int64_options &          \
41     nir_lower_int64_op_to_options_mask(nir_op_##name)) \
42       ? lower_##name(b, __VA_ARGS__)                   \
43       : nir_##name(b, __VA_ARGS__)
44 
45 static nir_def *
lower_b2i64(nir_builder * b,nir_def * x)46 lower_b2i64(nir_builder *b, nir_def *x)
47 {
48    return nir_pack_64_2x32_split(b, nir_b2i32(b, x), nir_imm_int(b, 0));
49 }
50 
51 static nir_def *
lower_i2i8(nir_builder * b,nir_def * x)52 lower_i2i8(nir_builder *b, nir_def *x)
53 {
54    return nir_i2i8(b, nir_unpack_64_2x32_split_x(b, x));
55 }
56 
57 static nir_def *
lower_i2i16(nir_builder * b,nir_def * x)58 lower_i2i16(nir_builder *b, nir_def *x)
59 {
60    return nir_i2i16(b, nir_unpack_64_2x32_split_x(b, x));
61 }
62 
63 static nir_def *
lower_i2i32(nir_builder * b,nir_def * x)64 lower_i2i32(nir_builder *b, nir_def *x)
65 {
66    return nir_unpack_64_2x32_split_x(b, x);
67 }
68 
69 static nir_def *
lower_i2i64(nir_builder * b,nir_def * x)70 lower_i2i64(nir_builder *b, nir_def *x)
71 {
72    nir_def *x32 = x->bit_size == 32 ? x : nir_i2i32(b, x);
73    return nir_pack_64_2x32_split(b, x32, nir_ishr_imm(b, x32, 31));
74 }
75 
76 static nir_def *
lower_u2u8(nir_builder * b,nir_def * x)77 lower_u2u8(nir_builder *b, nir_def *x)
78 {
79    return nir_u2u8(b, nir_unpack_64_2x32_split_x(b, x));
80 }
81 
82 static nir_def *
lower_u2u16(nir_builder * b,nir_def * x)83 lower_u2u16(nir_builder *b, nir_def *x)
84 {
85    return nir_u2u16(b, nir_unpack_64_2x32_split_x(b, x));
86 }
87 
88 static nir_def *
lower_u2u32(nir_builder * b,nir_def * x)89 lower_u2u32(nir_builder *b, nir_def *x)
90 {
91    return nir_unpack_64_2x32_split_x(b, x);
92 }
93 
94 static nir_def *
lower_u2u64(nir_builder * b,nir_def * x)95 lower_u2u64(nir_builder *b, nir_def *x)
96 {
97    nir_def *x32 = x->bit_size == 32 ? x : nir_u2u32(b, x);
98    return nir_pack_64_2x32_split(b, x32, nir_imm_int(b, 0));
99 }
100 
101 static nir_def *
lower_bcsel64(nir_builder * b,nir_def * cond,nir_def * x,nir_def * y)102 lower_bcsel64(nir_builder *b, nir_def *cond, nir_def *x, nir_def *y)
103 {
104    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
105    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
106    nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
107    nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
108 
109    return nir_pack_64_2x32_split(b, nir_bcsel(b, cond, x_lo, y_lo),
110                                  nir_bcsel(b, cond, x_hi, y_hi));
111 }
112 
113 static nir_def *
lower_inot64(nir_builder * b,nir_def * x)114 lower_inot64(nir_builder *b, nir_def *x)
115 {
116    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
117    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
118 
119    return nir_pack_64_2x32_split(b, nir_inot(b, x_lo), nir_inot(b, x_hi));
120 }
121 
122 static nir_def *
lower_iand64(nir_builder * b,nir_def * x,nir_def * y)123 lower_iand64(nir_builder *b, nir_def *x, nir_def *y)
124 {
125    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
126    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
127    nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
128    nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
129 
130    return nir_pack_64_2x32_split(b, nir_iand(b, x_lo, y_lo),
131                                  nir_iand(b, x_hi, y_hi));
132 }
133 
134 static nir_def *
lower_ior64(nir_builder * b,nir_def * x,nir_def * y)135 lower_ior64(nir_builder *b, nir_def *x, nir_def *y)
136 {
137    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
138    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
139    nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
140    nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
141 
142    return nir_pack_64_2x32_split(b, nir_ior(b, x_lo, y_lo),
143                                  nir_ior(b, x_hi, y_hi));
144 }
145 
146 static nir_def *
lower_ixor64(nir_builder * b,nir_def * x,nir_def * y)147 lower_ixor64(nir_builder *b, nir_def *x, nir_def *y)
148 {
149    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
150    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
151    nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
152    nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
153 
154    return nir_pack_64_2x32_split(b, nir_ixor(b, x_lo, y_lo),
155                                  nir_ixor(b, x_hi, y_hi));
156 }
157 
158 static nir_def *
lower_ishl64(nir_builder * b,nir_def * x,nir_def * y)159 lower_ishl64(nir_builder *b, nir_def *x, nir_def *y)
160 {
161    /* Implemented as
162     *
163     * uint64_t lshift(uint64_t x, int c)
164     * {
165     *    c %= 64;
166     *
167     *    if (c == 0) return x;
168     *
169     *    uint32_t lo = LO(x), hi = HI(x);
170     *
171     *    if (c < 32) {
172     *       uint32_t lo_shifted = lo << c;
173     *       uint32_t hi_shifted = hi << c;
174     *       uint32_t lo_shifted_hi = lo >> abs(32 - c);
175     *       return pack_64(lo_shifted, hi_shifted | lo_shifted_hi);
176     *    } else {
177     *       uint32_t lo_shifted_hi = lo << abs(32 - c);
178     *       return pack_64(0, lo_shifted_hi);
179     *    }
180     * }
181     */
182    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
183    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
184    y = nir_iand_imm(b, y, 0x3f);
185 
186    nir_def *reverse_count = nir_iabs(b, nir_iadd_imm(b, y, -32));
187    nir_def *lo_shifted = nir_ishl(b, x_lo, y);
188    nir_def *hi_shifted = nir_ishl(b, x_hi, y);
189    nir_def *lo_shifted_hi = nir_ushr(b, x_lo, reverse_count);
190 
191    nir_def *res_if_lt_32 =
192       nir_pack_64_2x32_split(b, lo_shifted,
193                              nir_ior(b, hi_shifted, lo_shifted_hi));
194    nir_def *res_if_ge_32 =
195       nir_pack_64_2x32_split(b, nir_imm_int(b, 0),
196                              nir_ishl(b, x_lo, reverse_count));
197 
198    return nir_bcsel(b, nir_ieq_imm(b, y, 0), x,
199                     nir_bcsel(b, nir_uge_imm(b, y, 32),
200                               res_if_ge_32, res_if_lt_32));
201 }
202 
203 static nir_def *
lower_ishr64(nir_builder * b,nir_def * x,nir_def * y)204 lower_ishr64(nir_builder *b, nir_def *x, nir_def *y)
205 {
206    /* Implemented as
207     *
208     * uint64_t arshift(uint64_t x, int c)
209     * {
210     *    c %= 64;
211     *
212     *    if (c == 0) return x;
213     *
214     *    uint32_t lo = LO(x);
215     *    int32_t  hi = HI(x);
216     *
217     *    if (c < 32) {
218     *       uint32_t lo_shifted = lo >> c;
219     *       uint32_t hi_shifted = hi >> c;
220     *       uint32_t hi_shifted_lo = hi << abs(32 - c);
221     *       return pack_64(hi_shifted, hi_shifted_lo | lo_shifted);
222     *    } else {
223     *       uint32_t hi_shifted = hi >> 31;
224     *       uint32_t hi_shifted_lo = hi >> abs(32 - c);
225     *       return pack_64(hi_shifted, hi_shifted_lo);
226     *    }
227     * }
228     */
229    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
230    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
231    y = nir_iand_imm(b, y, 0x3f);
232 
233    nir_def *reverse_count = nir_iabs(b, nir_iadd_imm(b, y, -32));
234    nir_def *lo_shifted = nir_ushr(b, x_lo, y);
235    nir_def *hi_shifted = nir_ishr(b, x_hi, y);
236    nir_def *hi_shifted_lo = nir_ishl(b, x_hi, reverse_count);
237 
238    nir_def *res_if_lt_32 =
239       nir_pack_64_2x32_split(b, nir_ior(b, lo_shifted, hi_shifted_lo),
240                              hi_shifted);
241    nir_def *res_if_ge_32 =
242       nir_pack_64_2x32_split(b, nir_ishr(b, x_hi, reverse_count),
243                              nir_ishr_imm(b, x_hi, 31));
244 
245    return nir_bcsel(b, nir_ieq_imm(b, y, 0), x,
246                     nir_bcsel(b, nir_uge_imm(b, y, 32),
247                               res_if_ge_32, res_if_lt_32));
248 }
249 
250 static nir_def *
lower_ushr64(nir_builder * b,nir_def * x,nir_def * y)251 lower_ushr64(nir_builder *b, nir_def *x, nir_def *y)
252 {
253    /* Implemented as
254     *
255     * uint64_t rshift(uint64_t x, int c)
256     * {
257     *    c %= 64;
258     *
259     *    if (c == 0) return x;
260     *
261     *    uint32_t lo = LO(x), hi = HI(x);
262     *
263     *    if (c < 32) {
264     *       uint32_t lo_shifted = lo >> c;
265     *       uint32_t hi_shifted = hi >> c;
266     *       uint32_t hi_shifted_lo = hi << abs(32 - c);
267     *       return pack_64(hi_shifted, hi_shifted_lo | lo_shifted);
268     *    } else {
269     *       uint32_t hi_shifted_lo = hi >> abs(32 - c);
270     *       return pack_64(0, hi_shifted_lo);
271     *    }
272     * }
273     */
274 
275    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
276    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
277    y = nir_iand_imm(b, y, 0x3f);
278 
279    nir_def *reverse_count = nir_iabs(b, nir_iadd_imm(b, y, -32));
280    nir_def *lo_shifted = nir_ushr(b, x_lo, y);
281    nir_def *hi_shifted = nir_ushr(b, x_hi, y);
282    nir_def *hi_shifted_lo = nir_ishl(b, x_hi, reverse_count);
283 
284    nir_def *res_if_lt_32 =
285       nir_pack_64_2x32_split(b, nir_ior(b, lo_shifted, hi_shifted_lo),
286                              hi_shifted);
287    nir_def *res_if_ge_32 =
288       nir_pack_64_2x32_split(b, nir_ushr(b, x_hi, reverse_count),
289                              nir_imm_int(b, 0));
290 
291    return nir_bcsel(b, nir_ieq_imm(b, y, 0), x,
292                     nir_bcsel(b, nir_uge_imm(b, y, 32),
293                               res_if_ge_32, res_if_lt_32));
294 }
295 
296 static nir_def *
lower_iadd64(nir_builder * b,nir_def * x,nir_def * y)297 lower_iadd64(nir_builder *b, nir_def *x, nir_def *y)
298 {
299    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
300    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
301    nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
302    nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
303 
304    nir_def *res_lo = nir_iadd(b, x_lo, y_lo);
305    nir_def *carry = nir_b2i32(b, nir_ult(b, res_lo, x_lo));
306    nir_def *res_hi = nir_iadd(b, carry, nir_iadd(b, x_hi, y_hi));
307 
308    return nir_pack_64_2x32_split(b, res_lo, res_hi);
309 }
310 
311 static nir_def *
lower_isub64(nir_builder * b,nir_def * x,nir_def * y)312 lower_isub64(nir_builder *b, nir_def *x, nir_def *y)
313 {
314    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
315    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
316    nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
317    nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
318 
319    nir_def *res_lo = nir_isub(b, x_lo, y_lo);
320    nir_def *borrow = nir_ineg(b, nir_b2i32(b, nir_ult(b, x_lo, y_lo)));
321    nir_def *res_hi = nir_iadd(b, nir_isub(b, x_hi, y_hi), borrow);
322 
323    return nir_pack_64_2x32_split(b, res_lo, res_hi);
324 }
325 
326 static nir_def *
lower_ineg64(nir_builder * b,nir_def * x)327 lower_ineg64(nir_builder *b, nir_def *x)
328 {
329    /* Since isub is the same number of instructions (with better dependencies)
330     * as iadd, subtraction is actually more efficient for ineg than the usual
331     * 2's complement "flip the bits and add one".
332     */
333    return lower_isub64(b, nir_imm_int64(b, 0), x);
334 }
335 
336 static nir_def *
lower_iabs64(nir_builder * b,nir_def * x)337 lower_iabs64(nir_builder *b, nir_def *x)
338 {
339    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
340    nir_def *x_is_neg = nir_ilt_imm(b, x_hi, 0);
341    return nir_bcsel(b, x_is_neg, nir_ineg(b, x), x);
342 }
343 
344 static nir_def *
lower_int64_compare(nir_builder * b,nir_op op,nir_def * x,nir_def * y)345 lower_int64_compare(nir_builder *b, nir_op op, nir_def *x, nir_def *y)
346 {
347    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
348    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
349    nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
350    nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
351 
352    switch (op) {
353    case nir_op_ieq:
354       return nir_iand(b, nir_ieq(b, x_hi, y_hi), nir_ieq(b, x_lo, y_lo));
355    case nir_op_ine:
356       return nir_ior(b, nir_ine(b, x_hi, y_hi), nir_ine(b, x_lo, y_lo));
357    case nir_op_ult:
358       return nir_ior(b, nir_ult(b, x_hi, y_hi),
359                      nir_iand(b, nir_ieq(b, x_hi, y_hi),
360                               nir_ult(b, x_lo, y_lo)));
361    case nir_op_ilt:
362       return nir_ior(b, nir_ilt(b, x_hi, y_hi),
363                      nir_iand(b, nir_ieq(b, x_hi, y_hi),
364                               nir_ult(b, x_lo, y_lo)));
365       break;
366    case nir_op_uge:
367       /* Lower as !(x < y) in the hopes of better CSE */
368       return nir_inot(b, lower_int64_compare(b, nir_op_ult, x, y));
369    case nir_op_ige:
370       /* Lower as !(x < y) in the hopes of better CSE */
371       return nir_inot(b, lower_int64_compare(b, nir_op_ilt, x, y));
372    default:
373       unreachable("Invalid comparison");
374    }
375 }
376 
377 static nir_def *
lower_umax64(nir_builder * b,nir_def * x,nir_def * y)378 lower_umax64(nir_builder *b, nir_def *x, nir_def *y)
379 {
380    return nir_bcsel(b, lower_int64_compare(b, nir_op_ult, x, y), y, x);
381 }
382 
383 static nir_def *
lower_imax64(nir_builder * b,nir_def * x,nir_def * y)384 lower_imax64(nir_builder *b, nir_def *x, nir_def *y)
385 {
386    return nir_bcsel(b, lower_int64_compare(b, nir_op_ilt, x, y), y, x);
387 }
388 
389 static nir_def *
lower_umin64(nir_builder * b,nir_def * x,nir_def * y)390 lower_umin64(nir_builder *b, nir_def *x, nir_def *y)
391 {
392    return nir_bcsel(b, lower_int64_compare(b, nir_op_ult, x, y), x, y);
393 }
394 
395 static nir_def *
lower_imin64(nir_builder * b,nir_def * x,nir_def * y)396 lower_imin64(nir_builder *b, nir_def *x, nir_def *y)
397 {
398    return nir_bcsel(b, lower_int64_compare(b, nir_op_ilt, x, y), x, y);
399 }
400 
401 static nir_def *
lower_mul_2x32_64(nir_builder * b,nir_def * x,nir_def * y,bool sign_extend)402 lower_mul_2x32_64(nir_builder *b, nir_def *x, nir_def *y,
403                   bool sign_extend)
404 {
405    nir_def *res_hi = sign_extend ? nir_imul_high(b, x, y)
406                                  : nir_umul_high(b, x, y);
407 
408    return nir_pack_64_2x32_split(b, nir_imul(b, x, y), res_hi);
409 }
410 
411 static nir_def *
lower_imul64(nir_builder * b,nir_def * x,nir_def * y)412 lower_imul64(nir_builder *b, nir_def *x, nir_def *y)
413 {
414    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
415    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
416    nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
417    nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
418 
419    nir_def *mul_lo = nir_umul_2x32_64(b, x_lo, y_lo);
420    nir_def *res_hi = nir_iadd(b, nir_unpack_64_2x32_split_y(b, mul_lo),
421                               nir_iadd(b, nir_imul(b, x_lo, y_hi),
422                                        nir_imul(b, x_hi, y_lo)));
423 
424    return nir_pack_64_2x32_split(b, nir_unpack_64_2x32_split_x(b, mul_lo),
425                                  res_hi);
426 }
427 
428 static nir_def *
lower_mul_high64(nir_builder * b,nir_def * x,nir_def * y,bool sign_extend)429 lower_mul_high64(nir_builder *b, nir_def *x, nir_def *y,
430                  bool sign_extend)
431 {
432    nir_def *x32[4], *y32[4];
433    x32[0] = nir_unpack_64_2x32_split_x(b, x);
434    x32[1] = nir_unpack_64_2x32_split_y(b, x);
435    if (sign_extend) {
436       x32[2] = x32[3] = nir_ishr_imm(b, x32[1], 31);
437    } else {
438       x32[2] = x32[3] = nir_imm_int(b, 0);
439    }
440 
441    y32[0] = nir_unpack_64_2x32_split_x(b, y);
442    y32[1] = nir_unpack_64_2x32_split_y(b, y);
443    if (sign_extend) {
444       y32[2] = y32[3] = nir_ishr_imm(b, y32[1], 31);
445    } else {
446       y32[2] = y32[3] = nir_imm_int(b, 0);
447    }
448 
449    nir_def *res[8] = {
450       NULL,
451    };
452 
453    /* Yes, the following generates a pile of code.  However, we throw res[0]
454     * and res[1] away in the end and, if we're in the umul case, four of our
455     * eight dword operands will be constant zero and opt_algebraic will clean
456     * this up nicely.
457     */
458    for (unsigned i = 0; i < 4; i++) {
459       nir_def *carry = NULL;
460       for (unsigned j = 0; j < 4; j++) {
461          /* The maximum values of x32[i] and y32[j] are UINT32_MAX so the
462           * maximum value of tmp is UINT32_MAX * UINT32_MAX.  The maximum
463           * value that will fit in tmp is
464           *
465           *    UINT64_MAX = UINT32_MAX << 32 + UINT32_MAX
466           *               = UINT32_MAX * (UINT32_MAX + 1) + UINT32_MAX
467           *               = UINT32_MAX * UINT32_MAX + 2 * UINT32_MAX
468           *
469           * so we're guaranteed that we can add in two more 32-bit values
470           * without overflowing tmp.
471           */
472          nir_def *tmp = nir_umul_2x32_64(b, x32[i], y32[j]);
473 
474          if (res[i + j])
475             tmp = nir_iadd(b, tmp, nir_u2u64(b, res[i + j]));
476          if (carry)
477             tmp = nir_iadd(b, tmp, carry);
478          res[i + j] = nir_u2u32(b, tmp);
479          carry = nir_ushr_imm(b, tmp, 32);
480       }
481       res[i + 4] = nir_u2u32(b, carry);
482    }
483 
484    return nir_pack_64_2x32_split(b, res[2], res[3]);
485 }
486 
487 static nir_def *
lower_isign64(nir_builder * b,nir_def * x)488 lower_isign64(nir_builder *b, nir_def *x)
489 {
490    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
491    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
492 
493    nir_def *is_non_zero = nir_i2b(b, nir_ior(b, x_lo, x_hi));
494    nir_def *res_hi = nir_ishr_imm(b, x_hi, 31);
495    nir_def *res_lo = nir_ior(b, res_hi, nir_b2i32(b, is_non_zero));
496 
497    return nir_pack_64_2x32_split(b, res_lo, res_hi);
498 }
499 
500 static void
lower_udiv64_mod64(nir_builder * b,nir_def * n,nir_def * d,nir_def ** q,nir_def ** r)501 lower_udiv64_mod64(nir_builder *b, nir_def *n, nir_def *d,
502                    nir_def **q, nir_def **r)
503 {
504    /* TODO: We should specially handle the case where the denominator is a
505     * constant.  In that case, we should be able to reduce it to a multiply by
506     * a constant, some shifts, and an add.
507     */
508    nir_def *n_lo = nir_unpack_64_2x32_split_x(b, n);
509    nir_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
510    nir_def *d_lo = nir_unpack_64_2x32_split_x(b, d);
511    nir_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
512 
513    nir_def *q_lo = nir_imm_zero(b, n->num_components, 32);
514    nir_def *q_hi = nir_imm_zero(b, n->num_components, 32);
515 
516    nir_def *n_hi_before_if = n_hi;
517    nir_def *q_hi_before_if = q_hi;
518 
519    /* If the upper 32 bits of denom are non-zero, it is impossible for shifts
520     * greater than 32 bits to occur.  If the upper 32 bits of the numerator
521     * are zero, it is impossible for (denom << [63, 32]) <= numer unless
522     * denom == 0.
523     */
524    nir_def *need_high_div =
525       nir_iand(b, nir_ieq_imm(b, d_hi, 0), nir_uge(b, n_hi, d_lo));
526    nir_push_if(b, nir_bany(b, need_high_div));
527    {
528       /* If we only have one component, then the bany above goes away and
529        * this is always true within the if statement.
530        */
531       if (n->num_components == 1)
532          need_high_div = nir_imm_true(b);
533 
534       nir_def *log2_d_lo = nir_ufind_msb(b, d_lo);
535 
536       for (int i = 31; i >= 0; i--) {
537          /* if ((d.x << i) <= n.y) {
538           *    n.y -= d.x << i;
539           *    quot.y |= 1U << i;
540           * }
541           */
542          nir_def *d_shift = nir_ishl_imm(b, d_lo, i);
543          nir_def *new_n_hi = nir_isub(b, n_hi, d_shift);
544          nir_def *new_q_hi = nir_ior_imm(b, q_hi, 1ull << i);
545          nir_def *cond = nir_iand(b, need_high_div,
546                                   nir_uge(b, n_hi, d_shift));
547          if (i != 0) {
548             /* log2_d_lo is always <= 31, so we don't need to bother with it
549              * in the last iteration.
550              */
551             cond = nir_iand(b, cond,
552                             nir_ile_imm(b, log2_d_lo, 31 - i));
553          }
554          n_hi = nir_bcsel(b, cond, new_n_hi, n_hi);
555          q_hi = nir_bcsel(b, cond, new_q_hi, q_hi);
556       }
557    }
558    nir_pop_if(b, NULL);
559    n_hi = nir_if_phi(b, n_hi, n_hi_before_if);
560    q_hi = nir_if_phi(b, q_hi, q_hi_before_if);
561 
562    nir_def *log2_denom = nir_ufind_msb(b, d_hi);
563 
564    n = nir_pack_64_2x32_split(b, n_lo, n_hi);
565    d = nir_pack_64_2x32_split(b, d_lo, d_hi);
566    for (int i = 31; i >= 0; i--) {
567       /* if ((d64 << i) <= n64) {
568        *    n64 -= d64 << i;
569        *    quot.x |= 1U << i;
570        * }
571        */
572       nir_def *d_shift = nir_ishl_imm(b, d, i);
573       nir_def *new_n = nir_isub(b, n, d_shift);
574       nir_def *new_q_lo = nir_ior_imm(b, q_lo, 1ull << i);
575       nir_def *cond = nir_uge(b, n, d_shift);
576       if (i != 0) {
577          /* log2_denom is always <= 31, so we don't need to bother with it
578           * in the last iteration.
579           */
580          cond = nir_iand(b, cond,
581                          nir_ile_imm(b, log2_denom, 31 - i));
582       }
583       n = nir_bcsel(b, cond, new_n, n);
584       q_lo = nir_bcsel(b, cond, new_q_lo, q_lo);
585    }
586 
587    *q = nir_pack_64_2x32_split(b, q_lo, q_hi);
588    *r = n;
589 }
590 
591 static nir_def *
lower_udiv64(nir_builder * b,nir_def * n,nir_def * d)592 lower_udiv64(nir_builder *b, nir_def *n, nir_def *d)
593 {
594    nir_def *q, *r;
595    lower_udiv64_mod64(b, n, d, &q, &r);
596    return q;
597 }
598 
599 static nir_def *
lower_idiv64(nir_builder * b,nir_def * n,nir_def * d)600 lower_idiv64(nir_builder *b, nir_def *n, nir_def *d)
601 {
602    nir_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
603    nir_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
604 
605    nir_def *negate = nir_ine(b, nir_ilt_imm(b, n_hi, 0),
606                              nir_ilt_imm(b, d_hi, 0));
607    nir_def *q, *r;
608    lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
609    return nir_bcsel(b, negate, nir_ineg(b, q), q);
610 }
611 
612 static nir_def *
lower_umod64(nir_builder * b,nir_def * n,nir_def * d)613 lower_umod64(nir_builder *b, nir_def *n, nir_def *d)
614 {
615    nir_def *q, *r;
616    lower_udiv64_mod64(b, n, d, &q, &r);
617    return r;
618 }
619 
620 static nir_def *
lower_imod64(nir_builder * b,nir_def * n,nir_def * d)621 lower_imod64(nir_builder *b, nir_def *n, nir_def *d)
622 {
623    nir_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
624    nir_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
625    nir_def *n_is_neg = nir_ilt_imm(b, n_hi, 0);
626    nir_def *d_is_neg = nir_ilt_imm(b, d_hi, 0);
627 
628    nir_def *q, *r;
629    lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
630 
631    nir_def *rem = nir_bcsel(b, n_is_neg, nir_ineg(b, r), r);
632 
633    return nir_bcsel(b, nir_ieq_imm(b, r, 0), nir_imm_int64(b, 0),
634                     nir_bcsel(b, nir_ieq(b, n_is_neg, d_is_neg), rem,
635                               nir_iadd(b, rem, d)));
636 }
637 
638 static nir_def *
lower_irem64(nir_builder * b,nir_def * n,nir_def * d)639 lower_irem64(nir_builder *b, nir_def *n, nir_def *d)
640 {
641    nir_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
642    nir_def *n_is_neg = nir_ilt_imm(b, n_hi, 0);
643 
644    nir_def *q, *r;
645    lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
646    return nir_bcsel(b, n_is_neg, nir_ineg(b, r), r);
647 }
648 
649 static nir_def *
lower_extract(nir_builder * b,nir_op op,nir_def * x,nir_def * c)650 lower_extract(nir_builder *b, nir_op op, nir_def *x, nir_def *c)
651 {
652    assert(op == nir_op_extract_u8 || op == nir_op_extract_i8 ||
653           op == nir_op_extract_u16 || op == nir_op_extract_i16);
654 
655    const int chunk = nir_src_as_uint(nir_src_for_ssa(c));
656    const int chunk_bits =
657       (op == nir_op_extract_u8 || op == nir_op_extract_i8) ? 8 : 16;
658    const int num_chunks_in_32 = 32 / chunk_bits;
659 
660    nir_def *extract32;
661    if (chunk < num_chunks_in_32) {
662       extract32 = nir_build_alu(b, op, nir_unpack_64_2x32_split_x(b, x),
663                                 nir_imm_int(b, chunk),
664                                 NULL, NULL);
665    } else {
666       extract32 = nir_build_alu(b, op, nir_unpack_64_2x32_split_y(b, x),
667                                 nir_imm_int(b, chunk - num_chunks_in_32),
668                                 NULL, NULL);
669    }
670 
671    if (op == nir_op_extract_i8 || op == nir_op_extract_i16)
672       return lower_i2i64(b, extract32);
673    else
674       return lower_u2u64(b, extract32);
675 }
676 
677 static nir_def *
lower_ufind_msb64(nir_builder * b,nir_def * x)678 lower_ufind_msb64(nir_builder *b, nir_def *x)
679 {
680 
681    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
682    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
683    nir_def *lo_count = nir_ufind_msb(b, x_lo);
684    nir_def *hi_count = nir_ufind_msb(b, x_hi);
685 
686    if (b->shader->options->lower_uadd_sat) {
687       nir_def *valid_hi_bits = nir_ine_imm(b, x_hi, 0);
688       nir_def *hi_res = nir_iadd_imm(b, hi_count, 32);
689       return nir_bcsel(b, valid_hi_bits, hi_res, lo_count);
690    } else {
691       /* If hi_count was -1, it will still be -1 after this uadd_sat. As a
692        * result, hi_count is either -1 or the correct return value for 64-bit
693        * ufind_msb.
694        */
695       nir_def *hi_res = nir_uadd_sat(b, nir_imm_intN_t(b, 32, 32), hi_count);
696 
697       /* hi_res is either -1 or a value in the range [63, 32]. lo_count is
698        * either -1 or a value in the range [31, 0]. The imax will pick
699        * lo_count only when hi_res is -1. In those cases, lo_count is
700        * guaranteed to be the correct answer.
701        */
702       return nir_imax(b, hi_res, lo_count);
703    }
704 }
705 
706 static nir_def *
lower_find_lsb64(nir_builder * b,nir_def * x)707 lower_find_lsb64(nir_builder *b, nir_def *x)
708 {
709    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
710    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
711    nir_def *lo_lsb = nir_find_lsb(b, x_lo);
712    nir_def *hi_lsb = nir_find_lsb(b, x_hi);
713 
714    /* Use umin so that -1 (no bits found) becomes larger (0xFFFFFFFF)
715     * than any actual bit position, so we return a found bit instead.
716     * This is similar to the ufind_msb lowering. If you need this lowering
717     * without uadd_sat, add code like in lower_ufind_msb64.
718     */
719    assert(!b->shader->options->lower_uadd_sat);
720    return nir_umin(b, lo_lsb, nir_uadd_sat(b, hi_lsb, nir_imm_int(b, 32)));
721 }
722 
723 static nir_def *
lower_2f(nir_builder * b,nir_def * x,unsigned dest_bit_size,bool src_is_signed)724 lower_2f(nir_builder *b, nir_def *x, unsigned dest_bit_size,
725          bool src_is_signed)
726 {
727    nir_def *x_sign = NULL;
728 
729    if (src_is_signed) {
730       x_sign = nir_bcsel(b, COND_LOWER_CMP(b, ilt, x, nir_imm_int64(b, 0)),
731                          nir_imm_floatN_t(b, -1, dest_bit_size),
732                          nir_imm_floatN_t(b, 1, dest_bit_size));
733       x = COND_LOWER_OP(b, iabs, x);
734    }
735 
736    nir_def *exp = COND_LOWER_OP(b, ufind_msb, x);
737    unsigned significand_bits;
738 
739    switch (dest_bit_size) {
740    case 64:
741       significand_bits = 52;
742       break;
743    case 32:
744       significand_bits = 23;
745       break;
746    case 16:
747       significand_bits = 10;
748       break;
749    default:
750       unreachable("Invalid dest_bit_size");
751    }
752 
753    nir_def *discard =
754       nir_imax(b, nir_iadd_imm(b, exp, -significand_bits),
755                nir_imm_int(b, 0));
756    nir_def *significand = COND_LOWER_OP(b, ushr, x, discard);
757    if (significand_bits < 32)
758       significand = COND_LOWER_CAST(b, u2u32, significand);
759 
760    /* Round-to-nearest-even implementation:
761     * - if the non-representable part of the significand is higher than half
762     *   the minimum representable significand, we round-up
763     * - if the non-representable part of the significand is equal to half the
764     *   minimum representable significand and the representable part of the
765     *   significand is odd, we round-up
766     * - in any other case, we round-down
767     */
768    nir_def *lsb_mask = COND_LOWER_OP(b, ishl, nir_imm_int64(b, 1), discard);
769    nir_def *rem_mask = COND_LOWER_OP(b, isub, lsb_mask, nir_imm_int64(b, 1));
770    nir_def *half = COND_LOWER_OP(b, ishr, lsb_mask, nir_imm_int(b, 1));
771    nir_def *rem = COND_LOWER_OP(b, iand, x, rem_mask);
772    nir_def *halfway = nir_iand(b, COND_LOWER_CMP(b, ieq, rem, half),
773                                nir_ine_imm(b, discard, 0));
774    nir_def *is_odd = COND_LOWER_CMP(b, ine, nir_imm_int64(b, 0),
775                                     COND_LOWER_OP(b, iand, x, lsb_mask));
776    nir_def *round_up = nir_ior(b, COND_LOWER_CMP(b, ilt, half, rem),
777                                nir_iand(b, halfway, is_odd));
778    if (!nir_is_rounding_mode_rtz(b->shader->info.float_controls_execution_mode,
779                                  dest_bit_size)) {
780       if (significand_bits >= 32)
781          significand = COND_LOWER_OP(b, iadd, significand,
782                                      COND_LOWER_CAST(b, b2i64, round_up));
783       else
784          significand = nir_iadd(b, significand, nir_b2i32(b, round_up));
785    }
786 
787    nir_def *res;
788 
789    if (dest_bit_size == 64) {
790       /* Compute the left shift required to normalize the original
791        * unrounded input manually.
792        */
793       nir_def *shift =
794          nir_imax(b, nir_isub_imm(b, significand_bits, exp),
795                   nir_imm_int(b, 0));
796       significand = COND_LOWER_OP(b, ishl, significand, shift);
797 
798       /* Check whether normalization led to overflow of the available
799        * significand bits, which can only happen if round_up was true
800        * above, in which case we need to add carry to the exponent and
801        * discard an extra bit from the significand.  Note that we
802        * don't need to repeat the round-up logic again, since the LSB
803        * of the significand is guaranteed to be zero if there was
804        * overflow.
805        */
806       nir_def *carry = nir_b2i32(
807          b, nir_uge_imm(b, nir_unpack_64_2x32_split_y(b, significand),
808                         (uint64_t)(1 << (significand_bits - 31))));
809       significand = COND_LOWER_OP(b, ishr, significand, carry);
810       exp = nir_iadd(b, exp, carry);
811 
812       /* Compute the biased exponent, taking care to handle a zero
813        * input correctly, which would have caused exp to be negative.
814        */
815       nir_def *biased_exp = nir_bcsel(b, nir_ilt_imm(b, exp, 0),
816                                       nir_imm_int(b, 0),
817                                       nir_iadd_imm(b, exp, 1023));
818 
819       /* Pack the significand and exponent manually. */
820       nir_def *lo = nir_unpack_64_2x32_split_x(b, significand);
821       nir_def *hi = nir_bitfield_insert(
822          b, nir_unpack_64_2x32_split_y(b, significand),
823          biased_exp, nir_imm_int(b, 20), nir_imm_int(b, 11));
824 
825       res = nir_pack_64_2x32_split(b, lo, hi);
826 
827    } else if (dest_bit_size == 32) {
828       res = nir_fmul(b, nir_u2f32(b, significand),
829                      nir_fexp2(b, nir_u2f32(b, discard)));
830    } else {
831       res = nir_fmul(b, nir_u2f16(b, significand),
832                      nir_fexp2(b, nir_u2f16(b, discard)));
833    }
834 
835    if (src_is_signed)
836       res = nir_fmul(b, res, x_sign);
837 
838    return res;
839 }
840 
841 static nir_def *
lower_f2(nir_builder * b,nir_def * x,bool dst_is_signed)842 lower_f2(nir_builder *b, nir_def *x, bool dst_is_signed)
843 {
844    assert(x->bit_size == 16 || x->bit_size == 32 || x->bit_size == 64);
845    nir_def *x_sign = NULL;
846 
847    if (dst_is_signed)
848       x_sign = nir_fsign(b, x);
849 
850    x = nir_ftrunc(b, x);
851 
852    if (dst_is_signed)
853       x = nir_fabs(b, x);
854 
855    nir_def *res;
856    if (x->bit_size < 32) {
857       res = nir_pack_64_2x32_split(b, nir_f2u32(b, x), nir_imm_int(b, 0));
858    } else {
859       nir_def *div = nir_imm_floatN_t(b, 1ULL << 32, x->bit_size);
860       nir_def *res_hi = nir_f2u32(b, nir_fdiv(b, x, div));
861       nir_def *res_lo = nir_f2u32(b, nir_frem(b, x, div));
862       res = nir_pack_64_2x32_split(b, res_lo, res_hi);
863    }
864 
865    if (dst_is_signed)
866       res = nir_bcsel(b, nir_flt_imm(b, x_sign, 0),
867                       nir_ineg(b, res), res);
868 
869    return res;
870 }
871 
872 static nir_def *
lower_bit_count64(nir_builder * b,nir_def * x)873 lower_bit_count64(nir_builder *b, nir_def *x)
874 {
875    nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
876    nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
877    nir_def *lo_count = nir_bit_count(b, x_lo);
878    nir_def *hi_count = nir_bit_count(b, x_hi);
879    return nir_iadd(b, lo_count, hi_count);
880 }
881 
882 nir_lower_int64_options
nir_lower_int64_op_to_options_mask(nir_op opcode)883 nir_lower_int64_op_to_options_mask(nir_op opcode)
884 {
885    switch (opcode) {
886    case nir_op_imul:
887    case nir_op_amul:
888       return nir_lower_imul64;
889    case nir_op_imul_2x32_64:
890    case nir_op_umul_2x32_64:
891       return nir_lower_imul_2x32_64;
892    case nir_op_imul_high:
893    case nir_op_umul_high:
894       return nir_lower_imul_high64;
895    case nir_op_isign:
896       return nir_lower_isign64;
897    case nir_op_udiv:
898    case nir_op_idiv:
899    case nir_op_umod:
900    case nir_op_imod:
901    case nir_op_irem:
902       return nir_lower_divmod64;
903    case nir_op_b2i64:
904    case nir_op_i2i8:
905    case nir_op_i2i16:
906    case nir_op_i2i32:
907    case nir_op_i2i64:
908    case nir_op_u2u8:
909    case nir_op_u2u16:
910    case nir_op_u2u32:
911    case nir_op_u2u64:
912    case nir_op_i2f64:
913    case nir_op_u2f64:
914    case nir_op_i2f32:
915    case nir_op_u2f32:
916    case nir_op_i2f16:
917    case nir_op_u2f16:
918    case nir_op_f2i64:
919    case nir_op_f2u64:
920       return nir_lower_conv64;
921    case nir_op_bcsel:
922       return nir_lower_bcsel64;
923    case nir_op_ieq:
924    case nir_op_ine:
925    case nir_op_ult:
926    case nir_op_ilt:
927    case nir_op_uge:
928    case nir_op_ige:
929       return nir_lower_icmp64;
930    case nir_op_iadd:
931    case nir_op_isub:
932       return nir_lower_iadd64;
933    case nir_op_imin:
934    case nir_op_imax:
935    case nir_op_umin:
936    case nir_op_umax:
937       return nir_lower_minmax64;
938    case nir_op_iabs:
939       return nir_lower_iabs64;
940    case nir_op_ineg:
941       return nir_lower_ineg64;
942    case nir_op_iand:
943    case nir_op_ior:
944    case nir_op_ixor:
945    case nir_op_inot:
946       return nir_lower_logic64;
947    case nir_op_ishl:
948    case nir_op_ishr:
949    case nir_op_ushr:
950       return nir_lower_shift64;
951    case nir_op_extract_u8:
952    case nir_op_extract_i8:
953    case nir_op_extract_u16:
954    case nir_op_extract_i16:
955       return nir_lower_extract64;
956    case nir_op_ufind_msb:
957       return nir_lower_ufind_msb64;
958    case nir_op_find_lsb:
959       return nir_lower_find_lsb64;
960    case nir_op_bit_count:
961       return nir_lower_bit_count64;
962    default:
963       return 0;
964    }
965 }
966 
967 static nir_def *
lower_int64_alu_instr(nir_builder * b,nir_alu_instr * alu)968 lower_int64_alu_instr(nir_builder *b, nir_alu_instr *alu)
969 {
970    nir_def *src[4];
971    for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
972       src[i] = nir_ssa_for_alu_src(b, alu, i);
973 
974    switch (alu->op) {
975    case nir_op_imul:
976    case nir_op_amul:
977       return lower_imul64(b, src[0], src[1]);
978    case nir_op_imul_2x32_64:
979       return lower_mul_2x32_64(b, src[0], src[1], true);
980    case nir_op_umul_2x32_64:
981       return lower_mul_2x32_64(b, src[0], src[1], false);
982    case nir_op_imul_high:
983       return lower_mul_high64(b, src[0], src[1], true);
984    case nir_op_umul_high:
985       return lower_mul_high64(b, src[0], src[1], false);
986    case nir_op_isign:
987       return lower_isign64(b, src[0]);
988    case nir_op_udiv:
989       return lower_udiv64(b, src[0], src[1]);
990    case nir_op_idiv:
991       return lower_idiv64(b, src[0], src[1]);
992    case nir_op_umod:
993       return lower_umod64(b, src[0], src[1]);
994    case nir_op_imod:
995       return lower_imod64(b, src[0], src[1]);
996    case nir_op_irem:
997       return lower_irem64(b, src[0], src[1]);
998    case nir_op_b2i64:
999       return lower_b2i64(b, src[0]);
1000    case nir_op_i2i8:
1001       return lower_i2i8(b, src[0]);
1002    case nir_op_i2i16:
1003       return lower_i2i16(b, src[0]);
1004    case nir_op_i2i32:
1005       return lower_i2i32(b, src[0]);
1006    case nir_op_i2i64:
1007       return lower_i2i64(b, src[0]);
1008    case nir_op_u2u8:
1009       return lower_u2u8(b, src[0]);
1010    case nir_op_u2u16:
1011       return lower_u2u16(b, src[0]);
1012    case nir_op_u2u32:
1013       return lower_u2u32(b, src[0]);
1014    case nir_op_u2u64:
1015       return lower_u2u64(b, src[0]);
1016    case nir_op_bcsel:
1017       return lower_bcsel64(b, src[0], src[1], src[2]);
1018    case nir_op_ieq:
1019    case nir_op_ine:
1020    case nir_op_ult:
1021    case nir_op_ilt:
1022    case nir_op_uge:
1023    case nir_op_ige:
1024       return lower_int64_compare(b, alu->op, src[0], src[1]);
1025    case nir_op_iadd:
1026       return lower_iadd64(b, src[0], src[1]);
1027    case nir_op_isub:
1028       return lower_isub64(b, src[0], src[1]);
1029    case nir_op_imin:
1030       return lower_imin64(b, src[0], src[1]);
1031    case nir_op_imax:
1032       return lower_imax64(b, src[0], src[1]);
1033    case nir_op_umin:
1034       return lower_umin64(b, src[0], src[1]);
1035    case nir_op_umax:
1036       return lower_umax64(b, src[0], src[1]);
1037    case nir_op_iabs:
1038       return lower_iabs64(b, src[0]);
1039    case nir_op_ineg:
1040       return lower_ineg64(b, src[0]);
1041    case nir_op_iand:
1042       return lower_iand64(b, src[0], src[1]);
1043    case nir_op_ior:
1044       return lower_ior64(b, src[0], src[1]);
1045    case nir_op_ixor:
1046       return lower_ixor64(b, src[0], src[1]);
1047    case nir_op_inot:
1048       return lower_inot64(b, src[0]);
1049    case nir_op_ishl:
1050       return lower_ishl64(b, src[0], src[1]);
1051    case nir_op_ishr:
1052       return lower_ishr64(b, src[0], src[1]);
1053    case nir_op_ushr:
1054       return lower_ushr64(b, src[0], src[1]);
1055    case nir_op_extract_u8:
1056    case nir_op_extract_i8:
1057    case nir_op_extract_u16:
1058    case nir_op_extract_i16:
1059       return lower_extract(b, alu->op, src[0], src[1]);
1060    case nir_op_ufind_msb:
1061       return lower_ufind_msb64(b, src[0]);
1062    case nir_op_find_lsb:
1063       return lower_find_lsb64(b, src[0]);
1064    case nir_op_bit_count:
1065       return lower_bit_count64(b, src[0]);
1066    case nir_op_i2f64:
1067    case nir_op_i2f32:
1068    case nir_op_i2f16:
1069       return lower_2f(b, src[0], alu->def.bit_size, true);
1070    case nir_op_u2f64:
1071    case nir_op_u2f32:
1072    case nir_op_u2f16:
1073       return lower_2f(b, src[0], alu->def.bit_size, false);
1074    case nir_op_f2i64:
1075    case nir_op_f2u64:
1076       return lower_f2(b, src[0], alu->op == nir_op_f2i64);
1077    default:
1078       unreachable("Invalid ALU opcode to lower");
1079    }
1080 }
1081 
1082 static bool
should_lower_int64_alu_instr(const nir_alu_instr * alu,const nir_shader_compiler_options * options)1083 should_lower_int64_alu_instr(const nir_alu_instr *alu,
1084                              const nir_shader_compiler_options *options)
1085 {
1086    switch (alu->op) {
1087    case nir_op_i2i8:
1088    case nir_op_i2i16:
1089    case nir_op_i2i32:
1090    case nir_op_u2u8:
1091    case nir_op_u2u16:
1092    case nir_op_u2u32:
1093       if (alu->src[0].src.ssa->bit_size != 64)
1094          return false;
1095       break;
1096    case nir_op_bcsel:
1097       assert(alu->src[1].src.ssa->bit_size ==
1098              alu->src[2].src.ssa->bit_size);
1099       if (alu->src[1].src.ssa->bit_size != 64)
1100          return false;
1101       break;
1102    case nir_op_ieq:
1103    case nir_op_ine:
1104    case nir_op_ult:
1105    case nir_op_ilt:
1106    case nir_op_uge:
1107    case nir_op_ige:
1108       assert(alu->src[0].src.ssa->bit_size ==
1109              alu->src[1].src.ssa->bit_size);
1110       if (alu->src[0].src.ssa->bit_size != 64)
1111          return false;
1112       break;
1113    case nir_op_ufind_msb:
1114    case nir_op_find_lsb:
1115    case nir_op_bit_count:
1116       if (alu->src[0].src.ssa->bit_size != 64)
1117          return false;
1118       break;
1119    case nir_op_amul:
1120       if (options->has_imul24)
1121          return false;
1122       if (alu->def.bit_size != 64)
1123          return false;
1124       break;
1125    case nir_op_i2f64:
1126    case nir_op_u2f64:
1127    case nir_op_i2f32:
1128    case nir_op_u2f32:
1129    case nir_op_i2f16:
1130    case nir_op_u2f16:
1131       if (alu->src[0].src.ssa->bit_size != 64)
1132          return false;
1133       break;
1134    case nir_op_f2u64:
1135    case nir_op_f2i64:
1136       FALLTHROUGH;
1137    default:
1138       if (alu->def.bit_size != 64)
1139          return false;
1140       break;
1141    }
1142 
1143    unsigned mask = nir_lower_int64_op_to_options_mask(alu->op);
1144    return (options->lower_int64_options & mask) != 0;
1145 }
1146 
1147 static nir_def *
split_64bit_subgroup_op(nir_builder * b,const nir_intrinsic_instr * intrin)1148 split_64bit_subgroup_op(nir_builder *b, const nir_intrinsic_instr *intrin)
1149 {
1150    const nir_intrinsic_info *info = &nir_intrinsic_infos[intrin->intrinsic];
1151 
1152    /* This works on subgroup ops with a single 64-bit source which can be
1153     * trivially lowered by doing the exact same op on both halves.
1154     */
1155    assert(nir_src_bit_size(intrin->src[0]) == 64);
1156    nir_def *split_src0[2] = {
1157       nir_unpack_64_2x32_split_x(b, intrin->src[0].ssa),
1158       nir_unpack_64_2x32_split_y(b, intrin->src[0].ssa),
1159    };
1160 
1161    assert(info->has_dest && intrin->def.bit_size == 64);
1162 
1163    nir_def *res[2];
1164    for (unsigned i = 0; i < 2; i++) {
1165       nir_intrinsic_instr *split =
1166          nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
1167       split->num_components = intrin->num_components;
1168       split->src[0] = nir_src_for_ssa(split_src0[i]);
1169 
1170       /* Other sources must be less than 64 bits and get copied directly */
1171       for (unsigned j = 1; j < info->num_srcs; j++) {
1172          assert(nir_src_bit_size(intrin->src[j]) < 64);
1173          split->src[j] = nir_src_for_ssa(intrin->src[j].ssa);
1174       }
1175 
1176       /* Copy const indices, if any */
1177       memcpy(split->const_index, intrin->const_index,
1178              sizeof(intrin->const_index));
1179 
1180       nir_def_init(&split->instr, &split->def,
1181                    intrin->def.num_components, 32);
1182       nir_builder_instr_insert(b, &split->instr);
1183 
1184       res[i] = &split->def;
1185    }
1186 
1187    return nir_pack_64_2x32_split(b, res[0], res[1]);
1188 }
1189 
1190 static nir_def *
build_vote_ieq(nir_builder * b,nir_def * x)1191 build_vote_ieq(nir_builder *b, nir_def *x)
1192 {
1193    nir_intrinsic_instr *vote =
1194       nir_intrinsic_instr_create(b->shader, nir_intrinsic_vote_ieq);
1195    vote->src[0] = nir_src_for_ssa(x);
1196    vote->num_components = x->num_components;
1197    nir_def_init(&vote->instr, &vote->def, 1, 1);
1198    nir_builder_instr_insert(b, &vote->instr);
1199    return &vote->def;
1200 }
1201 
1202 static nir_def *
lower_vote_ieq(nir_builder * b,nir_def * x)1203 lower_vote_ieq(nir_builder *b, nir_def *x)
1204 {
1205    return nir_iand(b, build_vote_ieq(b, nir_unpack_64_2x32_split_x(b, x)),
1206                    build_vote_ieq(b, nir_unpack_64_2x32_split_y(b, x)));
1207 }
1208 
1209 static nir_def *
build_scan_intrinsic(nir_builder * b,nir_intrinsic_op scan_op,nir_op reduction_op,unsigned cluster_size,nir_def * val)1210 build_scan_intrinsic(nir_builder *b, nir_intrinsic_op scan_op,
1211                      nir_op reduction_op, unsigned cluster_size,
1212                      nir_def *val)
1213 {
1214    nir_intrinsic_instr *scan =
1215       nir_intrinsic_instr_create(b->shader, scan_op);
1216    scan->num_components = val->num_components;
1217    scan->src[0] = nir_src_for_ssa(val);
1218    nir_intrinsic_set_reduction_op(scan, reduction_op);
1219    if (scan_op == nir_intrinsic_reduce)
1220       nir_intrinsic_set_cluster_size(scan, cluster_size);
1221    nir_def_init(&scan->instr, &scan->def, val->num_components,
1222                 val->bit_size);
1223    nir_builder_instr_insert(b, &scan->instr);
1224    return &scan->def;
1225 }
1226 
1227 static nir_def *
lower_scan_iadd64(nir_builder * b,const nir_intrinsic_instr * intrin)1228 lower_scan_iadd64(nir_builder *b, const nir_intrinsic_instr *intrin)
1229 {
1230    unsigned cluster_size =
1231       intrin->intrinsic == nir_intrinsic_reduce ? nir_intrinsic_cluster_size(intrin) : 0;
1232 
1233    /* Split it into three chunks of no more than 24 bits each.  With 8 bits
1234     * of headroom, we're guaranteed that there will never be overflow in the
1235     * individual subgroup operations.  (Assuming, of course, a subgroup size
1236     * no larger than 256 which seems reasonable.)  We can then scan on each of
1237     * the chunks and add them back together at the end.
1238     */
1239    nir_def *x = intrin->src[0].ssa;
1240    nir_def *x_low =
1241       nir_u2u32(b, nir_iand_imm(b, x, 0xffffff));
1242    nir_def *x_mid =
1243       nir_u2u32(b, nir_iand_imm(b, nir_ushr_imm(b, x, 24),
1244                                 0xffffff));
1245    nir_def *x_hi =
1246       nir_u2u32(b, nir_ushr_imm(b, x, 48));
1247 
1248    nir_def *scan_low =
1249       build_scan_intrinsic(b, intrin->intrinsic, nir_op_iadd,
1250                            cluster_size, x_low);
1251    nir_def *scan_mid =
1252       build_scan_intrinsic(b, intrin->intrinsic, nir_op_iadd,
1253                            cluster_size, x_mid);
1254    nir_def *scan_hi =
1255       build_scan_intrinsic(b, intrin->intrinsic, nir_op_iadd,
1256                            cluster_size, x_hi);
1257 
1258    scan_low = nir_u2u64(b, scan_low);
1259    scan_mid = nir_ishl_imm(b, nir_u2u64(b, scan_mid), 24);
1260    scan_hi = nir_ishl_imm(b, nir_u2u64(b, scan_hi), 48);
1261 
1262    return nir_iadd(b, scan_hi, nir_iadd(b, scan_mid, scan_low));
1263 }
1264 
1265 static bool
should_lower_int64_intrinsic(const nir_intrinsic_instr * intrin,const nir_shader_compiler_options * options)1266 should_lower_int64_intrinsic(const nir_intrinsic_instr *intrin,
1267                              const nir_shader_compiler_options *options)
1268 {
1269    switch (intrin->intrinsic) {
1270    case nir_intrinsic_read_invocation:
1271    case nir_intrinsic_read_first_invocation:
1272    case nir_intrinsic_shuffle:
1273    case nir_intrinsic_shuffle_xor:
1274    case nir_intrinsic_shuffle_up:
1275    case nir_intrinsic_shuffle_down:
1276    case nir_intrinsic_quad_broadcast:
1277    case nir_intrinsic_quad_swap_horizontal:
1278    case nir_intrinsic_quad_swap_vertical:
1279    case nir_intrinsic_quad_swap_diagonal:
1280       return intrin->def.bit_size == 64 &&
1281              (options->lower_int64_options & nir_lower_subgroup_shuffle64);
1282 
1283    case nir_intrinsic_vote_ieq:
1284       return intrin->src[0].ssa->bit_size == 64 &&
1285              (options->lower_int64_options & nir_lower_vote_ieq64);
1286 
1287    case nir_intrinsic_reduce:
1288    case nir_intrinsic_inclusive_scan:
1289    case nir_intrinsic_exclusive_scan:
1290       if (intrin->def.bit_size != 64)
1291          return false;
1292 
1293       switch (nir_intrinsic_reduction_op(intrin)) {
1294       case nir_op_iadd:
1295          return options->lower_int64_options & nir_lower_scan_reduce_iadd64;
1296       case nir_op_iand:
1297       case nir_op_ior:
1298       case nir_op_ixor:
1299          return options->lower_int64_options & nir_lower_scan_reduce_bitwise64;
1300       default:
1301          return false;
1302       }
1303       break;
1304 
1305    default:
1306       return false;
1307    }
1308 }
1309 
1310 static nir_def *
lower_int64_intrinsic(nir_builder * b,nir_intrinsic_instr * intrin)1311 lower_int64_intrinsic(nir_builder *b, nir_intrinsic_instr *intrin)
1312 {
1313    switch (intrin->intrinsic) {
1314    case nir_intrinsic_read_invocation:
1315    case nir_intrinsic_read_first_invocation:
1316    case nir_intrinsic_shuffle:
1317    case nir_intrinsic_shuffle_xor:
1318    case nir_intrinsic_shuffle_up:
1319    case nir_intrinsic_shuffle_down:
1320    case nir_intrinsic_quad_broadcast:
1321    case nir_intrinsic_quad_swap_horizontal:
1322    case nir_intrinsic_quad_swap_vertical:
1323    case nir_intrinsic_quad_swap_diagonal:
1324       return split_64bit_subgroup_op(b, intrin);
1325 
1326    case nir_intrinsic_vote_ieq:
1327       return lower_vote_ieq(b, intrin->src[0].ssa);
1328 
1329    case nir_intrinsic_reduce:
1330    case nir_intrinsic_inclusive_scan:
1331    case nir_intrinsic_exclusive_scan:
1332       switch (nir_intrinsic_reduction_op(intrin)) {
1333       case nir_op_iadd:
1334          return lower_scan_iadd64(b, intrin);
1335       case nir_op_iand:
1336       case nir_op_ior:
1337       case nir_op_ixor:
1338          return split_64bit_subgroup_op(b, intrin);
1339       default:
1340          unreachable("Unsupported subgroup scan/reduce op");
1341       }
1342       break;
1343 
1344    default:
1345       unreachable("Unsupported intrinsic");
1346    }
1347    return NULL;
1348 }
1349 
1350 static bool
should_lower_int64_instr(const nir_instr * instr,const void * _options)1351 should_lower_int64_instr(const nir_instr *instr, const void *_options)
1352 {
1353    switch (instr->type) {
1354    case nir_instr_type_alu:
1355       return should_lower_int64_alu_instr(nir_instr_as_alu(instr), _options);
1356    case nir_instr_type_intrinsic:
1357       return should_lower_int64_intrinsic(nir_instr_as_intrinsic(instr),
1358                                           _options);
1359    default:
1360       return false;
1361    }
1362 }
1363 
1364 static nir_def *
lower_int64_instr(nir_builder * b,nir_instr * instr,void * _options)1365 lower_int64_instr(nir_builder *b, nir_instr *instr, void *_options)
1366 {
1367    switch (instr->type) {
1368    case nir_instr_type_alu:
1369       return lower_int64_alu_instr(b, nir_instr_as_alu(instr));
1370    case nir_instr_type_intrinsic:
1371       return lower_int64_intrinsic(b, nir_instr_as_intrinsic(instr));
1372    default:
1373       return NULL;
1374    }
1375 }
1376 
1377 bool
nir_lower_int64(nir_shader * shader)1378 nir_lower_int64(nir_shader *shader)
1379 {
1380    return nir_shader_lower_instructions(shader, should_lower_int64_instr,
1381                                         lower_int64_instr,
1382                                         (void *)shader->options);
1383 }
1384 
1385 static bool
should_lower_int64_float_conv(const nir_instr * instr,const void * _options)1386 should_lower_int64_float_conv(const nir_instr *instr, const void *_options)
1387 {
1388    if (instr->type != nir_instr_type_alu)
1389       return false;
1390 
1391    nir_alu_instr *alu = nir_instr_as_alu(instr);
1392 
1393    switch (alu->op) {
1394    case nir_op_i2f64:
1395    case nir_op_i2f32:
1396    case nir_op_i2f16:
1397    case nir_op_u2f64:
1398    case nir_op_u2f32:
1399    case nir_op_u2f16:
1400    case nir_op_f2i64:
1401    case nir_op_f2u64:
1402       return should_lower_int64_alu_instr(alu, _options);
1403    default:
1404       return false;
1405    }
1406 }
1407 
1408 /**
1409  * Like nir_lower_int64(), but only lowers conversions to/from float.
1410  *
1411  * These operations in particular may affect double-precision lowering,
1412  * so it can be useful to run them in tandem with nir_lower_doubles().
1413  */
1414 bool
nir_lower_int64_float_conversions(nir_shader * shader)1415 nir_lower_int64_float_conversions(nir_shader *shader)
1416 {
1417    return nir_shader_lower_instructions(shader, should_lower_int64_float_conv,
1418                                         lower_int64_instr,
1419                                         (void *)shader->options);
1420 }
1421