• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#
2# Copyright (C) 2014 Intel Corporation
3#
4# Permission is hereby granted, free of charge, to any person obtaining a
5# copy of this software and associated documentation files (the "Software"),
6# to deal in the Software without restriction, including without limitation
7# the rights to use, copy, modify, merge, publish, distribute, sublicense,
8# and/or sell copies of the Software, and to permit persons to whom the
9# Software is furnished to do so, subject to the following conditions:
10#
11# The above copyright notice and this permission notice (including the next
12# paragraph) shall be included in all copies or substantial portions of the
13# Software.
14#
15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21# IN THE SOFTWARE.
22
23import ast
24from collections import defaultdict
25import itertools
26import struct
27import sys
28import mako.template
29import re
30import traceback
31
32from nir_opcodes import opcodes, type_sizes
33
34# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c
35nir_search_max_comm_ops = 8
36
37# These opcodes are only employed by nir_search.  This provides a mapping from
38# opcode to destination type.
39conv_opcode_types = {
40    'i2f' : 'float',
41    'u2f' : 'float',
42    'f2f' : 'float',
43    'f2u' : 'uint',
44    'f2i' : 'int',
45    'u2u' : 'uint',
46    'i2i' : 'int',
47    'b2f' : 'float',
48    'b2i' : 'int',
49    'i2b' : 'bool',
50    'f2b' : 'bool',
51}
52
53def get_cond_index(conds, cond):
54    if cond:
55        if cond in conds:
56            return conds[cond]
57        else:
58            cond_index = len(conds)
59            conds[cond] = cond_index
60            return cond_index
61    else:
62        return -1
63
64def get_c_opcode(op):
65      if op in conv_opcode_types:
66         return 'nir_search_op_' + op
67      else:
68         return 'nir_op_' + op
69
70_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?")
71
72def type_bits(type_str):
73   m = _type_re.match(type_str)
74   assert m.group('type')
75
76   if m.group('bits') is None:
77      return 0
78   else:
79      return int(m.group('bits'))
80
81# Represents a set of variables, each with a unique id
82class VarSet(object):
83   def __init__(self):
84      self.names = {}
85      self.ids = itertools.count()
86      self.immutable = False;
87
88   def __getitem__(self, name):
89      if name not in self.names:
90         assert not self.immutable, "Unknown replacement variable: " + name
91         self.names[name] = next(self.ids)
92
93      return self.names[name]
94
95   def lock(self):
96      self.immutable = True
97
98class SearchExpression(object):
99   def __init__(self, expr):
100      self.opcode = expr[0]
101      self.sources = expr[1:]
102      self.ignore_exact = False
103
104   @staticmethod
105   def create(val):
106      if isinstance(val, tuple):
107         return SearchExpression(val)
108      else:
109         assert(isinstance(val, SearchExpression))
110         return val
111
112   def __repr__(self):
113      l = [self.opcode, *self.sources]
114      if self.ignore_exact:
115         l.append('ignore_exact')
116      return repr((*l,))
117
118class Value(object):
119   @staticmethod
120   def create(val, name_base, varset, algebraic_pass):
121      if isinstance(val, bytes):
122         val = val.decode('utf-8')
123
124      if isinstance(val, tuple) or isinstance(val, SearchExpression):
125         return Expression(val, name_base, varset, algebraic_pass)
126      elif isinstance(val, Expression):
127         return val
128      elif isinstance(val, str):
129         return Variable(val, name_base, varset, algebraic_pass)
130      elif isinstance(val, (bool, float, int)):
131         return Constant(val, name_base)
132
133   def __init__(self, val, name, type_str):
134      self.in_val = str(val)
135      self.name = name
136      self.type_str = type_str
137
138   def __str__(self):
139      return self.in_val
140
141   def get_bit_size(self):
142      """Get the physical bit-size that has been chosen for this value, or if
143      there is none, the canonical value which currently represents this
144      bit-size class. Variables will be preferred, i.e. if there are any
145      variables in the equivalence class, the canonical value will be a
146      variable. We do this since we'll need to know which variable each value
147      is equivalent to when constructing the replacement expression. This is
148      the "find" part of the union-find algorithm.
149      """
150      bit_size = self
151
152      while isinstance(bit_size, Value):
153         if bit_size._bit_size is None:
154            break
155         bit_size = bit_size._bit_size
156
157      if bit_size is not self:
158         self._bit_size = bit_size
159      return bit_size
160
161   def set_bit_size(self, other):
162      """Make self.get_bit_size() return what other.get_bit_size() return
163      before calling this, or just "other" if it's a concrete bit-size. This is
164      the "union" part of the union-find algorithm.
165      """
166
167      self_bit_size = self.get_bit_size()
168      other_bit_size = other if isinstance(other, int) else other.get_bit_size()
169
170      if self_bit_size == other_bit_size:
171         return
172
173      self_bit_size._bit_size = other_bit_size
174
175   @property
176   def type_enum(self):
177      return "nir_search_value_" + self.type_str
178
179   @property
180   def c_bit_size(self):
181      bit_size = self.get_bit_size()
182      if isinstance(bit_size, int):
183         return bit_size
184      elif isinstance(bit_size, Variable):
185         return -bit_size.index - 1
186      else:
187         # If the bit-size class is neither a variable, nor an actual bit-size, then
188         # - If it's in the search expression, we don't need to check anything
189         # - If it's in the replace expression, either it's ambiguous (in which
190         # case we'd reject it), or it equals the bit-size of the search value
191         # We represent these cases with a 0 bit-size.
192         return 0
193
194   __template = mako.template.Template("""   { .${val.type_str} = {
195      { ${val.type_enum}, ${val.c_bit_size} },
196% if isinstance(val, Constant):
197      ${val.type()}, { ${val.hex()} /* ${val.value} */ },
198% elif isinstance(val, Variable):
199      ${val.index}, /* ${val.var_name} */
200      ${'true' if val.is_constant else 'false'},
201      ${val.type() or 'nir_type_invalid' },
202      ${val.cond_index},
203      ${val.swizzle()},
204% elif isinstance(val, Expression):
205      ${'true' if val.inexact else 'false'},
206      ${'true' if val.exact else 'false'},
207      ${'true' if val.ignore_exact else 'false'},
208      ${val.c_opcode()},
209      ${val.comm_expr_idx}, ${val.comm_exprs},
210      { ${', '.join(src.array_index for src in val.sources)} },
211      ${val.cond_index},
212% endif
213   } },
214""")
215
216   def render(self, cache):
217      struct_init = self.__template.render(val=self,
218                                           Constant=Constant,
219                                           Variable=Variable,
220                                           Expression=Expression)
221      if struct_init in cache:
222         # If it's in the cache, register a name remap in the cache and render
223         # only a comment saying it's been remapped
224         self.array_index = cache[struct_init]
225         return "   /* {} -> {} in the cache */\n".format(self.name,
226                                                       cache[struct_init])
227      else:
228         self.array_index = str(cache["next_index"])
229         cache[struct_init] = self.array_index
230         cache["next_index"] += 1
231         return struct_init
232
233_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?")
234
235class Constant(Value):
236   def __init__(self, val, name):
237      Value.__init__(self, val, name, "constant")
238
239      if isinstance(val, (str)):
240         m = _constant_re.match(val)
241         self.value = ast.literal_eval(m.group('value'))
242         self._bit_size = int(m.group('bits')) if m.group('bits') else None
243      else:
244         self.value = val
245         self._bit_size = None
246
247      if isinstance(self.value, bool):
248         assert self._bit_size is None or self._bit_size == 1
249         self._bit_size = 1
250
251   def hex(self):
252      if isinstance(self.value, (bool)):
253         return 'NIR_TRUE' if self.value else 'NIR_FALSE'
254      if isinstance(self.value, int):
255         return hex(self.value)
256      elif isinstance(self.value, float):
257         return hex(struct.unpack('Q', struct.pack('d', self.value))[0])
258      else:
259         assert False
260
261   def type(self):
262      if isinstance(self.value, (bool)):
263         return "nir_type_bool"
264      elif isinstance(self.value, int):
265         return "nir_type_int"
266      elif isinstance(self.value, float):
267         return "nir_type_float"
268
269   def equivalent(self, other):
270      """Check that two constants are equivalent.
271
272      This is check is much weaker than equality.  One generally cannot be
273      used in place of the other.  Using this implementation for the __eq__
274      will break BitSizeValidator.
275
276      """
277      if not isinstance(other, type(self)):
278         return False
279
280      return self.value == other.value
281
282# The $ at the end forces there to be an error if any part of the string
283# doesn't match one of the field patterns.
284_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)"
285                          r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?"
286                          r"(?P<cond>\([^\)]+\))?"
287                          r"(?P<swiz>\.[xyzwabcdefghijklmnop]+)?"
288                          r"$")
289
290class Variable(Value):
291   def __init__(self, val, name, varset, algebraic_pass):
292      Value.__init__(self, val, name, "variable")
293
294      m = _var_name_re.match(val)
295      assert m and m.group('name') is not None, \
296            "Malformed variable name \"{}\".".format(val)
297
298      self.var_name = m.group('name')
299
300      # Prevent common cases where someone puts quotes around a literal
301      # constant.  If we want to support names that have numeric or
302      # punctuation characters, we can me the first assertion more flexible.
303      assert self.var_name.isalpha()
304      assert self.var_name != 'True'
305      assert self.var_name != 'False'
306
307      self.is_constant = m.group('const') is not None
308      self.cond_index = get_cond_index(algebraic_pass.variable_cond, m.group('cond'))
309      self.required_type = m.group('type')
310      self._bit_size = int(m.group('bits')) if m.group('bits') else None
311      self.swiz = m.group('swiz')
312
313      if self.required_type == 'bool':
314         if self._bit_size is not None:
315            assert self._bit_size in type_sizes(self.required_type)
316         else:
317            self._bit_size = 1
318
319      if self.required_type is not None:
320         assert self.required_type in ('float', 'bool', 'int', 'uint')
321
322      self.index = varset[self.var_name]
323
324   def type(self):
325      if self.required_type == 'bool':
326         return "nir_type_bool"
327      elif self.required_type in ('int', 'uint'):
328         return "nir_type_int"
329      elif self.required_type == 'float':
330         return "nir_type_float"
331
332   def equivalent(self, other):
333      """Check that two variables are equivalent.
334
335      This is check is much weaker than equality.  One generally cannot be
336      used in place of the other.  Using this implementation for the __eq__
337      will break BitSizeValidator.
338
339      """
340      if not isinstance(other, type(self)):
341         return False
342
343      return self.index == other.index
344
345   def swizzle(self):
346      if self.swiz is not None:
347         swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3,
348                     'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3,
349                     'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7,
350                     'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11,
351                     'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 }
352         return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}'
353      return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}'
354
355_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?"
356                        r"(?P<cond>\([^\)]+\))?")
357
358class Expression(Value):
359   def __init__(self, expr, name_base, varset, algebraic_pass):
360      Value.__init__(self, expr, name_base, "expression")
361
362      expr = SearchExpression.create(expr)
363
364      m = _opcode_re.match(expr.opcode)
365      assert m and m.group('opcode') is not None
366
367      self.opcode = m.group('opcode')
368      self._bit_size = int(m.group('bits')) if m.group('bits') else None
369      self.inexact = m.group('inexact') is not None
370      self.exact = m.group('exact') is not None
371      self.ignore_exact = expr.ignore_exact
372      self.cond = m.group('cond')
373
374      assert not self.inexact or not self.exact, \
375            'Expression cannot be both exact and inexact.'
376
377      # "many-comm-expr" isn't really a condition.  It's notification to the
378      # generator that this pattern is known to have too many commutative
379      # expressions, and an error should not be generated for this case.
380      self.many_commutative_expressions = False
381      if self.cond and self.cond.find("many-comm-expr") >= 0:
382         # Split the condition into a comma-separated list.  Remove
383         # "many-comm-expr".  If there is anything left, put it back together.
384         c = self.cond[1:-1].split(",")
385         c.remove("many-comm-expr")
386         assert(len(c) <= 1)
387
388         self.cond = c[0] if c else None
389         self.many_commutative_expressions = True
390
391      # Deduplicate references to the condition functions for the expressions
392      # and save the index for the order they were added.
393      self.cond_index = get_cond_index(algebraic_pass.expression_cond, self.cond)
394
395      self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset, algebraic_pass)
396                       for (i, src) in enumerate(expr.sources) ]
397
398      # nir_search_expression::srcs is hard-coded to 4
399      assert len(self.sources) <= 4
400
401      if self.opcode in conv_opcode_types:
402         assert self._bit_size is None, \
403                'Expression cannot use an unsized conversion opcode with ' \
404                'an explicit size; that\'s silly.'
405
406      self.__index_comm_exprs(0)
407
408   def equivalent(self, other):
409      """Check that two variables are equivalent.
410
411      This is check is much weaker than equality.  One generally cannot be
412      used in place of the other.  Using this implementation for the __eq__
413      will break BitSizeValidator.
414
415      This implementation does not check for equivalence due to commutativity,
416      but it could.
417
418      """
419      if not isinstance(other, type(self)):
420         return False
421
422      if len(self.sources) != len(other.sources):
423         return False
424
425      if self.opcode != other.opcode:
426         return False
427
428      return all(s.equivalent(o) for s, o in zip(self.sources, other.sources))
429
430   def __index_comm_exprs(self, base_idx):
431      """Recursively count and index commutative expressions
432      """
433      self.comm_exprs = 0
434
435      # A note about the explicit "len(self.sources)" check. The list of
436      # sources comes from user input, and that input might be bad.  Check
437      # that the expected second source exists before accessing it. Without
438      # this check, a unit test that does "('iadd', 'a')" will crash.
439      if self.opcode not in conv_opcode_types and \
440         "2src_commutative" in opcodes[self.opcode].algebraic_properties and \
441         len(self.sources) >= 2 and \
442         not self.sources[0].equivalent(self.sources[1]):
443         self.comm_expr_idx = base_idx
444         self.comm_exprs += 1
445      else:
446         self.comm_expr_idx = -1
447
448      for s in self.sources:
449         if isinstance(s, Expression):
450            s.__index_comm_exprs(base_idx + self.comm_exprs)
451            self.comm_exprs += s.comm_exprs
452
453      return self.comm_exprs
454
455   def c_opcode(self):
456      return get_c_opcode(self.opcode)
457
458   def render(self, cache):
459      srcs = "".join(src.render(cache) for src in self.sources)
460      return srcs + super(Expression, self).render(cache)
461
462class BitSizeValidator(object):
463   """A class for validating bit sizes of expressions.
464
465   NIR supports multiple bit-sizes on expressions in order to handle things
466   such as fp64.  The source and destination of every ALU operation is
467   assigned a type and that type may or may not specify a bit size.  Sources
468   and destinations whose type does not specify a bit size are considered
469   "unsized" and automatically take on the bit size of the corresponding
470   register or SSA value.  NIR has two simple rules for bit sizes that are
471   validated by nir_validator:
472
473    1) A given SSA def or register has a single bit size that is respected by
474       everything that reads from it or writes to it.
475
476    2) The bit sizes of all unsized inputs/outputs on any given ALU
477       instruction must match.  They need not match the sized inputs or
478       outputs but they must match each other.
479
480   In order to keep nir_algebraic relatively simple and easy-to-use,
481   nir_search supports a type of bit-size inference based on the two rules
482   above.  This is similar to type inference in many common programming
483   languages.  If, for instance, you are constructing an add operation and you
484   know the second source is 16-bit, then you know that the other source and
485   the destination must also be 16-bit.  There are, however, cases where this
486   inference can be ambiguous or contradictory.  Consider, for instance, the
487   following transformation:
488
489   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
490
491   This transformation can potentially cause a problem because usub_borrow is
492   well-defined for any bit-size of integer.  However, b2i always generates a
493   32-bit result so it could end up replacing a 64-bit expression with one
494   that takes two 64-bit values and produces a 32-bit value.  As another
495   example, consider this expression:
496
497   (('bcsel', a, b, 0), ('iand', a, b))
498
499   In this case, in the search expression a must be 32-bit but b can
500   potentially have any bit size.  If we had a 64-bit b value, we would end up
501   trying to and a 32-bit value with a 64-bit value which would be invalid
502
503   This class solves that problem by providing a validation layer that proves
504   that a given search-and-replace operation is 100% well-defined before we
505   generate any code.  This ensures that bugs are caught at compile time
506   rather than at run time.
507
508   Each value maintains a "bit-size class", which is either an actual bit size
509   or an equivalence class with other values that must have the same bit size.
510   The validator works by combining bit-size classes with each other according
511   to the NIR rules outlined above, checking that there are no inconsistencies.
512   When doing this for the replacement expression, we make sure to never change
513   the equivalence class of any of the search values. We could make the example
514   transforms above work by doing some extra run-time checking of the search
515   expression, but we make the user specify those constraints themselves, to
516   avoid any surprises. Since the replacement bitsizes can only be connected to
517   the source bitsize via variables (variables must have the same bitsize in
518   the source and replacment expressions) or the roots of the expression (the
519   replacement expression must produce the same bit size as the search
520   expression), we prevent merging a variable with anything when processing the
521   replacement expression, or specializing the search bitsize
522   with anything. The former prevents
523
524   (('bcsel', a, b, 0), ('iand', a, b))
525
526   from being allowed, since we'd have to merge the bitsizes for a and b due to
527   the 'iand', while the latter prevents
528
529   (('usub_borrow', a, b), ('b2i@32', ('ult', a, b)))
530
531   from being allowed, since the search expression has the bit size of a and b,
532   which can't be specialized to 32 which is the bitsize of the replace
533   expression. It also prevents something like:
534
535   (('b2i', ('i2b', a)), ('ineq', a, 0))
536
537   since the bitsize of 'b2i', which can be anything, can't be specialized to
538   the bitsize of a.
539
540   After doing all this, we check that every subexpression of the replacement
541   was assigned a constant bitsize, the bitsize of a variable, or the bitsize
542   of the search expresssion, since those are the things that are known when
543   constructing the replacement expresssion. Finally, we record the bitsize
544   needed in nir_search_value so that we know what to do when building the
545   replacement expression.
546   """
547
548   def __init__(self, varset):
549      self._var_classes = [None] * len(varset.names)
550
551   def compare_bitsizes(self, a, b):
552      """Determines which bitsize class is a specialization of the other, or
553      whether neither is. When we merge two different bitsizes, the
554      less-specialized bitsize always points to the more-specialized one, so
555      that calling get_bit_size() always gets you the most specialized bitsize.
556      The specialization partial order is given by:
557      - Physical bitsizes are always the most specialized, and a different
558        bitsize can never specialize another.
559      - In the search expression, variables can always be specialized to each
560        other and to physical bitsizes. In the replace expression, we disallow
561        this to avoid adding extra constraints to the search expression that
562        the user didn't specify.
563      - Expressions and constants without a bitsize can always be specialized to
564        each other and variables, but not the other way around.
565
566        We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b,
567        and None if they are not comparable (neither a <= b nor b <= a).
568      """
569      if isinstance(a, int):
570         if isinstance(b, int):
571            return 0 if a == b else None
572         elif isinstance(b, Variable):
573            return -1 if self.is_search else None
574         else:
575            return -1
576      elif isinstance(a, Variable):
577         if isinstance(b, int):
578            return 1 if self.is_search else None
579         elif isinstance(b, Variable):
580            return 0 if self.is_search or a.index == b.index else None
581         else:
582            return -1
583      else:
584         if isinstance(b, int):
585            return 1
586         elif isinstance(b, Variable):
587            return 1
588         else:
589            return 0
590
591   def unify_bit_size(self, a, b, error_msg):
592      """Record that a must have the same bit-size as b. If both
593      have been assigned conflicting physical bit-sizes, call "error_msg" with
594      the bit-sizes of self and other to get a message and raise an error.
595      In the replace expression, disallow merging variables with other
596      variables and physical bit-sizes as well.
597      """
598      a_bit_size = a.get_bit_size()
599      b_bit_size = b if isinstance(b, int) else b.get_bit_size()
600
601      cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size)
602
603      assert cmp_result is not None, \
604         error_msg(a_bit_size, b_bit_size)
605
606      if cmp_result < 0:
607         b_bit_size.set_bit_size(a)
608      elif not isinstance(a_bit_size, int):
609         a_bit_size.set_bit_size(b)
610
611   def merge_variables(self, val):
612      """Perform the first part of type inference by merging all the different
613      uses of the same variable. We always do this as if we're in the search
614      expression, even if we're actually not, since otherwise we'd get errors
615      if the search expression specified some constraint but the replace
616      expression didn't, because we'd be merging a variable and a constant.
617      """
618      if isinstance(val, Variable):
619         if self._var_classes[val.index] is None:
620            self._var_classes[val.index] = val
621         else:
622            other = self._var_classes[val.index]
623            self.unify_bit_size(other, val,
624                  lambda other_bit_size, bit_size:
625                     'Variable {} has conflicting bit size requirements: ' \
626                     'it must have bit size {} and {}'.format(
627                        val.var_name, other_bit_size, bit_size))
628      elif isinstance(val, Expression):
629         for src in val.sources:
630            self.merge_variables(src)
631
632   def validate_value(self, val):
633      """Validate the an expression by performing classic Hindley-Milner
634      type inference on bitsizes. This will detect if there are any conflicting
635      requirements, and unify variables so that we know which variables must
636      have the same bitsize. If we're operating on the replace expression, we
637      will refuse to merge different variables together or merge a variable
638      with a constant, in order to prevent surprises due to rules unexpectedly
639      not matching at runtime.
640      """
641      if not isinstance(val, Expression):
642         return
643
644      # Generic conversion ops are special in that they have a single unsized
645      # source and an unsized destination and the two don't have to match.
646      # This means there's no validation or unioning to do here besides the
647      # len(val.sources) check.
648      if val.opcode in conv_opcode_types:
649         assert len(val.sources) == 1, \
650            "Expression {} has {} sources, expected 1".format(
651               val, len(val.sources))
652         self.validate_value(val.sources[0])
653         return
654
655      nir_op = opcodes[val.opcode]
656      assert len(val.sources) == nir_op.num_inputs, \
657         "Expression {} has {} sources, expected {}".format(
658            val, len(val.sources), nir_op.num_inputs)
659
660      for src in val.sources:
661         self.validate_value(src)
662
663      dst_type_bits = type_bits(nir_op.output_type)
664
665      # First, unify all the sources. That way, an error coming up because two
666      # sources have an incompatible bit-size won't produce an error message
667      # involving the destination.
668      first_unsized_src = None
669      for src_type, src in zip(nir_op.input_types, val.sources):
670         src_type_bits = type_bits(src_type)
671         if src_type_bits == 0:
672            if first_unsized_src is None:
673               first_unsized_src = src
674               continue
675
676            if self.is_search:
677               self.unify_bit_size(first_unsized_src, src,
678                  lambda first_unsized_src_bit_size, src_bit_size:
679                     'Source {} of {} must have bit size {}, while source {} ' \
680                     'must have incompatible bit size {}'.format(
681                        first_unsized_src, val, first_unsized_src_bit_size,
682                        src, src_bit_size))
683            else:
684               self.unify_bit_size(first_unsized_src, src,
685                  lambda first_unsized_src_bit_size, src_bit_size:
686                     'Sources {} (bit size of {}) and {} (bit size of {}) ' \
687                     'of {} may not have the same bit size when building the ' \
688                     'replacement expression.'.format(
689                        first_unsized_src, first_unsized_src_bit_size, src,
690                        src_bit_size, val))
691         else:
692            if self.is_search:
693               self.unify_bit_size(src, src_type_bits,
694                  lambda src_bit_size, unused:
695                     '{} must have {} bits, but as a source of nir_op_{} '\
696                     'it must have {} bits'.format(
697                        src, src_bit_size, nir_op.name, src_type_bits))
698            else:
699               self.unify_bit_size(src, src_type_bits,
700                  lambda src_bit_size, unused:
701                     '{} has the bit size of {}, but as a source of ' \
702                     'nir_op_{} it must have {} bits, which may not be the ' \
703                     'same'.format(
704                        src, src_bit_size, nir_op.name, src_type_bits))
705
706      if dst_type_bits == 0:
707         if first_unsized_src is not None:
708            if self.is_search:
709               self.unify_bit_size(val, first_unsized_src,
710                  lambda val_bit_size, src_bit_size:
711                     '{} must have the bit size of {}, while its source {} ' \
712                     'must have incompatible bit size {}'.format(
713                        val, val_bit_size, first_unsized_src, src_bit_size))
714            else:
715               self.unify_bit_size(val, first_unsized_src,
716                  lambda val_bit_size, src_bit_size:
717                     '{} must have {} bits, but its source {} ' \
718                     '(bit size of {}) may not have that bit size ' \
719                     'when building the replacement.'.format(
720                        val, val_bit_size, first_unsized_src, src_bit_size))
721      else:
722         self.unify_bit_size(val, dst_type_bits,
723            lambda dst_bit_size, unused:
724               '{} must have {} bits, but as a destination of nir_op_{} ' \
725               'it must have {} bits'.format(
726                  val, dst_bit_size, nir_op.name, dst_type_bits))
727
728   def validate_replace(self, val, search):
729      bit_size = val.get_bit_size()
730      assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \
731            bit_size == search.get_bit_size(), \
732            'Ambiguous bit size for replacement value {}: ' \
733            'it cannot be deduced from a variable, a fixed bit size ' \
734            'somewhere, or the search expression.'.format(val)
735
736      if isinstance(val, Expression):
737         for src in val.sources:
738            self.validate_replace(src, search)
739      elif isinstance(val, Variable):
740          # These catch problems when someone copies and pastes the search
741          # into the replacement.
742          assert not val.is_constant, \
743              'Replacement variables must not be marked constant.'
744
745          assert val.cond_index == -1, \
746              'Replacement variables must not have a condition.'
747
748          assert not val.required_type, \
749              'Replacement variables must not have a required type.'
750
751   def validate(self, search, replace):
752      self.is_search = True
753      self.merge_variables(search)
754      self.merge_variables(replace)
755      self.validate_value(search)
756
757      self.is_search = False
758      self.validate_value(replace)
759
760      # Check that search is always more specialized than replace. Note that
761      # we're doing this in replace mode, disallowing merging variables.
762      search_bit_size = search.get_bit_size()
763      replace_bit_size = replace.get_bit_size()
764      cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size)
765
766      assert cmp_result is not None and cmp_result <= 0, \
767         'The search expression bit size {} and replace expression ' \
768         'bit size {} may not be the same'.format(
769               search_bit_size, replace_bit_size)
770
771      replace.set_bit_size(search)
772
773      self.validate_replace(replace, search)
774
775_optimization_ids = itertools.count()
776
777condition_list = ['true']
778
779class SearchAndReplace(object):
780   def __init__(self, transform, algebraic_pass):
781      self.id = next(_optimization_ids)
782
783      search = transform[0]
784      replace = transform[1]
785      if len(transform) > 2:
786         self.condition = transform[2]
787      else:
788         self.condition = 'true'
789
790      if self.condition not in condition_list:
791         condition_list.append(self.condition)
792      self.condition_index = condition_list.index(self.condition)
793
794      varset = VarSet()
795      if isinstance(search, Expression):
796         self.search = search
797      else:
798         self.search = Expression(search, "search{0}".format(self.id), varset, algebraic_pass)
799
800      varset.lock()
801
802      if isinstance(replace, Value):
803         self.replace = replace
804      else:
805         self.replace = Value.create(replace, "replace{0}".format(self.id), varset, algebraic_pass)
806
807      BitSizeValidator(varset).validate(self.search, self.replace)
808
809class TreeAutomaton(object):
810   """This class calculates a bottom-up tree automaton to quickly search for
811   the left-hand sides of tranforms. Tree automatons are a generalization of
812   classical NFA's and DFA's, where the transition function determines the
813   state of the parent node based on the state of its children. We construct a
814   deterministic automaton to match patterns, using a similar algorithm to the
815   classical NFA to DFA construction. At the moment, it only matches opcodes
816   and constants (without checking the actual value), leaving more detailed
817   checking to the search function which actually checks the leaves. The
818   automaton acts as a quick filter for the search function, requiring only n
819   + 1 table lookups for each n-source operation. The implementation is based
820   on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit."
821   In the language of that reference, this is a frontier-to-root deterministic
822   automaton using only symbol filtering. The filtering is crucial to reduce
823   both the time taken to generate the tables and the size of the tables.
824   """
825   def __init__(self, transforms):
826      self.patterns = [t.search for t in transforms]
827      self._compute_items()
828      self._build_table()
829      #print('num items: {}'.format(len(set(self.items.values()))))
830      #print('num states: {}'.format(len(self.states)))
831      #for state, patterns in zip(self.states, self.patterns):
832      #   print('{}: num patterns: {}'.format(state, len(patterns)))
833
834   class IndexMap(object):
835      """An indexed list of objects, where one can either lookup an object by
836      index or find the index associated to an object quickly using a hash
837      table. Compared to a list, it has a constant time index(). Compared to a
838      set, it provides a stable iteration order.
839      """
840      def __init__(self, iterable=()):
841         self.objects = []
842         self.map = {}
843         for obj in iterable:
844            self.add(obj)
845
846      def __getitem__(self, i):
847         return self.objects[i]
848
849      def __contains__(self, obj):
850         return obj in self.map
851
852      def __len__(self):
853         return len(self.objects)
854
855      def __iter__(self):
856         return iter(self.objects)
857
858      def clear(self):
859         self.objects = []
860         self.map.clear()
861
862      def index(self, obj):
863         return self.map[obj]
864
865      def add(self, obj):
866         if obj in self.map:
867            return self.map[obj]
868         else:
869            index = len(self.objects)
870            self.objects.append(obj)
871            self.map[obj] = index
872            return index
873
874      def __repr__(self):
875         return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])'
876
877   class Item(object):
878      """This represents an "item" in the language of "Tree Automatons." This
879      is just a subtree of some pattern, which represents a potential partial
880      match at runtime. We deduplicate them, so that identical subtrees of
881      different patterns share the same object, and store some extra
882      information needed for the main algorithm as well.
883      """
884      def __init__(self, opcode, children):
885         self.opcode = opcode
886         self.children = children
887         # These are the indices of patterns for which this item is the root node.
888         self.patterns = []
889         # This the set of opcodes for parents of this item. Used to speed up
890         # filtering.
891         self.parent_ops = set()
892
893      def __str__(self):
894         return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')'
895
896      def __repr__(self):
897         return str(self)
898
899   def _compute_items(self):
900      """Build a set of all possible items, deduplicating them."""
901      # This is a map from (opcode, sources) to item.
902      self.items = {}
903
904      # The set of all opcodes used by the patterns. Used later to avoid
905      # building and emitting all the tables for opcodes that aren't used.
906      self.opcodes = self.IndexMap()
907
908      def get_item(opcode, children, pattern=None):
909         commutative = len(children) >= 2 \
910               and "2src_commutative" in opcodes[opcode].algebraic_properties
911         item = self.items.setdefault((opcode, children),
912                                      self.Item(opcode, children))
913         if commutative:
914            self.items[opcode, (children[1], children[0]) + children[2:]] = item
915         if pattern is not None:
916            item.patterns.append(pattern)
917         return item
918
919      self.wildcard = get_item("__wildcard", ())
920      self.const = get_item("__const", ())
921
922      def process_subpattern(src, pattern=None):
923         if isinstance(src, Constant):
924            # Note: we throw away the actual constant value!
925            return self.const
926         elif isinstance(src, Variable):
927            if src.is_constant:
928               return self.const
929            else:
930               # Note: we throw away which variable it is here! This special
931               # item is equivalent to nu in "Tree Automatons."
932               return self.wildcard
933         else:
934            assert isinstance(src, Expression)
935            opcode = src.opcode
936            stripped = opcode.rstrip('0123456789')
937            if stripped in conv_opcode_types:
938               # Matches that use conversion opcodes with a specific type,
939               # like f2i1, are tricky.  Either we construct the automaton to
940               # match specific NIR opcodes like nir_op_f2i1, in which case we
941               # need to create separate items for each possible NIR opcode
942               # for patterns that have a generic opcode like f2i, or we
943               # construct it to match the search opcode, in which case we
944               # need to map f2i1 to f2i when constructing the automaton. Here
945               # we do the latter.
946               opcode = stripped
947            self.opcodes.add(opcode)
948            children = tuple(process_subpattern(c) for c in src.sources)
949            item = get_item(opcode, children, pattern)
950            for i, child in enumerate(children):
951               child.parent_ops.add(opcode)
952            return item
953
954      for i, pattern in enumerate(self.patterns):
955         process_subpattern(pattern, i)
956
957   def _build_table(self):
958      """This is the core algorithm which builds up the transition table. It
959      is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl .
960      Comp_a and Filt_{a,i} using integers to identify match sets." It
961      simultaneously builds up a list of all possible "match sets" or
962      "states", where each match set represents the set of Item's that match a
963      given instruction, and builds up the transition table between states.
964      """
965      # Map from opcode + filtered state indices to transitioned state.
966      self.table = defaultdict(dict)
967      # Bijection from state to index. q in the original algorithm is
968      # len(self.states)
969      self.states = self.IndexMap()
970      # Lists of pattern matches separated by None
971      self.state_patterns = [None]
972      # Offset in the ->transforms table for each state index
973      self.state_pattern_offsets = []
974      # Map from state index to filtered state index for each opcode.
975      self.filter = defaultdict(list)
976      # Bijections from filtered state to filtered state index for each
977      # opcode, called the "representor sets" in the original algorithm.
978      # q_{a,j} in the original algorithm is len(self.rep[op]).
979      self.rep = defaultdict(self.IndexMap)
980
981      # Everything in self.states with a index at least worklist_index is part
982      # of the worklist of newly created states. There is also a worklist of
983      # newly fitered states for each opcode, for which worklist_indices
984      # serves a similar purpose. worklist_index corresponds to p in the
985      # original algorithm, while worklist_indices is p_{a,j} (although since
986      # we only filter by opcode/symbol, it's really just p_a).
987      self.worklist_index = 0
988      worklist_indices = defaultdict(lambda: 0)
989
990      # This is the set of opcodes for which the filtered worklist is non-empty.
991      # It's used to avoid scanning opcodes for which there is nothing to
992      # process when building the transition table. It corresponds to new_a in
993      # the original algorithm.
994      new_opcodes = self.IndexMap()
995
996      # Process states on the global worklist, filtering them for each opcode,
997      # updating the filter tables, and updating the filtered worklists if any
998      # new filtered states are found. Similar to ComputeRepresenterSets() in
999      # the original algorithm, although that only processes a single state.
1000      def process_new_states():
1001         while self.worklist_index < len(self.states):
1002            state = self.states[self.worklist_index]
1003            # Calculate pattern matches for this state. Each pattern is
1004            # assigned to a unique item, so we don't have to worry about
1005            # deduplicating them here. However, we do have to sort them so
1006            # that they're visited at runtime in the order they're specified
1007            # in the source.
1008            patterns = list(sorted(p for item in state for p in item.patterns))
1009
1010            if patterns:
1011                # Add our patterns to the global table.
1012                self.state_pattern_offsets.append(len(self.state_patterns))
1013                self.state_patterns.extend(patterns)
1014                self.state_patterns.append(None)
1015            else:
1016                # Point to the initial sentinel in the global table.
1017                self.state_pattern_offsets.append(0)
1018
1019            # calculate filter table for this state, and update filtered
1020            # worklists.
1021            for op in self.opcodes:
1022               filt = self.filter[op]
1023               rep = self.rep[op]
1024               filtered = frozenset(item for item in state if \
1025                  op in item.parent_ops)
1026               if filtered in rep:
1027                  rep_index = rep.index(filtered)
1028               else:
1029                  rep_index = rep.add(filtered)
1030                  new_opcodes.add(op)
1031               assert len(filt) == self.worklist_index
1032               filt.append(rep_index)
1033            self.worklist_index += 1
1034
1035      # There are two start states: one which can only match as a wildcard,
1036      # and one which can match as a wildcard or constant. These will be the
1037      # states of intrinsics/other instructions and load_const instructions,
1038      # respectively. The indices of these must match the definitions of
1039      # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can
1040      # initialize things correctly.
1041      self.states.add(frozenset((self.wildcard,)))
1042      self.states.add(frozenset((self.const,self.wildcard)))
1043      process_new_states()
1044
1045      while len(new_opcodes) > 0:
1046         for op in new_opcodes:
1047            rep = self.rep[op]
1048            table = self.table[op]
1049            op_worklist_index = worklist_indices[op]
1050            if op in conv_opcode_types:
1051               num_srcs = 1
1052            else:
1053               num_srcs = opcodes[op].num_inputs
1054
1055            # Iterate over all possible source combinations where at least one
1056            # is on the worklist.
1057            for src_indices in itertools.product(range(len(rep)), repeat=num_srcs):
1058               if all(src_idx < op_worklist_index for src_idx in src_indices):
1059                  continue
1060
1061               srcs = tuple(rep[src_idx] for src_idx in src_indices)
1062
1063               # Try all possible pairings of source items and add the
1064               # corresponding parent items. This is Comp_a from the paper.
1065               parent = set(self.items[op, item_srcs] for item_srcs in
1066                  itertools.product(*srcs) if (op, item_srcs) in self.items)
1067
1068               # We could always start matching something else with a
1069               # wildcard. This is Cl from the paper.
1070               parent.add(self.wildcard)
1071
1072               table[src_indices] = self.states.add(frozenset(parent))
1073            worklist_indices[op] = len(rep)
1074         new_opcodes.clear()
1075         process_new_states()
1076
1077_algebraic_pass_template = mako.template.Template("""
1078#include "nir.h"
1079#include "nir_builder.h"
1080#include "nir_search.h"
1081#include "nir_search_helpers.h"
1082
1083/* What follows is NIR algebraic transform code for the following ${len(xforms)}
1084 * transforms:
1085% for xform in xforms:
1086 *    ${xform.search} => ${xform.replace}
1087% endfor
1088 */
1089
1090<% cache = {"next_index": 0} %>
1091static const nir_search_value_union ${pass_name}_values[] = {
1092% for xform in xforms:
1093   /* ${xform.search} => ${xform.replace} */
1094${xform.search.render(cache)}
1095${xform.replace.render(cache)}
1096% endfor
1097};
1098
1099% if expression_cond:
1100static const nir_search_expression_cond ${pass_name}_expression_cond[] = {
1101% for cond in expression_cond:
1102   ${cond[0]},
1103% endfor
1104};
1105% endif
1106
1107% if variable_cond:
1108static const nir_search_variable_cond ${pass_name}_variable_cond[] = {
1109% for cond in variable_cond:
1110   ${cond[0]},
1111% endfor
1112};
1113% endif
1114
1115static const struct transform ${pass_name}_transforms[] = {
1116% for i in automaton.state_patterns:
1117% if i is not None:
1118   { ${xforms[i].search.array_index}, ${xforms[i].replace.array_index}, ${xforms[i].condition_index} },
1119% else:
1120   { ~0, ~0, ~0 }, /* Sentinel */
1121
1122% endif
1123% endfor
1124};
1125
1126static const struct per_op_table ${pass_name}_pass_op_table[nir_num_search_ops] = {
1127% for op in automaton.opcodes:
1128   [${get_c_opcode(op)}] = {
1129% if all(e == 0 for e in automaton.filter[op]):
1130      .filter = NULL,
1131% else:
1132      .filter = (const uint16_t []) {
1133      % for e in automaton.filter[op]:
1134         ${e},
1135      % endfor
1136      },
1137% endif
1138      <%
1139        num_filtered = len(automaton.rep[op])
1140      %>
1141      .num_filtered_states = ${num_filtered},
1142      .table = (const uint16_t []) {
1143      <%
1144        num_srcs = len(next(iter(automaton.table[op])))
1145      %>
1146      % for indices in itertools.product(range(num_filtered), repeat=num_srcs):
1147         ${automaton.table[op][indices]},
1148      % endfor
1149      },
1150   },
1151% endfor
1152};
1153
1154/* Mapping from state index to offset in transforms (0 being no transforms) */
1155static const uint16_t ${pass_name}_transform_offsets[] = {
1156% for offset in automaton.state_pattern_offsets:
1157   ${offset},
1158% endfor
1159};
1160
1161static const nir_algebraic_table ${pass_name}_table = {
1162   .transforms = ${pass_name}_transforms,
1163   .transform_offsets = ${pass_name}_transform_offsets,
1164   .pass_op_table = ${pass_name}_pass_op_table,
1165   .values = ${pass_name}_values,
1166   .expression_cond = ${ pass_name + "_expression_cond" if expression_cond else "NULL" },
1167   .variable_cond = ${ pass_name + "_variable_cond" if variable_cond else "NULL" },
1168};
1169
1170bool
1171${pass_name}(
1172   nir_shader *shader
1173% for type, name in params:
1174   , ${type} ${name}
1175% endfor
1176) {
1177   bool progress = false;
1178   bool condition_flags[${len(condition_list)}];
1179   const nir_shader_compiler_options *options = shader->options;
1180   const shader_info *info = &shader->info;
1181   (void) options;
1182   (void) info;
1183
1184   STATIC_ASSERT(${str(cache["next_index"])} == ARRAY_SIZE(${pass_name}_values));
1185   % for index, condition in enumerate(condition_list):
1186   condition_flags[${index}] = ${condition};
1187   % endfor
1188
1189   nir_foreach_function_impl(impl, shader) {
1190     progress |= nir_algebraic_impl(impl, condition_flags, &${pass_name}_table);
1191   }
1192
1193   return progress;
1194}
1195""")
1196
1197
1198class AlgebraicPass(object):
1199   # params is a list of `("type", "name")` tuples
1200   def __init__(self, pass_name, transforms, params=[]):
1201      self.xforms = []
1202      self.opcode_xforms = defaultdict(lambda : [])
1203      self.pass_name = pass_name
1204      self.expression_cond = {}
1205      self.variable_cond = {}
1206      self.params = params
1207
1208      error = False
1209
1210      for xform in transforms:
1211         if not isinstance(xform, SearchAndReplace):
1212            try:
1213               xform = SearchAndReplace(xform, self)
1214            except:
1215               print("Failed to parse transformation:", file=sys.stderr)
1216               print("  " + str(xform), file=sys.stderr)
1217               traceback.print_exc(file=sys.stderr)
1218               print('', file=sys.stderr)
1219               error = True
1220               continue
1221
1222         self.xforms.append(xform)
1223         if xform.search.opcode in conv_opcode_types:
1224            dst_type = conv_opcode_types[xform.search.opcode]
1225            for size in type_sizes(dst_type):
1226               sized_opcode = xform.search.opcode + str(size)
1227               self.opcode_xforms[sized_opcode].append(xform)
1228         else:
1229            self.opcode_xforms[xform.search.opcode].append(xform)
1230
1231         # Check to make sure the search pattern does not unexpectedly contain
1232         # more commutative expressions than match_expression (nir_search.c)
1233         # can handle.
1234         comm_exprs = xform.search.comm_exprs
1235
1236         if xform.search.many_commutative_expressions:
1237            if comm_exprs <= nir_search_max_comm_ops:
1238               print("Transform expected to have too many commutative " \
1239                     "expression but did not " \
1240                     "({} <= {}).".format(comm_exprs, nir_search_max_comm_op),
1241                     file=sys.stderr)
1242               print("  " + str(xform), file=sys.stderr)
1243               traceback.print_exc(file=sys.stderr)
1244               print('', file=sys.stderr)
1245               error = True
1246         else:
1247            if comm_exprs > nir_search_max_comm_ops:
1248               print("Transformation with too many commutative expressions " \
1249                     "({} > {}).  Modify pattern or annotate with " \
1250                     "\"many-comm-expr\".".format(comm_exprs,
1251                                                  nir_search_max_comm_ops),
1252                     file=sys.stderr)
1253               print("  " + str(xform.search), file=sys.stderr)
1254               print("{}".format(xform.search.cond), file=sys.stderr)
1255               error = True
1256
1257      self.automaton = TreeAutomaton(self.xforms)
1258
1259      if error:
1260         sys.exit(1)
1261
1262
1263   def render(self):
1264      return _algebraic_pass_template.render(pass_name=self.pass_name,
1265                                             xforms=self.xforms,
1266                                             opcode_xforms=self.opcode_xforms,
1267                                             condition_list=condition_list,
1268                                             automaton=self.automaton,
1269                                             expression_cond = sorted(self.expression_cond.items(), key=lambda kv: kv[1]),
1270                                             variable_cond = sorted(self.variable_cond.items(), key=lambda kv: kv[1]),
1271                                             get_c_opcode=get_c_opcode,
1272                                             itertools=itertools,
1273                                             params=self.params)
1274
1275# The replacement expression isn't necessarily exact if the search expression is exact.
1276def ignore_exact(*expr):
1277   expr = SearchExpression.create(expr)
1278   expr.ignore_exact = True
1279   return expr
1280