• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2015 Intel Corporation
3  * Copyright © 2019 Valve Corporation
4  *
5  * Permission is hereby granted, free of charge, to any person obtaining a
6  * copy of this software and associated documentation files (the "Software"),
7  * to deal in the Software without restriction, including without limitation
8  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
9  * and/or sell copies of the Software, and to permit persons to whom the
10  * Software is furnished to do so, subject to the following conditions:
11  *
12  * The above copyright notice and this permission notice (including the next
13  * paragraph) shall be included in all copies or substantial portions of the
14  * Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
19  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
22  * IN THE SOFTWARE.
23  *
24  * Authors:
25  *    Jason Ekstrand (jason@jlekstrand.net)
26  *    Samuel Pitoiset (samuel.pitoiset@gmail.com>
27  */
28 
29 #include "nir.h"
30 #include "nir_builder.h"
31 
32 static nir_ssa_def *
lower_frexp_sig(nir_builder * b,nir_ssa_def * x)33 lower_frexp_sig(nir_builder *b, nir_ssa_def *x)
34 {
35    nir_ssa_def *abs_x = nir_fabs(b, x);
36    nir_ssa_def *zero = nir_imm_floatN_t(b, 0, x->bit_size);
37    nir_ssa_def *sign_mantissa_mask, *exponent_value;
38 
39    switch (x->bit_size) {
40    case 16:
41       /* Half-precision floating-point values are stored as
42        *   1 sign bit;
43        *   5 exponent bits;
44        *   10 mantissa bits.
45        *
46        * An exponent shift of 10 will shift the mantissa out, leaving only the
47        * exponent and sign bit (which itself may be zero, if the absolute value
48        * was taken before the bitcast and shift).
49        */
50       sign_mantissa_mask = nir_imm_intN_t(b, 0x83ffu, 16);
51       /* Exponent of floating-point values in the range [0.5, 1.0). */
52       exponent_value = nir_imm_intN_t(b, 0x3800u, 16);
53       break;
54    case 32:
55       /* Single-precision floating-point values are stored as
56        *   1 sign bit;
57        *   8 exponent bits;
58        *   23 mantissa bits.
59        *
60        * An exponent shift of 23 will shift the mantissa out, leaving only the
61        * exponent and sign bit (which itself may be zero, if the absolute value
62        * was taken before the bitcast and shift.
63        */
64       sign_mantissa_mask = nir_imm_int(b, 0x807fffffu);
65       /* Exponent of floating-point values in the range [0.5, 1.0). */
66       exponent_value = nir_imm_int(b, 0x3f000000u);
67       break;
68    case 64:
69       /* Double-precision floating-point values are stored as
70        *   1 sign bit;
71        *   11 exponent bits;
72        *   52 mantissa bits.
73        *
74        * An exponent shift of 20 will shift the remaining mantissa bits out,
75        * leaving only the exponent and sign bit (which itself may be zero, if
76        * the absolute value was taken before the bitcast and shift.
77        */
78       sign_mantissa_mask = nir_imm_int(b, 0x800fffffu);
79       /* Exponent of floating-point values in the range [0.5, 1.0). */
80       exponent_value = nir_imm_int(b, 0x3fe00000u);
81       break;
82    default:
83       unreachable("Invalid bitsize");
84    }
85 
86    if (x->bit_size == 64) {
87       /* We only need to deal with the exponent so first we extract the upper
88        * 32 bits using nir_unpack_64_2x32_split_y.
89        */
90       nir_ssa_def *upper_x = nir_unpack_64_2x32_split_y(b, x);
91 
92       /* If x is ±0, ±Inf, or NaN, return x unmodified. */
93       nir_ssa_def *new_upper =
94          nir_bcsel(b,
95                    nir_iand(b,
96                             nir_flt(b, zero, abs_x),
97                             nir_fisfinite(b, x)),
98                    nir_ior(b,
99                            nir_iand(b, upper_x, sign_mantissa_mask),
100                            exponent_value),
101                    upper_x);
102 
103       nir_ssa_def *lower_x = nir_unpack_64_2x32_split_x(b, x);
104 
105       return nir_pack_64_2x32_split(b, lower_x, new_upper);
106    } else {
107       /* If x is ±0, ±Inf, or NaN, return x unmodified. */
108       return nir_bcsel(b,
109                        nir_iand(b,
110                                 nir_flt(b, zero, abs_x),
111                                 nir_fisfinite(b, x)),
112                        nir_ior(b,
113                                nir_iand(b, x, sign_mantissa_mask),
114                                exponent_value),
115                        x);
116    }
117 }
118 
119 static nir_ssa_def *
lower_frexp_exp(nir_builder * b,nir_ssa_def * x)120 lower_frexp_exp(nir_builder *b, nir_ssa_def *x)
121 {
122    nir_ssa_def *abs_x = nir_fabs(b, x);
123    nir_ssa_def *zero = nir_imm_floatN_t(b, 0, x->bit_size);
124    nir_ssa_def *is_not_zero = nir_fneu(b, abs_x, zero);
125    nir_ssa_def *exponent;
126 
127    switch (x->bit_size) {
128    case 16: {
129       nir_ssa_def *exponent_shift = nir_imm_int(b, 10);
130       nir_ssa_def *exponent_bias = nir_imm_intN_t(b, -14, 16);
131 
132       /* Significand return must be of the same type as the input, but the
133        * exponent must be a 32-bit integer.
134        */
135       exponent = nir_i2i32(b, nir_iadd(b, nir_ushr(b, abs_x, exponent_shift),
136                               nir_bcsel(b, is_not_zero, exponent_bias, zero)));
137       break;
138    }
139    case 32: {
140       nir_ssa_def *exponent_shift = nir_imm_int(b, 23);
141       nir_ssa_def *exponent_bias = nir_imm_int(b, -126);
142 
143       exponent = nir_iadd(b, nir_ushr(b, abs_x, exponent_shift),
144                              nir_bcsel(b, is_not_zero, exponent_bias, zero));
145       break;
146    }
147    case 64: {
148       nir_ssa_def *exponent_shift = nir_imm_int(b, 20);
149       nir_ssa_def *exponent_bias = nir_imm_int(b, -1022);
150 
151       nir_ssa_def *zero32 = nir_imm_int(b, 0);
152       nir_ssa_def *abs_upper_x = nir_unpack_64_2x32_split_y(b, abs_x);
153 
154       exponent = nir_iadd(b, nir_ushr(b, abs_upper_x, exponent_shift),
155                              nir_bcsel(b, is_not_zero, exponent_bias, zero32));
156       break;
157    }
158    default:
159       unreachable("Invalid bitsize");
160    }
161 
162    return exponent;
163 }
164 
165 static bool
lower_frexp_impl(nir_function_impl * impl)166 lower_frexp_impl(nir_function_impl *impl)
167 {
168    bool progress = false;
169 
170    nir_builder b;
171    nir_builder_init(&b, impl);
172 
173    nir_foreach_block(block, impl) {
174       nir_foreach_instr_safe(instr, block) {
175          if (instr->type != nir_instr_type_alu)
176             continue;
177 
178          nir_alu_instr *alu_instr = nir_instr_as_alu(instr);
179          nir_ssa_def *lower;
180 
181          b.cursor = nir_before_instr(instr);
182 
183          switch (alu_instr->op) {
184          case nir_op_frexp_sig:
185             lower = lower_frexp_sig(&b, nir_ssa_for_alu_src(&b, alu_instr, 0));
186             break;
187          case nir_op_frexp_exp:
188             lower = lower_frexp_exp(&b, nir_ssa_for_alu_src(&b, alu_instr, 0));
189             break;
190          default:
191             continue;
192          }
193 
194          nir_ssa_def_rewrite_uses(&alu_instr->dest.dest.ssa,
195                                   lower);
196          nir_instr_remove(instr);
197          progress = true;
198       }
199    }
200 
201    if (progress) {
202       nir_metadata_preserve(impl, nir_metadata_block_index |
203                                   nir_metadata_dominance);
204    }
205 
206    return progress;
207 }
208 
209 bool
nir_lower_frexp(nir_shader * shader)210 nir_lower_frexp(nir_shader *shader)
211 {
212    bool progress = false;
213 
214    nir_foreach_function(function, shader) {
215       if (function->impl)
216          progress |= lower_frexp_impl(function->impl);
217    }
218 
219    return progress;
220 }
221