• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2018 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 #include <math.h>
24 #include <float.h>
25 #include "nir.h"
26 #include "nir_range_analysis.h"
27 #include "util/hash_table.h"
28 
29 /**
30  * Analyzes a sequence of operations to determine some aspects of the range of
31  * the result.
32  */
33 
34 static bool
is_not_negative(enum ssa_ranges r)35 is_not_negative(enum ssa_ranges r)
36 {
37    return r == gt_zero || r == ge_zero || r == eq_zero;
38 }
39 
40 static void *
pack_data(const struct ssa_result_range r)41 pack_data(const struct ssa_result_range r)
42 {
43    return (void *)(uintptr_t)(r.range | r.is_integral << 8);
44 }
45 
46 static struct ssa_result_range
unpack_data(const void * p)47 unpack_data(const void *p)
48 {
49    const uintptr_t v = (uintptr_t) p;
50 
51    return (struct ssa_result_range){v & 0xff, (v & 0x0ff00) != 0};
52 }
53 
54 static void *
pack_key(const struct nir_alu_instr * instr,nir_alu_type type)55 pack_key(const struct nir_alu_instr *instr, nir_alu_type type)
56 {
57    uintptr_t type_encoding;
58    uintptr_t ptr = (uintptr_t) instr;
59 
60    /* The low 2 bits have to be zero or this whole scheme falls apart. */
61    assert((ptr & 0x3) == 0);
62 
63    /* NIR is typeless in the sense that sequences of bits have whatever
64     * meaning is attached to them by the instruction that consumes them.
65     * However, the number of bits must match between producer and consumer.
66     * As a result, the number of bits does not need to be encoded here.
67     */
68    switch (nir_alu_type_get_base_type(type)) {
69    case nir_type_int:   type_encoding = 0; break;
70    case nir_type_uint:  type_encoding = 1; break;
71    case nir_type_bool:  type_encoding = 2; break;
72    case nir_type_float: type_encoding = 3; break;
73    default: unreachable("Invalid base type.");
74    }
75 
76    return (void *)(ptr | type_encoding);
77 }
78 
79 static nir_alu_type
nir_alu_src_type(const nir_alu_instr * instr,unsigned src)80 nir_alu_src_type(const nir_alu_instr *instr, unsigned src)
81 {
82    return nir_alu_type_get_base_type(nir_op_infos[instr->op].input_types[src]) |
83           nir_src_bit_size(instr->src[src].src);
84 }
85 
86 static struct ssa_result_range
analyze_constant(const struct nir_alu_instr * instr,unsigned src,nir_alu_type use_type)87 analyze_constant(const struct nir_alu_instr *instr, unsigned src,
88                  nir_alu_type use_type)
89 {
90    uint8_t swizzle[NIR_MAX_VEC_COMPONENTS] = { 0, 1, 2, 3,
91                                                4, 5, 6, 7,
92                                                8, 9, 10, 11,
93                                                12, 13, 14, 15 };
94 
95    /* If the source is an explicitly sized source, then we need to reset
96     * both the number of components and the swizzle.
97     */
98    const unsigned num_components = nir_ssa_alu_instr_src_components(instr, src);
99 
100    for (unsigned i = 0; i < num_components; ++i)
101       swizzle[i] = instr->src[src].swizzle[i];
102 
103    const nir_load_const_instr *const load =
104       nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr);
105 
106    struct ssa_result_range r = { unknown, false };
107 
108    switch (nir_alu_type_get_base_type(use_type)) {
109    case nir_type_float: {
110       double min_value = DBL_MAX;
111       double max_value = -DBL_MAX;
112       bool any_zero = false;
113       bool all_zero = true;
114 
115       r.is_integral = true;
116 
117       for (unsigned i = 0; i < num_components; ++i) {
118          const double v = nir_const_value_as_float(load->value[swizzle[i]],
119                                                    load->def.bit_size);
120 
121          if (floor(v) != v)
122             r.is_integral = false;
123 
124          any_zero = any_zero || (v == 0.0);
125          all_zero = all_zero && (v == 0.0);
126          min_value = MIN2(min_value, v);
127          max_value = MAX2(max_value, v);
128       }
129 
130       assert(any_zero >= all_zero);
131       assert(isnan(max_value) || max_value >= min_value);
132 
133       if (all_zero)
134          r.range = eq_zero;
135       else if (min_value > 0.0)
136          r.range = gt_zero;
137       else if (min_value == 0.0)
138          r.range = ge_zero;
139       else if (max_value < 0.0)
140          r.range = lt_zero;
141       else if (max_value == 0.0)
142          r.range = le_zero;
143       else if (!any_zero)
144          r.range = ne_zero;
145       else
146          r.range = unknown;
147 
148       return r;
149    }
150 
151    case nir_type_int:
152    case nir_type_bool: {
153       int64_t min_value = INT_MAX;
154       int64_t max_value = INT_MIN;
155       bool any_zero = false;
156       bool all_zero = true;
157 
158       for (unsigned i = 0; i < num_components; ++i) {
159          const int64_t v = nir_const_value_as_int(load->value[swizzle[i]],
160                                                   load->def.bit_size);
161 
162          any_zero = any_zero || (v == 0);
163          all_zero = all_zero && (v == 0);
164          min_value = MIN2(min_value, v);
165          max_value = MAX2(max_value, v);
166       }
167 
168       assert(any_zero >= all_zero);
169       assert(max_value >= min_value);
170 
171       if (all_zero)
172          r.range = eq_zero;
173       else if (min_value > 0)
174          r.range = gt_zero;
175       else if (min_value == 0)
176          r.range = ge_zero;
177       else if (max_value < 0)
178          r.range = lt_zero;
179       else if (max_value == 0)
180          r.range = le_zero;
181       else if (!any_zero)
182          r.range = ne_zero;
183       else
184          r.range = unknown;
185 
186       return r;
187    }
188 
189    case nir_type_uint: {
190       bool any_zero = false;
191       bool all_zero = true;
192 
193       for (unsigned i = 0; i < num_components; ++i) {
194          const uint64_t v = nir_const_value_as_uint(load->value[swizzle[i]],
195                                                     load->def.bit_size);
196 
197          any_zero = any_zero || (v == 0);
198          all_zero = all_zero && (v == 0);
199       }
200 
201       assert(any_zero >= all_zero);
202 
203       if (all_zero)
204          r.range = eq_zero;
205       else if (any_zero)
206          r.range = ge_zero;
207       else
208          r.range = gt_zero;
209 
210       return r;
211    }
212 
213    default:
214       unreachable("Invalid alu source type");
215    }
216 }
217 
218 /**
219  * Short-hand name for use in the tables in analyze_expression.  If this name
220  * becomes a problem on some compiler, we can change it to _.
221  */
222 #define _______ unknown
223 
224 
225 #if defined(__clang__)
226    /* clang wants _Pragma("unroll X") */
227    #define pragma_unroll_5 _Pragma("unroll 5")
228    #define pragma_unroll_7 _Pragma("unroll 7")
229 /* gcc wants _Pragma("GCC unroll X") */
230 #elif defined(__GNUC__)
231    #if __GNUC__ >= 8
232       #define pragma_unroll_5 _Pragma("GCC unroll 5")
233       #define pragma_unroll_7 _Pragma("GCC unroll 7")
234    #else
235       #pragma GCC optimize ("unroll-loops")
236       #define pragma_unroll_5
237       #define pragma_unroll_7
238    #endif
239 #else
240    /* MSVC doesn't have C99's _Pragma() */
241    #define pragma_unroll_5
242    #define pragma_unroll_7
243 #endif
244 
245 
246 #ifndef NDEBUG
247 #define ASSERT_TABLE_IS_COMMUTATIVE(t)                        \
248    do {                                                       \
249       static bool first = true;                               \
250       if (first) {                                            \
251          first = false;                                       \
252          pragma_unroll_7                                      \
253          for (unsigned r = 0; r < ARRAY_SIZE(t); r++) {       \
254             pragma_unroll_7                                   \
255             for (unsigned c = 0; c < ARRAY_SIZE(t[0]); c++)   \
256                assert(t[r][c] == t[c][r]);                    \
257          }                                                    \
258       }                                                       \
259    } while (false)
260 
261 #define ASSERT_TABLE_IS_DIAGONAL(t)                           \
262    do {                                                       \
263       static bool first = true;                               \
264       if (first) {                                            \
265          first = false;                                       \
266          pragma_unroll_7                                      \
267          for (unsigned r = 0; r < ARRAY_SIZE(t); r++)         \
268             assert(t[r][r] == r);                             \
269       }                                                       \
270    } while (false)
271 
272 static enum ssa_ranges
union_ranges(enum ssa_ranges a,enum ssa_ranges b)273 union_ranges(enum ssa_ranges a, enum ssa_ranges b)
274 {
275    static const enum ssa_ranges union_table[last_range + 1][last_range + 1] = {
276       /* left\right   unknown  lt_zero  le_zero  gt_zero  ge_zero  ne_zero  eq_zero */
277       /* unknown */ { _______, _______, _______, _______, _______, _______, _______ },
278       /* lt_zero */ { _______, lt_zero, le_zero, ne_zero, _______, ne_zero, le_zero },
279       /* le_zero */ { _______, le_zero, le_zero, _______, _______, _______, le_zero },
280       /* gt_zero */ { _______, ne_zero, _______, gt_zero, ge_zero, ne_zero, ge_zero },
281       /* ge_zero */ { _______, _______, _______, ge_zero, ge_zero, _______, ge_zero },
282       /* ne_zero */ { _______, ne_zero, _______, ne_zero, _______, ne_zero, _______ },
283       /* eq_zero */ { _______, le_zero, le_zero, ge_zero, ge_zero, _______, eq_zero },
284    };
285 
286    ASSERT_TABLE_IS_COMMUTATIVE(union_table);
287    ASSERT_TABLE_IS_DIAGONAL(union_table);
288 
289    return union_table[a][b];
290 }
291 
292 /* Verify that the 'unknown' entry in each row (or column) of the table is the
293  * union of all the other values in the row (or column).
294  */
295 #define ASSERT_UNION_OF_OTHERS_MATCHES_UNKNOWN_2_SOURCE(t)              \
296    do {                                                                 \
297       static bool first = true;                                         \
298       if (first) {                                                      \
299          first = false;                                                 \
300          pragma_unroll_7                                                \
301          for (unsigned i = 0; i < last_range; i++) {                    \
302             enum ssa_ranges col_range = t[i][unknown + 1];              \
303             enum ssa_ranges row_range = t[unknown + 1][i];              \
304                                                                         \
305             pragma_unroll_5                                             \
306             for (unsigned j = unknown + 2; j < last_range; j++) {       \
307                col_range = union_ranges(col_range, t[i][j]);            \
308                row_range = union_ranges(row_range, t[j][i]);            \
309             }                                                           \
310                                                                         \
311             assert(col_range == t[i][unknown]);                         \
312             assert(row_range == t[unknown][i]);                         \
313          }                                                              \
314       }                                                                 \
315    } while (false)
316 
317 /* For most operations, the union of ranges for a strict inequality and
318  * equality should be the range of the non-strict inequality (e.g.,
319  * union_ranges(range(op(lt_zero), range(op(eq_zero))) == range(op(le_zero)).
320  *
321  * Does not apply to selection-like opcodes (bcsel, fmin, fmax, etc.).
322  */
323 #define ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_1_SOURCE(t) \
324    do {                                                                 \
325       assert(union_ranges(t[lt_zero], t[eq_zero]) == t[le_zero]);       \
326       assert(union_ranges(t[gt_zero], t[eq_zero]) == t[ge_zero]);       \
327    } while (false)
328 
329 #define ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_2_SOURCE(t) \
330    do {                                                                 \
331       static bool first = true;                                         \
332       if (first) {                                                      \
333          first = false;                                                 \
334          pragma_unroll_7                                                \
335          for (unsigned i = 0; i < last_range; i++) {                    \
336             assert(union_ranges(t[i][lt_zero], t[i][eq_zero]) == t[i][le_zero]); \
337             assert(union_ranges(t[i][gt_zero], t[i][eq_zero]) == t[i][ge_zero]); \
338             assert(union_ranges(t[lt_zero][i], t[eq_zero][i]) == t[le_zero][i]); \
339             assert(union_ranges(t[gt_zero][i], t[eq_zero][i]) == t[ge_zero][i]); \
340          }                                                              \
341       }                                                                 \
342    } while (false)
343 
344 /* Several other unordered tuples span the range of "everything."  Each should
345  * have the same value as unknown: (lt_zero, ge_zero), (le_zero, gt_zero), and
346  * (eq_zero, ne_zero).  union_ranges is already commutative, so only one
347  * ordering needs to be checked.
348  *
349  * Does not apply to selection-like opcodes (bcsel, fmin, fmax, etc.).
350  *
351  * In cases where this can be used, it is unnecessary to also use
352  * ASSERT_UNION_OF_OTHERS_MATCHES_UNKNOWN_*_SOURCE.  For any range X,
353  * union_ranges(X, X) == X.  The disjoint ranges cover all of the non-unknown
354  * possibilities, so the union of all the unions of disjoint ranges is
355  * equivalent to the union of "others."
356  */
357 #define ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_1_SOURCE(t)            \
358    do {                                                                 \
359       assert(union_ranges(t[lt_zero], t[ge_zero]) == t[unknown]);       \
360       assert(union_ranges(t[le_zero], t[gt_zero]) == t[unknown]);       \
361       assert(union_ranges(t[eq_zero], t[ne_zero]) == t[unknown]);       \
362    } while (false)
363 
364 #define ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_2_SOURCE(t)            \
365    do {                                                                 \
366       static bool first = true;                                         \
367       if (first) {                                                      \
368          first = false;                                                 \
369          pragma_unroll_7                                                \
370          for (unsigned i = 0; i < last_range; i++) {                    \
371             assert(union_ranges(t[i][lt_zero], t[i][ge_zero]) ==        \
372                    t[i][unknown]);                                      \
373             assert(union_ranges(t[i][le_zero], t[i][gt_zero]) ==        \
374                    t[i][unknown]);                                      \
375             assert(union_ranges(t[i][eq_zero], t[i][ne_zero]) ==        \
376                    t[i][unknown]);                                      \
377                                                                         \
378             assert(union_ranges(t[lt_zero][i], t[ge_zero][i]) ==        \
379                    t[unknown][i]);                                      \
380             assert(union_ranges(t[le_zero][i], t[gt_zero][i]) ==        \
381                    t[unknown][i]);                                      \
382             assert(union_ranges(t[eq_zero][i], t[ne_zero][i]) ==        \
383                    t[unknown][i]);                                      \
384          }                                                              \
385       }                                                                 \
386    } while (false)
387 
388 #else
389 #define ASSERT_TABLE_IS_COMMUTATIVE(t)
390 #define ASSERT_TABLE_IS_DIAGONAL(t)
391 #define ASSERT_UNION_OF_OTHERS_MATCHES_UNKNOWN_2_SOURCE(t)
392 #define ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_1_SOURCE(t)
393 #define ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_2_SOURCE(t)
394 #define ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_1_SOURCE(t)
395 #define ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_2_SOURCE(t)
396 #endif
397 
398 /**
399  * Analyze an expression to determine the range of its result
400  *
401  * The end result of this analysis is a token that communicates something
402  * about the range of values.  There's an implicit grammar that produces
403  * tokens from sequences of literal values, other tokens, and operations.
404  * This function implements this grammar as a recursive-descent parser.  Some
405  * (but not all) of the grammar is listed in-line in the function.
406  */
407 static struct ssa_result_range
analyze_expression(const nir_alu_instr * instr,unsigned src,struct hash_table * ht,nir_alu_type use_type)408 analyze_expression(const nir_alu_instr *instr, unsigned src,
409                    struct hash_table *ht, nir_alu_type use_type)
410 {
411    /* Ensure that the _Pragma("GCC unroll 7") above are correct. */
412    STATIC_ASSERT(last_range + 1 == 7);
413 
414    if (!instr->src[src].src.is_ssa)
415       return (struct ssa_result_range){unknown, false};
416 
417    if (nir_src_is_const(instr->src[src].src))
418       return analyze_constant(instr, src, use_type);
419 
420    if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
421       return (struct ssa_result_range){unknown, false};
422 
423    const struct nir_alu_instr *const alu =
424        nir_instr_as_alu(instr->src[src].src.ssa->parent_instr);
425 
426    /* Bail if the type of the instruction generating the value does not match
427     * the type the value will be interpreted as.  int/uint/bool can be
428     * reinterpreted trivially.  The most important cases are between float and
429     * non-float.
430     */
431    if (alu->op != nir_op_mov && alu->op != nir_op_bcsel) {
432       const nir_alu_type use_base_type =
433          nir_alu_type_get_base_type(use_type);
434       const nir_alu_type src_base_type =
435          nir_alu_type_get_base_type(nir_op_infos[alu->op].output_type);
436 
437       if (use_base_type != src_base_type &&
438           (use_base_type == nir_type_float ||
439            src_base_type == nir_type_float)) {
440          return (struct ssa_result_range){unknown, false};
441       }
442    }
443 
444    struct hash_entry *he = _mesa_hash_table_search(ht, pack_key(alu, use_type));
445    if (he != NULL)
446       return unpack_data(he->data);
447 
448    struct ssa_result_range r = {unknown, false};
449 
450    /* ge_zero: ge_zero + ge_zero
451     *
452     * gt_zero: gt_zero + eq_zero
453     *        | gt_zero + ge_zero
454     *        | eq_zero + gt_zero   # Addition is commutative
455     *        | ge_zero + gt_zero   # Addition is commutative
456     *        | gt_zero + gt_zero
457     *        ;
458     *
459     * le_zero: le_zero + le_zero
460     *
461     * lt_zero: lt_zero + eq_zero
462     *        | lt_zero + le_zero
463     *        | eq_zero + lt_zero   # Addition is commutative
464     *        | le_zero + lt_zero   # Addition is commutative
465     *        | lt_zero + lt_zero
466     *        ;
467     *
468     * ne_zero: eq_zero + ne_zero
469     *        | ne_zero + eq_zero   # Addition is commutative
470     *        ;
471     *
472     * eq_zero: eq_zero + eq_zero
473     *        ;
474     *
475     * All other cases are 'unknown'.  The seeming odd entry is (ne_zero,
476     * ne_zero), but that could be (-5, +5) which is not ne_zero.
477     */
478    static const enum ssa_ranges fadd_table[last_range + 1][last_range + 1] = {
479       /* left\right   unknown  lt_zero  le_zero  gt_zero  ge_zero  ne_zero  eq_zero */
480       /* unknown */ { _______, _______, _______, _______, _______, _______, _______ },
481       /* lt_zero */ { _______, lt_zero, lt_zero, _______, _______, _______, lt_zero },
482       /* le_zero */ { _______, lt_zero, le_zero, _______, _______, _______, le_zero },
483       /* gt_zero */ { _______, _______, _______, gt_zero, gt_zero, _______, gt_zero },
484       /* ge_zero */ { _______, _______, _______, gt_zero, ge_zero, _______, ge_zero },
485       /* ne_zero */ { _______, _______, _______, _______, _______, _______, ne_zero },
486       /* eq_zero */ { _______, lt_zero, le_zero, gt_zero, ge_zero, ne_zero, eq_zero },
487    };
488 
489    ASSERT_TABLE_IS_COMMUTATIVE(fadd_table);
490    ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_2_SOURCE(fadd_table);
491    ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_2_SOURCE(fadd_table);
492 
493    /* Due to flush-to-zero semanatics of floating-point numbers with very
494     * small mangnitudes, we can never really be sure a result will be
495     * non-zero.
496     *
497     * ge_zero: ge_zero * ge_zero
498     *        | ge_zero * gt_zero
499     *        | ge_zero * eq_zero
500     *        | le_zero * lt_zero
501     *        | lt_zero * le_zero  # Multiplication is commutative
502     *        | le_zero * le_zero
503     *        | gt_zero * ge_zero  # Multiplication is commutative
504     *        | eq_zero * ge_zero  # Multiplication is commutative
505     *        | a * a              # Left source == right source
506     *        | gt_zero * gt_zero
507     *        | lt_zero * lt_zero
508     *        ;
509     *
510     * le_zero: ge_zero * le_zero
511     *        | ge_zero * lt_zero
512     *        | lt_zero * ge_zero  # Multiplication is commutative
513     *        | le_zero * ge_zero  # Multiplication is commutative
514     *        | le_zero * gt_zero
515     *        | lt_zero * gt_zero
516     *        | gt_zero * lt_zero  # Multiplication is commutative
517     *        ;
518     *
519     * eq_zero: eq_zero * <any>
520     *          <any> * eq_zero    # Multiplication is commutative
521     *
522     * All other cases are 'unknown'.
523     */
524    static const enum ssa_ranges fmul_table[last_range + 1][last_range + 1] = {
525       /* left\right   unknown  lt_zero  le_zero  gt_zero  ge_zero  ne_zero  eq_zero */
526       /* unknown */ { _______, _______, _______, _______, _______, _______, eq_zero },
527       /* lt_zero */ { _______, ge_zero, ge_zero, le_zero, le_zero, _______, eq_zero },
528       /* le_zero */ { _______, ge_zero, ge_zero, le_zero, le_zero, _______, eq_zero },
529       /* gt_zero */ { _______, le_zero, le_zero, ge_zero, ge_zero, _______, eq_zero },
530       /* ge_zero */ { _______, le_zero, le_zero, ge_zero, ge_zero, _______, eq_zero },
531       /* ne_zero */ { _______, _______, _______, _______, _______, _______, eq_zero },
532       /* eq_zero */ { eq_zero, eq_zero, eq_zero, eq_zero, eq_zero, eq_zero, eq_zero }
533    };
534 
535    ASSERT_TABLE_IS_COMMUTATIVE(fmul_table);
536    ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_2_SOURCE(fmul_table);
537    ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_2_SOURCE(fmul_table);
538 
539    static const enum ssa_ranges fneg_table[last_range + 1] = {
540    /* unknown  lt_zero  le_zero  gt_zero  ge_zero  ne_zero  eq_zero */
541       _______, gt_zero, ge_zero, lt_zero, le_zero, ne_zero, eq_zero
542    };
543 
544    ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_1_SOURCE(fneg_table);
545    ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_1_SOURCE(fneg_table);
546 
547 
548    switch (alu->op) {
549    case nir_op_b2f32:
550    case nir_op_b2i32:
551       r = (struct ssa_result_range){ge_zero, alu->op == nir_op_b2f32};
552       break;
553 
554    case nir_op_bcsel: {
555       const struct ssa_result_range left =
556          analyze_expression(alu, 1, ht, use_type);
557       const struct ssa_result_range right =
558          analyze_expression(alu, 2, ht, use_type);
559 
560       r.is_integral = left.is_integral && right.is_integral;
561 
562       /* le_zero: bcsel(<any>, le_zero, lt_zero)
563        *        | bcsel(<any>, eq_zero, lt_zero)
564        *        | bcsel(<any>, le_zero, eq_zero)
565        *        | bcsel(<any>, lt_zero, le_zero)
566        *        | bcsel(<any>, lt_zero, eq_zero)
567        *        | bcsel(<any>, eq_zero, le_zero)
568        *        | bcsel(<any>, le_zero, le_zero)
569        *        ;
570        *
571        * lt_zero: bcsel(<any>, lt_zero, lt_zero)
572        *        ;
573        *
574        * ge_zero: bcsel(<any>, ge_zero, ge_zero)
575        *        | bcsel(<any>, ge_zero, gt_zero)
576        *        | bcsel(<any>, ge_zero, eq_zero)
577        *        | bcsel(<any>, gt_zero, ge_zero)
578        *        | bcsel(<any>, eq_zero, ge_zero)
579        *        ;
580        *
581        * gt_zero: bcsel(<any>, gt_zero, gt_zero)
582        *        ;
583        *
584        * ne_zero: bcsel(<any>, ne_zero, gt_zero)
585        *        | bcsel(<any>, ne_zero, lt_zero)
586        *        | bcsel(<any>, gt_zero, lt_zero)
587        *        | bcsel(<any>, gt_zero, ne_zero)
588        *        | bcsel(<any>, lt_zero, ne_zero)
589        *        | bcsel(<any>, lt_zero, gt_zero)
590        *        | bcsel(<any>, ne_zero, ne_zero)
591        *        ;
592        *
593        * eq_zero: bcsel(<any>, eq_zero, eq_zero)
594        *        ;
595        *
596        * All other cases are 'unknown'.
597        *
598        * The ranges could be tightened if the range of the first source is
599        * known.  However, opt_algebraic will (eventually) elminiate the bcsel
600        * if the condition is known.
601        */
602       static const enum ssa_ranges table[last_range + 1][last_range + 1] = {
603          /* left\right   unknown  lt_zero  le_zero  gt_zero  ge_zero  ne_zero  eq_zero */
604          /* unknown */ { _______, _______, _______, _______, _______, _______, _______ },
605          /* lt_zero */ { _______, lt_zero, le_zero, ne_zero, _______, ne_zero, le_zero },
606          /* le_zero */ { _______, le_zero, le_zero, _______, _______, _______, le_zero },
607          /* gt_zero */ { _______, ne_zero, _______, gt_zero, ge_zero, ne_zero, ge_zero },
608          /* ge_zero */ { _______, _______, _______, ge_zero, ge_zero, _______, ge_zero },
609          /* ne_zero */ { _______, ne_zero, _______, ne_zero, _______, ne_zero, _______ },
610          /* eq_zero */ { _______, le_zero, le_zero, ge_zero, ge_zero, _______, eq_zero },
611       };
612 
613       ASSERT_TABLE_IS_COMMUTATIVE(table);
614       ASSERT_TABLE_IS_DIAGONAL(table);
615       ASSERT_UNION_OF_OTHERS_MATCHES_UNKNOWN_2_SOURCE(table);
616 
617       r.range = table[left.range][right.range];
618       break;
619    }
620 
621    case nir_op_i2f32:
622    case nir_op_u2f32:
623       r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
624 
625       r.is_integral = true;
626 
627       if (r.range == unknown && alu->op == nir_op_u2f32)
628          r.range = ge_zero;
629 
630       break;
631 
632    case nir_op_fabs:
633       r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
634 
635       switch (r.range) {
636       case unknown:
637       case le_zero:
638       case ge_zero:
639          r.range = ge_zero;
640          break;
641 
642       case lt_zero:
643       case gt_zero:
644       case ne_zero:
645          r.range = gt_zero;
646          break;
647 
648       case eq_zero:
649          break;
650       }
651 
652       break;
653 
654    case nir_op_fadd: {
655       const struct ssa_result_range left =
656          analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
657       const struct ssa_result_range right =
658          analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
659 
660       r.is_integral = left.is_integral && right.is_integral;
661       r.range = fadd_table[left.range][right.range];
662       break;
663    }
664 
665    case nir_op_fexp2: {
666       /* If the parameter might be less than zero, the mathematically result
667        * will be on (0, 1).  For sufficiently large magnitude negative
668        * parameters, the result will flush to zero.
669        */
670       static const enum ssa_ranges table[last_range + 1] = {
671       /* unknown  lt_zero  le_zero  gt_zero  ge_zero  ne_zero  eq_zero */
672          ge_zero, ge_zero, ge_zero, gt_zero, gt_zero, ge_zero, gt_zero
673       };
674 
675       r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
676 
677       ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_1_SOURCE(table);
678       ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_1_SOURCE(table);
679 
680       r.is_integral = r.is_integral && is_not_negative(r.range);
681       r.range = table[r.range];
682       break;
683    }
684 
685    case nir_op_fmax: {
686       const struct ssa_result_range left =
687          analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
688       const struct ssa_result_range right =
689          analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
690 
691       r.is_integral = left.is_integral && right.is_integral;
692 
693       /* gt_zero: fmax(gt_zero, *)
694        *        | fmax(*, gt_zero)        # Treat fmax as commutative
695        *        ;
696        *
697        * ge_zero: fmax(ge_zero, ne_zero)
698        *        | fmax(ge_zero, lt_zero)
699        *        | fmax(ge_zero, le_zero)
700        *        | fmax(ge_zero, eq_zero)
701        *        | fmax(ne_zero, ge_zero)  # Treat fmax as commutative
702        *        | fmax(lt_zero, ge_zero)  # Treat fmax as commutative
703        *        | fmax(le_zero, ge_zero)  # Treat fmax as commutative
704        *        | fmax(eq_zero, ge_zero)  # Treat fmax as commutative
705        *        | fmax(ge_zero, ge_zero)
706        *        ;
707        *
708        * le_zero: fmax(le_zero, lt_zero)
709        *        | fmax(lt_zero, le_zero)  # Treat fmax as commutative
710        *        | fmax(le_zero, le_zero)
711        *        ;
712        *
713        * lt_zero: fmax(lt_zero, lt_zero)
714        *        ;
715        *
716        * ne_zero: fmax(ne_zero, lt_zero)
717        *        | fmax(lt_zero, ne_zero)  # Treat fmax as commutative
718        *        | fmax(ne_zero, ne_zero)
719        *        ;
720        *
721        * eq_zero: fmax(eq_zero, le_zero)
722        *        | fmax(eq_zero, lt_zero)
723        *        | fmax(le_zero, eq_zero)  # Treat fmax as commutative
724        *        | fmax(lt_zero, eq_zero)  # Treat fmax as commutative
725        *        | fmax(eq_zero, eq_zero)
726        *        ;
727        *
728        * All other cases are 'unknown'.
729        */
730       static const enum ssa_ranges table[last_range + 1][last_range + 1] = {
731          /* left\right   unknown  lt_zero  le_zero  gt_zero  ge_zero  ne_zero  eq_zero */
732          /* unknown */ { _______, _______, _______, gt_zero, ge_zero, _______, _______ },
733          /* lt_zero */ { _______, lt_zero, le_zero, gt_zero, ge_zero, ne_zero, eq_zero },
734          /* le_zero */ { _______, le_zero, le_zero, gt_zero, ge_zero, _______, eq_zero },
735          /* gt_zero */ { gt_zero, gt_zero, gt_zero, gt_zero, gt_zero, gt_zero, gt_zero },
736          /* ge_zero */ { ge_zero, ge_zero, ge_zero, gt_zero, ge_zero, ge_zero, ge_zero },
737          /* ne_zero */ { _______, ne_zero, _______, gt_zero, ge_zero, ne_zero, _______ },
738          /* eq_zero */ { _______, eq_zero, eq_zero, gt_zero, ge_zero, _______, eq_zero }
739       };
740 
741       /* Treat fmax as commutative. */
742       ASSERT_TABLE_IS_COMMUTATIVE(table);
743       ASSERT_TABLE_IS_DIAGONAL(table);
744       ASSERT_UNION_OF_OTHERS_MATCHES_UNKNOWN_2_SOURCE(table);
745 
746       r.range = table[left.range][right.range];
747       break;
748    }
749 
750    case nir_op_fmin: {
751       const struct ssa_result_range left =
752          analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
753       const struct ssa_result_range right =
754          analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
755 
756       r.is_integral = left.is_integral && right.is_integral;
757 
758       /* lt_zero: fmin(lt_zero, *)
759        *        | fmin(*, lt_zero)        # Treat fmin as commutative
760        *        ;
761        *
762        * le_zero: fmin(le_zero, ne_zero)
763        *        | fmin(le_zero, gt_zero)
764        *        | fmin(le_zero, ge_zero)
765        *        | fmin(le_zero, eq_zero)
766        *        | fmin(ne_zero, le_zero)  # Treat fmin as commutative
767        *        | fmin(gt_zero, le_zero)  # Treat fmin as commutative
768        *        | fmin(ge_zero, le_zero)  # Treat fmin as commutative
769        *        | fmin(eq_zero, le_zero)  # Treat fmin as commutative
770        *        | fmin(le_zero, le_zero)
771        *        ;
772        *
773        * ge_zero: fmin(ge_zero, gt_zero)
774        *        | fmin(gt_zero, ge_zero)  # Treat fmin as commutative
775        *        | fmin(ge_zero, ge_zero)
776        *        ;
777        *
778        * gt_zero: fmin(gt_zero, gt_zero)
779        *        ;
780        *
781        * ne_zero: fmin(ne_zero, gt_zero)
782        *        | fmin(gt_zero, ne_zero)  # Treat fmin as commutative
783        *        | fmin(ne_zero, ne_zero)
784        *        ;
785        *
786        * eq_zero: fmin(eq_zero, ge_zero)
787        *        | fmin(eq_zero, gt_zero)
788        *        | fmin(ge_zero, eq_zero)  # Treat fmin as commutative
789        *        | fmin(gt_zero, eq_zero)  # Treat fmin as commutative
790        *        | fmin(eq_zero, eq_zero)
791        *        ;
792        *
793        * All other cases are 'unknown'.
794        */
795       static const enum ssa_ranges table[last_range + 1][last_range + 1] = {
796          /* left\right   unknown  lt_zero  le_zero  gt_zero  ge_zero  ne_zero  eq_zero */
797          /* unknown */ { _______, lt_zero, le_zero, _______, _______, _______, _______ },
798          /* lt_zero */ { lt_zero, lt_zero, lt_zero, lt_zero, lt_zero, lt_zero, lt_zero },
799          /* le_zero */ { le_zero, lt_zero, le_zero, le_zero, le_zero, le_zero, le_zero },
800          /* gt_zero */ { _______, lt_zero, le_zero, gt_zero, ge_zero, ne_zero, eq_zero },
801          /* ge_zero */ { _______, lt_zero, le_zero, ge_zero, ge_zero, _______, eq_zero },
802          /* ne_zero */ { _______, lt_zero, le_zero, ne_zero, _______, ne_zero, _______ },
803          /* eq_zero */ { _______, lt_zero, le_zero, eq_zero, eq_zero, _______, eq_zero }
804       };
805 
806       /* Treat fmin as commutative. */
807       ASSERT_TABLE_IS_COMMUTATIVE(table);
808       ASSERT_TABLE_IS_DIAGONAL(table);
809       ASSERT_UNION_OF_OTHERS_MATCHES_UNKNOWN_2_SOURCE(table);
810 
811       r.range = table[left.range][right.range];
812       break;
813    }
814 
815    case nir_op_fmul: {
816       const struct ssa_result_range left =
817          analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
818       const struct ssa_result_range right =
819          analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
820 
821       r.is_integral = left.is_integral && right.is_integral;
822 
823       /* x * x => ge_zero */
824       if (left.range != eq_zero && nir_alu_srcs_equal(alu, alu, 0, 1)) {
825          /* Even if x > 0, the result of x*x can be zero when x is, for
826           * example, a subnormal number.
827           */
828          r.range = ge_zero;
829       } else if (left.range != eq_zero && nir_alu_srcs_negative_equal(alu, alu, 0, 1)) {
830          /* -x * x => le_zero. */
831          r.range = le_zero;
832       } else
833          r.range = fmul_table[left.range][right.range];
834 
835       break;
836    }
837 
838    case nir_op_frcp:
839       r = (struct ssa_result_range){
840          analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)).range,
841          false
842       };
843       break;
844 
845    case nir_op_mov:
846       r = analyze_expression(alu, 0, ht, use_type);
847       break;
848 
849    case nir_op_fneg:
850       r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
851 
852       r.range = fneg_table[r.range];
853       break;
854 
855    case nir_op_fsat:
856       r = analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
857 
858       switch (r.range) {
859       case le_zero:
860       case lt_zero:
861          r.range = eq_zero;
862          r.is_integral = true;
863          break;
864 
865       case eq_zero:
866          assert(r.is_integral);
867       case gt_zero:
868       case ge_zero:
869          /* The fsat doesn't add any information in these cases. */
870          break;
871 
872       case ne_zero:
873       case unknown:
874          /* Since the result must be in [0, 1], the value must be >= 0. */
875          r.range = ge_zero;
876          break;
877       }
878       break;
879 
880    case nir_op_fsign:
881       r = (struct ssa_result_range){
882          analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0)).range,
883          true
884       };
885       break;
886 
887    case nir_op_fsqrt:
888    case nir_op_frsq:
889       r = (struct ssa_result_range){ge_zero, false};
890       break;
891 
892    case nir_op_ffloor: {
893       const struct ssa_result_range left =
894          analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
895 
896       r.is_integral = true;
897 
898       if (left.is_integral || left.range == le_zero || left.range == lt_zero)
899          r.range = left.range;
900       else if (left.range == ge_zero || left.range == gt_zero)
901          r.range = ge_zero;
902       else if (left.range == ne_zero)
903          r.range = unknown;
904 
905       break;
906    }
907 
908    case nir_op_fceil: {
909       const struct ssa_result_range left =
910          analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
911 
912       r.is_integral = true;
913 
914       if (left.is_integral || left.range == ge_zero || left.range == gt_zero)
915          r.range = left.range;
916       else if (left.range == le_zero || left.range == lt_zero)
917          r.range = le_zero;
918       else if (left.range == ne_zero)
919          r.range = unknown;
920 
921       break;
922    }
923 
924    case nir_op_ftrunc: {
925       const struct ssa_result_range left =
926          analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
927 
928       r.is_integral = true;
929 
930       if (left.is_integral)
931          r.range = left.range;
932       else if (left.range == ge_zero || left.range == gt_zero)
933          r.range = ge_zero;
934       else if (left.range == le_zero || left.range == lt_zero)
935          r.range = le_zero;
936       else if (left.range == ne_zero)
937          r.range = unknown;
938 
939       break;
940    }
941 
942    case nir_op_flt:
943    case nir_op_fge:
944    case nir_op_feq:
945    case nir_op_fneu:
946    case nir_op_ilt:
947    case nir_op_ige:
948    case nir_op_ieq:
949    case nir_op_ine:
950    case nir_op_ult:
951    case nir_op_uge:
952       /* Boolean results are 0 or -1. */
953       r = (struct ssa_result_range){le_zero, false};
954       break;
955 
956    case nir_op_fpow: {
957       /* Due to flush-to-zero semanatics of floating-point numbers with very
958        * small mangnitudes, we can never really be sure a result will be
959        * non-zero.
960        *
961        * NIR uses pow() and powf() to constant evaluate nir_op_fpow.  The man
962        * page for that function says:
963        *
964        *    If y is 0, the result is 1.0 (even if x is a NaN).
965        *
966        * gt_zero: pow(*, eq_zero)
967        *        | pow(eq_zero, lt_zero)   # 0^-y = +inf
968        *        | pow(eq_zero, le_zero)   # 0^-y = +inf or 0^0 = 1.0
969        *        ;
970        *
971        * eq_zero: pow(eq_zero, gt_zero)
972        *        ;
973        *
974        * ge_zero: pow(gt_zero, gt_zero)
975        *        | pow(gt_zero, ge_zero)
976        *        | pow(gt_zero, lt_zero)
977        *        | pow(gt_zero, le_zero)
978        *        | pow(gt_zero, ne_zero)
979        *        | pow(gt_zero, unknown)
980        *        | pow(ge_zero, gt_zero)
981        *        | pow(ge_zero, ge_zero)
982        *        | pow(ge_zero, lt_zero)
983        *        | pow(ge_zero, le_zero)
984        *        | pow(ge_zero, ne_zero)
985        *        | pow(ge_zero, unknown)
986        *        | pow(eq_zero, ge_zero)  # 0^0 = 1.0 or 0^+y = 0.0
987        *        | pow(eq_zero, ne_zero)  # 0^-y = +inf or 0^+y = 0.0
988        *        | pow(eq_zero, unknown)  # union of all other y cases
989        *        ;
990        *
991        * All other cases are unknown.
992        *
993        * We could do better if the right operand is a constant, integral
994        * value.
995        */
996       static const enum ssa_ranges table[last_range + 1][last_range + 1] = {
997          /* left\right   unknown  lt_zero  le_zero  gt_zero  ge_zero  ne_zero  eq_zero */
998          /* unknown */ { _______, _______, _______, _______, _______, _______, gt_zero },
999          /* lt_zero */ { _______, _______, _______, _______, _______, _______, gt_zero },
1000          /* le_zero */ { _______, _______, _______, _______, _______, _______, gt_zero },
1001          /* gt_zero */ { ge_zero, ge_zero, ge_zero, ge_zero, ge_zero, ge_zero, gt_zero },
1002          /* ge_zero */ { ge_zero, ge_zero, ge_zero, ge_zero, ge_zero, ge_zero, gt_zero },
1003          /* ne_zero */ { _______, _______, _______, _______, _______, _______, gt_zero },
1004          /* eq_zero */ { ge_zero, gt_zero, gt_zero, eq_zero, ge_zero, ge_zero, gt_zero },
1005       };
1006 
1007       const struct ssa_result_range left =
1008          analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
1009       const struct ssa_result_range right =
1010          analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
1011 
1012       ASSERT_UNION_OF_DISJOINT_MATCHES_UNKNOWN_2_SOURCE(table);
1013       ASSERT_UNION_OF_EQ_AND_STRICT_INEQ_MATCHES_NONSTRICT_2_SOURCE(table);
1014 
1015       r.is_integral = left.is_integral && right.is_integral &&
1016                       is_not_negative(right.range);
1017       r.range = table[left.range][right.range];
1018       break;
1019    }
1020 
1021    case nir_op_ffma: {
1022       const struct ssa_result_range first =
1023          analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
1024       const struct ssa_result_range second =
1025          analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
1026       const struct ssa_result_range third =
1027          analyze_expression(alu, 2, ht, nir_alu_src_type(alu, 2));
1028 
1029       r.is_integral = first.is_integral && second.is_integral &&
1030                       third.is_integral;
1031 
1032       enum ssa_ranges fmul_range;
1033 
1034       if (first.range != eq_zero && nir_alu_srcs_equal(alu, alu, 0, 1)) {
1035          /* See handling of nir_op_fmul for explanation of why ge_zero is the
1036           * range.
1037           */
1038          fmul_range = ge_zero;
1039       } else if (first.range != eq_zero && nir_alu_srcs_negative_equal(alu, alu, 0, 1)) {
1040          /* -x * x => le_zero */
1041          fmul_range = le_zero;
1042       } else
1043          fmul_range = fmul_table[first.range][second.range];
1044 
1045       r.range = fadd_table[fmul_range][third.range];
1046       break;
1047    }
1048 
1049    case nir_op_flrp: {
1050       const struct ssa_result_range first =
1051          analyze_expression(alu, 0, ht, nir_alu_src_type(alu, 0));
1052       const struct ssa_result_range second =
1053          analyze_expression(alu, 1, ht, nir_alu_src_type(alu, 1));
1054       const struct ssa_result_range third =
1055          analyze_expression(alu, 2, ht, nir_alu_src_type(alu, 2));
1056 
1057       r.is_integral = first.is_integral && second.is_integral &&
1058                       third.is_integral;
1059 
1060       /* Decompose the flrp to first + third * (second + -first) */
1061       const enum ssa_ranges inner_fadd_range =
1062          fadd_table[second.range][fneg_table[first.range]];
1063 
1064       const enum ssa_ranges fmul_range =
1065          fmul_table[third.range][inner_fadd_range];
1066 
1067       r.range = fadd_table[first.range][fmul_range];
1068       break;
1069    }
1070 
1071    default:
1072       r = (struct ssa_result_range){unknown, false};
1073       break;
1074    }
1075 
1076    if (r.range == eq_zero)
1077       r.is_integral = true;
1078 
1079    _mesa_hash_table_insert(ht, pack_key(alu, use_type), pack_data(r));
1080    return r;
1081 }
1082 
1083 #undef _______
1084 
1085 struct ssa_result_range
nir_analyze_range(struct hash_table * range_ht,const nir_alu_instr * instr,unsigned src)1086 nir_analyze_range(struct hash_table *range_ht,
1087                   const nir_alu_instr *instr, unsigned src)
1088 {
1089    return analyze_expression(instr, src, range_ht,
1090                              nir_alu_src_type(instr, src));
1091 }
1092 
bitmask(uint32_t size)1093 static uint32_t bitmask(uint32_t size) {
1094    return size >= 32 ? 0xffffffffu : ((uint32_t)1 << size) - 1u;
1095 }
1096 
mul_clamp(uint32_t a,uint32_t b)1097 static uint64_t mul_clamp(uint32_t a, uint32_t b)
1098 {
1099    if (a != 0 && (a * b) / a != b)
1100       return (uint64_t)UINT32_MAX + 1;
1101    else
1102       return a * b;
1103 }
1104 
1105 /* recursively gather at most "buf_size" phi/bcsel sources */
1106 static unsigned
search_phi_bcsel(nir_ssa_scalar scalar,nir_ssa_scalar * buf,unsigned buf_size,struct set * visited)1107 search_phi_bcsel(nir_ssa_scalar scalar, nir_ssa_scalar *buf, unsigned buf_size, struct set *visited)
1108 {
1109    if (_mesa_set_search(visited, scalar.def))
1110       return 0;
1111    _mesa_set_add(visited, scalar.def);
1112 
1113    if (scalar.def->parent_instr->type == nir_instr_type_phi) {
1114       nir_phi_instr *phi = nir_instr_as_phi(scalar.def->parent_instr);
1115       unsigned num_sources_left = exec_list_length(&phi->srcs);
1116       if (buf_size >= num_sources_left) {
1117          unsigned total_added = 0;
1118          nir_foreach_phi_src(src, phi) {
1119             unsigned added = search_phi_bcsel(
1120                (nir_ssa_scalar){src->src.ssa, 0}, buf + total_added, buf_size - num_sources_left, visited);
1121             buf_size -= added;
1122             total_added += added;
1123             num_sources_left--;
1124          }
1125          return total_added;
1126       }
1127    }
1128 
1129    if (nir_ssa_scalar_is_alu(scalar)) {
1130       nir_op op = nir_ssa_scalar_alu_op(scalar);
1131 
1132       if ((op == nir_op_bcsel || op == nir_op_b32csel) && buf_size >= 2) {
1133          nir_ssa_scalar src0 = nir_ssa_scalar_chase_alu_src(scalar, 0);
1134          nir_ssa_scalar src1 = nir_ssa_scalar_chase_alu_src(scalar, 1);
1135 
1136          unsigned added = search_phi_bcsel(src0, buf, buf_size - 1, visited);
1137          buf_size -= added;
1138          added += search_phi_bcsel(src1, buf + added, buf_size, visited);
1139          return added;
1140       }
1141    }
1142 
1143    buf[0] = scalar;
1144    return 1;
1145 }
1146 
1147 static nir_variable *
lookup_input(nir_shader * shader,unsigned driver_location)1148 lookup_input(nir_shader *shader, unsigned driver_location)
1149 {
1150    return nir_find_variable_with_driver_location(shader, nir_var_shader_in,
1151                                                  driver_location);
1152 }
1153 
1154 uint32_t
nir_unsigned_upper_bound(nir_shader * shader,struct hash_table * range_ht,nir_ssa_scalar scalar,const nir_unsigned_upper_bound_config * config)1155 nir_unsigned_upper_bound(nir_shader *shader, struct hash_table *range_ht,
1156                          nir_ssa_scalar scalar,
1157                          const nir_unsigned_upper_bound_config *config)
1158 {
1159    assert(scalar.def->bit_size <= 32);
1160 
1161    if (nir_ssa_scalar_is_const(scalar))
1162       return nir_ssa_scalar_as_uint(scalar);
1163 
1164    /* keys can't be 0, so we have to add 1 to the index */
1165    void *key = (void*)(((uintptr_t)(scalar.def->index + 1) << 4) | scalar.comp);
1166    struct hash_entry *he = _mesa_hash_table_search(range_ht, key);
1167    if (he != NULL)
1168       return (uintptr_t)he->data;
1169 
1170    uint32_t max = bitmask(scalar.def->bit_size);
1171 
1172    if (scalar.def->parent_instr->type == nir_instr_type_intrinsic) {
1173       uint32_t res = max;
1174       nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(scalar.def->parent_instr);
1175       switch (intrin->intrinsic) {
1176       case nir_intrinsic_load_local_invocation_index:
1177          if (shader->info.cs.local_size_variable) {
1178             res = config->max_work_group_invocations - 1;
1179          } else {
1180             res = (shader->info.cs.local_size[0] *
1181                    shader->info.cs.local_size[1] *
1182                    shader->info.cs.local_size[2]) - 1u;
1183          }
1184          break;
1185       case nir_intrinsic_load_local_invocation_id:
1186          if (shader->info.cs.local_size_variable)
1187             res = config->max_work_group_size[scalar.comp] - 1u;
1188          else
1189             res = shader->info.cs.local_size[scalar.comp] - 1u;
1190          break;
1191       case nir_intrinsic_load_work_group_id:
1192          res = config->max_work_group_count[scalar.comp] - 1u;
1193          break;
1194       case nir_intrinsic_load_num_work_groups:
1195          res = config->max_work_group_count[scalar.comp];
1196          break;
1197       case nir_intrinsic_load_global_invocation_id:
1198          if (shader->info.cs.local_size_variable) {
1199             res = mul_clamp(config->max_work_group_size[scalar.comp],
1200                             config->max_work_group_count[scalar.comp]) - 1u;
1201          } else {
1202             res = (shader->info.cs.local_size[scalar.comp] *
1203                    config->max_work_group_count[scalar.comp]) - 1u;
1204          }
1205          break;
1206       case nir_intrinsic_load_subgroup_invocation:
1207       case nir_intrinsic_first_invocation:
1208       case nir_intrinsic_mbcnt_amd:
1209          res = config->max_subgroup_size - 1;
1210          break;
1211       case nir_intrinsic_load_subgroup_size:
1212          res = config->max_subgroup_size;
1213          break;
1214       case nir_intrinsic_load_subgroup_id:
1215       case nir_intrinsic_load_num_subgroups: {
1216          uint32_t work_group_size = config->max_work_group_invocations;
1217          if (!shader->info.cs.local_size_variable) {
1218             work_group_size = shader->info.cs.local_size[0] *
1219                               shader->info.cs.local_size[1] *
1220                               shader->info.cs.local_size[2];
1221          }
1222          res = (work_group_size + config->min_subgroup_size - 1) / config->min_subgroup_size;
1223          if (intrin->intrinsic == nir_intrinsic_load_subgroup_id)
1224             res--;
1225          break;
1226       }
1227       case nir_intrinsic_load_input: {
1228          if (shader->info.stage == MESA_SHADER_VERTEX && nir_src_is_const(intrin->src[0])) {
1229             nir_variable *var = lookup_input(shader, nir_intrinsic_base(intrin));
1230             if (var) {
1231                int loc = var->data.location - VERT_ATTRIB_GENERIC0;
1232                if (loc >= 0)
1233                   res = config->vertex_attrib_max[loc];
1234             }
1235          }
1236          break;
1237       }
1238       case nir_intrinsic_reduce:
1239       case nir_intrinsic_inclusive_scan:
1240       case nir_intrinsic_exclusive_scan: {
1241          nir_op op = nir_intrinsic_reduction_op(intrin);
1242          if (op == nir_op_umin || op == nir_op_umax || op == nir_op_imin || op == nir_op_imax)
1243             res = nir_unsigned_upper_bound(shader, range_ht, (nir_ssa_scalar){intrin->src[0].ssa, 0}, config);
1244          break;
1245       }
1246       case nir_intrinsic_read_first_invocation:
1247       case nir_intrinsic_read_invocation:
1248       case nir_intrinsic_shuffle:
1249       case nir_intrinsic_shuffle_xor:
1250       case nir_intrinsic_shuffle_up:
1251       case nir_intrinsic_shuffle_down:
1252       case nir_intrinsic_quad_broadcast:
1253       case nir_intrinsic_quad_swap_horizontal:
1254       case nir_intrinsic_quad_swap_vertical:
1255       case nir_intrinsic_quad_swap_diagonal:
1256       case nir_intrinsic_quad_swizzle_amd:
1257       case nir_intrinsic_masked_swizzle_amd:
1258          res = nir_unsigned_upper_bound(shader, range_ht, (nir_ssa_scalar){intrin->src[0].ssa, 0}, config);
1259          break;
1260       case nir_intrinsic_write_invocation_amd: {
1261          uint32_t src0 = nir_unsigned_upper_bound(shader, range_ht, (nir_ssa_scalar){intrin->src[0].ssa, 0}, config);
1262          uint32_t src1 = nir_unsigned_upper_bound(shader, range_ht, (nir_ssa_scalar){intrin->src[1].ssa, 0}, config);
1263          res = MAX2(src0, src1);
1264          break;
1265       }
1266       default:
1267          break;
1268       }
1269       if (res != max)
1270          _mesa_hash_table_insert(range_ht, key, (void*)(uintptr_t)res);
1271       return res;
1272    }
1273 
1274    if (scalar.def->parent_instr->type == nir_instr_type_phi) {
1275       bool cyclic = false;
1276       nir_foreach_phi_src(src, nir_instr_as_phi(scalar.def->parent_instr)) {
1277          if (nir_block_dominates(scalar.def->parent_instr->block, src->pred)) {
1278             cyclic = true;
1279             break;
1280          }
1281       }
1282 
1283       uint32_t res = 0;
1284       if (cyclic) {
1285          _mesa_hash_table_insert(range_ht, key, (void*)(uintptr_t)max);
1286 
1287          struct set *visited = _mesa_pointer_set_create(NULL);
1288          nir_ssa_scalar defs[64];
1289          unsigned def_count = search_phi_bcsel(scalar, defs, 64, visited);
1290          _mesa_set_destroy(visited, NULL);
1291 
1292          for (unsigned i = 0; i < def_count; i++)
1293             res = MAX2(res, nir_unsigned_upper_bound(shader, range_ht, defs[i], config));
1294       } else {
1295          nir_foreach_phi_src(src, nir_instr_as_phi(scalar.def->parent_instr)) {
1296             res = MAX2(res, nir_unsigned_upper_bound(
1297                shader, range_ht, (nir_ssa_scalar){src->src.ssa, 0}, config));
1298          }
1299       }
1300 
1301       _mesa_hash_table_insert(range_ht, key, (void*)(uintptr_t)res);
1302       return res;
1303    }
1304 
1305    if (nir_ssa_scalar_is_alu(scalar)) {
1306       nir_op op = nir_ssa_scalar_alu_op(scalar);
1307 
1308       switch (op) {
1309       case nir_op_umin:
1310       case nir_op_imin:
1311       case nir_op_imax:
1312       case nir_op_umax:
1313       case nir_op_iand:
1314       case nir_op_ior:
1315       case nir_op_ixor:
1316       case nir_op_ishl:
1317       case nir_op_imul:
1318       case nir_op_ushr:
1319       case nir_op_ishr:
1320       case nir_op_iadd:
1321       case nir_op_umod:
1322       case nir_op_udiv:
1323       case nir_op_bcsel:
1324       case nir_op_b32csel:
1325       case nir_op_ubfe:
1326       case nir_op_bfm:
1327       case nir_op_f2u32:
1328       case nir_op_fmul:
1329          break;
1330       default:
1331          return max;
1332       }
1333 
1334       uint32_t src0 = nir_unsigned_upper_bound(shader, range_ht, nir_ssa_scalar_chase_alu_src(scalar, 0), config);
1335       uint32_t src1 = max, src2 = max;
1336       if (nir_op_infos[op].num_inputs > 1)
1337          src1 = nir_unsigned_upper_bound(shader, range_ht, nir_ssa_scalar_chase_alu_src(scalar, 1), config);
1338       if (nir_op_infos[op].num_inputs > 2)
1339          src2 = nir_unsigned_upper_bound(shader, range_ht, nir_ssa_scalar_chase_alu_src(scalar, 2), config);
1340 
1341       uint32_t res = max;
1342       switch (op) {
1343       case nir_op_umin:
1344          res = src0 < src1 ? src0 : src1;
1345          break;
1346       case nir_op_imin:
1347       case nir_op_imax:
1348       case nir_op_umax:
1349          res = src0 > src1 ? src0 : src1;
1350          break;
1351       case nir_op_iand:
1352          res = bitmask(util_last_bit64(src0)) & bitmask(util_last_bit64(src1));
1353          break;
1354       case nir_op_ior:
1355       case nir_op_ixor:
1356          res = bitmask(util_last_bit64(src0)) | bitmask(util_last_bit64(src1));
1357          break;
1358       case nir_op_ishl:
1359          if (util_last_bit64(src0) + src1 > scalar.def->bit_size)
1360             res = max; /* overflow */
1361          else
1362             res = src0 << MIN2(src1, scalar.def->bit_size - 1u);
1363          break;
1364       case nir_op_imul:
1365          if (src0 != 0 && (src0 * src1) / src0 != src1)
1366             res = max;
1367          else
1368             res = src0 * src1;
1369          break;
1370       case nir_op_ushr: {
1371          nir_ssa_scalar src1_scalar = nir_ssa_scalar_chase_alu_src(scalar, 1);
1372          if (nir_ssa_scalar_is_const(src1_scalar))
1373             res = src0 >> nir_ssa_scalar_as_uint(src1_scalar);
1374          else
1375             res = src0;
1376          break;
1377       }
1378       case nir_op_ishr: {
1379          nir_ssa_scalar src1_scalar = nir_ssa_scalar_chase_alu_src(scalar, 1);
1380          if (src0 <= 2147483647 && nir_ssa_scalar_is_const(src1_scalar))
1381             res = src0 >> nir_ssa_scalar_as_uint(src1_scalar);
1382          else
1383             res = src0;
1384          break;
1385       }
1386       case nir_op_iadd:
1387          if (src0 + src1 < src0)
1388             res = max; /* overflow */
1389          else
1390             res = src0 + src1;
1391          break;
1392       case nir_op_umod:
1393          res = src1 ? src1 - 1 : 0;
1394          break;
1395       case nir_op_udiv: {
1396          nir_ssa_scalar src1_scalar = nir_ssa_scalar_chase_alu_src(scalar, 1);
1397          if (nir_ssa_scalar_is_const(src1_scalar))
1398             res = nir_ssa_scalar_as_uint(src1_scalar) ? src0 / nir_ssa_scalar_as_uint(src1_scalar) : 0;
1399          else
1400             res = src0;
1401          break;
1402       }
1403       case nir_op_bcsel:
1404       case nir_op_b32csel:
1405          res = src1 > src2 ? src1 : src2;
1406          break;
1407       case nir_op_ubfe:
1408          res = bitmask(MIN2(src2, scalar.def->bit_size));
1409          break;
1410       case nir_op_bfm: {
1411          nir_ssa_scalar src1_scalar = nir_ssa_scalar_chase_alu_src(scalar, 1);
1412          if (nir_ssa_scalar_is_const(src1_scalar)) {
1413             src0 = MIN2(src0, 31);
1414             src1 = nir_ssa_scalar_as_uint(src1_scalar) & 0x1fu;
1415             res = bitmask(src0) << src1;
1416          } else {
1417             src0 = MIN2(src0, 31);
1418             src1 = MIN2(src1, 31);
1419             res = bitmask(MIN2(src0 + src1, 32));
1420          }
1421          break;
1422       }
1423       /* limited floating-point support for f2u32(fmul(load_input(), <constant>)) */
1424       case nir_op_f2u32:
1425          /* infinity/NaN starts at 0x7f800000u, negative numbers at 0x80000000 */
1426          if (src0 < 0x7f800000u) {
1427             float val;
1428             memcpy(&val, &src0, 4);
1429             res = (uint32_t)val;
1430          }
1431          break;
1432       case nir_op_fmul:
1433          /* infinity/NaN starts at 0x7f800000u, negative numbers at 0x80000000 */
1434          if (src0 < 0x7f800000u && src1 < 0x7f800000u) {
1435             float src0_f, src1_f;
1436             memcpy(&src0_f, &src0, 4);
1437             memcpy(&src1_f, &src1, 4);
1438             /* not a proper rounding-up multiplication, but should be good enough */
1439             float max_f = ceilf(src0_f) * ceilf(src1_f);
1440             memcpy(&res, &max_f, 4);
1441          }
1442          break;
1443       default:
1444          res = max;
1445          break;
1446       }
1447       _mesa_hash_table_insert(range_ht, key, (void*)(uintptr_t)res);
1448       return res;
1449    }
1450 
1451    return max;
1452 }
1453 
1454 bool
nir_addition_might_overflow(nir_shader * shader,struct hash_table * range_ht,nir_ssa_scalar ssa,unsigned const_val,const nir_unsigned_upper_bound_config * config)1455 nir_addition_might_overflow(nir_shader *shader, struct hash_table *range_ht,
1456                             nir_ssa_scalar ssa, unsigned const_val,
1457                             const nir_unsigned_upper_bound_config *config)
1458 {
1459    if (nir_ssa_scalar_is_alu(ssa)) {
1460       nir_op alu_op = nir_ssa_scalar_alu_op(ssa);
1461 
1462       /* iadd(imul(a, #b), #c) */
1463       if (alu_op == nir_op_imul || alu_op == nir_op_ishl) {
1464          nir_ssa_scalar mul_src0 = nir_ssa_scalar_chase_alu_src(ssa, 0);
1465          nir_ssa_scalar mul_src1 = nir_ssa_scalar_chase_alu_src(ssa, 1);
1466          uint32_t stride = 1;
1467          if (nir_ssa_scalar_is_const(mul_src0))
1468             stride = nir_ssa_scalar_as_uint(mul_src0);
1469          else if (nir_ssa_scalar_is_const(mul_src1))
1470             stride = nir_ssa_scalar_as_uint(mul_src1);
1471 
1472          if (alu_op == nir_op_ishl)
1473             stride = 1u << (stride % 32u);
1474 
1475          if (!stride || const_val <= UINT32_MAX - (UINT32_MAX / stride * stride))
1476             return false;
1477       }
1478 
1479       /* iadd(iand(a, #b), #c) */
1480       if (alu_op == nir_op_iand) {
1481          nir_ssa_scalar and_src0 = nir_ssa_scalar_chase_alu_src(ssa, 0);
1482          nir_ssa_scalar and_src1 = nir_ssa_scalar_chase_alu_src(ssa, 1);
1483          uint32_t mask = 0xffffffff;
1484          if (nir_ssa_scalar_is_const(and_src0))
1485             mask = nir_ssa_scalar_as_uint(and_src0);
1486          else if (nir_ssa_scalar_is_const(and_src1))
1487             mask = nir_ssa_scalar_as_uint(and_src1);
1488          if (mask == 0 || const_val < (1u << (ffs(mask) - 1)))
1489             return false;
1490       }
1491    }
1492 
1493    uint32_t ub = nir_unsigned_upper_bound(shader, range_ht, ssa, config);
1494    return const_val + ub < const_val;
1495 }
1496 
1497