• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2016 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir.h"
25 #include "nir_builder.h"
26 
27 #define COND_LOWER_OP(b, name, ...)                                   \
28         (b->shader->options->lower_int64_options &                    \
29          nir_lower_int64_op_to_options_mask(nir_op_##name)) ?         \
30         lower_##name##64(b, __VA_ARGS__) : nir_##name(b, __VA_ARGS__)
31 
32 #define COND_LOWER_CMP(b, name, ...)                                  \
33         (b->shader->options->lower_int64_options &                    \
34          nir_lower_int64_op_to_options_mask(nir_op_##name)) ?         \
35         lower_int64_compare(b, nir_op_##name, __VA_ARGS__) :          \
36         nir_##name(b, __VA_ARGS__)
37 
38 #define COND_LOWER_CAST(b, name, ...)                                 \
39         (b->shader->options->lower_int64_options &                    \
40          nir_lower_int64_op_to_options_mask(nir_op_##name)) ?         \
41         lower_##name(b, __VA_ARGS__) :                                \
42         nir_##name(b, __VA_ARGS__)
43 
44 static nir_ssa_def *
lower_b2i64(nir_builder * b,nir_ssa_def * x)45 lower_b2i64(nir_builder *b, nir_ssa_def *x)
46 {
47    return nir_pack_64_2x32_split(b, nir_b2i32(b, x), nir_imm_int(b, 0));
48 }
49 
50 static nir_ssa_def *
lower_i2b(nir_builder * b,nir_ssa_def * x)51 lower_i2b(nir_builder *b, nir_ssa_def *x)
52 {
53    return nir_ine(b, nir_ior(b, nir_unpack_64_2x32_split_x(b, x),
54                                 nir_unpack_64_2x32_split_y(b, x)),
55                      nir_imm_int(b, 0));
56 }
57 
58 static nir_ssa_def *
lower_i2i8(nir_builder * b,nir_ssa_def * x)59 lower_i2i8(nir_builder *b, nir_ssa_def *x)
60 {
61    return nir_i2i8(b, nir_unpack_64_2x32_split_x(b, x));
62 }
63 
64 static nir_ssa_def *
lower_i2i16(nir_builder * b,nir_ssa_def * x)65 lower_i2i16(nir_builder *b, nir_ssa_def *x)
66 {
67    return nir_i2i16(b, nir_unpack_64_2x32_split_x(b, x));
68 }
69 
70 
71 static nir_ssa_def *
lower_i2i32(nir_builder * b,nir_ssa_def * x)72 lower_i2i32(nir_builder *b, nir_ssa_def *x)
73 {
74    return nir_unpack_64_2x32_split_x(b, x);
75 }
76 
77 static nir_ssa_def *
lower_i2i64(nir_builder * b,nir_ssa_def * x)78 lower_i2i64(nir_builder *b, nir_ssa_def *x)
79 {
80    nir_ssa_def *x32 = x->bit_size == 32 ? x : nir_i2i32(b, x);
81    return nir_pack_64_2x32_split(b, x32, nir_ishr_imm(b, x32, 31));
82 }
83 
84 static nir_ssa_def *
lower_u2u8(nir_builder * b,nir_ssa_def * x)85 lower_u2u8(nir_builder *b, nir_ssa_def *x)
86 {
87    return nir_u2u8(b, nir_unpack_64_2x32_split_x(b, x));
88 }
89 
90 static nir_ssa_def *
lower_u2u16(nir_builder * b,nir_ssa_def * x)91 lower_u2u16(nir_builder *b, nir_ssa_def *x)
92 {
93    return nir_u2u16(b, nir_unpack_64_2x32_split_x(b, x));
94 }
95 
96 static nir_ssa_def *
lower_u2u32(nir_builder * b,nir_ssa_def * x)97 lower_u2u32(nir_builder *b, nir_ssa_def *x)
98 {
99    return nir_unpack_64_2x32_split_x(b, x);
100 }
101 
102 static nir_ssa_def *
lower_u2u64(nir_builder * b,nir_ssa_def * x)103 lower_u2u64(nir_builder *b, nir_ssa_def *x)
104 {
105    nir_ssa_def *x32 = x->bit_size == 32 ? x : nir_u2u32(b, x);
106    return nir_pack_64_2x32_split(b, x32, nir_imm_int(b, 0));
107 }
108 
109 static nir_ssa_def *
lower_bcsel64(nir_builder * b,nir_ssa_def * cond,nir_ssa_def * x,nir_ssa_def * y)110 lower_bcsel64(nir_builder *b, nir_ssa_def *cond, nir_ssa_def *x, nir_ssa_def *y)
111 {
112    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
113    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
114    nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
115    nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
116 
117    return nir_pack_64_2x32_split(b, nir_bcsel(b, cond, x_lo, y_lo),
118                                     nir_bcsel(b, cond, x_hi, y_hi));
119 }
120 
121 static nir_ssa_def *
lower_inot64(nir_builder * b,nir_ssa_def * x)122 lower_inot64(nir_builder *b, nir_ssa_def *x)
123 {
124    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
125    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
126 
127    return nir_pack_64_2x32_split(b, nir_inot(b, x_lo), nir_inot(b, x_hi));
128 }
129 
130 static nir_ssa_def *
lower_iand64(nir_builder * b,nir_ssa_def * x,nir_ssa_def * y)131 lower_iand64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
132 {
133    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
134    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
135    nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
136    nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
137 
138    return nir_pack_64_2x32_split(b, nir_iand(b, x_lo, y_lo),
139                                     nir_iand(b, x_hi, y_hi));
140 }
141 
142 static nir_ssa_def *
lower_ior64(nir_builder * b,nir_ssa_def * x,nir_ssa_def * y)143 lower_ior64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
144 {
145    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
146    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
147    nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
148    nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
149 
150    return nir_pack_64_2x32_split(b, nir_ior(b, x_lo, y_lo),
151                                     nir_ior(b, x_hi, y_hi));
152 }
153 
154 static nir_ssa_def *
lower_ixor64(nir_builder * b,nir_ssa_def * x,nir_ssa_def * y)155 lower_ixor64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
156 {
157    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
158    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
159    nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
160    nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
161 
162    return nir_pack_64_2x32_split(b, nir_ixor(b, x_lo, y_lo),
163                                     nir_ixor(b, x_hi, y_hi));
164 }
165 
166 static nir_ssa_def *
lower_ishl64(nir_builder * b,nir_ssa_def * x,nir_ssa_def * y)167 lower_ishl64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
168 {
169    /* Implemented as
170     *
171     * uint64_t lshift(uint64_t x, int c)
172     * {
173     *    if (c == 0) return x;
174     *
175     *    uint32_t lo = LO(x), hi = HI(x);
176     *
177     *    if (c < 32) {
178     *       uint32_t lo_shifted = lo << c;
179     *       uint32_t hi_shifted = hi << c;
180     *       uint32_t lo_shifted_hi = lo >> abs(32 - c);
181     *       return pack_64(lo_shifted, hi_shifted | lo_shifted_hi);
182     *    } else {
183     *       uint32_t lo_shifted_hi = lo << abs(32 - c);
184     *       return pack_64(0, lo_shifted_hi);
185     *    }
186     * }
187     */
188    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
189    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
190 
191    nir_ssa_def *reverse_count = nir_iabs(b, nir_iadd(b, y, nir_imm_int(b, -32)));
192    nir_ssa_def *lo_shifted = nir_ishl(b, x_lo, y);
193    nir_ssa_def *hi_shifted = nir_ishl(b, x_hi, y);
194    nir_ssa_def *lo_shifted_hi = nir_ushr(b, x_lo, reverse_count);
195 
196    nir_ssa_def *res_if_lt_32 =
197       nir_pack_64_2x32_split(b, lo_shifted,
198                                 nir_ior(b, hi_shifted, lo_shifted_hi));
199    nir_ssa_def *res_if_ge_32 =
200       nir_pack_64_2x32_split(b, nir_imm_int(b, 0),
201                                 nir_ishl(b, x_lo, reverse_count));
202 
203    return nir_bcsel(b, nir_ieq_imm(b, y, 0), x,
204                     nir_bcsel(b, nir_uge(b, y, nir_imm_int(b, 32)),
205                                  res_if_ge_32, res_if_lt_32));
206 }
207 
208 static nir_ssa_def *
lower_ishr64(nir_builder * b,nir_ssa_def * x,nir_ssa_def * y)209 lower_ishr64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
210 {
211    /* Implemented as
212     *
213     * uint64_t arshift(uint64_t x, int c)
214     * {
215     *    if (c == 0) return x;
216     *
217     *    uint32_t lo = LO(x);
218     *    int32_t  hi = HI(x);
219     *
220     *    if (c < 32) {
221     *       uint32_t lo_shifted = lo >> c;
222     *       uint32_t hi_shifted = hi >> c;
223     *       uint32_t hi_shifted_lo = hi << abs(32 - c);
224     *       return pack_64(hi_shifted, hi_shifted_lo | lo_shifted);
225     *    } else {
226     *       uint32_t hi_shifted = hi >> 31;
227     *       uint32_t hi_shifted_lo = hi >> abs(32 - c);
228     *       return pack_64(hi_shifted, hi_shifted_lo);
229     *    }
230     * }
231     */
232    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
233    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
234 
235    nir_ssa_def *reverse_count = nir_iabs(b, nir_iadd(b, y, nir_imm_int(b, -32)));
236    nir_ssa_def *lo_shifted = nir_ushr(b, x_lo, y);
237    nir_ssa_def *hi_shifted = nir_ishr(b, x_hi, y);
238    nir_ssa_def *hi_shifted_lo = nir_ishl(b, x_hi, reverse_count);
239 
240    nir_ssa_def *res_if_lt_32 =
241       nir_pack_64_2x32_split(b, nir_ior(b, lo_shifted, hi_shifted_lo),
242                                 hi_shifted);
243    nir_ssa_def *res_if_ge_32 =
244       nir_pack_64_2x32_split(b, nir_ishr(b, x_hi, reverse_count),
245                                 nir_ishr(b, x_hi, nir_imm_int(b, 31)));
246 
247    return nir_bcsel(b, nir_ieq_imm(b, y, 0), x,
248                     nir_bcsel(b, nir_uge(b, y, nir_imm_int(b, 32)),
249                                  res_if_ge_32, res_if_lt_32));
250 }
251 
252 static nir_ssa_def *
lower_ushr64(nir_builder * b,nir_ssa_def * x,nir_ssa_def * y)253 lower_ushr64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
254 {
255    /* Implemented as
256     *
257     * uint64_t rshift(uint64_t x, int c)
258     * {
259     *    if (c == 0) return x;
260     *
261     *    uint32_t lo = LO(x), hi = HI(x);
262     *
263     *    if (c < 32) {
264     *       uint32_t lo_shifted = lo >> c;
265     *       uint32_t hi_shifted = hi >> c;
266     *       uint32_t hi_shifted_lo = hi << abs(32 - c);
267     *       return pack_64(hi_shifted, hi_shifted_lo | lo_shifted);
268     *    } else {
269     *       uint32_t hi_shifted_lo = hi >> abs(32 - c);
270     *       return pack_64(0, hi_shifted_lo);
271     *    }
272     * }
273     */
274 
275    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
276    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
277 
278    nir_ssa_def *reverse_count = nir_iabs(b, nir_iadd(b, y, nir_imm_int(b, -32)));
279    nir_ssa_def *lo_shifted = nir_ushr(b, x_lo, y);
280    nir_ssa_def *hi_shifted = nir_ushr(b, x_hi, y);
281    nir_ssa_def *hi_shifted_lo = nir_ishl(b, x_hi, reverse_count);
282 
283    nir_ssa_def *res_if_lt_32 =
284       nir_pack_64_2x32_split(b, nir_ior(b, lo_shifted, hi_shifted_lo),
285                                 hi_shifted);
286    nir_ssa_def *res_if_ge_32 =
287       nir_pack_64_2x32_split(b, nir_ushr(b, x_hi, reverse_count),
288                                 nir_imm_int(b, 0));
289 
290    return nir_bcsel(b, nir_ieq_imm(b, y, 0), x,
291                     nir_bcsel(b, nir_uge(b, y, nir_imm_int(b, 32)),
292                                  res_if_ge_32, res_if_lt_32));
293 }
294 
295 static nir_ssa_def *
lower_iadd64(nir_builder * b,nir_ssa_def * x,nir_ssa_def * y)296 lower_iadd64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
297 {
298    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
299    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
300    nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
301    nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
302 
303    nir_ssa_def *res_lo = nir_iadd(b, x_lo, y_lo);
304    nir_ssa_def *carry = nir_b2i32(b, nir_ult(b, res_lo, x_lo));
305    nir_ssa_def *res_hi = nir_iadd(b, carry, nir_iadd(b, x_hi, y_hi));
306 
307    return nir_pack_64_2x32_split(b, res_lo, res_hi);
308 }
309 
310 static nir_ssa_def *
lower_isub64(nir_builder * b,nir_ssa_def * x,nir_ssa_def * y)311 lower_isub64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
312 {
313    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
314    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
315    nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
316    nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
317 
318    nir_ssa_def *res_lo = nir_isub(b, x_lo, y_lo);
319    nir_ssa_def *borrow = nir_ineg(b, nir_b2i32(b, nir_ult(b, x_lo, y_lo)));
320    nir_ssa_def *res_hi = nir_iadd(b, nir_isub(b, x_hi, y_hi), borrow);
321 
322    return nir_pack_64_2x32_split(b, res_lo, res_hi);
323 }
324 
325 static nir_ssa_def *
lower_ineg64(nir_builder * b,nir_ssa_def * x)326 lower_ineg64(nir_builder *b, nir_ssa_def *x)
327 {
328    /* Since isub is the same number of instructions (with better dependencies)
329     * as iadd, subtraction is actually more efficient for ineg than the usual
330     * 2's complement "flip the bits and add one".
331     */
332    return lower_isub64(b, nir_imm_int64(b, 0), x);
333 }
334 
335 static nir_ssa_def *
lower_iabs64(nir_builder * b,nir_ssa_def * x)336 lower_iabs64(nir_builder *b, nir_ssa_def *x)
337 {
338    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
339    nir_ssa_def *x_is_neg = nir_ilt(b, x_hi, nir_imm_int(b, 0));
340    return nir_bcsel(b, x_is_neg, nir_ineg(b, x), x);
341 }
342 
343 static nir_ssa_def *
lower_int64_compare(nir_builder * b,nir_op op,nir_ssa_def * x,nir_ssa_def * y)344 lower_int64_compare(nir_builder *b, nir_op op, nir_ssa_def *x, nir_ssa_def *y)
345 {
346    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
347    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
348    nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
349    nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
350 
351    switch (op) {
352    case nir_op_ieq:
353       return nir_iand(b, nir_ieq(b, x_hi, y_hi), nir_ieq(b, x_lo, y_lo));
354    case nir_op_ine:
355       return nir_ior(b, nir_ine(b, x_hi, y_hi), nir_ine(b, x_lo, y_lo));
356    case nir_op_ult:
357       return nir_ior(b, nir_ult(b, x_hi, y_hi),
358                         nir_iand(b, nir_ieq(b, x_hi, y_hi),
359                                     nir_ult(b, x_lo, y_lo)));
360    case nir_op_ilt:
361       return nir_ior(b, nir_ilt(b, x_hi, y_hi),
362                         nir_iand(b, nir_ieq(b, x_hi, y_hi),
363                                     nir_ult(b, x_lo, y_lo)));
364       break;
365    case nir_op_uge:
366       /* Lower as !(x < y) in the hopes of better CSE */
367       return nir_inot(b, lower_int64_compare(b, nir_op_ult, x, y));
368    case nir_op_ige:
369       /* Lower as !(x < y) in the hopes of better CSE */
370       return nir_inot(b, lower_int64_compare(b, nir_op_ilt, x, y));
371    default:
372       unreachable("Invalid comparison");
373    }
374 }
375 
376 static nir_ssa_def *
lower_umax64(nir_builder * b,nir_ssa_def * x,nir_ssa_def * y)377 lower_umax64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
378 {
379    return nir_bcsel(b, lower_int64_compare(b, nir_op_ult, x, y), y, x);
380 }
381 
382 static nir_ssa_def *
lower_imax64(nir_builder * b,nir_ssa_def * x,nir_ssa_def * y)383 lower_imax64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
384 {
385    return nir_bcsel(b, lower_int64_compare(b, nir_op_ilt, x, y), y, x);
386 }
387 
388 static nir_ssa_def *
lower_umin64(nir_builder * b,nir_ssa_def * x,nir_ssa_def * y)389 lower_umin64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
390 {
391    return nir_bcsel(b, lower_int64_compare(b, nir_op_ult, x, y), x, y);
392 }
393 
394 static nir_ssa_def *
lower_imin64(nir_builder * b,nir_ssa_def * x,nir_ssa_def * y)395 lower_imin64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
396 {
397    return nir_bcsel(b, lower_int64_compare(b, nir_op_ilt, x, y), x, y);
398 }
399 
400 static nir_ssa_def *
lower_mul_2x32_64(nir_builder * b,nir_ssa_def * x,nir_ssa_def * y,bool sign_extend)401 lower_mul_2x32_64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y,
402                   bool sign_extend)
403 {
404    nir_ssa_def *res_hi = sign_extend ? nir_imul_high(b, x, y)
405                                      : nir_umul_high(b, x, y);
406 
407    return nir_pack_64_2x32_split(b, nir_imul(b, x, y), res_hi);
408 }
409 
410 static nir_ssa_def *
lower_imul64(nir_builder * b,nir_ssa_def * x,nir_ssa_def * y)411 lower_imul64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y)
412 {
413    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
414    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
415    nir_ssa_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
416    nir_ssa_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
417 
418    nir_ssa_def *mul_lo = nir_umul_2x32_64(b, x_lo, y_lo);
419    nir_ssa_def *res_hi = nir_iadd(b, nir_unpack_64_2x32_split_y(b, mul_lo),
420                          nir_iadd(b, nir_imul(b, x_lo, y_hi),
421                                      nir_imul(b, x_hi, y_lo)));
422 
423    return nir_pack_64_2x32_split(b, nir_unpack_64_2x32_split_x(b, mul_lo),
424                                  res_hi);
425 }
426 
427 static nir_ssa_def *
lower_mul_high64(nir_builder * b,nir_ssa_def * x,nir_ssa_def * y,bool sign_extend)428 lower_mul_high64(nir_builder *b, nir_ssa_def *x, nir_ssa_def *y,
429                  bool sign_extend)
430 {
431    nir_ssa_def *x32[4], *y32[4];
432    x32[0] = nir_unpack_64_2x32_split_x(b, x);
433    x32[1] = nir_unpack_64_2x32_split_y(b, x);
434    if (sign_extend) {
435       x32[2] = x32[3] = nir_ishr_imm(b, x32[1], 31);
436    } else {
437       x32[2] = x32[3] = nir_imm_int(b, 0);
438    }
439 
440    y32[0] = nir_unpack_64_2x32_split_x(b, y);
441    y32[1] = nir_unpack_64_2x32_split_y(b, y);
442    if (sign_extend) {
443       y32[2] = y32[3] = nir_ishr_imm(b, y32[1], 31);
444    } else {
445       y32[2] = y32[3] = nir_imm_int(b, 0);
446    }
447 
448    nir_ssa_def *res[8] = { NULL, };
449 
450    /* Yes, the following generates a pile of code.  However, we throw res[0]
451     * and res[1] away in the end and, if we're in the umul case, four of our
452     * eight dword operands will be constant zero and opt_algebraic will clean
453     * this up nicely.
454     */
455    for (unsigned i = 0; i < 4; i++) {
456       nir_ssa_def *carry = NULL;
457       for (unsigned j = 0; j < 4; j++) {
458          /* The maximum values of x32[i] and y32[i] are UINT32_MAX so the
459           * maximum value of tmp is UINT32_MAX * UINT32_MAX.  The maximum
460           * value that will fit in tmp is
461           *
462           *    UINT64_MAX = UINT32_MAX << 32 + UINT32_MAX
463           *               = UINT32_MAX * (UINT32_MAX + 1) + UINT32_MAX
464           *               = UINT32_MAX * UINT32_MAX + 2 * UINT32_MAX
465           *
466           * so we're guaranteed that we can add in two more 32-bit values
467           * without overflowing tmp.
468           */
469          nir_ssa_def *tmp = nir_umul_2x32_64(b, x32[i], y32[i]);
470 
471          if (res[i + j])
472             tmp = nir_iadd(b, tmp, nir_u2u64(b, res[i + j]));
473          if (carry)
474             tmp = nir_iadd(b, tmp, carry);
475          res[i + j] = nir_u2u32(b, tmp);
476          carry = nir_ushr_imm(b, tmp, 32);
477       }
478       res[i + 4] = nir_u2u32(b, carry);
479    }
480 
481    return nir_pack_64_2x32_split(b, res[2], res[3]);
482 }
483 
484 static nir_ssa_def *
lower_isign64(nir_builder * b,nir_ssa_def * x)485 lower_isign64(nir_builder *b, nir_ssa_def *x)
486 {
487    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
488    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
489 
490    nir_ssa_def *is_non_zero = nir_i2b(b, nir_ior(b, x_lo, x_hi));
491    nir_ssa_def *res_hi = nir_ishr_imm(b, x_hi, 31);
492    nir_ssa_def *res_lo = nir_ior(b, res_hi, nir_b2i32(b, is_non_zero));
493 
494    return nir_pack_64_2x32_split(b, res_lo, res_hi);
495 }
496 
497 static void
lower_udiv64_mod64(nir_builder * b,nir_ssa_def * n,nir_ssa_def * d,nir_ssa_def ** q,nir_ssa_def ** r)498 lower_udiv64_mod64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d,
499                    nir_ssa_def **q, nir_ssa_def **r)
500 {
501    /* TODO: We should specially handle the case where the denominator is a
502     * constant.  In that case, we should be able to reduce it to a multiply by
503     * a constant, some shifts, and an add.
504     */
505    nir_ssa_def *n_lo = nir_unpack_64_2x32_split_x(b, n);
506    nir_ssa_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
507    nir_ssa_def *d_lo = nir_unpack_64_2x32_split_x(b, d);
508    nir_ssa_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
509 
510    nir_ssa_def *q_lo = nir_imm_zero(b, n->num_components, 32);
511    nir_ssa_def *q_hi = nir_imm_zero(b, n->num_components, 32);
512 
513    nir_ssa_def *n_hi_before_if = n_hi;
514    nir_ssa_def *q_hi_before_if = q_hi;
515 
516    /* If the upper 32 bits of denom are non-zero, it is impossible for shifts
517     * greater than 32 bits to occur.  If the upper 32 bits of the numerator
518     * are zero, it is impossible for (denom << [63, 32]) <= numer unless
519     * denom == 0.
520     */
521    nir_ssa_def *need_high_div =
522       nir_iand(b, nir_ieq_imm(b, d_hi, 0), nir_uge(b, n_hi, d_lo));
523    nir_push_if(b, nir_bany(b, need_high_div));
524    {
525       /* If we only have one component, then the bany above goes away and
526        * this is always true within the if statement.
527        */
528       if (n->num_components == 1)
529          need_high_div = nir_imm_true(b);
530 
531       nir_ssa_def *log2_d_lo = nir_ufind_msb(b, d_lo);
532 
533       for (int i = 31; i >= 0; i--) {
534          /* if ((d.x << i) <= n.y) {
535           *    n.y -= d.x << i;
536           *    quot.y |= 1U << i;
537           * }
538           */
539          nir_ssa_def *d_shift = nir_ishl(b, d_lo, nir_imm_int(b, i));
540          nir_ssa_def *new_n_hi = nir_isub(b, n_hi, d_shift);
541          nir_ssa_def *new_q_hi = nir_ior(b, q_hi, nir_imm_int(b, 1u << i));
542          nir_ssa_def *cond = nir_iand(b, need_high_div,
543                                          nir_uge(b, n_hi, d_shift));
544          if (i != 0) {
545             /* log2_d_lo is always <= 31, so we don't need to bother with it
546              * in the last iteration.
547              */
548             cond = nir_iand(b, cond,
549                                nir_ige(b, nir_imm_int(b, 31 - i), log2_d_lo));
550          }
551          n_hi = nir_bcsel(b, cond, new_n_hi, n_hi);
552          q_hi = nir_bcsel(b, cond, new_q_hi, q_hi);
553       }
554    }
555    nir_pop_if(b, NULL);
556    n_hi = nir_if_phi(b, n_hi, n_hi_before_if);
557    q_hi = nir_if_phi(b, q_hi, q_hi_before_if);
558 
559    nir_ssa_def *log2_denom = nir_ufind_msb(b, d_hi);
560 
561    n = nir_pack_64_2x32_split(b, n_lo, n_hi);
562    d = nir_pack_64_2x32_split(b, d_lo, d_hi);
563    for (int i = 31; i >= 0; i--) {
564       /* if ((d64 << i) <= n64) {
565        *    n64 -= d64 << i;
566        *    quot.x |= 1U << i;
567        * }
568        */
569       nir_ssa_def *d_shift = nir_ishl(b, d, nir_imm_int(b, i));
570       nir_ssa_def *new_n = nir_isub(b, n, d_shift);
571       nir_ssa_def *new_q_lo = nir_ior(b, q_lo, nir_imm_int(b, 1u << i));
572       nir_ssa_def *cond = nir_uge(b, n, d_shift);
573       if (i != 0) {
574          /* log2_denom is always <= 31, so we don't need to bother with it
575           * in the last iteration.
576           */
577          cond = nir_iand(b, cond,
578                             nir_ige(b, nir_imm_int(b, 31 - i), log2_denom));
579       }
580       n = nir_bcsel(b, cond, new_n, n);
581       q_lo = nir_bcsel(b, cond, new_q_lo, q_lo);
582    }
583 
584    *q = nir_pack_64_2x32_split(b, q_lo, q_hi);
585    *r = n;
586 }
587 
588 static nir_ssa_def *
lower_udiv64(nir_builder * b,nir_ssa_def * n,nir_ssa_def * d)589 lower_udiv64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d)
590 {
591    nir_ssa_def *q, *r;
592    lower_udiv64_mod64(b, n, d, &q, &r);
593    return q;
594 }
595 
596 static nir_ssa_def *
lower_idiv64(nir_builder * b,nir_ssa_def * n,nir_ssa_def * d)597 lower_idiv64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d)
598 {
599    nir_ssa_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
600    nir_ssa_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
601 
602    nir_ssa_def *negate = nir_ine(b, nir_ilt(b, n_hi, nir_imm_int(b, 0)),
603                                     nir_ilt(b, d_hi, nir_imm_int(b, 0)));
604    nir_ssa_def *q, *r;
605    lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
606    return nir_bcsel(b, negate, nir_ineg(b, q), q);
607 }
608 
609 static nir_ssa_def *
lower_umod64(nir_builder * b,nir_ssa_def * n,nir_ssa_def * d)610 lower_umod64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d)
611 {
612    nir_ssa_def *q, *r;
613    lower_udiv64_mod64(b, n, d, &q, &r);
614    return r;
615 }
616 
617 static nir_ssa_def *
lower_imod64(nir_builder * b,nir_ssa_def * n,nir_ssa_def * d)618 lower_imod64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d)
619 {
620    nir_ssa_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
621    nir_ssa_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
622    nir_ssa_def *n_is_neg = nir_ilt(b, n_hi, nir_imm_int(b, 0));
623    nir_ssa_def *d_is_neg = nir_ilt(b, d_hi, nir_imm_int(b, 0));
624 
625    nir_ssa_def *q, *r;
626    lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
627 
628    nir_ssa_def *rem = nir_bcsel(b, n_is_neg, nir_ineg(b, r), r);
629 
630    return nir_bcsel(b, nir_ieq_imm(b, r, 0), nir_imm_int64(b, 0),
631           nir_bcsel(b, nir_ieq(b, n_is_neg, d_is_neg), rem,
632                        nir_iadd(b, rem, d)));
633 }
634 
635 static nir_ssa_def *
lower_irem64(nir_builder * b,nir_ssa_def * n,nir_ssa_def * d)636 lower_irem64(nir_builder *b, nir_ssa_def *n, nir_ssa_def *d)
637 {
638    nir_ssa_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
639    nir_ssa_def *n_is_neg = nir_ilt(b, n_hi, nir_imm_int(b, 0));
640 
641    nir_ssa_def *q, *r;
642    lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
643    return nir_bcsel(b, n_is_neg, nir_ineg(b, r), r);
644 }
645 
646 static nir_ssa_def *
lower_extract(nir_builder * b,nir_op op,nir_ssa_def * x,nir_ssa_def * c)647 lower_extract(nir_builder *b, nir_op op, nir_ssa_def *x, nir_ssa_def *c)
648 {
649    assert(op == nir_op_extract_u8 || op == nir_op_extract_i8 ||
650           op == nir_op_extract_u16 || op == nir_op_extract_i16);
651 
652    const int chunk = nir_src_as_uint(nir_src_for_ssa(c));
653    const int chunk_bits =
654       (op == nir_op_extract_u8 || op == nir_op_extract_i8) ? 8 : 16;
655    const int num_chunks_in_32 = 32 / chunk_bits;
656 
657    nir_ssa_def *extract32;
658    if (chunk < num_chunks_in_32) {
659       extract32 = nir_build_alu(b, op, nir_unpack_64_2x32_split_x(b, x),
660                                    nir_imm_int(b, chunk),
661                                    NULL, NULL);
662    } else {
663       extract32 = nir_build_alu(b, op, nir_unpack_64_2x32_split_y(b, x),
664                                    nir_imm_int(b, chunk - num_chunks_in_32),
665                                    NULL, NULL);
666    }
667 
668    if (op == nir_op_extract_i8 || op == nir_op_extract_i16)
669       return lower_i2i64(b, extract32);
670    else
671       return lower_u2u64(b, extract32);
672 }
673 
674 static nir_ssa_def *
lower_ufind_msb64(nir_builder * b,nir_ssa_def * x)675 lower_ufind_msb64(nir_builder *b, nir_ssa_def *x)
676 {
677 
678    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
679    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
680    nir_ssa_def *lo_count = nir_ufind_msb(b, x_lo);
681    nir_ssa_def *hi_count = nir_ufind_msb(b, x_hi);
682    nir_ssa_def *valid_hi_bits = nir_ine(b, x_hi, nir_imm_int(b, 0));
683    nir_ssa_def *hi_res = nir_iadd(b, nir_imm_intN_t(b, 32, 32), hi_count);
684    return nir_bcsel(b, valid_hi_bits, hi_res, lo_count);
685 }
686 
687 static nir_ssa_def *
lower_2f(nir_builder * b,nir_ssa_def * x,unsigned dest_bit_size,bool src_is_signed)688 lower_2f(nir_builder *b, nir_ssa_def *x, unsigned dest_bit_size,
689          bool src_is_signed)
690 {
691    nir_ssa_def *x_sign = NULL;
692 
693    if (src_is_signed) {
694       x_sign = nir_bcsel(b, COND_LOWER_CMP(b, ilt, x, nir_imm_int64(b, 0)),
695                          nir_imm_floatN_t(b, -1, dest_bit_size),
696                          nir_imm_floatN_t(b, 1, dest_bit_size));
697       x = COND_LOWER_OP(b, iabs, x);
698    }
699 
700    nir_ssa_def *exp = COND_LOWER_OP(b, ufind_msb, x);
701    unsigned significand_bits;
702 
703    switch (dest_bit_size) {
704    case 32:
705       significand_bits = 23;
706       break;
707    case 16:
708       significand_bits = 10;
709       break;
710    default:
711       unreachable("Invalid dest_bit_size");
712    }
713 
714    nir_ssa_def *discard =
715       nir_imax(b, nir_isub(b, exp, nir_imm_int(b, significand_bits)),
716                   nir_imm_int(b, 0));
717    nir_ssa_def *significand =
718       COND_LOWER_CAST(b, u2u32, COND_LOWER_OP(b, ushr, x, discard));
719 
720    /* Round-to-nearest-even implementation:
721     * - if the non-representable part of the significand is higher than half
722     *   the minimum representable significand, we round-up
723     * - if the non-representable part of the significand is equal to half the
724     *   minimum representable significand and the representable part of the
725     *   significand is odd, we round-up
726     * - in any other case, we round-down
727     */
728    nir_ssa_def *lsb_mask = COND_LOWER_OP(b, ishl, nir_imm_int64(b, 1), discard);
729    nir_ssa_def *rem_mask = COND_LOWER_OP(b, isub, lsb_mask, nir_imm_int64(b, 1));
730    nir_ssa_def *half = COND_LOWER_OP(b, ishr, lsb_mask, nir_imm_int(b, 1));
731    nir_ssa_def *rem = COND_LOWER_OP(b, iand, x, rem_mask);
732    nir_ssa_def *halfway = nir_iand(b, COND_LOWER_CMP(b, ieq, rem, half),
733                                    nir_ine(b, discard, nir_imm_int(b, 0)));
734    nir_ssa_def *is_odd = nir_i2b(b, nir_iand(b, significand, nir_imm_int(b, 1)));
735    nir_ssa_def *round_up = nir_ior(b, COND_LOWER_CMP(b, ilt, half, rem),
736                                    nir_iand(b, halfway, is_odd));
737    significand = nir_iadd(b, significand, nir_b2i32(b, round_up));
738 
739    nir_ssa_def *res;
740 
741    if (dest_bit_size == 32)
742       res = nir_fmul(b, nir_u2f32(b, significand),
743                      nir_fexp2(b, nir_u2f32(b, discard)));
744    else
745       res = nir_fmul(b, nir_u2f16(b, significand),
746                      nir_fexp2(b, nir_u2f16(b, discard)));
747 
748    if (src_is_signed)
749       res = nir_fmul(b, res, x_sign);
750 
751    return res;
752 }
753 
754 static nir_ssa_def *
lower_f2(nir_builder * b,nir_ssa_def * x,bool dst_is_signed)755 lower_f2(nir_builder *b, nir_ssa_def *x, bool dst_is_signed)
756 {
757    assert(x->bit_size == 16 || x->bit_size == 32);
758    nir_ssa_def *x_sign = NULL;
759 
760    if (dst_is_signed)
761       x_sign = nir_fsign(b, x);
762    else
763       x = nir_fmin(b, x, nir_imm_floatN_t(b, UINT64_MAX, x->bit_size));
764 
765    x = nir_ftrunc(b, x);
766 
767    if (dst_is_signed) {
768       x = nir_fmin(b, x, nir_imm_floatN_t(b, INT64_MAX, x->bit_size));
769       x = nir_fmax(b, x, nir_imm_floatN_t(b, INT64_MIN, x->bit_size));
770       x = nir_fabs(b, x);
771    }
772 
773    nir_ssa_def *div = nir_imm_floatN_t(b, 1ULL << 32, x->bit_size);
774    nir_ssa_def *res_hi = nir_f2u32(b, nir_fdiv(b, x, div));
775    nir_ssa_def *res_lo = nir_f2u32(b, nir_frem(b, x, div));
776    nir_ssa_def *res = nir_pack_64_2x32_split(b, res_lo, res_hi);
777 
778    if (dst_is_signed)
779       res = nir_bcsel(b, nir_flt(b, x_sign, nir_imm_floatN_t(b, 0, x->bit_size)),
780                       nir_ineg(b, res), res);
781 
782    return res;
783 }
784 
785 static nir_ssa_def *
lower_bit_count64(nir_builder * b,nir_ssa_def * x)786 lower_bit_count64(nir_builder *b, nir_ssa_def *x)
787 {
788    nir_ssa_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
789    nir_ssa_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
790    nir_ssa_def *lo_count = nir_bit_count(b, x_lo);
791    nir_ssa_def *hi_count = nir_bit_count(b, x_hi);
792    return nir_iadd(b, lo_count, hi_count);
793 }
794 
795 nir_lower_int64_options
nir_lower_int64_op_to_options_mask(nir_op opcode)796 nir_lower_int64_op_to_options_mask(nir_op opcode)
797 {
798    switch (opcode) {
799    case nir_op_imul:
800    case nir_op_amul:
801       return nir_lower_imul64;
802    case nir_op_imul_2x32_64:
803    case nir_op_umul_2x32_64:
804       return nir_lower_imul_2x32_64;
805    case nir_op_imul_high:
806    case nir_op_umul_high:
807       return nir_lower_imul_high64;
808    case nir_op_isign:
809       return nir_lower_isign64;
810    case nir_op_udiv:
811    case nir_op_idiv:
812    case nir_op_umod:
813    case nir_op_imod:
814    case nir_op_irem:
815       return nir_lower_divmod64;
816    case nir_op_b2i64:
817    case nir_op_i2b1:
818    case nir_op_i2i8:
819    case nir_op_i2i16:
820    case nir_op_i2i32:
821    case nir_op_i2i64:
822    case nir_op_u2u8:
823    case nir_op_u2u16:
824    case nir_op_u2u32:
825    case nir_op_u2u64:
826    case nir_op_i2f32:
827    case nir_op_u2f32:
828    case nir_op_i2f16:
829    case nir_op_u2f16:
830    case nir_op_f2i64:
831    case nir_op_f2u64:
832    case nir_op_bcsel:
833       return nir_lower_mov64;
834    case nir_op_ieq:
835    case nir_op_ine:
836    case nir_op_ult:
837    case nir_op_ilt:
838    case nir_op_uge:
839    case nir_op_ige:
840       return nir_lower_icmp64;
841    case nir_op_iadd:
842    case nir_op_isub:
843       return nir_lower_iadd64;
844    case nir_op_imin:
845    case nir_op_imax:
846    case nir_op_umin:
847    case nir_op_umax:
848       return nir_lower_minmax64;
849    case nir_op_iabs:
850       return nir_lower_iabs64;
851    case nir_op_ineg:
852       return nir_lower_ineg64;
853    case nir_op_iand:
854    case nir_op_ior:
855    case nir_op_ixor:
856    case nir_op_inot:
857       return nir_lower_logic64;
858    case nir_op_ishl:
859    case nir_op_ishr:
860    case nir_op_ushr:
861       return nir_lower_shift64;
862    case nir_op_extract_u8:
863    case nir_op_extract_i8:
864    case nir_op_extract_u16:
865    case nir_op_extract_i16:
866       return nir_lower_extract64;
867    case nir_op_ufind_msb:
868       return nir_lower_ufind_msb64;
869    case nir_op_bit_count:
870       return nir_lower_bit_count64;
871    default:
872       return 0;
873    }
874 }
875 
876 static nir_ssa_def *
lower_int64_alu_instr(nir_builder * b,nir_alu_instr * alu)877 lower_int64_alu_instr(nir_builder *b, nir_alu_instr *alu)
878 {
879    nir_ssa_def *src[4];
880    for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
881       src[i] = nir_ssa_for_alu_src(b, alu, i);
882 
883    switch (alu->op) {
884    case nir_op_imul:
885    case nir_op_amul:
886       return lower_imul64(b, src[0], src[1]);
887    case nir_op_imul_2x32_64:
888       return lower_mul_2x32_64(b, src[0], src[1], true);
889    case nir_op_umul_2x32_64:
890       return lower_mul_2x32_64(b, src[0], src[1], false);
891    case nir_op_imul_high:
892       return lower_mul_high64(b, src[0], src[1], true);
893    case nir_op_umul_high:
894       return lower_mul_high64(b, src[0], src[1], false);
895    case nir_op_isign:
896       return lower_isign64(b, src[0]);
897    case nir_op_udiv:
898       return lower_udiv64(b, src[0], src[1]);
899    case nir_op_idiv:
900       return lower_idiv64(b, src[0], src[1]);
901    case nir_op_umod:
902       return lower_umod64(b, src[0], src[1]);
903    case nir_op_imod:
904       return lower_imod64(b, src[0], src[1]);
905    case nir_op_irem:
906       return lower_irem64(b, src[0], src[1]);
907    case nir_op_b2i64:
908       return lower_b2i64(b, src[0]);
909    case nir_op_i2b1:
910       return lower_i2b(b, src[0]);
911    case nir_op_i2i8:
912       return lower_i2i8(b, src[0]);
913    case nir_op_i2i16:
914       return lower_i2i16(b, src[0]);
915    case nir_op_i2i32:
916       return lower_i2i32(b, src[0]);
917    case nir_op_i2i64:
918       return lower_i2i64(b, src[0]);
919    case nir_op_u2u8:
920       return lower_u2u8(b, src[0]);
921    case nir_op_u2u16:
922       return lower_u2u16(b, src[0]);
923    case nir_op_u2u32:
924       return lower_u2u32(b, src[0]);
925    case nir_op_u2u64:
926       return lower_u2u64(b, src[0]);
927    case nir_op_bcsel:
928       return lower_bcsel64(b, src[0], src[1], src[2]);
929    case nir_op_ieq:
930    case nir_op_ine:
931    case nir_op_ult:
932    case nir_op_ilt:
933    case nir_op_uge:
934    case nir_op_ige:
935       return lower_int64_compare(b, alu->op, src[0], src[1]);
936    case nir_op_iadd:
937       return lower_iadd64(b, src[0], src[1]);
938    case nir_op_isub:
939       return lower_isub64(b, src[0], src[1]);
940    case nir_op_imin:
941       return lower_imin64(b, src[0], src[1]);
942    case nir_op_imax:
943       return lower_imax64(b, src[0], src[1]);
944    case nir_op_umin:
945       return lower_umin64(b, src[0], src[1]);
946    case nir_op_umax:
947       return lower_umax64(b, src[0], src[1]);
948    case nir_op_iabs:
949       return lower_iabs64(b, src[0]);
950    case nir_op_ineg:
951       return lower_ineg64(b, src[0]);
952    case nir_op_iand:
953       return lower_iand64(b, src[0], src[1]);
954    case nir_op_ior:
955       return lower_ior64(b, src[0], src[1]);
956    case nir_op_ixor:
957       return lower_ixor64(b, src[0], src[1]);
958    case nir_op_inot:
959       return lower_inot64(b, src[0]);
960    case nir_op_ishl:
961       return lower_ishl64(b, src[0], src[1]);
962    case nir_op_ishr:
963       return lower_ishr64(b, src[0], src[1]);
964    case nir_op_ushr:
965       return lower_ushr64(b, src[0], src[1]);
966    case nir_op_extract_u8:
967    case nir_op_extract_i8:
968    case nir_op_extract_u16:
969    case nir_op_extract_i16:
970       return lower_extract(b, alu->op, src[0], src[1]);
971    case nir_op_ufind_msb:
972       return lower_ufind_msb64(b, src[0]);
973    case nir_op_bit_count:
974       return lower_bit_count64(b, src[0]);
975    case nir_op_i2f64:
976    case nir_op_i2f32:
977    case nir_op_i2f16:
978       return lower_2f(b, src[0], nir_dest_bit_size(alu->dest.dest), true);
979    case nir_op_u2f64:
980    case nir_op_u2f32:
981    case nir_op_u2f16:
982       return lower_2f(b, src[0], nir_dest_bit_size(alu->dest.dest), false);
983    case nir_op_f2i64:
984    case nir_op_f2u64:
985       /* We don't support f64toi64 (yet?). */
986       if (src[0]->bit_size > 32)
987          return false;
988 
989       return lower_f2(b, src[0], alu->op == nir_op_f2i64);
990    default:
991       unreachable("Invalid ALU opcode to lower");
992    }
993 }
994 
995 static bool
should_lower_int64_alu_instr(const nir_alu_instr * alu,const nir_shader_compiler_options * options)996 should_lower_int64_alu_instr(const nir_alu_instr *alu,
997                              const nir_shader_compiler_options *options)
998 {
999    switch (alu->op) {
1000    case nir_op_i2b1:
1001    case nir_op_i2i8:
1002    case nir_op_i2i16:
1003    case nir_op_i2i32:
1004    case nir_op_u2u8:
1005    case nir_op_u2u16:
1006    case nir_op_u2u32:
1007       assert(alu->src[0].src.is_ssa);
1008       if (alu->src[0].src.ssa->bit_size != 64)
1009          return false;
1010       break;
1011    case nir_op_bcsel:
1012       assert(alu->src[1].src.is_ssa);
1013       assert(alu->src[2].src.is_ssa);
1014       assert(alu->src[1].src.ssa->bit_size ==
1015              alu->src[2].src.ssa->bit_size);
1016       if (alu->src[1].src.ssa->bit_size != 64)
1017          return false;
1018       break;
1019    case nir_op_ieq:
1020    case nir_op_ine:
1021    case nir_op_ult:
1022    case nir_op_ilt:
1023    case nir_op_uge:
1024    case nir_op_ige:
1025       assert(alu->src[0].src.is_ssa);
1026       assert(alu->src[1].src.is_ssa);
1027       assert(alu->src[0].src.ssa->bit_size ==
1028              alu->src[1].src.ssa->bit_size);
1029       if (alu->src[0].src.ssa->bit_size != 64)
1030          return false;
1031       break;
1032    case nir_op_ufind_msb:
1033    case nir_op_bit_count:
1034       assert(alu->src[0].src.is_ssa);
1035       if (alu->src[0].src.ssa->bit_size != 64)
1036          return false;
1037       break;
1038    case nir_op_amul:
1039       assert(alu->dest.dest.is_ssa);
1040       if (options->has_imul24)
1041          return false;
1042       if (alu->dest.dest.ssa.bit_size != 64)
1043          return false;
1044       break;
1045    case nir_op_i2f64:
1046    case nir_op_u2f64:
1047    case nir_op_i2f32:
1048    case nir_op_u2f32:
1049    case nir_op_i2f16:
1050    case nir_op_u2f16:
1051       assert(alu->src[0].src.is_ssa);
1052       if (alu->src[0].src.ssa->bit_size != 64)
1053          return false;
1054       break;
1055    case nir_op_f2u64:
1056    case nir_op_f2i64:
1057       FALLTHROUGH;
1058    default:
1059       assert(alu->dest.dest.is_ssa);
1060       if (alu->dest.dest.ssa.bit_size != 64)
1061          return false;
1062       break;
1063    }
1064 
1065    unsigned mask = nir_lower_int64_op_to_options_mask(alu->op);
1066    return (options->lower_int64_options & mask) != 0;
1067 }
1068 
1069 static nir_ssa_def *
split_64bit_subgroup_op(nir_builder * b,const nir_intrinsic_instr * intrin)1070 split_64bit_subgroup_op(nir_builder *b, const nir_intrinsic_instr *intrin)
1071 {
1072    const nir_intrinsic_info *info = &nir_intrinsic_infos[intrin->intrinsic];
1073 
1074    /* This works on subgroup ops with a single 64-bit source which can be
1075     * trivially lowered by doing the exact same op on both halves.
1076     */
1077    assert(intrin->src[0].is_ssa && intrin->src[0].ssa->bit_size == 64);
1078    nir_ssa_def *split_src0[2] = {
1079       nir_unpack_64_2x32_split_x(b, intrin->src[0].ssa),
1080       nir_unpack_64_2x32_split_y(b, intrin->src[0].ssa),
1081    };
1082 
1083    assert(info->has_dest && intrin->dest.is_ssa &&
1084           intrin->dest.ssa.bit_size == 64);
1085 
1086    nir_ssa_def *res[2];
1087    for (unsigned i = 0; i < 2; i++) {
1088       nir_intrinsic_instr *split =
1089          nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
1090       split->num_components = intrin->num_components;
1091       split->src[0] = nir_src_for_ssa(split_src0[i]);
1092 
1093       /* Other sources must be less than 64 bits and get copied directly */
1094       for (unsigned j = 1; j < info->num_srcs; j++) {
1095          assert(intrin->src[j].is_ssa && intrin->src[j].ssa->bit_size < 64);
1096          split->src[j] = nir_src_for_ssa(intrin->src[j].ssa);
1097       }
1098 
1099       /* Copy const indices, if any */
1100       memcpy(split->const_index, intrin->const_index,
1101              sizeof(intrin->const_index));
1102 
1103       nir_ssa_dest_init(&split->instr, &split->dest,
1104                         intrin->dest.ssa.num_components, 32, NULL);
1105       nir_builder_instr_insert(b, &split->instr);
1106 
1107       res[i] = &split->dest.ssa;
1108    }
1109 
1110    return nir_pack_64_2x32_split(b, res[0], res[1]);
1111 }
1112 
1113 static nir_ssa_def *
build_vote_ieq(nir_builder * b,nir_ssa_def * x)1114 build_vote_ieq(nir_builder *b, nir_ssa_def *x)
1115 {
1116    nir_intrinsic_instr *vote =
1117       nir_intrinsic_instr_create(b->shader, nir_intrinsic_vote_ieq);
1118    vote->src[0] = nir_src_for_ssa(x);
1119    vote->num_components = x->num_components;
1120    nir_ssa_dest_init(&vote->instr, &vote->dest, 1, 1, NULL);
1121    nir_builder_instr_insert(b, &vote->instr);
1122    return &vote->dest.ssa;
1123 }
1124 
1125 static nir_ssa_def *
lower_vote_ieq(nir_builder * b,nir_ssa_def * x)1126 lower_vote_ieq(nir_builder *b, nir_ssa_def *x)
1127 {
1128    return nir_iand(b, build_vote_ieq(b, nir_unpack_64_2x32_split_x(b, x)),
1129                       build_vote_ieq(b, nir_unpack_64_2x32_split_y(b, x)));
1130 }
1131 
1132 static nir_ssa_def *
build_scan_intrinsic(nir_builder * b,nir_intrinsic_op scan_op,nir_op reduction_op,unsigned cluster_size,nir_ssa_def * val)1133 build_scan_intrinsic(nir_builder *b, nir_intrinsic_op scan_op,
1134                      nir_op reduction_op, unsigned cluster_size,
1135                      nir_ssa_def *val)
1136 {
1137    nir_intrinsic_instr *scan =
1138       nir_intrinsic_instr_create(b->shader, scan_op);
1139    scan->num_components = val->num_components;
1140    scan->src[0] = nir_src_for_ssa(val);
1141    nir_intrinsic_set_reduction_op(scan, reduction_op);
1142    if (scan_op == nir_intrinsic_reduce)
1143       nir_intrinsic_set_cluster_size(scan, cluster_size);
1144    nir_ssa_dest_init(&scan->instr, &scan->dest,
1145                      val->num_components, val->bit_size, NULL);
1146    nir_builder_instr_insert(b, &scan->instr);
1147    return &scan->dest.ssa;
1148 }
1149 
1150 static nir_ssa_def *
lower_scan_iadd64(nir_builder * b,const nir_intrinsic_instr * intrin)1151 lower_scan_iadd64(nir_builder *b, const nir_intrinsic_instr *intrin)
1152 {
1153    unsigned cluster_size =
1154       intrin->intrinsic == nir_intrinsic_reduce ?
1155       nir_intrinsic_cluster_size(intrin) : 0;
1156 
1157    /* Split it into three chunks of no more than 24 bits each.  With 8 bits
1158     * of headroom, we're guaranteed that there will never be overflow in the
1159     * individual subgroup operations.  (Assuming, of course, a subgroup size
1160     * no larger than 256 which seems reasonable.)  We can then scan on each of
1161     * the chunks and add them back together at the end.
1162     */
1163    assert(intrin->src[0].is_ssa);
1164    nir_ssa_def *x = intrin->src[0].ssa;
1165    nir_ssa_def *x_low =
1166       nir_u2u32(b, nir_iand_imm(b, x, 0xffffff));
1167    nir_ssa_def *x_mid =
1168       nir_u2u32(b, nir_iand_imm(b, nir_ushr(b, x, nir_imm_int(b, 24)),
1169                                    0xffffff));
1170    nir_ssa_def *x_hi =
1171       nir_u2u32(b, nir_ushr(b, x, nir_imm_int(b, 48)));
1172 
1173    nir_ssa_def *scan_low =
1174       build_scan_intrinsic(b, intrin->intrinsic, nir_op_iadd,
1175                               cluster_size, x_low);
1176    nir_ssa_def *scan_mid =
1177       build_scan_intrinsic(b, intrin->intrinsic, nir_op_iadd,
1178                               cluster_size, x_mid);
1179    nir_ssa_def *scan_hi =
1180       build_scan_intrinsic(b, intrin->intrinsic, nir_op_iadd,
1181                               cluster_size, x_hi);
1182 
1183    scan_low = nir_u2u64(b, scan_low);
1184    scan_mid = nir_ishl(b, nir_u2u64(b, scan_mid), nir_imm_int(b, 24));
1185    scan_hi = nir_ishl(b, nir_u2u64(b, scan_hi), nir_imm_int(b, 48));
1186 
1187    return nir_iadd(b, scan_hi, nir_iadd(b, scan_mid, scan_low));
1188 }
1189 
1190 static bool
should_lower_int64_intrinsic(const nir_intrinsic_instr * intrin,const nir_shader_compiler_options * options)1191 should_lower_int64_intrinsic(const nir_intrinsic_instr *intrin,
1192                              const nir_shader_compiler_options *options)
1193 {
1194    switch (intrin->intrinsic) {
1195    case nir_intrinsic_read_invocation:
1196    case nir_intrinsic_read_first_invocation:
1197    case nir_intrinsic_shuffle:
1198    case nir_intrinsic_shuffle_xor:
1199    case nir_intrinsic_shuffle_up:
1200    case nir_intrinsic_shuffle_down:
1201    case nir_intrinsic_quad_broadcast:
1202    case nir_intrinsic_quad_swap_horizontal:
1203    case nir_intrinsic_quad_swap_vertical:
1204    case nir_intrinsic_quad_swap_diagonal:
1205       assert(intrin->dest.is_ssa);
1206       return intrin->dest.ssa.bit_size == 64 &&
1207              (options->lower_int64_options & nir_lower_subgroup_shuffle64);
1208 
1209    case nir_intrinsic_vote_ieq:
1210       assert(intrin->src[0].is_ssa);
1211       return intrin->src[0].ssa->bit_size == 64 &&
1212              (options->lower_int64_options & nir_lower_vote_ieq64);
1213 
1214    case nir_intrinsic_reduce:
1215    case nir_intrinsic_inclusive_scan:
1216    case nir_intrinsic_exclusive_scan:
1217       assert(intrin->dest.is_ssa);
1218       if (intrin->dest.ssa.bit_size != 64)
1219          return false;
1220 
1221       switch (nir_intrinsic_reduction_op(intrin)) {
1222       case nir_op_iadd:
1223          return options->lower_int64_options & nir_lower_scan_reduce_iadd64;
1224       case nir_op_iand:
1225       case nir_op_ior:
1226       case nir_op_ixor:
1227          return options->lower_int64_options & nir_lower_scan_reduce_bitwise64;
1228       default:
1229          return false;
1230       }
1231       break;
1232 
1233    default:
1234       return false;
1235    }
1236 }
1237 
1238 static nir_ssa_def *
lower_int64_intrinsic(nir_builder * b,nir_intrinsic_instr * intrin)1239 lower_int64_intrinsic(nir_builder *b, nir_intrinsic_instr *intrin)
1240 {
1241    switch (intrin->intrinsic) {
1242    case nir_intrinsic_read_invocation:
1243    case nir_intrinsic_read_first_invocation:
1244    case nir_intrinsic_shuffle:
1245    case nir_intrinsic_shuffle_xor:
1246    case nir_intrinsic_shuffle_up:
1247    case nir_intrinsic_shuffle_down:
1248    case nir_intrinsic_quad_broadcast:
1249    case nir_intrinsic_quad_swap_horizontal:
1250    case nir_intrinsic_quad_swap_vertical:
1251    case nir_intrinsic_quad_swap_diagonal:
1252       return split_64bit_subgroup_op(b, intrin);
1253 
1254    case nir_intrinsic_vote_ieq:
1255       assert(intrin->src[0].is_ssa);
1256       return lower_vote_ieq(b, intrin->src[0].ssa);
1257 
1258    case nir_intrinsic_reduce:
1259    case nir_intrinsic_inclusive_scan:
1260    case nir_intrinsic_exclusive_scan:
1261       switch (nir_intrinsic_reduction_op(intrin)) {
1262       case nir_op_iadd:
1263          return lower_scan_iadd64(b, intrin);
1264       case nir_op_iand:
1265       case nir_op_ior:
1266       case nir_op_ixor:
1267          return split_64bit_subgroup_op(b, intrin);
1268       default:
1269          unreachable("Unsupported subgroup scan/reduce op");
1270       }
1271       break;
1272 
1273    default:
1274       unreachable("Unsupported intrinsic");
1275    }
1276 }
1277 
1278 static bool
should_lower_int64_instr(const nir_instr * instr,const void * _options)1279 should_lower_int64_instr(const nir_instr *instr, const void *_options)
1280 {
1281    switch (instr->type) {
1282    case nir_instr_type_alu:
1283       return should_lower_int64_alu_instr(nir_instr_as_alu(instr), _options);
1284    case nir_instr_type_intrinsic:
1285       return should_lower_int64_intrinsic(nir_instr_as_intrinsic(instr),
1286                                           _options);
1287    default:
1288       return false;
1289    }
1290 }
1291 
1292 static nir_ssa_def *
lower_int64_instr(nir_builder * b,nir_instr * instr,void * _options)1293 lower_int64_instr(nir_builder *b, nir_instr *instr, void *_options)
1294 {
1295    switch (instr->type) {
1296    case nir_instr_type_alu:
1297       return lower_int64_alu_instr(b, nir_instr_as_alu(instr));
1298    case nir_instr_type_intrinsic:
1299       return lower_int64_intrinsic(b, nir_instr_as_intrinsic(instr));
1300    default:
1301       return NULL;
1302    }
1303 }
1304 
1305 bool
nir_lower_int64(nir_shader * shader)1306 nir_lower_int64(nir_shader *shader)
1307 {
1308    return nir_shader_lower_instructions(shader, should_lower_int64_instr,
1309                                         lower_int64_instr,
1310                                         (void *)shader->options);
1311 }
1312