• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2020 Collabora Ltd.
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #ifndef NIR_CONVERSION_BUILDER_H
25 #define NIR_CONVERSION_BUILDER_H
26 
27 #include "util/u_math.h"
28 #include "nir_builder.h"
29 #include "nir_builtin_builder.h"
30 
31 #ifdef __cplusplus
32 extern "C" {
33 #endif
34 
35 static inline nir_ssa_def *
nir_round_float_to_int(nir_builder * b,nir_ssa_def * src,nir_rounding_mode round)36 nir_round_float_to_int(nir_builder *b, nir_ssa_def *src,
37                        nir_rounding_mode round)
38 {
39    switch (round) {
40    case nir_rounding_mode_ru:
41       return nir_fceil(b, src);
42 
43    case nir_rounding_mode_rd:
44       return nir_ffloor(b, src);
45 
46    case nir_rounding_mode_rtne:
47       return nir_fround_even(b, src);
48 
49    case nir_rounding_mode_undef:
50    case nir_rounding_mode_rtz:
51       break;
52    }
53    unreachable("unexpected rounding mode");
54 }
55 
56 static inline nir_ssa_def *
nir_round_float_to_float(nir_builder * b,nir_ssa_def * src,unsigned dest_bit_size,nir_rounding_mode round)57 nir_round_float_to_float(nir_builder *b, nir_ssa_def *src,
58                          unsigned dest_bit_size,
59                          nir_rounding_mode round)
60 {
61    unsigned src_bit_size = src->bit_size;
62    if (dest_bit_size > src_bit_size)
63       return src; /* No rounding is needed for an up-convert */
64 
65    nir_op low_conv = nir_type_conversion_op(nir_type_float | src_bit_size,
66                                             nir_type_float | dest_bit_size,
67                                             nir_rounding_mode_undef);
68    nir_op high_conv = nir_type_conversion_op(nir_type_float | dest_bit_size,
69                                              nir_type_float | src_bit_size,
70                                              nir_rounding_mode_undef);
71 
72    switch (round) {
73    case nir_rounding_mode_ru: {
74       /* If lower-precision conversion results in a lower value, push it
75       * up one ULP. */
76       nir_ssa_def *lower_prec =
77          nir_build_alu(b, low_conv, src, NULL, NULL, NULL);
78       nir_ssa_def *roundtrip =
79          nir_build_alu(b, high_conv, lower_prec, NULL, NULL, NULL);
80       nir_ssa_def *cmp = nir_flt(b, roundtrip, src);
81       nir_ssa_def *inf = nir_imm_floatN_t(b, INFINITY, dest_bit_size);
82       return nir_bcsel(b, cmp, nir_nextafter(b, lower_prec, inf), lower_prec);
83    }
84    case nir_rounding_mode_rd: {
85       /* If lower-precision conversion results in a higher value, push it
86       * down one ULP. */
87       nir_ssa_def *lower_prec =
88          nir_build_alu(b, low_conv, src, NULL, NULL, NULL);
89       nir_ssa_def *roundtrip =
90          nir_build_alu(b, high_conv, lower_prec, NULL, NULL, NULL);
91       nir_ssa_def *cmp = nir_flt(b, src, roundtrip);
92       nir_ssa_def *neg_inf = nir_imm_floatN_t(b, -INFINITY, dest_bit_size);
93       return nir_bcsel(b, cmp, nir_nextafter(b, lower_prec, neg_inf), lower_prec);
94    }
95    case nir_rounding_mode_rtz:
96       return nir_bcsel(b, nir_flt(b, src, nir_imm_zero(b, 1, src->bit_size)),
97                           nir_round_float_to_float(b, src, dest_bit_size,
98                                                    nir_rounding_mode_ru),
99                           nir_round_float_to_float(b, src, dest_bit_size,
100                                                    nir_rounding_mode_rd));
101    case nir_rounding_mode_rtne:
102    case nir_rounding_mode_undef:
103       break;
104    }
105    unreachable("unexpected rounding mode");
106 }
107 
108 static inline nir_ssa_def *
nir_round_int_to_float(nir_builder * b,nir_ssa_def * src,nir_alu_type src_type,unsigned dest_bit_size,nir_rounding_mode round)109 nir_round_int_to_float(nir_builder *b, nir_ssa_def *src,
110                        nir_alu_type src_type,
111                        unsigned dest_bit_size,
112                        nir_rounding_mode round)
113 {
114    /* We only care whether or not its signed */
115    src_type = nir_alu_type_get_base_type(src_type);
116 
117    unsigned mantissa_bits;
118    switch (dest_bit_size) {
119    case 16:
120       mantissa_bits = 10;
121       break;
122    case 32:
123       mantissa_bits = 23;
124       break;
125    case 64:
126       mantissa_bits = 52;
127       break;
128    default: unreachable("Unsupported bit size");
129    }
130 
131    if (src->bit_size < mantissa_bits)
132       return src;
133 
134    if (src_type == nir_type_int) {
135       nir_ssa_def *sign =
136          nir_i2b1(b, nir_ishr(b, src, nir_imm_int(b, src->bit_size - 1)));
137       nir_ssa_def *abs = nir_iabs(b, src);
138       nir_ssa_def *positive_rounded =
139          nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, round);
140       nir_ssa_def *max_positive =
141          nir_imm_intN_t(b, (1ull << (src->bit_size - 1)) - 1, src->bit_size);
142       switch (round) {
143       case nir_rounding_mode_rtz:
144          return nir_bcsel(b, sign, nir_ineg(b, positive_rounded),
145                                    positive_rounded);
146          break;
147       case nir_rounding_mode_ru:
148          return nir_bcsel(b, sign,
149                           nir_ineg(b, nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, nir_rounding_mode_rd)),
150                           nir_umin(b, positive_rounded, max_positive));
151          break;
152       case nir_rounding_mode_rd:
153          return nir_bcsel(b, sign,
154                           nir_ineg(b,
155                                    nir_umin(b, max_positive,
156                                             nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, nir_rounding_mode_ru))),
157                           positive_rounded);
158       case nir_rounding_mode_rtne:
159       case nir_rounding_mode_undef:
160          break;
161       }
162       unreachable("unexpected rounding mode");
163    } else {
164       nir_ssa_def *mantissa_bit_size = nir_imm_int(b, mantissa_bits);
165       nir_ssa_def *msb = nir_imax(b, nir_ufind_msb(b, src), mantissa_bit_size);
166       nir_ssa_def *bits_to_lose = nir_isub(b, msb, mantissa_bit_size);
167       nir_ssa_def *one = nir_imm_intN_t(b, 1, src->bit_size);
168       nir_ssa_def *adjust = nir_ishl(b, one, bits_to_lose);
169       nir_ssa_def *mask = nir_inot(b, nir_isub(b, adjust, one));
170       nir_ssa_def *truncated = nir_iand(b, src, mask);
171       switch (round) {
172       case nir_rounding_mode_rtz:
173       case nir_rounding_mode_rd:
174          return truncated;
175          break;
176       case nir_rounding_mode_ru:
177          return nir_bcsel(b, nir_ieq(b, src, truncated),
178                              src, nir_uadd_sat(b, truncated, adjust));
179       case nir_rounding_mode_rtne:
180       case nir_rounding_mode_undef:
181          break;
182       }
183       unreachable("unexpected rounding mode");
184    }
185 }
186 
187 /** Returns true if the representable range of a contains the representable
188  * range of b.
189  */
190 static inline bool
nir_alu_type_range_contains_type_range(nir_alu_type a,nir_alu_type b)191 nir_alu_type_range_contains_type_range(nir_alu_type a, nir_alu_type b)
192 {
193    /* Split types from bit sizes */
194    nir_alu_type a_base_type = nir_alu_type_get_base_type(a);
195    nir_alu_type b_base_type = nir_alu_type_get_base_type(b);
196    unsigned a_bit_size = nir_alu_type_get_type_size(a);
197    unsigned b_bit_size = nir_alu_type_get_type_size(b);
198 
199    /* This requires sized types */
200    assert(a_bit_size > 0 && b_bit_size > 0);
201 
202    if (a_base_type == b_base_type && a_bit_size >= b_bit_size)
203       return true;
204 
205    if (a_base_type == nir_type_int && b_base_type == nir_type_uint &&
206        a_bit_size > b_bit_size)
207       return true;
208 
209    /* 16-bit floats fit in 32-bit integers */
210    if (a_base_type == nir_type_int && a_bit_size >= 32 &&
211        b == nir_type_float16)
212       return true;
213 
214    /* All signed or unsigned ints can fit in float or above. A uint8 can fit
215     * in a float16.
216     */
217    if (a_base_type == nir_type_float && b_base_type != nir_type_float &&
218        (a_bit_size >= 32 || b_bit_size == 8))
219       return true;
220 
221    return false;
222 }
223 
224 /**
225  * Clamp the source value into the widest representatble range of the
226  * destination type with cmp + bcsel.
227  */
228 static inline nir_ssa_def *
nir_clamp_to_type_range(nir_builder * b,nir_ssa_def * src,nir_alu_type src_type,nir_alu_type dest_type)229 nir_clamp_to_type_range(nir_builder *b,
230                         nir_ssa_def *src, nir_alu_type src_type,
231                         nir_alu_type dest_type)
232 {
233    assert(nir_alu_type_get_type_size(src_type) == 0 ||
234           nir_alu_type_get_type_size(src_type) == src->bit_size);
235    src_type |= src->bit_size;
236    if (nir_alu_type_range_contains_type_range(dest_type, src_type))
237       return src;
238 
239    /* Split types from bit sizes */
240    nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);
241    nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);
242    unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);
243    assert(dest_bit_size != 0);
244 
245    /* limits of the destination type, expressed in the source type */
246    nir_ssa_def *low = NULL, *high = NULL;
247    switch (dest_base_type) {
248    case nir_type_int: {
249       int64_t ilow, ihigh;
250       if (dest_bit_size == 64) {
251          ilow = INT64_MIN;
252          ihigh = INT64_MAX;
253       } else {
254          ilow = -(1ll << (dest_bit_size - 1));
255          ihigh = (1ll << (dest_bit_size - 1)) - 1;
256       }
257 
258       if (src_base_type == nir_type_int) {
259          low = nir_imm_intN_t(b, ilow, src->bit_size);
260          high = nir_imm_intN_t(b, ihigh, src->bit_size);
261       } else if (src_base_type == nir_type_uint) {
262          assert(src->bit_size >= dest_bit_size);
263          high = nir_imm_intN_t(b, ihigh, src->bit_size);
264       } else {
265          low = nir_imm_floatN_t(b, ilow, src->bit_size);
266          high = nir_imm_floatN_t(b, ihigh, src->bit_size);
267       }
268       break;
269    }
270    case nir_type_uint: {
271       uint64_t uhigh = dest_bit_size == 64 ?
272          ~0ull : (1ull << dest_bit_size) - 1;
273       if (src_base_type != nir_type_float) {
274          low = nir_imm_intN_t(b, 0, src->bit_size);
275          if (src_base_type == nir_type_uint || src->bit_size > dest_bit_size)
276             high = nir_imm_intN_t(b, uhigh, src->bit_size);
277       } else {
278          low = nir_imm_floatN_t(b, 0.0f, src->bit_size);
279          high = nir_imm_floatN_t(b, uhigh, src->bit_size);
280       }
281       break;
282    }
283    case nir_type_float: {
284       double flow, fhigh;
285       switch (dest_bit_size) {
286       case 16:
287          flow = -65504.0f;
288          fhigh = 65504.0f;
289          break;
290       case 32:
291          flow = -FLT_MAX;
292          fhigh = FLT_MAX;
293          break;
294       case 64:
295          flow = -DBL_MAX;
296          fhigh = DBL_MAX;
297          break;
298       default:
299          unreachable("Unhandled bit size");
300       }
301 
302       switch (src_base_type) {
303       case nir_type_int: {
304          int64_t src_ilow, src_ihigh;
305          if (src->bit_size == 64) {
306             src_ilow = INT64_MIN;
307             src_ihigh = INT64_MAX;
308          } else {
309             src_ilow = -(1ll << (src->bit_size - 1));
310             src_ihigh = (1ll << (src->bit_size - 1)) - 1;
311          }
312          if (src_ilow < flow)
313             low = nir_imm_intN_t(b, flow, src->bit_size);
314          if (src_ihigh > fhigh)
315             high = nir_imm_intN_t(b, fhigh, src->bit_size);
316          break;
317       }
318       case nir_type_uint: {
319          uint64_t src_uhigh = src->bit_size == 64 ?
320             ~0ull : (1ull << src->bit_size) - 1;
321          if (src_uhigh > fhigh)
322             high = nir_imm_intN_t(b, fhigh, src->bit_size);
323          break;
324       }
325       case nir_type_float:
326          low = nir_imm_floatN_t(b, flow, src->bit_size);
327          high = nir_imm_floatN_t(b, fhigh, src->bit_size);
328          break;
329       default:
330          unreachable("Clamping from unknown type");
331       }
332       break;
333    }
334    default:
335       unreachable("clamping to unknown type");
336       break;
337    }
338 
339    nir_ssa_def *low_cond = NULL, *high_cond = NULL;
340    switch (src_base_type) {
341    case nir_type_int:
342       low_cond = low ? nir_ilt(b, src, low) : NULL;
343       high_cond = high ? nir_ilt(b, high, src) : NULL;
344       break;
345    case nir_type_uint:
346       low_cond = low ? nir_ult(b, src, low) : NULL;
347       high_cond = high ? nir_ult(b, high, src) : NULL;
348       break;
349    case nir_type_float:
350       low_cond = low ? nir_flt(b, src, low) : NULL;
351       high_cond = high ? nir_flt(b, high, src) : NULL;
352       break;
353    default:
354       unreachable("clamping from unknown type");
355    }
356 
357    nir_ssa_def *res = src;
358    if (low_cond)
359       res = nir_bcsel(b, low_cond, low, res);
360    if (high_cond)
361       res = nir_bcsel(b, high_cond, high, res);
362 
363    return res;
364 }
365 
366 static inline nir_rounding_mode
nir_simplify_conversion_rounding(nir_alu_type src_type,nir_alu_type dest_type,nir_rounding_mode rounding)367 nir_simplify_conversion_rounding(nir_alu_type src_type,
368                                  nir_alu_type dest_type,
369                                  nir_rounding_mode rounding)
370 {
371    nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);
372    nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);
373    unsigned src_bit_size = nir_alu_type_get_type_size(src_type);
374    unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);
375    assert(src_bit_size > 0 && dest_bit_size > 0);
376 
377    if (rounding == nir_rounding_mode_undef)
378       return rounding;
379 
380    /* Pure integer conversion doesn't have any rounding */
381    if (src_base_type != nir_type_float &&
382        dest_base_type != nir_type_float)
383       return nir_rounding_mode_undef;
384 
385    /* Float down-casts don't round */
386    if (src_base_type == nir_type_float &&
387        dest_base_type == nir_type_float &&
388        dest_bit_size >= src_bit_size)
389       return nir_rounding_mode_undef;
390 
391    /* Regular float to int conversions are RTZ */
392    if (src_base_type == nir_type_float &&
393        dest_base_type != nir_type_float &&
394        rounding == nir_rounding_mode_rtz)
395       return nir_rounding_mode_undef;
396 
397    /* The CL spec requires regular conversions to float to be RTNE */
398    if (dest_base_type == nir_type_float &&
399        rounding == nir_rounding_mode_rtne)
400       return nir_rounding_mode_undef;
401 
402    /* Couldn't simplify */
403    return rounding;
404 }
405 
406 static inline nir_ssa_def *
nir_convert_with_rounding(nir_builder * b,nir_ssa_def * src,nir_alu_type src_type,nir_alu_type dest_type,nir_rounding_mode round,bool clamp)407 nir_convert_with_rounding(nir_builder *b,
408                           nir_ssa_def *src, nir_alu_type src_type,
409                           nir_alu_type dest_type,
410                           nir_rounding_mode round,
411                           bool clamp)
412 {
413    /* Some stuff wants sized types */
414    assert(nir_alu_type_get_type_size(src_type) == 0 ||
415           nir_alu_type_get_type_size(src_type) == src->bit_size);
416    src_type |= src->bit_size;
417 
418    /* Split types from bit sizes */
419    nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type);
420    nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type);
421    unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type);
422 
423    /* Try to simplify the conversion if we can */
424    clamp = clamp &&
425       !nir_alu_type_range_contains_type_range(dest_type, src_type);
426    round = nir_simplify_conversion_rounding(src_type, dest_type, round);
427 
428    /*
429     * If we don't care about rounding and clamping, we can just use NIR's
430     * built-in ops. There is also a special case for SPIR-V in shaders, where
431     * f32/f64 -> f16 conversions can have one of two rounding modes applied,
432     * which NIR has built-in opcodes for.
433     *
434     * For the rest, we have our own implementation of rounding and clamping.
435     */
436    bool trivial_convert;
437    if (!clamp && round == nir_rounding_mode_undef) {
438       trivial_convert = true;
439    } else if (!clamp && src_type == nir_type_float32 &&
440                         dest_type == nir_type_float16 &&
441                         (round == nir_rounding_mode_rtne ||
442                          round == nir_rounding_mode_rtz)) {
443       trivial_convert = true;
444    } else {
445       trivial_convert = false;
446    }
447    if (trivial_convert) {
448       nir_op op = nir_type_conversion_op(src_type, dest_type, round);
449       return nir_build_alu(b, op, src, NULL, NULL, NULL);
450    }
451 
452    nir_ssa_def *dest = src;
453 
454    /* clamp the result into range */
455    if (clamp)
456       dest = nir_clamp_to_type_range(b, dest, src_type, dest_type);
457 
458    /* round with selected rounding mode */
459    if (!trivial_convert && round != nir_rounding_mode_undef) {
460       if (src_base_type == nir_type_float) {
461          if (dest_base_type == nir_type_float) {
462             dest = nir_round_float_to_float(b, dest, dest_bit_size, round);
463          } else {
464             dest = nir_round_float_to_int(b, dest, round);
465          }
466       } else {
467          dest = nir_round_int_to_float(b, dest, src_type, dest_bit_size, round);
468       }
469 
470       round = nir_rounding_mode_undef;
471    }
472 
473    /* now we can convert the value */
474    nir_op op = nir_type_conversion_op(src_type, dest_type, round);
475    return nir_build_alu(b, op, dest, NULL, NULL, NULL);
476 }
477 
478 #ifdef __cplusplus
479 }
480 #endif
481 
482 #endif /* NIR_CONVERSION_BUILDER_H */
483