1# 2# Copyright (C) 2014 Intel Corporation 3# 4# Permission is hereby granted, free of charge, to any person obtaining a 5# copy of this software and associated documentation files (the "Software"), 6# to deal in the Software without restriction, including without limitation 7# the rights to use, copy, modify, merge, publish, distribute, sublicense, 8# and/or sell copies of the Software, and to permit persons to whom the 9# Software is furnished to do so, subject to the following conditions: 10# 11# The above copyright notice and this permission notice (including the next 12# paragraph) shall be included in all copies or substantial portions of the 13# Software. 14# 15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21# IN THE SOFTWARE. 22# 23# Authors: 24# Jason Ekstrand (jason@jlekstrand.net) 25 26import ast 27from collections import defaultdict 28import itertools 29import struct 30import sys 31import mako.template 32import re 33import traceback 34 35from nir_opcodes import opcodes, type_sizes 36 37# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c 38nir_search_max_comm_ops = 8 39 40# These opcodes are only employed by nir_search. This provides a mapping from 41# opcode to destination type. 42conv_opcode_types = { 43 'i2f' : 'float', 44 'u2f' : 'float', 45 'f2f' : 'float', 46 'f2u' : 'uint', 47 'f2i' : 'int', 48 'u2u' : 'uint', 49 'i2i' : 'int', 50 'b2f' : 'float', 51 'b2i' : 'int', 52 'i2b' : 'bool', 53 'f2b' : 'bool', 54} 55 56def get_c_opcode(op): 57 if op in conv_opcode_types: 58 return 'nir_search_op_' + op 59 else: 60 return 'nir_op_' + op 61 62_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?") 63 64def type_bits(type_str): 65 m = _type_re.match(type_str) 66 assert m.group('type') 67 68 if m.group('bits') is None: 69 return 0 70 else: 71 return int(m.group('bits')) 72 73# Represents a set of variables, each with a unique id 74class VarSet(object): 75 def __init__(self): 76 self.names = {} 77 self.ids = itertools.count() 78 self.immutable = False; 79 80 def __getitem__(self, name): 81 if name not in self.names: 82 assert not self.immutable, "Unknown replacement variable: " + name 83 self.names[name] = next(self.ids) 84 85 return self.names[name] 86 87 def lock(self): 88 self.immutable = True 89 90class Value(object): 91 @staticmethod 92 def create(val, name_base, varset): 93 if isinstance(val, bytes): 94 val = val.decode('utf-8') 95 96 if isinstance(val, tuple): 97 return Expression(val, name_base, varset) 98 elif isinstance(val, Expression): 99 return val 100 elif isinstance(val, str): 101 return Variable(val, name_base, varset) 102 elif isinstance(val, (bool, float, int)): 103 return Constant(val, name_base) 104 105 def __init__(self, val, name, type_str): 106 self.in_val = str(val) 107 self.name = name 108 self.type_str = type_str 109 110 def __str__(self): 111 return self.in_val 112 113 def get_bit_size(self): 114 """Get the physical bit-size that has been chosen for this value, or if 115 there is none, the canonical value which currently represents this 116 bit-size class. Variables will be preferred, i.e. if there are any 117 variables in the equivalence class, the canonical value will be a 118 variable. We do this since we'll need to know which variable each value 119 is equivalent to when constructing the replacement expression. This is 120 the "find" part of the union-find algorithm. 121 """ 122 bit_size = self 123 124 while isinstance(bit_size, Value): 125 if bit_size._bit_size is None: 126 break 127 bit_size = bit_size._bit_size 128 129 if bit_size is not self: 130 self._bit_size = bit_size 131 return bit_size 132 133 def set_bit_size(self, other): 134 """Make self.get_bit_size() return what other.get_bit_size() return 135 before calling this, or just "other" if it's a concrete bit-size. This is 136 the "union" part of the union-find algorithm. 137 """ 138 139 self_bit_size = self.get_bit_size() 140 other_bit_size = other if isinstance(other, int) else other.get_bit_size() 141 142 if self_bit_size == other_bit_size: 143 return 144 145 self_bit_size._bit_size = other_bit_size 146 147 @property 148 def type_enum(self): 149 return "nir_search_value_" + self.type_str 150 151 @property 152 def c_type(self): 153 return "nir_search_" + self.type_str 154 155 def __c_name(self, cache): 156 if cache is not None and self.name in cache: 157 return cache[self.name] 158 else: 159 return self.name 160 161 def c_value_ptr(self, cache): 162 return "&{0}.value".format(self.__c_name(cache)) 163 164 def c_ptr(self, cache): 165 return "&{0}".format(self.__c_name(cache)) 166 167 @property 168 def c_bit_size(self): 169 bit_size = self.get_bit_size() 170 if isinstance(bit_size, int): 171 return bit_size 172 elif isinstance(bit_size, Variable): 173 return -bit_size.index - 1 174 else: 175 # If the bit-size class is neither a variable, nor an actual bit-size, then 176 # - If it's in the search expression, we don't need to check anything 177 # - If it's in the replace expression, either it's ambiguous (in which 178 # case we'd reject it), or it equals the bit-size of the search value 179 # We represent these cases with a 0 bit-size. 180 return 0 181 182 __template = mako.template.Template("""{ 183 { ${val.type_enum}, ${val.c_bit_size} }, 184% if isinstance(val, Constant): 185 ${val.type()}, { ${val.hex()} /* ${val.value} */ }, 186% elif isinstance(val, Variable): 187 ${val.index}, /* ${val.var_name} */ 188 ${'true' if val.is_constant else 'false'}, 189 ${val.type() or 'nir_type_invalid' }, 190 ${val.cond if val.cond else 'NULL'}, 191 ${val.swizzle()}, 192% elif isinstance(val, Expression): 193 ${'true' if val.inexact else 'false'}, ${'true' if val.exact else 'false'}, 194 ${val.comm_expr_idx}, ${val.comm_exprs}, 195 ${val.c_opcode()}, 196 { ${', '.join(src.c_value_ptr(cache) for src in val.sources)} }, 197 ${val.cond if val.cond else 'NULL'}, 198% endif 199};""") 200 201 def render(self, cache): 202 struct_init = self.__template.render(val=self, cache=cache, 203 Constant=Constant, 204 Variable=Variable, 205 Expression=Expression) 206 if cache is not None and struct_init in cache: 207 # If it's in the cache, register a name remap in the cache and render 208 # only a comment saying it's been remapped 209 cache[self.name] = cache[struct_init] 210 return "/* {} -> {} in the cache */\n".format(self.name, 211 cache[struct_init]) 212 else: 213 if cache is not None: 214 cache[struct_init] = self.name 215 return "static const {} {} = {}\n".format(self.c_type, self.name, 216 struct_init) 217 218_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?") 219 220class Constant(Value): 221 def __init__(self, val, name): 222 Value.__init__(self, val, name, "constant") 223 224 if isinstance(val, (str)): 225 m = _constant_re.match(val) 226 self.value = ast.literal_eval(m.group('value')) 227 self._bit_size = int(m.group('bits')) if m.group('bits') else None 228 else: 229 self.value = val 230 self._bit_size = None 231 232 if isinstance(self.value, bool): 233 assert self._bit_size is None or self._bit_size == 1 234 self._bit_size = 1 235 236 def hex(self): 237 if isinstance(self.value, (bool)): 238 return 'NIR_TRUE' if self.value else 'NIR_FALSE' 239 if isinstance(self.value, int): 240 return hex(self.value) 241 elif isinstance(self.value, float): 242 return hex(struct.unpack('Q', struct.pack('d', self.value))[0]) 243 else: 244 assert False 245 246 def type(self): 247 if isinstance(self.value, (bool)): 248 return "nir_type_bool" 249 elif isinstance(self.value, int): 250 return "nir_type_int" 251 elif isinstance(self.value, float): 252 return "nir_type_float" 253 254 def equivalent(self, other): 255 """Check that two constants are equivalent. 256 257 This is check is much weaker than equality. One generally cannot be 258 used in place of the other. Using this implementation for the __eq__ 259 will break BitSizeValidator. 260 261 """ 262 if not isinstance(other, type(self)): 263 return False 264 265 return self.value == other.value 266 267# The $ at the end forces there to be an error if any part of the string 268# doesn't match one of the field patterns. 269_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)" 270 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?" 271 r"(?P<cond>\([^\)]+\))?" 272 r"(?P<swiz>\.[xyzw]+)?" 273 r"$") 274 275class Variable(Value): 276 def __init__(self, val, name, varset): 277 Value.__init__(self, val, name, "variable") 278 279 m = _var_name_re.match(val) 280 assert m and m.group('name') is not None, \ 281 "Malformed variable name \"{}\".".format(val) 282 283 self.var_name = m.group('name') 284 285 # Prevent common cases where someone puts quotes around a literal 286 # constant. If we want to support names that have numeric or 287 # punctuation characters, we can me the first assertion more flexible. 288 assert self.var_name.isalpha() 289 assert self.var_name != 'True' 290 assert self.var_name != 'False' 291 292 self.is_constant = m.group('const') is not None 293 self.cond = m.group('cond') 294 self.required_type = m.group('type') 295 self._bit_size = int(m.group('bits')) if m.group('bits') else None 296 self.swiz = m.group('swiz') 297 298 if self.required_type == 'bool': 299 if self._bit_size is not None: 300 assert self._bit_size in type_sizes(self.required_type) 301 else: 302 self._bit_size = 1 303 304 if self.required_type is not None: 305 assert self.required_type in ('float', 'bool', 'int', 'uint') 306 307 self.index = varset[self.var_name] 308 309 def type(self): 310 if self.required_type == 'bool': 311 return "nir_type_bool" 312 elif self.required_type in ('int', 'uint'): 313 return "nir_type_int" 314 elif self.required_type == 'float': 315 return "nir_type_float" 316 317 def equivalent(self, other): 318 """Check that two variables are equivalent. 319 320 This is check is much weaker than equality. One generally cannot be 321 used in place of the other. Using this implementation for the __eq__ 322 will break BitSizeValidator. 323 324 """ 325 if not isinstance(other, type(self)): 326 return False 327 328 return self.index == other.index 329 330 def swizzle(self): 331 if self.swiz is not None: 332 swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3, 333 'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3, 334 'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7, 335 'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11, 336 'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 } 337 return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}' 338 return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}' 339 340_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?" 341 r"(?P<cond>\([^\)]+\))?") 342 343class Expression(Value): 344 def __init__(self, expr, name_base, varset): 345 Value.__init__(self, expr, name_base, "expression") 346 assert isinstance(expr, tuple) 347 348 m = _opcode_re.match(expr[0]) 349 assert m and m.group('opcode') is not None 350 351 self.opcode = m.group('opcode') 352 self._bit_size = int(m.group('bits')) if m.group('bits') else None 353 self.inexact = m.group('inexact') is not None 354 self.exact = m.group('exact') is not None 355 self.cond = m.group('cond') 356 357 assert not self.inexact or not self.exact, \ 358 'Expression cannot be both exact and inexact.' 359 360 # "many-comm-expr" isn't really a condition. It's notification to the 361 # generator that this pattern is known to have too many commutative 362 # expressions, and an error should not be generated for this case. 363 self.many_commutative_expressions = False 364 if self.cond and self.cond.find("many-comm-expr") >= 0: 365 # Split the condition into a comma-separated list. Remove 366 # "many-comm-expr". If there is anything left, put it back together. 367 c = self.cond[1:-1].split(",") 368 c.remove("many-comm-expr") 369 370 self.cond = "({})".format(",".join(c)) if c else None 371 self.many_commutative_expressions = True 372 373 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset) 374 for (i, src) in enumerate(expr[1:]) ] 375 376 # nir_search_expression::srcs is hard-coded to 4 377 assert len(self.sources) <= 4 378 379 if self.opcode in conv_opcode_types: 380 assert self._bit_size is None, \ 381 'Expression cannot use an unsized conversion opcode with ' \ 382 'an explicit size; that\'s silly.' 383 384 self.__index_comm_exprs(0) 385 386 def equivalent(self, other): 387 """Check that two variables are equivalent. 388 389 This is check is much weaker than equality. One generally cannot be 390 used in place of the other. Using this implementation for the __eq__ 391 will break BitSizeValidator. 392 393 This implementation does not check for equivalence due to commutativity, 394 but it could. 395 396 """ 397 if not isinstance(other, type(self)): 398 return False 399 400 if len(self.sources) != len(other.sources): 401 return False 402 403 if self.opcode != other.opcode: 404 return False 405 406 return all(s.equivalent(o) for s, o in zip(self.sources, other.sources)) 407 408 def __index_comm_exprs(self, base_idx): 409 """Recursively count and index commutative expressions 410 """ 411 self.comm_exprs = 0 412 413 # A note about the explicit "len(self.sources)" check. The list of 414 # sources comes from user input, and that input might be bad. Check 415 # that the expected second source exists before accessing it. Without 416 # this check, a unit test that does "('iadd', 'a')" will crash. 417 if self.opcode not in conv_opcode_types and \ 418 "2src_commutative" in opcodes[self.opcode].algebraic_properties and \ 419 len(self.sources) >= 2 and \ 420 not self.sources[0].equivalent(self.sources[1]): 421 self.comm_expr_idx = base_idx 422 self.comm_exprs += 1 423 else: 424 self.comm_expr_idx = -1 425 426 for s in self.sources: 427 if isinstance(s, Expression): 428 s.__index_comm_exprs(base_idx + self.comm_exprs) 429 self.comm_exprs += s.comm_exprs 430 431 return self.comm_exprs 432 433 def c_opcode(self): 434 return get_c_opcode(self.opcode) 435 436 def render(self, cache): 437 srcs = "\n".join(src.render(cache) for src in self.sources) 438 return srcs + super(Expression, self).render(cache) 439 440class BitSizeValidator(object): 441 """A class for validating bit sizes of expressions. 442 443 NIR supports multiple bit-sizes on expressions in order to handle things 444 such as fp64. The source and destination of every ALU operation is 445 assigned a type and that type may or may not specify a bit size. Sources 446 and destinations whose type does not specify a bit size are considered 447 "unsized" and automatically take on the bit size of the corresponding 448 register or SSA value. NIR has two simple rules for bit sizes that are 449 validated by nir_validator: 450 451 1) A given SSA def or register has a single bit size that is respected by 452 everything that reads from it or writes to it. 453 454 2) The bit sizes of all unsized inputs/outputs on any given ALU 455 instruction must match. They need not match the sized inputs or 456 outputs but they must match each other. 457 458 In order to keep nir_algebraic relatively simple and easy-to-use, 459 nir_search supports a type of bit-size inference based on the two rules 460 above. This is similar to type inference in many common programming 461 languages. If, for instance, you are constructing an add operation and you 462 know the second source is 16-bit, then you know that the other source and 463 the destination must also be 16-bit. There are, however, cases where this 464 inference can be ambiguous or contradictory. Consider, for instance, the 465 following transformation: 466 467 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b))) 468 469 This transformation can potentially cause a problem because usub_borrow is 470 well-defined for any bit-size of integer. However, b2i always generates a 471 32-bit result so it could end up replacing a 64-bit expression with one 472 that takes two 64-bit values and produces a 32-bit value. As another 473 example, consider this expression: 474 475 (('bcsel', a, b, 0), ('iand', a, b)) 476 477 In this case, in the search expression a must be 32-bit but b can 478 potentially have any bit size. If we had a 64-bit b value, we would end up 479 trying to and a 32-bit value with a 64-bit value which would be invalid 480 481 This class solves that problem by providing a validation layer that proves 482 that a given search-and-replace operation is 100% well-defined before we 483 generate any code. This ensures that bugs are caught at compile time 484 rather than at run time. 485 486 Each value maintains a "bit-size class", which is either an actual bit size 487 or an equivalence class with other values that must have the same bit size. 488 The validator works by combining bit-size classes with each other according 489 to the NIR rules outlined above, checking that there are no inconsistencies. 490 When doing this for the replacement expression, we make sure to never change 491 the equivalence class of any of the search values. We could make the example 492 transforms above work by doing some extra run-time checking of the search 493 expression, but we make the user specify those constraints themselves, to 494 avoid any surprises. Since the replacement bitsizes can only be connected to 495 the source bitsize via variables (variables must have the same bitsize in 496 the source and replacment expressions) or the roots of the expression (the 497 replacement expression must produce the same bit size as the search 498 expression), we prevent merging a variable with anything when processing the 499 replacement expression, or specializing the search bitsize 500 with anything. The former prevents 501 502 (('bcsel', a, b, 0), ('iand', a, b)) 503 504 from being allowed, since we'd have to merge the bitsizes for a and b due to 505 the 'iand', while the latter prevents 506 507 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b))) 508 509 from being allowed, since the search expression has the bit size of a and b, 510 which can't be specialized to 32 which is the bitsize of the replace 511 expression. It also prevents something like: 512 513 (('b2i', ('i2b', a)), ('ineq', a, 0)) 514 515 since the bitsize of 'b2i', which can be anything, can't be specialized to 516 the bitsize of a. 517 518 After doing all this, we check that every subexpression of the replacement 519 was assigned a constant bitsize, the bitsize of a variable, or the bitsize 520 of the search expresssion, since those are the things that are known when 521 constructing the replacement expresssion. Finally, we record the bitsize 522 needed in nir_search_value so that we know what to do when building the 523 replacement expression. 524 """ 525 526 def __init__(self, varset): 527 self._var_classes = [None] * len(varset.names) 528 529 def compare_bitsizes(self, a, b): 530 """Determines which bitsize class is a specialization of the other, or 531 whether neither is. When we merge two different bitsizes, the 532 less-specialized bitsize always points to the more-specialized one, so 533 that calling get_bit_size() always gets you the most specialized bitsize. 534 The specialization partial order is given by: 535 - Physical bitsizes are always the most specialized, and a different 536 bitsize can never specialize another. 537 - In the search expression, variables can always be specialized to each 538 other and to physical bitsizes. In the replace expression, we disallow 539 this to avoid adding extra constraints to the search expression that 540 the user didn't specify. 541 - Expressions and constants without a bitsize can always be specialized to 542 each other and variables, but not the other way around. 543 544 We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b, 545 and None if they are not comparable (neither a <= b nor b <= a). 546 """ 547 if isinstance(a, int): 548 if isinstance(b, int): 549 return 0 if a == b else None 550 elif isinstance(b, Variable): 551 return -1 if self.is_search else None 552 else: 553 return -1 554 elif isinstance(a, Variable): 555 if isinstance(b, int): 556 return 1 if self.is_search else None 557 elif isinstance(b, Variable): 558 return 0 if self.is_search or a.index == b.index else None 559 else: 560 return -1 561 else: 562 if isinstance(b, int): 563 return 1 564 elif isinstance(b, Variable): 565 return 1 566 else: 567 return 0 568 569 def unify_bit_size(self, a, b, error_msg): 570 """Record that a must have the same bit-size as b. If both 571 have been assigned conflicting physical bit-sizes, call "error_msg" with 572 the bit-sizes of self and other to get a message and raise an error. 573 In the replace expression, disallow merging variables with other 574 variables and physical bit-sizes as well. 575 """ 576 a_bit_size = a.get_bit_size() 577 b_bit_size = b if isinstance(b, int) else b.get_bit_size() 578 579 cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size) 580 581 assert cmp_result is not None, \ 582 error_msg(a_bit_size, b_bit_size) 583 584 if cmp_result < 0: 585 b_bit_size.set_bit_size(a) 586 elif not isinstance(a_bit_size, int): 587 a_bit_size.set_bit_size(b) 588 589 def merge_variables(self, val): 590 """Perform the first part of type inference by merging all the different 591 uses of the same variable. We always do this as if we're in the search 592 expression, even if we're actually not, since otherwise we'd get errors 593 if the search expression specified some constraint but the replace 594 expression didn't, because we'd be merging a variable and a constant. 595 """ 596 if isinstance(val, Variable): 597 if self._var_classes[val.index] is None: 598 self._var_classes[val.index] = val 599 else: 600 other = self._var_classes[val.index] 601 self.unify_bit_size(other, val, 602 lambda other_bit_size, bit_size: 603 'Variable {} has conflicting bit size requirements: ' \ 604 'it must have bit size {} and {}'.format( 605 val.var_name, other_bit_size, bit_size)) 606 elif isinstance(val, Expression): 607 for src in val.sources: 608 self.merge_variables(src) 609 610 def validate_value(self, val): 611 """Validate the an expression by performing classic Hindley-Milner 612 type inference on bitsizes. This will detect if there are any conflicting 613 requirements, and unify variables so that we know which variables must 614 have the same bitsize. If we're operating on the replace expression, we 615 will refuse to merge different variables together or merge a variable 616 with a constant, in order to prevent surprises due to rules unexpectedly 617 not matching at runtime. 618 """ 619 if not isinstance(val, Expression): 620 return 621 622 # Generic conversion ops are special in that they have a single unsized 623 # source and an unsized destination and the two don't have to match. 624 # This means there's no validation or unioning to do here besides the 625 # len(val.sources) check. 626 if val.opcode in conv_opcode_types: 627 assert len(val.sources) == 1, \ 628 "Expression {} has {} sources, expected 1".format( 629 val, len(val.sources)) 630 self.validate_value(val.sources[0]) 631 return 632 633 nir_op = opcodes[val.opcode] 634 assert len(val.sources) == nir_op.num_inputs, \ 635 "Expression {} has {} sources, expected {}".format( 636 val, len(val.sources), nir_op.num_inputs) 637 638 for src in val.sources: 639 self.validate_value(src) 640 641 dst_type_bits = type_bits(nir_op.output_type) 642 643 # First, unify all the sources. That way, an error coming up because two 644 # sources have an incompatible bit-size won't produce an error message 645 # involving the destination. 646 first_unsized_src = None 647 for src_type, src in zip(nir_op.input_types, val.sources): 648 src_type_bits = type_bits(src_type) 649 if src_type_bits == 0: 650 if first_unsized_src is None: 651 first_unsized_src = src 652 continue 653 654 if self.is_search: 655 self.unify_bit_size(first_unsized_src, src, 656 lambda first_unsized_src_bit_size, src_bit_size: 657 'Source {} of {} must have bit size {}, while source {} ' \ 658 'must have incompatible bit size {}'.format( 659 first_unsized_src, val, first_unsized_src_bit_size, 660 src, src_bit_size)) 661 else: 662 self.unify_bit_size(first_unsized_src, src, 663 lambda first_unsized_src_bit_size, src_bit_size: 664 'Sources {} (bit size of {}) and {} (bit size of {}) ' \ 665 'of {} may not have the same bit size when building the ' \ 666 'replacement expression.'.format( 667 first_unsized_src, first_unsized_src_bit_size, src, 668 src_bit_size, val)) 669 else: 670 if self.is_search: 671 self.unify_bit_size(src, src_type_bits, 672 lambda src_bit_size, unused: 673 '{} must have {} bits, but as a source of nir_op_{} '\ 674 'it must have {} bits'.format( 675 src, src_bit_size, nir_op.name, src_type_bits)) 676 else: 677 self.unify_bit_size(src, src_type_bits, 678 lambda src_bit_size, unused: 679 '{} has the bit size of {}, but as a source of ' \ 680 'nir_op_{} it must have {} bits, which may not be the ' \ 681 'same'.format( 682 src, src_bit_size, nir_op.name, src_type_bits)) 683 684 if dst_type_bits == 0: 685 if first_unsized_src is not None: 686 if self.is_search: 687 self.unify_bit_size(val, first_unsized_src, 688 lambda val_bit_size, src_bit_size: 689 '{} must have the bit size of {}, while its source {} ' \ 690 'must have incompatible bit size {}'.format( 691 val, val_bit_size, first_unsized_src, src_bit_size)) 692 else: 693 self.unify_bit_size(val, first_unsized_src, 694 lambda val_bit_size, src_bit_size: 695 '{} must have {} bits, but its source {} ' \ 696 '(bit size of {}) may not have that bit size ' \ 697 'when building the replacement.'.format( 698 val, val_bit_size, first_unsized_src, src_bit_size)) 699 else: 700 self.unify_bit_size(val, dst_type_bits, 701 lambda dst_bit_size, unused: 702 '{} must have {} bits, but as a destination of nir_op_{} ' \ 703 'it must have {} bits'.format( 704 val, dst_bit_size, nir_op.name, dst_type_bits)) 705 706 def validate_replace(self, val, search): 707 bit_size = val.get_bit_size() 708 assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \ 709 bit_size == search.get_bit_size(), \ 710 'Ambiguous bit size for replacement value {}: ' \ 711 'it cannot be deduced from a variable, a fixed bit size ' \ 712 'somewhere, or the search expression.'.format(val) 713 714 if isinstance(val, Expression): 715 for src in val.sources: 716 self.validate_replace(src, search) 717 718 def validate(self, search, replace): 719 self.is_search = True 720 self.merge_variables(search) 721 self.merge_variables(replace) 722 self.validate_value(search) 723 724 self.is_search = False 725 self.validate_value(replace) 726 727 # Check that search is always more specialized than replace. Note that 728 # we're doing this in replace mode, disallowing merging variables. 729 search_bit_size = search.get_bit_size() 730 replace_bit_size = replace.get_bit_size() 731 cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size) 732 733 assert cmp_result is not None and cmp_result <= 0, \ 734 'The search expression bit size {} and replace expression ' \ 735 'bit size {} may not be the same'.format( 736 search_bit_size, replace_bit_size) 737 738 replace.set_bit_size(search) 739 740 self.validate_replace(replace, search) 741 742_optimization_ids = itertools.count() 743 744condition_list = ['true'] 745 746class SearchAndReplace(object): 747 def __init__(self, transform): 748 self.id = next(_optimization_ids) 749 750 search = transform[0] 751 replace = transform[1] 752 if len(transform) > 2: 753 self.condition = transform[2] 754 else: 755 self.condition = 'true' 756 757 if self.condition not in condition_list: 758 condition_list.append(self.condition) 759 self.condition_index = condition_list.index(self.condition) 760 761 varset = VarSet() 762 if isinstance(search, Expression): 763 self.search = search 764 else: 765 self.search = Expression(search, "search{0}".format(self.id), varset) 766 767 varset.lock() 768 769 if isinstance(replace, Value): 770 self.replace = replace 771 else: 772 self.replace = Value.create(replace, "replace{0}".format(self.id), varset) 773 774 BitSizeValidator(varset).validate(self.search, self.replace) 775 776class TreeAutomaton(object): 777 """This class calculates a bottom-up tree automaton to quickly search for 778 the left-hand sides of tranforms. Tree automatons are a generalization of 779 classical NFA's and DFA's, where the transition function determines the 780 state of the parent node based on the state of its children. We construct a 781 deterministic automaton to match patterns, using a similar algorithm to the 782 classical NFA to DFA construction. At the moment, it only matches opcodes 783 and constants (without checking the actual value), leaving more detailed 784 checking to the search function which actually checks the leaves. The 785 automaton acts as a quick filter for the search function, requiring only n 786 + 1 table lookups for each n-source operation. The implementation is based 787 on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit." 788 In the language of that reference, this is a frontier-to-root deterministic 789 automaton using only symbol filtering. The filtering is crucial to reduce 790 both the time taken to generate the tables and the size of the tables. 791 """ 792 def __init__(self, transforms): 793 self.patterns = [t.search for t in transforms] 794 self._compute_items() 795 self._build_table() 796 #print('num items: {}'.format(len(set(self.items.values())))) 797 #print('num states: {}'.format(len(self.states))) 798 #for state, patterns in zip(self.states, self.patterns): 799 # print('{}: num patterns: {}'.format(state, len(patterns))) 800 801 class IndexMap(object): 802 """An indexed list of objects, where one can either lookup an object by 803 index or find the index associated to an object quickly using a hash 804 table. Compared to a list, it has a constant time index(). Compared to a 805 set, it provides a stable iteration order. 806 """ 807 def __init__(self, iterable=()): 808 self.objects = [] 809 self.map = {} 810 for obj in iterable: 811 self.add(obj) 812 813 def __getitem__(self, i): 814 return self.objects[i] 815 816 def __contains__(self, obj): 817 return obj in self.map 818 819 def __len__(self): 820 return len(self.objects) 821 822 def __iter__(self): 823 return iter(self.objects) 824 825 def clear(self): 826 self.objects = [] 827 self.map.clear() 828 829 def index(self, obj): 830 return self.map[obj] 831 832 def add(self, obj): 833 if obj in self.map: 834 return self.map[obj] 835 else: 836 index = len(self.objects) 837 self.objects.append(obj) 838 self.map[obj] = index 839 return index 840 841 def __repr__(self): 842 return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])' 843 844 class Item(object): 845 """This represents an "item" in the language of "Tree Automatons." This 846 is just a subtree of some pattern, which represents a potential partial 847 match at runtime. We deduplicate them, so that identical subtrees of 848 different patterns share the same object, and store some extra 849 information needed for the main algorithm as well. 850 """ 851 def __init__(self, opcode, children): 852 self.opcode = opcode 853 self.children = children 854 # These are the indices of patterns for which this item is the root node. 855 self.patterns = [] 856 # This the set of opcodes for parents of this item. Used to speed up 857 # filtering. 858 self.parent_ops = set() 859 860 def __str__(self): 861 return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')' 862 863 def __repr__(self): 864 return str(self) 865 866 def _compute_items(self): 867 """Build a set of all possible items, deduplicating them.""" 868 # This is a map from (opcode, sources) to item. 869 self.items = {} 870 871 # The set of all opcodes used by the patterns. Used later to avoid 872 # building and emitting all the tables for opcodes that aren't used. 873 self.opcodes = self.IndexMap() 874 875 def get_item(opcode, children, pattern=None): 876 commutative = len(children) >= 2 \ 877 and "2src_commutative" in opcodes[opcode].algebraic_properties 878 item = self.items.setdefault((opcode, children), 879 self.Item(opcode, children)) 880 if commutative: 881 self.items[opcode, (children[1], children[0]) + children[2:]] = item 882 if pattern is not None: 883 item.patterns.append(pattern) 884 return item 885 886 self.wildcard = get_item("__wildcard", ()) 887 self.const = get_item("__const", ()) 888 889 def process_subpattern(src, pattern=None): 890 if isinstance(src, Constant): 891 # Note: we throw away the actual constant value! 892 return self.const 893 elif isinstance(src, Variable): 894 if src.is_constant: 895 return self.const 896 else: 897 # Note: we throw away which variable it is here! This special 898 # item is equivalent to nu in "Tree Automatons." 899 return self.wildcard 900 else: 901 assert isinstance(src, Expression) 902 opcode = src.opcode 903 stripped = opcode.rstrip('0123456789') 904 if stripped in conv_opcode_types: 905 # Matches that use conversion opcodes with a specific type, 906 # like f2b1, are tricky. Either we construct the automaton to 907 # match specific NIR opcodes like nir_op_f2b1, in which case we 908 # need to create separate items for each possible NIR opcode 909 # for patterns that have a generic opcode like f2b, or we 910 # construct it to match the search opcode, in which case we 911 # need to map f2b1 to f2b when constructing the automaton. Here 912 # we do the latter. 913 opcode = stripped 914 self.opcodes.add(opcode) 915 children = tuple(process_subpattern(c) for c in src.sources) 916 item = get_item(opcode, children, pattern) 917 for i, child in enumerate(children): 918 child.parent_ops.add(opcode) 919 return item 920 921 for i, pattern in enumerate(self.patterns): 922 process_subpattern(pattern, i) 923 924 def _build_table(self): 925 """This is the core algorithm which builds up the transition table. It 926 is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl . 927 Comp_a and Filt_{a,i} using integers to identify match sets." It 928 simultaneously builds up a list of all possible "match sets" or 929 "states", where each match set represents the set of Item's that match a 930 given instruction, and builds up the transition table between states. 931 """ 932 # Map from opcode + filtered state indices to transitioned state. 933 self.table = defaultdict(dict) 934 # Bijection from state to index. q in the original algorithm is 935 # len(self.states) 936 self.states = self.IndexMap() 937 # List of pattern matches for each state index. 938 self.state_patterns = [] 939 # Map from state index to filtered state index for each opcode. 940 self.filter = defaultdict(list) 941 # Bijections from filtered state to filtered state index for each 942 # opcode, called the "representor sets" in the original algorithm. 943 # q_{a,j} in the original algorithm is len(self.rep[op]). 944 self.rep = defaultdict(self.IndexMap) 945 946 # Everything in self.states with a index at least worklist_index is part 947 # of the worklist of newly created states. There is also a worklist of 948 # newly fitered states for each opcode, for which worklist_indices 949 # serves a similar purpose. worklist_index corresponds to p in the 950 # original algorithm, while worklist_indices is p_{a,j} (although since 951 # we only filter by opcode/symbol, it's really just p_a). 952 self.worklist_index = 0 953 worklist_indices = defaultdict(lambda: 0) 954 955 # This is the set of opcodes for which the filtered worklist is non-empty. 956 # It's used to avoid scanning opcodes for which there is nothing to 957 # process when building the transition table. It corresponds to new_a in 958 # the original algorithm. 959 new_opcodes = self.IndexMap() 960 961 # Process states on the global worklist, filtering them for each opcode, 962 # updating the filter tables, and updating the filtered worklists if any 963 # new filtered states are found. Similar to ComputeRepresenterSets() in 964 # the original algorithm, although that only processes a single state. 965 def process_new_states(): 966 while self.worklist_index < len(self.states): 967 state = self.states[self.worklist_index] 968 969 # Calculate pattern matches for this state. Each pattern is 970 # assigned to a unique item, so we don't have to worry about 971 # deduplicating them here. However, we do have to sort them so 972 # that they're visited at runtime in the order they're specified 973 # in the source. 974 patterns = list(sorted(p for item in state for p in item.patterns)) 975 assert len(self.state_patterns) == self.worklist_index 976 self.state_patterns.append(patterns) 977 978 # calculate filter table for this state, and update filtered 979 # worklists. 980 for op in self.opcodes: 981 filt = self.filter[op] 982 rep = self.rep[op] 983 filtered = frozenset(item for item in state if \ 984 op in item.parent_ops) 985 if filtered in rep: 986 rep_index = rep.index(filtered) 987 else: 988 rep_index = rep.add(filtered) 989 new_opcodes.add(op) 990 assert len(filt) == self.worklist_index 991 filt.append(rep_index) 992 self.worklist_index += 1 993 994 # There are two start states: one which can only match as a wildcard, 995 # and one which can match as a wildcard or constant. These will be the 996 # states of intrinsics/other instructions and load_const instructions, 997 # respectively. The indices of these must match the definitions of 998 # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can 999 # initialize things correctly. 1000 self.states.add(frozenset((self.wildcard,))) 1001 self.states.add(frozenset((self.const,self.wildcard))) 1002 process_new_states() 1003 1004 while len(new_opcodes) > 0: 1005 for op in new_opcodes: 1006 rep = self.rep[op] 1007 table = self.table[op] 1008 op_worklist_index = worklist_indices[op] 1009 if op in conv_opcode_types: 1010 num_srcs = 1 1011 else: 1012 num_srcs = opcodes[op].num_inputs 1013 1014 # Iterate over all possible source combinations where at least one 1015 # is on the worklist. 1016 for src_indices in itertools.product(range(len(rep)), repeat=num_srcs): 1017 if all(src_idx < op_worklist_index for src_idx in src_indices): 1018 continue 1019 1020 srcs = tuple(rep[src_idx] for src_idx in src_indices) 1021 1022 # Try all possible pairings of source items and add the 1023 # corresponding parent items. This is Comp_a from the paper. 1024 parent = set(self.items[op, item_srcs] for item_srcs in 1025 itertools.product(*srcs) if (op, item_srcs) in self.items) 1026 1027 # We could always start matching something else with a 1028 # wildcard. This is Cl from the paper. 1029 parent.add(self.wildcard) 1030 1031 table[src_indices] = self.states.add(frozenset(parent)) 1032 worklist_indices[op] = len(rep) 1033 new_opcodes.clear() 1034 process_new_states() 1035 1036_algebraic_pass_template = mako.template.Template(""" 1037#include "nir.h" 1038#include "nir_builder.h" 1039#include "nir_search.h" 1040#include "nir_search_helpers.h" 1041 1042/* What follows is NIR algebraic transform code for the following ${len(xforms)} 1043 * transforms: 1044% for xform in xforms: 1045 * ${xform.search} => ${xform.replace} 1046% endfor 1047 */ 1048 1049<% cache = {} %> 1050% for xform in xforms: 1051 ${xform.search.render(cache)} 1052 ${xform.replace.render(cache)} 1053% endfor 1054 1055% for state_id, state_xforms in enumerate(automaton.state_patterns): 1056% if state_xforms: # avoid emitting a 0-length array for MSVC 1057static const struct transform ${pass_name}_state${state_id}_xforms[] = { 1058% for i in state_xforms: 1059 { ${xforms[i].search.c_ptr(cache)}, ${xforms[i].replace.c_value_ptr(cache)}, ${xforms[i].condition_index} }, 1060% endfor 1061}; 1062% endif 1063% endfor 1064 1065static const struct per_op_table ${pass_name}_table[nir_num_search_ops] = { 1066% for op in automaton.opcodes: 1067 [${get_c_opcode(op)}] = { 1068 .filter = (uint16_t []) { 1069 % for e in automaton.filter[op]: 1070 ${e}, 1071 % endfor 1072 }, 1073 <% 1074 num_filtered = len(automaton.rep[op]) 1075 %> 1076 .num_filtered_states = ${num_filtered}, 1077 .table = (uint16_t []) { 1078 <% 1079 num_srcs = len(next(iter(automaton.table[op]))) 1080 %> 1081 % for indices in itertools.product(range(num_filtered), repeat=num_srcs): 1082 ${automaton.table[op][indices]}, 1083 % endfor 1084 }, 1085 }, 1086% endfor 1087}; 1088 1089const struct transform *${pass_name}_transforms[] = { 1090% for i in range(len(automaton.state_patterns)): 1091 % if automaton.state_patterns[i]: 1092 ${pass_name}_state${i}_xforms, 1093 % else: 1094 NULL, 1095 % endif 1096% endfor 1097}; 1098 1099const uint16_t ${pass_name}_transform_counts[] = { 1100% for i in range(len(automaton.state_patterns)): 1101 % if automaton.state_patterns[i]: 1102 (uint16_t)ARRAY_SIZE(${pass_name}_state${i}_xforms), 1103 % else: 1104 0, 1105 % endif 1106% endfor 1107}; 1108 1109bool 1110${pass_name}(nir_shader *shader) 1111{ 1112 bool progress = false; 1113 bool condition_flags[${len(condition_list)}]; 1114 const nir_shader_compiler_options *options = shader->options; 1115 const shader_info *info = &shader->info; 1116 (void) options; 1117 (void) info; 1118 1119 % for index, condition in enumerate(condition_list): 1120 condition_flags[${index}] = ${condition}; 1121 % endfor 1122 1123 nir_foreach_function(function, shader) { 1124 if (function->impl) { 1125 progress |= nir_algebraic_impl(function->impl, condition_flags, 1126 ${pass_name}_transforms, 1127 ${pass_name}_transform_counts, 1128 ${pass_name}_table); 1129 } 1130 } 1131 1132 return progress; 1133} 1134""") 1135 1136 1137class AlgebraicPass(object): 1138 def __init__(self, pass_name, transforms): 1139 self.xforms = [] 1140 self.opcode_xforms = defaultdict(lambda : []) 1141 self.pass_name = pass_name 1142 1143 error = False 1144 1145 for xform in transforms: 1146 if not isinstance(xform, SearchAndReplace): 1147 try: 1148 xform = SearchAndReplace(xform) 1149 except: 1150 print("Failed to parse transformation:", file=sys.stderr) 1151 print(" " + str(xform), file=sys.stderr) 1152 traceback.print_exc(file=sys.stderr) 1153 print('', file=sys.stderr) 1154 error = True 1155 continue 1156 1157 self.xforms.append(xform) 1158 if xform.search.opcode in conv_opcode_types: 1159 dst_type = conv_opcode_types[xform.search.opcode] 1160 for size in type_sizes(dst_type): 1161 sized_opcode = xform.search.opcode + str(size) 1162 self.opcode_xforms[sized_opcode].append(xform) 1163 else: 1164 self.opcode_xforms[xform.search.opcode].append(xform) 1165 1166 # Check to make sure the search pattern does not unexpectedly contain 1167 # more commutative expressions than match_expression (nir_search.c) 1168 # can handle. 1169 comm_exprs = xform.search.comm_exprs 1170 1171 if xform.search.many_commutative_expressions: 1172 if comm_exprs <= nir_search_max_comm_ops: 1173 print("Transform expected to have too many commutative " \ 1174 "expression but did not " \ 1175 "({} <= {}).".format(comm_exprs, nir_search_max_comm_op), 1176 file=sys.stderr) 1177 print(" " + str(xform), file=sys.stderr) 1178 traceback.print_exc(file=sys.stderr) 1179 print('', file=sys.stderr) 1180 error = True 1181 else: 1182 if comm_exprs > nir_search_max_comm_ops: 1183 print("Transformation with too many commutative expressions " \ 1184 "({} > {}). Modify pattern or annotate with " \ 1185 "\"many-comm-expr\".".format(comm_exprs, 1186 nir_search_max_comm_ops), 1187 file=sys.stderr) 1188 print(" " + str(xform.search), file=sys.stderr) 1189 print("{}".format(xform.search.cond), file=sys.stderr) 1190 error = True 1191 1192 self.automaton = TreeAutomaton(self.xforms) 1193 1194 if error: 1195 sys.exit(1) 1196 1197 1198 def render(self): 1199 return _algebraic_pass_template.render(pass_name=self.pass_name, 1200 xforms=self.xforms, 1201 opcode_xforms=self.opcode_xforms, 1202 condition_list=condition_list, 1203 automaton=self.automaton, 1204 get_c_opcode=get_c_opcode, 1205 itertools=itertools) 1206