1# 2# Copyright (C) 2014 Intel Corporation 3# 4# Permission is hereby granted, free of charge, to any person obtaining a 5# copy of this software and associated documentation files (the "Software"), 6# to deal in the Software without restriction, including without limitation 7# the rights to use, copy, modify, merge, publish, distribute, sublicense, 8# and/or sell copies of the Software, and to permit persons to whom the 9# Software is furnished to do so, subject to the following conditions: 10# 11# The above copyright notice and this permission notice (including the next 12# paragraph) shall be included in all copies or substantial portions of the 13# Software. 14# 15# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 18# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 20# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 21# IN THE SOFTWARE. 22# 23# Authors: 24# Jason Ekstrand (jason@jlekstrand.net) 25 26from __future__ import print_function 27import ast 28import itertools 29import struct 30import sys 31import mako.template 32import re 33import traceback 34 35from nir_opcodes import opcodes 36 37_type_re = re.compile(r"(?P<type>int|uint|bool|float)?(?P<bits>\d+)?") 38 39def type_bits(type_str): 40 m = _type_re.match(type_str) 41 assert m.group('type') 42 43 if m.group('bits') is None: 44 return 0 45 else: 46 return int(m.group('bits')) 47 48# Represents a set of variables, each with a unique id 49class VarSet(object): 50 def __init__(self): 51 self.names = {} 52 self.ids = itertools.count() 53 self.immutable = False; 54 55 def __getitem__(self, name): 56 if name not in self.names: 57 assert not self.immutable, "Unknown replacement variable: " + name 58 self.names[name] = self.ids.next() 59 60 return self.names[name] 61 62 def lock(self): 63 self.immutable = True 64 65class Value(object): 66 @staticmethod 67 def create(val, name_base, varset): 68 if isinstance(val, tuple): 69 return Expression(val, name_base, varset) 70 elif isinstance(val, Expression): 71 return val 72 elif isinstance(val, (str, unicode)): 73 return Variable(val, name_base, varset) 74 elif isinstance(val, (bool, int, long, float)): 75 return Constant(val, name_base) 76 77 __template = mako.template.Template(""" 78static const ${val.c_type} ${val.name} = { 79 { ${val.type_enum}, ${val.bit_size} }, 80% if isinstance(val, Constant): 81 ${val.type()}, { ${hex(val)} /* ${val.value} */ }, 82% elif isinstance(val, Variable): 83 ${val.index}, /* ${val.var_name} */ 84 ${'true' if val.is_constant else 'false'}, 85 ${val.type() or 'nir_type_invalid' }, 86 ${val.cond if val.cond else 'NULL'}, 87% elif isinstance(val, Expression): 88 ${'true' if val.inexact else 'false'}, 89 nir_op_${val.opcode}, 90 { ${', '.join(src.c_ptr for src in val.sources)} }, 91 ${val.cond if val.cond else 'NULL'}, 92% endif 93};""") 94 95 def __init__(self, name, type_str): 96 self.name = name 97 self.type_str = type_str 98 99 @property 100 def type_enum(self): 101 return "nir_search_value_" + self.type_str 102 103 @property 104 def c_type(self): 105 return "nir_search_" + self.type_str 106 107 @property 108 def c_ptr(self): 109 return "&{0}.value".format(self.name) 110 111 def render(self): 112 return self.__template.render(val=self, 113 Constant=Constant, 114 Variable=Variable, 115 Expression=Expression) 116 117_constant_re = re.compile(r"(?P<value>[^@\(]+)(?:@(?P<bits>\d+))?") 118 119class Constant(Value): 120 def __init__(self, val, name): 121 Value.__init__(self, name, "constant") 122 123 if isinstance(val, (str)): 124 m = _constant_re.match(val) 125 self.value = ast.literal_eval(m.group('value')) 126 self.bit_size = int(m.group('bits')) if m.group('bits') else 0 127 else: 128 self.value = val 129 self.bit_size = 0 130 131 if isinstance(self.value, bool): 132 assert self.bit_size == 0 or self.bit_size == 32 133 self.bit_size = 32 134 135 def __hex__(self): 136 if isinstance(self.value, (bool)): 137 return 'NIR_TRUE' if self.value else 'NIR_FALSE' 138 if isinstance(self.value, (int, long)): 139 return hex(self.value) 140 elif isinstance(self.value, float): 141 return hex(struct.unpack('Q', struct.pack('d', self.value))[0]) 142 else: 143 assert False 144 145 def type(self): 146 if isinstance(self.value, (bool)): 147 return "nir_type_bool32" 148 elif isinstance(self.value, (int, long)): 149 return "nir_type_int" 150 elif isinstance(self.value, float): 151 return "nir_type_float" 152 153_var_name_re = re.compile(r"(?P<const>#)?(?P<name>\w+)" 154 r"(?:@(?P<type>int|uint|bool|float)?(?P<bits>\d+)?)?" 155 r"(?P<cond>\([^\)]+\))?") 156 157class Variable(Value): 158 def __init__(self, val, name, varset): 159 Value.__init__(self, name, "variable") 160 161 m = _var_name_re.match(val) 162 assert m and m.group('name') is not None 163 164 self.var_name = m.group('name') 165 self.is_constant = m.group('const') is not None 166 self.cond = m.group('cond') 167 self.required_type = m.group('type') 168 self.bit_size = int(m.group('bits')) if m.group('bits') else 0 169 170 if self.required_type == 'bool': 171 assert self.bit_size == 0 or self.bit_size == 32 172 self.bit_size = 32 173 174 if self.required_type is not None: 175 assert self.required_type in ('float', 'bool', 'int', 'uint') 176 177 self.index = varset[self.var_name] 178 179 def type(self): 180 if self.required_type == 'bool': 181 return "nir_type_bool32" 182 elif self.required_type in ('int', 'uint'): 183 return "nir_type_int" 184 elif self.required_type == 'float': 185 return "nir_type_float" 186 187_opcode_re = re.compile(r"(?P<inexact>~)?(?P<opcode>\w+)(?:@(?P<bits>\d+))?" 188 r"(?P<cond>\([^\)]+\))?") 189 190class Expression(Value): 191 def __init__(self, expr, name_base, varset): 192 Value.__init__(self, name_base, "expression") 193 assert isinstance(expr, tuple) 194 195 m = _opcode_re.match(expr[0]) 196 assert m and m.group('opcode') is not None 197 198 self.opcode = m.group('opcode') 199 self.bit_size = int(m.group('bits')) if m.group('bits') else 0 200 self.inexact = m.group('inexact') is not None 201 self.cond = m.group('cond') 202 self.sources = [ Value.create(src, "{0}_{1}".format(name_base, i), varset) 203 for (i, src) in enumerate(expr[1:]) ] 204 205 def render(self): 206 srcs = "\n".join(src.render() for src in self.sources) 207 return srcs + super(Expression, self).render() 208 209class IntEquivalenceRelation(object): 210 """A class representing an equivalence relation on integers. 211 212 Each integer has a canonical form which is the maximum integer to which it 213 is equivalent. Two integers are equivalent precisely when they have the 214 same canonical form. 215 216 The convention of maximum is explicitly chosen to make using it in 217 BitSizeValidator easier because it means that an actual bit_size (if any) 218 will always be the canonical form. 219 """ 220 def __init__(self): 221 self._remap = {} 222 223 def get_canonical(self, x): 224 """Get the canonical integer corresponding to x.""" 225 if x in self._remap: 226 return self.get_canonical(self._remap[x]) 227 else: 228 return x 229 230 def add_equiv(self, a, b): 231 """Add an equivalence and return the canonical form.""" 232 c = max(self.get_canonical(a), self.get_canonical(b)) 233 if a != c: 234 assert a < c 235 self._remap[a] = c 236 237 if b != c: 238 assert b < c 239 self._remap[b] = c 240 241 return c 242 243class BitSizeValidator(object): 244 """A class for validating bit sizes of expressions. 245 246 NIR supports multiple bit-sizes on expressions in order to handle things 247 such as fp64. The source and destination of every ALU operation is 248 assigned a type and that type may or may not specify a bit size. Sources 249 and destinations whose type does not specify a bit size are considered 250 "unsized" and automatically take on the bit size of the corresponding 251 register or SSA value. NIR has two simple rules for bit sizes that are 252 validated by nir_validator: 253 254 1) A given SSA def or register has a single bit size that is respected by 255 everything that reads from it or writes to it. 256 257 2) The bit sizes of all unsized inputs/outputs on any given ALU 258 instruction must match. They need not match the sized inputs or 259 outputs but they must match each other. 260 261 In order to keep nir_algebraic relatively simple and easy-to-use, 262 nir_search supports a type of bit-size inference based on the two rules 263 above. This is similar to type inference in many common programming 264 languages. If, for instance, you are constructing an add operation and you 265 know the second source is 16-bit, then you know that the other source and 266 the destination must also be 16-bit. There are, however, cases where this 267 inference can be ambiguous or contradictory. Consider, for instance, the 268 following transformation: 269 270 (('usub_borrow', a, b), ('b2i', ('ult', a, b))) 271 272 This transformation can potentially cause a problem because usub_borrow is 273 well-defined for any bit-size of integer. However, b2i always generates a 274 32-bit result so it could end up replacing a 64-bit expression with one 275 that takes two 64-bit values and produces a 32-bit value. As another 276 example, consider this expression: 277 278 (('bcsel', a, b, 0), ('iand', a, b)) 279 280 In this case, in the search expression a must be 32-bit but b can 281 potentially have any bit size. If we had a 64-bit b value, we would end up 282 trying to and a 32-bit value with a 64-bit value which would be invalid 283 284 This class solves that problem by providing a validation layer that proves 285 that a given search-and-replace operation is 100% well-defined before we 286 generate any code. This ensures that bugs are caught at compile time 287 rather than at run time. 288 289 The basic operation of the validator is very similar to the bitsize_tree in 290 nir_search only a little more subtle. Instead of simply tracking bit 291 sizes, it tracks "bit classes" where each class is represented by an 292 integer. A value of 0 means we don't know anything yet, positive values 293 are actual bit-sizes, and negative values are used to track equivalence 294 classes of sizes that must be the same but have yet to receive an actual 295 size. The first stage uses the bitsize_tree algorithm to assign bit 296 classes to each variable. If it ever comes across an inconsistency, it 297 assert-fails. Then the second stage uses that information to prove that 298 the resulting expression can always validly be constructed. 299 """ 300 301 def __init__(self, varset): 302 self._num_classes = 0 303 self._var_classes = [0] * len(varset.names) 304 self._class_relation = IntEquivalenceRelation() 305 306 def validate(self, search, replace): 307 dst_class = self._propagate_bit_size_up(search) 308 if dst_class == 0: 309 dst_class = self._new_class() 310 self._propagate_bit_class_down(search, dst_class) 311 312 validate_dst_class = self._validate_bit_class_up(replace) 313 assert validate_dst_class == 0 or validate_dst_class == dst_class 314 self._validate_bit_class_down(replace, dst_class) 315 316 def _new_class(self): 317 self._num_classes += 1 318 return -self._num_classes 319 320 def _set_var_bit_class(self, var_id, bit_class): 321 assert bit_class != 0 322 var_class = self._var_classes[var_id] 323 if var_class == 0: 324 self._var_classes[var_id] = bit_class 325 else: 326 canon_class = self._class_relation.get_canonical(var_class) 327 assert canon_class < 0 or canon_class == bit_class 328 var_class = self._class_relation.add_equiv(var_class, bit_class) 329 self._var_classes[var_id] = var_class 330 331 def _get_var_bit_class(self, var_id): 332 return self._class_relation.get_canonical(self._var_classes[var_id]) 333 334 def _propagate_bit_size_up(self, val): 335 if isinstance(val, (Constant, Variable)): 336 return val.bit_size 337 338 elif isinstance(val, Expression): 339 nir_op = opcodes[val.opcode] 340 val.common_size = 0 341 for i in range(nir_op.num_inputs): 342 src_bits = self._propagate_bit_size_up(val.sources[i]) 343 if src_bits == 0: 344 continue 345 346 src_type_bits = type_bits(nir_op.input_types[i]) 347 if src_type_bits != 0: 348 assert src_bits == src_type_bits 349 else: 350 assert val.common_size == 0 or src_bits == val.common_size 351 val.common_size = src_bits 352 353 dst_type_bits = type_bits(nir_op.output_type) 354 if dst_type_bits != 0: 355 assert val.bit_size == 0 or val.bit_size == dst_type_bits 356 return dst_type_bits 357 else: 358 if val.common_size != 0: 359 assert val.bit_size == 0 or val.bit_size == val.common_size 360 else: 361 val.common_size = val.bit_size 362 return val.common_size 363 364 def _propagate_bit_class_down(self, val, bit_class): 365 if isinstance(val, Constant): 366 assert val.bit_size == 0 or val.bit_size == bit_class 367 368 elif isinstance(val, Variable): 369 assert val.bit_size == 0 or val.bit_size == bit_class 370 self._set_var_bit_class(val.index, bit_class) 371 372 elif isinstance(val, Expression): 373 nir_op = opcodes[val.opcode] 374 dst_type_bits = type_bits(nir_op.output_type) 375 if dst_type_bits != 0: 376 assert bit_class == 0 or bit_class == dst_type_bits 377 else: 378 assert val.common_size == 0 or val.common_size == bit_class 379 val.common_size = bit_class 380 381 if val.common_size: 382 common_class = val.common_size 383 elif nir_op.num_inputs: 384 # If we got here then we have no idea what the actual size is. 385 # Instead, we use a generic class 386 common_class = self._new_class() 387 388 for i in range(nir_op.num_inputs): 389 src_type_bits = type_bits(nir_op.input_types[i]) 390 if src_type_bits != 0: 391 self._propagate_bit_class_down(val.sources[i], src_type_bits) 392 else: 393 self._propagate_bit_class_down(val.sources[i], common_class) 394 395 def _validate_bit_class_up(self, val): 396 if isinstance(val, Constant): 397 return val.bit_size 398 399 elif isinstance(val, Variable): 400 var_class = self._get_var_bit_class(val.index) 401 # By the time we get to validation, every variable should have a class 402 assert var_class != 0 403 404 # If we have an explicit size provided by the user, the variable 405 # *must* exactly match the search. It cannot be implicitly sized 406 # because otherwise we could end up with a conflict at runtime. 407 assert val.bit_size == 0 or val.bit_size == var_class 408 409 return var_class 410 411 elif isinstance(val, Expression): 412 nir_op = opcodes[val.opcode] 413 val.common_class = 0 414 for i in range(nir_op.num_inputs): 415 src_class = self._validate_bit_class_up(val.sources[i]) 416 if src_class == 0: 417 continue 418 419 src_type_bits = type_bits(nir_op.input_types[i]) 420 if src_type_bits != 0: 421 assert src_class == src_type_bits 422 else: 423 assert val.common_class == 0 or src_class == val.common_class 424 val.common_class = src_class 425 426 dst_type_bits = type_bits(nir_op.output_type) 427 if dst_type_bits != 0: 428 assert val.bit_size == 0 or val.bit_size == dst_type_bits 429 return dst_type_bits 430 else: 431 if val.common_class != 0: 432 assert val.bit_size == 0 or val.bit_size == val.common_class 433 else: 434 val.common_class = val.bit_size 435 return val.common_class 436 437 def _validate_bit_class_down(self, val, bit_class): 438 # At this point, everything *must* have a bit class. Otherwise, we have 439 # a value we don't know how to define. 440 assert bit_class != 0 441 442 if isinstance(val, Constant): 443 assert val.bit_size == 0 or val.bit_size == bit_class 444 445 elif isinstance(val, Variable): 446 assert val.bit_size == 0 or val.bit_size == bit_class 447 448 elif isinstance(val, Expression): 449 nir_op = opcodes[val.opcode] 450 dst_type_bits = type_bits(nir_op.output_type) 451 if dst_type_bits != 0: 452 assert bit_class == dst_type_bits 453 else: 454 assert val.common_class == 0 or val.common_class == bit_class 455 val.common_class = bit_class 456 457 for i in range(nir_op.num_inputs): 458 src_type_bits = type_bits(nir_op.input_types[i]) 459 if src_type_bits != 0: 460 self._validate_bit_class_down(val.sources[i], src_type_bits) 461 else: 462 self._validate_bit_class_down(val.sources[i], val.common_class) 463 464_optimization_ids = itertools.count() 465 466condition_list = ['true'] 467 468class SearchAndReplace(object): 469 def __init__(self, transform): 470 self.id = _optimization_ids.next() 471 472 search = transform[0] 473 replace = transform[1] 474 if len(transform) > 2: 475 self.condition = transform[2] 476 else: 477 self.condition = 'true' 478 479 if self.condition not in condition_list: 480 condition_list.append(self.condition) 481 self.condition_index = condition_list.index(self.condition) 482 483 varset = VarSet() 484 if isinstance(search, Expression): 485 self.search = search 486 else: 487 self.search = Expression(search, "search{0}".format(self.id), varset) 488 489 varset.lock() 490 491 if isinstance(replace, Value): 492 self.replace = replace 493 else: 494 self.replace = Value.create(replace, "replace{0}".format(self.id), varset) 495 496 BitSizeValidator(varset).validate(self.search, self.replace) 497 498_algebraic_pass_template = mako.template.Template(""" 499#include "nir.h" 500#include "nir_search.h" 501#include "nir_search_helpers.h" 502 503#ifndef NIR_OPT_ALGEBRAIC_STRUCT_DEFS 504#define NIR_OPT_ALGEBRAIC_STRUCT_DEFS 505 506struct transform { 507 const nir_search_expression *search; 508 const nir_search_value *replace; 509 unsigned condition_offset; 510}; 511 512#endif 513 514% for (opcode, xform_list) in xform_dict.iteritems(): 515% for xform in xform_list: 516 ${xform.search.render()} 517 ${xform.replace.render()} 518% endfor 519 520static const struct transform ${pass_name}_${opcode}_xforms[] = { 521% for xform in xform_list: 522 { &${xform.search.name}, ${xform.replace.c_ptr}, ${xform.condition_index} }, 523% endfor 524}; 525% endfor 526 527static bool 528${pass_name}_block(nir_block *block, const bool *condition_flags, 529 void *mem_ctx) 530{ 531 bool progress = false; 532 533 nir_foreach_instr_reverse_safe(instr, block) { 534 if (instr->type != nir_instr_type_alu) 535 continue; 536 537 nir_alu_instr *alu = nir_instr_as_alu(instr); 538 if (!alu->dest.dest.is_ssa) 539 continue; 540 541 switch (alu->op) { 542 % for opcode in xform_dict.keys(): 543 case nir_op_${opcode}: 544 for (unsigned i = 0; i < ARRAY_SIZE(${pass_name}_${opcode}_xforms); i++) { 545 const struct transform *xform = &${pass_name}_${opcode}_xforms[i]; 546 if (condition_flags[xform->condition_offset] && 547 nir_replace_instr(alu, xform->search, xform->replace, 548 mem_ctx)) { 549 progress = true; 550 break; 551 } 552 } 553 break; 554 % endfor 555 default: 556 break; 557 } 558 } 559 560 return progress; 561} 562 563static bool 564${pass_name}_impl(nir_function_impl *impl, const bool *condition_flags) 565{ 566 void *mem_ctx = ralloc_parent(impl); 567 bool progress = false; 568 569 nir_foreach_block_reverse(block, impl) { 570 progress |= ${pass_name}_block(block, condition_flags, mem_ctx); 571 } 572 573 if (progress) 574 nir_metadata_preserve(impl, nir_metadata_block_index | 575 nir_metadata_dominance); 576 577 return progress; 578} 579 580 581bool 582${pass_name}(nir_shader *shader) 583{ 584 bool progress = false; 585 bool condition_flags[${len(condition_list)}]; 586 const nir_shader_compiler_options *options = shader->options; 587 (void) options; 588 589 % for index, condition in enumerate(condition_list): 590 condition_flags[${index}] = ${condition}; 591 % endfor 592 593 nir_foreach_function(function, shader) { 594 if (function->impl) 595 progress |= ${pass_name}_impl(function->impl, condition_flags); 596 } 597 598 return progress; 599} 600""") 601 602class AlgebraicPass(object): 603 def __init__(self, pass_name, transforms): 604 self.xform_dict = {} 605 self.pass_name = pass_name 606 607 error = False 608 609 for xform in transforms: 610 if not isinstance(xform, SearchAndReplace): 611 try: 612 xform = SearchAndReplace(xform) 613 except: 614 print("Failed to parse transformation:", file=sys.stderr) 615 print(" " + str(xform), file=sys.stderr) 616 traceback.print_exc(file=sys.stderr) 617 print('', file=sys.stderr) 618 error = True 619 continue 620 621 if xform.search.opcode not in self.xform_dict: 622 self.xform_dict[xform.search.opcode] = [] 623 624 self.xform_dict[xform.search.opcode].append(xform) 625 626 if error: 627 sys.exit(1) 628 629 def render(self): 630 return _algebraic_pass_template.render(pass_name=self.pass_name, 631 xform_dict=self.xform_dict, 632 condition_list=condition_list) 633