1 /*
2 * Copyright © 2016 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir_builder.h"
26
27 #define COND_LOWER_OP(b, name, ...) \
28 (b->shader->options->lower_int64_options & \
29 nir_lower_int64_op_to_options_mask(nir_op_##name)) \
30 ? lower_##name##64(b, __VA_ARGS__) \
31 : nir_##name(b, __VA_ARGS__)
32
33 #define COND_LOWER_CMP(b, name, ...) \
34 (b->shader->options->lower_int64_options & \
35 nir_lower_int64_op_to_options_mask(nir_op_##name)) \
36 ? lower_int64_compare(b, nir_op_##name, __VA_ARGS__) \
37 : nir_##name(b, __VA_ARGS__)
38
39 #define COND_LOWER_CAST(b, name, ...) \
40 (b->shader->options->lower_int64_options & \
41 nir_lower_int64_op_to_options_mask(nir_op_##name)) \
42 ? lower_##name(b, __VA_ARGS__) \
43 : nir_##name(b, __VA_ARGS__)
44
45 static nir_def *
lower_b2i64(nir_builder * b,nir_def * x)46 lower_b2i64(nir_builder *b, nir_def *x)
47 {
48 return nir_pack_64_2x32_split(b, nir_b2i32(b, x), nir_imm_int(b, 0));
49 }
50
51 static nir_def *
lower_i2i8(nir_builder * b,nir_def * x)52 lower_i2i8(nir_builder *b, nir_def *x)
53 {
54 return nir_i2i8(b, nir_unpack_64_2x32_split_x(b, x));
55 }
56
57 static nir_def *
lower_i2i16(nir_builder * b,nir_def * x)58 lower_i2i16(nir_builder *b, nir_def *x)
59 {
60 return nir_i2i16(b, nir_unpack_64_2x32_split_x(b, x));
61 }
62
63 static nir_def *
lower_i2i32(nir_builder * b,nir_def * x)64 lower_i2i32(nir_builder *b, nir_def *x)
65 {
66 return nir_unpack_64_2x32_split_x(b, x);
67 }
68
69 static nir_def *
lower_i2i64(nir_builder * b,nir_def * x)70 lower_i2i64(nir_builder *b, nir_def *x)
71 {
72 nir_def *x32 = x->bit_size == 32 ? x : nir_i2i32(b, x);
73 return nir_pack_64_2x32_split(b, x32, nir_ishr_imm(b, x32, 31));
74 }
75
76 static nir_def *
lower_u2u8(nir_builder * b,nir_def * x)77 lower_u2u8(nir_builder *b, nir_def *x)
78 {
79 return nir_u2u8(b, nir_unpack_64_2x32_split_x(b, x));
80 }
81
82 static nir_def *
lower_u2u16(nir_builder * b,nir_def * x)83 lower_u2u16(nir_builder *b, nir_def *x)
84 {
85 return nir_u2u16(b, nir_unpack_64_2x32_split_x(b, x));
86 }
87
88 static nir_def *
lower_u2u32(nir_builder * b,nir_def * x)89 lower_u2u32(nir_builder *b, nir_def *x)
90 {
91 return nir_unpack_64_2x32_split_x(b, x);
92 }
93
94 static nir_def *
lower_u2u64(nir_builder * b,nir_def * x)95 lower_u2u64(nir_builder *b, nir_def *x)
96 {
97 nir_def *x32 = x->bit_size == 32 ? x : nir_u2u32(b, x);
98 return nir_pack_64_2x32_split(b, x32, nir_imm_int(b, 0));
99 }
100
101 static nir_def *
lower_bcsel64(nir_builder * b,nir_def * cond,nir_def * x,nir_def * y)102 lower_bcsel64(nir_builder *b, nir_def *cond, nir_def *x, nir_def *y)
103 {
104 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
105 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
106 nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
107 nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
108
109 return nir_pack_64_2x32_split(b, nir_bcsel(b, cond, x_lo, y_lo),
110 nir_bcsel(b, cond, x_hi, y_hi));
111 }
112
113 static nir_def *
lower_inot64(nir_builder * b,nir_def * x)114 lower_inot64(nir_builder *b, nir_def *x)
115 {
116 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
117 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
118
119 return nir_pack_64_2x32_split(b, nir_inot(b, x_lo), nir_inot(b, x_hi));
120 }
121
122 static nir_def *
lower_iand64(nir_builder * b,nir_def * x,nir_def * y)123 lower_iand64(nir_builder *b, nir_def *x, nir_def *y)
124 {
125 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
126 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
127 nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
128 nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
129
130 return nir_pack_64_2x32_split(b, nir_iand(b, x_lo, y_lo),
131 nir_iand(b, x_hi, y_hi));
132 }
133
134 static nir_def *
lower_ior64(nir_builder * b,nir_def * x,nir_def * y)135 lower_ior64(nir_builder *b, nir_def *x, nir_def *y)
136 {
137 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
138 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
139 nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
140 nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
141
142 return nir_pack_64_2x32_split(b, nir_ior(b, x_lo, y_lo),
143 nir_ior(b, x_hi, y_hi));
144 }
145
146 static nir_def *
lower_ixor64(nir_builder * b,nir_def * x,nir_def * y)147 lower_ixor64(nir_builder *b, nir_def *x, nir_def *y)
148 {
149 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
150 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
151 nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
152 nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
153
154 return nir_pack_64_2x32_split(b, nir_ixor(b, x_lo, y_lo),
155 nir_ixor(b, x_hi, y_hi));
156 }
157
158 static nir_def *
lower_ishl64(nir_builder * b,nir_def * x,nir_def * y)159 lower_ishl64(nir_builder *b, nir_def *x, nir_def *y)
160 {
161 /* Implemented as
162 *
163 * uint64_t lshift(uint64_t x, int c)
164 * {
165 * c %= 64;
166 *
167 * if (c == 0) return x;
168 *
169 * uint32_t lo = LO(x), hi = HI(x);
170 *
171 * if (c < 32) {
172 * uint32_t lo_shifted = lo << c;
173 * uint32_t hi_shifted = hi << c;
174 * uint32_t lo_shifted_hi = lo >> abs(32 - c);
175 * return pack_64(lo_shifted, hi_shifted | lo_shifted_hi);
176 * } else {
177 * uint32_t lo_shifted_hi = lo << abs(32 - c);
178 * return pack_64(0, lo_shifted_hi);
179 * }
180 * }
181 */
182 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
183 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
184 y = nir_iand_imm(b, y, 0x3f);
185
186 nir_def *reverse_count = nir_iabs(b, nir_iadd_imm(b, y, -32));
187 nir_def *lo_shifted = nir_ishl(b, x_lo, y);
188 nir_def *hi_shifted = nir_ishl(b, x_hi, y);
189 nir_def *lo_shifted_hi = nir_ushr(b, x_lo, reverse_count);
190
191 nir_def *res_if_lt_32 =
192 nir_pack_64_2x32_split(b, lo_shifted,
193 nir_ior(b, hi_shifted, lo_shifted_hi));
194 nir_def *res_if_ge_32 =
195 nir_pack_64_2x32_split(b, nir_imm_int(b, 0),
196 nir_ishl(b, x_lo, reverse_count));
197
198 return nir_bcsel(b, nir_ieq_imm(b, y, 0), x,
199 nir_bcsel(b, nir_uge_imm(b, y, 32),
200 res_if_ge_32, res_if_lt_32));
201 }
202
203 static nir_def *
lower_ishr64(nir_builder * b,nir_def * x,nir_def * y)204 lower_ishr64(nir_builder *b, nir_def *x, nir_def *y)
205 {
206 /* Implemented as
207 *
208 * uint64_t arshift(uint64_t x, int c)
209 * {
210 * c %= 64;
211 *
212 * if (c == 0) return x;
213 *
214 * uint32_t lo = LO(x);
215 * int32_t hi = HI(x);
216 *
217 * if (c < 32) {
218 * uint32_t lo_shifted = lo >> c;
219 * uint32_t hi_shifted = hi >> c;
220 * uint32_t hi_shifted_lo = hi << abs(32 - c);
221 * return pack_64(hi_shifted, hi_shifted_lo | lo_shifted);
222 * } else {
223 * uint32_t hi_shifted = hi >> 31;
224 * uint32_t hi_shifted_lo = hi >> abs(32 - c);
225 * return pack_64(hi_shifted, hi_shifted_lo);
226 * }
227 * }
228 */
229 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
230 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
231 y = nir_iand_imm(b, y, 0x3f);
232
233 nir_def *reverse_count = nir_iabs(b, nir_iadd_imm(b, y, -32));
234 nir_def *lo_shifted = nir_ushr(b, x_lo, y);
235 nir_def *hi_shifted = nir_ishr(b, x_hi, y);
236 nir_def *hi_shifted_lo = nir_ishl(b, x_hi, reverse_count);
237
238 nir_def *res_if_lt_32 =
239 nir_pack_64_2x32_split(b, nir_ior(b, lo_shifted, hi_shifted_lo),
240 hi_shifted);
241 nir_def *res_if_ge_32 =
242 nir_pack_64_2x32_split(b, nir_ishr(b, x_hi, reverse_count),
243 nir_ishr_imm(b, x_hi, 31));
244
245 return nir_bcsel(b, nir_ieq_imm(b, y, 0), x,
246 nir_bcsel(b, nir_uge_imm(b, y, 32),
247 res_if_ge_32, res_if_lt_32));
248 }
249
250 static nir_def *
lower_ushr64(nir_builder * b,nir_def * x,nir_def * y)251 lower_ushr64(nir_builder *b, nir_def *x, nir_def *y)
252 {
253 /* Implemented as
254 *
255 * uint64_t rshift(uint64_t x, int c)
256 * {
257 * c %= 64;
258 *
259 * if (c == 0) return x;
260 *
261 * uint32_t lo = LO(x), hi = HI(x);
262 *
263 * if (c < 32) {
264 * uint32_t lo_shifted = lo >> c;
265 * uint32_t hi_shifted = hi >> c;
266 * uint32_t hi_shifted_lo = hi << abs(32 - c);
267 * return pack_64(hi_shifted, hi_shifted_lo | lo_shifted);
268 * } else {
269 * uint32_t hi_shifted_lo = hi >> abs(32 - c);
270 * return pack_64(0, hi_shifted_lo);
271 * }
272 * }
273 */
274
275 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
276 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
277 y = nir_iand_imm(b, y, 0x3f);
278
279 nir_def *reverse_count = nir_iabs(b, nir_iadd_imm(b, y, -32));
280 nir_def *lo_shifted = nir_ushr(b, x_lo, y);
281 nir_def *hi_shifted = nir_ushr(b, x_hi, y);
282 nir_def *hi_shifted_lo = nir_ishl(b, x_hi, reverse_count);
283
284 nir_def *res_if_lt_32 =
285 nir_pack_64_2x32_split(b, nir_ior(b, lo_shifted, hi_shifted_lo),
286 hi_shifted);
287 nir_def *res_if_ge_32 =
288 nir_pack_64_2x32_split(b, nir_ushr(b, x_hi, reverse_count),
289 nir_imm_int(b, 0));
290
291 return nir_bcsel(b, nir_ieq_imm(b, y, 0), x,
292 nir_bcsel(b, nir_uge_imm(b, y, 32),
293 res_if_ge_32, res_if_lt_32));
294 }
295
296 static nir_def *
lower_iadd64(nir_builder * b,nir_def * x,nir_def * y)297 lower_iadd64(nir_builder *b, nir_def *x, nir_def *y)
298 {
299 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
300 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
301 nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
302 nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
303
304 nir_def *res_lo = nir_iadd(b, x_lo, y_lo);
305 nir_def *carry = nir_b2i32(b, nir_ult(b, res_lo, x_lo));
306 nir_def *res_hi = nir_iadd(b, carry, nir_iadd(b, x_hi, y_hi));
307
308 return nir_pack_64_2x32_split(b, res_lo, res_hi);
309 }
310
311 static nir_def *
lower_isub64(nir_builder * b,nir_def * x,nir_def * y)312 lower_isub64(nir_builder *b, nir_def *x, nir_def *y)
313 {
314 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
315 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
316 nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
317 nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
318
319 nir_def *res_lo = nir_isub(b, x_lo, y_lo);
320 nir_def *borrow = nir_ineg(b, nir_b2i32(b, nir_ult(b, x_lo, y_lo)));
321 nir_def *res_hi = nir_iadd(b, nir_isub(b, x_hi, y_hi), borrow);
322
323 return nir_pack_64_2x32_split(b, res_lo, res_hi);
324 }
325
326 static nir_def *
lower_ineg64(nir_builder * b,nir_def * x)327 lower_ineg64(nir_builder *b, nir_def *x)
328 {
329 /* Since isub is the same number of instructions (with better dependencies)
330 * as iadd, subtraction is actually more efficient for ineg than the usual
331 * 2's complement "flip the bits and add one".
332 */
333 return lower_isub64(b, nir_imm_int64(b, 0), x);
334 }
335
336 static nir_def *
lower_iabs64(nir_builder * b,nir_def * x)337 lower_iabs64(nir_builder *b, nir_def *x)
338 {
339 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
340 nir_def *x_is_neg = nir_ilt_imm(b, x_hi, 0);
341 return nir_bcsel(b, x_is_neg, nir_ineg(b, x), x);
342 }
343
344 static nir_def *
lower_int64_compare(nir_builder * b,nir_op op,nir_def * x,nir_def * y)345 lower_int64_compare(nir_builder *b, nir_op op, nir_def *x, nir_def *y)
346 {
347 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
348 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
349 nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
350 nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
351
352 switch (op) {
353 case nir_op_ieq:
354 return nir_iand(b, nir_ieq(b, x_hi, y_hi), nir_ieq(b, x_lo, y_lo));
355 case nir_op_ine:
356 return nir_ior(b, nir_ine(b, x_hi, y_hi), nir_ine(b, x_lo, y_lo));
357 case nir_op_ult:
358 return nir_ior(b, nir_ult(b, x_hi, y_hi),
359 nir_iand(b, nir_ieq(b, x_hi, y_hi),
360 nir_ult(b, x_lo, y_lo)));
361 case nir_op_ilt:
362 return nir_ior(b, nir_ilt(b, x_hi, y_hi),
363 nir_iand(b, nir_ieq(b, x_hi, y_hi),
364 nir_ult(b, x_lo, y_lo)));
365 break;
366 case nir_op_uge:
367 /* Lower as !(x < y) in the hopes of better CSE */
368 return nir_inot(b, lower_int64_compare(b, nir_op_ult, x, y));
369 case nir_op_ige:
370 /* Lower as !(x < y) in the hopes of better CSE */
371 return nir_inot(b, lower_int64_compare(b, nir_op_ilt, x, y));
372 default:
373 unreachable("Invalid comparison");
374 }
375 }
376
377 static nir_def *
lower_umax64(nir_builder * b,nir_def * x,nir_def * y)378 lower_umax64(nir_builder *b, nir_def *x, nir_def *y)
379 {
380 return nir_bcsel(b, lower_int64_compare(b, nir_op_ult, x, y), y, x);
381 }
382
383 static nir_def *
lower_imax64(nir_builder * b,nir_def * x,nir_def * y)384 lower_imax64(nir_builder *b, nir_def *x, nir_def *y)
385 {
386 return nir_bcsel(b, lower_int64_compare(b, nir_op_ilt, x, y), y, x);
387 }
388
389 static nir_def *
lower_umin64(nir_builder * b,nir_def * x,nir_def * y)390 lower_umin64(nir_builder *b, nir_def *x, nir_def *y)
391 {
392 return nir_bcsel(b, lower_int64_compare(b, nir_op_ult, x, y), x, y);
393 }
394
395 static nir_def *
lower_imin64(nir_builder * b,nir_def * x,nir_def * y)396 lower_imin64(nir_builder *b, nir_def *x, nir_def *y)
397 {
398 return nir_bcsel(b, lower_int64_compare(b, nir_op_ilt, x, y), x, y);
399 }
400
401 static nir_def *
lower_mul_2x32_64(nir_builder * b,nir_def * x,nir_def * y,bool sign_extend)402 lower_mul_2x32_64(nir_builder *b, nir_def *x, nir_def *y,
403 bool sign_extend)
404 {
405 nir_def *res_hi = sign_extend ? nir_imul_high(b, x, y)
406 : nir_umul_high(b, x, y);
407
408 return nir_pack_64_2x32_split(b, nir_imul(b, x, y), res_hi);
409 }
410
411 static nir_def *
lower_imul64(nir_builder * b,nir_def * x,nir_def * y)412 lower_imul64(nir_builder *b, nir_def *x, nir_def *y)
413 {
414 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
415 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
416 nir_def *y_lo = nir_unpack_64_2x32_split_x(b, y);
417 nir_def *y_hi = nir_unpack_64_2x32_split_y(b, y);
418
419 nir_def *mul_lo = nir_umul_2x32_64(b, x_lo, y_lo);
420 nir_def *res_hi = nir_iadd(b, nir_unpack_64_2x32_split_y(b, mul_lo),
421 nir_iadd(b, nir_imul(b, x_lo, y_hi),
422 nir_imul(b, x_hi, y_lo)));
423
424 return nir_pack_64_2x32_split(b, nir_unpack_64_2x32_split_x(b, mul_lo),
425 res_hi);
426 }
427
428 static nir_def *
lower_mul_high64(nir_builder * b,nir_def * x,nir_def * y,bool sign_extend)429 lower_mul_high64(nir_builder *b, nir_def *x, nir_def *y,
430 bool sign_extend)
431 {
432 nir_def *x32[4], *y32[4];
433 x32[0] = nir_unpack_64_2x32_split_x(b, x);
434 x32[1] = nir_unpack_64_2x32_split_y(b, x);
435 if (sign_extend) {
436 x32[2] = x32[3] = nir_ishr_imm(b, x32[1], 31);
437 } else {
438 x32[2] = x32[3] = nir_imm_int(b, 0);
439 }
440
441 y32[0] = nir_unpack_64_2x32_split_x(b, y);
442 y32[1] = nir_unpack_64_2x32_split_y(b, y);
443 if (sign_extend) {
444 y32[2] = y32[3] = nir_ishr_imm(b, y32[1], 31);
445 } else {
446 y32[2] = y32[3] = nir_imm_int(b, 0);
447 }
448
449 nir_def *res[8] = {
450 NULL,
451 };
452
453 /* Yes, the following generates a pile of code. However, we throw res[0]
454 * and res[1] away in the end and, if we're in the umul case, four of our
455 * eight dword operands will be constant zero and opt_algebraic will clean
456 * this up nicely.
457 */
458 for (unsigned i = 0; i < 4; i++) {
459 nir_def *carry = NULL;
460 for (unsigned j = 0; j < 4; j++) {
461 /* The maximum values of x32[i] and y32[j] are UINT32_MAX so the
462 * maximum value of tmp is UINT32_MAX * UINT32_MAX. The maximum
463 * value that will fit in tmp is
464 *
465 * UINT64_MAX = UINT32_MAX << 32 + UINT32_MAX
466 * = UINT32_MAX * (UINT32_MAX + 1) + UINT32_MAX
467 * = UINT32_MAX * UINT32_MAX + 2 * UINT32_MAX
468 *
469 * so we're guaranteed that we can add in two more 32-bit values
470 * without overflowing tmp.
471 */
472 nir_def *tmp = nir_umul_2x32_64(b, x32[i], y32[j]);
473
474 if (res[i + j])
475 tmp = nir_iadd(b, tmp, nir_u2u64(b, res[i + j]));
476 if (carry)
477 tmp = nir_iadd(b, tmp, carry);
478 res[i + j] = nir_u2u32(b, tmp);
479 carry = nir_ushr_imm(b, tmp, 32);
480 }
481 res[i + 4] = nir_u2u32(b, carry);
482 }
483
484 return nir_pack_64_2x32_split(b, res[2], res[3]);
485 }
486
487 static nir_def *
lower_isign64(nir_builder * b,nir_def * x)488 lower_isign64(nir_builder *b, nir_def *x)
489 {
490 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
491 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
492
493 nir_def *is_non_zero = nir_i2b(b, nir_ior(b, x_lo, x_hi));
494 nir_def *res_hi = nir_ishr_imm(b, x_hi, 31);
495 nir_def *res_lo = nir_ior(b, res_hi, nir_b2i32(b, is_non_zero));
496
497 return nir_pack_64_2x32_split(b, res_lo, res_hi);
498 }
499
500 static void
lower_udiv64_mod64(nir_builder * b,nir_def * n,nir_def * d,nir_def ** q,nir_def ** r)501 lower_udiv64_mod64(nir_builder *b, nir_def *n, nir_def *d,
502 nir_def **q, nir_def **r)
503 {
504 /* TODO: We should specially handle the case where the denominator is a
505 * constant. In that case, we should be able to reduce it to a multiply by
506 * a constant, some shifts, and an add.
507 */
508 nir_def *n_lo = nir_unpack_64_2x32_split_x(b, n);
509 nir_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
510 nir_def *d_lo = nir_unpack_64_2x32_split_x(b, d);
511 nir_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
512
513 nir_def *q_lo = nir_imm_zero(b, n->num_components, 32);
514 nir_def *q_hi = nir_imm_zero(b, n->num_components, 32);
515
516 nir_def *n_hi_before_if = n_hi;
517 nir_def *q_hi_before_if = q_hi;
518
519 /* If the upper 32 bits of denom are non-zero, it is impossible for shifts
520 * greater than 32 bits to occur. If the upper 32 bits of the numerator
521 * are zero, it is impossible for (denom << [63, 32]) <= numer unless
522 * denom == 0.
523 */
524 nir_def *need_high_div =
525 nir_iand(b, nir_ieq_imm(b, d_hi, 0), nir_uge(b, n_hi, d_lo));
526 nir_push_if(b, nir_bany(b, need_high_div));
527 {
528 /* If we only have one component, then the bany above goes away and
529 * this is always true within the if statement.
530 */
531 if (n->num_components == 1)
532 need_high_div = nir_imm_true(b);
533
534 nir_def *log2_d_lo = nir_ufind_msb(b, d_lo);
535
536 for (int i = 31; i >= 0; i--) {
537 /* if ((d.x << i) <= n.y) {
538 * n.y -= d.x << i;
539 * quot.y |= 1U << i;
540 * }
541 */
542 nir_def *d_shift = nir_ishl_imm(b, d_lo, i);
543 nir_def *new_n_hi = nir_isub(b, n_hi, d_shift);
544 nir_def *new_q_hi = nir_ior_imm(b, q_hi, 1ull << i);
545 nir_def *cond = nir_iand(b, need_high_div,
546 nir_uge(b, n_hi, d_shift));
547 if (i != 0) {
548 /* log2_d_lo is always <= 31, so we don't need to bother with it
549 * in the last iteration.
550 */
551 cond = nir_iand(b, cond,
552 nir_ile_imm(b, log2_d_lo, 31 - i));
553 }
554 n_hi = nir_bcsel(b, cond, new_n_hi, n_hi);
555 q_hi = nir_bcsel(b, cond, new_q_hi, q_hi);
556 }
557 }
558 nir_pop_if(b, NULL);
559 n_hi = nir_if_phi(b, n_hi, n_hi_before_if);
560 q_hi = nir_if_phi(b, q_hi, q_hi_before_if);
561
562 nir_def *log2_denom = nir_ufind_msb(b, d_hi);
563
564 n = nir_pack_64_2x32_split(b, n_lo, n_hi);
565 d = nir_pack_64_2x32_split(b, d_lo, d_hi);
566 for (int i = 31; i >= 0; i--) {
567 /* if ((d64 << i) <= n64) {
568 * n64 -= d64 << i;
569 * quot.x |= 1U << i;
570 * }
571 */
572 nir_def *d_shift = nir_ishl_imm(b, d, i);
573 nir_def *new_n = nir_isub(b, n, d_shift);
574 nir_def *new_q_lo = nir_ior_imm(b, q_lo, 1ull << i);
575 nir_def *cond = nir_uge(b, n, d_shift);
576 if (i != 0) {
577 /* log2_denom is always <= 31, so we don't need to bother with it
578 * in the last iteration.
579 */
580 cond = nir_iand(b, cond,
581 nir_ile_imm(b, log2_denom, 31 - i));
582 }
583 n = nir_bcsel(b, cond, new_n, n);
584 q_lo = nir_bcsel(b, cond, new_q_lo, q_lo);
585 }
586
587 *q = nir_pack_64_2x32_split(b, q_lo, q_hi);
588 *r = n;
589 }
590
591 static nir_def *
lower_udiv64(nir_builder * b,nir_def * n,nir_def * d)592 lower_udiv64(nir_builder *b, nir_def *n, nir_def *d)
593 {
594 nir_def *q, *r;
595 lower_udiv64_mod64(b, n, d, &q, &r);
596 return q;
597 }
598
599 static nir_def *
lower_idiv64(nir_builder * b,nir_def * n,nir_def * d)600 lower_idiv64(nir_builder *b, nir_def *n, nir_def *d)
601 {
602 nir_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
603 nir_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
604
605 nir_def *negate = nir_ine(b, nir_ilt_imm(b, n_hi, 0),
606 nir_ilt_imm(b, d_hi, 0));
607 nir_def *q, *r;
608 lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
609 return nir_bcsel(b, negate, nir_ineg(b, q), q);
610 }
611
612 static nir_def *
lower_umod64(nir_builder * b,nir_def * n,nir_def * d)613 lower_umod64(nir_builder *b, nir_def *n, nir_def *d)
614 {
615 nir_def *q, *r;
616 lower_udiv64_mod64(b, n, d, &q, &r);
617 return r;
618 }
619
620 static nir_def *
lower_imod64(nir_builder * b,nir_def * n,nir_def * d)621 lower_imod64(nir_builder *b, nir_def *n, nir_def *d)
622 {
623 nir_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
624 nir_def *d_hi = nir_unpack_64_2x32_split_y(b, d);
625 nir_def *n_is_neg = nir_ilt_imm(b, n_hi, 0);
626 nir_def *d_is_neg = nir_ilt_imm(b, d_hi, 0);
627
628 nir_def *q, *r;
629 lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
630
631 nir_def *rem = nir_bcsel(b, n_is_neg, nir_ineg(b, r), r);
632
633 return nir_bcsel(b, nir_ieq_imm(b, r, 0), nir_imm_int64(b, 0),
634 nir_bcsel(b, nir_ieq(b, n_is_neg, d_is_neg), rem,
635 nir_iadd(b, rem, d)));
636 }
637
638 static nir_def *
lower_irem64(nir_builder * b,nir_def * n,nir_def * d)639 lower_irem64(nir_builder *b, nir_def *n, nir_def *d)
640 {
641 nir_def *n_hi = nir_unpack_64_2x32_split_y(b, n);
642 nir_def *n_is_neg = nir_ilt_imm(b, n_hi, 0);
643
644 nir_def *q, *r;
645 lower_udiv64_mod64(b, nir_iabs(b, n), nir_iabs(b, d), &q, &r);
646 return nir_bcsel(b, n_is_neg, nir_ineg(b, r), r);
647 }
648
649 static nir_def *
lower_extract(nir_builder * b,nir_op op,nir_def * x,nir_def * c)650 lower_extract(nir_builder *b, nir_op op, nir_def *x, nir_def *c)
651 {
652 assert(op == nir_op_extract_u8 || op == nir_op_extract_i8 ||
653 op == nir_op_extract_u16 || op == nir_op_extract_i16);
654
655 const int chunk = nir_src_as_uint(nir_src_for_ssa(c));
656 const int chunk_bits =
657 (op == nir_op_extract_u8 || op == nir_op_extract_i8) ? 8 : 16;
658 const int num_chunks_in_32 = 32 / chunk_bits;
659
660 nir_def *extract32;
661 if (chunk < num_chunks_in_32) {
662 extract32 = nir_build_alu(b, op, nir_unpack_64_2x32_split_x(b, x),
663 nir_imm_int(b, chunk),
664 NULL, NULL);
665 } else {
666 extract32 = nir_build_alu(b, op, nir_unpack_64_2x32_split_y(b, x),
667 nir_imm_int(b, chunk - num_chunks_in_32),
668 NULL, NULL);
669 }
670
671 if (op == nir_op_extract_i8 || op == nir_op_extract_i16)
672 return lower_i2i64(b, extract32);
673 else
674 return lower_u2u64(b, extract32);
675 }
676
677 static nir_def *
lower_ufind_msb64(nir_builder * b,nir_def * x)678 lower_ufind_msb64(nir_builder *b, nir_def *x)
679 {
680
681 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
682 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
683 nir_def *lo_count = nir_ufind_msb(b, x_lo);
684 nir_def *hi_count = nir_ufind_msb(b, x_hi);
685
686 if (b->shader->options->lower_uadd_sat) {
687 nir_def *valid_hi_bits = nir_ine_imm(b, x_hi, 0);
688 nir_def *hi_res = nir_iadd_imm(b, hi_count, 32);
689 return nir_bcsel(b, valid_hi_bits, hi_res, lo_count);
690 } else {
691 /* If hi_count was -1, it will still be -1 after this uadd_sat. As a
692 * result, hi_count is either -1 or the correct return value for 64-bit
693 * ufind_msb.
694 */
695 nir_def *hi_res = nir_uadd_sat(b, nir_imm_intN_t(b, 32, 32), hi_count);
696
697 /* hi_res is either -1 or a value in the range [63, 32]. lo_count is
698 * either -1 or a value in the range [31, 0]. The imax will pick
699 * lo_count only when hi_res is -1. In those cases, lo_count is
700 * guaranteed to be the correct answer.
701 */
702 return nir_imax(b, hi_res, lo_count);
703 }
704 }
705
706 static nir_def *
lower_find_lsb64(nir_builder * b,nir_def * x)707 lower_find_lsb64(nir_builder *b, nir_def *x)
708 {
709 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
710 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
711 nir_def *lo_lsb = nir_find_lsb(b, x_lo);
712 nir_def *hi_lsb = nir_find_lsb(b, x_hi);
713
714 /* Use umin so that -1 (no bits found) becomes larger (0xFFFFFFFF)
715 * than any actual bit position, so we return a found bit instead.
716 * This is similar to the ufind_msb lowering. If you need this lowering
717 * without uadd_sat, add code like in lower_ufind_msb64.
718 */
719 assert(!b->shader->options->lower_uadd_sat);
720 return nir_umin(b, lo_lsb, nir_uadd_sat(b, hi_lsb, nir_imm_int(b, 32)));
721 }
722
723 static nir_def *
lower_2f(nir_builder * b,nir_def * x,unsigned dest_bit_size,bool src_is_signed)724 lower_2f(nir_builder *b, nir_def *x, unsigned dest_bit_size,
725 bool src_is_signed)
726 {
727 nir_def *x_sign = NULL;
728
729 if (src_is_signed) {
730 x_sign = nir_bcsel(b, COND_LOWER_CMP(b, ilt, x, nir_imm_int64(b, 0)),
731 nir_imm_floatN_t(b, -1, dest_bit_size),
732 nir_imm_floatN_t(b, 1, dest_bit_size));
733 x = COND_LOWER_OP(b, iabs, x);
734 }
735
736 nir_def *exp = COND_LOWER_OP(b, ufind_msb, x);
737 unsigned significand_bits;
738
739 switch (dest_bit_size) {
740 case 64:
741 significand_bits = 52;
742 break;
743 case 32:
744 significand_bits = 23;
745 break;
746 case 16:
747 significand_bits = 10;
748 break;
749 default:
750 unreachable("Invalid dest_bit_size");
751 }
752
753 nir_def *discard =
754 nir_imax(b, nir_iadd_imm(b, exp, -significand_bits),
755 nir_imm_int(b, 0));
756 nir_def *significand = COND_LOWER_OP(b, ushr, x, discard);
757 if (significand_bits < 32)
758 significand = COND_LOWER_CAST(b, u2u32, significand);
759
760 /* Round-to-nearest-even implementation:
761 * - if the non-representable part of the significand is higher than half
762 * the minimum representable significand, we round-up
763 * - if the non-representable part of the significand is equal to half the
764 * minimum representable significand and the representable part of the
765 * significand is odd, we round-up
766 * - in any other case, we round-down
767 */
768 nir_def *lsb_mask = COND_LOWER_OP(b, ishl, nir_imm_int64(b, 1), discard);
769 nir_def *rem_mask = COND_LOWER_OP(b, isub, lsb_mask, nir_imm_int64(b, 1));
770 nir_def *half = COND_LOWER_OP(b, ishr, lsb_mask, nir_imm_int(b, 1));
771 nir_def *rem = COND_LOWER_OP(b, iand, x, rem_mask);
772 nir_def *halfway = nir_iand(b, COND_LOWER_CMP(b, ieq, rem, half),
773 nir_ine_imm(b, discard, 0));
774 nir_def *is_odd = COND_LOWER_CMP(b, ine, nir_imm_int64(b, 0),
775 COND_LOWER_OP(b, iand, x, lsb_mask));
776 nir_def *round_up = nir_ior(b, COND_LOWER_CMP(b, ilt, half, rem),
777 nir_iand(b, halfway, is_odd));
778 if (!nir_is_rounding_mode_rtz(b->shader->info.float_controls_execution_mode,
779 dest_bit_size)) {
780 if (significand_bits >= 32)
781 significand = COND_LOWER_OP(b, iadd, significand,
782 COND_LOWER_CAST(b, b2i64, round_up));
783 else
784 significand = nir_iadd(b, significand, nir_b2i32(b, round_up));
785 }
786
787 nir_def *res;
788
789 if (dest_bit_size == 64) {
790 /* Compute the left shift required to normalize the original
791 * unrounded input manually.
792 */
793 nir_def *shift =
794 nir_imax(b, nir_isub_imm(b, significand_bits, exp),
795 nir_imm_int(b, 0));
796 significand = COND_LOWER_OP(b, ishl, significand, shift);
797
798 /* Check whether normalization led to overflow of the available
799 * significand bits, which can only happen if round_up was true
800 * above, in which case we need to add carry to the exponent and
801 * discard an extra bit from the significand. Note that we
802 * don't need to repeat the round-up logic again, since the LSB
803 * of the significand is guaranteed to be zero if there was
804 * overflow.
805 */
806 nir_def *carry = nir_b2i32(
807 b, nir_uge_imm(b, nir_unpack_64_2x32_split_y(b, significand),
808 (uint64_t)(1 << (significand_bits - 31))));
809 significand = COND_LOWER_OP(b, ishr, significand, carry);
810 exp = nir_iadd(b, exp, carry);
811
812 /* Compute the biased exponent, taking care to handle a zero
813 * input correctly, which would have caused exp to be negative.
814 */
815 nir_def *biased_exp = nir_bcsel(b, nir_ilt_imm(b, exp, 0),
816 nir_imm_int(b, 0),
817 nir_iadd_imm(b, exp, 1023));
818
819 /* Pack the significand and exponent manually. */
820 nir_def *lo = nir_unpack_64_2x32_split_x(b, significand);
821 nir_def *hi = nir_bitfield_insert(
822 b, nir_unpack_64_2x32_split_y(b, significand),
823 biased_exp, nir_imm_int(b, 20), nir_imm_int(b, 11));
824
825 res = nir_pack_64_2x32_split(b, lo, hi);
826
827 } else if (dest_bit_size == 32) {
828 res = nir_fmul(b, nir_u2f32(b, significand),
829 nir_fexp2(b, nir_u2f32(b, discard)));
830 } else {
831 res = nir_fmul(b, nir_u2f16(b, significand),
832 nir_fexp2(b, nir_u2f16(b, discard)));
833 }
834
835 if (src_is_signed)
836 res = nir_fmul(b, res, x_sign);
837
838 return res;
839 }
840
841 static nir_def *
lower_f2(nir_builder * b,nir_def * x,bool dst_is_signed)842 lower_f2(nir_builder *b, nir_def *x, bool dst_is_signed)
843 {
844 assert(x->bit_size == 16 || x->bit_size == 32 || x->bit_size == 64);
845 nir_def *x_sign = NULL;
846
847 if (dst_is_signed)
848 x_sign = nir_fsign(b, x);
849
850 x = nir_ftrunc(b, x);
851
852 if (dst_is_signed)
853 x = nir_fabs(b, x);
854
855 nir_def *res;
856 if (x->bit_size < 32) {
857 res = nir_pack_64_2x32_split(b, nir_f2u32(b, x), nir_imm_int(b, 0));
858 } else {
859 nir_def *div = nir_imm_floatN_t(b, 1ULL << 32, x->bit_size);
860 nir_def *res_hi = nir_f2u32(b, nir_fdiv(b, x, div));
861 nir_def *res_lo = nir_f2u32(b, nir_frem(b, x, div));
862 res = nir_pack_64_2x32_split(b, res_lo, res_hi);
863 }
864
865 if (dst_is_signed)
866 res = nir_bcsel(b, nir_flt_imm(b, x_sign, 0),
867 nir_ineg(b, res), res);
868
869 return res;
870 }
871
872 static nir_def *
lower_bit_count64(nir_builder * b,nir_def * x)873 lower_bit_count64(nir_builder *b, nir_def *x)
874 {
875 nir_def *x_lo = nir_unpack_64_2x32_split_x(b, x);
876 nir_def *x_hi = nir_unpack_64_2x32_split_y(b, x);
877 nir_def *lo_count = nir_bit_count(b, x_lo);
878 nir_def *hi_count = nir_bit_count(b, x_hi);
879 return nir_iadd(b, lo_count, hi_count);
880 }
881
882 nir_lower_int64_options
nir_lower_int64_op_to_options_mask(nir_op opcode)883 nir_lower_int64_op_to_options_mask(nir_op opcode)
884 {
885 switch (opcode) {
886 case nir_op_imul:
887 case nir_op_amul:
888 return nir_lower_imul64;
889 case nir_op_imul_2x32_64:
890 case nir_op_umul_2x32_64:
891 return nir_lower_imul_2x32_64;
892 case nir_op_imul_high:
893 case nir_op_umul_high:
894 return nir_lower_imul_high64;
895 case nir_op_isign:
896 return nir_lower_isign64;
897 case nir_op_udiv:
898 case nir_op_idiv:
899 case nir_op_umod:
900 case nir_op_imod:
901 case nir_op_irem:
902 return nir_lower_divmod64;
903 case nir_op_b2i64:
904 case nir_op_i2i8:
905 case nir_op_i2i16:
906 case nir_op_i2i32:
907 case nir_op_i2i64:
908 case nir_op_u2u8:
909 case nir_op_u2u16:
910 case nir_op_u2u32:
911 case nir_op_u2u64:
912 case nir_op_i2f64:
913 case nir_op_u2f64:
914 case nir_op_i2f32:
915 case nir_op_u2f32:
916 case nir_op_i2f16:
917 case nir_op_u2f16:
918 case nir_op_f2i64:
919 case nir_op_f2u64:
920 return nir_lower_conv64;
921 case nir_op_bcsel:
922 return nir_lower_bcsel64;
923 case nir_op_ieq:
924 case nir_op_ine:
925 case nir_op_ult:
926 case nir_op_ilt:
927 case nir_op_uge:
928 case nir_op_ige:
929 return nir_lower_icmp64;
930 case nir_op_iadd:
931 case nir_op_isub:
932 return nir_lower_iadd64;
933 case nir_op_imin:
934 case nir_op_imax:
935 case nir_op_umin:
936 case nir_op_umax:
937 return nir_lower_minmax64;
938 case nir_op_iabs:
939 return nir_lower_iabs64;
940 case nir_op_ineg:
941 return nir_lower_ineg64;
942 case nir_op_iand:
943 case nir_op_ior:
944 case nir_op_ixor:
945 case nir_op_inot:
946 return nir_lower_logic64;
947 case nir_op_ishl:
948 case nir_op_ishr:
949 case nir_op_ushr:
950 return nir_lower_shift64;
951 case nir_op_extract_u8:
952 case nir_op_extract_i8:
953 case nir_op_extract_u16:
954 case nir_op_extract_i16:
955 return nir_lower_extract64;
956 case nir_op_ufind_msb:
957 return nir_lower_ufind_msb64;
958 case nir_op_find_lsb:
959 return nir_lower_find_lsb64;
960 case nir_op_bit_count:
961 return nir_lower_bit_count64;
962 default:
963 return 0;
964 }
965 }
966
967 static nir_def *
lower_int64_alu_instr(nir_builder * b,nir_alu_instr * alu)968 lower_int64_alu_instr(nir_builder *b, nir_alu_instr *alu)
969 {
970 nir_def *src[4];
971 for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++)
972 src[i] = nir_ssa_for_alu_src(b, alu, i);
973
974 switch (alu->op) {
975 case nir_op_imul:
976 case nir_op_amul:
977 return lower_imul64(b, src[0], src[1]);
978 case nir_op_imul_2x32_64:
979 return lower_mul_2x32_64(b, src[0], src[1], true);
980 case nir_op_umul_2x32_64:
981 return lower_mul_2x32_64(b, src[0], src[1], false);
982 case nir_op_imul_high:
983 return lower_mul_high64(b, src[0], src[1], true);
984 case nir_op_umul_high:
985 return lower_mul_high64(b, src[0], src[1], false);
986 case nir_op_isign:
987 return lower_isign64(b, src[0]);
988 case nir_op_udiv:
989 return lower_udiv64(b, src[0], src[1]);
990 case nir_op_idiv:
991 return lower_idiv64(b, src[0], src[1]);
992 case nir_op_umod:
993 return lower_umod64(b, src[0], src[1]);
994 case nir_op_imod:
995 return lower_imod64(b, src[0], src[1]);
996 case nir_op_irem:
997 return lower_irem64(b, src[0], src[1]);
998 case nir_op_b2i64:
999 return lower_b2i64(b, src[0]);
1000 case nir_op_i2i8:
1001 return lower_i2i8(b, src[0]);
1002 case nir_op_i2i16:
1003 return lower_i2i16(b, src[0]);
1004 case nir_op_i2i32:
1005 return lower_i2i32(b, src[0]);
1006 case nir_op_i2i64:
1007 return lower_i2i64(b, src[0]);
1008 case nir_op_u2u8:
1009 return lower_u2u8(b, src[0]);
1010 case nir_op_u2u16:
1011 return lower_u2u16(b, src[0]);
1012 case nir_op_u2u32:
1013 return lower_u2u32(b, src[0]);
1014 case nir_op_u2u64:
1015 return lower_u2u64(b, src[0]);
1016 case nir_op_bcsel:
1017 return lower_bcsel64(b, src[0], src[1], src[2]);
1018 case nir_op_ieq:
1019 case nir_op_ine:
1020 case nir_op_ult:
1021 case nir_op_ilt:
1022 case nir_op_uge:
1023 case nir_op_ige:
1024 return lower_int64_compare(b, alu->op, src[0], src[1]);
1025 case nir_op_iadd:
1026 return lower_iadd64(b, src[0], src[1]);
1027 case nir_op_isub:
1028 return lower_isub64(b, src[0], src[1]);
1029 case nir_op_imin:
1030 return lower_imin64(b, src[0], src[1]);
1031 case nir_op_imax:
1032 return lower_imax64(b, src[0], src[1]);
1033 case nir_op_umin:
1034 return lower_umin64(b, src[0], src[1]);
1035 case nir_op_umax:
1036 return lower_umax64(b, src[0], src[1]);
1037 case nir_op_iabs:
1038 return lower_iabs64(b, src[0]);
1039 case nir_op_ineg:
1040 return lower_ineg64(b, src[0]);
1041 case nir_op_iand:
1042 return lower_iand64(b, src[0], src[1]);
1043 case nir_op_ior:
1044 return lower_ior64(b, src[0], src[1]);
1045 case nir_op_ixor:
1046 return lower_ixor64(b, src[0], src[1]);
1047 case nir_op_inot:
1048 return lower_inot64(b, src[0]);
1049 case nir_op_ishl:
1050 return lower_ishl64(b, src[0], src[1]);
1051 case nir_op_ishr:
1052 return lower_ishr64(b, src[0], src[1]);
1053 case nir_op_ushr:
1054 return lower_ushr64(b, src[0], src[1]);
1055 case nir_op_extract_u8:
1056 case nir_op_extract_i8:
1057 case nir_op_extract_u16:
1058 case nir_op_extract_i16:
1059 return lower_extract(b, alu->op, src[0], src[1]);
1060 case nir_op_ufind_msb:
1061 return lower_ufind_msb64(b, src[0]);
1062 case nir_op_find_lsb:
1063 return lower_find_lsb64(b, src[0]);
1064 case nir_op_bit_count:
1065 return lower_bit_count64(b, src[0]);
1066 case nir_op_i2f64:
1067 case nir_op_i2f32:
1068 case nir_op_i2f16:
1069 return lower_2f(b, src[0], alu->def.bit_size, true);
1070 case nir_op_u2f64:
1071 case nir_op_u2f32:
1072 case nir_op_u2f16:
1073 return lower_2f(b, src[0], alu->def.bit_size, false);
1074 case nir_op_f2i64:
1075 case nir_op_f2u64:
1076 return lower_f2(b, src[0], alu->op == nir_op_f2i64);
1077 default:
1078 unreachable("Invalid ALU opcode to lower");
1079 }
1080 }
1081
1082 static bool
should_lower_int64_alu_instr(const nir_alu_instr * alu,const nir_shader_compiler_options * options)1083 should_lower_int64_alu_instr(const nir_alu_instr *alu,
1084 const nir_shader_compiler_options *options)
1085 {
1086 switch (alu->op) {
1087 case nir_op_i2i8:
1088 case nir_op_i2i16:
1089 case nir_op_i2i32:
1090 case nir_op_u2u8:
1091 case nir_op_u2u16:
1092 case nir_op_u2u32:
1093 if (alu->src[0].src.ssa->bit_size != 64)
1094 return false;
1095 break;
1096 case nir_op_bcsel:
1097 assert(alu->src[1].src.ssa->bit_size ==
1098 alu->src[2].src.ssa->bit_size);
1099 if (alu->src[1].src.ssa->bit_size != 64)
1100 return false;
1101 break;
1102 case nir_op_ieq:
1103 case nir_op_ine:
1104 case nir_op_ult:
1105 case nir_op_ilt:
1106 case nir_op_uge:
1107 case nir_op_ige:
1108 assert(alu->src[0].src.ssa->bit_size ==
1109 alu->src[1].src.ssa->bit_size);
1110 if (alu->src[0].src.ssa->bit_size != 64)
1111 return false;
1112 break;
1113 case nir_op_ufind_msb:
1114 case nir_op_find_lsb:
1115 case nir_op_bit_count:
1116 if (alu->src[0].src.ssa->bit_size != 64)
1117 return false;
1118 break;
1119 case nir_op_amul:
1120 if (options->has_imul24)
1121 return false;
1122 if (alu->def.bit_size != 64)
1123 return false;
1124 break;
1125 case nir_op_i2f64:
1126 case nir_op_u2f64:
1127 case nir_op_i2f32:
1128 case nir_op_u2f32:
1129 case nir_op_i2f16:
1130 case nir_op_u2f16:
1131 if (alu->src[0].src.ssa->bit_size != 64)
1132 return false;
1133 break;
1134 case nir_op_f2u64:
1135 case nir_op_f2i64:
1136 FALLTHROUGH;
1137 default:
1138 if (alu->def.bit_size != 64)
1139 return false;
1140 break;
1141 }
1142
1143 unsigned mask = nir_lower_int64_op_to_options_mask(alu->op);
1144 return (options->lower_int64_options & mask) != 0;
1145 }
1146
1147 static nir_def *
split_64bit_subgroup_op(nir_builder * b,const nir_intrinsic_instr * intrin)1148 split_64bit_subgroup_op(nir_builder *b, const nir_intrinsic_instr *intrin)
1149 {
1150 const nir_intrinsic_info *info = &nir_intrinsic_infos[intrin->intrinsic];
1151
1152 /* This works on subgroup ops with a single 64-bit source which can be
1153 * trivially lowered by doing the exact same op on both halves.
1154 */
1155 assert(nir_src_bit_size(intrin->src[0]) == 64);
1156 nir_def *split_src0[2] = {
1157 nir_unpack_64_2x32_split_x(b, intrin->src[0].ssa),
1158 nir_unpack_64_2x32_split_y(b, intrin->src[0].ssa),
1159 };
1160
1161 assert(info->has_dest && intrin->def.bit_size == 64);
1162
1163 nir_def *res[2];
1164 for (unsigned i = 0; i < 2; i++) {
1165 nir_intrinsic_instr *split =
1166 nir_intrinsic_instr_create(b->shader, intrin->intrinsic);
1167 split->num_components = intrin->num_components;
1168 split->src[0] = nir_src_for_ssa(split_src0[i]);
1169
1170 /* Other sources must be less than 64 bits and get copied directly */
1171 for (unsigned j = 1; j < info->num_srcs; j++) {
1172 assert(nir_src_bit_size(intrin->src[j]) < 64);
1173 split->src[j] = nir_src_for_ssa(intrin->src[j].ssa);
1174 }
1175
1176 /* Copy const indices, if any */
1177 memcpy(split->const_index, intrin->const_index,
1178 sizeof(intrin->const_index));
1179
1180 nir_def_init(&split->instr, &split->def,
1181 intrin->def.num_components, 32);
1182 nir_builder_instr_insert(b, &split->instr);
1183
1184 res[i] = &split->def;
1185 }
1186
1187 return nir_pack_64_2x32_split(b, res[0], res[1]);
1188 }
1189
1190 static nir_def *
build_vote_ieq(nir_builder * b,nir_def * x)1191 build_vote_ieq(nir_builder *b, nir_def *x)
1192 {
1193 nir_intrinsic_instr *vote =
1194 nir_intrinsic_instr_create(b->shader, nir_intrinsic_vote_ieq);
1195 vote->src[0] = nir_src_for_ssa(x);
1196 vote->num_components = x->num_components;
1197 nir_def_init(&vote->instr, &vote->def, 1, 1);
1198 nir_builder_instr_insert(b, &vote->instr);
1199 return &vote->def;
1200 }
1201
1202 static nir_def *
lower_vote_ieq(nir_builder * b,nir_def * x)1203 lower_vote_ieq(nir_builder *b, nir_def *x)
1204 {
1205 return nir_iand(b, build_vote_ieq(b, nir_unpack_64_2x32_split_x(b, x)),
1206 build_vote_ieq(b, nir_unpack_64_2x32_split_y(b, x)));
1207 }
1208
1209 static nir_def *
build_scan_intrinsic(nir_builder * b,nir_intrinsic_op scan_op,nir_op reduction_op,unsigned cluster_size,nir_def * val)1210 build_scan_intrinsic(nir_builder *b, nir_intrinsic_op scan_op,
1211 nir_op reduction_op, unsigned cluster_size,
1212 nir_def *val)
1213 {
1214 nir_intrinsic_instr *scan =
1215 nir_intrinsic_instr_create(b->shader, scan_op);
1216 scan->num_components = val->num_components;
1217 scan->src[0] = nir_src_for_ssa(val);
1218 nir_intrinsic_set_reduction_op(scan, reduction_op);
1219 if (scan_op == nir_intrinsic_reduce)
1220 nir_intrinsic_set_cluster_size(scan, cluster_size);
1221 nir_def_init(&scan->instr, &scan->def, val->num_components,
1222 val->bit_size);
1223 nir_builder_instr_insert(b, &scan->instr);
1224 return &scan->def;
1225 }
1226
1227 static nir_def *
lower_scan_iadd64(nir_builder * b,const nir_intrinsic_instr * intrin)1228 lower_scan_iadd64(nir_builder *b, const nir_intrinsic_instr *intrin)
1229 {
1230 unsigned cluster_size =
1231 intrin->intrinsic == nir_intrinsic_reduce ? nir_intrinsic_cluster_size(intrin) : 0;
1232
1233 /* Split it into three chunks of no more than 24 bits each. With 8 bits
1234 * of headroom, we're guaranteed that there will never be overflow in the
1235 * individual subgroup operations. (Assuming, of course, a subgroup size
1236 * no larger than 256 which seems reasonable.) We can then scan on each of
1237 * the chunks and add them back together at the end.
1238 */
1239 nir_def *x = intrin->src[0].ssa;
1240 nir_def *x_low =
1241 nir_u2u32(b, nir_iand_imm(b, x, 0xffffff));
1242 nir_def *x_mid =
1243 nir_u2u32(b, nir_iand_imm(b, nir_ushr_imm(b, x, 24),
1244 0xffffff));
1245 nir_def *x_hi =
1246 nir_u2u32(b, nir_ushr_imm(b, x, 48));
1247
1248 nir_def *scan_low =
1249 build_scan_intrinsic(b, intrin->intrinsic, nir_op_iadd,
1250 cluster_size, x_low);
1251 nir_def *scan_mid =
1252 build_scan_intrinsic(b, intrin->intrinsic, nir_op_iadd,
1253 cluster_size, x_mid);
1254 nir_def *scan_hi =
1255 build_scan_intrinsic(b, intrin->intrinsic, nir_op_iadd,
1256 cluster_size, x_hi);
1257
1258 scan_low = nir_u2u64(b, scan_low);
1259 scan_mid = nir_ishl_imm(b, nir_u2u64(b, scan_mid), 24);
1260 scan_hi = nir_ishl_imm(b, nir_u2u64(b, scan_hi), 48);
1261
1262 return nir_iadd(b, scan_hi, nir_iadd(b, scan_mid, scan_low));
1263 }
1264
1265 static bool
should_lower_int64_intrinsic(const nir_intrinsic_instr * intrin,const nir_shader_compiler_options * options)1266 should_lower_int64_intrinsic(const nir_intrinsic_instr *intrin,
1267 const nir_shader_compiler_options *options)
1268 {
1269 switch (intrin->intrinsic) {
1270 case nir_intrinsic_read_invocation:
1271 case nir_intrinsic_read_first_invocation:
1272 case nir_intrinsic_shuffle:
1273 case nir_intrinsic_shuffle_xor:
1274 case nir_intrinsic_shuffle_up:
1275 case nir_intrinsic_shuffle_down:
1276 case nir_intrinsic_quad_broadcast:
1277 case nir_intrinsic_quad_swap_horizontal:
1278 case nir_intrinsic_quad_swap_vertical:
1279 case nir_intrinsic_quad_swap_diagonal:
1280 return intrin->def.bit_size == 64 &&
1281 (options->lower_int64_options & nir_lower_subgroup_shuffle64);
1282
1283 case nir_intrinsic_vote_ieq:
1284 return intrin->src[0].ssa->bit_size == 64 &&
1285 (options->lower_int64_options & nir_lower_vote_ieq64);
1286
1287 case nir_intrinsic_reduce:
1288 case nir_intrinsic_inclusive_scan:
1289 case nir_intrinsic_exclusive_scan:
1290 if (intrin->def.bit_size != 64)
1291 return false;
1292
1293 switch (nir_intrinsic_reduction_op(intrin)) {
1294 case nir_op_iadd:
1295 return options->lower_int64_options & nir_lower_scan_reduce_iadd64;
1296 case nir_op_iand:
1297 case nir_op_ior:
1298 case nir_op_ixor:
1299 return options->lower_int64_options & nir_lower_scan_reduce_bitwise64;
1300 default:
1301 return false;
1302 }
1303 break;
1304
1305 default:
1306 return false;
1307 }
1308 }
1309
1310 static nir_def *
lower_int64_intrinsic(nir_builder * b,nir_intrinsic_instr * intrin)1311 lower_int64_intrinsic(nir_builder *b, nir_intrinsic_instr *intrin)
1312 {
1313 switch (intrin->intrinsic) {
1314 case nir_intrinsic_read_invocation:
1315 case nir_intrinsic_read_first_invocation:
1316 case nir_intrinsic_shuffle:
1317 case nir_intrinsic_shuffle_xor:
1318 case nir_intrinsic_shuffle_up:
1319 case nir_intrinsic_shuffle_down:
1320 case nir_intrinsic_quad_broadcast:
1321 case nir_intrinsic_quad_swap_horizontal:
1322 case nir_intrinsic_quad_swap_vertical:
1323 case nir_intrinsic_quad_swap_diagonal:
1324 return split_64bit_subgroup_op(b, intrin);
1325
1326 case nir_intrinsic_vote_ieq:
1327 return lower_vote_ieq(b, intrin->src[0].ssa);
1328
1329 case nir_intrinsic_reduce:
1330 case nir_intrinsic_inclusive_scan:
1331 case nir_intrinsic_exclusive_scan:
1332 switch (nir_intrinsic_reduction_op(intrin)) {
1333 case nir_op_iadd:
1334 return lower_scan_iadd64(b, intrin);
1335 case nir_op_iand:
1336 case nir_op_ior:
1337 case nir_op_ixor:
1338 return split_64bit_subgroup_op(b, intrin);
1339 default:
1340 unreachable("Unsupported subgroup scan/reduce op");
1341 }
1342 break;
1343
1344 default:
1345 unreachable("Unsupported intrinsic");
1346 }
1347 return NULL;
1348 }
1349
1350 static bool
should_lower_int64_instr(const nir_instr * instr,const void * _options)1351 should_lower_int64_instr(const nir_instr *instr, const void *_options)
1352 {
1353 switch (instr->type) {
1354 case nir_instr_type_alu:
1355 return should_lower_int64_alu_instr(nir_instr_as_alu(instr), _options);
1356 case nir_instr_type_intrinsic:
1357 return should_lower_int64_intrinsic(nir_instr_as_intrinsic(instr),
1358 _options);
1359 default:
1360 return false;
1361 }
1362 }
1363
1364 static nir_def *
lower_int64_instr(nir_builder * b,nir_instr * instr,void * _options)1365 lower_int64_instr(nir_builder *b, nir_instr *instr, void *_options)
1366 {
1367 switch (instr->type) {
1368 case nir_instr_type_alu:
1369 return lower_int64_alu_instr(b, nir_instr_as_alu(instr));
1370 case nir_instr_type_intrinsic:
1371 return lower_int64_intrinsic(b, nir_instr_as_intrinsic(instr));
1372 default:
1373 return NULL;
1374 }
1375 }
1376
1377 bool
nir_lower_int64(nir_shader * shader)1378 nir_lower_int64(nir_shader *shader)
1379 {
1380 return nir_shader_lower_instructions(shader, should_lower_int64_instr,
1381 lower_int64_instr,
1382 (void *)shader->options);
1383 }
1384
1385 static bool
should_lower_int64_float_conv(const nir_instr * instr,const void * _options)1386 should_lower_int64_float_conv(const nir_instr *instr, const void *_options)
1387 {
1388 if (instr->type != nir_instr_type_alu)
1389 return false;
1390
1391 nir_alu_instr *alu = nir_instr_as_alu(instr);
1392
1393 switch (alu->op) {
1394 case nir_op_i2f64:
1395 case nir_op_i2f32:
1396 case nir_op_i2f16:
1397 case nir_op_u2f64:
1398 case nir_op_u2f32:
1399 case nir_op_u2f16:
1400 case nir_op_f2i64:
1401 case nir_op_f2u64:
1402 return should_lower_int64_alu_instr(alu, _options);
1403 default:
1404 return false;
1405 }
1406 }
1407
1408 /**
1409 * Like nir_lower_int64(), but only lowers conversions to/from float.
1410 *
1411 * These operations in particular may affect double-precision lowering,
1412 * so it can be useful to run them in tandem with nir_lower_doubles().
1413 */
1414 bool
nir_lower_int64_float_conversions(nir_shader * shader)1415 nir_lower_int64_float_conversions(nir_shader *shader)
1416 {
1417 return nir_shader_lower_instructions(shader, should_lower_int64_float_conv,
1418 lower_int64_instr,
1419 (void *)shader->options);
1420 }
1421