1 /*
2 * Copyright © 2020 Collabora Ltd.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #ifndef NIR_CONVERSION_BUILDER_H
25 #define NIR_CONVERSION_BUILDER_H
26
27 #include "util/u_math.h"
28 #include "nir_builder.h"
29 #include "nir_builtin_builder.h"
30
31 #ifdef __cplusplus
32 extern "C" {
33 #endif
34
35 static inline nir_ssa_def *
nir_round_float_to_int(nir_builder * b,nir_ssa_def * src,nir_rounding_mode round)36 nir_round_float_to_int(nir_builder *b, nir_ssa_def *src,
37 nir_rounding_mode round)
38 {
39 switch (round) {
40 case nir_rounding_mode_ru:
41 return nir_fceil(b, src);
42
43 case nir_rounding_mode_rd:
44 return nir_ffloor(b, src);
45
46 case nir_rounding_mode_rtne:
47 return nir_fround_even(b, src);
48
49 case nir_rounding_mode_undef:
50 case nir_rounding_mode_rtz:
51 break;
52 }
53 unreachable("unexpected rounding mode");
54 }
55
56 static inline nir_ssa_def *
nir_round_float_to_float(nir_builder * b,nir_ssa_def * src,unsigned dest_bit_size,nir_rounding_mode round)57 nir_round_float_to_float(nir_builder *b, nir_ssa_def *src,
58 unsigned dest_bit_size,
59 nir_rounding_mode round)
60 {
61 unsigned src_bit_size = src->bit_size;
62 if (dest_bit_size > src_bit_size)
63 return src; /* No rounding is needed for an up-convert */
64
65 nir_op low_conv = nir_type_conversion_op(nir_type_float | src_bit_size,
66 nir_type_float | dest_bit_size,
67 nir_rounding_mode_undef);
68 nir_op high_conv = nir_type_conversion_op(nir_type_float | dest_bit_size,
69 nir_type_float | src_bit_size,
70 nir_rounding_mode_undef);
71
72 switch (round) {
73 case nir_rounding_mode_ru: {
74 /* If lower-precision conversion results in a lower value, push it
75 * up one ULP. */
76 nir_ssa_def *lower_prec =
77 nir_build_alu(b, low_conv, src, NULL, NULL, NULL);
78 nir_ssa_def *roundtrip =
79 nir_build_alu(b, high_conv, lower_prec, NULL, NULL, NULL);
80 nir_ssa_def *cmp = nir_flt(b, roundtrip, src);
81 nir_ssa_def *inf = nir_imm_floatN_t(b, INFINITY, dest_bit_size);
82 return nir_bcsel(b, cmp, nir_nextafter(b, lower_prec, inf), lower_prec);
83 }
84 case nir_rounding_mode_rd: {
85 /* If lower-precision conversion results in a higher value, push it
86 * down one ULP. */
87 nir_ssa_def *lower_prec =
88 nir_build_alu(b, low_conv, src, NULL, NULL, NULL);
89 nir_ssa_def *roundtrip =
90 nir_build_alu(b, high_conv, lower_prec, NULL, NULL, NULL);
91 nir_ssa_def *cmp = nir_flt(b, src, roundtrip);
92 nir_ssa_def *neg_inf = nir_imm_floatN_t(b, -INFINITY, dest_bit_size);
93 return nir_bcsel(b, cmp, nir_nextafter(b, lower_prec, neg_inf), lower_prec);
94 }
95 case nir_rounding_mode_rtz:
96 return nir_bcsel(b, nir_flt(b, src, nir_imm_zero(b, 1, src->bit_size)),
97 nir_round_float_to_float(b, src, dest_bit_size,
98 nir_rounding_mode_ru),
99 nir_round_float_to_float(b, src, dest_bit_size,
100 nir_rounding_mode_rd));
101 case nir_rounding_mode_rtne:
102 case nir_rounding_mode_undef:
103 break;
104 }
105 unreachable("unexpected rounding mode");
106 }
107
108 static inline nir_ssa_def *
nir_round_int_to_float(nir_builder * b,nir_ssa_def * src,nir_alu_type src_type,unsigned dest_bit_size,nir_rounding_mode round)109 nir_round_int_to_float(nir_builder *b, nir_ssa_def *src,
110 nir_alu_type src_type,
111 unsigned dest_bit_size,
112 nir_rounding_mode round)
113 {
114 /* We only care whether or not its signed */
115 src_type = nir_alu_type_get_base_type(src_type);
116
117 unsigned mantissa_bits;
118 switch (dest_bit_size) {
119 case 16:
120 mantissa_bits = 10;
121 break;
122 case 32:
123 mantissa_bits = 23;
124 break;
125 case 64:
126 mantissa_bits = 52;
127 break;
128 default: unreachable("Unsupported bit size");
129 }
130
131 if (src->bit_size < mantissa_bits)
132 return src;
133
134 if (src_type == nir_type_int) {
135 nir_ssa_def *sign =
136 nir_i2b1(b, nir_ishr(b, src, nir_imm_int(b, src->bit_size - 1)));
137 nir_ssa_def *abs = nir_iabs(b, src);
138 nir_ssa_def *positive_rounded =
139 nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, round);
140 nir_ssa_def *max_positive =
141 nir_imm_intN_t(b, (1ull << (src->bit_size - 1)) - 1, src->bit_size);
142 switch (round) {
143 case nir_rounding_mode_rtz:
144 return nir_bcsel(b, sign, nir_ineg(b, positive_rounded),
145 positive_rounded);
146 break;
147 case nir_rounding_mode_ru:
148 return nir_bcsel(b, sign,
149 nir_ineg(b, nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, nir_rounding_mode_rd)),
150 nir_umin(b, positive_rounded, max_positive));
151 break;
152 case nir_rounding_mode_rd:
153 return nir_bcsel(b, sign,
154 nir_ineg(b,
155 nir_umin(b, max_positive,
156 nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, nir_rounding_mode_ru))),
157 positive_rounded);
158 case nir_rounding_mode_rtne:
159 case nir_rounding_mode_undef:
160 break;
161 }
162 unreachable("unexpected rounding mode");
163 } else {
164 nir_ssa_def *mantissa_bit_size = nir_imm_int(b, mantissa_bits);
165 nir_ssa_def *msb = nir_imax(b, nir_ufind_msb(b, src), mantissa_bit_size);
166 nir_ssa_def *bits_to_lose = nir_isub(b, msb, mantissa_bit_size);
167 nir_ssa_def *one = nir_imm_intN_t(b, 1, src->bit_size);
168 nir_ssa_def *adjust = nir_ishl(b, one, bits_to_lose);
169 nir_ssa_def *mask = nir_inot(b, nir_isub(b, adjust, one));
170 nir_ssa_def *truncated = nir_iand(b, src, mask);
171 switch (round) {
172 case nir_rounding_mode_rtz:
173 case nir_rounding_mode_rd:
174 return truncated;
175 break;
176 case nir_rounding_mode_ru:
177 return nir_bcsel(b, nir_ieq(b, src, truncated),
178 src, nir_uadd_sat(b, truncated, adjust));
179 case nir_rounding_mode_rtne:
180 case nir_rounding_mode_undef:
181 break;
182 }
183 unreachable("unexpected rounding mode");
184 }
185 }
186
187 /** Returns true if the representable range of a contains the representable
188 * range of b.
189 */
190 static inline bool
nir_alu_type_range_contains_type_range(nir_alu_type a,nir_alu_type b)191 nir_alu_type_range_contains_type_range(nir_alu_type a, nir_alu_type b)
192 {
193 /* Split types from bit sizes */
194 nir_alu_type a_base_type = nir_alu_type_get_base_type(a);
195 nir_alu_type b_base_type = nir_alu_type_get_base_type(b);
196 unsigned a_bit_size = nir_alu_type_get_type_size(a);
197 unsigned b_bit_size = nir_alu_type_get_type_size(b);
198
199 /* This requires sized types */
200 assert(a_bit_size > 0 && b_bit_size > 0);
201
202 if (a_base_type == b_base_type && a_bit_size >= b_bit_size)
203 return true;
204
205 if (a_base_type == nir_type_int && b_base_type == nir_type_uint &&
206 a_bit_size > b_bit_size)
207 return true;
208
209 /* 16-bit floats fit in 32-bit integers */
210 if (a_base_type == nir_type_int && a_bit_size >= 32 &&
211 b == nir_type_float16)
212 return true;
213
214 /* All signed or unsigned ints can fit in float or above. A uint8 can fit
215 * in a float16.
216 */
217 if (a_base_type == nir_type_float && b_base_type != nir_type_float &&
218 (a_bit_size >= 32 || b_bit_size == 8))
219 return true;
220
221 return false;
222 }
223
224 /**
225 * Clamp the source value into the widest representatble range of the
226 * destination type with cmp + bcsel.
227 */
228 static inline nir_ssa_def *
nir_clamp_to_type_range(nir_builder * b,nir_ssa_def * src,nir_alu_type src_type,nir_alu_type dest_type)229 nir_clamp_to_type_range(nir_builder *b,
230 nir_ssa_def *src, nir_alu_type src_type,
231 nir_alu_type dest_type)
232 {
233 assert(nir_alu_type_get_type_size(src_type) == 0 ||
234 nir_alu_type_get_type_size(src_type) == src->bit_size);
235 src_type |= src->bit_size;
236 if (nir_alu_type_range_contains_type_range(dest_type, src_type))
237 return src;
238
239 /* Split types from bit sizes */
240 nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);
241 nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);
242 unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);
243 assert(dest_bit_size != 0);
244
245 /* limits of the destination type, expressed in the source type */
246 nir_ssa_def *low = NULL, *high = NULL;
247 switch (dest_base_type) {
248 case nir_type_int: {
249 int64_t ilow, ihigh;
250 if (dest_bit_size == 64) {
251 ilow = INT64_MIN;
252 ihigh = INT64_MAX;
253 } else {
254 ilow = -(1ll << (dest_bit_size - 1));
255 ihigh = (1ll << (dest_bit_size - 1)) - 1;
256 }
257
258 if (src_base_type == nir_type_int) {
259 low = nir_imm_intN_t(b, ilow, src->bit_size);
260 high = nir_imm_intN_t(b, ihigh, src->bit_size);
261 } else if (src_base_type == nir_type_uint) {
262 assert(src->bit_size >= dest_bit_size);
263 high = nir_imm_intN_t(b, ihigh, src->bit_size);
264 } else {
265 low = nir_imm_floatN_t(b, ilow, src->bit_size);
266 high = nir_imm_floatN_t(b, ihigh, src->bit_size);
267 }
268 break;
269 }
270 case nir_type_uint: {
271 uint64_t uhigh = dest_bit_size == 64 ?
272 ~0ull : (1ull << dest_bit_size) - 1;
273 if (src_base_type != nir_type_float) {
274 low = nir_imm_intN_t(b, 0, src->bit_size);
275 if (src_base_type == nir_type_uint || src->bit_size > dest_bit_size)
276 high = nir_imm_intN_t(b, uhigh, src->bit_size);
277 } else {
278 low = nir_imm_floatN_t(b, 0.0f, src->bit_size);
279 high = nir_imm_floatN_t(b, uhigh, src->bit_size);
280 }
281 break;
282 }
283 case nir_type_float: {
284 double flow, fhigh;
285 switch (dest_bit_size) {
286 case 16:
287 flow = -65504.0f;
288 fhigh = 65504.0f;
289 break;
290 case 32:
291 flow = -FLT_MAX;
292 fhigh = FLT_MAX;
293 break;
294 case 64:
295 flow = -DBL_MAX;
296 fhigh = DBL_MAX;
297 break;
298 default:
299 unreachable("Unhandled bit size");
300 }
301
302 switch (src_base_type) {
303 case nir_type_int: {
304 int64_t src_ilow, src_ihigh;
305 if (src->bit_size == 64) {
306 src_ilow = INT64_MIN;
307 src_ihigh = INT64_MAX;
308 } else {
309 src_ilow = -(1ll << (src->bit_size - 1));
310 src_ihigh = (1ll << (src->bit_size - 1)) - 1;
311 }
312 if (src_ilow < flow)
313 low = nir_imm_intN_t(b, flow, src->bit_size);
314 if (src_ihigh > fhigh)
315 high = nir_imm_intN_t(b, fhigh, src->bit_size);
316 break;
317 }
318 case nir_type_uint: {
319 uint64_t src_uhigh = src->bit_size == 64 ?
320 ~0ull : (1ull << src->bit_size) - 1;
321 if (src_uhigh > fhigh)
322 high = nir_imm_intN_t(b, fhigh, src->bit_size);
323 break;
324 }
325 case nir_type_float:
326 low = nir_imm_floatN_t(b, flow, src->bit_size);
327 high = nir_imm_floatN_t(b, fhigh, src->bit_size);
328 break;
329 default:
330 unreachable("Clamping from unknown type");
331 }
332 break;
333 }
334 default:
335 unreachable("clamping to unknown type");
336 break;
337 }
338
339 nir_ssa_def *low_cond = NULL, *high_cond = NULL;
340 switch (src_base_type) {
341 case nir_type_int:
342 low_cond = low ? nir_ilt(b, src, low) : NULL;
343 high_cond = high ? nir_ilt(b, high, src) : NULL;
344 break;
345 case nir_type_uint:
346 low_cond = low ? nir_ult(b, src, low) : NULL;
347 high_cond = high ? nir_ult(b, high, src) : NULL;
348 break;
349 case nir_type_float:
350 low_cond = low ? nir_flt(b, src, low) : NULL;
351 high_cond = high ? nir_flt(b, high, src) : NULL;
352 break;
353 default:
354 unreachable("clamping from unknown type");
355 }
356
357 nir_ssa_def *res = src;
358 if (low_cond)
359 res = nir_bcsel(b, low_cond, low, res);
360 if (high_cond)
361 res = nir_bcsel(b, high_cond, high, res);
362
363 return res;
364 }
365
366 static inline nir_rounding_mode
nir_simplify_conversion_rounding(nir_alu_type src_type,nir_alu_type dest_type,nir_rounding_mode rounding)367 nir_simplify_conversion_rounding(nir_alu_type src_type,
368 nir_alu_type dest_type,
369 nir_rounding_mode rounding)
370 {
371 nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);
372 nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);
373 unsigned src_bit_size = nir_alu_type_get_type_size(src_type);
374 unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);
375 assert(src_bit_size > 0 && dest_bit_size > 0);
376
377 if (rounding == nir_rounding_mode_undef)
378 return rounding;
379
380 /* Pure integer conversion doesn't have any rounding */
381 if (src_base_type != nir_type_float &&
382 dest_base_type != nir_type_float)
383 return nir_rounding_mode_undef;
384
385 /* Float down-casts don't round */
386 if (src_base_type == nir_type_float &&
387 dest_base_type == nir_type_float &&
388 dest_bit_size >= src_bit_size)
389 return nir_rounding_mode_undef;
390
391 /* Regular float to int conversions are RTZ */
392 if (src_base_type == nir_type_float &&
393 dest_base_type != nir_type_float &&
394 rounding == nir_rounding_mode_rtz)
395 return nir_rounding_mode_undef;
396
397 /* The CL spec requires regular conversions to float to be RTNE */
398 if (dest_base_type == nir_type_float &&
399 rounding == nir_rounding_mode_rtne)
400 return nir_rounding_mode_undef;
401
402 /* Couldn't simplify */
403 return rounding;
404 }
405
406 static inline nir_ssa_def *
nir_convert_with_rounding(nir_builder * b,nir_ssa_def * src,nir_alu_type src_type,nir_alu_type dest_type,nir_rounding_mode round,bool clamp)407 nir_convert_with_rounding(nir_builder *b,
408 nir_ssa_def *src, nir_alu_type src_type,
409 nir_alu_type dest_type,
410 nir_rounding_mode round,
411 bool clamp)
412 {
413 /* Some stuff wants sized types */
414 assert(nir_alu_type_get_type_size(src_type) == 0 ||
415 nir_alu_type_get_type_size(src_type) == src->bit_size);
416 src_type |= src->bit_size;
417
418 /* Split types from bit sizes */
419 nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);
420 nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);
421 unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);
422
423 /* Try to simplify the conversion if we can */
424 clamp = clamp &&
425 !nir_alu_type_range_contains_type_range(dest_type, src_type);
426 round = nir_simplify_conversion_rounding(src_type, dest_type, round);
427
428 /*
429 * If we don't care about rounding and clamping, we can just use NIR's
430 * built-in ops. There is also a special case for SPIR-V in shaders, where
431 * f32/f64 -> f16 conversions can have one of two rounding modes applied,
432 * which NIR has built-in opcodes for.
433 *
434 * For the rest, we have our own implementation of rounding and clamping.
435 */
436 bool trivial_convert;
437 if (!clamp && round == nir_rounding_mode_undef) {
438 trivial_convert = true;
439 } else if (!clamp && src_type == nir_type_float32 &&
440 dest_type == nir_type_float16 &&
441 (round == nir_rounding_mode_rtne ||
442 round == nir_rounding_mode_rtz)) {
443 trivial_convert = true;
444 } else {
445 trivial_convert = false;
446 }
447 if (trivial_convert) {
448 nir_op op = nir_type_conversion_op(src_type, dest_type, round);
449 return nir_build_alu(b, op, src, NULL, NULL, NULL);
450 }
451
452 nir_ssa_def *dest = src;
453
454 /* clamp the result into range */
455 if (clamp)
456 dest = nir_clamp_to_type_range(b, dest, src_type, dest_type);
457
458 /* round with selected rounding mode */
459 if (!trivial_convert && round != nir_rounding_mode_undef) {
460 if (src_base_type == nir_type_float) {
461 if (dest_base_type == nir_type_float) {
462 dest = nir_round_float_to_float(b, dest, dest_bit_size, round);
463 } else {
464 dest = nir_round_float_to_int(b, dest, round);
465 }
466 } else {
467 dest = nir_round_int_to_float(b, dest, src_type, dest_bit_size, round);
468 }
469
470 round = nir_rounding_mode_undef;
471 }
472
473 /* now we can convert the value */
474 nir_op op = nir_type_conversion_op(src_type, dest_type, round);
475 return nir_build_alu(b, op, dest, NULL, NULL, NULL);
476 }
477
478 #ifdef __cplusplus
479 }
480 #endif
481
482 #endif /* NIR_CONVERSION_BUILDER_H */
483