• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2# Copyright (C) 2014 Intel Corporation
3#
4# Permission is hereby granted, free of charge, to any person obtaining a
5# copy of this software and associated documentation files (the "Software"),
6# to deal in the Software without restriction, including without limitation
7# the rights to use, copy, modify, merge, publish, distribute, sublicense,
8# and/or sell copies of the Software, and to permit persons to whom the
9# Software is furnished to do so, subject to the following conditions:
10#
11# The above copyright notice and this permission notice (including the next
12# paragraph) shall be included in all copies or substantial portions of the
13# Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21# IN THE SOFTWARE.
22
23import ast
24from collections import defaultdict
25import itertools
26import struct
27import sys
28import mako.template
29import re
30import traceback
31
32from nir_opcodes import opcodes, type_sizes
33
34# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c
35nir_search_max_comm_ops = 8
36
37# These opcodes are only employed by nir_search.  This provides a mapping from
38# opcode to destination type.
39conv_opcode_types = {
40    'i2f' : 'float',
41    'u2f' : 'float',
42    'f2f' : 'float',
43    'f2u' : 'uint',
44    'f2i' : 'int',
45    'u2u' : 'uint',
46    'i2i' : 'int',
47    'b2f' : 'float',
48    'b2i' : 'int',
49    'i2b' : 'bool',
50    'f2b' : 'bool',
51}
52
53def get_cond_index(conds, cond):
54    if cond:
55        if cond in conds:
56            return conds[cond]
57        else:
58            cond_index = len(conds)
59            conds[cond] = cond_index
60            return cond_index
61    else:
62        return -1
63
64def get_c_opcode(op):
65      if op in conv_opcode_types:
66         return 'nir_search_op_' + op
67      else:
68         return 'nir_op_' + op
69
70_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
71
72def type_bits(type_str):
73   m = _type_re.match(type_str)
74   assert m.group('type')
75
76   if m.group('bits') is None:
77      return 0
78   else:
79      return int(m.group('bits'))
80
81# Represents a set of variables, each with a unique id
82class VarSet(object):
83   def __init__(self):
84      self.names = {}
85      self.ids = itertools.count()
86      self.immutable = False;
87
88   def __getitem__(self, name):
89      if name not in self.names:
90         assert not self.immutable, "Unknown replacement variable: " + name
91         self.names[name] = next(self.ids)
92
93      return self.names[name]
94
95   def lock(self):
96      self.immutable = True
97
98class SearchExpression(object):
99   def __init__(self, expr):
100      self.opcode = expr[0]
101      self.sources = expr[1:]
102      self.ignore_exact = False
103
104   @staticmethod
105   def create(val):
106      if isinstance(val, tuple):
107         return SearchExpression(val)
108      else:
109         assert(isinstance(val, SearchExpression))
110         return val
111
112   def __repr__(self):
113      l = [self.opcode, *self.sources]
114      if self.ignore_exact:
115         l.append('ignore_exact')
116      return repr((*l,))
117
118class Value(object):
119   @staticmethod
120   def create(val, name_base, varset, algebraic_pass):
121      if isinstance(val, bytes):
122         val = val.decode('utf-8')
123
124      if isinstance(val, tuple) or isinstance(val, SearchExpression):
125         return Expression(val, name_base, varset, algebraic_pass)
126      elif isinstance(val, Expression):
127         return val
128      elif isinstance(val, str):
129         return Variable(val, name_base, varset, algebraic_pass)
130      elif isinstance(val, (bool, float, int)):
131         return Constant(val, name_base)
132
133   def __init__(self, val, name, type_str):
134      self.in_val = str(val)
135      self.name = name
136      self.type_str = type_str
137
138   def __str__(self):
139      return self.in_val
140
141   def get_bit_size(self):
142      """Get the physical bit-size that has been chosen for this value, or if
143      there is none, the canonical value which currently represents this
144      bit-size class. Variables will be preferred, i.e. if there are any
145      variables in the equivalence class, the canonical value will be a
146      variable. We do this since we'll need to know which variable each value
147      is equivalent to when constructing the replacement expression. This is
148      the "find" part of the union-find algorithm.
149      """
150      bit_size = self
151
152      while isinstance(bit_size, Value):
153         if bit_size._bit_size is None:
154            break
155         bit_size = bit_size._bit_size
156
157      if bit_size is not self:
158         self._bit_size = bit_size
159      return bit_size
160
161   def set_bit_size(self, other):
162      """Make self.get_bit_size() return what other.get_bit_size() return
163      before calling this, or just "other" if it's a concrete bit-size. This is
164      the "union" part of the union-find algorithm.
165      """
166
167      self_bit_size = self.get_bit_size()
168      other_bit_size = other if isinstance(other, int) else other.get_bit_size()
169
170      if self_bit_size == other_bit_size:
171         return
172
173      self_bit_size._bit_size = other_bit_size
174
175   @property
176   def type_enum(self):
177      return "nir_search_value_" + self.type_str
178
179   @property
180   def c_bit_size(self):
181      bit_size = self.get_bit_size()
182      if isinstance(bit_size, int):
183         return bit_size
184      elif isinstance(bit_size, Variable):
185         return -bit_size.index - 1
186      else:
187         # If the bit-size class is neither a variable, nor an actual bit-size, then
188         # - If it's in the search expression, we don't need to check anything
189         # - If it's in the replace expression, either it's ambiguous (in which
190         # case we'd reject it), or it equals the bit-size of the search value
191         # We represent these cases with a 0 bit-size.
192         return 0
193
194   __template = mako.template.Template("""   { .${val.type_str} = {
195      { ${val.type_enum}, ${val.c_bit_size} },
196% if isinstance(val, Constant):
197      ${val.type()}, { ${val.hex()} /* ${val.value} */ },
198% elif isinstance(val, Variable):
199      ${val.index}, /* ${val.var_name} */
200      ${'true' if val.is_constant else 'false'},
201      ${val.type() or 'nir_type_invalid' },
202      ${val.cond_index},
203      ${val.swizzle()},
204% elif isinstance(val, Expression):
205      ${'true' if val.inexact else 'false'},
206      ${'true' if val.exact else 'false'},
207      ${'true' if val.ignore_exact else 'false'},
208      ${'true' if val.nsz else 'false'},
209      ${'true' if val.nnan else 'false'},
210      ${'true' if val.ninf else 'false'},
211      ${'true' if val.swizzle_y else 'false'},
212      ${val.c_opcode()},
213      ${val.comm_expr_idx}, ${val.comm_exprs},
214      { ${', '.join(src.array_index for src in val.sources)} },
215      ${val.cond_index},
216% endif
217   } },
218""")
219
220   def render(self, cache):
221      struct_init = self.__template.render(val=self,
222                                           Constant=Constant,
223                                           Variable=Variable,
224                                           Expression=Expression)
225      if struct_init in cache:
226         # If it's in the cache, register a name remap in the cache and render
227         # only a comment saying it's been remapped
228         self.array_index = cache[struct_init]
229         return "   /* {} -> {} in the cache */\n".format(self.name,
230                                                       cache[struct_init])
231      else:
232         self.array_index = str(cache["next_index"])
233         cache[struct_init] = self.array_index
234         cache["next_index"] += 1
235         return struct_init
236
237_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
238
239class Constant(Value):
240   def __init__(self, val, name):
241      Value.__init__(self, val, name, "constant")
242
243      if isinstance(val, (str)):
244         m = _constant_re.match(val)
245         self.value = ast.literal_eval(m.group('value'))
246         self._bit_size = int(m.group('bits')) if m.group('bits') else None
247      else:
248         self.value = val
249         self._bit_size = None
250
251      if isinstance(self.value, bool):
252         assert self._bit_size is None or self._bit_size == 1
253         self._bit_size = 1
254
255   def hex(self):
256      if isinstance(self.value, (bool)):
257         return 'NIR_TRUE' if self.value else 'NIR_FALSE'
258      if isinstance(self.value, int):
259         # Explicitly sign-extend negative integers to 64-bit, ensuring correct
260         # handling of -INT32_MIN which is not representable in 32-bit.
261         if self.value < 0:
262            return hex(struct.unpack('Q', struct.pack('q', self.value))[0]) + 'ull'
263         else:
264            return hex(self.value) + 'ull'
265      elif isinstance(self.value, float):
266         return hex(struct.unpack('Q', struct.pack('d', self.value))[0]) + 'ull'
267      else:
268         assert False
269
270   def type(self):
271      if isinstance(self.value, (bool)):
272         return "nir_type_bool"
273      elif isinstance(self.value, int):
274         return "nir_type_int"
275      elif isinstance(self.value, float):
276         return "nir_type_float"
277
278   def equivalent(self, other):
279      """Check that two constants are equivalent.
280
281      This is check is much weaker than equality.  One generally cannot be
282      used in place of the other.  Using this implementation for the __eq__
283      will break BitSizeValidator.
284
285      """
286      if not isinstance(other, type(self)):
287         return False
288
289      return self.value == other.value
290
291# The $ at the end forces there to be an error if any part of the string
292# doesn't match one of the field patterns.
293_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
294                          r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
295                          r"(?P<cond>\([^\)]+\))?"
296                          r"(?P<swiz>\.[xyzwabcdefghijklmnop]+)?"
297                          r"$")
298
299class Variable(Value):
300   def __init__(self, val, name, varset, algebraic_pass):
301      Value.__init__(self, val, name, "variable")
302
303      m = _var_name_re.match(val)
304      assert m and m.group('name') is not None, \
305            "Malformed variable name \"{}\".".format(val)
306
307      self.var_name = m.group('name')
308
309      # Prevent common cases where someone puts quotes around a literal
310      # constant.  If we want to support names that have numeric or
311      # punctuation characters, we can me the first assertion more flexible.
312      assert self.var_name.isalpha()
313      assert self.var_name != 'True'
314      assert self.var_name != 'False'
315
316      self.is_constant = m.group('const') is not None
317      self.cond_index = get_cond_index(algebraic_pass.variable_cond, m.group('cond'))
318      self.required_type = m.group('type')
319      self._bit_size = int(m.group('bits')) if m.group('bits') else None
320      self.swiz = m.group('swiz')
321
322      if self.required_type == 'bool':
323         if self._bit_size is not None:
324            assert self._bit_size in type_sizes(self.required_type)
325         else:
326            self._bit_size = 1
327
328      if self.required_type is not None:
329         assert self.required_type in ('float', 'bool', 'int', 'uint')
330
331      self.index = varset[self.var_name]
332
333   def type(self):
334      if self.required_type == 'bool':
335         return "nir_type_bool"
336      elif self.required_type in ('int', 'uint'):
337         return "nir_type_int"
338      elif self.required_type == 'float':
339         return "nir_type_float"
340
341   def equivalent(self, other):
342      """Check that two variables are equivalent.
343
344      This is check is much weaker than equality.  One generally cannot be
345      used in place of the other.  Using this implementation for the __eq__
346      will break BitSizeValidator.
347
348      """
349      if not isinstance(other, type(self)):
350         return False
351
352      return self.index == other.index
353
354   def swizzle(self):
355      if self.swiz is not None:
356         swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3,
357                     'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3,
358                     'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7,
359                     'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11,
360                     'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 }
361         return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}'
362      return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}'
363
364_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
365                        r"(?P<cond>\([^\)]+\))?(?P<swizzle_y>\.y)?")
366
367class Expression(Value):
368   def __init__(self, expr, name_base, varset, algebraic_pass):
369      Value.__init__(self, expr, name_base, "expression")
370
371      expr = SearchExpression.create(expr)
372
373      m = _opcode_re.match(expr.opcode)
374      assert m and m.group('opcode') is not None
375
376      self.opcode = m.group('opcode')
377      self._bit_size = int(m.group('bits')) if m.group('bits') else None
378      self.inexact = m.group('inexact') is not None
379      self.exact = m.group('exact') is not None
380      self.ignore_exact = expr.ignore_exact
381      self.cond = m.group('cond')
382
383      assert not self.inexact or not self.exact, \
384            'Expression cannot be both exact and inexact.'
385
386      # "many-comm-expr" isn't really a condition.  It's notification to the
387      # generator that this pattern is known to have too many commutative
388      # expressions, and an error should not be generated for this case.
389      # nsz, nnan and ninf are special conditions, so we treat them specially too.
390      cond = {k: True for k in self.cond[1:-1].split(",")} if self.cond else {}
391      self.many_commutative_expressions = cond.pop('many-comm-expr', False)
392      self.nsz = cond.pop('nsz', False)
393      self.nnan = cond.pop('nnan', False)
394      self.ninf = cond.pop('ninf', False)
395      self.swizzle_y = m.group('swizzle_y') is not None
396
397      assert len(cond) <= 1
398      self.cond = cond.popitem()[0] if cond else None
399
400      # Deduplicate references to the condition functions for the expressions
401      # and save the index for the order they were added.
402      self.cond_index = get_cond_index(algebraic_pass.expression_cond, self.cond)
403
404      self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset, algebraic_pass)
405                       for (i, src) in enumerate(expr.sources) ]
406
407      # nir_search_expression::srcs is hard-coded to 4
408      assert len(self.sources) <= 4
409
410      if self.opcode in conv_opcode_types:
411         assert self._bit_size is None, \
412                'Expression cannot use an unsized conversion opcode with ' \
413                'an explicit size; that\'s silly.'
414
415      self.__index_comm_exprs(0)
416
417   def equivalent(self, other):
418      """Check that two variables are equivalent.
419
420      This is check is much weaker than equality.  One generally cannot be
421      used in place of the other.  Using this implementation for the __eq__
422      will break BitSizeValidator.
423
424      This implementation does not check for equivalence due to commutativity,
425      but it could.
426
427      """
428      if not isinstance(other, type(self)):
429         return False
430
431      if len(self.sources) != len(other.sources):
432         return False
433
434      if self.opcode != other.opcode:
435         return False
436
437      return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
438
439   def __index_comm_exprs(self, base_idx):
440      """Recursively count and index commutative expressions
441      """
442      self.comm_exprs = 0
443
444      # A note about the explicit "len(self.sources)" check. The list of
445      # sources comes from user input, and that input might be bad.  Check
446      # that the expected second source exists before accessing it. Without
447      # this check, a unit test that does "('iadd', 'a')" will crash.
448      if self.opcode not in conv_opcode_types and \
449         "2src_commutative" in opcodes[self.opcode].algebraic_properties and \
450         len(self.sources) >= 2 and \
451         not self.sources[0].equivalent(self.sources[1]):
452         self.comm_expr_idx = base_idx
453         self.comm_exprs += 1
454      else:
455         self.comm_expr_idx = -1
456
457      for s in self.sources:
458         if isinstance(s, Expression):
459            s.__index_comm_exprs(base_idx + self.comm_exprs)
460            self.comm_exprs += s.comm_exprs
461
462      return self.comm_exprs
463
464   def c_opcode(self):
465      return get_c_opcode(self.opcode)
466
467   def render(self, cache):
468      srcs = "".join(src.render(cache) for src in self.sources)
469      return srcs + super(Expression, self).render(cache)
470
471class BitSizeValidator(object):
472   """A class for validating bit sizes of expressions.
473
474   NIR supports multiple bit-sizes on expressions in order to handle things
475   such as fp64.  The source and destination of every ALU operation is
476   assigned a type and that type may or may not specify a bit size.  Sources
477   and destinations whose type does not specify a bit size are considered
478   "unsized" and automatically take on the bit size of the corresponding
479   register or SSA value.  NIR has two simple rules for bit sizes that are
480   validated by nir_validator:
481
482    1) A given SSA def or register has a single bit size that is respected by
483       everything that reads from it or writes to it.
484
485    2) The bit sizes of all unsized inputs/outputs on any given ALU
486       instruction must match.  They need not match the sized inputs or
487       outputs but they must match each other.
488
489   In order to keep nir_algebraic relatively simple and easy-to-use,
490   nir_search supports a type of bit-size inference based on the two rules
491   above.  This is similar to type inference in many common programming
492   languages.  If, for instance, you are constructing an add operation and you
493   know the second source is 16-bit, then you know that the other source and
494   the destination must also be 16-bit.  There are, however, cases where this
495   inference can be ambiguous or contradictory.  Consider, for instance, the
496   following transformation:
497
498   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
499
500   This transformation can potentially cause a problem because usub_borrow is
501   well-defined for any bit-size of integer.  However, b2i always generates a
502   32-bit result so it could end up replacing a 64-bit expression with one
503   that takes two 64-bit values and produces a 32-bit value.  As another
504   example, consider this expression:
505
506   (('bcsel', a, b, 0), ('iand', a, b))
507
508   In this case, in the search expression a must be 32-bit but b can
509   potentially have any bit size.  If we had a 64-bit b value, we would end up
510   trying to and a 32-bit value with a 64-bit value which would be invalid
511
512   This class solves that problem by providing a validation layer that proves
513   that a given search-and-replace operation is 100% well-defined before we
514   generate any code.  This ensures that bugs are caught at compile time
515   rather than at run time.
516
517   Each value maintains a "bit-size class", which is either an actual bit size
518   or an equivalence class with other values that must have the same bit size.
519   The validator works by combining bit-size classes with each other according
520   to the NIR rules outlined above, checking that there are no inconsistencies.
521   When doing this for the replacement expression, we make sure to never change
522   the equivalence class of any of the search values. We could make the example
523   transforms above work by doing some extra run-time checking of the search
524   expression, but we make the user specify those constraints themselves, to
525   avoid any surprises. Since the replacement bitsizes can only be connected to
526   the source bitsize via variables (variables must have the same bitsize in
527   the source and replacment expressions) or the roots of the expression (the
528   replacement expression must produce the same bit size as the search
529   expression), we prevent merging a variable with anything when processing the
530   replacement expression, or specializing the search bitsize
531   with anything. The former prevents
532
533   (('bcsel', a, b, 0), ('iand', a, b))
534
535   from being allowed, since we'd have to merge the bitsizes for a and b due to
536   the 'iand', while the latter prevents
537
538   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
539
540   from being allowed, since the search expression has the bit size of a and b,
541   which can't be specialized to 32 which is the bitsize of the replace
542   expression. It also prevents something like:
543
544   (('b2i', ('i2b', a)), ('ineq', a, 0))
545
546   since the bitsize of 'b2i', which can be anything, can't be specialized to
547   the bitsize of a.
548
549   After doing all this, we check that every subexpression of the replacement
550   was assigned a constant bitsize, the bitsize of a variable, or the bitsize
551   of the search expresssion, since those are the things that are known when
552   constructing the replacement expresssion. Finally, we record the bitsize
553   needed in nir_search_value so that we know what to do when building the
554   replacement expression.
555   """
556
557   def __init__(self, varset):
558      self._var_classes = [None] * len(varset.names)
559
560   def compare_bitsizes(self, a, b):
561      """Determines which bitsize class is a specialization of the other, or
562      whether neither is. When we merge two different bitsizes, the
563      less-specialized bitsize always points to the more-specialized one, so
564      that calling get_bit_size() always gets you the most specialized bitsize.
565      The specialization partial order is given by:
566      - Physical bitsizes are always the most specialized, and a different
567        bitsize can never specialize another.
568      - In the search expression, variables can always be specialized to each
569        other and to physical bitsizes. In the replace expression, we disallow
570        this to avoid adding extra constraints to the search expression that
571        the user didn't specify.
572      - Expressions and constants without a bitsize can always be specialized to
573        each other and variables, but not the other way around.
574
575        We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
576        and None if they are not comparable (neither a <= b nor b <= a).
577      """
578      if isinstance(a, int):
579         if isinstance(b, int):
580            return 0 if a == b else None
581         elif isinstance(b, Variable):
582            return -1 if self.is_search else None
583         else:
584            return -1
585      elif isinstance(a, Variable):
586         if isinstance(b, int):
587            return 1 if self.is_search else None
588         elif isinstance(b, Variable):
589            return 0 if self.is_search or a.index == b.index else None
590         else:
591            return -1
592      else:
593         if isinstance(b, int):
594            return 1
595         elif isinstance(b, Variable):
596            return 1
597         else:
598            return 0
599
600   def unify_bit_size(self, a, b, error_msg):
601      """Record that a must have the same bit-size as b. If both
602      have been assigned conflicting physical bit-sizes, call "error_msg" with
603      the bit-sizes of self and other to get a message and raise an error.
604      In the replace expression, disallow merging variables with other
605      variables and physical bit-sizes as well.
606      """
607      a_bit_size = a.get_bit_size()
608      b_bit_size = b if isinstance(b, int) else b.get_bit_size()
609
610      cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
611
612      assert cmp_result is not None, \
613         error_msg(a_bit_size, b_bit_size)
614
615      if cmp_result < 0:
616         b_bit_size.set_bit_size(a)
617      elif not isinstance(a_bit_size, int):
618         a_bit_size.set_bit_size(b)
619
620   def merge_variables(self, val):
621      """Perform the first part of type inference by merging all the different
622      uses of the same variable. We always do this as if we're in the search
623      expression, even if we're actually not, since otherwise we'd get errors
624      if the search expression specified some constraint but the replace
625      expression didn't, because we'd be merging a variable and a constant.
626      """
627      if isinstance(val, Variable):
628         if self._var_classes[val.index] is None:
629            self._var_classes[val.index] = val
630         else:
631            other = self._var_classes[val.index]
632            self.unify_bit_size(other, val,
633                  lambda other_bit_size, bit_size:
634                     'Variable {} has conflicting bit size requirements: ' \
635                     'it must have bit size {} and {}'.format(
636                        val.var_name, other_bit_size, bit_size))
637      elif isinstance(val, Expression):
638         for src in val.sources:
639            self.merge_variables(src)
640
641   def validate_value(self, val):
642      """Validate the an expression by performing classic Hindley-Milner
643      type inference on bitsizes. This will detect if there are any conflicting
644      requirements, and unify variables so that we know which variables must
645      have the same bitsize. If we're operating on the replace expression, we
646      will refuse to merge different variables together or merge a variable
647      with a constant, in order to prevent surprises due to rules unexpectedly
648      not matching at runtime.
649      """
650      if not isinstance(val, Expression):
651         return
652
653      # Generic conversion ops are special in that they have a single unsized
654      # source and an unsized destination and the two don't have to match.
655      # This means there's no validation or unioning to do here besides the
656      # len(val.sources) check.
657      if val.opcode in conv_opcode_types:
658         assert len(val.sources) == 1, \
659            "Expression {} has {} sources, expected 1".format(
660               val, len(val.sources))
661         self.validate_value(val.sources[0])
662         return
663
664      nir_op = opcodes[val.opcode]
665      assert len(val.sources) == nir_op.num_inputs, \
666         "Expression {} has {} sources, expected {}".format(
667            val, len(val.sources), nir_op.num_inputs)
668
669      for src in val.sources:
670         self.validate_value(src)
671
672      dst_type_bits = type_bits(nir_op.output_type)
673
674      # First, unify all the sources. That way, an error coming up because two
675      # sources have an incompatible bit-size won't produce an error message
676      # involving the destination.
677      first_unsized_src = None
678      for src_type, src in zip(nir_op.input_types, val.sources):
679         src_type_bits = type_bits(src_type)
680         if src_type_bits == 0:
681            if first_unsized_src is None:
682               first_unsized_src = src
683               continue
684
685            if self.is_search:
686               self.unify_bit_size(first_unsized_src, src,
687                  lambda first_unsized_src_bit_size, src_bit_size:
688                     'Source {} of {} must have bit size {}, while source {} ' \
689                     'must have incompatible bit size {}'.format(
690                        first_unsized_src, val, first_unsized_src_bit_size,
691                        src, src_bit_size))
692            else:
693               self.unify_bit_size(first_unsized_src, src,
694                  lambda first_unsized_src_bit_size, src_bit_size:
695                     'Sources {} (bit size of {}) and {} (bit size of {}) ' \
696                     'of {} may not have the same bit size when building the ' \
697                     'replacement expression.'.format(
698                        first_unsized_src, first_unsized_src_bit_size, src,
699                        src_bit_size, val))
700         else:
701            if self.is_search:
702               self.unify_bit_size(src, src_type_bits,
703                  lambda src_bit_size, unused:
704                     '{} must have {} bits, but as a source of nir_op_{} '\
705                     'it must have {} bits'.format(
706                        src, src_bit_size, nir_op.name, src_type_bits))
707            else:
708               self.unify_bit_size(src, src_type_bits,
709                  lambda src_bit_size, unused:
710                     '{} has the bit size of {}, but as a source of ' \
711                     'nir_op_{} it must have {} bits, which may not be the ' \
712                     'same'.format(
713                        src, src_bit_size, nir_op.name, src_type_bits))
714
715      if dst_type_bits == 0:
716         if first_unsized_src is not None:
717            if self.is_search:
718               self.unify_bit_size(val, first_unsized_src,
719                  lambda val_bit_size, src_bit_size:
720                     '{} must have the bit size of {}, while its source {} ' \
721                     'must have incompatible bit size {}'.format(
722                        val, val_bit_size, first_unsized_src, src_bit_size))
723            else:
724               self.unify_bit_size(val, first_unsized_src,
725                  lambda val_bit_size, src_bit_size:
726                     '{} must have {} bits, but its source {} ' \
727                     '(bit size of {}) may not have that bit size ' \
728                     'when building the replacement.'.format(
729                        val, val_bit_size, first_unsized_src, src_bit_size))
730      else:
731         self.unify_bit_size(val, dst_type_bits,
732            lambda dst_bit_size, unused:
733               '{} must have {} bits, but as a destination of nir_op_{} ' \
734               'it must have {} bits'.format(
735                  val, dst_bit_size, nir_op.name, dst_type_bits))
736
737   def validate_replace(self, val, search):
738      bit_size = val.get_bit_size()
739      assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
740            bit_size == search.get_bit_size(), \
741            'Ambiguous bit size for replacement value {}: ' \
742            'it cannot be deduced from a variable, a fixed bit size ' \
743            'somewhere, or the search expression.'.format(val)
744
745      if isinstance(val, Expression):
746         for src in val.sources:
747            self.validate_replace(src, search)
748      elif isinstance(val, Variable):
749          # These catch problems when someone copies and pastes the search
750          # into the replacement.
751          assert not val.is_constant, \
752              'Replacement variables must not be marked constant.'
753
754          assert val.cond_index == -1, \
755              'Replacement variables must not have a condition.'
756
757          assert not val.required_type, \
758              'Replacement variables must not have a required type.'
759
760   def validate(self, search, replace):
761      self.is_search = True
762      self.merge_variables(search)
763      self.merge_variables(replace)
764      self.validate_value(search)
765
766      self.is_search = False
767      self.validate_value(replace)
768
769      # Check that search is always more specialized than replace. Note that
770      # we're doing this in replace mode, disallowing merging variables.
771      search_bit_size = search.get_bit_size()
772      replace_bit_size = replace.get_bit_size()
773      cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
774
775      assert cmp_result is not None and cmp_result <= 0, \
776         'The search expression bit size {} and replace expression ' \
777         'bit size {} may not be the same'.format(
778               search_bit_size, replace_bit_size)
779
780      replace.set_bit_size(search)
781
782      self.validate_replace(replace, search)
783
784_optimization_ids = itertools.count()
785
786condition_list = ['true']
787
788class SearchAndReplace(object):
789   def __init__(self, transform, algebraic_pass):
790      self.id = next(_optimization_ids)
791
792      search = transform[0]
793      replace = transform[1]
794      if len(transform) > 2:
795         self.condition = transform[2]
796      else:
797         self.condition = 'true'
798
799      if self.condition not in condition_list:
800         condition_list.append(self.condition)
801      self.condition_index = condition_list.index(self.condition)
802
803      varset = VarSet()
804      if isinstance(search, Expression):
805         self.search = search
806      else:
807         self.search = Expression(search, "search{0}".format(self.id), varset, algebraic_pass)
808
809      varset.lock()
810
811      if isinstance(replace, Value):
812         self.replace = replace
813      else:
814         self.replace = Value.create(replace, "replace{0}".format(self.id), varset, algebraic_pass)
815
816      BitSizeValidator(varset).validate(self.search, self.replace)
817
818class TreeAutomaton(object):
819   """This class calculates a bottom-up tree automaton to quickly search for
820   the left-hand sides of tranforms. Tree automatons are a generalization of
821   classical NFA's and DFA's, where the transition function determines the
822   state of the parent node based on the state of its children. We construct a
823   deterministic automaton to match patterns, using a similar algorithm to the
824   classical NFA to DFA construction. At the moment, it only matches opcodes
825   and constants (without checking the actual value), leaving more detailed
826   checking to the search function which actually checks the leaves. The
827   automaton acts as a quick filter for the search function, requiring only n
828   + 1 table lookups for each n-source operation. The implementation is based
829   on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit."
830   In the language of that reference, this is a frontier-to-root deterministic
831   automaton using only symbol filtering. The filtering is crucial to reduce
832   both the time taken to generate the tables and the size of the tables.
833   """
834   def __init__(self, transforms):
835      self.patterns = [t.search for t in transforms]
836      self._compute_items()
837      self._build_table()
838      #print('num items: {}'.format(len(set(self.items.values()))))
839      #print('num states: {}'.format(len(self.states)))
840      #for state, patterns in zip(self.states, self.patterns):
841      #   print('{}: num patterns: {}'.format(state, len(patterns)))
842
843   class IndexMap(object):
844      """An indexed list of objects, where one can either lookup an object by
845      index or find the index associated to an object quickly using a hash
846      table. Compared to a list, it has a constant time index(). Compared to a
847      set, it provides a stable iteration order.
848      """
849      def __init__(self, iterable=()):
850         self.objects = []
851         self.map = {}
852         for obj in iterable:
853            self.add(obj)
854
855      def __getitem__(self, i):
856         return self.objects[i]
857
858      def __contains__(self, obj):
859         return obj in self.map
860
861      def __len__(self):
862         return len(self.objects)
863
864      def __iter__(self):
865         return iter(self.objects)
866
867      def clear(self):
868         self.objects = []
869         self.map.clear()
870
871      def index(self, obj):
872         return self.map[obj]
873
874      def add(self, obj):
875         if obj in self.map:
876            return self.map[obj]
877         else:
878            index = len(self.objects)
879            self.objects.append(obj)
880            self.map[obj] = index
881            return index
882
883      def __repr__(self):
884         return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])'
885
886   class Item(object):
887      """This represents an "item" in the language of "Tree Automatons." This
888      is just a subtree of some pattern, which represents a potential partial
889      match at runtime. We deduplicate them, so that identical subtrees of
890      different patterns share the same object, and store some extra
891      information needed for the main algorithm as well.
892      """
893      def __init__(self, opcode, children):
894         self.opcode = opcode
895         self.children = children
896         # These are the indices of patterns for which this item is the root node.
897         self.patterns = []
898         # This the set of opcodes for parents of this item. Used to speed up
899         # filtering.
900         self.parent_ops = set()
901
902      def __str__(self):
903         return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')'
904
905      def __repr__(self):
906         return str(self)
907
908   def _compute_items(self):
909      """Build a set of all possible items, deduplicating them."""
910      # This is a map from (opcode, sources) to item.
911      self.items = {}
912
913      # The set of all opcodes used by the patterns. Used later to avoid
914      # building and emitting all the tables for opcodes that aren't used.
915      self.opcodes = self.IndexMap()
916
917      def get_item(opcode, children, pattern=None):
918         commutative = len(children) >= 2 \
919               and "2src_commutative" in opcodes[opcode].algebraic_properties
920         item = self.items.setdefault((opcode, children),
921                                      self.Item(opcode, children))
922         if commutative:
923            self.items[opcode, (children[1], children[0]) + children[2:]] = item
924         if pattern is not None:
925            item.patterns.append(pattern)
926         return item
927
928      self.wildcard = get_item("__wildcard", ())
929      self.const = get_item("__const", ())
930
931      def process_subpattern(src, pattern=None):
932         if isinstance(src, Constant):
933            # Note: we throw away the actual constant value!
934            return self.const
935         elif isinstance(src, Variable):
936            if src.is_constant:
937               return self.const
938            else:
939               # Note: we throw away which variable it is here! This special
940               # item is equivalent to nu in "Tree Automatons."
941               return self.wildcard
942         else:
943            assert isinstance(src, Expression)
944            opcode = src.opcode
945            stripped = opcode.rstrip('0123456789')
946            if stripped in conv_opcode_types:
947               # Matches that use conversion opcodes with a specific type,
948               # like f2i1, are tricky.  Either we construct the automaton to
949               # match specific NIR opcodes like nir_op_f2i1, in which case we
950               # need to create separate items for each possible NIR opcode
951               # for patterns that have a generic opcode like f2i, or we
952               # construct it to match the search opcode, in which case we
953               # need to map f2i1 to f2i when constructing the automaton. Here
954               # we do the latter.
955               opcode = stripped
956            self.opcodes.add(opcode)
957            children = tuple(process_subpattern(c) for c in src.sources)
958            item = get_item(opcode, children, pattern)
959            for i, child in enumerate(children):
960               child.parent_ops.add(opcode)
961            return item
962
963      for i, pattern in enumerate(self.patterns):
964         process_subpattern(pattern, i)
965
966   def _build_table(self):
967      """This is the core algorithm which builds up the transition table. It
968      is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl .
969      Comp_a and Filt_{a,i} using integers to identify match sets." It
970      simultaneously builds up a list of all possible "match sets" or
971      "states", where each match set represents the set of Item's that match a
972      given instruction, and builds up the transition table between states.
973      """
974      # Map from opcode + filtered state indices to transitioned state.
975      self.table = defaultdict(dict)
976      # Bijection from state to index. q in the original algorithm is
977      # len(self.states)
978      self.states = self.IndexMap()
979      # Lists of pattern matches separated by None
980      self.state_patterns = [None]
981      # Offset in the ->transforms table for each state index
982      self.state_pattern_offsets = []
983      # Map from state index to filtered state index for each opcode.
984      self.filter = defaultdict(list)
985      # Bijections from filtered state to filtered state index for each
986      # opcode, called the "representor sets" in the original algorithm.
987      # q_{a,j} in the original algorithm is len(self.rep[op]).
988      self.rep = defaultdict(self.IndexMap)
989
990      # Everything in self.states with a index at least worklist_index is part
991      # of the worklist of newly created states. There is also a worklist of
992      # newly fitered states for each opcode, for which worklist_indices
993      # serves a similar purpose. worklist_index corresponds to p in the
994      # original algorithm, while worklist_indices is p_{a,j} (although since
995      # we only filter by opcode/symbol, it's really just p_a).
996      self.worklist_index = 0
997      worklist_indices = defaultdict(lambda: 0)
998
999      # This is the set of opcodes for which the filtered worklist is non-empty.
1000      # It's used to avoid scanning opcodes for which there is nothing to
1001      # process when building the transition table. It corresponds to new_a in
1002      # the original algorithm.
1003      new_opcodes = self.IndexMap()
1004
1005      # Process states on the global worklist, filtering them for each opcode,
1006      # updating the filter tables, and updating the filtered worklists if any
1007      # new filtered states are found. Similar to ComputeRepresenterSets() in
1008      # the original algorithm, although that only processes a single state.
1009      def process_new_states():
1010         while self.worklist_index < len(self.states):
1011            state = self.states[self.worklist_index]
1012            # Calculate pattern matches for this state. Each pattern is
1013            # assigned to a unique item, so we don't have to worry about
1014            # deduplicating them here. However, we do have to sort them so
1015            # that they're visited at runtime in the order they're specified
1016            # in the source.
1017            patterns = list(sorted(p for item in state for p in item.patterns))
1018
1019            if patterns:
1020                # Add our patterns to the global table.
1021                self.state_pattern_offsets.append(len(self.state_patterns))
1022                self.state_patterns.extend(patterns)
1023                self.state_patterns.append(None)
1024            else:
1025                # Point to the initial sentinel in the global table.
1026                self.state_pattern_offsets.append(0)
1027
1028            # calculate filter table for this state, and update filtered
1029            # worklists.
1030            for op in self.opcodes:
1031               filt = self.filter[op]
1032               rep = self.rep[op]
1033               filtered = frozenset(item for item in state if \
1034                  op in item.parent_ops)
1035               if filtered in rep:
1036                  rep_index = rep.index(filtered)
1037               else:
1038                  rep_index = rep.add(filtered)
1039                  new_opcodes.add(op)
1040               assert len(filt) == self.worklist_index
1041               filt.append(rep_index)
1042            self.worklist_index += 1
1043
1044      # There are two start states: one which can only match as a wildcard,
1045      # and one which can match as a wildcard or constant. These will be the
1046      # states of intrinsics/other instructions and load_const instructions,
1047      # respectively. The indices of these must match the definitions of
1048      # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can
1049      # initialize things correctly.
1050      self.states.add(frozenset((self.wildcard,)))
1051      self.states.add(frozenset((self.const,self.wildcard)))
1052      process_new_states()
1053
1054      while len(new_opcodes) > 0:
1055         for op in new_opcodes:
1056            rep = self.rep[op]
1057            table = self.table[op]
1058            op_worklist_index = worklist_indices[op]
1059            if op in conv_opcode_types:
1060               num_srcs = 1
1061            else:
1062               num_srcs = opcodes[op].num_inputs
1063
1064            # Iterate over all possible source combinations where at least one
1065            # is on the worklist.
1066            for src_indices in itertools.product(range(len(rep)), repeat=num_srcs):
1067               if all(src_idx < op_worklist_index for src_idx in src_indices):
1068                  continue
1069
1070               srcs = tuple(rep[src_idx] for src_idx in src_indices)
1071
1072               # Try all possible pairings of source items and add the
1073               # corresponding parent items. This is Comp_a from the paper.
1074               parent = set(self.items[op, item_srcs] for item_srcs in
1075                  itertools.product(*srcs) if (op, item_srcs) in self.items)
1076
1077               # We could always start matching something else with a
1078               # wildcard. This is Cl from the paper.
1079               parent.add(self.wildcard)
1080
1081               table[src_indices] = self.states.add(frozenset(parent))
1082            worklist_indices[op] = len(rep)
1083         new_opcodes.clear()
1084         process_new_states()
1085
1086_algebraic_pass_template = mako.template.Template("""
1087#include "nir.h"
1088#include "nir_builder.h"
1089#include "nir_search.h"
1090#include "nir_search_helpers.h"
1091
1092/* What follows is NIR algebraic transform code for the following ${len(xforms)}
1093 * transforms:
1094% for xform in xforms:
1095 *    ${xform.search} => ${xform.replace}
1096% endfor
1097 */
1098
1099<% cache = {"next_index": 0} %>
1100static const nir_search_value_union ${pass_name}_values[] = {
1101% for xform in xforms:
1102   /* ${xform.search} => ${xform.replace} */
1103${xform.search.render(cache)}
1104${xform.replace.render(cache)}
1105% endfor
1106};
1107
1108% if expression_cond:
1109static const nir_search_expression_cond ${pass_name}_expression_cond[] = {
1110% for cond in expression_cond:
1111   ${cond[0]},
1112% endfor
1113};
1114% endif
1115
1116% if variable_cond:
1117static const nir_search_variable_cond ${pass_name}_variable_cond[] = {
1118% for cond in variable_cond:
1119   ${cond[0]},
1120% endfor
1121};
1122% endif
1123
1124static const struct transform ${pass_name}_transforms[] = {
1125% for i in automaton.state_patterns:
1126% if i is not None:
1127   { ${xforms[i].search.array_index}, ${xforms[i].replace.array_index}, ${xforms[i].condition_index} },
1128% else:
1129   { ~0, ~0, ~0 }, /* Sentinel */
1130
1131% endif
1132% endfor
1133};
1134
1135static const struct per_op_table ${pass_name}_pass_op_table[nir_num_search_ops] = {
1136% for op in automaton.opcodes:
1137   [${get_c_opcode(op)}] = {
1138% if all(e == 0 for e in automaton.filter[op]):
1139      .filter = NULL,
1140% else:
1141      .filter = (const uint16_t []) {
1142      % for e in automaton.filter[op]:
1143         ${e},
1144      % endfor
1145      },
1146% endif
1147      <%
1148        num_filtered = len(automaton.rep[op])
1149      %>
1150      .num_filtered_states = ${num_filtered},
1151      .table = (const uint16_t []) {
1152      <%
1153        num_srcs = len(next(iter(automaton.table[op])))
1154      %>
1155      % for indices in itertools.product(range(num_filtered), repeat=num_srcs):
1156         ${automaton.table[op][indices]},
1157      % endfor
1158      },
1159   },
1160% endfor
1161};
1162
1163/* Mapping from state index to offset in transforms (0 being no transforms) */
1164static const uint16_t ${pass_name}_transform_offsets[] = {
1165% for offset in automaton.state_pattern_offsets:
1166   ${offset},
1167% endfor
1168};
1169
1170static const nir_algebraic_table ${pass_name}_table = {
1171   .transforms = ${pass_name}_transforms,
1172   .transform_offsets = ${pass_name}_transform_offsets,
1173   .pass_op_table = ${pass_name}_pass_op_table,
1174   .values = ${pass_name}_values,
1175   .expression_cond = ${ pass_name + "_expression_cond" if expression_cond else "NULL" },
1176   .variable_cond = ${ pass_name + "_variable_cond" if variable_cond else "NULL" },
1177};
1178
1179bool
1180${pass_name}(
1181   nir_shader *shader
1182% for type, name in params:
1183   , ${type} ${name}
1184% endfor
1185) {
1186   bool progress = false;
1187   bool condition_flags[${len(condition_list)}];
1188   const nir_shader_compiler_options *options = shader->options;
1189   const shader_info *info = &shader->info;
1190   (void) options;
1191   (void) info;
1192
1193   STATIC_ASSERT(${str(cache["next_index"])} == ARRAY_SIZE(${pass_name}_values));
1194   % for index, condition in enumerate(condition_list):
1195   condition_flags[${index}] = ${condition};
1196   % endfor
1197
1198   nir_foreach_function_impl(impl, shader) {
1199     progress |= nir_algebraic_impl(impl, condition_flags, &${pass_name}_table);
1200   }
1201
1202   return progress;
1203}
1204""")
1205
1206
1207class AlgebraicPass(object):
1208   # params is a list of `("type", "name")` tuples
1209   def __init__(self, pass_name, transforms, params=[]):
1210      self.xforms = []
1211      self.opcode_xforms = defaultdict(lambda : [])
1212      self.pass_name = pass_name
1213      self.expression_cond = {}
1214      self.variable_cond = {}
1215      self.params = params
1216
1217      error = False
1218
1219      for xform in transforms:
1220         if not isinstance(xform, SearchAndReplace):
1221            try:
1222               xform = SearchAndReplace(xform, self)
1223            except:
1224               print("Failed to parse transformation:", file=sys.stderr)
1225               print("  " + str(xform), file=sys.stderr)
1226               traceback.print_exc(file=sys.stderr)
1227               print('', file=sys.stderr)
1228               error = True
1229               continue
1230
1231         self.xforms.append(xform)
1232         if xform.search.opcode in conv_opcode_types:
1233            dst_type = conv_opcode_types[xform.search.opcode]
1234            for size in type_sizes(dst_type):
1235               sized_opcode = xform.search.opcode + str(size)
1236               self.opcode_xforms[sized_opcode].append(xform)
1237         else:
1238            self.opcode_xforms[xform.search.opcode].append(xform)
1239
1240         # Check to make sure the search pattern does not unexpectedly contain
1241         # more commutative expressions than match_expression (nir_search.c)
1242         # can handle.
1243         comm_exprs = xform.search.comm_exprs
1244
1245         if xform.search.many_commutative_expressions:
1246            if comm_exprs <= nir_search_max_comm_ops:
1247               print("Transform expected to have too many commutative " \
1248                     "expression but did not " \
1249                     "({} <= {}).".format(comm_exprs, nir_search_max_comm_op),
1250                     file=sys.stderr)
1251               print("  " + str(xform), file=sys.stderr)
1252               traceback.print_exc(file=sys.stderr)
1253               print('', file=sys.stderr)
1254               error = True
1255         else:
1256            if comm_exprs > nir_search_max_comm_ops:
1257               print("Transformation with too many commutative expressions " \
1258                     "({} > {}).  Modify pattern or annotate with " \
1259                     "\"many-comm-expr\".".format(comm_exprs,
1260                                                  nir_search_max_comm_ops),
1261                     file=sys.stderr)
1262               print("  " + str(xform.search), file=sys.stderr)
1263               print("{}".format(xform.search.cond), file=sys.stderr)
1264               error = True
1265
1266      self.automaton = TreeAutomaton(self.xforms)
1267
1268      if error:
1269         sys.exit(1)
1270
1271
1272   def render(self):
1273      return _algebraic_pass_template.render(pass_name=self.pass_name,
1274                                             xforms=self.xforms,
1275                                             opcode_xforms=self.opcode_xforms,
1276                                             condition_list=condition_list,
1277                                             automaton=self.automaton,
1278                                             expression_cond = sorted(self.expression_cond.items(), key=lambda kv: kv[1]),
1279                                             variable_cond = sorted(self.variable_cond.items(), key=lambda kv: kv[1]),
1280                                             get_c_opcode=get_c_opcode,
1281                                             itertools=itertools,
1282                                             params=self.params)
1283
1284# The replacement expression isn't necessarily exact if the search expression is exact.
1285def ignore_exact(*expr):
1286   expr = SearchExpression.create(expr)
1287   expr.ignore_exact = True
1288   return expr
1289