1# 2# Copyright (C) 2020 Collabora, Ltd. 3# Copyright (C) 2018 Alyssa Rosenzweig 4# Copyright (C) 2016 Intel Corporation 5# 6# Permission is hereby granted, free of charge, to any person obtaining a 7# copy of this software and associated documentation files (the "Software"), 8# to deal in the Software without restriction, including without limitation 9# the rights to use, copy, modify, merge, publish, distribute, sublicense, 10# and/or sell copies of the Software, and to permit persons to whom the 11# Software is furnished to do so, subject to the following conditions: 12# 13# The above copyright notice and this permission notice (including the next 14# paragraph) shall be included in all copies or substantial portions of the 15# Software. 16# 17# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 20# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 22# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS 23# IN THE SOFTWARE. 24 25import argparse 26import sys 27import math 28 29a = 'a' 30b = 'b' 31c = 'c' 32 33algebraic_late = [ 34 # ineg must be lowered late, but only for integers; floats will try to 35 # have modifiers attached... hence why this has to be here rather than 36 # a more standard lower_negate approach 37 38 (('ineg', a), ('isub', 0, a)), 39] 40 41for isz in ('8', '16', '32'): 42 for osz in ('16', '32', '64'): 43 algebraic_late += [(('b2f' + osz, 'a@' + isz), ('b' + isz + 'csel', a, 1.0, 0.0))] 44 45# There's no native integer min/max instruction, lower those to cmp+bcsel 46for sz in ('8', '16', '32'): 47 for t in ('i', 'u'): 48 algebraic_late += [ 49 ((t + 'min', 'a@' + sz, 'b@' + sz), ('b' + sz + 'csel', (t + 'lt' + sz, a, b), a, b)), 50 ((t + 'max', 'a@' + sz, 'b@' + sz), ('b' + sz + 'csel', (t + 'lt' + sz, b, a), a, b)) 51 ] 52 53# Midgard is able to type convert down by only one "step" per instruction; if 54# NIR wants more than one step, we need to break up into multiple instructions 55 56converts = [] 57 58for op in ('u2u', 'i2i', 'f2f', 'i2f', 'u2f', 'f2i', 'f2u'): 59 srcsz_max = 64 60 dstsz_max = 64 61 # 8 bit float doesn't exist 62 srcsz_min = 8 if op[0] != 'f' else 16 63 dstsz_min = 8 if op[2] != 'f' else 16 64 dstsz = dstsz_min 65 # Iterate over all possible destination and source sizes 66 while dstsz <= dstsz_max: 67 srcsz = srcsz_min 68 while srcsz <= srcsz_max: 69 # Size converter lowering is only needed if src and dst sizes are 70 # spaced by a factor > 2. 71 # Type converter lowering is needed as soon as src_size != dst_size 72 if srcsz != dstsz and ((srcsz * 2 != dstsz and srcsz != dstsz * 2) or op[0] != op[2]): 73 cursz = srcsz 74 rule = a 75 # When converting down we first do the type conversion followed 76 # by one or more size conversions. When converting up, we do 77 # the type conversion at the end. This way we don't have to 78 # deal with the fact that f2f8 doesn't exists. 79 sizeconvop = op[0] + '2' + op[0] if srcsz < dstsz else op[2] + '2' + op[2] 80 if srcsz > dstsz and op[0] != op[2]: 81 rule = (op + str(int(cursz)), rule) 82 while cursz != dstsz: 83 cursz = cursz / 2 if dstsz < srcsz else cursz * 2 84 rule = (sizeconvop + str(int(cursz)), rule) 85 if srcsz < dstsz and op[0] != op[2]: 86 rule = (op + str(int(cursz)), rule) 87 converts += [((op + str(int(dstsz)), 'a@' + str(int(srcsz))), rule)] 88 srcsz *= 2 89 dstsz *= 2 90 91# Bifrost doesn't have fp16 for a lot of special ops 92SPECIAL = ['fexp2', 'flog2', 'fsin', 'fcos'] 93 94for op in SPECIAL: 95 converts += [((op + '@16', a), ('f2f16', (op, ('f2f32', a))))] 96 97converts += [(('f2b32', a), ('fneu32', a, 0.0)), 98 (('i2b32', a), ('ine32', a, 0)), 99 (('b2i32', a), ('iand', 'a@32', 1))] 100 101def main(): 102 parser = argparse.ArgumentParser() 103 parser.add_argument('-p', '--import-path', required=True) 104 args = parser.parse_args() 105 sys.path.insert(0, args.import_path) 106 run() 107 108 109def run(): 110 import nir_algebraic # pylint: disable=import-error 111 112 print('#include "bifrost_nir.h"') 113 114 print(nir_algebraic.AlgebraicPass("bifrost_nir_lower_algebraic_late", 115 algebraic_late + converts).render()) 116 117if __name__ == '__main__': 118 main() 119