1 /*
2 * Copyright © 2015 Red Hat
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 * Rob Clark <robclark@freedesktop.org>
25 */
26
27 #include "nir.h"
28 #include "nir_builder.h"
29
30 /* Has two paths
31 * One (nir_lower_idiv_fast) lowers idiv/udiv/umod and is based on
32 * NV50LegalizeSSA::handleDIV()
33 *
34 * Note that this path probably does not have not enough precision for
35 * compute shaders. Perhaps we want a second higher precision (looping)
36 * version of this? Or perhaps we assume if you can do compute shaders you
37 * can also branch out to a pre-optimized shader library routine..
38 *
39 * The other path (nir_lower_idiv_precise) is based off of code used by LLVM's
40 * AMDGPU target. It should handle 32-bit idiv/irem/imod/udiv/umod exactly.
41 */
42
43 static nir_ssa_def *
convert_instr(nir_builder * bld,nir_op op,nir_ssa_def * numer,nir_ssa_def * denom)44 convert_instr(nir_builder *bld, nir_op op,
45 nir_ssa_def *numer, nir_ssa_def *denom)
46 {
47 nir_ssa_def *af, *bf, *a, *b, *q, *r, *rt;
48 bool is_signed;
49
50 is_signed = (op == nir_op_idiv ||
51 op == nir_op_imod ||
52 op == nir_op_irem);
53
54 if (is_signed) {
55 af = nir_i2f32(bld, numer);
56 bf = nir_i2f32(bld, denom);
57 af = nir_fabs(bld, af);
58 bf = nir_fabs(bld, bf);
59 a = nir_iabs(bld, numer);
60 b = nir_iabs(bld, denom);
61 } else {
62 af = nir_u2f32(bld, numer);
63 bf = nir_u2f32(bld, denom);
64 a = numer;
65 b = denom;
66 }
67
68 /* get first result: */
69 bf = nir_frcp(bld, bf);
70 bf = nir_isub(bld, bf, nir_imm_int(bld, 2)); /* yes, really */
71 q = nir_fmul(bld, af, bf);
72
73 if (is_signed) {
74 q = nir_f2i32(bld, q);
75 } else {
76 q = nir_f2u32(bld, q);
77 }
78
79 /* get error of first result: */
80 r = nir_imul(bld, q, b);
81 r = nir_isub(bld, a, r);
82 r = nir_u2f32(bld, r);
83 r = nir_fmul(bld, r, bf);
84 r = nir_f2u32(bld, r);
85
86 /* add quotients: */
87 q = nir_iadd(bld, q, r);
88
89 /* correction: if modulus >= divisor, add 1 */
90 r = nir_imul(bld, q, b);
91 r = nir_isub(bld, a, r);
92 rt = nir_uge(bld, r, b);
93
94 if (op == nir_op_umod) {
95 q = nir_bcsel(bld, rt, nir_isub(bld, r, b), r);
96 } else {
97 r = nir_b2i32(bld, rt);
98
99 q = nir_iadd(bld, q, r);
100 if (is_signed) {
101 /* fix the sign: */
102 r = nir_ixor(bld, numer, denom);
103 r = nir_ilt(bld, r, nir_imm_int(bld, 0));
104 b = nir_ineg(bld, q);
105 q = nir_bcsel(bld, r, b, q);
106
107 if (op == nir_op_imod || op == nir_op_irem) {
108 q = nir_imul(bld, q, denom);
109 q = nir_isub(bld, numer, q);
110 if (op == nir_op_imod) {
111 q = nir_bcsel(bld, nir_ieq_imm(bld, q, 0),
112 nir_imm_int(bld, 0),
113 nir_bcsel(bld, r, nir_iadd(bld, q, denom), q));
114 }
115 }
116 }
117 }
118
119 return q;
120 }
121
122 /* ported from LLVM's AMDGPUTargetLowering::LowerUDIVREM */
123 static nir_ssa_def *
emit_udiv(nir_builder * bld,nir_ssa_def * numer,nir_ssa_def * denom,bool modulo)124 emit_udiv(nir_builder *bld, nir_ssa_def *numer, nir_ssa_def *denom, bool modulo)
125 {
126 nir_ssa_def *rcp = nir_frcp(bld, nir_u2f32(bld, denom));
127 rcp = nir_f2u32(bld, nir_fmul_imm(bld, rcp, 4294966784.0));
128
129 nir_ssa_def *neg_rcp_times_denom =
130 nir_imul(bld, rcp, nir_ineg(bld, denom));
131 rcp = nir_iadd(bld, rcp, nir_umul_high(bld, rcp, neg_rcp_times_denom));
132
133 /* Get initial estimate for quotient/remainder, then refine the estimate
134 * in two iterations after */
135 nir_ssa_def *quotient = nir_umul_high(bld, numer, rcp);
136 nir_ssa_def *num_s_remainder = nir_imul(bld, quotient, denom);
137 nir_ssa_def *remainder = nir_isub(bld, numer, num_s_remainder);
138
139 /* First refinement step */
140 nir_ssa_def *remainder_ge_den = nir_uge(bld, remainder, denom);
141 if (!modulo) {
142 quotient = nir_bcsel(bld, remainder_ge_den,
143 nir_iadd_imm(bld, quotient, 1), quotient);
144 }
145 remainder = nir_bcsel(bld, remainder_ge_den,
146 nir_isub(bld, remainder, denom), remainder);
147
148 /* Second refinement step */
149 remainder_ge_den = nir_uge(bld, remainder, denom);
150 if (modulo) {
151 return nir_bcsel(bld, remainder_ge_den, nir_isub(bld, remainder, denom),
152 remainder);
153 } else {
154 return nir_bcsel(bld, remainder_ge_den, nir_iadd_imm(bld, quotient, 1),
155 quotient);
156 }
157 }
158
159 /* ported from LLVM's AMDGPUTargetLowering::LowerSDIVREM */
160 static nir_ssa_def *
emit_idiv(nir_builder * bld,nir_ssa_def * numer,nir_ssa_def * denom,nir_op op)161 emit_idiv(nir_builder *bld, nir_ssa_def *numer, nir_ssa_def *denom, nir_op op)
162 {
163 nir_ssa_def *lh_sign = nir_ilt(bld, numer, nir_imm_int(bld, 0));
164 nir_ssa_def *rh_sign = nir_ilt(bld, denom, nir_imm_int(bld, 0));
165
166 nir_ssa_def *lhs = nir_iabs(bld, numer);
167 nir_ssa_def *rhs = nir_iabs(bld, denom);
168
169 if (op == nir_op_idiv) {
170 nir_ssa_def *d_sign = nir_ixor(bld, lh_sign, rh_sign);
171 nir_ssa_def *res = emit_udiv(bld, lhs, rhs, false);
172 return nir_bcsel(bld, d_sign, nir_ineg(bld, res), res);
173 } else {
174 nir_ssa_def *res = emit_udiv(bld, lhs, rhs, true);
175 res = nir_bcsel(bld, lh_sign, nir_ineg(bld, res), res);
176 if (op == nir_op_imod) {
177 nir_ssa_def *cond = nir_ieq_imm(bld, res, 0);
178 cond = nir_ior(bld, nir_ieq(bld, lh_sign, rh_sign), cond);
179 res = nir_bcsel(bld, cond, res, nir_iadd(bld, res, denom));
180 }
181 return res;
182 }
183 }
184
185 static nir_ssa_def *
convert_instr_precise(nir_builder * bld,nir_op op,nir_ssa_def * numer,nir_ssa_def * denom)186 convert_instr_precise(nir_builder *bld, nir_op op,
187 nir_ssa_def *numer, nir_ssa_def *denom)
188 {
189 if (op == nir_op_udiv || op == nir_op_umod)
190 return emit_udiv(bld, numer, denom, op == nir_op_umod);
191 else
192 return emit_idiv(bld, numer, denom, op);
193 }
194
195 static nir_ssa_def *
convert_instr_small(nir_builder * b,nir_op op,nir_ssa_def * numer,nir_ssa_def * denom,const nir_lower_idiv_options * options)196 convert_instr_small(nir_builder *b, nir_op op,
197 nir_ssa_def *numer, nir_ssa_def *denom,
198 const nir_lower_idiv_options *options)
199 {
200 unsigned sz = numer->bit_size;
201 nir_alu_type int_type = nir_op_infos[op].output_type | sz;
202 nir_alu_type float_type = nir_type_float | (options->allow_fp16 ? sz * 2 : 32);
203
204 nir_ssa_def *p = nir_type_convert(b, numer, int_type, float_type);
205 nir_ssa_def *q = nir_type_convert(b, denom, int_type, float_type);
206
207 /* Take 1/q but offset mantissa by 1 to correct for rounding. This is
208 * needed for correct results and has been checked exhaustively for
209 * all pairs of 16-bit integers */
210 nir_ssa_def *rcp = nir_iadd_imm(b, nir_frcp(b, q), 1);
211
212 /* Divide by multiplying by adjusted reciprocal */
213 nir_ssa_def *res = nir_fmul(b, p, rcp);
214
215 /* Convert back to integer space with rounding inferred by type */
216 res = nir_type_convert(b, res, float_type, int_type);
217
218 /* Get remainder given the quotient */
219 if (op == nir_op_umod || op == nir_op_imod || op == nir_op_irem)
220 res = nir_isub(b, numer, nir_imul(b, denom, res));
221
222 /* Adjust for sign, see constant folding definition */
223 if (op == nir_op_imod) {
224 nir_ssa_def *zero = nir_imm_zero(b, 1, sz);
225 nir_ssa_def *diff_sign =
226 nir_ine(b, nir_ige(b, numer, zero), nir_ige(b, denom, zero));
227
228 nir_ssa_def *adjust = nir_iand(b, diff_sign, nir_ine(b, res, zero));
229 res = nir_iadd(b, res, nir_bcsel(b, adjust, denom, zero));
230 }
231
232 return res;
233 }
234
235 static nir_ssa_def *
lower_idiv(nir_builder * b,nir_instr * instr,void * _data)236 lower_idiv(nir_builder *b, nir_instr *instr, void *_data)
237 {
238 const nir_lower_idiv_options *options = _data;
239 nir_alu_instr *alu = nir_instr_as_alu(instr);
240
241 nir_ssa_def *numer = nir_ssa_for_alu_src(b, alu, 0);
242 nir_ssa_def *denom = nir_ssa_for_alu_src(b, alu, 1);
243
244 b->exact = true;
245
246 if (numer->bit_size < 32)
247 return convert_instr_small(b, alu->op, numer, denom, options);
248 else if (options->imprecise_32bit_lowering)
249 return convert_instr(b, alu->op, numer, denom);
250 else
251 return convert_instr_precise(b, alu->op, numer, denom);
252 }
253
254 static bool
inst_is_idiv(const nir_instr * instr,UNUSED const void * _state)255 inst_is_idiv(const nir_instr *instr, UNUSED const void *_state)
256 {
257 if (instr->type != nir_instr_type_alu)
258 return false;
259
260 nir_alu_instr *alu = nir_instr_as_alu(instr);
261
262 if (alu->dest.dest.ssa.bit_size > 32)
263 return false;
264
265 switch (alu->op) {
266 case nir_op_idiv:
267 case nir_op_udiv:
268 case nir_op_imod:
269 case nir_op_umod:
270 case nir_op_irem:
271 return true;
272 default:
273 return false;
274 }
275 }
276
277 bool
nir_lower_idiv(nir_shader * shader,const nir_lower_idiv_options * options)278 nir_lower_idiv(nir_shader *shader, const nir_lower_idiv_options *options)
279 {
280 return nir_shader_lower_instructions(shader,
281 inst_is_idiv,
282 lower_idiv,
283 (void *)options);
284 }
285