1# 2# Copyright (C) 2014 Intel Corporation 3# 4# Permission is hereby granted, free of charge, to any person obtaining a 5# copy of this software and associated documentation files (the "Software"), 6# to deal in the Software without restriction, including without limitation 7# the rights to use, copy, modify, merge, publish, distribute, sublicense, 8# and/or sell copies of the Software, and to permit persons to whom the 9# Software is furnished to do so, subject to the following conditions: 10# 11# The above copyright notice and this permission notice (including the next 12# paragraph) shall be included in all copies or substantial portions of the 13# Software. 14# 15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21# IN THE SOFTWARE. 22 23import ast 24from collections import defaultdict 25import itertools 26import struct 27import sys 28import mako.template 29import re 30import traceback 31 32from nir_opcodes import opcodes, type_sizes 33 34# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c 35nir_search_max_comm_ops = 8 36 37# These opcodes are only employed by nir_search. This provides a mapping from 38# opcode to destination type. 39conv_opcode_types = { 40 'i2f' : 'float', 41 'u2f' : 'float', 42 'f2f' : 'float', 43 'f2u' : 'uint', 44 'f2i' : 'int', 45 'u2u' : 'uint', 46 'i2i' : 'int', 47 'b2f' : 'float', 48 'b2i' : 'int', 49 'i2b' : 'bool', 50 'f2b' : 'bool', 51} 52 53def get_cond_index(conds, cond): 54 if cond: 55 if cond in conds: 56 return conds[cond] 57 else: 58 cond_index = len(conds) 59 conds[cond] = cond_index 60 return cond_index 61 else: 62 return -1 63 64def get_c_opcode(op): 65 if op in conv_opcode_types: 66 return 'nir_search_op_' + op 67 else: 68 return 'nir_op_' + op 69 70_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?") 71 72def type_bits(type_str): 73 m = _type_re.match(type_str) 74 assert m.group('type') 75 76 if m.group('bits') is None: 77 return 0 78 else: 79 return int(m.group('bits')) 80 81# Represents a set of variables, each with a unique id 82class VarSet(object): 83 def __init__(self): 84 self.names = {} 85 self.ids = itertools.count() 86 self.immutable = False; 87 88 def __getitem__(self, name): 89 if name not in self.names: 90 assert not self.immutable, "Unknown replacement variable: " + name 91 self.names[name] = next(self.ids) 92 93 return self.names[name] 94 95 def lock(self): 96 self.immutable = True 97 98class SearchExpression(object): 99 def __init__(self, expr): 100 self.opcode = expr[0] 101 self.sources = expr[1:] 102 self.ignore_exact = False 103 104 @staticmethod 105 def create(val): 106 if isinstance(val, tuple): 107 return SearchExpression(val) 108 else: 109 assert(isinstance(val, SearchExpression)) 110 return val 111 112 def __repr__(self): 113 l = [self.opcode, *self.sources] 114 if self.ignore_exact: 115 l.append('ignore_exact') 116 return repr((*l,)) 117 118class Value(object): 119 @staticmethod 120 def create(val, name_base, varset, algebraic_pass): 121 if isinstance(val, bytes): 122 val = val.decode('utf-8') 123 124 if isinstance(val, tuple) or isinstance(val, SearchExpression): 125 return Expression(val, name_base, varset, algebraic_pass) 126 elif isinstance(val, Expression): 127 return val 128 elif isinstance(val, str): 129 return Variable(val, name_base, varset, algebraic_pass) 130 elif isinstance(val, (bool, float, int)): 131 return Constant(val, name_base) 132 133 def __init__(self, val, name, type_str): 134 self.in_val = str(val) 135 self.name = name 136 self.type_str = type_str 137 138 def __str__(self): 139 return self.in_val 140 141 def get_bit_size(self): 142 """Get the physical bit-size that has been chosen for this value, or if 143 there is none, the canonical value which currently represents this 144 bit-size class. Variables will be preferred, i.e. if there are any 145 variables in the equivalence class, the canonical value will be a 146 variable. We do this since we'll need to know which variable each value 147 is equivalent to when constructing the replacement expression. This is 148 the "find" part of the union-find algorithm. 149 """ 150 bit_size = self 151 152 while isinstance(bit_size, Value): 153 if bit_size._bit_size is None: 154 break 155 bit_size = bit_size._bit_size 156 157 if bit_size is not self: 158 self._bit_size = bit_size 159 return bit_size 160 161 def set_bit_size(self, other): 162 """Make self.get_bit_size() return what other.get_bit_size() return 163 before calling this, or just "other" if it's a concrete bit-size. This is 164 the "union" part of the union-find algorithm. 165 """ 166 167 self_bit_size = self.get_bit_size() 168 other_bit_size = other if isinstance(other, int) else other.get_bit_size() 169 170 if self_bit_size == other_bit_size: 171 return 172 173 self_bit_size._bit_size = other_bit_size 174 175 @property 176 def type_enum(self): 177 return "nir_search_value_" + self.type_str 178 179 @property 180 def c_bit_size(self): 181 bit_size = self.get_bit_size() 182 if isinstance(bit_size, int): 183 return bit_size 184 elif isinstance(bit_size, Variable): 185 return -bit_size.index - 1 186 else: 187 # If the bit-size class is neither a variable, nor an actual bit-size, then 188 # - If it's in the search expression, we don't need to check anything 189 # - If it's in the replace expression, either it's ambiguous (in which 190 # case we'd reject it), or it equals the bit-size of the search value 191 # We represent these cases with a 0 bit-size. 192 return 0 193 194 __template = mako.template.Template(""" { .${val.type_str} = { 195 { ${val.type_enum}, ${val.c_bit_size} }, 196% if isinstance(val, Constant): 197 ${val.type()}, { ${val.hex()} /* ${val.value} */ }, 198% elif isinstance(val, Variable): 199 ${val.index}, /* ${val.var_name} */ 200 ${'true' if val.is_constant else 'false'}, 201 ${val.type() or 'nir_type_invalid' }, 202 ${val.cond_index}, 203 ${val.swizzle()}, 204% elif isinstance(val, Expression): 205 ${'true' if val.inexact else 'false'}, 206 ${'true' if val.exact else 'false'}, 207 ${'true' if val.ignore_exact else 'false'}, 208 ${val.c_opcode()}, 209 ${val.comm_expr_idx}, ${val.comm_exprs}, 210 { ${', '.join(src.array_index for src in val.sources)} }, 211 ${val.cond_index}, 212% endif 213 } }, 214""") 215 216 def render(self, cache): 217 struct_init = self.__template.render(val=self, 218 Constant=Constant, 219 Variable=Variable, 220 Expression=Expression) 221 if struct_init in cache: 222 # If it's in the cache, register a name remap in the cache and render 223 # only a comment saying it's been remapped 224 self.array_index = cache[struct_init] 225 return " /* {} -> {} in the cache */\n".format(self.name, 226 cache[struct_init]) 227 else: 228 self.array_index = str(cache["next_index"]) 229 cache[struct_init] = self.array_index 230 cache["next_index"] += 1 231 return struct_init 232 233_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?") 234 235class Constant(Value): 236 def __init__(self, val, name): 237 Value.__init__(self, val, name, "constant") 238 239 if isinstance(val, (str)): 240 m = _constant_re.match(val) 241 self.value = ast.literal_eval(m.group('value')) 242 self._bit_size = int(m.group('bits')) if m.group('bits') else None 243 else: 244 self.value = val 245 self._bit_size = None 246 247 if isinstance(self.value, bool): 248 assert self._bit_size is None or self._bit_size == 1 249 self._bit_size = 1 250 251 def hex(self): 252 if isinstance(self.value, (bool)): 253 return 'NIR_TRUE' if self.value else 'NIR_FALSE' 254 if isinstance(self.value, int): 255 return hex(self.value) 256 elif isinstance(self.value, float): 257 return hex(struct.unpack('Q', struct.pack('d', self.value))[0]) 258 else: 259 assert False 260 261 def type(self): 262 if isinstance(self.value, (bool)): 263 return "nir_type_bool" 264 elif isinstance(self.value, int): 265 return "nir_type_int" 266 elif isinstance(self.value, float): 267 return "nir_type_float" 268 269 def equivalent(self, other): 270 """Check that two constants are equivalent. 271 272 This is check is much weaker than equality. One generally cannot be 273 used in place of the other. Using this implementation for the __eq__ 274 will break BitSizeValidator. 275 276 """ 277 if not isinstance(other, type(self)): 278 return False 279 280 return self.value == other.value 281 282# The $ at the end forces there to be an error if any part of the string 283# doesn't match one of the field patterns. 284_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)" 285 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?" 286 r"(?P<cond>\([^\)]+\))?" 287 r"(?P<swiz>\.[xyzwabcdefghijklmnop]+)?" 288 r"$") 289 290class Variable(Value): 291 def __init__(self, val, name, varset, algebraic_pass): 292 Value.__init__(self, val, name, "variable") 293 294 m = _var_name_re.match(val) 295 assert m and m.group('name') is not None, \ 296 "Malformed variable name \"{}\".".format(val) 297 298 self.var_name = m.group('name') 299 300 # Prevent common cases where someone puts quotes around a literal 301 # constant. If we want to support names that have numeric or 302 # punctuation characters, we can me the first assertion more flexible. 303 assert self.var_name.isalpha() 304 assert self.var_name != 'True' 305 assert self.var_name != 'False' 306 307 self.is_constant = m.group('const') is not None 308 self.cond_index = get_cond_index(algebraic_pass.variable_cond, m.group('cond')) 309 self.required_type = m.group('type') 310 self._bit_size = int(m.group('bits')) if m.group('bits') else None 311 self.swiz = m.group('swiz') 312 313 if self.required_type == 'bool': 314 if self._bit_size is not None: 315 assert self._bit_size in type_sizes(self.required_type) 316 else: 317 self._bit_size = 1 318 319 if self.required_type is not None: 320 assert self.required_type in ('float', 'bool', 'int', 'uint') 321 322 self.index = varset[self.var_name] 323 324 def type(self): 325 if self.required_type == 'bool': 326 return "nir_type_bool" 327 elif self.required_type in ('int', 'uint'): 328 return "nir_type_int" 329 elif self.required_type == 'float': 330 return "nir_type_float" 331 332 def equivalent(self, other): 333 """Check that two variables are equivalent. 334 335 This is check is much weaker than equality. One generally cannot be 336 used in place of the other. Using this implementation for the __eq__ 337 will break BitSizeValidator. 338 339 """ 340 if not isinstance(other, type(self)): 341 return False 342 343 return self.index == other.index 344 345 def swizzle(self): 346 if self.swiz is not None: 347 swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3, 348 'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3, 349 'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7, 350 'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11, 351 'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 } 352 return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}' 353 return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}' 354 355_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?" 356 r"(?P<cond>\([^\)]+\))?") 357 358class Expression(Value): 359 def __init__(self, expr, name_base, varset, algebraic_pass): 360 Value.__init__(self, expr, name_base, "expression") 361 362 expr = SearchExpression.create(expr) 363 364 m = _opcode_re.match(expr.opcode) 365 assert m and m.group('opcode') is not None 366 367 self.opcode = m.group('opcode') 368 self._bit_size = int(m.group('bits')) if m.group('bits') else None 369 self.inexact = m.group('inexact') is not None 370 self.exact = m.group('exact') is not None 371 self.ignore_exact = expr.ignore_exact 372 self.cond = m.group('cond') 373 374 assert not self.inexact or not self.exact, \ 375 'Expression cannot be both exact and inexact.' 376 377 # "many-comm-expr" isn't really a condition. It's notification to the 378 # generator that this pattern is known to have too many commutative 379 # expressions, and an error should not be generated for this case. 380 self.many_commutative_expressions = False 381 if self.cond and self.cond.find("many-comm-expr") >= 0: 382 # Split the condition into a comma-separated list. Remove 383 # "many-comm-expr". If there is anything left, put it back together. 384 c = self.cond[1:-1].split(",") 385 c.remove("many-comm-expr") 386 assert(len(c) <= 1) 387 388 self.cond = c[0] if c else None 389 self.many_commutative_expressions = True 390 391 # Deduplicate references to the condition functions for the expressions 392 # and save the index for the order they were added. 393 self.cond_index = get_cond_index(algebraic_pass.expression_cond, self.cond) 394 395 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset, algebraic_pass) 396 for (i, src) in enumerate(expr.sources) ] 397 398 # nir_search_expression::srcs is hard-coded to 4 399 assert len(self.sources) <= 4 400 401 if self.opcode in conv_opcode_types: 402 assert self._bit_size is None, \ 403 'Expression cannot use an unsized conversion opcode with ' \ 404 'an explicit size; that\'s silly.' 405 406 self.__index_comm_exprs(0) 407 408 def equivalent(self, other): 409 """Check that two variables are equivalent. 410 411 This is check is much weaker than equality. One generally cannot be 412 used in place of the other. Using this implementation for the __eq__ 413 will break BitSizeValidator. 414 415 This implementation does not check for equivalence due to commutativity, 416 but it could. 417 418 """ 419 if not isinstance(other, type(self)): 420 return False 421 422 if len(self.sources) != len(other.sources): 423 return False 424 425 if self.opcode != other.opcode: 426 return False 427 428 return all(s.equivalent(o) for s, o in zip(self.sources, other.sources)) 429 430 def __index_comm_exprs(self, base_idx): 431 """Recursively count and index commutative expressions 432 """ 433 self.comm_exprs = 0 434 435 # A note about the explicit "len(self.sources)" check. The list of 436 # sources comes from user input, and that input might be bad. Check 437 # that the expected second source exists before accessing it. Without 438 # this check, a unit test that does "('iadd', 'a')" will crash. 439 if self.opcode not in conv_opcode_types and \ 440 "2src_commutative" in opcodes[self.opcode].algebraic_properties and \ 441 len(self.sources) >= 2 and \ 442 not self.sources[0].equivalent(self.sources[1]): 443 self.comm_expr_idx = base_idx 444 self.comm_exprs += 1 445 else: 446 self.comm_expr_idx = -1 447 448 for s in self.sources: 449 if isinstance(s, Expression): 450 s.__index_comm_exprs(base_idx + self.comm_exprs) 451 self.comm_exprs += s.comm_exprs 452 453 return self.comm_exprs 454 455 def c_opcode(self): 456 return get_c_opcode(self.opcode) 457 458 def render(self, cache): 459 srcs = "".join(src.render(cache) for src in self.sources) 460 return srcs + super(Expression, self).render(cache) 461 462class BitSizeValidator(object): 463 """A class for validating bit sizes of expressions. 464 465 NIR supports multiple bit-sizes on expressions in order to handle things 466 such as fp64. The source and destination of every ALU operation is 467 assigned a type and that type may or may not specify a bit size. Sources 468 and destinations whose type does not specify a bit size are considered 469 "unsized" and automatically take on the bit size of the corresponding 470 register or SSA value. NIR has two simple rules for bit sizes that are 471 validated by nir_validator: 472 473 1) A given SSA def or register has a single bit size that is respected by 474 everything that reads from it or writes to it. 475 476 2) The bit sizes of all unsized inputs/outputs on any given ALU 477 instruction must match. They need not match the sized inputs or 478 outputs but they must match each other. 479 480 In order to keep nir_algebraic relatively simple and easy-to-use, 481 nir_search supports a type of bit-size inference based on the two rules 482 above. This is similar to type inference in many common programming 483 languages. If, for instance, you are constructing an add operation and you 484 know the second source is 16-bit, then you know that the other source and 485 the destination must also be 16-bit. There are, however, cases where this 486 inference can be ambiguous or contradictory. Consider, for instance, the 487 following transformation: 488 489 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b))) 490 491 This transformation can potentially cause a problem because usub_borrow is 492 well-defined for any bit-size of integer. However, b2i always generates a 493 32-bit result so it could end up replacing a 64-bit expression with one 494 that takes two 64-bit values and produces a 32-bit value. As another 495 example, consider this expression: 496 497 (('bcsel', a, b, 0), ('iand', a, b)) 498 499 In this case, in the search expression a must be 32-bit but b can 500 potentially have any bit size. If we had a 64-bit b value, we would end up 501 trying to and a 32-bit value with a 64-bit value which would be invalid 502 503 This class solves that problem by providing a validation layer that proves 504 that a given search-and-replace operation is 100% well-defined before we 505 generate any code. This ensures that bugs are caught at compile time 506 rather than at run time. 507 508 Each value maintains a "bit-size class", which is either an actual bit size 509 or an equivalence class with other values that must have the same bit size. 510 The validator works by combining bit-size classes with each other according 511 to the NIR rules outlined above, checking that there are no inconsistencies. 512 When doing this for the replacement expression, we make sure to never change 513 the equivalence class of any of the search values. We could make the example 514 transforms above work by doing some extra run-time checking of the search 515 expression, but we make the user specify those constraints themselves, to 516 avoid any surprises. Since the replacement bitsizes can only be connected to 517 the source bitsize via variables (variables must have the same bitsize in 518 the source and replacment expressions) or the roots of the expression (the 519 replacement expression must produce the same bit size as the search 520 expression), we prevent merging a variable with anything when processing the 521 replacement expression, or specializing the search bitsize 522 with anything. The former prevents 523 524 (('bcsel', a, b, 0), ('iand', a, b)) 525 526 from being allowed, since we'd have to merge the bitsizes for a and b due to 527 the 'iand', while the latter prevents 528 529 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b))) 530 531 from being allowed, since the search expression has the bit size of a and b, 532 which can't be specialized to 32 which is the bitsize of the replace 533 expression. It also prevents something like: 534 535 (('b2i', ('i2b', a)), ('ineq', a, 0)) 536 537 since the bitsize of 'b2i', which can be anything, can't be specialized to 538 the bitsize of a. 539 540 After doing all this, we check that every subexpression of the replacement 541 was assigned a constant bitsize, the bitsize of a variable, or the bitsize 542 of the search expresssion, since those are the things that are known when 543 constructing the replacement expresssion. Finally, we record the bitsize 544 needed in nir_search_value so that we know what to do when building the 545 replacement expression. 546 """ 547 548 def __init__(self, varset): 549 self._var_classes = [None] * len(varset.names) 550 551 def compare_bitsizes(self, a, b): 552 """Determines which bitsize class is a specialization of the other, or 553 whether neither is. When we merge two different bitsizes, the 554 less-specialized bitsize always points to the more-specialized one, so 555 that calling get_bit_size() always gets you the most specialized bitsize. 556 The specialization partial order is given by: 557 - Physical bitsizes are always the most specialized, and a different 558 bitsize can never specialize another. 559 - In the search expression, variables can always be specialized to each 560 other and to physical bitsizes. In the replace expression, we disallow 561 this to avoid adding extra constraints to the search expression that 562 the user didn't specify. 563 - Expressions and constants without a bitsize can always be specialized to 564 each other and variables, but not the other way around. 565 566 We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b, 567 and None if they are not comparable (neither a <= b nor b <= a). 568 """ 569 if isinstance(a, int): 570 if isinstance(b, int): 571 return 0 if a == b else None 572 elif isinstance(b, Variable): 573 return -1 if self.is_search else None 574 else: 575 return -1 576 elif isinstance(a, Variable): 577 if isinstance(b, int): 578 return 1 if self.is_search else None 579 elif isinstance(b, Variable): 580 return 0 if self.is_search or a.index == b.index else None 581 else: 582 return -1 583 else: 584 if isinstance(b, int): 585 return 1 586 elif isinstance(b, Variable): 587 return 1 588 else: 589 return 0 590 591 def unify_bit_size(self, a, b, error_msg): 592 """Record that a must have the same bit-size as b. If both 593 have been assigned conflicting physical bit-sizes, call "error_msg" with 594 the bit-sizes of self and other to get a message and raise an error. 595 In the replace expression, disallow merging variables with other 596 variables and physical bit-sizes as well. 597 """ 598 a_bit_size = a.get_bit_size() 599 b_bit_size = b if isinstance(b, int) else b.get_bit_size() 600 601 cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size) 602 603 assert cmp_result is not None, \ 604 error_msg(a_bit_size, b_bit_size) 605 606 if cmp_result < 0: 607 b_bit_size.set_bit_size(a) 608 elif not isinstance(a_bit_size, int): 609 a_bit_size.set_bit_size(b) 610 611 def merge_variables(self, val): 612 """Perform the first part of type inference by merging all the different 613 uses of the same variable. We always do this as if we're in the search 614 expression, even if we're actually not, since otherwise we'd get errors 615 if the search expression specified some constraint but the replace 616 expression didn't, because we'd be merging a variable and a constant. 617 """ 618 if isinstance(val, Variable): 619 if self._var_classes[val.index] is None: 620 self._var_classes[val.index] = val 621 else: 622 other = self._var_classes[val.index] 623 self.unify_bit_size(other, val, 624 lambda other_bit_size, bit_size: 625 'Variable {} has conflicting bit size requirements: ' \ 626 'it must have bit size {} and {}'.format( 627 val.var_name, other_bit_size, bit_size)) 628 elif isinstance(val, Expression): 629 for src in val.sources: 630 self.merge_variables(src) 631 632 def validate_value(self, val): 633 """Validate the an expression by performing classic Hindley-Milner 634 type inference on bitsizes. This will detect if there are any conflicting 635 requirements, and unify variables so that we know which variables must 636 have the same bitsize. If we're operating on the replace expression, we 637 will refuse to merge different variables together or merge a variable 638 with a constant, in order to prevent surprises due to rules unexpectedly 639 not matching at runtime. 640 """ 641 if not isinstance(val, Expression): 642 return 643 644 # Generic conversion ops are special in that they have a single unsized 645 # source and an unsized destination and the two don't have to match. 646 # This means there's no validation or unioning to do here besides the 647 # len(val.sources) check. 648 if val.opcode in conv_opcode_types: 649 assert len(val.sources) == 1, \ 650 "Expression {} has {} sources, expected 1".format( 651 val, len(val.sources)) 652 self.validate_value(val.sources[0]) 653 return 654 655 nir_op = opcodes[val.opcode] 656 assert len(val.sources) == nir_op.num_inputs, \ 657 "Expression {} has {} sources, expected {}".format( 658 val, len(val.sources), nir_op.num_inputs) 659 660 for src in val.sources: 661 self.validate_value(src) 662 663 dst_type_bits = type_bits(nir_op.output_type) 664 665 # First, unify all the sources. That way, an error coming up because two 666 # sources have an incompatible bit-size won't produce an error message 667 # involving the destination. 668 first_unsized_src = None 669 for src_type, src in zip(nir_op.input_types, val.sources): 670 src_type_bits = type_bits(src_type) 671 if src_type_bits == 0: 672 if first_unsized_src is None: 673 first_unsized_src = src 674 continue 675 676 if self.is_search: 677 self.unify_bit_size(first_unsized_src, src, 678 lambda first_unsized_src_bit_size, src_bit_size: 679 'Source {} of {} must have bit size {}, while source {} ' \ 680 'must have incompatible bit size {}'.format( 681 first_unsized_src, val, first_unsized_src_bit_size, 682 src, src_bit_size)) 683 else: 684 self.unify_bit_size(first_unsized_src, src, 685 lambda first_unsized_src_bit_size, src_bit_size: 686 'Sources {} (bit size of {}) and {} (bit size of {}) ' \ 687 'of {} may not have the same bit size when building the ' \ 688 'replacement expression.'.format( 689 first_unsized_src, first_unsized_src_bit_size, src, 690 src_bit_size, val)) 691 else: 692 if self.is_search: 693 self.unify_bit_size(src, src_type_bits, 694 lambda src_bit_size, unused: 695 '{} must have {} bits, but as a source of nir_op_{} '\ 696 'it must have {} bits'.format( 697 src, src_bit_size, nir_op.name, src_type_bits)) 698 else: 699 self.unify_bit_size(src, src_type_bits, 700 lambda src_bit_size, unused: 701 '{} has the bit size of {}, but as a source of ' \ 702 'nir_op_{} it must have {} bits, which may not be the ' \ 703 'same'.format( 704 src, src_bit_size, nir_op.name, src_type_bits)) 705 706 if dst_type_bits == 0: 707 if first_unsized_src is not None: 708 if self.is_search: 709 self.unify_bit_size(val, first_unsized_src, 710 lambda val_bit_size, src_bit_size: 711 '{} must have the bit size of {}, while its source {} ' \ 712 'must have incompatible bit size {}'.format( 713 val, val_bit_size, first_unsized_src, src_bit_size)) 714 else: 715 self.unify_bit_size(val, first_unsized_src, 716 lambda val_bit_size, src_bit_size: 717 '{} must have {} bits, but its source {} ' \ 718 '(bit size of {}) may not have that bit size ' \ 719 'when building the replacement.'.format( 720 val, val_bit_size, first_unsized_src, src_bit_size)) 721 else: 722 self.unify_bit_size(val, dst_type_bits, 723 lambda dst_bit_size, unused: 724 '{} must have {} bits, but as a destination of nir_op_{} ' \ 725 'it must have {} bits'.format( 726 val, dst_bit_size, nir_op.name, dst_type_bits)) 727 728 def validate_replace(self, val, search): 729 bit_size = val.get_bit_size() 730 assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \ 731 bit_size == search.get_bit_size(), \ 732 'Ambiguous bit size for replacement value {}: ' \ 733 'it cannot be deduced from a variable, a fixed bit size ' \ 734 'somewhere, or the search expression.'.format(val) 735 736 if isinstance(val, Expression): 737 for src in val.sources: 738 self.validate_replace(src, search) 739 elif isinstance(val, Variable): 740 # These catch problems when someone copies and pastes the search 741 # into the replacement. 742 assert not val.is_constant, \ 743 'Replacement variables must not be marked constant.' 744 745 assert val.cond_index == -1, \ 746 'Replacement variables must not have a condition.' 747 748 assert not val.required_type, \ 749 'Replacement variables must not have a required type.' 750 751 def validate(self, search, replace): 752 self.is_search = True 753 self.merge_variables(search) 754 self.merge_variables(replace) 755 self.validate_value(search) 756 757 self.is_search = False 758 self.validate_value(replace) 759 760 # Check that search is always more specialized than replace. Note that 761 # we're doing this in replace mode, disallowing merging variables. 762 search_bit_size = search.get_bit_size() 763 replace_bit_size = replace.get_bit_size() 764 cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size) 765 766 assert cmp_result is not None and cmp_result <= 0, \ 767 'The search expression bit size {} and replace expression ' \ 768 'bit size {} may not be the same'.format( 769 search_bit_size, replace_bit_size) 770 771 replace.set_bit_size(search) 772 773 self.validate_replace(replace, search) 774 775_optimization_ids = itertools.count() 776 777condition_list = ['true'] 778 779class SearchAndReplace(object): 780 def __init__(self, transform, algebraic_pass): 781 self.id = next(_optimization_ids) 782 783 search = transform[0] 784 replace = transform[1] 785 if len(transform) > 2: 786 self.condition = transform[2] 787 else: 788 self.condition = 'true' 789 790 if self.condition not in condition_list: 791 condition_list.append(self.condition) 792 self.condition_index = condition_list.index(self.condition) 793 794 varset = VarSet() 795 if isinstance(search, Expression): 796 self.search = search 797 else: 798 self.search = Expression(search, "search{0}".format(self.id), varset, algebraic_pass) 799 800 varset.lock() 801 802 if isinstance(replace, Value): 803 self.replace = replace 804 else: 805 self.replace = Value.create(replace, "replace{0}".format(self.id), varset, algebraic_pass) 806 807 BitSizeValidator(varset).validate(self.search, self.replace) 808 809class TreeAutomaton(object): 810 """This class calculates a bottom-up tree automaton to quickly search for 811 the left-hand sides of tranforms. Tree automatons are a generalization of 812 classical NFA's and DFA's, where the transition function determines the 813 state of the parent node based on the state of its children. We construct a 814 deterministic automaton to match patterns, using a similar algorithm to the 815 classical NFA to DFA construction. At the moment, it only matches opcodes 816 and constants (without checking the actual value), leaving more detailed 817 checking to the search function which actually checks the leaves. The 818 automaton acts as a quick filter for the search function, requiring only n 819 + 1 table lookups for each n-source operation. The implementation is based 820 on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit." 821 In the language of that reference, this is a frontier-to-root deterministic 822 automaton using only symbol filtering. The filtering is crucial to reduce 823 both the time taken to generate the tables and the size of the tables. 824 """ 825 def __init__(self, transforms): 826 self.patterns = [t.search for t in transforms] 827 self._compute_items() 828 self._build_table() 829 #print('num items: {}'.format(len(set(self.items.values())))) 830 #print('num states: {}'.format(len(self.states))) 831 #for state, patterns in zip(self.states, self.patterns): 832 # print('{}: num patterns: {}'.format(state, len(patterns))) 833 834 class IndexMap(object): 835 """An indexed list of objects, where one can either lookup an object by 836 index or find the index associated to an object quickly using a hash 837 table. Compared to a list, it has a constant time index(). Compared to a 838 set, it provides a stable iteration order. 839 """ 840 def __init__(self, iterable=()): 841 self.objects = [] 842 self.map = {} 843 for obj in iterable: 844 self.add(obj) 845 846 def __getitem__(self, i): 847 return self.objects[i] 848 849 def __contains__(self, obj): 850 return obj in self.map 851 852 def __len__(self): 853 return len(self.objects) 854 855 def __iter__(self): 856 return iter(self.objects) 857 858 def clear(self): 859 self.objects = [] 860 self.map.clear() 861 862 def index(self, obj): 863 return self.map[obj] 864 865 def add(self, obj): 866 if obj in self.map: 867 return self.map[obj] 868 else: 869 index = len(self.objects) 870 self.objects.append(obj) 871 self.map[obj] = index 872 return index 873 874 def __repr__(self): 875 return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])' 876 877 class Item(object): 878 """This represents an "item" in the language of "Tree Automatons." This 879 is just a subtree of some pattern, which represents a potential partial 880 match at runtime. We deduplicate them, so that identical subtrees of 881 different patterns share the same object, and store some extra 882 information needed for the main algorithm as well. 883 """ 884 def __init__(self, opcode, children): 885 self.opcode = opcode 886 self.children = children 887 # These are the indices of patterns for which this item is the root node. 888 self.patterns = [] 889 # This the set of opcodes for parents of this item. Used to speed up 890 # filtering. 891 self.parent_ops = set() 892 893 def __str__(self): 894 return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')' 895 896 def __repr__(self): 897 return str(self) 898 899 def _compute_items(self): 900 """Build a set of all possible items, deduplicating them.""" 901 # This is a map from (opcode, sources) to item. 902 self.items = {} 903 904 # The set of all opcodes used by the patterns. Used later to avoid 905 # building and emitting all the tables for opcodes that aren't used. 906 self.opcodes = self.IndexMap() 907 908 def get_item(opcode, children, pattern=None): 909 commutative = len(children) >= 2 \ 910 and "2src_commutative" in opcodes[opcode].algebraic_properties 911 item = self.items.setdefault((opcode, children), 912 self.Item(opcode, children)) 913 if commutative: 914 self.items[opcode, (children[1], children[0]) + children[2:]] = item 915 if pattern is not None: 916 item.patterns.append(pattern) 917 return item 918 919 self.wildcard = get_item("__wildcard", ()) 920 self.const = get_item("__const", ()) 921 922 def process_subpattern(src, pattern=None): 923 if isinstance(src, Constant): 924 # Note: we throw away the actual constant value! 925 return self.const 926 elif isinstance(src, Variable): 927 if src.is_constant: 928 return self.const 929 else: 930 # Note: we throw away which variable it is here! This special 931 # item is equivalent to nu in "Tree Automatons." 932 return self.wildcard 933 else: 934 assert isinstance(src, Expression) 935 opcode = src.opcode 936 stripped = opcode.rstrip('0123456789') 937 if stripped in conv_opcode_types: 938 # Matches that use conversion opcodes with a specific type, 939 # like f2i1, are tricky. Either we construct the automaton to 940 # match specific NIR opcodes like nir_op_f2i1, in which case we 941 # need to create separate items for each possible NIR opcode 942 # for patterns that have a generic opcode like f2i, or we 943 # construct it to match the search opcode, in which case we 944 # need to map f2i1 to f2i when constructing the automaton. Here 945 # we do the latter. 946 opcode = stripped 947 self.opcodes.add(opcode) 948 children = tuple(process_subpattern(c) for c in src.sources) 949 item = get_item(opcode, children, pattern) 950 for i, child in enumerate(children): 951 child.parent_ops.add(opcode) 952 return item 953 954 for i, pattern in enumerate(self.patterns): 955 process_subpattern(pattern, i) 956 957 def _build_table(self): 958 """This is the core algorithm which builds up the transition table. It 959 is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl . 960 Comp_a and Filt_{a,i} using integers to identify match sets." It 961 simultaneously builds up a list of all possible "match sets" or 962 "states", where each match set represents the set of Item's that match a 963 given instruction, and builds up the transition table between states. 964 """ 965 # Map from opcode + filtered state indices to transitioned state. 966 self.table = defaultdict(dict) 967 # Bijection from state to index. q in the original algorithm is 968 # len(self.states) 969 self.states = self.IndexMap() 970 # Lists of pattern matches separated by None 971 self.state_patterns = [None] 972 # Offset in the ->transforms table for each state index 973 self.state_pattern_offsets = [] 974 # Map from state index to filtered state index for each opcode. 975 self.filter = defaultdict(list) 976 # Bijections from filtered state to filtered state index for each 977 # opcode, called the "representor sets" in the original algorithm. 978 # q_{a,j} in the original algorithm is len(self.rep[op]). 979 self.rep = defaultdict(self.IndexMap) 980 981 # Everything in self.states with a index at least worklist_index is part 982 # of the worklist of newly created states. There is also a worklist of 983 # newly fitered states for each opcode, for which worklist_indices 984 # serves a similar purpose. worklist_index corresponds to p in the 985 # original algorithm, while worklist_indices is p_{a,j} (although since 986 # we only filter by opcode/symbol, it's really just p_a). 987 self.worklist_index = 0 988 worklist_indices = defaultdict(lambda: 0) 989 990 # This is the set of opcodes for which the filtered worklist is non-empty. 991 # It's used to avoid scanning opcodes for which there is nothing to 992 # process when building the transition table. It corresponds to new_a in 993 # the original algorithm. 994 new_opcodes = self.IndexMap() 995 996 # Process states on the global worklist, filtering them for each opcode, 997 # updating the filter tables, and updating the filtered worklists if any 998 # new filtered states are found. Similar to ComputeRepresenterSets() in 999 # the original algorithm, although that only processes a single state. 1000 def process_new_states(): 1001 while self.worklist_index < len(self.states): 1002 state = self.states[self.worklist_index] 1003 # Calculate pattern matches for this state. Each pattern is 1004 # assigned to a unique item, so we don't have to worry about 1005 # deduplicating them here. However, we do have to sort them so 1006 # that they're visited at runtime in the order they're specified 1007 # in the source. 1008 patterns = list(sorted(p for item in state for p in item.patterns)) 1009 1010 if patterns: 1011 # Add our patterns to the global table. 1012 self.state_pattern_offsets.append(len(self.state_patterns)) 1013 self.state_patterns.extend(patterns) 1014 self.state_patterns.append(None) 1015 else: 1016 # Point to the initial sentinel in the global table. 1017 self.state_pattern_offsets.append(0) 1018 1019 # calculate filter table for this state, and update filtered 1020 # worklists. 1021 for op in self.opcodes: 1022 filt = self.filter[op] 1023 rep = self.rep[op] 1024 filtered = frozenset(item for item in state if \ 1025 op in item.parent_ops) 1026 if filtered in rep: 1027 rep_index = rep.index(filtered) 1028 else: 1029 rep_index = rep.add(filtered) 1030 new_opcodes.add(op) 1031 assert len(filt) == self.worklist_index 1032 filt.append(rep_index) 1033 self.worklist_index += 1 1034 1035 # There are two start states: one which can only match as a wildcard, 1036 # and one which can match as a wildcard or constant. These will be the 1037 # states of intrinsics/other instructions and load_const instructions, 1038 # respectively. The indices of these must match the definitions of 1039 # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can 1040 # initialize things correctly. 1041 self.states.add(frozenset((self.wildcard,))) 1042 self.states.add(frozenset((self.const,self.wildcard))) 1043 process_new_states() 1044 1045 while len(new_opcodes) > 0: 1046 for op in new_opcodes: 1047 rep = self.rep[op] 1048 table = self.table[op] 1049 op_worklist_index = worklist_indices[op] 1050 if op in conv_opcode_types: 1051 num_srcs = 1 1052 else: 1053 num_srcs = opcodes[op].num_inputs 1054 1055 # Iterate over all possible source combinations where at least one 1056 # is on the worklist. 1057 for src_indices in itertools.product(range(len(rep)), repeat=num_srcs): 1058 if all(src_idx < op_worklist_index for src_idx in src_indices): 1059 continue 1060 1061 srcs = tuple(rep[src_idx] for src_idx in src_indices) 1062 1063 # Try all possible pairings of source items and add the 1064 # corresponding parent items. This is Comp_a from the paper. 1065 parent = set(self.items[op, item_srcs] for item_srcs in 1066 itertools.product(*srcs) if (op, item_srcs) in self.items) 1067 1068 # We could always start matching something else with a 1069 # wildcard. This is Cl from the paper. 1070 parent.add(self.wildcard) 1071 1072 table[src_indices] = self.states.add(frozenset(parent)) 1073 worklist_indices[op] = len(rep) 1074 new_opcodes.clear() 1075 process_new_states() 1076 1077_algebraic_pass_template = mako.template.Template(""" 1078#include "nir.h" 1079#include "nir_builder.h" 1080#include "nir_search.h" 1081#include "nir_search_helpers.h" 1082 1083/* What follows is NIR algebraic transform code for the following ${len(xforms)} 1084 * transforms: 1085% for xform in xforms: 1086 * ${xform.search} => ${xform.replace} 1087% endfor 1088 */ 1089 1090<% cache = {"next_index": 0} %> 1091static const nir_search_value_union ${pass_name}_values[] = { 1092% for xform in xforms: 1093 /* ${xform.search} => ${xform.replace} */ 1094${xform.search.render(cache)} 1095${xform.replace.render(cache)} 1096% endfor 1097}; 1098 1099% if expression_cond: 1100static const nir_search_expression_cond ${pass_name}_expression_cond[] = { 1101% for cond in expression_cond: 1102 ${cond[0]}, 1103% endfor 1104}; 1105% endif 1106 1107% if variable_cond: 1108static const nir_search_variable_cond ${pass_name}_variable_cond[] = { 1109% for cond in variable_cond: 1110 ${cond[0]}, 1111% endfor 1112}; 1113% endif 1114 1115static const struct transform ${pass_name}_transforms[] = { 1116% for i in automaton.state_patterns: 1117% if i is not None: 1118 { ${xforms[i].search.array_index}, ${xforms[i].replace.array_index}, ${xforms[i].condition_index} }, 1119% else: 1120 { ~0, ~0, ~0 }, /* Sentinel */ 1121 1122% endif 1123% endfor 1124}; 1125 1126static const struct per_op_table ${pass_name}_pass_op_table[nir_num_search_ops] = { 1127% for op in automaton.opcodes: 1128 [${get_c_opcode(op)}] = { 1129% if all(e == 0 for e in automaton.filter[op]): 1130 .filter = NULL, 1131% else: 1132 .filter = (const uint16_t []) { 1133 % for e in automaton.filter[op]: 1134 ${e}, 1135 % endfor 1136 }, 1137% endif 1138 <% 1139 num_filtered = len(automaton.rep[op]) 1140 %> 1141 .num_filtered_states = ${num_filtered}, 1142 .table = (const uint16_t []) { 1143 <% 1144 num_srcs = len(next(iter(automaton.table[op]))) 1145 %> 1146 % for indices in itertools.product(range(num_filtered), repeat=num_srcs): 1147 ${automaton.table[op][indices]}, 1148 % endfor 1149 }, 1150 }, 1151% endfor 1152}; 1153 1154/* Mapping from state index to offset in transforms (0 being no transforms) */ 1155static const uint16_t ${pass_name}_transform_offsets[] = { 1156% for offset in automaton.state_pattern_offsets: 1157 ${offset}, 1158% endfor 1159}; 1160 1161static const nir_algebraic_table ${pass_name}_table = { 1162 .transforms = ${pass_name}_transforms, 1163 .transform_offsets = ${pass_name}_transform_offsets, 1164 .pass_op_table = ${pass_name}_pass_op_table, 1165 .values = ${pass_name}_values, 1166 .expression_cond = ${ pass_name + "_expression_cond" if expression_cond else "NULL" }, 1167 .variable_cond = ${ pass_name + "_variable_cond" if variable_cond else "NULL" }, 1168}; 1169 1170bool 1171${pass_name}( 1172 nir_shader *shader 1173% for type, name in params: 1174 , ${type} ${name} 1175% endfor 1176) { 1177 bool progress = false; 1178 bool condition_flags[${len(condition_list)}]; 1179 const nir_shader_compiler_options *options = shader->options; 1180 const shader_info *info = &shader->info; 1181 (void) options; 1182 (void) info; 1183 1184 STATIC_ASSERT(${str(cache["next_index"])} == ARRAY_SIZE(${pass_name}_values)); 1185 % for index, condition in enumerate(condition_list): 1186 condition_flags[${index}] = ${condition}; 1187 % endfor 1188 1189 nir_foreach_function_impl(impl, shader) { 1190 progress |= nir_algebraic_impl(impl, condition_flags, &${pass_name}_table); 1191 } 1192 1193 return progress; 1194} 1195""") 1196 1197 1198class AlgebraicPass(object): 1199 # params is a list of `("type", "name")` tuples 1200 def __init__(self, pass_name, transforms, params=[]): 1201 self.xforms = [] 1202 self.opcode_xforms = defaultdict(lambda : []) 1203 self.pass_name = pass_name 1204 self.expression_cond = {} 1205 self.variable_cond = {} 1206 self.params = params 1207 1208 error = False 1209 1210 for xform in transforms: 1211 if not isinstance(xform, SearchAndReplace): 1212 try: 1213 xform = SearchAndReplace(xform, self) 1214 except: 1215 print("Failed to parse transformation:", file=sys.stderr) 1216 print(" " + str(xform), file=sys.stderr) 1217 traceback.print_exc(file=sys.stderr) 1218 print('', file=sys.stderr) 1219 error = True 1220 continue 1221 1222 self.xforms.append(xform) 1223 if xform.search.opcode in conv_opcode_types: 1224 dst_type = conv_opcode_types[xform.search.opcode] 1225 for size in type_sizes(dst_type): 1226 sized_opcode = xform.search.opcode + str(size) 1227 self.opcode_xforms[sized_opcode].append(xform) 1228 else: 1229 self.opcode_xforms[xform.search.opcode].append(xform) 1230 1231 # Check to make sure the search pattern does not unexpectedly contain 1232 # more commutative expressions than match_expression (nir_search.c) 1233 # can handle. 1234 comm_exprs = xform.search.comm_exprs 1235 1236 if xform.search.many_commutative_expressions: 1237 if comm_exprs <= nir_search_max_comm_ops: 1238 print("Transform expected to have too many commutative " \ 1239 "expression but did not " \ 1240 "({} <= {}).".format(comm_exprs, nir_search_max_comm_op), 1241 file=sys.stderr) 1242 print(" " + str(xform), file=sys.stderr) 1243 traceback.print_exc(file=sys.stderr) 1244 print('', file=sys.stderr) 1245 error = True 1246 else: 1247 if comm_exprs > nir_search_max_comm_ops: 1248 print("Transformation with too many commutative expressions " \ 1249 "({} > {}). Modify pattern or annotate with " \ 1250 "\"many-comm-expr\".".format(comm_exprs, 1251 nir_search_max_comm_ops), 1252 file=sys.stderr) 1253 print(" " + str(xform.search), file=sys.stderr) 1254 print("{}".format(xform.search.cond), file=sys.stderr) 1255 error = True 1256 1257 self.automaton = TreeAutomaton(self.xforms) 1258 1259 if error: 1260 sys.exit(1) 1261 1262 1263 def render(self): 1264 return _algebraic_pass_template.render(pass_name=self.pass_name, 1265 xforms=self.xforms, 1266 opcode_xforms=self.opcode_xforms, 1267 condition_list=condition_list, 1268 automaton=self.automaton, 1269 expression_cond = sorted(self.expression_cond.items(), key=lambda kv: kv[1]), 1270 variable_cond = sorted(self.variable_cond.items(), key=lambda kv: kv[1]), 1271 get_c_opcode=get_c_opcode, 1272 itertools=itertools, 1273 params=self.params) 1274 1275# The replacement expression isn't necessarily exact if the search expression is exact. 1276def ignore_exact(*expr): 1277 expr = SearchExpression.create(expr) 1278 expr.ignore_exact = True 1279 return expr 1280