• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © Microsoft Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include "nir_builder.h"
25 
26 /* The following float-to-half conversion routines are based on the "half" library:
27  * https://sourceforge.net/projects/half/
28  *
29  * half - IEEE 754-based half-precision floating-point library.
30  *
31  * Copyright (c) 2012-2019 Christian Rau <rauy@users.sourceforge.net>
32  *
33  * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation
34  * files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy,
35  * modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the
36  * Software is furnished to do so, subject to the following conditions:
37  *
38  * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
39  *
40  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
41  * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
42  * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
43  * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
44  *
45  * Version 2.1.0
46  */
47 static bool
lower_fp16_casts_filter(const nir_instr * instr,const void * data)48 lower_fp16_casts_filter(const nir_instr *instr, const void *data)
49 {
50    if (instr->type == nir_instr_type_alu) {
51       nir_alu_instr *alu = nir_instr_as_alu(instr);
52       switch (alu->op) {
53       case nir_op_f2f16:
54       case nir_op_f2f16_rtne:
55       case nir_op_f2f16_rtz:
56          return true;
57       default:
58          return false;
59       }
60    } else if (instr->type == nir_instr_type_intrinsic) {
61       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
62       return intrin->intrinsic == nir_intrinsic_convert_alu_types &&
63          nir_intrinsic_dest_type(intrin) == nir_type_float16;
64    }
65    return false;
66 }
67 
68 static nir_ssa_def *
half_rounded(nir_builder * b,nir_ssa_def * value,nir_ssa_def * guard,nir_ssa_def * sticky,nir_ssa_def * sign,nir_rounding_mode mode)69 half_rounded(nir_builder *b, nir_ssa_def *value, nir_ssa_def *guard, nir_ssa_def *sticky,
70              nir_ssa_def *sign, nir_rounding_mode mode)
71 {
72    switch (mode) {
73    case nir_rounding_mode_rtne:
74       return nir_iadd(b, value, nir_iand(b, guard, nir_ior(b, sticky, value)));
75    case nir_rounding_mode_ru:
76       sign = nir_ushr(b, sign, nir_imm_int(b, 31));
77       return nir_iadd(b, value, nir_iand(b, nir_inot(b, sign),
78                                             nir_ior(b, guard, sticky)));
79    case nir_rounding_mode_rd:
80       sign = nir_ushr(b, sign, nir_imm_int(b, 31));
81       return nir_iadd(b, value, nir_iand(b, sign,
82                                             nir_ior(b, guard, sticky)));
83    default:
84       return value;
85    }
86 }
87 
88 static nir_ssa_def *
float_to_half_impl(nir_builder * b,nir_ssa_def * src,nir_rounding_mode mode)89 float_to_half_impl(nir_builder *b, nir_ssa_def *src, nir_rounding_mode mode)
90 {
91    nir_ssa_def *f32infinity = nir_imm_int(b, 255 << 23);
92    nir_ssa_def *f16max = nir_imm_int(b, (127 + 16) << 23);
93 
94    if (src->bit_size == 64)
95       src = nir_f2f32(b, src);
96    nir_ssa_def *sign = nir_iand(b, src, nir_imm_int(b, 0x80000000));
97    nir_ssa_def *one = nir_imm_int(b, 1);
98 
99    nir_ssa_def *abs = nir_iand(b, src, nir_imm_int(b, 0x7FFFFFFF));
100    /* NaN or INF. For rtne, overflow also becomes INF, so combine the comparisons */
101    nir_push_if(b, nir_ige(b, abs, mode == nir_rounding_mode_rtne ? f16max : f32infinity));
102    nir_ssa_def *inf_nanfp16 = nir_bcsel(b,
103                                     nir_ilt(b, f32infinity, abs),
104                                     nir_imm_int(b, 0x7E00),
105                                     nir_imm_int(b, 0x7C00));
106    nir_push_else(b, NULL);
107 
108    nir_ssa_def *overflowed_fp16 = NULL;
109    if (mode != nir_rounding_mode_rtne) {
110       /* Handle overflow */
111       nir_push_if(b, nir_ige(b, abs, f16max));
112       switch (mode) {
113       case nir_rounding_mode_rtz:
114          overflowed_fp16 = nir_imm_int(b, 0x7BFF);
115          break;
116       case nir_rounding_mode_ru:
117          /* Negative becomes max float, positive becomes inf */
118          overflowed_fp16 = nir_bcsel(b, nir_i2b1(b, sign), nir_imm_int(b, 0x7BFF), nir_imm_int(b, 0x7C00));
119          break;
120       case nir_rounding_mode_rd:
121          /* Negative becomes inf, positive becomes max float */
122          overflowed_fp16 = nir_bcsel(b, nir_i2b1(b, sign), nir_imm_int(b, 0x7C00), nir_imm_int(b, 0x7BFF));
123          break;
124       default: unreachable("Should've been handled already");
125       }
126       nir_push_else(b, NULL);
127    }
128 
129    nir_push_if(b, nir_ige(b, abs, nir_imm_int(b, 113 << 23)));
130 
131    /* FP16 will be normal */
132    nir_ssa_def *zero = nir_imm_int(b, 0);
133    nir_ssa_def *value = nir_ior(b,
134                                 nir_ishl(b,
135                                          nir_isub(b,
136                                                   nir_ushr(b, abs, nir_imm_int(b, 23)),
137                                                   nir_imm_int(b, 112)),
138                                          nir_imm_int(b, 10)),
139                                 nir_iand(b, nir_ushr(b, abs, nir_imm_int(b, 13)), nir_imm_int(b, 0x3FFF)));
140    nir_ssa_def *guard = nir_iand(b, nir_ushr(b, abs, nir_imm_int(b, 12)), one);
141    nir_ssa_def *sticky = nir_bcsel(b, nir_ine(b, nir_iand(b, abs, nir_imm_int(b, 0xFFF)), zero), one, zero);
142    nir_ssa_def *normal_fp16 = half_rounded(b, value, guard, sticky, sign, mode);
143 
144    nir_push_else(b, NULL);
145    nir_push_if(b, nir_ige(b, abs, nir_imm_int(b, 102 << 23)));
146 
147    /* FP16 will be denormal */
148    nir_ssa_def *i = nir_isub(b, nir_imm_int(b, 125), nir_ushr(b, abs, nir_imm_int(b, 23)));
149    nir_ssa_def *masked = nir_ior(b, nir_iand(b, abs, nir_imm_int(b, 0x7FFFFF)), nir_imm_int(b, 0x800000));
150    value = nir_ushr(b, masked, nir_iadd(b, i, one));
151    guard = nir_iand(b, nir_ushr(b, masked, i), one);
152    sticky = nir_bcsel(b, nir_ine(b, nir_iand(b, masked, nir_isub(b, nir_ishl(b, one, i), one)), zero), one, zero);
153    nir_ssa_def *denormal_fp16 = half_rounded(b, value, guard, sticky, sign, mode);
154 
155    nir_push_else(b, NULL);
156 
157    /* Handle underflow. Nonzero values need to shift up or down for round-up or round-down */
158    nir_ssa_def *underflowed_fp16 = zero;
159    if (mode == nir_rounding_mode_ru ||
160        mode == nir_rounding_mode_rd) {
161       nir_push_if(b, nir_i2b1(b, abs));
162 
163       if (mode == nir_rounding_mode_ru)
164          underflowed_fp16 = nir_bcsel(b, nir_i2b1(b, sign), zero, one);
165       else
166          underflowed_fp16 = nir_bcsel(b, nir_i2b1(b, sign), one, zero);
167 
168       nir_push_else(b, NULL);
169       nir_pop_if(b, NULL);
170       underflowed_fp16 = nir_if_phi(b, underflowed_fp16, zero);
171    }
172 
173    nir_pop_if(b, NULL);
174    nir_ssa_def *underflowed_or_denorm_fp16 = nir_if_phi(b, denormal_fp16, underflowed_fp16);
175 
176    nir_pop_if(b, NULL);
177    nir_ssa_def *finite_fp16 = nir_if_phi(b, normal_fp16, underflowed_or_denorm_fp16);
178 
179    nir_ssa_def *finite_or_overflowed_fp16 = finite_fp16;
180    if (mode != nir_rounding_mode_rtne) {
181       nir_pop_if(b, NULL);
182       finite_or_overflowed_fp16 = nir_if_phi(b, overflowed_fp16, finite_fp16);
183    }
184 
185    nir_pop_if(b, NULL);
186    nir_ssa_def *fp16 = nir_if_phi(b, inf_nanfp16, finite_or_overflowed_fp16);
187 
188    return nir_u2u16(b, nir_ior(b, fp16, nir_ushr(b, sign, nir_imm_int(b, 16))));
189 }
190 
191 static nir_ssa_def *
lower_fp16_cast_impl(nir_builder * b,nir_instr * instr,void * data)192 lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void *data)
193 {
194    nir_ssa_def *src, *dst;
195    uint8_t *swizzle = NULL;
196    nir_rounding_mode mode = nir_rounding_mode_rtne;
197 
198    if (instr->type == nir_instr_type_alu) {
199       nir_alu_instr *alu = nir_instr_as_alu(instr);
200       src = alu->src[0].src.ssa;
201       swizzle = alu->src[0].swizzle;
202       dst = &alu->dest.dest.ssa;
203       switch (alu->op) {
204       case nir_op_f2f16:
205       case nir_op_f2f16_rtne:
206          break;
207       case nir_op_f2f16_rtz:
208          mode = nir_rounding_mode_rtz;
209          break;
210       default: unreachable("Should've been filtered");
211       }
212    } else {
213       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
214       assert(nir_intrinsic_src_type(intrin) == nir_type_float32);
215       src = intrin->src[0].ssa;
216       dst = &intrin->dest.ssa;
217       mode = nir_intrinsic_rounding_mode(intrin);
218    }
219 
220    nir_ssa_def *rets[NIR_MAX_VEC_COMPONENTS] = { NULL };
221 
222    for (unsigned i = 0; i < dst->num_components; i++) {
223       nir_ssa_def *comp = nir_channel(b, src, swizzle ? swizzle[i] : i);
224       rets[i] = float_to_half_impl(b, comp, mode);
225    }
226 
227    return nir_vec(b, rets, dst->num_components);
228 }
229 
230 bool
nir_lower_fp16_casts(nir_shader * shader)231 nir_lower_fp16_casts(nir_shader *shader)
232 {
233    return nir_shader_lower_instructions(shader,
234                                         lower_fp16_casts_filter,
235                                         lower_fp16_cast_impl,
236                                         NULL);
237 }
238