1# 2# Copyright (C) 2014 Intel Corporation 3# 4# Permission is hereby granted, free of charge, to any person obtaining a 5# copy of this software and associated documentation files (the "Software"), 6# to deal in the Software without restriction, including without limitation 7# the rights to use, copy, modify, merge, publish, distribute, sublicense, 8# and/or sell copies of the Software, and to permit persons to whom the 9# Software is furnished to do so, subject to the following conditions: 10# 11# The above copyright notice and this permission notice (including the next 12# paragraph) shall be included in all copies or substantial portions of the 13# Software. 14# 15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21# IN THE SOFTWARE. 22 23import ast 24from collections import defaultdict 25import itertools 26import struct 27import sys 28import mako.template 29import re 30import traceback 31 32from nir_opcodes import opcodes, type_sizes 33 34# This should be the same as NIR_SEARCH_MAX_COMM_OPS in nir_search.c 35nir_search_max_comm_ops = 8 36 37# These opcodes are only employed by nir_search. This provides a mapping from 38# opcode to destination type. 39conv_opcode_types = { 40 'i2f' : 'float', 41 'u2f' : 'float', 42 'f2f' : 'float', 43 'f2u' : 'uint', 44 'f2i' : 'int', 45 'u2u' : 'uint', 46 'i2i' : 'int', 47 'b2f' : 'float', 48 'b2i' : 'int', 49 'i2b' : 'bool', 50 'f2b' : 'bool', 51} 52 53def get_cond_index(conds, cond): 54 if cond: 55 if cond in conds: 56 return conds[cond] 57 else: 58 cond_index = len(conds) 59 conds[cond] = cond_index 60 return cond_index 61 else: 62 return -1 63 64def get_c_opcode(op): 65 if op in conv_opcode_types: 66 return 'nir_search_op_' + op 67 else: 68 return 'nir_op_' + op 69 70_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?") 71 72def type_bits(type_str): 73 m = _type_re.match(type_str) 74 assert m.group('type') 75 76 if m.group('bits') is None: 77 return 0 78 else: 79 return int(m.group('bits')) 80 81# Represents a set of variables, each with a unique id 82class VarSet(object): 83 def __init__(self): 84 self.names = {} 85 self.ids = itertools.count() 86 self.immutable = False; 87 88 def __getitem__(self, name): 89 if name not in self.names: 90 assert not self.immutable, "Unknown replacement variable: " + name 91 self.names[name] = next(self.ids) 92 93 return self.names[name] 94 95 def lock(self): 96 self.immutable = True 97 98class SearchExpression(object): 99 def __init__(self, expr): 100 self.opcode = expr[0] 101 self.sources = expr[1:] 102 self.ignore_exact = False 103 104 @staticmethod 105 def create(val): 106 if isinstance(val, tuple): 107 return SearchExpression(val) 108 else: 109 assert(isinstance(val, SearchExpression)) 110 return val 111 112 def __repr__(self): 113 l = [self.opcode, *self.sources] 114 if self.ignore_exact: 115 l.append('ignore_exact') 116 return repr((*l,)) 117 118class Value(object): 119 @staticmethod 120 def create(val, name_base, varset, algebraic_pass): 121 if isinstance(val, bytes): 122 val = val.decode('utf-8') 123 124 if isinstance(val, tuple) or isinstance(val, SearchExpression): 125 return Expression(val, name_base, varset, algebraic_pass) 126 elif isinstance(val, Expression): 127 return val 128 elif isinstance(val, str): 129 return Variable(val, name_base, varset, algebraic_pass) 130 elif isinstance(val, (bool, float, int)): 131 return Constant(val, name_base) 132 133 def __init__(self, val, name, type_str): 134 self.in_val = str(val) 135 self.name = name 136 self.type_str = type_str 137 138 def __str__(self): 139 return self.in_val 140 141 def get_bit_size(self): 142 """Get the physical bit-size that has been chosen for this value, or if 143 there is none, the canonical value which currently represents this 144 bit-size class. Variables will be preferred, i.e. if there are any 145 variables in the equivalence class, the canonical value will be a 146 variable. We do this since we'll need to know which variable each value 147 is equivalent to when constructing the replacement expression. This is 148 the "find" part of the union-find algorithm. 149 """ 150 bit_size = self 151 152 while isinstance(bit_size, Value): 153 if bit_size._bit_size is None: 154 break 155 bit_size = bit_size._bit_size 156 157 if bit_size is not self: 158 self._bit_size = bit_size 159 return bit_size 160 161 def set_bit_size(self, other): 162 """Make self.get_bit_size() return what other.get_bit_size() return 163 before calling this, or just "other" if it's a concrete bit-size. This is 164 the "union" part of the union-find algorithm. 165 """ 166 167 self_bit_size = self.get_bit_size() 168 other_bit_size = other if isinstance(other, int) else other.get_bit_size() 169 170 if self_bit_size == other_bit_size: 171 return 172 173 self_bit_size._bit_size = other_bit_size 174 175 @property 176 def type_enum(self): 177 return "nir_search_value_" + self.type_str 178 179 @property 180 def c_bit_size(self): 181 bit_size = self.get_bit_size() 182 if isinstance(bit_size, int): 183 return bit_size 184 elif isinstance(bit_size, Variable): 185 return -bit_size.index - 1 186 else: 187 # If the bit-size class is neither a variable, nor an actual bit-size, then 188 # - If it's in the search expression, we don't need to check anything 189 # - If it's in the replace expression, either it's ambiguous (in which 190 # case we'd reject it), or it equals the bit-size of the search value 191 # We represent these cases with a 0 bit-size. 192 return 0 193 194 __template = mako.template.Template(""" { .${val.type_str} = { 195 { ${val.type_enum}, ${val.c_bit_size} }, 196% if isinstance(val, Constant): 197 ${val.type()}, { ${val.hex()} /* ${val.value} */ }, 198% elif isinstance(val, Variable): 199 ${val.index}, /* ${val.var_name} */ 200 ${'true' if val.is_constant else 'false'}, 201 ${val.type() or 'nir_type_invalid' }, 202 ${val.cond_index}, 203 ${val.swizzle()}, 204% elif isinstance(val, Expression): 205 ${'true' if val.inexact else 'false'}, 206 ${'true' if val.exact else 'false'}, 207 ${'true' if val.ignore_exact else 'false'}, 208 ${'true' if val.nsz else 'false'}, 209 ${'true' if val.nnan else 'false'}, 210 ${'true' if val.ninf else 'false'}, 211 ${'true' if val.swizzle_y else 'false'}, 212 ${val.c_opcode()}, 213 ${val.comm_expr_idx}, ${val.comm_exprs}, 214 { ${', '.join(src.array_index for src in val.sources)} }, 215 ${val.cond_index}, 216% endif 217 } }, 218""") 219 220 def render(self, cache): 221 struct_init = self.__template.render(val=self, 222 Constant=Constant, 223 Variable=Variable, 224 Expression=Expression) 225 if struct_init in cache: 226 # If it's in the cache, register a name remap in the cache and render 227 # only a comment saying it's been remapped 228 self.array_index = cache[struct_init] 229 return " /* {} -> {} in the cache */\n".format(self.name, 230 cache[struct_init]) 231 else: 232 self.array_index = str(cache["next_index"]) 233 cache[struct_init] = self.array_index 234 cache["next_index"] += 1 235 return struct_init 236 237_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?") 238 239class Constant(Value): 240 def __init__(self, val, name): 241 Value.__init__(self, val, name, "constant") 242 243 if isinstance(val, (str)): 244 m = _constant_re.match(val) 245 self.value = ast.literal_eval(m.group('value')) 246 self._bit_size = int(m.group('bits')) if m.group('bits') else None 247 else: 248 self.value = val 249 self._bit_size = None 250 251 if isinstance(self.value, bool): 252 assert self._bit_size is None or self._bit_size == 1 253 self._bit_size = 1 254 255 def hex(self): 256 if isinstance(self.value, (bool)): 257 return 'NIR_TRUE' if self.value else 'NIR_FALSE' 258 if isinstance(self.value, int): 259 # Explicitly sign-extend negative integers to 64-bit, ensuring correct 260 # handling of -INT32_MIN which is not representable in 32-bit. 261 if self.value < 0: 262 return hex(struct.unpack('Q', struct.pack('q', self.value))[0]) + 'ull' 263 else: 264 return hex(self.value) + 'ull' 265 elif isinstance(self.value, float): 266 return hex(struct.unpack('Q', struct.pack('d', self.value))[0]) + 'ull' 267 else: 268 assert False 269 270 def type(self): 271 if isinstance(self.value, (bool)): 272 return "nir_type_bool" 273 elif isinstance(self.value, int): 274 return "nir_type_int" 275 elif isinstance(self.value, float): 276 return "nir_type_float" 277 278 def equivalent(self, other): 279 """Check that two constants are equivalent. 280 281 This is check is much weaker than equality. One generally cannot be 282 used in place of the other. Using this implementation for the __eq__ 283 will break BitSizeValidator. 284 285 """ 286 if not isinstance(other, type(self)): 287 return False 288 289 return self.value == other.value 290 291# The $ at the end forces there to be an error if any part of the string 292# doesn't match one of the field patterns. 293_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)" 294 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?" 295 r"(?P<cond>\([^\)]+\))?" 296 r"(?P<swiz>\.[xyzwabcdefghijklmnop]+)?" 297 r"$") 298 299class Variable(Value): 300 def __init__(self, val, name, varset, algebraic_pass): 301 Value.__init__(self, val, name, "variable") 302 303 m = _var_name_re.match(val) 304 assert m and m.group('name') is not None, \ 305 "Malformed variable name \"{}\".".format(val) 306 307 self.var_name = m.group('name') 308 309 # Prevent common cases where someone puts quotes around a literal 310 # constant. If we want to support names that have numeric or 311 # punctuation characters, we can me the first assertion more flexible. 312 assert self.var_name.isalpha() 313 assert self.var_name != 'True' 314 assert self.var_name != 'False' 315 316 self.is_constant = m.group('const') is not None 317 self.cond_index = get_cond_index(algebraic_pass.variable_cond, m.group('cond')) 318 self.required_type = m.group('type') 319 self._bit_size = int(m.group('bits')) if m.group('bits') else None 320 self.swiz = m.group('swiz') 321 322 if self.required_type == 'bool': 323 if self._bit_size is not None: 324 assert self._bit_size in type_sizes(self.required_type) 325 else: 326 self._bit_size = 1 327 328 if self.required_type is not None: 329 assert self.required_type in ('float', 'bool', 'int', 'uint') 330 331 self.index = varset[self.var_name] 332 333 def type(self): 334 if self.required_type == 'bool': 335 return "nir_type_bool" 336 elif self.required_type in ('int', 'uint'): 337 return "nir_type_int" 338 elif self.required_type == 'float': 339 return "nir_type_float" 340 341 def equivalent(self, other): 342 """Check that two variables are equivalent. 343 344 This is check is much weaker than equality. One generally cannot be 345 used in place of the other. Using this implementation for the __eq__ 346 will break BitSizeValidator. 347 348 """ 349 if not isinstance(other, type(self)): 350 return False 351 352 return self.index == other.index 353 354 def swizzle(self): 355 if self.swiz is not None: 356 swizzles = {'x' : 0, 'y' : 1, 'z' : 2, 'w' : 3, 357 'a' : 0, 'b' : 1, 'c' : 2, 'd' : 3, 358 'e' : 4, 'f' : 5, 'g' : 6, 'h' : 7, 359 'i' : 8, 'j' : 9, 'k' : 10, 'l' : 11, 360 'm' : 12, 'n' : 13, 'o' : 14, 'p' : 15 } 361 return '{' + ', '.join([str(swizzles[c]) for c in self.swiz[1:]]) + '}' 362 return '{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}' 363 364_opcode_re = re.compile(r"(?P<inexact>~)?(?P<exact>!)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?" 365 r"(?P<cond>\([^\)]+\))?(?P<swizzle_y>\.y)?") 366 367class Expression(Value): 368 def __init__(self, expr, name_base, varset, algebraic_pass): 369 Value.__init__(self, expr, name_base, "expression") 370 371 expr = SearchExpression.create(expr) 372 373 m = _opcode_re.match(expr.opcode) 374 assert m and m.group('opcode') is not None 375 376 self.opcode = m.group('opcode') 377 self._bit_size = int(m.group('bits')) if m.group('bits') else None 378 self.inexact = m.group('inexact') is not None 379 self.exact = m.group('exact') is not None 380 self.ignore_exact = expr.ignore_exact 381 self.cond = m.group('cond') 382 383 assert not self.inexact or not self.exact, \ 384 'Expression cannot be both exact and inexact.' 385 386 # "many-comm-expr" isn't really a condition. It's notification to the 387 # generator that this pattern is known to have too many commutative 388 # expressions, and an error should not be generated for this case. 389 # nsz, nnan and ninf are special conditions, so we treat them specially too. 390 cond = {k: True for k in self.cond[1:-1].split(",")} if self.cond else {} 391 self.many_commutative_expressions = cond.pop('many-comm-expr', False) 392 self.nsz = cond.pop('nsz', False) 393 self.nnan = cond.pop('nnan', False) 394 self.ninf = cond.pop('ninf', False) 395 self.swizzle_y = m.group('swizzle_y') is not None 396 397 assert len(cond) <= 1 398 self.cond = cond.popitem()[0] if cond else None 399 400 # Deduplicate references to the condition functions for the expressions 401 # and save the index for the order they were added. 402 self.cond_index = get_cond_index(algebraic_pass.expression_cond, self.cond) 403 404 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset, algebraic_pass) 405 for (i, src) in enumerate(expr.sources) ] 406 407 # nir_search_expression::srcs is hard-coded to 4 408 assert len(self.sources) <= 4 409 410 if self.opcode in conv_opcode_types: 411 assert self._bit_size is None, \ 412 'Expression cannot use an unsized conversion opcode with ' \ 413 'an explicit size; that\'s silly.' 414 415 self.__index_comm_exprs(0) 416 417 def equivalent(self, other): 418 """Check that two variables are equivalent. 419 420 This is check is much weaker than equality. One generally cannot be 421 used in place of the other. Using this implementation for the __eq__ 422 will break BitSizeValidator. 423 424 This implementation does not check for equivalence due to commutativity, 425 but it could. 426 427 """ 428 if not isinstance(other, type(self)): 429 return False 430 431 if len(self.sources) != len(other.sources): 432 return False 433 434 if self.opcode != other.opcode: 435 return False 436 437 return all(s.equivalent(o) for s, o in zip(self.sources, other.sources)) 438 439 def __index_comm_exprs(self, base_idx): 440 """Recursively count and index commutative expressions 441 """ 442 self.comm_exprs = 0 443 444 # A note about the explicit "len(self.sources)" check. The list of 445 # sources comes from user input, and that input might be bad. Check 446 # that the expected second source exists before accessing it. Without 447 # this check, a unit test that does "('iadd', 'a')" will crash. 448 if self.opcode not in conv_opcode_types and \ 449 "2src_commutative" in opcodes[self.opcode].algebraic_properties and \ 450 len(self.sources) >= 2 and \ 451 not self.sources[0].equivalent(self.sources[1]): 452 self.comm_expr_idx = base_idx 453 self.comm_exprs += 1 454 else: 455 self.comm_expr_idx = -1 456 457 for s in self.sources: 458 if isinstance(s, Expression): 459 s.__index_comm_exprs(base_idx + self.comm_exprs) 460 self.comm_exprs += s.comm_exprs 461 462 return self.comm_exprs 463 464 def c_opcode(self): 465 return get_c_opcode(self.opcode) 466 467 def render(self, cache): 468 srcs = "".join(src.render(cache) for src in self.sources) 469 return srcs + super(Expression, self).render(cache) 470 471class BitSizeValidator(object): 472 """A class for validating bit sizes of expressions. 473 474 NIR supports multiple bit-sizes on expressions in order to handle things 475 such as fp64. The source and destination of every ALU operation is 476 assigned a type and that type may or may not specify a bit size. Sources 477 and destinations whose type does not specify a bit size are considered 478 "unsized" and automatically take on the bit size of the corresponding 479 register or SSA value. NIR has two simple rules for bit sizes that are 480 validated by nir_validator: 481 482 1) A given SSA def or register has a single bit size that is respected by 483 everything that reads from it or writes to it. 484 485 2) The bit sizes of all unsized inputs/outputs on any given ALU 486 instruction must match. They need not match the sized inputs or 487 outputs but they must match each other. 488 489 In order to keep nir_algebraic relatively simple and easy-to-use, 490 nir_search supports a type of bit-size inference based on the two rules 491 above. This is similar to type inference in many common programming 492 languages. If, for instance, you are constructing an add operation and you 493 know the second source is 16-bit, then you know that the other source and 494 the destination must also be 16-bit. There are, however, cases where this 495 inference can be ambiguous or contradictory. Consider, for instance, the 496 following transformation: 497 498 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b))) 499 500 This transformation can potentially cause a problem because usub_borrow is 501 well-defined for any bit-size of integer. However, b2i always generates a 502 32-bit result so it could end up replacing a 64-bit expression with one 503 that takes two 64-bit values and produces a 32-bit value. As another 504 example, consider this expression: 505 506 (('bcsel', a, b, 0), ('iand', a, b)) 507 508 In this case, in the search expression a must be 32-bit but b can 509 potentially have any bit size. If we had a 64-bit b value, we would end up 510 trying to and a 32-bit value with a 64-bit value which would be invalid 511 512 This class solves that problem by providing a validation layer that proves 513 that a given search-and-replace operation is 100% well-defined before we 514 generate any code. This ensures that bugs are caught at compile time 515 rather than at run time. 516 517 Each value maintains a "bit-size class", which is either an actual bit size 518 or an equivalence class with other values that must have the same bit size. 519 The validator works by combining bit-size classes with each other according 520 to the NIR rules outlined above, checking that there are no inconsistencies. 521 When doing this for the replacement expression, we make sure to never change 522 the equivalence class of any of the search values. We could make the example 523 transforms above work by doing some extra run-time checking of the search 524 expression, but we make the user specify those constraints themselves, to 525 avoid any surprises. Since the replacement bitsizes can only be connected to 526 the source bitsize via variables (variables must have the same bitsize in 527 the source and replacment expressions) or the roots of the expression (the 528 replacement expression must produce the same bit size as the search 529 expression), we prevent merging a variable with anything when processing the 530 replacement expression, or specializing the search bitsize 531 with anything. The former prevents 532 533 (('bcsel', a, b, 0), ('iand', a, b)) 534 535 from being allowed, since we'd have to merge the bitsizes for a and b due to 536 the 'iand', while the latter prevents 537 538 (('usub_borrow', a, b), ('b2i@32', ('ult', a, b))) 539 540 from being allowed, since the search expression has the bit size of a and b, 541 which can't be specialized to 32 which is the bitsize of the replace 542 expression. It also prevents something like: 543 544 (('b2i', ('i2b', a)), ('ineq', a, 0)) 545 546 since the bitsize of 'b2i', which can be anything, can't be specialized to 547 the bitsize of a. 548 549 After doing all this, we check that every subexpression of the replacement 550 was assigned a constant bitsize, the bitsize of a variable, or the bitsize 551 of the search expresssion, since those are the things that are known when 552 constructing the replacement expresssion. Finally, we record the bitsize 553 needed in nir_search_value so that we know what to do when building the 554 replacement expression. 555 """ 556 557 def __init__(self, varset): 558 self._var_classes = [None] * len(varset.names) 559 560 def compare_bitsizes(self, a, b): 561 """Determines which bitsize class is a specialization of the other, or 562 whether neither is. When we merge two different bitsizes, the 563 less-specialized bitsize always points to the more-specialized one, so 564 that calling get_bit_size() always gets you the most specialized bitsize. 565 The specialization partial order is given by: 566 - Physical bitsizes are always the most specialized, and a different 567 bitsize can never specialize another. 568 - In the search expression, variables can always be specialized to each 569 other and to physical bitsizes. In the replace expression, we disallow 570 this to avoid adding extra constraints to the search expression that 571 the user didn't specify. 572 - Expressions and constants without a bitsize can always be specialized to 573 each other and variables, but not the other way around. 574 575 We return -1 if a <= b (b can be specialized to a), 0 if a = b, 1 if a >= b, 576 and None if they are not comparable (neither a <= b nor b <= a). 577 """ 578 if isinstance(a, int): 579 if isinstance(b, int): 580 return 0 if a == b else None 581 elif isinstance(b, Variable): 582 return -1 if self.is_search else None 583 else: 584 return -1 585 elif isinstance(a, Variable): 586 if isinstance(b, int): 587 return 1 if self.is_search else None 588 elif isinstance(b, Variable): 589 return 0 if self.is_search or a.index == b.index else None 590 else: 591 return -1 592 else: 593 if isinstance(b, int): 594 return 1 595 elif isinstance(b, Variable): 596 return 1 597 else: 598 return 0 599 600 def unify_bit_size(self, a, b, error_msg): 601 """Record that a must have the same bit-size as b. If both 602 have been assigned conflicting physical bit-sizes, call "error_msg" with 603 the bit-sizes of self and other to get a message and raise an error. 604 In the replace expression, disallow merging variables with other 605 variables and physical bit-sizes as well. 606 """ 607 a_bit_size = a.get_bit_size() 608 b_bit_size = b if isinstance(b, int) else b.get_bit_size() 609 610 cmp_result = self.compare_bitsizes(a_bit_size, b_bit_size) 611 612 assert cmp_result is not None, \ 613 error_msg(a_bit_size, b_bit_size) 614 615 if cmp_result < 0: 616 b_bit_size.set_bit_size(a) 617 elif not isinstance(a_bit_size, int): 618 a_bit_size.set_bit_size(b) 619 620 def merge_variables(self, val): 621 """Perform the first part of type inference by merging all the different 622 uses of the same variable. We always do this as if we're in the search 623 expression, even if we're actually not, since otherwise we'd get errors 624 if the search expression specified some constraint but the replace 625 expression didn't, because we'd be merging a variable and a constant. 626 """ 627 if isinstance(val, Variable): 628 if self._var_classes[val.index] is None: 629 self._var_classes[val.index] = val 630 else: 631 other = self._var_classes[val.index] 632 self.unify_bit_size(other, val, 633 lambda other_bit_size, bit_size: 634 'Variable {} has conflicting bit size requirements: ' \ 635 'it must have bit size {} and {}'.format( 636 val.var_name, other_bit_size, bit_size)) 637 elif isinstance(val, Expression): 638 for src in val.sources: 639 self.merge_variables(src) 640 641 def validate_value(self, val): 642 """Validate the an expression by performing classic Hindley-Milner 643 type inference on bitsizes. This will detect if there are any conflicting 644 requirements, and unify variables so that we know which variables must 645 have the same bitsize. If we're operating on the replace expression, we 646 will refuse to merge different variables together or merge a variable 647 with a constant, in order to prevent surprises due to rules unexpectedly 648 not matching at runtime. 649 """ 650 if not isinstance(val, Expression): 651 return 652 653 # Generic conversion ops are special in that they have a single unsized 654 # source and an unsized destination and the two don't have to match. 655 # This means there's no validation or unioning to do here besides the 656 # len(val.sources) check. 657 if val.opcode in conv_opcode_types: 658 assert len(val.sources) == 1, \ 659 "Expression {} has {} sources, expected 1".format( 660 val, len(val.sources)) 661 self.validate_value(val.sources[0]) 662 return 663 664 nir_op = opcodes[val.opcode] 665 assert len(val.sources) == nir_op.num_inputs, \ 666 "Expression {} has {} sources, expected {}".format( 667 val, len(val.sources), nir_op.num_inputs) 668 669 for src in val.sources: 670 self.validate_value(src) 671 672 dst_type_bits = type_bits(nir_op.output_type) 673 674 # First, unify all the sources. That way, an error coming up because two 675 # sources have an incompatible bit-size won't produce an error message 676 # involving the destination. 677 first_unsized_src = None 678 for src_type, src in zip(nir_op.input_types, val.sources): 679 src_type_bits = type_bits(src_type) 680 if src_type_bits == 0: 681 if first_unsized_src is None: 682 first_unsized_src = src 683 continue 684 685 if self.is_search: 686 self.unify_bit_size(first_unsized_src, src, 687 lambda first_unsized_src_bit_size, src_bit_size: 688 'Source {} of {} must have bit size {}, while source {} ' \ 689 'must have incompatible bit size {}'.format( 690 first_unsized_src, val, first_unsized_src_bit_size, 691 src, src_bit_size)) 692 else: 693 self.unify_bit_size(first_unsized_src, src, 694 lambda first_unsized_src_bit_size, src_bit_size: 695 'Sources {} (bit size of {}) and {} (bit size of {}) ' \ 696 'of {} may not have the same bit size when building the ' \ 697 'replacement expression.'.format( 698 first_unsized_src, first_unsized_src_bit_size, src, 699 src_bit_size, val)) 700 else: 701 if self.is_search: 702 self.unify_bit_size(src, src_type_bits, 703 lambda src_bit_size, unused: 704 '{} must have {} bits, but as a source of nir_op_{} '\ 705 'it must have {} bits'.format( 706 src, src_bit_size, nir_op.name, src_type_bits)) 707 else: 708 self.unify_bit_size(src, src_type_bits, 709 lambda src_bit_size, unused: 710 '{} has the bit size of {}, but as a source of ' \ 711 'nir_op_{} it must have {} bits, which may not be the ' \ 712 'same'.format( 713 src, src_bit_size, nir_op.name, src_type_bits)) 714 715 if dst_type_bits == 0: 716 if first_unsized_src is not None: 717 if self.is_search: 718 self.unify_bit_size(val, first_unsized_src, 719 lambda val_bit_size, src_bit_size: 720 '{} must have the bit size of {}, while its source {} ' \ 721 'must have incompatible bit size {}'.format( 722 val, val_bit_size, first_unsized_src, src_bit_size)) 723 else: 724 self.unify_bit_size(val, first_unsized_src, 725 lambda val_bit_size, src_bit_size: 726 '{} must have {} bits, but its source {} ' \ 727 '(bit size of {}) may not have that bit size ' \ 728 'when building the replacement.'.format( 729 val, val_bit_size, first_unsized_src, src_bit_size)) 730 else: 731 self.unify_bit_size(val, dst_type_bits, 732 lambda dst_bit_size, unused: 733 '{} must have {} bits, but as a destination of nir_op_{} ' \ 734 'it must have {} bits'.format( 735 val, dst_bit_size, nir_op.name, dst_type_bits)) 736 737 def validate_replace(self, val, search): 738 bit_size = val.get_bit_size() 739 assert isinstance(bit_size, int) or isinstance(bit_size, Variable) or \ 740 bit_size == search.get_bit_size(), \ 741 'Ambiguous bit size for replacement value {}: ' \ 742 'it cannot be deduced from a variable, a fixed bit size ' \ 743 'somewhere, or the search expression.'.format(val) 744 745 if isinstance(val, Expression): 746 for src in val.sources: 747 self.validate_replace(src, search) 748 elif isinstance(val, Variable): 749 # These catch problems when someone copies and pastes the search 750 # into the replacement. 751 assert not val.is_constant, \ 752 'Replacement variables must not be marked constant.' 753 754 assert val.cond_index == -1, \ 755 'Replacement variables must not have a condition.' 756 757 assert not val.required_type, \ 758 'Replacement variables must not have a required type.' 759 760 def validate(self, search, replace): 761 self.is_search = True 762 self.merge_variables(search) 763 self.merge_variables(replace) 764 self.validate_value(search) 765 766 self.is_search = False 767 self.validate_value(replace) 768 769 # Check that search is always more specialized than replace. Note that 770 # we're doing this in replace mode, disallowing merging variables. 771 search_bit_size = search.get_bit_size() 772 replace_bit_size = replace.get_bit_size() 773 cmp_result = self.compare_bitsizes(search_bit_size, replace_bit_size) 774 775 assert cmp_result is not None and cmp_result <= 0, \ 776 'The search expression bit size {} and replace expression ' \ 777 'bit size {} may not be the same'.format( 778 search_bit_size, replace_bit_size) 779 780 replace.set_bit_size(search) 781 782 self.validate_replace(replace, search) 783 784_optimization_ids = itertools.count() 785 786condition_list = ['true'] 787 788class SearchAndReplace(object): 789 def __init__(self, transform, algebraic_pass): 790 self.id = next(_optimization_ids) 791 792 search = transform[0] 793 replace = transform[1] 794 if len(transform) > 2: 795 self.condition = transform[2] 796 else: 797 self.condition = 'true' 798 799 if self.condition not in condition_list: 800 condition_list.append(self.condition) 801 self.condition_index = condition_list.index(self.condition) 802 803 varset = VarSet() 804 if isinstance(search, Expression): 805 self.search = search 806 else: 807 self.search = Expression(search, "search{0}".format(self.id), varset, algebraic_pass) 808 809 varset.lock() 810 811 if isinstance(replace, Value): 812 self.replace = replace 813 else: 814 self.replace = Value.create(replace, "replace{0}".format(self.id), varset, algebraic_pass) 815 816 BitSizeValidator(varset).validate(self.search, self.replace) 817 818class TreeAutomaton(object): 819 """This class calculates a bottom-up tree automaton to quickly search for 820 the left-hand sides of tranforms. Tree automatons are a generalization of 821 classical NFA's and DFA's, where the transition function determines the 822 state of the parent node based on the state of its children. We construct a 823 deterministic automaton to match patterns, using a similar algorithm to the 824 classical NFA to DFA construction. At the moment, it only matches opcodes 825 and constants (without checking the actual value), leaving more detailed 826 checking to the search function which actually checks the leaves. The 827 automaton acts as a quick filter for the search function, requiring only n 828 + 1 table lookups for each n-source operation. The implementation is based 829 on the theory described in "Tree Automatons: Two Taxonomies and a Toolkit." 830 In the language of that reference, this is a frontier-to-root deterministic 831 automaton using only symbol filtering. The filtering is crucial to reduce 832 both the time taken to generate the tables and the size of the tables. 833 """ 834 def __init__(self, transforms): 835 self.patterns = [t.search for t in transforms] 836 self._compute_items() 837 self._build_table() 838 #print('num items: {}'.format(len(set(self.items.values())))) 839 #print('num states: {}'.format(len(self.states))) 840 #for state, patterns in zip(self.states, self.patterns): 841 # print('{}: num patterns: {}'.format(state, len(patterns))) 842 843 class IndexMap(object): 844 """An indexed list of objects, where one can either lookup an object by 845 index or find the index associated to an object quickly using a hash 846 table. Compared to a list, it has a constant time index(). Compared to a 847 set, it provides a stable iteration order. 848 """ 849 def __init__(self, iterable=()): 850 self.objects = [] 851 self.map = {} 852 for obj in iterable: 853 self.add(obj) 854 855 def __getitem__(self, i): 856 return self.objects[i] 857 858 def __contains__(self, obj): 859 return obj in self.map 860 861 def __len__(self): 862 return len(self.objects) 863 864 def __iter__(self): 865 return iter(self.objects) 866 867 def clear(self): 868 self.objects = [] 869 self.map.clear() 870 871 def index(self, obj): 872 return self.map[obj] 873 874 def add(self, obj): 875 if obj in self.map: 876 return self.map[obj] 877 else: 878 index = len(self.objects) 879 self.objects.append(obj) 880 self.map[obj] = index 881 return index 882 883 def __repr__(self): 884 return 'IndexMap([' + ', '.join(repr(e) for e in self.objects) + '])' 885 886 class Item(object): 887 """This represents an "item" in the language of "Tree Automatons." This 888 is just a subtree of some pattern, which represents a potential partial 889 match at runtime. We deduplicate them, so that identical subtrees of 890 different patterns share the same object, and store some extra 891 information needed for the main algorithm as well. 892 """ 893 def __init__(self, opcode, children): 894 self.opcode = opcode 895 self.children = children 896 # These are the indices of patterns for which this item is the root node. 897 self.patterns = [] 898 # This the set of opcodes for parents of this item. Used to speed up 899 # filtering. 900 self.parent_ops = set() 901 902 def __str__(self): 903 return '(' + ', '.join([self.opcode] + [str(c) for c in self.children]) + ')' 904 905 def __repr__(self): 906 return str(self) 907 908 def _compute_items(self): 909 """Build a set of all possible items, deduplicating them.""" 910 # This is a map from (opcode, sources) to item. 911 self.items = {} 912 913 # The set of all opcodes used by the patterns. Used later to avoid 914 # building and emitting all the tables for opcodes that aren't used. 915 self.opcodes = self.IndexMap() 916 917 def get_item(opcode, children, pattern=None): 918 commutative = len(children) >= 2 \ 919 and "2src_commutative" in opcodes[opcode].algebraic_properties 920 item = self.items.setdefault((opcode, children), 921 self.Item(opcode, children)) 922 if commutative: 923 self.items[opcode, (children[1], children[0]) + children[2:]] = item 924 if pattern is not None: 925 item.patterns.append(pattern) 926 return item 927 928 self.wildcard = get_item("__wildcard", ()) 929 self.const = get_item("__const", ()) 930 931 def process_subpattern(src, pattern=None): 932 if isinstance(src, Constant): 933 # Note: we throw away the actual constant value! 934 return self.const 935 elif isinstance(src, Variable): 936 if src.is_constant: 937 return self.const 938 else: 939 # Note: we throw away which variable it is here! This special 940 # item is equivalent to nu in "Tree Automatons." 941 return self.wildcard 942 else: 943 assert isinstance(src, Expression) 944 opcode = src.opcode 945 stripped = opcode.rstrip('0123456789') 946 if stripped in conv_opcode_types: 947 # Matches that use conversion opcodes with a specific type, 948 # like f2i1, are tricky. Either we construct the automaton to 949 # match specific NIR opcodes like nir_op_f2i1, in which case we 950 # need to create separate items for each possible NIR opcode 951 # for patterns that have a generic opcode like f2i, or we 952 # construct it to match the search opcode, in which case we 953 # need to map f2i1 to f2i when constructing the automaton. Here 954 # we do the latter. 955 opcode = stripped 956 self.opcodes.add(opcode) 957 children = tuple(process_subpattern(c) for c in src.sources) 958 item = get_item(opcode, children, pattern) 959 for i, child in enumerate(children): 960 child.parent_ops.add(opcode) 961 return item 962 963 for i, pattern in enumerate(self.patterns): 964 process_subpattern(pattern, i) 965 966 def _build_table(self): 967 """This is the core algorithm which builds up the transition table. It 968 is based off of Algorithm 5.7.38 "Reachability-based tabulation of Cl . 969 Comp_a and Filt_{a,i} using integers to identify match sets." It 970 simultaneously builds up a list of all possible "match sets" or 971 "states", where each match set represents the set of Item's that match a 972 given instruction, and builds up the transition table between states. 973 """ 974 # Map from opcode + filtered state indices to transitioned state. 975 self.table = defaultdict(dict) 976 # Bijection from state to index. q in the original algorithm is 977 # len(self.states) 978 self.states = self.IndexMap() 979 # Lists of pattern matches separated by None 980 self.state_patterns = [None] 981 # Offset in the ->transforms table for each state index 982 self.state_pattern_offsets = [] 983 # Map from state index to filtered state index for each opcode. 984 self.filter = defaultdict(list) 985 # Bijections from filtered state to filtered state index for each 986 # opcode, called the "representor sets" in the original algorithm. 987 # q_{a,j} in the original algorithm is len(self.rep[op]). 988 self.rep = defaultdict(self.IndexMap) 989 990 # Everything in self.states with a index at least worklist_index is part 991 # of the worklist of newly created states. There is also a worklist of 992 # newly fitered states for each opcode, for which worklist_indices 993 # serves a similar purpose. worklist_index corresponds to p in the 994 # original algorithm, while worklist_indices is p_{a,j} (although since 995 # we only filter by opcode/symbol, it's really just p_a). 996 self.worklist_index = 0 997 worklist_indices = defaultdict(lambda: 0) 998 999 # This is the set of opcodes for which the filtered worklist is non-empty. 1000 # It's used to avoid scanning opcodes for which there is nothing to 1001 # process when building the transition table. It corresponds to new_a in 1002 # the original algorithm. 1003 new_opcodes = self.IndexMap() 1004 1005 # Process states on the global worklist, filtering them for each opcode, 1006 # updating the filter tables, and updating the filtered worklists if any 1007 # new filtered states are found. Similar to ComputeRepresenterSets() in 1008 # the original algorithm, although that only processes a single state. 1009 def process_new_states(): 1010 while self.worklist_index < len(self.states): 1011 state = self.states[self.worklist_index] 1012 # Calculate pattern matches for this state. Each pattern is 1013 # assigned to a unique item, so we don't have to worry about 1014 # deduplicating them here. However, we do have to sort them so 1015 # that they're visited at runtime in the order they're specified 1016 # in the source. 1017 patterns = list(sorted(p for item in state for p in item.patterns)) 1018 1019 if patterns: 1020 # Add our patterns to the global table. 1021 self.state_pattern_offsets.append(len(self.state_patterns)) 1022 self.state_patterns.extend(patterns) 1023 self.state_patterns.append(None) 1024 else: 1025 # Point to the initial sentinel in the global table. 1026 self.state_pattern_offsets.append(0) 1027 1028 # calculate filter table for this state, and update filtered 1029 # worklists. 1030 for op in self.opcodes: 1031 filt = self.filter[op] 1032 rep = self.rep[op] 1033 filtered = frozenset(item for item in state if \ 1034 op in item.parent_ops) 1035 if filtered in rep: 1036 rep_index = rep.index(filtered) 1037 else: 1038 rep_index = rep.add(filtered) 1039 new_opcodes.add(op) 1040 assert len(filt) == self.worklist_index 1041 filt.append(rep_index) 1042 self.worklist_index += 1 1043 1044 # There are two start states: one which can only match as a wildcard, 1045 # and one which can match as a wildcard or constant. These will be the 1046 # states of intrinsics/other instructions and load_const instructions, 1047 # respectively. The indices of these must match the definitions of 1048 # WILDCARD_STATE and CONST_STATE below, so that the runtime C code can 1049 # initialize things correctly. 1050 self.states.add(frozenset((self.wildcard,))) 1051 self.states.add(frozenset((self.const,self.wildcard))) 1052 process_new_states() 1053 1054 while len(new_opcodes) > 0: 1055 for op in new_opcodes: 1056 rep = self.rep[op] 1057 table = self.table[op] 1058 op_worklist_index = worklist_indices[op] 1059 if op in conv_opcode_types: 1060 num_srcs = 1 1061 else: 1062 num_srcs = opcodes[op].num_inputs 1063 1064 # Iterate over all possible source combinations where at least one 1065 # is on the worklist. 1066 for src_indices in itertools.product(range(len(rep)), repeat=num_srcs): 1067 if all(src_idx < op_worklist_index for src_idx in src_indices): 1068 continue 1069 1070 srcs = tuple(rep[src_idx] for src_idx in src_indices) 1071 1072 # Try all possible pairings of source items and add the 1073 # corresponding parent items. This is Comp_a from the paper. 1074 parent = set(self.items[op, item_srcs] for item_srcs in 1075 itertools.product(*srcs) if (op, item_srcs) in self.items) 1076 1077 # We could always start matching something else with a 1078 # wildcard. This is Cl from the paper. 1079 parent.add(self.wildcard) 1080 1081 table[src_indices] = self.states.add(frozenset(parent)) 1082 worklist_indices[op] = len(rep) 1083 new_opcodes.clear() 1084 process_new_states() 1085 1086_algebraic_pass_template = mako.template.Template(""" 1087#include "nir.h" 1088#include "nir_builder.h" 1089#include "nir_search.h" 1090#include "nir_search_helpers.h" 1091 1092/* What follows is NIR algebraic transform code for the following ${len(xforms)} 1093 * transforms: 1094% for xform in xforms: 1095 * ${xform.search} => ${xform.replace} 1096% endfor 1097 */ 1098 1099<% cache = {"next_index": 0} %> 1100static const nir_search_value_union ${pass_name}_values[] = { 1101% for xform in xforms: 1102 /* ${xform.search} => ${xform.replace} */ 1103${xform.search.render(cache)} 1104${xform.replace.render(cache)} 1105% endfor 1106}; 1107 1108% if expression_cond: 1109static const nir_search_expression_cond ${pass_name}_expression_cond[] = { 1110% for cond in expression_cond: 1111 ${cond[0]}, 1112% endfor 1113}; 1114% endif 1115 1116% if variable_cond: 1117static const nir_search_variable_cond ${pass_name}_variable_cond[] = { 1118% for cond in variable_cond: 1119 ${cond[0]}, 1120% endfor 1121}; 1122% endif 1123 1124static const struct transform ${pass_name}_transforms[] = { 1125% for i in automaton.state_patterns: 1126% if i is not None: 1127 { ${xforms[i].search.array_index}, ${xforms[i].replace.array_index}, ${xforms[i].condition_index} }, 1128% else: 1129 { ~0, ~0, ~0 }, /* Sentinel */ 1130 1131% endif 1132% endfor 1133}; 1134 1135static const struct per_op_table ${pass_name}_pass_op_table[nir_num_search_ops] = { 1136% for op in automaton.opcodes: 1137 [${get_c_opcode(op)}] = { 1138% if all(e == 0 for e in automaton.filter[op]): 1139 .filter = NULL, 1140% else: 1141 .filter = (const uint16_t []) { 1142 % for e in automaton.filter[op]: 1143 ${e}, 1144 % endfor 1145 }, 1146% endif 1147 <% 1148 num_filtered = len(automaton.rep[op]) 1149 %> 1150 .num_filtered_states = ${num_filtered}, 1151 .table = (const uint16_t []) { 1152 <% 1153 num_srcs = len(next(iter(automaton.table[op]))) 1154 %> 1155 % for indices in itertools.product(range(num_filtered), repeat=num_srcs): 1156 ${automaton.table[op][indices]}, 1157 % endfor 1158 }, 1159 }, 1160% endfor 1161}; 1162 1163/* Mapping from state index to offset in transforms (0 being no transforms) */ 1164static const uint16_t ${pass_name}_transform_offsets[] = { 1165% for offset in automaton.state_pattern_offsets: 1166 ${offset}, 1167% endfor 1168}; 1169 1170static const nir_algebraic_table ${pass_name}_table = { 1171 .transforms = ${pass_name}_transforms, 1172 .transform_offsets = ${pass_name}_transform_offsets, 1173 .pass_op_table = ${pass_name}_pass_op_table, 1174 .values = ${pass_name}_values, 1175 .expression_cond = ${ pass_name + "_expression_cond" if expression_cond else "NULL" }, 1176 .variable_cond = ${ pass_name + "_variable_cond" if variable_cond else "NULL" }, 1177}; 1178 1179bool 1180${pass_name}( 1181 nir_shader *shader 1182% for type, name in params: 1183 , ${type} ${name} 1184% endfor 1185) { 1186 bool progress = false; 1187 bool condition_flags[${len(condition_list)}]; 1188 const nir_shader_compiler_options *options = shader->options; 1189 const shader_info *info = &shader->info; 1190 (void) options; 1191 (void) info; 1192 1193 STATIC_ASSERT(${str(cache["next_index"])} == ARRAY_SIZE(${pass_name}_values)); 1194 % for index, condition in enumerate(condition_list): 1195 condition_flags[${index}] = ${condition}; 1196 % endfor 1197 1198 nir_foreach_function_impl(impl, shader) { 1199 progress |= nir_algebraic_impl(impl, condition_flags, &${pass_name}_table); 1200 } 1201 1202 return progress; 1203} 1204""") 1205 1206 1207class AlgebraicPass(object): 1208 # params is a list of `("type", "name")` tuples 1209 def __init__(self, pass_name, transforms, params=[]): 1210 self.xforms = [] 1211 self.opcode_xforms = defaultdict(lambda : []) 1212 self.pass_name = pass_name 1213 self.expression_cond = {} 1214 self.variable_cond = {} 1215 self.params = params 1216 1217 error = False 1218 1219 for xform in transforms: 1220 if not isinstance(xform, SearchAndReplace): 1221 try: 1222 xform = SearchAndReplace(xform, self) 1223 except: 1224 print("Failed to parse transformation:", file=sys.stderr) 1225 print(" " + str(xform), file=sys.stderr) 1226 traceback.print_exc(file=sys.stderr) 1227 print('', file=sys.stderr) 1228 error = True 1229 continue 1230 1231 self.xforms.append(xform) 1232 if xform.search.opcode in conv_opcode_types: 1233 dst_type = conv_opcode_types[xform.search.opcode] 1234 for size in type_sizes(dst_type): 1235 sized_opcode = xform.search.opcode + str(size) 1236 self.opcode_xforms[sized_opcode].append(xform) 1237 else: 1238 self.opcode_xforms[xform.search.opcode].append(xform) 1239 1240 # Check to make sure the search pattern does not unexpectedly contain 1241 # more commutative expressions than match_expression (nir_search.c) 1242 # can handle. 1243 comm_exprs = xform.search.comm_exprs 1244 1245 if xform.search.many_commutative_expressions: 1246 if comm_exprs <= nir_search_max_comm_ops: 1247 print("Transform expected to have too many commutative " \ 1248 "expression but did not " \ 1249 "({} <= {}).".format(comm_exprs, nir_search_max_comm_op), 1250 file=sys.stderr) 1251 print(" " + str(xform), file=sys.stderr) 1252 traceback.print_exc(file=sys.stderr) 1253 print('', file=sys.stderr) 1254 error = True 1255 else: 1256 if comm_exprs > nir_search_max_comm_ops: 1257 print("Transformation with too many commutative expressions " \ 1258 "({} > {}). Modify pattern or annotate with " \ 1259 "\"many-comm-expr\".".format(comm_exprs, 1260 nir_search_max_comm_ops), 1261 file=sys.stderr) 1262 print(" " + str(xform.search), file=sys.stderr) 1263 print("{}".format(xform.search.cond), file=sys.stderr) 1264 error = True 1265 1266 self.automaton = TreeAutomaton(self.xforms) 1267 1268 if error: 1269 sys.exit(1) 1270 1271 1272 def render(self): 1273 return _algebraic_pass_template.render(pass_name=self.pass_name, 1274 xforms=self.xforms, 1275 opcode_xforms=self.opcode_xforms, 1276 condition_list=condition_list, 1277 automaton=self.automaton, 1278 expression_cond = sorted(self.expression_cond.items(), key=lambda kv: kv[1]), 1279 variable_cond = sorted(self.variable_cond.items(), key=lambda kv: kv[1]), 1280 get_c_opcode=get_c_opcode, 1281 itertools=itertools, 1282 params=self.params) 1283 1284# The replacement expression isn't necessarily exact if the search expression is exact. 1285def ignore_exact(*expr): 1286 expr = SearchExpression.create(expr) 1287 expr.ignore_exact = True 1288 return expr 1289