1# 2# Copyright (C) 2014 Intel Corporation 3# 4# Permission is hereby granted, free of charge, to any person obtaining a 5# copy of this software and associated documentation files (the "Software"), 6# to deal in the Software without restriction, including without limitation 7# the rights to use, copy, modify, merge, publish, distribute, sublicense, 8# and/or sell copies of the Software, and to permit persons to whom the 9# Software is furnished to do so, subject to the following conditions: 10# 11# The above copyright notice and this permission notice (including the next 12# paragraph) shall be included in all copies or substantial portions of the 13# Software. 14# 15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21# IN THE SOFTWARE. 22# 23# Authors: 24# Jason Ekstrand (jason@jlekstrand.net) 25 26from __future__ import print_function 27import ast 28from collections import defaultdict 29import itertools 30import struct 31import sys 32import mako.template 33import re 34import traceback 35 36from nir_opcodes import opcodes, type_sizes 37 38# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c 39nir_search_max_comm_ops = 8 40 41# These opcodes are only employed by nir_search. This provides a mapping from 42# opcode to destination type. 43conv_opcode_types = { 44 'i2f' : 'float', 45 'u2f' : 'float', 46 'f2f' : 'float', 47 'f2u' : 'uint', 48 'f2i' : 'int', 49 'u2u' : 'uint', 50 'i2i' : 'int', 51 'b2f' : 'float', 52 'b2i' : 'int', 53 'i2b' : 'bool', 54 'f2b' : 'bool', 55} 56 57def get_c_opcode(op): 58 if op in conv_opcode_types: 59 return 'nir_search_op_' + op 60 else: 61 return 'nir_op_' + op 62 63 64if sys.version_info < (3, 0): 65 integer_types = (int, long) 66 string_type = unicode 67 68else: 69 integer_types = (int, ) 70 string_type = str 71 72_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?") 73 74def type_bits(type_str): 75 m = _type_re.match(type_str) 76 assert m.group('type') 77 78 if m.group('bits') is None: 79 return 0 80 else: 81 return int(m.group('bits')) 82 83# Represents a set of variables, each with a unique id 84class VarSet(object): 85 def __init__(self): 86 self.names = {} 87 self.ids = itertools.count() 88 self.immutable = False; 89 90 def __getitem__(self, name): 91 if name not in self.names: 92 assert not self.immutable, "Unknown replacement variable: " + name 93 self.names[name] = next(self.ids) 94 95 return self.names[name] 96 97 def lock(self): 98 self.immutable = True 99 100class Value(object): 101 @staticmethod 102 def create(val, name_base, varset): 103 if isinstance(val, bytes): 104 val = val.decode('utf-8') 105 106 if isinstance(val, tuple): 107 return Expression(val, name_base, varset) 108 elif isinstance(val, Expression): 109 return val 110 elif isinstance(val, string_type): 111 return Variable(val, name_base, varset) 112 elif isinstance(val, (bool, float) + integer_types): 113 return Constant(val, name_base) 114 115 def __init__(self, val, name, type_str): 116 self.in_val = str(val) 117 self.name = name 118 self.type_str = type_str 119 120 def __str__(self): 121 return self.in_val 122 123 def get_bit_size(self): 124 """Get the physical bit-size that has been chosen for this value, or if 125 there is none, the canonical value which currently represents this 126 bit-size class. Variables will be preferred, i.e. if there are any 127 variables in the equivalence class, the canonical value will be a 128 variable. We do this since we'll need to know which variable each value 129 is equivalent to when constructing the replacement expression. This is 130 the "find" part of the union-find algorithm. 131 """ 132 bit_size = self 133 134 while isinstance(bit_size, Value): 135 if bit_size._bit_size is None: 136 break 137 bit_size = bit_size._bit_size 138 139 if bit_size is not self: 140 self._bit_size = bit_size 141 return bit_size 142 143 def set_bit_size(self, other): 144 """Make self.get_bit_size() return what other.get_bit_size() return 145 before calling this, or just "other" if it's a concrete bit-size. This is 146 the "union" part of the union-find algorithm. 147 """ 148 149 self_bit_size = self.get_bit_size() 150 other_bit_size = other if isinstance(other, int) else other.get_bit_size() 151 152 if self_bit_size == other_bit_size: 153 return 154 155 self_bit_size._bit_size = other_bit_size 156 157 @property 158 def type_enum(self): 159 return "nir_search_value_" + self.type_str 160 161 @property 162 def c_type(self): 163 return "nir_search_" + self.type_str 164 165 def __c_name(self, cache): 166 if cache is not None and self.name in cache: 167 return cache[self.name] 168 else: 169 return self.name 170 171 def c_value_ptr(self, cache): 172 return "&{0}.value".format(self.__c_name(cache)) 173 174 def c_ptr(self, cache): 175 return "&{0}".format(self.__c_name(cache)) 176 177 @property 178 def c_bit_size(self): 179 bit_size = self.get_bit_size() 180 if isinstance(bit_size, int): 181 return bit_size 182 elif isinstance(bit_size, Variable): 183 return -bit_size.index - 1 184 else: 185 # If the bit-size class is neither a variable, nor an actual bit-size, then 186 # - If it's in the search expression, we don't need to check anything 187 # - If it's in the replace expression, either it's ambiguous (in which 188 # case we'd reject it), or it equals the bit-size of the search value 189 # We represent these cases with a 0 bit-size. 190 return 0 191 192 __template = mako.template.Template("""{ 193 { ${val.type_enum}, ${val.c_bit_size} }, 194% if isinstance(val, Constant): 195 ${val.type()}, { ${val.hex()} /* ${val.value} */ }, 196% elif isinstance(val, Variable): 197 ${val.index}, /* ${val.var_name} */ 198 ${'true' if val.is_constant else 'false'}, 199 ${val.type() or 'nir_type_invalid' }, 200 ${val.cond if val.cond else 'NULL'}, 201 ${val.swizzle()}, 202% elif isinstance(val, Expression): 203 ${'true' if val.inexact else 'false'}, ${'true' if val.exact else 'false'}, 204 ${val.comm_expr_idx}, ${val.comm_exprs}, 205 ${val.c_opcode()}, 206 { ${', '.join(src.c_value_ptr(cache) for src in val.sources)} }, 207 ${val.cond if val.cond else 'NULL'}, 208% endif 209};""") 210 211 def render(self, cache): 212 struct_init = self.__template.render(val=self, cache=cache, 213 Constant=Constant, 214 Variable=Variable, 215 Expression=Expression) 216 if cache is not None and struct_init in cache: 217 # If it's in the cache, register a name remap in the cache and render 218 # only a comment saying it's been remapped 219 cache[self.name] = cache[struct_init] 220 return "/* {} -> {} in the cache */\n".format(self.name, 221 cache[struct_init]) 222 else: 223 if cache is not None: 224 cache[struct_init] = self.name 225 return "static const {} {} = {}\n".format(self.c_type, self.name, 226 struct_init) 227 228_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?") 229 230class Constant(Value): 231 def __init__(self, val, name): 232 Value.__init__(self, val, name, "constant") 233 234 if isinstance(val, (str)): 235 m = _constant_re.match(val) 236 self.value = ast.literal_eval(m.group('value')) 237 self._bit_size = int(m.group('bits')) if m.group('bits') else None 238 else: 239 self.value = val 240 self._bit_size = None 241 242 if isinstance(self.value, bool): 243 assert self._bit_size is None or self._bit_size == 1 244 self._bit_size = 1 245 246 def hex(self): 247 if isinstance(self.value, (bool)): 248 return 'NIR_TRUE' if self.value else 'NIR_FALSE' 249 if isinstance(self.value, integer_types): 250 return hex(self.value) 251 elif isinstance(self.value, float): 252 i = struct.unpack('Q', struct.pack('d', self.value))[0] 253 h = hex(i) 254 255 # On Python 2 this 'L' suffix is automatically added, but not on Python 3 256 # Adding it explicitly makes the generated file identical, regardless 257 # of the Python version running this script. 258 if h[-1] != 'L' and i > sys.maxsize: 259 h += 'L' 260 261 return h 262 else: 263 assert False 264 265 def type(self): 266 if isinstance(self.value, (bool)): 267 return "nir_type_bool" 268 elif isinstance(self.value, integer_types): 269 return "nir_type_int" 270 elif isinstance(self.value, float): 271 return "nir_type_float" 272 273 def equivalent(self, other): 274 """Check that two constants are equivalent. 275 276 This is check is much weaker than equality. One generally cannot be 277 used in place of the other. Using this implementation for the __eq__ 278 will break BitSizeValidator. 279 280 """ 281 if not isinstance(other, type(self)): 282 return False 283 284 return self.value == other.value 285 286# The $ at the end forces there to be an error if any part of the string 287# doesn't match one of the field patterns. 288_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)" 289 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?" 290 r"(?P<cond>\([^\)]+\))?" 291 r"(?P<swiz>\.[xyzw]+)?" 292 r"$") 293 294class Variable(Value): 295 def __init__(self, val, name, varset): 296 Value.__init__(self, val, name, "variable") 297 298 m = _var_name_re.match(val) 299 assert m and m.group('name') is not None, \ 300 "Malformed variable name \"{}\".".format(val) 301 302 self.var_name = m.group('name') 303 304 # Prevent common cases where someone puts quotes around a literal 305 # constant. If we want to support names that have numeric or 306 # punctuation characters, we can me the first assertion more flexible. 307 assert self.var_name.isalpha() 308 assert self.var_name != 'True' 309 assert self.var_name != 'False' 310 311 self.is_constant = m.group('const') is not None 312 self.cond = m.group('cond') 313 self.required_type = m.group('type') 314 self._bit_size = int(m.group('bits')) if m.group('bits') else None 315 self.swiz = m.group('swiz') 316 317 if self.required_type == 'bool': 318 if self._bit_size is not None: 319 assert self._bit_size in type_sizes(self.required_type) 320 else: 321 self._bit_size = 1 322 323 if self.required_type is not None: 324 assert self.required_type in ('float', 'bool', 'int', 'uint') 325 326 self.index = varset[self.var_name] 327 328 def type(self): 329 if self.required_type == 'bool': 330 return "nir_type_bool" 331 elif self.required_type in ('int', 'uint'): 332 return "nir_type_int" 333 elif self.required_type == 'float': 334 return "nir_type_float" 335 336 def equivalent(self, other): 337 """Check that two variables are equivalent. 338 339 This is check is much weaker than equality. One generally cannot be 340 used in place of the other. Using this implementation for the __eq__ 341 will break BitSizeValidator. 342 343 """ 344 if not isinstance(other, type(self)): 345 return False 346 347 return self.index == other.index 348 349 def swizzle(self): 350 if self.swiz is not None: 351 swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3, 352 'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3, 353 'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7, 354 'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11, 355 'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 } 356 return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}' 357 return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}' 358 359_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?" 360 r"(?P<cond>\([^\)]+\))?") 361 362class Expression(Value): 363 def __init__(self, expr, name_base, varset): 364 Value.__init__(self, expr, name_base, "expression") 365 assert isinstance(expr, tuple) 366 367 m = _opcode_re.match(expr[0]) 368 assert m and m.group('opcode') is not None 369 370 self.opcode = m.group('opcode') 371 self._bit_size = int(m.group('bits')) if m.group('bits') else None 372 self.inexact = m.group('inexact') is not None 373 self.exact = m.group('exact') is not None 374 self.cond = m.group('cond') 375 376 assert not self.inexact or not self.exact, \ 377 'Expression cannot be both exact and inexact.' 378 379 # "many-comm-expr" isn't really a condition. It's notification to the 380 # generator that this pattern is known to have too many commutative 381 # expressions, and an error should not be generated for this case. 382 self.many_commutative_expressions = False 383 if self.cond and self.cond.find("many-comm-expr") >= 0: 384 # Split the condition into a comma-separated list. Remove 385 # "many-comm-expr". If there is anything left, put it back together. 386 c = self.cond[1:-1].split(",") 387 c.remove("many-comm-expr") 388 389 self.cond = "({})".format(",".join(c)) if c else None 390 self.many_commutative_expressions = True 391 392 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset) 393 for (i, src) in enumerate(expr[1:]) ] 394 395 # nir_search_expression::srcs is hard-coded to 4 396 assert len(self.sources) <= 4 397 398 if self.opcode in conv_opcode_types: 399 assert self._bit_size is None, \ 400 'Expression cannot use an unsized conversion opcode with ' \ 401 'an explicit size; that\'s silly.' 402 403 self.__index_comm_exprs(0) 404 405 def equivalent(self, other): 406 """Check that two variables are equivalent. 407 408 This is check is much weaker than equality. One generally cannot be 409 used in place of the other. Using this implementation for the __eq__ 410 will break BitSizeValidator. 411 412 This implementation does not check for equivalence due to commutativity, 413 but it could. 414 415 """ 416 if not isinstance(other, type(self)): 417 return False 418 419 if len(self.sources) != len(other.sources): 420 return False 421 422 if self.opcode != other.opcode: 423 return False 424 425 return all(s.equivalent(o) for s, o in zip(self.sources, other.sources)) 426 427 def __index_comm_exprs(self, base_idx): 428 """Recursively count and index commutative expressions 429 """ 430 self.comm_exprs = 0 431 432 # A note about the explicit "len(self.sources)" check. The list of 433 # sources comes from user input, and that input might be bad. Check 434 # that the expected second source exists before accessing it. Without 435 # this check, a unit test that does "('iadd', 'a')" will crash. 436 if self.opcode not in conv_opcode_types and \ 437 "2src_commutative" in opcodes[self.opcode].algebraic_properties and \ 438 len(self.sources) >= 2 and \ 439 not self.sources[0].equivalent(self.sources[1]): 440 self.comm_expr_idx = base_idx 441 self.comm_exprs += 1 442 else: 443 self.comm_expr_idx = -1 444 445 for s in self.sources: 446 if isinstance(s, Expression): 447 s.__index_comm_exprs(base_idx + self.comm_exprs) 448 self.comm_exprs += s.comm_exprs 449 450 return self.comm_exprs 451 452 def c_opcode(self): 453 return get_c_opcode(self.opcode) 454 455 def render(self, cache): 456 srcs = "\n".join(src.render(cache) for src in self.sources) 457 return srcs + super(Expression, self).render(cache) 458 459class BitSizeValidator(object): 460 """A class for validating bit sizes of expressions. 461 462 NIR supports multiple bit-sizes on expressions in order to handle things 463 such as fp64. The source and destination of every ALU operation is 464 assigned a type and that type may or may not specify a bit size. Sources 465 and destinations whose type does not specify a bit size are considered 466 "unsized" and automatically take on the bit size of the corresponding 467 register or SSA value. NIR has two simple rules for bit sizes that are 468 validated by nir_validator: 469 470 1) A given SSA def or register has a single bit size that is respected by 471 everything that reads from it or writes to it. 472 473 2) The bit sizes of all unsized inputs/outputs on any given ALU 474 instruction must match. They need not match the sized inputs or 475 outputs but they must match each other. 476 477 In order to keep nir_algebraic relatively simple and easy-to-use, 478 nir_search supports a type of bit-size inference based on the two rules 479 above. This is similar to type inference in many common programming 480 languages. If, for instance, you are constructing an add operation and you 481 know the second source is 16-bit, then you know that the other source and 482 the destination must also be 16-bit. There are, however, cases where this 483 inference can be ambiguous or contradictory. Consider, for instance, the 484 following transformation: 485 486 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b))) 487 488 This transformation can potentially cause a problem because usub_borrow is 489 well-defined for any bit-size of integer. However, b2i always generates a 490 32-bit result so it could end up replacing a 64-bit expression with one 491 that takes two 64-bit values and produces a 32-bit value. As another 492 example, consider this expression: 493 494 (('bcsel', a, b, 0), ('iand', a, b)) 495 496 In this case, in the search expression a must be 32-bit but b can 497 potentially have any bit size. If we had a 64-bit b value, we would end up 498 trying to and a 32-bit value with a 64-bit value which would be invalid 499 500 This class solves that problem by providing a validation layer that proves 501 that a given search-and-replace operation is 100% well-defined before we 502 generate any code. This ensures that bugs are caught at compile time 503 rather than at run time. 504 505 Each value maintains a "bit-size class", which is either an actual bit size 506 or an equivalence class with other values that must have the same bit size. 507 The validator works by combining bit-size classes with each other according 508 to the NIR rules outlined above, checking that there are no inconsistencies. 509 When doing this for the replacement expression, we make sure to never change 510 the equivalence class of any of the search values. We could make the example 511 transforms above work by doing some extra run-time checking of the search 512 expression, but we make the user specify those constraints themselves, to 513 avoid any surprises. Since the replacement bitsizes can only be connected to 514 the source bitsize via variables (variables must have the same bitsize in 515 the source and replacment expressions) or the roots of the expression (the 516 replacement expression must produce the same bit size as the search 517 expression), we prevent merging a variable with anything when processing the 518 replacement expression, or specializing the search bitsize 519 with anything. The former prevents 520 521 (('bcsel', a, b, 0), ('iand', a, b)) 522 523 from being allowed, since we'd have to merge the bitsizes for a and b due to 524 the 'iand', while the latter prevents 525 526 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b))) 527 528 from being allowed, since the search expression has the bit size of a and b, 529 which can't be specialized to 32 which is the bitsize of the replace 530 expression. It also prevents something like: 531 532 (('b2i', ('i2b', a)), ('ineq', a, 0)) 533 534 since the bitsize of 'b2i', which can be anything, can't be specialized to 535 the bitsize of a. 536 537 After doing all this, we check that every subexpression of the replacement 538 was assigned a constant bitsize, the bitsize of a variable, or the bitsize 539 of the search expresssion, since those are the things that are known when 540 constructing the replacement expresssion. Finally, we record the bitsize 541 needed in nir_search_value so that we know what to do when building the 542 replacement expression. 543 """ 544 545 def __init__(self, varset): 546 self._var_classes = [None] * len(varset.names) 547 548 def compare_bitsizes(self, a, b): 549 """Determines which bitsize class is a specialization of the other, or 550 whether neither is. When we merge two different bitsizes, the 551 less-specialized bitsize always points to the more-specialized one, so 552 that calling get_bit_size() always gets you the most specialized bitsize. 553 The specialization partial order is given by: 554 - Physical bitsizes are always the most specialized, and a different 555 bitsize can never specialize another. 556 - In the search expression, variables can always be specialized to each 557 other and to physical bitsizes. In the replace expression, we disallow 558 this to avoid adding extra constraints to the search expression that 559 the user didn't specify. 560 - Expressions and constants without a bitsize can always be specialized to 561 each other and variables, but not the other way around. 562 563 We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b, 564 and None if they are not comparable (neither a <= b nor b <= a). 565 """ 566 if isinstance(a, int): 567 if isinstance(b, int): 568 return 0 if a == b else None 569 elif isinstance(b, Variable): 570 return -1 if self.is_search else None 571 else: 572 return -1 573 elif isinstance(a, Variable): 574 if isinstance(b, int): 575 return 1 if self.is_search else None 576 elif isinstance(b, Variable): 577 return 0 if self.is_search or a.index == b.index else None 578 else: 579 return -1 580 else: 581 if isinstance(b, int): 582 return 1 583 elif isinstance(b, Variable): 584 return 1 585 else: 586 return 0 587 588 def unify_bit_size(self, a, b, error_msg): 589 """Record that a must have the same bit-size as b. If both 590 have been assigned conflicting physical bit-sizes, call "error_msg" with 591 the bit-sizes of self and other to get a message and raise an error. 592 In the replace expression, disallow merging variables with other 593 variables and physical bit-sizes as well. 594 """ 595 a_bit_size = a.get_bit_size() 596 b_bit_size = b if isinstance(b, int) else b.get_bit_size() 597 598 cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size) 599 600 assert cmp_result is not None, \ 601 error_msg(a_bit_size, b_bit_size) 602 603 if cmp_result < 0: 604 b_bit_size.set_bit_size(a) 605 elif not isinstance(a_bit_size, int): 606 a_bit_size.set_bit_size(b) 607 608 def merge_variables(self, val): 609 """Perform the first part of type inference by merging all the different 610 uses of the same variable. We always do this as if we're in the search 611 expression, even if we're actually not, since otherwise we'd get errors 612 if the search expression specified some constraint but the replace 613 expression didn't, because we'd be merging a variable and a constant. 614 """ 615 if isinstance(val, Variable): 616 if self._var_classes[val.index] is None: 617 self._var_classes[val.index] = val 618 else: 619 other = self._var_classes[val.index] 620 self.unify_bit_size(other, val, 621 lambda other_bit_size, bit_size: 622 'Variable {} has conflicting bit size requirements: ' \ 623 'it must have bit size {} and {}'.format( 624 val.var_name, other_bit_size, bit_size)) 625 elif isinstance(val, Expression): 626 for src in val.sources: 627 self.merge_variables(src) 628 629 def validate_value(self, val): 630 """Validate the an expression by performing classic Hindley-Milner 631 type inference on bitsizes. This will detect if there are any conflicting 632 requirements, and unify variables so that we know which variables must 633 have the same bitsize. If we're operating on the replace expression, we 634 will refuse to merge different variables together or merge a variable 635 with a constant, in order to prevent surprises due to rules unexpectedly 636 not matching at runtime. 637 """ 638 if not isinstance(val, Expression): 639 return 640 641 # Generic conversion ops are special in that they have a single unsized 642 # source and an unsized destination and the two don't have to match. 643 # This means there's no validation or unioning to do here besides the 644 # len(val.sources) check. 645 if val.opcode in conv_opcode_types: 646 assert len(val.sources) == 1, \ 647 "Expression {} has {} sources, expected 1".format( 648 val, len(val.sources)) 649 self.validate_value(val.sources[0]) 650 return 651 652 nir_op = opcodes[val.opcode] 653 assert len(val.sources) == nir_op.num_inputs, \ 654 "Expression {} has {} sources, expected {}".format( 655 val, len(val.sources), nir_op.num_inputs) 656 657 for src in val.sources: 658 self.validate_value(src) 659 660 dst_type_bits = type_bits(nir_op.output_type) 661 662 # First, unify all the sources. That way, an error coming up because two 663 # sources have an incompatible bit-size won't produce an error message 664 # involving the destination. 665 first_unsized_src = None 666 for src_type, src in zip(nir_op.input_types, val.sources): 667 src_type_bits = type_bits(src_type) 668 if src_type_bits == 0: 669 if first_unsized_src is None: 670 first_unsized_src = src 671 continue 672 673 if self.is_search: 674 self.unify_bit_size(first_unsized_src, src, 675 lambda first_unsized_src_bit_size, src_bit_size: 676 'Source {} of {} must have bit size {}, while source {} ' \ 677 'must have incompatible bit size {}'.format( 678 first_unsized_src, val, first_unsized_src_bit_size, 679 src, src_bit_size)) 680 else: 681 self.unify_bit_size(first_unsized_src, src, 682 lambda first_unsized_src_bit_size, src_bit_size: 683 'Sources {} (bit size of {}) and {} (bit size of {}) ' \ 684 'of {} may not have the same bit size when building the ' \ 685 'replacement expression.'.format( 686 first_unsized_src, first_unsized_src_bit_size, src, 687 src_bit_size, val)) 688 else: 689 if self.is_search: 690 self.unify_bit_size(src, src_type_bits, 691 lambda src_bit_size, unused: 692 '{} must have {} bits, but as a source of nir_op_{} '\ 693 'it must have {} bits'.format( 694 src, src_bit_size, nir_op.name, src_type_bits)) 695 else: 696 self.unify_bit_size(src, src_type_bits, 697 lambda src_bit_size, unused: 698 '{} has the bit size of {}, but as a source of ' \ 699 'nir_op_{} it must have {} bits, which may not be the ' \ 700 'same'.format( 701 src, src_bit_size, nir_op.name, src_type_bits)) 702 703 if dst_type_bits == 0: 704 if first_unsized_src is not None: 705 if self.is_search: 706 self.unify_bit_size(val, first_unsized_src, 707 lambda val_bit_size, src_bit_size: 708 '{} must have the bit size of {}, while its source {} ' \ 709 'must have incompatible bit size {}'.format( 710 val, val_bit_size, first_unsized_src, src_bit_size)) 711 else: 712 self.unify_bit_size(val, first_unsized_src, 713 lambda val_bit_size, src_bit_size: 714 '{} must have {} bits, but its source {} ' \ 715 '(bit size of {}) may not have that bit size ' \ 716 'when building the replacement.'.format( 717 val, val_bit_size, first_unsized_src, src_bit_size)) 718 else: 719 self.unify_bit_size(val, dst_type_bits, 720 lambda dst_bit_size, unused: 721 '{} must have {} bits, but as a destination of nir_op_{} ' \ 722 'it must have {} bits'.format( 723 val, dst_bit_size, nir_op.name, dst_type_bits)) 724 725 def validate_replace(self, val, search): 726 bit_size = val.get_bit_size() 727 assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \ 728 bit_size == search.get_bit_size(), \ 729 'Ambiguous bit size for replacement value {}: ' \ 730 'it cannot be deduced from a variable, a fixed bit size ' \ 731 'somewhere, or the search expression.'.format(val) 732 733 if isinstance(val, Expression): 734 for src in val.sources: 735 self.validate_replace(src, search) 736 737 def validate(self, search, replace): 738 self.is_search = True 739 self.merge_variables(search) 740 self.merge_variables(replace) 741 self.validate_value(search) 742 743 self.is_search = False 744 self.validate_value(replace) 745 746 # Check that search is always more specialized than replace. Note that 747 # we're doing this in replace mode, disallowing merging variables. 748 search_bit_size = search.get_bit_size() 749 replace_bit_size = replace.get_bit_size() 750 cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size) 751 752 assert cmp_result is not None and cmp_result <= 0, \ 753 'The search expression bit size {} and replace expression ' \ 754 'bit size {} may not be the same'.format( 755 search_bit_size, replace_bit_size) 756 757 replace.set_bit_size(search) 758 759 self.validate_replace(replace, search) 760 761_optimization_ids = itertools.count() 762 763condition_list = ['true'] 764 765class SearchAndReplace(object): 766 def __init__(self, transform): 767 self.id = next(_optimization_ids) 768 769 search = transform[0] 770 replace = transform[1] 771 if len(transform) > 2: 772 self.condition = transform[2] 773 else: 774 self.condition = 'true' 775 776 if self.condition not in condition_list: 777 condition_list.append(self.condition) 778 self.condition_index = condition_list.index(self.condition) 779 780 varset = VarSet() 781 if isinstance(search, Expression): 782 self.search = search 783 else: 784 self.search = Expression(search, "search{0}".format(self.id), varset) 785 786 varset.lock() 787 788 if isinstance(replace, Value): 789 self.replace = replace 790 else: 791 self.replace = Value.create(replace, "replace{0}".format(self.id), varset) 792 793 BitSizeValidator(varset).validate(self.search, self.replace) 794 795class TreeAutomaton(object): 796 """This class calculates a bottom-up tree automaton to quickly search for 797 the left-hand sides of tranforms. Tree automatons are a generalization of 798 classical NFA's and DFA's, where the transition function determines the 799 state of the parent node based on the state of its children. We construct a 800 deterministic automaton to match patterns, using a similar algorithm to the 801 classical NFA to DFA construction. At the moment, it only matches opcodes 802 and constants (without checking the actual value), leaving more detailed 803 checking to the search function which actually checks the leaves. The 804 automaton acts as a quick filter for the search function, requiring only n 805 + 1 table lookups for each n-source operation. The implementation is based 806 on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit." 807 In the language of that reference, this is a frontier-to-root deterministic 808 automaton using only symbol filtering. The filtering is crucial to reduce 809 both the time taken to generate the tables and the size of the tables. 810 """ 811 def __init__(self, transforms): 812 self.patterns = [t.search for t in transforms] 813 self._compute_items() 814 self._build_table() 815 #print('num items: {}'.format(len(set(self.items.values())))) 816 #print('num states: {}'.format(len(self.states))) 817 #for state, patterns in zip(self.states, self.patterns): 818 # print('{}: num patterns: {}'.format(state, len(patterns))) 819 820 class IndexMap(object): 821 """An indexed list of objects, where one can either lookup an object by 822 index or find the index associated to an object quickly using a hash 823 table. Compared to a list, it has a constant time index(). Compared to a 824 set, it provides a stable iteration order. 825 """ 826 def __init__(self, iterable=()): 827 self.objects = [] 828 self.map = {} 829 for obj in iterable: 830 self.add(obj) 831 832 def __getitem__(self, i): 833 return self.objects[i] 834 835 def __contains__(self, obj): 836 return obj in self.map 837 838 def __len__(self): 839 return len(self.objects) 840 841 def __iter__(self): 842 return iter(self.objects) 843 844 def clear(self): 845 self.objects = [] 846 self.map.clear() 847 848 def index(self, obj): 849 return self.map[obj] 850 851 def add(self, obj): 852 if obj in self.map: 853 return self.map[obj] 854 else: 855 index = len(self.objects) 856 self.objects.append(obj) 857 self.map[obj] = index 858 return index 859 860 def __repr__(self): 861 return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])' 862 863 class Item(object): 864 """This represents an "item" in the language of "Tree Automatons." This 865 is just a subtree of some pattern, which represents a potential partial 866 match at runtime. We deduplicate them, so that identical subtrees of 867 different patterns share the same object, and store some extra 868 information needed for the main algorithm as well. 869 """ 870 def __init__(self, opcode, children): 871 self.opcode = opcode 872 self.children = children 873 # These are the indices of patterns for which this item is the root node. 874 self.patterns = [] 875 # This the set of opcodes for parents of this item. Used to speed up 876 # filtering. 877 self.parent_ops = set() 878 879 def __str__(self): 880 return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')' 881 882 def __repr__(self): 883 return str(self) 884 885 def _compute_items(self): 886 """Build a set of all possible items, deduplicating them.""" 887 # This is a map from (opcode, sources) to item. 888 self.items = {} 889 890 # The set of all opcodes used by the patterns. Used later to avoid 891 # building and emitting all the tables for opcodes that aren't used. 892 self.opcodes = self.IndexMap() 893 894 def get_item(opcode, children, pattern=None): 895 commutative = len(children) >= 2 \ 896 and "2src_commutative" in opcodes[opcode].algebraic_properties 897 item = self.items.setdefault((opcode, children), 898 self.Item(opcode, children)) 899 if commutative: 900 self.items[opcode, (children[1], children[0]) + children[2:]] = item 901 if pattern is not None: 902 item.patterns.append(pattern) 903 return item 904 905 self.wildcard = get_item("__wildcard", ()) 906 self.const = get_item("__const", ()) 907 908 def process_subpattern(src, pattern=None): 909 if isinstance(src, Constant): 910 # Note: we throw away the actual constant value! 911 return self.const 912 elif isinstance(src, Variable): 913 if src.is_constant: 914 return self.const 915 else: 916 # Note: we throw away which variable it is here! This special 917 # item is equivalent to nu in "Tree Automatons." 918 return self.wildcard 919 else: 920 assert isinstance(src, Expression) 921 opcode = src.opcode 922 stripped = opcode.rstrip('0123456789') 923 if stripped in conv_opcode_types: 924 # Matches that use conversion opcodes with a specific type, 925 # like f2b1, are tricky. Either we construct the automaton to 926 # match specific NIR opcodes like nir_op_f2b1, in which case we 927 # need to create separate items for each possible NIR opcode 928 # for patterns that have a generic opcode like f2b, or we 929 # construct it to match the search opcode, in which case we 930 # need to map f2b1 to f2b when constructing the automaton. Here 931 # we do the latter. 932 opcode = stripped 933 self.opcodes.add(opcode) 934 children = tuple(process_subpattern(c) for c in src.sources) 935 item = get_item(opcode, children, pattern) 936 for i, child in enumerate(children): 937 child.parent_ops.add(opcode) 938 return item 939 940 for i, pattern in enumerate(self.patterns): 941 process_subpattern(pattern, i) 942 943 def _build_table(self): 944 """This is the core algorithm which builds up the transition table. It 945 is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl . 946 Comp_a and Filt_{a,i} using integers to identify match sets." It 947 simultaneously builds up a list of all possible "match sets" or 948 "states", where each match set represents the set of Item's that match a 949 given instruction, and builds up the transition table between states. 950 """ 951 # Map from opcode + filtered state indices to transitioned state. 952 self.table = defaultdict(dict) 953 # Bijection from state to index. q in the original algorithm is 954 # len(self.states) 955 self.states = self.IndexMap() 956 # List of pattern matches for each state index. 957 self.state_patterns = [] 958 # Map from state index to filtered state index for each opcode. 959 self.filter = defaultdict(list) 960 # Bijections from filtered state to filtered state index for each 961 # opcode, called the "representor sets" in the original algorithm. 962 # q_{a,j} in the original algorithm is len(self.rep[op]). 963 self.rep = defaultdict(self.IndexMap) 964 965 # Everything in self.states with a index at least worklist_index is part 966 # of the worklist of newly created states. There is also a worklist of 967 # newly fitered states for each opcode, for which worklist_indices 968 # serves a similar purpose. worklist_index corresponds to p in the 969 # original algorithm, while worklist_indices is p_{a,j} (although since 970 # we only filter by opcode/symbol, it's really just p_a). 971 self.worklist_index = 0 972 worklist_indices = defaultdict(lambda: 0) 973 974 # This is the set of opcodes for which the filtered worklist is non-empty. 975 # It's used to avoid scanning opcodes for which there is nothing to 976 # process when building the transition table. It corresponds to new_a in 977 # the original algorithm. 978 new_opcodes = self.IndexMap() 979 980 # Process states on the global worklist, filtering them for each opcode, 981 # updating the filter tables, and updating the filtered worklists if any 982 # new filtered states are found. Similar to ComputeRepresenterSets() in 983 # the original algorithm, although that only processes a single state. 984 def process_new_states(): 985 while self.worklist_index < len(self.states): 986 state = self.states[self.worklist_index] 987 988 # Calculate pattern matches for this state. Each pattern is 989 # assigned to a unique item, so we don't have to worry about 990 # deduplicating them here. However, we do have to sort them so 991 # that they're visited at runtime in the order they're specified 992 # in the source. 993 patterns = list(sorted(p for item in state for p in item.patterns)) 994 assert len(self.state_patterns) == self.worklist_index 995 self.state_patterns.append(patterns) 996 997 # calculate filter table for this state, and update filtered 998 # worklists. 999 for op in self.opcodes: 1000 filt = self.filter[op] 1001 rep = self.rep[op] 1002 filtered = frozenset(item for item in state if \ 1003 op in item.parent_ops) 1004 if filtered in rep: 1005 rep_index = rep.index(filtered) 1006 else: 1007 rep_index = rep.add(filtered) 1008 new_opcodes.add(op) 1009 assert len(filt) == self.worklist_index 1010 filt.append(rep_index) 1011 self.worklist_index += 1 1012 1013 # There are two start states: one which can only match as a wildcard, 1014 # and one which can match as a wildcard or constant. These will be the 1015 # states of intrinsics/other instructions and load_const instructions, 1016 # respectively. The indices of these must match the definitions of 1017 # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can 1018 # initialize things correctly. 1019 self.states.add(frozenset((self.wildcard,))) 1020 self.states.add(frozenset((self.const,self.wildcard))) 1021 process_new_states() 1022 1023 while len(new_opcodes) > 0: 1024 for op in new_opcodes: 1025 rep = self.rep[op] 1026 table = self.table[op] 1027 op_worklist_index = worklist_indices[op] 1028 if op in conv_opcode_types: 1029 num_srcs = 1 1030 else: 1031 num_srcs = opcodes[op].num_inputs 1032 1033 # Iterate over all possible source combinations where at least one 1034 # is on the worklist. 1035 for src_indices in itertools.product(range(len(rep)), repeat=num_srcs): 1036 if all(src_idx < op_worklist_index for src_idx in src_indices): 1037 continue 1038 1039 srcs = tuple(rep[src_idx] for src_idx in src_indices) 1040 1041 # Try all possible pairings of source items and add the 1042 # corresponding parent items. This is Comp_a from the paper. 1043 parent = set(self.items[op, item_srcs] for item_srcs in 1044 itertools.product(*srcs) if (op, item_srcs) in self.items) 1045 1046 # We could always start matching something else with a 1047 # wildcard. This is Cl from the paper. 1048 parent.add(self.wildcard) 1049 1050 table[src_indices] = self.states.add(frozenset(parent)) 1051 worklist_indices[op] = len(rep) 1052 new_opcodes.clear() 1053 process_new_states() 1054 1055_algebraic_pass_template = mako.template.Template(""" 1056#include "nir.h" 1057#include "nir_builder.h" 1058#include "nir_search.h" 1059#include "nir_search_helpers.h" 1060 1061/* What follows is NIR algebraic transform code for the following ${len(xforms)} 1062 * transforms: 1063% for xform in xforms: 1064 * ${xform.search} => ${xform.replace} 1065% endfor 1066 */ 1067 1068<% cache = {} %> 1069% for xform in xforms: 1070 ${xform.search.render(cache)} 1071 ${xform.replace.render(cache)} 1072% endfor 1073 1074% for state_id, state_xforms in enumerate(automaton.state_patterns): 1075% if state_xforms: # avoid emitting a 0-length array for MSVC 1076static const struct transform ${pass_name}_state${state_id}_xforms[] = { 1077% for i in state_xforms: 1078 { ${xforms[i].search.c_ptr(cache)}, ${xforms[i].replace.c_value_ptr(cache)}, ${xforms[i].condition_index} }, 1079% endfor 1080}; 1081% endif 1082% endfor 1083 1084static const struct per_op_table ${pass_name}_table[nir_num_search_ops] = { 1085% for op in automaton.opcodes: 1086 [${get_c_opcode(op)}] = { 1087 .filter = (uint16_t []) { 1088 % for e in automaton.filter[op]: 1089 ${e}, 1090 % endfor 1091 }, 1092 <% 1093 num_filtered = len(automaton.rep[op]) 1094 %> 1095 .num_filtered_states = ${num_filtered}, 1096 .table = (uint16_t []) { 1097 <% 1098 num_srcs = len(next(iter(automaton.table[op]))) 1099 %> 1100 % for indices in itertools.product(range(num_filtered), repeat=num_srcs): 1101 ${automaton.table[op][indices]}, 1102 % endfor 1103 }, 1104 }, 1105% endfor 1106}; 1107 1108const struct transform *${pass_name}_transforms[] = { 1109% for i in range(len(automaton.state_patterns)): 1110 % if automaton.state_patterns[i]: 1111 ${pass_name}_state${i}_xforms, 1112 % else: 1113 NULL, 1114 % endif 1115% endfor 1116}; 1117 1118const uint16_t ${pass_name}_transform_counts[] = { 1119% for i in range(len(automaton.state_patterns)): 1120 % if automaton.state_patterns[i]: 1121 (uint16_t)ARRAY_SIZE(${pass_name}_state${i}_xforms), 1122 % else: 1123 0, 1124 % endif 1125% endfor 1126}; 1127 1128bool 1129${pass_name}(nir_shader *shader) 1130{ 1131 bool progress = false; 1132 bool condition_flags[${len(condition_list)}]; 1133 const nir_shader_compiler_options *options = shader->options; 1134 const shader_info *info = &shader->info; 1135 (void) options; 1136 (void) info; 1137 1138 % for index, condition in enumerate(condition_list): 1139 condition_flags[${index}] = ${condition}; 1140 % endfor 1141 1142 nir_foreach_function(function, shader) { 1143 if (function->impl) { 1144 progress |= nir_algebraic_impl(function->impl, condition_flags, 1145 ${pass_name}_transforms, 1146 ${pass_name}_transform_counts, 1147 ${pass_name}_table); 1148 } 1149 } 1150 1151 return progress; 1152} 1153""") 1154 1155 1156class AlgebraicPass(object): 1157 def __init__(self, pass_name, transforms): 1158 self.xforms = [] 1159 self.opcode_xforms = defaultdict(lambda : []) 1160 self.pass_name = pass_name 1161 1162 error = False 1163 1164 for xform in transforms: 1165 if not isinstance(xform, SearchAndReplace): 1166 try: 1167 xform = SearchAndReplace(xform) 1168 except: 1169 print("Failed to parse transformation:", file=sys.stderr) 1170 print(" " + str(xform), file=sys.stderr) 1171 traceback.print_exc(file=sys.stderr) 1172 print('', file=sys.stderr) 1173 error = True 1174 continue 1175 1176 self.xforms.append(xform) 1177 if xform.search.opcode in conv_opcode_types: 1178 dst_type = conv_opcode_types[xform.search.opcode] 1179 for size in type_sizes(dst_type): 1180 sized_opcode = xform.search.opcode + str(size) 1181 self.opcode_xforms[sized_opcode].append(xform) 1182 else: 1183 self.opcode_xforms[xform.search.opcode].append(xform) 1184 1185 # Check to make sure the search pattern does not unexpectedly contain 1186 # more commutative expressions than match_expression (nir_search.c) 1187 # can handle. 1188 comm_exprs = xform.search.comm_exprs 1189 1190 if xform.search.many_commutative_expressions: 1191 if comm_exprs <= nir_search_max_comm_ops: 1192 print("Transform expected to have too many commutative " \ 1193 "expression but did not " \ 1194 "({} <= {}).".format(comm_exprs, nir_search_max_comm_op), 1195 file=sys.stderr) 1196 print(" " + str(xform), file=sys.stderr) 1197 traceback.print_exc(file=sys.stderr) 1198 print('', file=sys.stderr) 1199 error = True 1200 else: 1201 if comm_exprs > nir_search_max_comm_ops: 1202 print("Transformation with too many commutative expressions " \ 1203 "({} > {}). Modify pattern or annotate with " \ 1204 "\"many-comm-expr\".".format(comm_exprs, 1205 nir_search_max_comm_ops), 1206 file=sys.stderr) 1207 print(" " + str(xform.search), file=sys.stderr) 1208 print("{}".format(xform.search.cond), file=sys.stderr) 1209 error = True 1210 1211 self.automaton = TreeAutomaton(self.xforms) 1212 1213 if error: 1214 sys.exit(1) 1215 1216 1217 def render(self): 1218 return _algebraic_pass_template.render(pass_name=self.pass_name, 1219 xforms=self.xforms, 1220 opcode_xforms=self.opcode_xforms, 1221 condition_list=condition_list, 1222 automaton=self.automaton, 1223 get_c_opcode=get_c_opcode, 1224 itertools=itertools) 1225