• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2014 Intel Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  *
23  * Authors:
24  *    Jason Ekstrand (jason@jlekstrand.net)
25  *
26  */
27 
28 #include <inttypes.h>
29 #include "nir_search.h"
30 #include "nir_builder.h"
31 #include "nir_worklist.h"
32 #include "util/half_float.h"
33 
34 /* This should be the same as nir_search_max_comm_ops in nir_algebraic.py. */
35 #define NIR_SEARCH_MAX_COMM_OPS 8
36 
37 struct match_state {
38    bool inexact_match;
39    bool has_exact_alu;
40    uint8_t comm_op_direction;
41    unsigned variables_seen;
42 
43    /* Used for running the automaton on newly-constructed instructions. */
44    struct util_dynarray *states;
45    const struct per_op_table *pass_op_table;
46    const nir_algebraic_table *table;
47 
48    nir_alu_src variables[NIR_SEARCH_MAX_VARIABLES];
49    struct hash_table *range_ht;
50 };
51 
52 static bool
53 match_expression(const nir_algebraic_table *table, const nir_search_expression *expr, nir_alu_instr *instr,
54                  unsigned num_components, const uint8_t *swizzle,
55                  struct match_state *state);
56 static bool
57 nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states,
58                         const struct per_op_table *pass_op_table);
59 
60 static const uint8_t identity_swizzle[NIR_MAX_VEC_COMPONENTS] =
61 {
62     0,  1,  2,  3,
63     4,  5,  6,  7,
64     8,  9, 10, 11,
65    12, 13, 14, 15,
66 };
67 
68 /**
69  * Check if a source produces a value of the given type.
70  *
71  * Used for satisfying 'a@type' constraints.
72  */
73 static bool
src_is_type(nir_src src,nir_alu_type type)74 src_is_type(nir_src src, nir_alu_type type)
75 {
76    assert(type != nir_type_invalid);
77 
78    if (!src.is_ssa)
79       return false;
80 
81    if (src.ssa->parent_instr->type == nir_instr_type_alu) {
82       nir_alu_instr *src_alu = nir_instr_as_alu(src.ssa->parent_instr);
83       nir_alu_type output_type = nir_op_infos[src_alu->op].output_type;
84 
85       if (type == nir_type_bool) {
86          switch (src_alu->op) {
87          case nir_op_iand:
88          case nir_op_ior:
89          case nir_op_ixor:
90             return src_is_type(src_alu->src[0].src, nir_type_bool) &&
91                    src_is_type(src_alu->src[1].src, nir_type_bool);
92          case nir_op_inot:
93             return src_is_type(src_alu->src[0].src, nir_type_bool);
94          default:
95             break;
96          }
97       }
98 
99       return nir_alu_type_get_base_type(output_type) == type;
100    } else if (src.ssa->parent_instr->type == nir_instr_type_intrinsic) {
101       nir_intrinsic_instr *intr = nir_instr_as_intrinsic(src.ssa->parent_instr);
102 
103       if (type == nir_type_bool) {
104          return intr->intrinsic == nir_intrinsic_load_front_face ||
105                 intr->intrinsic == nir_intrinsic_load_helper_invocation;
106       }
107    }
108 
109    /* don't know */
110    return false;
111 }
112 
113 static bool
nir_op_matches_search_op(nir_op nop,uint16_t sop)114 nir_op_matches_search_op(nir_op nop, uint16_t sop)
115 {
116    if (sop <= nir_last_opcode)
117       return nop == sop;
118 
119 #define MATCH_FCONV_CASE(op) \
120    case nir_search_op_##op: \
121       return nop == nir_op_##op##16 || \
122              nop == nir_op_##op##32 || \
123              nop == nir_op_##op##64;
124 
125 #define MATCH_ICONV_CASE(op) \
126    case nir_search_op_##op: \
127       return nop == nir_op_##op##8 || \
128              nop == nir_op_##op##16 || \
129              nop == nir_op_##op##32 || \
130              nop == nir_op_##op##64;
131 
132 #define MATCH_BCONV_CASE(op) \
133    case nir_search_op_##op: \
134       return nop == nir_op_##op##1 || \
135              nop == nir_op_##op##32;
136 
137    switch (sop) {
138    MATCH_FCONV_CASE(i2f)
139    MATCH_FCONV_CASE(u2f)
140    MATCH_FCONV_CASE(f2f)
141    MATCH_ICONV_CASE(f2u)
142    MATCH_ICONV_CASE(f2i)
143    MATCH_ICONV_CASE(u2u)
144    MATCH_ICONV_CASE(i2i)
145    MATCH_FCONV_CASE(b2f)
146    MATCH_ICONV_CASE(b2i)
147    MATCH_BCONV_CASE(i2b)
148    MATCH_BCONV_CASE(f2b)
149    default:
150       unreachable("Invalid nir_search_op");
151    }
152 
153 #undef MATCH_FCONV_CASE
154 #undef MATCH_ICONV_CASE
155 #undef MATCH_BCONV_CASE
156 }
157 
158 uint16_t
nir_search_op_for_nir_op(nir_op nop)159 nir_search_op_for_nir_op(nir_op nop)
160 {
161 #define MATCH_FCONV_CASE(op) \
162    case nir_op_##op##16: \
163    case nir_op_##op##32: \
164    case nir_op_##op##64: \
165       return nir_search_op_##op;
166 
167 #define MATCH_ICONV_CASE(op) \
168    case nir_op_##op##8: \
169    case nir_op_##op##16: \
170    case nir_op_##op##32: \
171    case nir_op_##op##64: \
172       return nir_search_op_##op;
173 
174 #define MATCH_BCONV_CASE(op) \
175    case nir_op_##op##1: \
176    case nir_op_##op##32: \
177       return nir_search_op_##op;
178 
179 
180    switch (nop) {
181    MATCH_FCONV_CASE(i2f)
182    MATCH_FCONV_CASE(u2f)
183    MATCH_FCONV_CASE(f2f)
184    MATCH_ICONV_CASE(f2u)
185    MATCH_ICONV_CASE(f2i)
186    MATCH_ICONV_CASE(u2u)
187    MATCH_ICONV_CASE(i2i)
188    MATCH_FCONV_CASE(b2f)
189    MATCH_ICONV_CASE(b2i)
190    MATCH_BCONV_CASE(i2b)
191    MATCH_BCONV_CASE(f2b)
192    default:
193       return nop;
194    }
195 
196 #undef MATCH_FCONV_CASE
197 #undef MATCH_ICONV_CASE
198 #undef MATCH_BCONV_CASE
199 }
200 
201 static nir_op
nir_op_for_search_op(uint16_t sop,unsigned bit_size)202 nir_op_for_search_op(uint16_t sop, unsigned bit_size)
203 {
204    if (sop <= nir_last_opcode)
205       return sop;
206 
207 #define RET_FCONV_CASE(op) \
208    case nir_search_op_##op: \
209       switch (bit_size) { \
210       case 16: return nir_op_##op##16; \
211       case 32: return nir_op_##op##32; \
212       case 64: return nir_op_##op##64; \
213       default: unreachable("Invalid bit size"); \
214       }
215 
216 #define RET_ICONV_CASE(op) \
217    case nir_search_op_##op: \
218       switch (bit_size) { \
219       case 8:  return nir_op_##op##8; \
220       case 16: return nir_op_##op##16; \
221       case 32: return nir_op_##op##32; \
222       case 64: return nir_op_##op##64; \
223       default: unreachable("Invalid bit size"); \
224       }
225 
226 #define RET_BCONV_CASE(op) \
227    case nir_search_op_##op: \
228       switch (bit_size) { \
229       case 1: return nir_op_##op##1; \
230       case 32: return nir_op_##op##32; \
231       default: unreachable("Invalid bit size"); \
232       }
233 
234    switch (sop) {
235    RET_FCONV_CASE(i2f)
236    RET_FCONV_CASE(u2f)
237    RET_FCONV_CASE(f2f)
238    RET_ICONV_CASE(f2u)
239    RET_ICONV_CASE(f2i)
240    RET_ICONV_CASE(u2u)
241    RET_ICONV_CASE(i2i)
242    RET_FCONV_CASE(b2f)
243    RET_ICONV_CASE(b2i)
244    RET_BCONV_CASE(i2b)
245    RET_BCONV_CASE(f2b)
246    default:
247       unreachable("Invalid nir_search_op");
248    }
249 
250 #undef RET_FCONV_CASE
251 #undef RET_ICONV_CASE
252 #undef RET_BCONV_CASE
253 }
254 
255 static bool
match_value(const nir_algebraic_table * table,const nir_search_value * value,nir_alu_instr * instr,unsigned src,unsigned num_components,const uint8_t * swizzle,struct match_state * state)256 match_value(const nir_algebraic_table *table,
257             const nir_search_value *value, nir_alu_instr *instr, unsigned src,
258             unsigned num_components, const uint8_t *swizzle,
259             struct match_state *state)
260 {
261    uint8_t new_swizzle[NIR_MAX_VEC_COMPONENTS];
262 
263    /* Searching only works on SSA values because, if it's not SSA, we can't
264     * know if the value changed between one instance of that value in the
265     * expression and another.  Also, the replace operation will place reads of
266     * that value right before the last instruction in the expression we're
267     * replacing so those reads will happen after the original reads and may
268     * not be valid if they're register reads.
269     */
270    assert(instr->src[src].src.is_ssa);
271 
272    /* If the source is an explicitly sized source, then we need to reset
273     * both the number of components and the swizzle.
274     */
275    if (nir_op_infos[instr->op].input_sizes[src] != 0) {
276       num_components = nir_op_infos[instr->op].input_sizes[src];
277       swizzle = identity_swizzle;
278    }
279 
280    for (unsigned i = 0; i < num_components; ++i)
281       new_swizzle[i] = instr->src[src].swizzle[swizzle[i]];
282 
283    /* If the value has a specific bit size and it doesn't match, bail */
284    if (value->bit_size > 0 &&
285        nir_src_bit_size(instr->src[src].src) != value->bit_size)
286       return false;
287 
288    switch (value->type) {
289    case nir_search_value_expression:
290       if (instr->src[src].src.ssa->parent_instr->type != nir_instr_type_alu)
291          return false;
292 
293       return match_expression(table, nir_search_value_as_expression(value),
294                               nir_instr_as_alu(instr->src[src].src.ssa->parent_instr),
295                               num_components, new_swizzle, state);
296 
297    case nir_search_value_variable: {
298       nir_search_variable *var = nir_search_value_as_variable(value);
299       assert(var->variable < NIR_SEARCH_MAX_VARIABLES);
300 
301       if (state->variables_seen & (1 << var->variable)) {
302          if (state->variables[var->variable].src.ssa != instr->src[src].src.ssa)
303             return false;
304 
305          assert(!instr->src[src].abs && !instr->src[src].negate);
306 
307          for (unsigned i = 0; i < num_components; ++i) {
308             if (state->variables[var->variable].swizzle[i] != new_swizzle[i])
309                return false;
310          }
311 
312          return true;
313       } else {
314          if (var->is_constant &&
315              instr->src[src].src.ssa->parent_instr->type != nir_instr_type_load_const)
316             return false;
317 
318          if (var->cond_index != -1 && !table->variable_cond[var->cond_index](state->range_ht, instr,
319                                                                              src, num_components, new_swizzle))
320             return false;
321 
322          if (var->type != nir_type_invalid &&
323              !src_is_type(instr->src[src].src, var->type))
324             return false;
325 
326          state->variables_seen |= (1 << var->variable);
327          state->variables[var->variable].src = instr->src[src].src;
328          state->variables[var->variable].abs = false;
329          state->variables[var->variable].negate = false;
330 
331          for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; ++i) {
332             if (i < num_components)
333                state->variables[var->variable].swizzle[i] = new_swizzle[i];
334             else
335                state->variables[var->variable].swizzle[i] = 0;
336          }
337 
338          return true;
339       }
340    }
341 
342    case nir_search_value_constant: {
343       nir_search_constant *const_val = nir_search_value_as_constant(value);
344 
345       if (!nir_src_is_const(instr->src[src].src))
346          return false;
347 
348       switch (const_val->type) {
349       case nir_type_float: {
350          nir_load_const_instr *const load =
351             nir_instr_as_load_const(instr->src[src].src.ssa->parent_instr);
352 
353          /* There are 8-bit and 1-bit integer types, but there are no 8-bit or
354           * 1-bit float types.  This prevents potential assertion failures in
355           * nir_src_comp_as_float.
356           */
357          if (load->def.bit_size < 16)
358             return false;
359 
360          for (unsigned i = 0; i < num_components; ++i) {
361             double val = nir_src_comp_as_float(instr->src[src].src,
362                                                new_swizzle[i]);
363             if (val != const_val->data.d)
364                return false;
365          }
366          return true;
367       }
368 
369       case nir_type_int:
370       case nir_type_uint:
371       case nir_type_bool: {
372          unsigned bit_size = nir_src_bit_size(instr->src[src].src);
373          uint64_t mask = u_uintN_max(bit_size);
374          for (unsigned i = 0; i < num_components; ++i) {
375             uint64_t val = nir_src_comp_as_uint(instr->src[src].src,
376                                                 new_swizzle[i]);
377             if ((val & mask) != (const_val->data.u & mask))
378                return false;
379          }
380          return true;
381       }
382 
383       default:
384          unreachable("Invalid alu source type");
385       }
386    }
387 
388    default:
389       unreachable("Invalid search value type");
390    }
391 }
392 
393 static bool
match_expression(const nir_algebraic_table * table,const nir_search_expression * expr,nir_alu_instr * instr,unsigned num_components,const uint8_t * swizzle,struct match_state * state)394 match_expression(const nir_algebraic_table *table, const nir_search_expression *expr, nir_alu_instr *instr,
395                  unsigned num_components, const uint8_t *swizzle,
396                  struct match_state *state)
397 {
398    if (expr->cond_index != -1 && !table->expression_cond[expr->cond_index](instr))
399       return false;
400 
401    if (!nir_op_matches_search_op(instr->op, expr->opcode))
402       return false;
403 
404    assert(instr->dest.dest.is_ssa);
405 
406    if (expr->value.bit_size > 0 &&
407        instr->dest.dest.ssa.bit_size != expr->value.bit_size)
408       return false;
409 
410    state->inexact_match = expr->inexact || state->inexact_match;
411    state->has_exact_alu = (instr->exact && !expr->ignore_exact) || state->has_exact_alu;
412    if (state->inexact_match && state->has_exact_alu)
413       return false;
414 
415    assert(!instr->dest.saturate);
416    assert(nir_op_infos[instr->op].num_inputs > 0);
417 
418    /* If we have an explicitly sized destination, we can only handle the
419     * identity swizzle.  While dot(vec3(a, b, c).zxy) is a valid
420     * expression, we don't have the information right now to propagate that
421     * swizzle through.  We can only properly propagate swizzles if the
422     * instruction is vectorized.
423     */
424    if (nir_op_infos[instr->op].output_size != 0) {
425       for (unsigned i = 0; i < num_components; i++) {
426          if (swizzle[i] != i)
427             return false;
428       }
429    }
430 
431    /* If this is a commutative expression and it's one of the first few, look
432     * up its direction for the current search operation.  We'll use that value
433     * to possibly flip the sources for the match.
434     */
435    unsigned comm_op_flip =
436       (expr->comm_expr_idx >= 0 &&
437        expr->comm_expr_idx < NIR_SEARCH_MAX_COMM_OPS) ?
438       ((state->comm_op_direction >> expr->comm_expr_idx) & 1) : 0;
439 
440    bool matched = true;
441    for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) {
442       /* 2src_commutative instructions that have 3 sources are only commutative
443        * in the first two sources.  Source 2 is always source 2.
444        */
445       if (!match_value(table, &state->table->values[expr->srcs[i]].value, instr,
446                        i < 2 ? i ^ comm_op_flip : i,
447                        num_components, swizzle, state)) {
448          matched = false;
449          break;
450       }
451    }
452 
453    return matched;
454 }
455 
456 static unsigned
replace_bitsize(const nir_search_value * value,unsigned search_bitsize,struct match_state * state)457 replace_bitsize(const nir_search_value *value, unsigned search_bitsize,
458                 struct match_state *state)
459 {
460    if (value->bit_size > 0)
461       return value->bit_size;
462    if (value->bit_size < 0)
463       return nir_src_bit_size(state->variables[-value->bit_size - 1].src);
464    return search_bitsize;
465 }
466 
467 static nir_alu_src
construct_value(nir_builder * build,const nir_search_value * value,unsigned num_components,unsigned search_bitsize,struct match_state * state,nir_instr * instr)468 construct_value(nir_builder *build,
469                 const nir_search_value *value,
470                 unsigned num_components, unsigned search_bitsize,
471                 struct match_state *state,
472                 nir_instr *instr)
473 {
474    switch (value->type) {
475    case nir_search_value_expression: {
476       const nir_search_expression *expr = nir_search_value_as_expression(value);
477       unsigned dst_bit_size = replace_bitsize(value, search_bitsize, state);
478       nir_op op = nir_op_for_search_op(expr->opcode, dst_bit_size);
479 
480       if (nir_op_infos[op].output_size != 0)
481          num_components = nir_op_infos[op].output_size;
482 
483       nir_alu_instr *alu = nir_alu_instr_create(build->shader, op);
484       nir_ssa_dest_init(&alu->instr, &alu->dest.dest, num_components,
485                         dst_bit_size, NULL);
486       alu->dest.write_mask = (1 << num_components) - 1;
487       alu->dest.saturate = false;
488 
489       /* We have no way of knowing what values in a given search expression
490        * map to a particular replacement value.  Therefore, if the
491        * expression we are replacing has any exact values, the entire
492        * replacement should be exact.
493        */
494       alu->exact = state->has_exact_alu || expr->exact;
495 
496       for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
497          /* If the source is an explicitly sized source, then we need to reset
498           * the number of components to match.
499           */
500          if (nir_op_infos[alu->op].input_sizes[i] != 0)
501             num_components = nir_op_infos[alu->op].input_sizes[i];
502 
503          alu->src[i] = construct_value(build, &state->table->values[expr->srcs[i]].value,
504                                        num_components, search_bitsize,
505                                        state, instr);
506       }
507 
508       nir_builder_instr_insert(build, &alu->instr);
509 
510       assert(alu->dest.dest.ssa.index ==
511              util_dynarray_num_elements(state->states, uint16_t));
512       util_dynarray_append(state->states, uint16_t, 0);
513       nir_algebraic_automaton(&alu->instr, state->states, state->pass_op_table);
514 
515       nir_alu_src val;
516       val.src = nir_src_for_ssa(&alu->dest.dest.ssa);
517       val.negate = false;
518       val.abs = false,
519       memcpy(val.swizzle, identity_swizzle, sizeof val.swizzle);
520 
521       return val;
522    }
523 
524    case nir_search_value_variable: {
525       const nir_search_variable *var = nir_search_value_as_variable(value);
526       assert(state->variables_seen & (1 << var->variable));
527 
528       nir_alu_src val = { NIR_SRC_INIT };
529       nir_alu_src_copy(&val, &state->variables[var->variable]);
530       assert(!var->is_constant);
531 
532       for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++)
533          val.swizzle[i] = state->variables[var->variable].swizzle[var->swizzle[i]];
534 
535       return val;
536    }
537 
538    case nir_search_value_constant: {
539       const nir_search_constant *c = nir_search_value_as_constant(value);
540       unsigned bit_size = replace_bitsize(value, search_bitsize, state);
541 
542       nir_ssa_def *cval;
543       switch (c->type) {
544       case nir_type_float:
545          cval = nir_imm_floatN_t(build, c->data.d, bit_size);
546          break;
547 
548       case nir_type_int:
549       case nir_type_uint:
550          cval = nir_imm_intN_t(build, c->data.i, bit_size);
551          break;
552 
553       case nir_type_bool:
554          cval = nir_imm_boolN_t(build, c->data.u, bit_size);
555          break;
556 
557       default:
558          unreachable("Invalid alu source type");
559       }
560 
561       assert(cval->index ==
562              util_dynarray_num_elements(state->states, uint16_t));
563       util_dynarray_append(state->states, uint16_t, 0);
564       nir_algebraic_automaton(cval->parent_instr, state->states,
565                               state->pass_op_table);
566 
567       nir_alu_src val;
568       val.src = nir_src_for_ssa(cval);
569       val.negate = false;
570       val.abs = false,
571       memset(val.swizzle, 0, sizeof val.swizzle);
572 
573       return val;
574    }
575 
576    default:
577       unreachable("Invalid search value type");
578    }
579 }
580 
dump_value(const nir_algebraic_table * table,const nir_search_value * val)581 UNUSED static void dump_value(const nir_algebraic_table *table, const nir_search_value *val)
582 {
583    switch (val->type) {
584    case nir_search_value_constant: {
585       const nir_search_constant *sconst = nir_search_value_as_constant(val);
586       switch (sconst->type) {
587       case nir_type_float:
588          fprintf(stderr, "%f", sconst->data.d);
589          break;
590       case nir_type_int:
591          fprintf(stderr, "%"PRId64, sconst->data.i);
592          break;
593       case nir_type_uint:
594          fprintf(stderr, "0x%"PRIx64, sconst->data.u);
595          break;
596       case nir_type_bool:
597          fprintf(stderr, "%s", sconst->data.u != 0 ? "True" : "False");
598          break;
599       default:
600          unreachable("bad const type");
601       }
602       break;
603    }
604 
605    case nir_search_value_variable: {
606       const nir_search_variable *var = nir_search_value_as_variable(val);
607       if (var->is_constant)
608          fprintf(stderr, "#");
609       fprintf(stderr, "%c", var->variable + 'a');
610       break;
611    }
612 
613    case nir_search_value_expression: {
614       const nir_search_expression *expr = nir_search_value_as_expression(val);
615       fprintf(stderr, "(");
616       if (expr->inexact)
617          fprintf(stderr, "~");
618       switch (expr->opcode) {
619 #define CASE(n) \
620       case nir_search_op_##n: fprintf(stderr, #n); break;
621       CASE(f2b)
622       CASE(b2f)
623       CASE(b2i)
624       CASE(i2b)
625       CASE(i2i)
626       CASE(f2i)
627       CASE(i2f)
628 #undef CASE
629       default:
630          fprintf(stderr, "%s", nir_op_infos[expr->opcode].name);
631       }
632 
633       unsigned num_srcs = 1;
634       if (expr->opcode <= nir_last_opcode)
635          num_srcs = nir_op_infos[expr->opcode].num_inputs;
636 
637       for (unsigned i = 0; i < num_srcs; i++) {
638          fprintf(stderr, " ");
639          dump_value(table, &table->values[expr->srcs[i]].value);
640       }
641 
642       fprintf(stderr, ")");
643       break;
644    }
645    }
646 
647    if (val->bit_size > 0)
648       fprintf(stderr, "@%d", val->bit_size);
649 }
650 
651 static void
add_uses_to_worklist(nir_instr * instr,nir_instr_worklist * worklist,struct util_dynarray * states,const struct per_op_table * pass_op_table)652 add_uses_to_worklist(nir_instr *instr,
653                      nir_instr_worklist *worklist,
654                      struct util_dynarray *states,
655                      const struct per_op_table *pass_op_table)
656 {
657    nir_ssa_def *def = nir_instr_ssa_def(instr);
658 
659    nir_foreach_use_safe(use_src, def) {
660       if (nir_algebraic_automaton(use_src->parent_instr, states, pass_op_table))
661          nir_instr_worklist_push_tail(worklist, use_src->parent_instr);
662    }
663 }
664 
665 static void
nir_algebraic_update_automaton(nir_instr * new_instr,nir_instr_worklist * algebraic_worklist,struct util_dynarray * states,const struct per_op_table * pass_op_table)666 nir_algebraic_update_automaton(nir_instr *new_instr,
667                                nir_instr_worklist *algebraic_worklist,
668                                struct util_dynarray *states,
669                                const struct per_op_table *pass_op_table)
670 {
671 
672    nir_instr_worklist *automaton_worklist = nir_instr_worklist_create();
673 
674    /* Walk through the tree of uses of our new instruction's SSA value,
675     * recursively updating the automaton state until it stabilizes.
676     */
677    add_uses_to_worklist(new_instr, automaton_worklist, states, pass_op_table);
678 
679    nir_instr *instr;
680    while ((instr = nir_instr_worklist_pop_head(automaton_worklist))) {
681       nir_instr_worklist_push_tail(algebraic_worklist, instr);
682       add_uses_to_worklist(instr, automaton_worklist, states, pass_op_table);
683    }
684 
685    nir_instr_worklist_destroy(automaton_worklist);
686 }
687 
688 nir_ssa_def *
nir_replace_instr(nir_builder * build,nir_alu_instr * instr,struct hash_table * range_ht,struct util_dynarray * states,const nir_algebraic_table * table,const nir_search_expression * search,const nir_search_value * replace,nir_instr_worklist * algebraic_worklist)689 nir_replace_instr(nir_builder *build, nir_alu_instr *instr,
690                   struct hash_table *range_ht,
691                   struct util_dynarray *states,
692                   const nir_algebraic_table *table,
693                   const nir_search_expression *search,
694                   const nir_search_value *replace,
695                   nir_instr_worklist *algebraic_worklist)
696 {
697    uint8_t swizzle[NIR_MAX_VEC_COMPONENTS] = { 0 };
698 
699    for (unsigned i = 0; i < instr->dest.dest.ssa.num_components; ++i)
700       swizzle[i] = i;
701 
702    assert(instr->dest.dest.is_ssa);
703 
704    struct match_state state;
705    state.inexact_match = false;
706    state.has_exact_alu = false;
707    state.range_ht = range_ht;
708    state.pass_op_table = table->pass_op_table;
709    state.table = table;
710 
711    STATIC_ASSERT(sizeof(state.comm_op_direction) * 8 >= NIR_SEARCH_MAX_COMM_OPS);
712 
713    unsigned comm_expr_combinations =
714       1 << MIN2(search->comm_exprs, NIR_SEARCH_MAX_COMM_OPS);
715 
716    bool found = false;
717    for (unsigned comb = 0; comb < comm_expr_combinations; comb++) {
718       /* The bitfield of directions is just the current iteration.  Hooray for
719        * binary.
720        */
721       state.comm_op_direction = comb;
722       state.variables_seen = 0;
723 
724       if (match_expression(table, search, instr,
725                            instr->dest.dest.ssa.num_components,
726                            swizzle, &state)) {
727          found = true;
728          break;
729       }
730    }
731    if (!found)
732       return NULL;
733 
734 #if 0
735    fprintf(stderr, "matched: ");
736    dump_value(&search->value);
737    fprintf(stderr, " -> ");
738    dump_value(replace);
739    fprintf(stderr, " ssa_%d\n", instr->dest.dest.ssa.index);
740 #endif
741 
742    /* If the instruction at the root of the expression tree being replaced is
743     * a unary operation, insert the replacement instructions at the location
744     * of the source of the unary operation.  Otherwise, insert the replacement
745     * instructions at the location of the expression tree root.
746     *
747     * For the unary operation case, this is done to prevent some spurious code
748     * motion that can dramatically extend live ranges.  Imagine an expression
749     * like -(A+B) where the addtion and the negation are separated by flow
750     * control and thousands of instructions.  If this expression is replaced
751     * with -A+-B, inserting the new instructions at the site of the negation
752     * could extend the live range of A and B dramtically.  This could increase
753     * register pressure and cause spilling.
754     *
755     * It may well be that moving instructions around is a good thing, but
756     * keeping algebraic optimizations and code motion optimizations separate
757     * seems safest.
758     */
759    nir_alu_instr *const src_instr = nir_src_as_alu_instr(instr->src[0].src);
760    if (src_instr != NULL &&
761        (instr->op == nir_op_fneg || instr->op == nir_op_fabs ||
762         instr->op == nir_op_ineg || instr->op == nir_op_iabs ||
763         instr->op == nir_op_inot)) {
764       /* Insert new instructions *after*.  Otherwise a hypothetical
765        * replacement fneg(X) -> fabs(X) would insert the fabs() instruction
766        * before X!  This can also occur for things like fneg(X.wzyx) -> X.wzyx
767        * in vector mode.  A move instruction to handle the swizzle will get
768        * inserted before X.
769        *
770        * This manifested in a single OpenGL ES 2.0 CTS vertex shader test on
771        * older Intel GPU that use vector-mode vertex processing.
772        */
773       build->cursor = nir_after_instr(&src_instr->instr);
774    } else {
775       build->cursor = nir_before_instr(&instr->instr);
776    }
777 
778    state.states = states;
779 
780    nir_alu_src val = construct_value(build, replace,
781                                      instr->dest.dest.ssa.num_components,
782                                      instr->dest.dest.ssa.bit_size,
783                                      &state, &instr->instr);
784 
785    /* Note that NIR builder will elide the MOV if it's a no-op, which may
786     * allow more work to be done in a single pass through algebraic.
787     */
788    nir_ssa_def *ssa_val =
789       nir_mov_alu(build, val, instr->dest.dest.ssa.num_components);
790    if (ssa_val->index == util_dynarray_num_elements(states, uint16_t)) {
791       util_dynarray_append(states, uint16_t, 0);
792       nir_algebraic_automaton(ssa_val->parent_instr, states, table->pass_op_table);
793    }
794 
795    /* Rewrite the uses of the old SSA value to the new one, and recurse
796     * through the uses updating the automaton's state.
797     */
798    nir_ssa_def_rewrite_uses(&instr->dest.dest.ssa, ssa_val);
799    nir_algebraic_update_automaton(ssa_val->parent_instr, algebraic_worklist,
800                                   states, table->pass_op_table);
801 
802    /* Nothing uses the instr any more, so drop it out of the program.  Note
803     * that the instr may be in the worklist still, so we can't free it
804     * directly.
805     */
806    nir_instr_remove(&instr->instr);
807 
808    return ssa_val;
809 }
810 
811 static bool
nir_algebraic_automaton(nir_instr * instr,struct util_dynarray * states,const struct per_op_table * pass_op_table)812 nir_algebraic_automaton(nir_instr *instr, struct util_dynarray *states,
813                         const struct per_op_table *pass_op_table)
814 {
815    switch (instr->type) {
816    case nir_instr_type_alu: {
817       nir_alu_instr *alu = nir_instr_as_alu(instr);
818       nir_op op = alu->op;
819       uint16_t search_op = nir_search_op_for_nir_op(op);
820       const struct per_op_table *tbl = &pass_op_table[search_op];
821       if (tbl->num_filtered_states == 0)
822          return false;
823 
824       /* Calculate the index into the transition table. Note the index
825        * calculated must match the iteration order of Python's
826        * itertools.product(), which was used to emit the transition
827        * table.
828        */
829       unsigned index = 0;
830       for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
831          index *= tbl->num_filtered_states;
832          if (tbl->filter)
833             index += tbl->filter[*util_dynarray_element(states, uint16_t,
834                                                         alu->src[i].src.ssa->index)];
835       }
836 
837       uint16_t *state = util_dynarray_element(states, uint16_t,
838                                               alu->dest.dest.ssa.index);
839       if (*state != tbl->table[index]) {
840          *state = tbl->table[index];
841          return true;
842       }
843       return false;
844    }
845 
846    case nir_instr_type_load_const: {
847       nir_load_const_instr *load_const = nir_instr_as_load_const(instr);
848       uint16_t *state = util_dynarray_element(states, uint16_t,
849                                               load_const->def.index);
850       if (*state != CONST_STATE) {
851          *state = CONST_STATE;
852          return true;
853       }
854       return false;
855    }
856 
857    default:
858       return false;
859    }
860 }
861 
862 static bool
nir_algebraic_instr(nir_builder * build,nir_instr * instr,struct hash_table * range_ht,const bool * condition_flags,const nir_algebraic_table * table,struct util_dynarray * states,nir_instr_worklist * worklist)863 nir_algebraic_instr(nir_builder *build, nir_instr *instr,
864                     struct hash_table *range_ht,
865                     const bool *condition_flags,
866                     const nir_algebraic_table *table,
867                     struct util_dynarray *states,
868                     nir_instr_worklist *worklist)
869 {
870 
871    if (instr->type != nir_instr_type_alu)
872       return false;
873 
874    nir_alu_instr *alu = nir_instr_as_alu(instr);
875    if (!alu->dest.dest.is_ssa)
876       return false;
877 
878    unsigned bit_size = alu->dest.dest.ssa.bit_size;
879    const unsigned execution_mode =
880       build->shader->info.float_controls_execution_mode;
881    const bool ignore_inexact =
882       nir_is_float_control_signed_zero_inf_nan_preserve(execution_mode, bit_size) ||
883       nir_is_denorm_flush_to_zero(execution_mode, bit_size);
884 
885    int xform_idx = *util_dynarray_element(states, uint16_t,
886                                           alu->dest.dest.ssa.index);
887    for (const struct transform *xform = &table->transforms[table->transform_offsets[xform_idx]];
888         xform->condition_offset != ~0;
889         xform++) {
890       if (condition_flags[xform->condition_offset] &&
891           !(table->values[xform->search].expression.inexact && ignore_inexact) &&
892           nir_replace_instr(build, alu, range_ht, states, table,
893                             &table->values[xform->search].expression,
894                             &table->values[xform->replace].value, worklist)) {
895          _mesa_hash_table_clear(range_ht, NULL);
896          return true;
897       }
898    }
899 
900    return false;
901 }
902 
903 bool
nir_algebraic_impl(nir_function_impl * impl,const bool * condition_flags,const nir_algebraic_table * table)904 nir_algebraic_impl(nir_function_impl *impl,
905                    const bool *condition_flags,
906                    const nir_algebraic_table *table)
907 {
908    bool progress = false;
909 
910    nir_builder build;
911    nir_builder_init(&build, impl);
912 
913    /* Note: it's important here that we're allocating a zeroed array, since
914     * state 0 is the default state, which means we don't have to visit
915     * anything other than constants and ALU instructions.
916     */
917    struct util_dynarray states = {0};
918    if (!util_dynarray_resize(&states, uint16_t, impl->ssa_alloc)) {
919       nir_metadata_preserve(impl, nir_metadata_all);
920       return false;
921    }
922    memset(states.data, 0, states.size);
923 
924    struct hash_table *range_ht = _mesa_pointer_hash_table_create(NULL);
925 
926    nir_instr_worklist *worklist = nir_instr_worklist_create();
927 
928    /* Walk top-to-bottom setting up the automaton state. */
929    nir_foreach_block(block, impl) {
930       nir_foreach_instr(instr, block) {
931          nir_algebraic_automaton(instr, &states, table->pass_op_table);
932       }
933    }
934 
935    /* Put our instrs in the worklist such that we're popping the last instr
936     * first.  This will encourage us to match the biggest source patterns when
937     * possible.
938     */
939    nir_foreach_block_reverse(block, impl) {
940       nir_foreach_instr_reverse(instr, block) {
941          if (instr->type == nir_instr_type_alu)
942             nir_instr_worklist_push_tail(worklist, instr);
943       }
944    }
945 
946    nir_instr *instr;
947    while ((instr = nir_instr_worklist_pop_head(worklist))) {
948       /* The worklist can have an instr pushed to it multiple times if it was
949        * the src of multiple instrs that also got optimized, so make sure that
950        * we don't try to re-optimize an instr we already handled.
951        */
952       if (exec_node_is_tail_sentinel(&instr->node))
953          continue;
954 
955       progress |= nir_algebraic_instr(&build, instr,
956                                       range_ht, condition_flags,
957                                       table, &states, worklist);
958    }
959 
960    nir_instr_worklist_destroy(worklist);
961    ralloc_free(range_ht);
962    util_dynarray_fini(&states);
963 
964    if (progress) {
965       nir_metadata_preserve(impl, nir_metadata_block_index |
966                                   nir_metadata_dominance);
967    } else {
968       nir_metadata_preserve(impl, nir_metadata_all);
969    }
970 
971    return progress;
972 }
973