1 /*
2 * Copyright © 2018 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23 #include <math.h>
24 #include <float.h>
25 #include "nir.h"
26 #include "nir_range_analysis.h"
27 #include "util/hash_table.h"
28
29 /**
30 * Analyzes a sequence of operations to determine some aspects of the range of
31 * the result.
32 */
33
34 static bool
is_not_negative(enum ssa_ranges r)35 is_not_negative(enum ssa_ranges r)
36 {
37 return r == gt_zero || r == ge_zero || r == eq_zero;
38 }
39
40 static void *
pack_data(const struct ssa_result_range r)41 pack_data(const struct ssa_result_range r)
42 {
43 return (void *)(uintptr_t)(r.range | r.is_integral << 8);
44 }
45
46 static struct ssa_result_range
unpack_data(const void * p)47 unpack_data(const void *p)
48 {
49 const uintptr_t v = (uintptr_t) p;
50
51 return (struct ssa_result_range){v & 0xff, (v & 0x0ff00) != 0};
52 }
53
54 static void *
pack_key(const struct nir_alu_instr * instr,nir_alu_type type)55 pack_key(const struct nir_alu_instr *instr, nir_alu_type type)
56 {
57 uintptr_t type_encoding;
58 uintptr_t ptr = (uintptr_t) instr;
59
60 /* The low 2 bits have to be zero or this whole scheme falls apart. */
61 assert((ptr & 0x3) == 0);
62
63 /* NIR is typeless in the sense that sequences of bits have whatever
64 * meaning is attached to them by the instruction that consumes them.
65 * However, the number of bits must match between producer and consumer.
66 * As a result, the number of bits does not need to be encoded here.
67 */
68 switch (nir_alu_type_get_base_type(type)) {
69 case nir_type_int: type_encoding = 0; break;
70 case nir_type_uint: type_encoding = 1; break;
71 case nir_type_bool: type_encoding = 2; break;
72 case nir_type_float: type_encoding = 3; break;
73 default: unreachable("Invalid base type.");
74 }
75
76 return (void *)(ptr | type_encoding);
77 }
78
79 static nir_alu_type
nir_alu_src_type(const nir_alu_instr * instr,unsigned src)80 nir_alu_src_type(const nir_alu_instr *instr, unsigned src)
81 {
82 return nir_alu_type_get_base_type(nir_op_infos[instr->op].input_types[src]) |
83 nir_src_bit_size(instr->src[src].src);
84 }
85
86 static struct ssa_result_range
analyze_constant(const struct nir_alu_instr * instr,unsigned src,nir_alu_type use_type)87 analyze_constant(const struct nir_alu_instr *instr, unsigned src,
88 nir_alu_type use_type)
89 {
90 uint8_t swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3,
91 4, 5, 6, 7,
92 8, 9, 10, 11,
93 12, 13, 14, 15 };
94
95 /* If the source is an explicitly sized source, then we need to reset
96 * both the number of components and the swizzle.
97 */
98 const unsigned num_components = nir_ssa_alu_instr_src_components(instr, src);
99
100 for (unsigned i = 0; i < num_components; ++i)
101 swizzle[i] = instr->src[src].swizzle[i];
102
103 const nir_load_const_instr *const load =
104 nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr);
105
106 struct ssa_result_range r = { unknown, false };
107
108 switch (nir_alu_type_get_base_type(use_type)) {
109 case nir_type_float: {
110 double min_value = DBL_MAX;
111 double max_value = -DBL_MAX;
112 bool any_zero = false;
113 bool all_zero = true;
114
115 r.is_integral = true;
116
117 for (unsigned i = 0; i < num_components; ++i) {
118 const double v = nir_const_value_as_float(load->value[swizzle[i]],
119 load->def.bit_size);
120
121 if (floor(v) != v)
122 r.is_integral = false;
123
124 any_zero = any_zero || (v == 0.0);
125 all_zero = all_zero && (v == 0.0);
126 min_value = MIN2(min_value, v);
127 max_value = MAX2(max_value, v);
128 }
129
130 assert(any_zero >= all_zero);
131 assert(isnan(max_value) || max_value >= min_value);
132
133 if (all_zero)
134 r.range = eq_zero;
135 else if (min_value > 0.0)
136 r.range = gt_zero;
137 else if (min_value == 0.0)
138 r.range = ge_zero;
139 else if (max_value < 0.0)
140 r.range = lt_zero;
141 else if (max_value == 0.0)
142 r.range = le_zero;
143 else if (!any_zero)
144 r.range = ne_zero;
145 else
146 r.range = unknown;
147
148 return r;
149 }
150
151 case nir_type_int:
152 case nir_type_bool: {
153 int64_t min_value = INT_MAX;
154 int64_t max_value = INT_MIN;
155 bool any_zero = false;
156 bool all_zero = true;
157
158 for (unsigned i = 0; i < num_components; ++i) {
159 const int64_t v = nir_const_value_as_int(load->value[swizzle[i]],
160 load->def.bit_size);
161
162 any_zero = any_zero || (v == 0);
163 all_zero = all_zero && (v == 0);
164 min_value = MIN2(min_value, v);
165 max_value = MAX2(max_value, v);
166 }
167
168 assert(any_zero >= all_zero);
169 assert(max_value >= min_value);
170
171 if (all_zero)
172 r.range = eq_zero;
173 else if (min_value > 0)
174 r.range = gt_zero;
175 else if (min_value == 0)
176 r.range = ge_zero;
177 else if (max_value < 0)
178 r.range = lt_zero;
179 else if (max_value == 0)
180 r.range = le_zero;
181 else if (!any_zero)
182 r.range = ne_zero;
183 else
184 r.range = unknown;
185
186 return r;
187 }
188
189 case nir_type_uint: {
190 bool any_zero = false;
191 bool all_zero = true;
192
193 for (unsigned i = 0; i < num_components; ++i) {
194 const uint64_t v = nir_const_value_as_uint(load->value[swizzle[i]],
195 load->def.bit_size);
196
197 any_zero = any_zero || (v == 0);
198 all_zero = all_zero && (v == 0);
199 }
200
201 assert(any_zero >= all_zero);
202
203 if (all_zero)
204 r.range = eq_zero;
205 else if (any_zero)
206 r.range = ge_zero;
207 else
208 r.range = gt_zero;
209
210 return r;
211 }
212
213 default:
214 unreachable("Invalid alu source type");
215 }
216 }
217
218 /**
219 * Short-hand name for use in the tables in analyze_expression. If this name
220 * becomes a problem on some compiler, we can change it to _.
221 */
222 #define _______ unknown
223
224
225 #if defined(__clang__)
226 /* clang wants _Pragma("unroll X") */
227 #define pragma_unroll_5 _Pragma("unroll 5")
228 #define pragma_unroll_7 _Pragma("unroll 7")
229 /* gcc wants _Pragma("GCC unroll X") */
230 #elif defined(__GNUC__)
231 #if __GNUC__ >= 8
232 #define pragma_unroll_5 _Pragma("GCC unroll 5")
233 #define pragma_unroll_7 _Pragma("GCC unroll 7")
234 #else
235 #pragma GCC optimize ("unroll-loops")
236 #define pragma_unroll_5
237 #define pragma_unroll_7
238 #endif
239 #else
240 /* MSVC doesn't have C99's _Pragma() */
241 #define pragma_unroll_5
242 #define pragma_unroll_7
243 #endif
244
245
246 #ifndef NDEBUG
247 #define ASSERT_TABLE_IS_COMMUTATIVE(t) \
248 do { \
249 static bool first = true; \
250 if (first) { \
251 first = false; \
252 pragma_unroll_7 \
253 for (unsigned r = 0; r < ARRAY_SIZE(t); r++) { \
254 pragma_unroll_7 \
255 for (unsigned c = 0; c < ARRAY_SIZE(t[0]); c++) \
256 assert(t[r][c] == t[c][r]); \
257 } \
258 } \
259 } while (false)
260
261 #define ASSERT_TABLE_IS_DIAGONAL(t) \
262 do { \
263 static bool first = true; \
264 if (first) { \
265 first = false; \
266 pragma_unroll_7 \
267 for (unsigned r = 0; r < ARRAY_SIZE(t); r++) \
268 assert(t[r][r] == r); \
269 } \
270 } while (false)
271
272 static enum ssa_ranges
union_ranges(enum ssa_ranges a,enum ssa_ranges b)273 union_ranges(enum ssa_ranges a, enum ssa_ranges b)
274 {
275 static const enum ssa_ranges union_table[last_range + 1][last_range + 1] = {
276 /* left\right unknown lt_zero le_zero gt_zero ge_zero ne_zero eq_zero */
277 /* unknown */ { _______, _______, _______, _______, _______, _______, _______ },
278 /* lt_zero */ { _______, lt_zero, le_zero, ne_zero, _______, ne_zero, le_zero },
279 /* le_zero */ { _______, le_zero, le_zero, _______, _______, _______, le_zero },
280 /* gt_zero */ { _______, ne_zero, _______, gt_zero, ge_zero, ne_zero, ge_zero },
281 /* ge_zero */ { _______, _______, _______, ge_zero, ge_zero, _______, ge_zero },
282 /* ne_zero */ { _______, ne_zero, _______, ne_zero, _______, ne_zero, _______ },
283 /* eq_zero */ { _______, le_zero, le_zero, ge_zero, ge_zero, _______, eq_zero },
284 };
285
286 ASSERT_TABLE_IS_COMMUTATIVE(union_table);
287 ASSERT_TABLE_IS_DIAGONAL(union_table);
288
289 return union_table[a][b];
290 }
291
292 /* Verify that the 'unknown' entry in each row (or column) of the table is the
293 * union of all the other values in the row (or column).
294 */
295 #define ASSERT_UNION_OF_OTHERS_MATCHES_UNKNOWN_2_SOURCE(t) \
296 do { \
297 static bool first = true; \
298 if (first) { \
299 first = false; \
300 pragma_unroll_7 \
301 for (unsigned i = 0; i < last_range; i++) { \
302 enum ssa_ranges col_range = t[i][unknown + 1]; \
303 enum ssa_ranges row_range = t[unknown + 1][i]; \
304 \
305 pragma_unroll_5 \
306 for (unsigned j = unknown + 2; j < last_range; j++) { \
307 col_range = union_ranges(col_range, t[i][j]); \
308 row_range = union_ranges(row_range, t[j][i]); \
309 } \
310 \
311 assert(col_range == t[i][unknown]); \
312 assert(row_range == t[unknown][i]); \
313 } \
314 } \
315 } while (false)
316
317 /* For most operations, the union of ranges for a strict inequality and
318 * equality should be the range of the non-strict inequality (e.g.,
319 * union_ranges(range(op(lt_zero), range(op(eq_zero))) == range(op(le_zero)).
320 *
321 * Does not apply to selection-like opcodes (bcsel, fmin, fmax, etc.).
322 */
323 #define ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_1_SOURCE(t) \
324 do { \
325 assert(union_ranges(t[lt_zero], t[eq_zero]) == t[le_zero]); \
326 assert(union_ranges(t[gt_zero], t[eq_zero]) == t[ge_zero]); \
327 } while (false)
328
329 #define ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_2_SOURCE(t) \
330 do { \
331 static bool first = true; \
332 if (first) { \
333 first = false; \
334 pragma_unroll_7 \
335 for (unsigned i = 0; i < last_range; i++) { \
336 assert(union_ranges(t[i][lt_zero], t[i][eq_zero]) == t[i][le_zero]); \
337 assert(union_ranges(t[i][gt_zero], t[i][eq_zero]) == t[i][ge_zero]); \
338 assert(union_ranges(t[lt_zero][i], t[eq_zero][i]) == t[le_zero][i]); \
339 assert(union_ranges(t[gt_zero][i], t[eq_zero][i]) == t[ge_zero][i]); \
340 } \
341 } \
342 } while (false)
343
344 /* Several other unordered tuples span the range of "everything." Each should
345 * have the same value as unknown: (lt_zero, ge_zero), (le_zero, gt_zero), and
346 * (eq_zero, ne_zero). union_ranges is already commutative, so only one
347 * ordering needs to be checked.
348 *
349 * Does not apply to selection-like opcodes (bcsel, fmin, fmax, etc.).
350 *
351 * In cases where this can be used, it is unnecessary to also use
352 * ASSERT_UNION_OF_OTHERS_MATCHES_UNKNOWN_*_SOURCE. For any range X,
353 * union_ranges(X, X) == X. The disjoint ranges cover all of the non-unknown
354 * possibilities, so the union of all the unions of disjoint ranges is
355 * equivalent to the union of "others."
356 */
357 #define ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_1_SOURCE(t) \
358 do { \
359 assert(union_ranges(t[lt_zero], t[ge_zero]) == t[unknown]); \
360 assert(union_ranges(t[le_zero], t[gt_zero]) == t[unknown]); \
361 assert(union_ranges(t[eq_zero], t[ne_zero]) == t[unknown]); \
362 } while (false)
363
364 #define ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_2_SOURCE(t) \
365 do { \
366 static bool first = true; \
367 if (first) { \
368 first = false; \
369 pragma_unroll_7 \
370 for (unsigned i = 0; i < last_range; i++) { \
371 assert(union_ranges(t[i][lt_zero], t[i][ge_zero]) == \
372 t[i][unknown]); \
373 assert(union_ranges(t[i][le_zero], t[i][gt_zero]) == \
374 t[i][unknown]); \
375 assert(union_ranges(t[i][eq_zero], t[i][ne_zero]) == \
376 t[i][unknown]); \
377 \
378 assert(union_ranges(t[lt_zero][i], t[ge_zero][i]) == \
379 t[unknown][i]); \
380 assert(union_ranges(t[le_zero][i], t[gt_zero][i]) == \
381 t[unknown][i]); \
382 assert(union_ranges(t[eq_zero][i], t[ne_zero][i]) == \
383 t[unknown][i]); \
384 } \
385 } \
386 } while (false)
387
388 #else
389 #define ASSERT_TABLE_IS_COMMUTATIVE(t)
390 #define ASSERT_TABLE_IS_DIAGONAL(t)
391 #define ASSERT_UNION_OF_OTHERS_MATCHES_UNKNOWN_2_SOURCE(t)
392 #define ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_1_SOURCE(t)
393 #define ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_2_SOURCE(t)
394 #define ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_1_SOURCE(t)
395 #define ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_2_SOURCE(t)
396 #endif
397
398 /**
399 * Analyze an expression to determine the range of its result
400 *
401 * The end result of this analysis is a token that communicates something
402 * about the range of values. There's an implicit grammar that produces
403 * tokens from sequences of literal values, other tokens, and operations.
404 * This function implements this grammar as a recursive-descent parser. Some
405 * (but not all) of the grammar is listed in-line in the function.
406 */
407 static struct ssa_result_range
analyze_expression(const nir_alu_instr * instr,unsigned src,struct hash_table * ht,nir_alu_type use_type)408 analyze_expression(const nir_alu_instr *instr, unsigned src,
409 struct hash_table *ht, nir_alu_type use_type)
410 {
411 /* Ensure that the _Pragma("GCC unroll 7") above are correct. */
412 STATIC_ASSERT(last_range + 1 == 7);
413
414 if (!instr->src[src].src.is_ssa)
415 return (struct ssa_result_range){unknown, false};
416
417 if (nir_src_is_const(instr->src[src].src))
418 return analyze_constant(instr, src, use_type);
419
420 if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
421 return (struct ssa_result_range){unknown, false};
422
423 const struct nir_alu_instr *const alu =
424 nir_instr_as_alu(instr->src[src].src.ssa->parent_instr);
425
426 /* Bail if the type of the instruction generating the value does not match
427 * the type the value will be interpreted as. int/uint/bool can be
428 * reinterpreted trivially. The most important cases are between float and
429 * non-float.
430 */
431 if (alu->op != nir_op_mov && alu->op != nir_op_bcsel) {
432 const nir_alu_type use_base_type =
433 nir_alu_type_get_base_type(use_type);
434 const nir_alu_type src_base_type =
435 nir_alu_type_get_base_type(nir_op_infos[alu->op].output_type);
436
437 if (use_base_type != src_base_type &&
438 (use_base_type == nir_type_float ||
439 src_base_type == nir_type_float)) {
440 return (struct ssa_result_range){unknown, false};
441 }
442 }
443
444 struct hash_entry *he = _mesa_hash_table_search(ht, pack_key(alu, use_type));
445 if (he != NULL)
446 return unpack_data(he->data);
447
448 struct ssa_result_range r = {unknown, false};
449
450 /* ge_zero: ge_zero + ge_zero
451 *
452 * gt_zero: gt_zero + eq_zero
453 * | gt_zero + ge_zero
454 * | eq_zero + gt_zero # Addition is commutative
455 * | ge_zero + gt_zero # Addition is commutative
456 * | gt_zero + gt_zero
457 * ;
458 *
459 * le_zero: le_zero + le_zero
460 *
461 * lt_zero: lt_zero + eq_zero
462 * | lt_zero + le_zero
463 * | eq_zero + lt_zero # Addition is commutative
464 * | le_zero + lt_zero # Addition is commutative
465 * | lt_zero + lt_zero
466 * ;
467 *
468 * ne_zero: eq_zero + ne_zero
469 * | ne_zero + eq_zero # Addition is commutative
470 * ;
471 *
472 * eq_zero: eq_zero + eq_zero
473 * ;
474 *
475 * All other cases are 'unknown'. The seeming odd entry is (ne_zero,
476 * ne_zero), but that could be (-5, +5) which is not ne_zero.
477 */
478 static const enum ssa_ranges fadd_table[last_range + 1][last_range + 1] = {
479 /* left\right unknown lt_zero le_zero gt_zero ge_zero ne_zero eq_zero */
480 /* unknown */ { _______, _______, _______, _______, _______, _______, _______ },
481 /* lt_zero */ { _______, lt_zero, lt_zero, _______, _______, _______, lt_zero },
482 /* le_zero */ { _______, lt_zero, le_zero, _______, _______, _______, le_zero },
483 /* gt_zero */ { _______, _______, _______, gt_zero, gt_zero, _______, gt_zero },
484 /* ge_zero */ { _______, _______, _______, gt_zero, ge_zero, _______, ge_zero },
485 /* ne_zero */ { _______, _______, _______, _______, _______, _______, ne_zero },
486 /* eq_zero */ { _______, lt_zero, le_zero, gt_zero, ge_zero, ne_zero, eq_zero },
487 };
488
489 ASSERT_TABLE_IS_COMMUTATIVE(fadd_table);
490 ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_2_SOURCE(fadd_table);
491 ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_2_SOURCE(fadd_table);
492
493 /* Due to flush-to-zero semanatics of floating-point numbers with very
494 * small mangnitudes, we can never really be sure a result will be
495 * non-zero.
496 *
497 * ge_zero: ge_zero * ge_zero
498 * | ge_zero * gt_zero
499 * | ge_zero * eq_zero
500 * | le_zero * lt_zero
501 * | lt_zero * le_zero # Multiplication is commutative
502 * | le_zero * le_zero
503 * | gt_zero * ge_zero # Multiplication is commutative
504 * | eq_zero * ge_zero # Multiplication is commutative
505 * | a * a # Left source == right source
506 * | gt_zero * gt_zero
507 * | lt_zero * lt_zero
508 * ;
509 *
510 * le_zero: ge_zero * le_zero
511 * | ge_zero * lt_zero
512 * | lt_zero * ge_zero # Multiplication is commutative
513 * | le_zero * ge_zero # Multiplication is commutative
514 * | le_zero * gt_zero
515 * | lt_zero * gt_zero
516 * | gt_zero * lt_zero # Multiplication is commutative
517 * ;
518 *
519 * eq_zero: eq_zero * <any>
520 * <any> * eq_zero # Multiplication is commutative
521 *
522 * All other cases are 'unknown'.
523 */
524 static const enum ssa_ranges fmul_table[last_range + 1][last_range + 1] = {
525 /* left\right unknown lt_zero le_zero gt_zero ge_zero ne_zero eq_zero */
526 /* unknown */ { _______, _______, _______, _______, _______, _______, eq_zero },
527 /* lt_zero */ { _______, ge_zero, ge_zero, le_zero, le_zero, _______, eq_zero },
528 /* le_zero */ { _______, ge_zero, ge_zero, le_zero, le_zero, _______, eq_zero },
529 /* gt_zero */ { _______, le_zero, le_zero, ge_zero, ge_zero, _______, eq_zero },
530 /* ge_zero */ { _______, le_zero, le_zero, ge_zero, ge_zero, _______, eq_zero },
531 /* ne_zero */ { _______, _______, _______, _______, _______, _______, eq_zero },
532 /* eq_zero */ { eq_zero, eq_zero, eq_zero, eq_zero, eq_zero, eq_zero, eq_zero }
533 };
534
535 ASSERT_TABLE_IS_COMMUTATIVE(fmul_table);
536 ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_2_SOURCE(fmul_table);
537 ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_2_SOURCE(fmul_table);
538
539 static const enum ssa_ranges fneg_table[last_range + 1] = {
540 /* unknown lt_zero le_zero gt_zero ge_zero ne_zero eq_zero */
541 _______, gt_zero, ge_zero, lt_zero, le_zero, ne_zero, eq_zero
542 };
543
544 ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_1_SOURCE(fneg_table);
545 ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_1_SOURCE(fneg_table);
546
547
548 switch (alu->op) {
549 case nir_op_b2f32:
550 case nir_op_b2i32:
551 r = (struct ssa_result_range){ge_zero, alu->op == nir_op_b2f32};
552 break;
553
554 case nir_op_bcsel: {
555 const struct ssa_result_range left =
556 analyze_expression(alu, 1, ht, use_type);
557 const struct ssa_result_range right =
558 analyze_expression(alu, 2, ht, use_type);
559
560 r.is_integral = left.is_integral && right.is_integral;
561
562 /* le_zero: bcsel(<any>, le_zero, lt_zero)
563 * | bcsel(<any>, eq_zero, lt_zero)
564 * | bcsel(<any>, le_zero, eq_zero)
565 * | bcsel(<any>, lt_zero, le_zero)
566 * | bcsel(<any>, lt_zero, eq_zero)
567 * | bcsel(<any>, eq_zero, le_zero)
568 * | bcsel(<any>, le_zero, le_zero)
569 * ;
570 *
571 * lt_zero: bcsel(<any>, lt_zero, lt_zero)
572 * ;
573 *
574 * ge_zero: bcsel(<any>, ge_zero, ge_zero)
575 * | bcsel(<any>, ge_zero, gt_zero)
576 * | bcsel(<any>, ge_zero, eq_zero)
577 * | bcsel(<any>, gt_zero, ge_zero)
578 * | bcsel(<any>, eq_zero, ge_zero)
579 * ;
580 *
581 * gt_zero: bcsel(<any>, gt_zero, gt_zero)
582 * ;
583 *
584 * ne_zero: bcsel(<any>, ne_zero, gt_zero)
585 * | bcsel(<any>, ne_zero, lt_zero)
586 * | bcsel(<any>, gt_zero, lt_zero)
587 * | bcsel(<any>, gt_zero, ne_zero)
588 * | bcsel(<any>, lt_zero, ne_zero)
589 * | bcsel(<any>, lt_zero, gt_zero)
590 * | bcsel(<any>, ne_zero, ne_zero)
591 * ;
592 *
593 * eq_zero: bcsel(<any>, eq_zero, eq_zero)
594 * ;
595 *
596 * All other cases are 'unknown'.
597 *
598 * The ranges could be tightened if the range of the first source is
599 * known. However, opt_algebraic will (eventually) elminiate the bcsel
600 * if the condition is known.
601 */
602 static const enum ssa_ranges table[last_range + 1][last_range + 1] = {
603 /* left\right unknown lt_zero le_zero gt_zero ge_zero ne_zero eq_zero */
604 /* unknown */ { _______, _______, _______, _______, _______, _______, _______ },
605 /* lt_zero */ { _______, lt_zero, le_zero, ne_zero, _______, ne_zero, le_zero },
606 /* le_zero */ { _______, le_zero, le_zero, _______, _______, _______, le_zero },
607 /* gt_zero */ { _______, ne_zero, _______, gt_zero, ge_zero, ne_zero, ge_zero },
608 /* ge_zero */ { _______, _______, _______, ge_zero, ge_zero, _______, ge_zero },
609 /* ne_zero */ { _______, ne_zero, _______, ne_zero, _______, ne_zero, _______ },
610 /* eq_zero */ { _______, le_zero, le_zero, ge_zero, ge_zero, _______, eq_zero },
611 };
612
613 ASSERT_TABLE_IS_COMMUTATIVE(table);
614 ASSERT_TABLE_IS_DIAGONAL(table);
615 ASSERT_UNION_OF_OTHERS_MATCHES_UNKNOWN_2_SOURCE(table);
616
617 r.range = table[left.range][right.range];
618 break;
619 }
620
621 case nir_op_i2f32:
622 case nir_op_u2f32:
623 r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
624
625 r.is_integral = true;
626
627 if (r.range == unknown && alu->op == nir_op_u2f32)
628 r.range = ge_zero;
629
630 break;
631
632 case nir_op_fabs:
633 r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
634
635 switch (r.range) {
636 case unknown:
637 case le_zero:
638 case ge_zero:
639 r.range = ge_zero;
640 break;
641
642 case lt_zero:
643 case gt_zero:
644 case ne_zero:
645 r.range = gt_zero;
646 break;
647
648 case eq_zero:
649 break;
650 }
651
652 break;
653
654 case nir_op_fadd: {
655 const struct ssa_result_range left =
656 analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
657 const struct ssa_result_range right =
658 analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
659
660 r.is_integral = left.is_integral && right.is_integral;
661 r.range = fadd_table[left.range][right.range];
662 break;
663 }
664
665 case nir_op_fexp2: {
666 /* If the parameter might be less than zero, the mathematically result
667 * will be on (0, 1). For sufficiently large magnitude negative
668 * parameters, the result will flush to zero.
669 */
670 static const enum ssa_ranges table[last_range + 1] = {
671 /* unknown lt_zero le_zero gt_zero ge_zero ne_zero eq_zero */
672 ge_zero, ge_zero, ge_zero, gt_zero, gt_zero, ge_zero, gt_zero
673 };
674
675 r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
676
677 ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_1_SOURCE(table);
678 ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_1_SOURCE(table);
679
680 r.is_integral = r.is_integral && is_not_negative(r.range);
681 r.range = table[r.range];
682 break;
683 }
684
685 case nir_op_fmax: {
686 const struct ssa_result_range left =
687 analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
688 const struct ssa_result_range right =
689 analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
690
691 r.is_integral = left.is_integral && right.is_integral;
692
693 /* gt_zero: fmax(gt_zero, *)
694 * | fmax(*, gt_zero) # Treat fmax as commutative
695 * ;
696 *
697 * ge_zero: fmax(ge_zero, ne_zero)
698 * | fmax(ge_zero, lt_zero)
699 * | fmax(ge_zero, le_zero)
700 * | fmax(ge_zero, eq_zero)
701 * | fmax(ne_zero, ge_zero) # Treat fmax as commutative
702 * | fmax(lt_zero, ge_zero) # Treat fmax as commutative
703 * | fmax(le_zero, ge_zero) # Treat fmax as commutative
704 * | fmax(eq_zero, ge_zero) # Treat fmax as commutative
705 * | fmax(ge_zero, ge_zero)
706 * ;
707 *
708 * le_zero: fmax(le_zero, lt_zero)
709 * | fmax(lt_zero, le_zero) # Treat fmax as commutative
710 * | fmax(le_zero, le_zero)
711 * ;
712 *
713 * lt_zero: fmax(lt_zero, lt_zero)
714 * ;
715 *
716 * ne_zero: fmax(ne_zero, lt_zero)
717 * | fmax(lt_zero, ne_zero) # Treat fmax as commutative
718 * | fmax(ne_zero, ne_zero)
719 * ;
720 *
721 * eq_zero: fmax(eq_zero, le_zero)
722 * | fmax(eq_zero, lt_zero)
723 * | fmax(le_zero, eq_zero) # Treat fmax as commutative
724 * | fmax(lt_zero, eq_zero) # Treat fmax as commutative
725 * | fmax(eq_zero, eq_zero)
726 * ;
727 *
728 * All other cases are 'unknown'.
729 */
730 static const enum ssa_ranges table[last_range + 1][last_range + 1] = {
731 /* left\right unknown lt_zero le_zero gt_zero ge_zero ne_zero eq_zero */
732 /* unknown */ { _______, _______, _______, gt_zero, ge_zero, _______, _______ },
733 /* lt_zero */ { _______, lt_zero, le_zero, gt_zero, ge_zero, ne_zero, eq_zero },
734 /* le_zero */ { _______, le_zero, le_zero, gt_zero, ge_zero, _______, eq_zero },
735 /* gt_zero */ { gt_zero, gt_zero, gt_zero, gt_zero, gt_zero, gt_zero, gt_zero },
736 /* ge_zero */ { ge_zero, ge_zero, ge_zero, gt_zero, ge_zero, ge_zero, ge_zero },
737 /* ne_zero */ { _______, ne_zero, _______, gt_zero, ge_zero, ne_zero, _______ },
738 /* eq_zero */ { _______, eq_zero, eq_zero, gt_zero, ge_zero, _______, eq_zero }
739 };
740
741 /* Treat fmax as commutative. */
742 ASSERT_TABLE_IS_COMMUTATIVE(table);
743 ASSERT_TABLE_IS_DIAGONAL(table);
744 ASSERT_UNION_OF_OTHERS_MATCHES_UNKNOWN_2_SOURCE(table);
745
746 r.range = table[left.range][right.range];
747 break;
748 }
749
750 case nir_op_fmin: {
751 const struct ssa_result_range left =
752 analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
753 const struct ssa_result_range right =
754 analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
755
756 r.is_integral = left.is_integral && right.is_integral;
757
758 /* lt_zero: fmin(lt_zero, *)
759 * | fmin(*, lt_zero) # Treat fmin as commutative
760 * ;
761 *
762 * le_zero: fmin(le_zero, ne_zero)
763 * | fmin(le_zero, gt_zero)
764 * | fmin(le_zero, ge_zero)
765 * | fmin(le_zero, eq_zero)
766 * | fmin(ne_zero, le_zero) # Treat fmin as commutative
767 * | fmin(gt_zero, le_zero) # Treat fmin as commutative
768 * | fmin(ge_zero, le_zero) # Treat fmin as commutative
769 * | fmin(eq_zero, le_zero) # Treat fmin as commutative
770 * | fmin(le_zero, le_zero)
771 * ;
772 *
773 * ge_zero: fmin(ge_zero, gt_zero)
774 * | fmin(gt_zero, ge_zero) # Treat fmin as commutative
775 * | fmin(ge_zero, ge_zero)
776 * ;
777 *
778 * gt_zero: fmin(gt_zero, gt_zero)
779 * ;
780 *
781 * ne_zero: fmin(ne_zero, gt_zero)
782 * | fmin(gt_zero, ne_zero) # Treat fmin as commutative
783 * | fmin(ne_zero, ne_zero)
784 * ;
785 *
786 * eq_zero: fmin(eq_zero, ge_zero)
787 * | fmin(eq_zero, gt_zero)
788 * | fmin(ge_zero, eq_zero) # Treat fmin as commutative
789 * | fmin(gt_zero, eq_zero) # Treat fmin as commutative
790 * | fmin(eq_zero, eq_zero)
791 * ;
792 *
793 * All other cases are 'unknown'.
794 */
795 static const enum ssa_ranges table[last_range + 1][last_range + 1] = {
796 /* left\right unknown lt_zero le_zero gt_zero ge_zero ne_zero eq_zero */
797 /* unknown */ { _______, lt_zero, le_zero, _______, _______, _______, _______ },
798 /* lt_zero */ { lt_zero, lt_zero, lt_zero, lt_zero, lt_zero, lt_zero, lt_zero },
799 /* le_zero */ { le_zero, lt_zero, le_zero, le_zero, le_zero, le_zero, le_zero },
800 /* gt_zero */ { _______, lt_zero, le_zero, gt_zero, ge_zero, ne_zero, eq_zero },
801 /* ge_zero */ { _______, lt_zero, le_zero, ge_zero, ge_zero, _______, eq_zero },
802 /* ne_zero */ { _______, lt_zero, le_zero, ne_zero, _______, ne_zero, _______ },
803 /* eq_zero */ { _______, lt_zero, le_zero, eq_zero, eq_zero, _______, eq_zero }
804 };
805
806 /* Treat fmin as commutative. */
807 ASSERT_TABLE_IS_COMMUTATIVE(table);
808 ASSERT_TABLE_IS_DIAGONAL(table);
809 ASSERT_UNION_OF_OTHERS_MATCHES_UNKNOWN_2_SOURCE(table);
810
811 r.range = table[left.range][right.range];
812 break;
813 }
814
815 case nir_op_fmul: {
816 const struct ssa_result_range left =
817 analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
818 const struct ssa_result_range right =
819 analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
820
821 r.is_integral = left.is_integral && right.is_integral;
822
823 /* x * x => ge_zero */
824 if (left.range != eq_zero && nir_alu_srcs_equal(alu, alu, 0, 1)) {
825 /* Even if x > 0, the result of x*x can be zero when x is, for
826 * example, a subnormal number.
827 */
828 r.range = ge_zero;
829 } else if (left.range != eq_zero && nir_alu_srcs_negative_equal(alu, alu, 0, 1)) {
830 /* -x * x => le_zero. */
831 r.range = le_zero;
832 } else
833 r.range = fmul_table[left.range][right.range];
834
835 break;
836 }
837
838 case nir_op_frcp:
839 r = (struct ssa_result_range){
840 analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)).range,
841 false
842 };
843 break;
844
845 case nir_op_mov:
846 r = analyze_expression(alu, 0, ht, use_type);
847 break;
848
849 case nir_op_fneg:
850 r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
851
852 r.range = fneg_table[r.range];
853 break;
854
855 case nir_op_fsat:
856 r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
857
858 switch (r.range) {
859 case le_zero:
860 case lt_zero:
861 r.range = eq_zero;
862 r.is_integral = true;
863 break;
864
865 case eq_zero:
866 assert(r.is_integral);
867 case gt_zero:
868 case ge_zero:
869 /* The fsat doesn't add any information in these cases. */
870 break;
871
872 case ne_zero:
873 case unknown:
874 /* Since the result must be in [0, 1], the value must be >= 0. */
875 r.range = ge_zero;
876 break;
877 }
878 break;
879
880 case nir_op_fsign:
881 r = (struct ssa_result_range){
882 analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)).range,
883 true
884 };
885 break;
886
887 case nir_op_fsqrt:
888 case nir_op_frsq:
889 r = (struct ssa_result_range){ge_zero, false};
890 break;
891
892 case nir_op_ffloor: {
893 const struct ssa_result_range left =
894 analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
895
896 r.is_integral = true;
897
898 if (left.is_integral || left.range == le_zero || left.range == lt_zero)
899 r.range = left.range;
900 else if (left.range == ge_zero || left.range == gt_zero)
901 r.range = ge_zero;
902 else if (left.range == ne_zero)
903 r.range = unknown;
904
905 break;
906 }
907
908 case nir_op_fceil: {
909 const struct ssa_result_range left =
910 analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
911
912 r.is_integral = true;
913
914 if (left.is_integral || left.range == ge_zero || left.range == gt_zero)
915 r.range = left.range;
916 else if (left.range == le_zero || left.range == lt_zero)
917 r.range = le_zero;
918 else if (left.range == ne_zero)
919 r.range = unknown;
920
921 break;
922 }
923
924 case nir_op_ftrunc: {
925 const struct ssa_result_range left =
926 analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
927
928 r.is_integral = true;
929
930 if (left.is_integral)
931 r.range = left.range;
932 else if (left.range == ge_zero || left.range == gt_zero)
933 r.range = ge_zero;
934 else if (left.range == le_zero || left.range == lt_zero)
935 r.range = le_zero;
936 else if (left.range == ne_zero)
937 r.range = unknown;
938
939 break;
940 }
941
942 case nir_op_flt:
943 case nir_op_fge:
944 case nir_op_feq:
945 case nir_op_fneu:
946 case nir_op_ilt:
947 case nir_op_ige:
948 case nir_op_ieq:
949 case nir_op_ine:
950 case nir_op_ult:
951 case nir_op_uge:
952 /* Boolean results are 0 or -1. */
953 r = (struct ssa_result_range){le_zero, false};
954 break;
955
956 case nir_op_fpow: {
957 /* Due to flush-to-zero semanatics of floating-point numbers with very
958 * small mangnitudes, we can never really be sure a result will be
959 * non-zero.
960 *
961 * NIR uses pow() and powf() to constant evaluate nir_op_fpow. The man
962 * page for that function says:
963 *
964 * If y is 0, the result is 1.0 (even if x is a NaN).
965 *
966 * gt_zero: pow(*, eq_zero)
967 * | pow(eq_zero, lt_zero) # 0^-y = +inf
968 * | pow(eq_zero, le_zero) # 0^-y = +inf or 0^0 = 1.0
969 * ;
970 *
971 * eq_zero: pow(eq_zero, gt_zero)
972 * ;
973 *
974 * ge_zero: pow(gt_zero, gt_zero)
975 * | pow(gt_zero, ge_zero)
976 * | pow(gt_zero, lt_zero)
977 * | pow(gt_zero, le_zero)
978 * | pow(gt_zero, ne_zero)
979 * | pow(gt_zero, unknown)
980 * | pow(ge_zero, gt_zero)
981 * | pow(ge_zero, ge_zero)
982 * | pow(ge_zero, lt_zero)
983 * | pow(ge_zero, le_zero)
984 * | pow(ge_zero, ne_zero)
985 * | pow(ge_zero, unknown)
986 * | pow(eq_zero, ge_zero) # 0^0 = 1.0 or 0^+y = 0.0
987 * | pow(eq_zero, ne_zero) # 0^-y = +inf or 0^+y = 0.0
988 * | pow(eq_zero, unknown) # union of all other y cases
989 * ;
990 *
991 * All other cases are unknown.
992 *
993 * We could do better if the right operand is a constant, integral
994 * value.
995 */
996 static const enum ssa_ranges table[last_range + 1][last_range + 1] = {
997 /* left\right unknown lt_zero le_zero gt_zero ge_zero ne_zero eq_zero */
998 /* unknown */ { _______, _______, _______, _______, _______, _______, gt_zero },
999 /* lt_zero */ { _______, _______, _______, _______, _______, _______, gt_zero },
1000 /* le_zero */ { _______, _______, _______, _______, _______, _______, gt_zero },
1001 /* gt_zero */ { ge_zero, ge_zero, ge_zero, ge_zero, ge_zero, ge_zero, gt_zero },
1002 /* ge_zero */ { ge_zero, ge_zero, ge_zero, ge_zero, ge_zero, ge_zero, gt_zero },
1003 /* ne_zero */ { _______, _______, _______, _______, _______, _______, gt_zero },
1004 /* eq_zero */ { ge_zero, gt_zero, gt_zero, eq_zero, ge_zero, ge_zero, gt_zero },
1005 };
1006
1007 const struct ssa_result_range left =
1008 analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
1009 const struct ssa_result_range right =
1010 analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
1011
1012 ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_2_SOURCE(table);
1013 ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_2_SOURCE(table);
1014
1015 r.is_integral = left.is_integral && right.is_integral &&
1016 is_not_negative(right.range);
1017 r.range = table[left.range][right.range];
1018 break;
1019 }
1020
1021 case nir_op_ffma: {
1022 const struct ssa_result_range first =
1023 analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
1024 const struct ssa_result_range second =
1025 analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
1026 const struct ssa_result_range third =
1027 analyze_expression(alu, 2, ht, nir_alu_src_type(alu, 2));
1028
1029 r.is_integral = first.is_integral && second.is_integral &&
1030 third.is_integral;
1031
1032 enum ssa_ranges fmul_range;
1033
1034 if (first.range != eq_zero && nir_alu_srcs_equal(alu, alu, 0, 1)) {
1035 /* See handling of nir_op_fmul for explanation of why ge_zero is the
1036 * range.
1037 */
1038 fmul_range = ge_zero;
1039 } else if (first.range != eq_zero && nir_alu_srcs_negative_equal(alu, alu, 0, 1)) {
1040 /* -x * x => le_zero */
1041 fmul_range = le_zero;
1042 } else
1043 fmul_range = fmul_table[first.range][second.range];
1044
1045 r.range = fadd_table[fmul_range][third.range];
1046 break;
1047 }
1048
1049 case nir_op_flrp: {
1050 const struct ssa_result_range first =
1051 analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
1052 const struct ssa_result_range second =
1053 analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
1054 const struct ssa_result_range third =
1055 analyze_expression(alu, 2, ht, nir_alu_src_type(alu, 2));
1056
1057 r.is_integral = first.is_integral && second.is_integral &&
1058 third.is_integral;
1059
1060 /* Decompose the flrp to first + third * (second + -first) */
1061 const enum ssa_ranges inner_fadd_range =
1062 fadd_table[second.range][fneg_table[first.range]];
1063
1064 const enum ssa_ranges fmul_range =
1065 fmul_table[third.range][inner_fadd_range];
1066
1067 r.range = fadd_table[first.range][fmul_range];
1068 break;
1069 }
1070
1071 default:
1072 r = (struct ssa_result_range){unknown, false};
1073 break;
1074 }
1075
1076 if (r.range == eq_zero)
1077 r.is_integral = true;
1078
1079 _mesa_hash_table_insert(ht, pack_key(alu, use_type), pack_data(r));
1080 return r;
1081 }
1082
1083 #undef _______
1084
1085 struct ssa_result_range
nir_analyze_range(struct hash_table * range_ht,const nir_alu_instr * instr,unsigned src)1086 nir_analyze_range(struct hash_table *range_ht,
1087 const nir_alu_instr *instr, unsigned src)
1088 {
1089 return analyze_expression(instr, src, range_ht,
1090 nir_alu_src_type(instr, src));
1091 }
1092
bitmask(uint32_t size)1093 static uint32_t bitmask(uint32_t size) {
1094 return size >= 32 ? 0xffffffffu : ((uint32_t)1 << size) - 1u;
1095 }
1096
mul_clamp(uint32_t a,uint32_t b)1097 static uint64_t mul_clamp(uint32_t a, uint32_t b)
1098 {
1099 if (a != 0 && (a * b) / a != b)
1100 return (uint64_t)UINT32_MAX + 1;
1101 else
1102 return a * b;
1103 }
1104
1105 /* recursively gather at most "buf_size" phi/bcsel sources */
1106 static unsigned
search_phi_bcsel(nir_ssa_scalar scalar,nir_ssa_scalar * buf,unsigned buf_size,struct set * visited)1107 search_phi_bcsel(nir_ssa_scalar scalar, nir_ssa_scalar *buf, unsigned buf_size, struct set *visited)
1108 {
1109 if (_mesa_set_search(visited, scalar.def))
1110 return 0;
1111 _mesa_set_add(visited, scalar.def);
1112
1113 if (scalar.def->parent_instr->type == nir_instr_type_phi) {
1114 nir_phi_instr *phi = nir_instr_as_phi(scalar.def->parent_instr);
1115 unsigned num_sources_left = exec_list_length(&phi->srcs);
1116 if (buf_size >= num_sources_left) {
1117 unsigned total_added = 0;
1118 nir_foreach_phi_src(src, phi) {
1119 unsigned added = search_phi_bcsel(
1120 (nir_ssa_scalar){src->src.ssa, 0}, buf + total_added, buf_size - num_sources_left, visited);
1121 buf_size -= added;
1122 total_added += added;
1123 num_sources_left--;
1124 }
1125 return total_added;
1126 }
1127 }
1128
1129 if (nir_ssa_scalar_is_alu(scalar)) {
1130 nir_op op = nir_ssa_scalar_alu_op(scalar);
1131
1132 if ((op == nir_op_bcsel || op == nir_op_b32csel) && buf_size >= 2) {
1133 nir_ssa_scalar src0 = nir_ssa_scalar_chase_alu_src(scalar, 0);
1134 nir_ssa_scalar src1 = nir_ssa_scalar_chase_alu_src(scalar, 1);
1135
1136 unsigned added = search_phi_bcsel(src0, buf, buf_size - 1, visited);
1137 buf_size -= added;
1138 added += search_phi_bcsel(src1, buf + added, buf_size, visited);
1139 return added;
1140 }
1141 }
1142
1143 buf[0] = scalar;
1144 return 1;
1145 }
1146
1147 static nir_variable *
lookup_input(nir_shader * shader,unsigned driver_location)1148 lookup_input(nir_shader *shader, unsigned driver_location)
1149 {
1150 return nir_find_variable_with_driver_location(shader, nir_var_shader_in,
1151 driver_location);
1152 }
1153
1154 uint32_t
nir_unsigned_upper_bound(nir_shader * shader,struct hash_table * range_ht,nir_ssa_scalar scalar,const nir_unsigned_upper_bound_config * config)1155 nir_unsigned_upper_bound(nir_shader *shader, struct hash_table *range_ht,
1156 nir_ssa_scalar scalar,
1157 const nir_unsigned_upper_bound_config *config)
1158 {
1159 assert(scalar.def->bit_size <= 32);
1160
1161 if (nir_ssa_scalar_is_const(scalar))
1162 return nir_ssa_scalar_as_uint(scalar);
1163
1164 /* keys can't be 0, so we have to add 1 to the index */
1165 void *key = (void*)(((uintptr_t)(scalar.def->index + 1) << 4) | scalar.comp);
1166 struct hash_entry *he = _mesa_hash_table_search(range_ht, key);
1167 if (he != NULL)
1168 return (uintptr_t)he->data;
1169
1170 uint32_t max = bitmask(scalar.def->bit_size);
1171
1172 if (scalar.def->parent_instr->type == nir_instr_type_intrinsic) {
1173 uint32_t res = max;
1174 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(scalar.def->parent_instr);
1175 switch (intrin->intrinsic) {
1176 case nir_intrinsic_load_local_invocation_index:
1177 if (shader->info.cs.local_size_variable) {
1178 res = config->max_work_group_invocations - 1;
1179 } else {
1180 res = (shader->info.cs.local_size[0] *
1181 shader->info.cs.local_size[1] *
1182 shader->info.cs.local_size[2]) - 1u;
1183 }
1184 break;
1185 case nir_intrinsic_load_local_invocation_id:
1186 if (shader->info.cs.local_size_variable)
1187 res = config->max_work_group_size[scalar.comp] - 1u;
1188 else
1189 res = shader->info.cs.local_size[scalar.comp] - 1u;
1190 break;
1191 case nir_intrinsic_load_work_group_id:
1192 res = config->max_work_group_count[scalar.comp] - 1u;
1193 break;
1194 case nir_intrinsic_load_num_work_groups:
1195 res = config->max_work_group_count[scalar.comp];
1196 break;
1197 case nir_intrinsic_load_global_invocation_id:
1198 if (shader->info.cs.local_size_variable) {
1199 res = mul_clamp(config->max_work_group_size[scalar.comp],
1200 config->max_work_group_count[scalar.comp]) - 1u;
1201 } else {
1202 res = (shader->info.cs.local_size[scalar.comp] *
1203 config->max_work_group_count[scalar.comp]) - 1u;
1204 }
1205 break;
1206 case nir_intrinsic_load_subgroup_invocation:
1207 case nir_intrinsic_first_invocation:
1208 case nir_intrinsic_mbcnt_amd:
1209 res = config->max_subgroup_size - 1;
1210 break;
1211 case nir_intrinsic_load_subgroup_size:
1212 res = config->max_subgroup_size;
1213 break;
1214 case nir_intrinsic_load_subgroup_id:
1215 case nir_intrinsic_load_num_subgroups: {
1216 uint32_t work_group_size = config->max_work_group_invocations;
1217 if (!shader->info.cs.local_size_variable) {
1218 work_group_size = shader->info.cs.local_size[0] *
1219 shader->info.cs.local_size[1] *
1220 shader->info.cs.local_size[2];
1221 }
1222 res = (work_group_size + config->min_subgroup_size - 1) / config->min_subgroup_size;
1223 if (intrin->intrinsic == nir_intrinsic_load_subgroup_id)
1224 res--;
1225 break;
1226 }
1227 case nir_intrinsic_load_input: {
1228 if (shader->info.stage == MESA_SHADER_VERTEX && nir_src_is_const(intrin->src[0])) {
1229 nir_variable *var = lookup_input(shader, nir_intrinsic_base(intrin));
1230 if (var) {
1231 int loc = var->data.location - VERT_ATTRIB_GENERIC0;
1232 if (loc >= 0)
1233 res = config->vertex_attrib_max[loc];
1234 }
1235 }
1236 break;
1237 }
1238 case nir_intrinsic_reduce:
1239 case nir_intrinsic_inclusive_scan:
1240 case nir_intrinsic_exclusive_scan: {
1241 nir_op op = nir_intrinsic_reduction_op(intrin);
1242 if (op == nir_op_umin || op == nir_op_umax || op == nir_op_imin || op == nir_op_imax)
1243 res = nir_unsigned_upper_bound(shader, range_ht, (nir_ssa_scalar){intrin->src[0].ssa, 0}, config);
1244 break;
1245 }
1246 case nir_intrinsic_read_first_invocation:
1247 case nir_intrinsic_read_invocation:
1248 case nir_intrinsic_shuffle:
1249 case nir_intrinsic_shuffle_xor:
1250 case nir_intrinsic_shuffle_up:
1251 case nir_intrinsic_shuffle_down:
1252 case nir_intrinsic_quad_broadcast:
1253 case nir_intrinsic_quad_swap_horizontal:
1254 case nir_intrinsic_quad_swap_vertical:
1255 case nir_intrinsic_quad_swap_diagonal:
1256 case nir_intrinsic_quad_swizzle_amd:
1257 case nir_intrinsic_masked_swizzle_amd:
1258 res = nir_unsigned_upper_bound(shader, range_ht, (nir_ssa_scalar){intrin->src[0].ssa, 0}, config);
1259 break;
1260 case nir_intrinsic_write_invocation_amd: {
1261 uint32_t src0 = nir_unsigned_upper_bound(shader, range_ht, (nir_ssa_scalar){intrin->src[0].ssa, 0}, config);
1262 uint32_t src1 = nir_unsigned_upper_bound(shader, range_ht, (nir_ssa_scalar){intrin->src[1].ssa, 0}, config);
1263 res = MAX2(src0, src1);
1264 break;
1265 }
1266 default:
1267 break;
1268 }
1269 if (res != max)
1270 _mesa_hash_table_insert(range_ht, key, (void*)(uintptr_t)res);
1271 return res;
1272 }
1273
1274 if (scalar.def->parent_instr->type == nir_instr_type_phi) {
1275 bool cyclic = false;
1276 nir_foreach_phi_src(src, nir_instr_as_phi(scalar.def->parent_instr)) {
1277 if (nir_block_dominates(scalar.def->parent_instr->block, src->pred)) {
1278 cyclic = true;
1279 break;
1280 }
1281 }
1282
1283 uint32_t res = 0;
1284 if (cyclic) {
1285 _mesa_hash_table_insert(range_ht, key, (void*)(uintptr_t)max);
1286
1287 struct set *visited = _mesa_pointer_set_create(NULL);
1288 nir_ssa_scalar defs[64];
1289 unsigned def_count = search_phi_bcsel(scalar, defs, 64, visited);
1290 _mesa_set_destroy(visited, NULL);
1291
1292 for (unsigned i = 0; i < def_count; i++)
1293 res = MAX2(res, nir_unsigned_upper_bound(shader, range_ht, defs[i], config));
1294 } else {
1295 nir_foreach_phi_src(src, nir_instr_as_phi(scalar.def->parent_instr)) {
1296 res = MAX2(res, nir_unsigned_upper_bound(
1297 shader, range_ht, (nir_ssa_scalar){src->src.ssa, 0}, config));
1298 }
1299 }
1300
1301 _mesa_hash_table_insert(range_ht, key, (void*)(uintptr_t)res);
1302 return res;
1303 }
1304
1305 if (nir_ssa_scalar_is_alu(scalar)) {
1306 nir_op op = nir_ssa_scalar_alu_op(scalar);
1307
1308 switch (op) {
1309 case nir_op_umin:
1310 case nir_op_imin:
1311 case nir_op_imax:
1312 case nir_op_umax:
1313 case nir_op_iand:
1314 case nir_op_ior:
1315 case nir_op_ixor:
1316 case nir_op_ishl:
1317 case nir_op_imul:
1318 case nir_op_ushr:
1319 case nir_op_ishr:
1320 case nir_op_iadd:
1321 case nir_op_umod:
1322 case nir_op_udiv:
1323 case nir_op_bcsel:
1324 case nir_op_b32csel:
1325 case nir_op_ubfe:
1326 case nir_op_bfm:
1327 case nir_op_f2u32:
1328 case nir_op_fmul:
1329 break;
1330 default:
1331 return max;
1332 }
1333
1334 uint32_t src0 = nir_unsigned_upper_bound(shader, range_ht, nir_ssa_scalar_chase_alu_src(scalar, 0), config);
1335 uint32_t src1 = max, src2 = max;
1336 if (nir_op_infos[op].num_inputs > 1)
1337 src1 = nir_unsigned_upper_bound(shader, range_ht, nir_ssa_scalar_chase_alu_src(scalar, 1), config);
1338 if (nir_op_infos[op].num_inputs > 2)
1339 src2 = nir_unsigned_upper_bound(shader, range_ht, nir_ssa_scalar_chase_alu_src(scalar, 2), config);
1340
1341 uint32_t res = max;
1342 switch (op) {
1343 case nir_op_umin:
1344 res = src0 < src1 ? src0 : src1;
1345 break;
1346 case nir_op_imin:
1347 case nir_op_imax:
1348 case nir_op_umax:
1349 res = src0 > src1 ? src0 : src1;
1350 break;
1351 case nir_op_iand:
1352 res = bitmask(util_last_bit64(src0)) & bitmask(util_last_bit64(src1));
1353 break;
1354 case nir_op_ior:
1355 case nir_op_ixor:
1356 res = bitmask(util_last_bit64(src0)) | bitmask(util_last_bit64(src1));
1357 break;
1358 case nir_op_ishl:
1359 if (util_last_bit64(src0) + src1 > scalar.def->bit_size)
1360 res = max; /* overflow */
1361 else
1362 res = src0 << MIN2(src1, scalar.def->bit_size - 1u);
1363 break;
1364 case nir_op_imul:
1365 if (src0 != 0 && (src0 * src1) / src0 != src1)
1366 res = max;
1367 else
1368 res = src0 * src1;
1369 break;
1370 case nir_op_ushr: {
1371 nir_ssa_scalar src1_scalar = nir_ssa_scalar_chase_alu_src(scalar, 1);
1372 if (nir_ssa_scalar_is_const(src1_scalar))
1373 res = src0 >> nir_ssa_scalar_as_uint(src1_scalar);
1374 else
1375 res = src0;
1376 break;
1377 }
1378 case nir_op_ishr: {
1379 nir_ssa_scalar src1_scalar = nir_ssa_scalar_chase_alu_src(scalar, 1);
1380 if (src0 <= 2147483647 && nir_ssa_scalar_is_const(src1_scalar))
1381 res = src0 >> nir_ssa_scalar_as_uint(src1_scalar);
1382 else
1383 res = src0;
1384 break;
1385 }
1386 case nir_op_iadd:
1387 if (src0 + src1 < src0)
1388 res = max; /* overflow */
1389 else
1390 res = src0 + src1;
1391 break;
1392 case nir_op_umod:
1393 res = src1 ? src1 - 1 : 0;
1394 break;
1395 case nir_op_udiv: {
1396 nir_ssa_scalar src1_scalar = nir_ssa_scalar_chase_alu_src(scalar, 1);
1397 if (nir_ssa_scalar_is_const(src1_scalar))
1398 res = nir_ssa_scalar_as_uint(src1_scalar) ? src0 / nir_ssa_scalar_as_uint(src1_scalar) : 0;
1399 else
1400 res = src0;
1401 break;
1402 }
1403 case nir_op_bcsel:
1404 case nir_op_b32csel:
1405 res = src1 > src2 ? src1 : src2;
1406 break;
1407 case nir_op_ubfe:
1408 res = bitmask(MIN2(src2, scalar.def->bit_size));
1409 break;
1410 case nir_op_bfm: {
1411 nir_ssa_scalar src1_scalar = nir_ssa_scalar_chase_alu_src(scalar, 1);
1412 if (nir_ssa_scalar_is_const(src1_scalar)) {
1413 src0 = MIN2(src0, 31);
1414 src1 = nir_ssa_scalar_as_uint(src1_scalar) & 0x1fu;
1415 res = bitmask(src0) << src1;
1416 } else {
1417 src0 = MIN2(src0, 31);
1418 src1 = MIN2(src1, 31);
1419 res = bitmask(MIN2(src0 + src1, 32));
1420 }
1421 break;
1422 }
1423 /* limited floating-point support for f2u32(fmul(load_input(), <constant>)) */
1424 case nir_op_f2u32:
1425 /* infinity/NaN starts at 0x7f800000u, negative numbers at 0x80000000 */
1426 if (src0 < 0x7f800000u) {
1427 float val;
1428 memcpy(&val, &src0, 4);
1429 res = (uint32_t)val;
1430 }
1431 break;
1432 case nir_op_fmul:
1433 /* infinity/NaN starts at 0x7f800000u, negative numbers at 0x80000000 */
1434 if (src0 < 0x7f800000u && src1 < 0x7f800000u) {
1435 float src0_f, src1_f;
1436 memcpy(&src0_f, &src0, 4);
1437 memcpy(&src1_f, &src1, 4);
1438 /* not a proper rounding-up multiplication, but should be good enough */
1439 float max_f = ceilf(src0_f) * ceilf(src1_f);
1440 memcpy(&res, &max_f, 4);
1441 }
1442 break;
1443 default:
1444 res = max;
1445 break;
1446 }
1447 _mesa_hash_table_insert(range_ht, key, (void*)(uintptr_t)res);
1448 return res;
1449 }
1450
1451 return max;
1452 }
1453
1454 bool
nir_addition_might_overflow(nir_shader * shader,struct hash_table * range_ht,nir_ssa_scalar ssa,unsigned const_val,const nir_unsigned_upper_bound_config * config)1455 nir_addition_might_overflow(nir_shader *shader, struct hash_table *range_ht,
1456 nir_ssa_scalar ssa, unsigned const_val,
1457 const nir_unsigned_upper_bound_config *config)
1458 {
1459 if (nir_ssa_scalar_is_alu(ssa)) {
1460 nir_op alu_op = nir_ssa_scalar_alu_op(ssa);
1461
1462 /* iadd(imul(a, #b), #c) */
1463 if (alu_op == nir_op_imul || alu_op == nir_op_ishl) {
1464 nir_ssa_scalar mul_src0 = nir_ssa_scalar_chase_alu_src(ssa, 0);
1465 nir_ssa_scalar mul_src1 = nir_ssa_scalar_chase_alu_src(ssa, 1);
1466 uint32_t stride = 1;
1467 if (nir_ssa_scalar_is_const(mul_src0))
1468 stride = nir_ssa_scalar_as_uint(mul_src0);
1469 else if (nir_ssa_scalar_is_const(mul_src1))
1470 stride = nir_ssa_scalar_as_uint(mul_src1);
1471
1472 if (alu_op == nir_op_ishl)
1473 stride = 1u << (stride % 32u);
1474
1475 if (!stride || const_val <= UINT32_MAX - (UINT32_MAX / stride * stride))
1476 return false;
1477 }
1478
1479 /* iadd(iand(a, #b), #c) */
1480 if (alu_op == nir_op_iand) {
1481 nir_ssa_scalar and_src0 = nir_ssa_scalar_chase_alu_src(ssa, 0);
1482 nir_ssa_scalar and_src1 = nir_ssa_scalar_chase_alu_src(ssa, 1);
1483 uint32_t mask = 0xffffffff;
1484 if (nir_ssa_scalar_is_const(and_src0))
1485 mask = nir_ssa_scalar_as_uint(and_src0);
1486 else if (nir_ssa_scalar_is_const(and_src1))
1487 mask = nir_ssa_scalar_as_uint(and_src1);
1488 if (mask == 0 || const_val < (1u << (ffs(mask) - 1)))
1489 return false;
1490 }
1491 }
1492
1493 uint32_t ub = nir_unsigned_upper_bound(shader, range_ht, ssa, config);
1494 return const_val + ub < const_val;
1495 }
1496
1497