• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/bin/env python
2COPYRIGHT = """\
3/*
4 * Copyright 2021 Intel Corporation
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a
7 * copy of this software and associated documentation files (the
8 * "Software"), to deal in the Software without restriction, including
9 * without limitation the rights to use, copy, modify, merge, publish,
10 * distribute, sub license, and/or sell copies of the Software, and to
11 * permit persons to whom the Software is furnished to do so, subject to
12 * the following conditions:
13 *
14 * The above copyright notice and this permission notice (including the
15 * next paragraph) shall be included in all copies or substantial portions
16 * of the Software.
17 *
18 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
19 * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
20 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NON-INFRINGEMENT.
21 * IN NO EVENT SHALL VMWARE AND/OR ITS SUPPLIERS BE LIABLE FOR
22 * ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
23 * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
24 * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
25 */
26"""
27
28import argparse
29import os.path
30import re
31import sys
32
33from grl_parser import parse_grl_file
34
35class Writer(object):
36    def __init__(self, file):
37        self._file = file
38        self._indent = 0
39        self._new_line = True
40
41    def push_indent(self, levels=4):
42        self._indent += levels
43
44    def pop_indent(self, levels=4):
45        self._indent -= levels
46
47    def write(self, s, *fmt):
48        if self._new_line:
49            s = '\n' + s
50        self._new_line = False
51        if s.endswith('\n'):
52            self._new_line = True
53            s = s[:-1]
54        if fmt:
55            s = s.format(*fmt)
56        self._file.write(s.replace('\n', '\n' + ' ' * self._indent))
57
58# Internal Representation
59
60class Value(object):
61    def __init__(self, name=None, zone=None):
62        self.name = name
63        self._zone = zone
64        self.live = False
65
66    @property
67    def zone(self):
68        assert self._zone is not None
69        return self._zone
70
71    def is_reg(self):
72        return False
73
74    def c_val(self):
75        if not self.name:
76            print(self)
77        assert self.name
78        return self.name
79
80    def c_cpu_val(self):
81        assert self.zone == 'cpu'
82        return self.c_val()
83
84    def c_gpu_val(self):
85        if self.zone == 'gpu':
86            return self.c_val()
87        else:
88            return 'mi_imm({})'.format(self.c_cpu_val())
89
90class Constant(Value):
91    def __init__(self, value):
92        super().__init__(zone='cpu')
93        self.value = value
94
95    def c_val(self):
96        if self.value < 100:
97            return str(self.value)
98        elif self.value < (1 << 32):
99            return '0x{:x}u'.format(self.value)
100        else:
101            return '0x{:x}ull'.format(self.value)
102
103class Register(Value):
104    def __init__(self, name):
105        super().__init__(name=name, zone='gpu')
106
107    def is_reg(self):
108        return True
109
110class FixedGPR(Register):
111    def __init__(self, num):
112        super().__init__('REG{}'.format(num))
113        self.num = num
114
115    def write_c(self, w):
116        w.write('UNUSED struct mi_value {} = mi_reserve_gpr(&b, {});\n',
117                self.name, self.num)
118
119class GroupSizeRegister(Register):
120    def __init__(self, comp):
121        super().__init__('DISPATCHDIM_' + 'XYZ'[comp])
122        self.comp = comp
123
124class Member(Value):
125    def __init__(self, value, member):
126        super().__init__(zone=value.zone)
127        self.value = value
128        self.member = member
129
130    def is_reg(self):
131        return self.value.is_reg()
132
133    def c_val(self):
134        c_val = self.value.c_val()
135        if self.zone == 'gpu':
136            assert isinstance(self.value, Register)
137            if self.member == 'hi':
138                return 'mi_value_half({}, true)'.format(c_val)
139            elif self.member == 'lo':
140                return 'mi_value_half({}, false)'.format(c_val)
141            else:
142                assert False, 'Invalid member: {}'.format(self.member)
143        else:
144            return '.'.join([c_val, self.member])
145
146class OffsetOf(Value):
147    def __init__(self, mk, expr):
148        super().__init__(zone='cpu')
149        assert isinstance(expr, tuple) and expr[0] == 'member'
150        self.type = mk.m.get_type(expr[1])
151        self.field = expr[2]
152
153    def c_val(self):
154        return 'offsetof({}, {})'.format(self.type.c_name, self.field)
155
156class Scope(object):
157    def __init__(self, m, mk, parent):
158        self.m = m
159        self.mk = mk
160        self.parent = parent
161        self.defs = {}
162
163    def add_def(self, d, name=None):
164        if name is None:
165            name = d.name
166        assert name not in self.defs
167        self.defs[name] = d
168
169    def get_def(self, name):
170        if name in self.defs:
171            return self.defs[name]
172        assert self.parent, 'Unknown definition: "{}"'.format(name)
173        return self.parent.get_def(name)
174
175class Statement(object):
176    def __init__(self, srcs=[]):
177        assert isinstance(srcs, (list, tuple))
178        self.srcs = list(srcs)
179
180class SSAStatement(Statement, Value):
181    _count = 0
182
183    def __init__(self, zone, srcs):
184        Statement.__init__(self, srcs)
185        Value.__init__(self, None, zone)
186        self.c_name = '_tmp{}'.format(SSAStatement._count)
187        SSAStatement._count += 1
188
189    def c_val(self):
190        return self.c_name
191
192    def write_c_refs(self, w):
193        assert self.zone == 'gpu'
194        assert self.uses > 0
195        if self.uses > 1:
196            w.write('mi_value_add_refs(&b, {}, {});\n',
197                    self.c_name, self.uses - 1)
198
199class Half(SSAStatement):
200    def __init__(self, value, half):
201        assert half in ('hi', 'lo')
202        super().__init__(None, [value])
203        self.half = half
204
205    @property
206    def zone(self):
207        return self.srcs[0].zone
208
209    def write_c(self, w):
210        assert self.half in ('hi', 'lo')
211        if self.zone == 'cpu':
212            if self.half == 'hi':
213                w.write('uint32_t {} = (uint64_t)({}) >> 32;\n',
214                        self.c_name, self.srcs[0].c_cpu_val())
215            else:
216                w.write('uint32_t {} = {};\n',
217                        self.c_name, self.srcs[0].c_cpu_val())
218        else:
219            if self.half == 'hi':
220                w.write('struct mi_value {} = mi_value_half({}, true);\n',
221                        self.c_name, self.srcs[0].c_gpu_val())
222            else:
223                w.write('struct mi_value {} = mi_value_half({}, false);\n',
224                        self.c_name, self.srcs[0].c_gpu_val())
225            self.write_c_refs(w)
226
227class Expression(SSAStatement):
228    def __init__(self, mk, op, *srcs):
229        super().__init__(None, srcs)
230        self.op = op
231
232    @property
233    def zone(self):
234        zone = 'cpu'
235        for s in self.srcs:
236            if s.zone == 'gpu':
237                zone = 'gpu'
238        return zone
239
240    def write_c(self, w):
241        if self.zone == 'cpu':
242            w.write('uint64_t {} = ', self.c_name)
243            c_cpu_vals = [s.c_cpu_val() for s in self.srcs]
244            if len(self.srcs) == 1:
245                w.write('({} {})', self.op, c_cpu_vals[0])
246            elif len(self.srcs) == 2:
247                w.write('({} {} {})', c_cpu_vals[0], self.op, c_cpu_vals[1])
248            else:
249                assert len(self.srcs) == 3 and op == '?'
250                w.write('({} ? {} : {})', *c_cpu_vals)
251            w.write(';\n')
252            return
253
254        w.write('struct mi_value {} = ', self.c_name)
255        if self.op == '~':
256            w.write('mi_inot(&b, {});\n', self.srcs[0].c_gpu_val())
257        elif self.op == '+':
258            w.write('mi_iadd(&b, {}, {});\n',
259                    self.srcs[0].c_gpu_val(), self.srcs[1].c_gpu_val())
260        elif self.op == '-':
261            w.write('mi_isub(&b, {}, {});\n',
262                    self.srcs[0].c_gpu_val(), self.srcs[1].c_gpu_val())
263        elif self.op == '&':
264            w.write('mi_iand(&b, {}, {});\n',
265                    self.srcs[0].c_gpu_val(), self.srcs[1].c_gpu_val())
266        elif self.op == '|':
267            w.write('mi_ior(&b, {}, {});\n',
268                    self.srcs[0].c_gpu_val(), self.srcs[1].c_gpu_val())
269        elif self.op == '<<':
270            if self.srcs[1].zone == 'cpu':
271                w.write('mi_ishl_imm(&b, {}, {});\n',
272                        self.srcs[0].c_gpu_val(), self.srcs[1].c_cpu_val())
273            else:
274                w.write('mi_ishl(&b, {}, {});\n',
275                        self.srcs[0].c_gpu_val(), self.srcs[1].c_gpu_val())
276        elif self.op == '>>':
277            if self.srcs[1].zone == 'cpu':
278                w.write('mi_ushr_imm(&b, {}, {});\n',
279                        self.srcs[0].c_gpu_val(), self.srcs[1].c_cpu_val())
280            else:
281                w.write('mi_ushr(&b, {}, {});\n',
282                        self.srcs[0].c_gpu_val(), self.srcs[1].c_gpu_val())
283        elif self.op == '==':
284            w.write('mi_ieq(&b, {}, {});\n',
285                    self.srcs[0].c_gpu_val(), self.srcs[1].c_gpu_val())
286        elif self.op == '<':
287            w.write('mi_ult(&b, {}, {});\n',
288                    self.srcs[0].c_gpu_val(), self.srcs[1].c_gpu_val())
289        elif self.op == '>':
290            w.write('mi_ult(&b, {}, {});\n',
291                    self.srcs[1].c_gpu_val(), self.srcs[0].c_gpu_val())
292        elif self.op == '<=':
293            w.write('mi_uge(&b, {}, {});\n',
294                    self.srcs[1].c_gpu_val(), self.srcs[0].c_gpu_val())
295        else:
296            assert False, 'Unknown expression opcode: {}'.format(self.op)
297        self.write_c_refs(w)
298
299class StoreReg(Statement):
300    def __init__(self, mk, reg, value):
301        super().__init__([mk.load_value(value)])
302        self.reg = mk.parse_value(reg)
303        assert self.reg.is_reg()
304
305    def write_c(self, w):
306        value = self.srcs[0]
307        w.write('mi_store(&b, {}, {});\n',
308                self.reg.c_gpu_val(), value.c_gpu_val())
309
310class LoadMem(SSAStatement):
311    def __init__(self, mk, bit_size, addr):
312        super().__init__('gpu', [mk.load_value(addr)])
313        self.bit_size = bit_size
314
315    def write_c(self, w):
316        addr = self.srcs[0]
317        w.write('struct mi_value {} = ', self.c_name)
318        if addr.zone == 'cpu':
319            w.write('mi_mem{}(anv_address_from_u64({}));\n',
320                    self.bit_size, addr.c_cpu_val())
321        else:
322            assert self.bit_size == 64
323            w.write('mi_load_mem64_offset(&b, anv_address_from_u64(0), {});\n',
324                    addr.c_gpu_val())
325        self.write_c_refs(w)
326
327class StoreMem(Statement):
328    def __init__(self, mk, bit_size, addr, src):
329        super().__init__([mk.load_value(addr), mk.load_value(src)])
330        self.bit_size = bit_size
331
332    def write_c(self, w):
333        addr, data = tuple(self.srcs)
334        if addr.zone == 'cpu':
335            w.write('mi_store(&b, mi_mem{}(anv_address_from_u64({})), {});\n',
336                    self.bit_size, addr.c_cpu_val(), data.c_gpu_val())
337        else:
338            assert self.bit_size == 64
339            w.write('mi_store_mem64_offset(&b, anv_address_from_u64(0), {}, {});\n',
340                    addr.c_gpu_val(), data.c_gpu_val())
341
342class GoTo(Statement):
343    def __init__(self, mk, target_id, cond=None, invert=False):
344        cond = [mk.load_value(cond)] if cond is not None else []
345        super().__init__(cond)
346        self.target_id = target_id
347        self.invert = invert
348        self.mk = mk
349
350    def write_c(self, w):
351        # Now that we've parsed the entire metakernel, we can look up the
352        # actual target from the id
353        target = self.mk.get_goto_target(self.target_id)
354
355        if self.srcs:
356            cond = self.srcs[0]
357            if self.invert:
358                w.write('mi_goto_if(&b, mi_inot(&b, {}), &{});\n', cond.c_gpu_val(), target.c_name)
359            else:
360                w.write('mi_goto_if(&b, {}, &{});\n', cond.c_gpu_val(), target.c_name)
361        else:
362            w.write('mi_goto(&b, &{});\n', target.c_name)
363
364class GoToTarget(Statement):
365    def __init__(self, mk, name):
366        super().__init__()
367        self.name = name
368        self.c_name = '_goto_target_' + name
369        self.goto_tokens = []
370
371        mk = mk.add_goto_target(self)
372
373    def write_decl(self, w):
374        w.write('struct mi_goto_target {} = MI_GOTO_TARGET_INIT;\n',
375                self.c_name)
376
377    def write_c(self, w):
378        w.write('mi_goto_target(&b, &{});\n', self.c_name)
379
380class Dispatch(Statement):
381    def __init__(self, mk, kernel, group_size, args, postsync):
382        if group_size is None:
383            srcs = [mk.scope.get_def('DISPATCHDIM_{}'.format(d)) for d in 'XYZ']
384        else:
385            srcs = [mk.load_value(s) for s in group_size]
386        srcs += [mk.load_value(a) for a in args]
387        super().__init__(srcs)
388        self.kernel = mk.m.kernels[kernel]
389        self.indirect = group_size is None
390        self.postsync = postsync
391
392    def write_c(self, w):
393        w.write('{\n')
394        w.push_indent()
395
396        group_size = self.srcs[:3]
397        args = self.srcs[3:]
398        if not self.indirect:
399            w.write('const uint32_t _group_size[3] = {{ {}, {}, {} }};\n',
400                    *[s.c_cpu_val() for s in group_size])
401            gs = '_group_size'
402        else:
403            gs = 'NULL'
404
405        w.write('const struct anv_kernel_arg _args[] = {\n')
406        w.push_indent()
407        for arg in args:
408            w.write('{{ .u64 = {} }},\n', arg.c_cpu_val())
409        w.pop_indent()
410        w.write('};\n')
411
412        w.write('genX(grl_dispatch)(cmd_buffer, {},\n', self.kernel.c_name)
413        w.write('                   {}, ARRAY_SIZE(_args), _args);\n', gs)
414        w.pop_indent()
415        w.write('}\n')
416
417class SemWait(Statement):
418    def __init__(self, scope, wait):
419        super().__init__()
420        self.wait = wait
421
422class Control(Statement):
423    def __init__(self, scope, wait):
424        super().__init__()
425        self.wait = wait
426
427    def write_c(self, w):
428        w.write('cmd_buffer->state.pending_pipe_bits |=\n')
429        w.write('    ANV_PIPE_CS_STALL_BIT |\n')
430        w.write('    ANV_PIPE_DATA_CACHE_FLUSH_BIT |\n')
431        w.write('    ANV_PIPE_UNTYPED_DATAPORT_CACHE_FLUSH_BIT;\n')
432        w.write('genX(cmd_buffer_apply_pipe_flushes)(cmd_buffer);\n')
433
434TYPE_REMAPS = {
435    'dword' : 'uint32_t',
436    'qword' : 'uint64_t',
437}
438
439class Module(object):
440    def __init__(self, grl_dir, elems):
441        assert isinstance(elems[0], tuple)
442        assert elems[0][0] == 'module-name'
443        self.grl_dir = grl_dir
444        self.name = elems[0][1]
445        self.kernels = {}
446        self.structs = {}
447        self.constants = []
448        self.metakernels = []
449        self.regs = {}
450
451        scope = Scope(self, None, None)
452        for e in elems[1:]:
453            if e[0] == 'kernel':
454                k = Kernel(self, *e[1:])
455                assert k.name not in self.kernels
456                self.kernels[k.name] = k
457            elif e[0] == 'kernel-module':
458                m = KernelModule(self, *e[1:])
459                for k in m.kernels:
460                    assert k.name not in self.kernels
461                    self.kernels[k.name] = k
462            elif e[0] == 'struct':
463                s = Struct(self, *e[1:])
464                assert s.name not in self.kernels
465                self.structs[s.name] = s
466            elif e[0] == 'named-constant':
467                c = NamedConstant(*e[1:])
468                scope.add_def(c)
469                self.constants.append(c)
470            elif e[0] == 'meta-kernel':
471                mk = MetaKernel(self, scope, *e[1:])
472                self.metakernels.append(mk)
473            elif e[0] == 'import':
474                assert e[2] == 'struct'
475                self.import_struct(e[1], e[3])
476            else:
477                assert False, 'Invalid module-level token: {}'.format(t[0])
478
479    def import_struct(self, filename, struct_name):
480        elems = parse_grl_file(os.path.join(self.grl_dir, filename), [])
481        assert elems
482        for e in elems[1:]:
483            if e[0] == 'struct' and e[1] == struct_name:
484                s = Struct(self, *e[1:])
485                assert s.name not in self.kernels
486                self.structs[s.name] = s
487                return
488        assert False, "Struct {0} not found in {1}".format(struct_name, filename)
489
490    def get_type(self, name):
491        if name in self.structs:
492            return self.structs[name]
493        return BasicType(TYPE_REMAPS.get(name, name))
494
495    def get_fixed_gpr(self, num):
496        assert isinstance(num, int)
497        if num in self.regs:
498            return self.regs[num]
499
500        reg = FixedGPR(num)
501        self.regs[num] = reg
502        return reg
503
504    def optimize(self):
505        progress = True
506        while progress:
507            progress = False
508
509            # Copy Propagation
510            for mk in self.metakernels:
511                if mk.opt_copy_prop():
512                    progress = True
513
514            # Dead Code Elimination
515            for r in self.regs.values():
516                r.live = False
517            for c in self.constants:
518                c.live = False
519            for mk in self.metakernels:
520                mk.opt_dead_code1()
521            for mk in self.metakernels:
522                if mk.opt_dead_code2():
523                    progress = True
524            for n in list(self.regs.keys()):
525                if not self.regs[n].live:
526                    del self.regs[n]
527                    progress = True
528            self.constants = [c for c in self.constants if c.live]
529
530    def compact_regs(self):
531        old_regs = self.regs
532        self.regs = {}
533        for i, reg in enumerate(old_regs.values()):
534            reg.num = i
535            self.regs[i] = reg
536
537    def write_h(self, w):
538        for s in self.structs.values():
539            s.write_h(w)
540        for mk in self.metakernels:
541            mk.write_h(w)
542
543    def write_c(self, w):
544        for c in self.constants:
545            c.write_c(w)
546        for mk in self.metakernels:
547            mk.write_c(w)
548
549class Kernel(object):
550    def __init__(self, m, name, ann):
551        self.name = name
552        self.source_file = ann['source']
553        self.kernel_name = self.source_file.replace('/', '_')[:-3].upper()
554        self.entrypoint = ann['kernelFunction']
555
556        assert self.source_file.endswith('.cl')
557        self.c_name = '_'.join([
558            'GRL_CL_KERNEL',
559            self.kernel_name,
560            self.entrypoint.upper(),
561        ])
562
563class KernelModule(object):
564    def __init__(self, m, name, source, kernels):
565        self.name = name
566        self.kernels = []
567        self.libraries = []
568
569        for k in kernels:
570            if k[0] == 'kernel':
571                k[2]['source'] = source
572                self.kernels.append(Kernel(m, *k[1:]))
573            elif k[0] == 'library':
574                # Skip this for now.
575                pass
576
577class BasicType(object):
578    def __init__(self, name):
579        self.name = name
580        self.c_name = name
581
582class Struct(object):
583    def __init__(self, m, name, fields, align):
584        assert align == 0
585        self.name = name
586        self.c_name = 'struct ' + '_'.join(['grl', m.name, self.name])
587        self.fields = [(m.get_type(t), n) for t, n in fields]
588
589    def write_h(self, w):
590        w.write('{} {{\n', self.c_name)
591        w.push_indent()
592        for f in self.fields:
593            w.write('{} {};\n', f[0].c_name, f[1])
594        w.pop_indent()
595        w.write('};\n')
596
597class NamedConstant(Value):
598    def __init__(self, name, value):
599        super().__init__(name, 'cpu')
600        self.name = name
601        self.value = Constant(value)
602        self.written = False
603
604    def set_module(self, m):
605        pass
606
607    def write_c(self, w):
608        if self.written:
609            return
610        w.write('static const uint64_t {} = {};\n',
611                self.name, self.value.c_val())
612        self.written = True
613
614class MetaKernelParameter(Value):
615    def __init__(self, mk, type, name):
616        super().__init__(name, 'cpu')
617        self.type = mk.m.get_type(type)
618
619class MetaKernel(object):
620    def __init__(self, m, m_scope, name, params, ann, statements):
621        self.m = m
622        self.name = name
623        self.c_name = '_'.join(['grl', m.name, self.name])
624        self.goto_targets = {}
625        self.num_tmps = 0
626
627        mk_scope = Scope(m, self, m_scope)
628
629        self.params = [MetaKernelParameter(self, *p) for p in params]
630        for p in self.params:
631            mk_scope.add_def(p)
632
633        mk_scope.add_def(GroupSizeRegister(0), name='DISPATCHDIM_X')
634        mk_scope.add_def(GroupSizeRegister(1), name='DISPATCHDIM_Y')
635        mk_scope.add_def(GroupSizeRegister(2), name='DISPATCHDIM_Z')
636
637        self.statements = []
638        self.parse_stmt(mk_scope, statements)
639        self.scope = None
640
641    def get_tmp(self):
642        tmpN = '_tmp{}'.format(self.num_tmps)
643        self.num_tmps += 1
644        return tmpN
645
646    def add_stmt(self, stmt):
647        self.statements.append(stmt)
648        return stmt
649
650    def parse_value(self, v):
651        if isinstance(v, Value):
652            return v
653        elif isinstance(v, str):
654            if re.match(r'REG\d+', v):
655                return self.m.get_fixed_gpr(int(v[3:]))
656            else:
657                return self.scope.get_def(v)
658        elif isinstance(v, int):
659            return Constant(v)
660        elif isinstance(v, tuple):
661            if v[0] == 'member':
662                return Member(self.parse_value(v[1]), v[2])
663            elif v[0] == 'offsetof':
664                return OffsetOf(self, v[1])
665            else:
666                op = v[0]
667                srcs = [self.parse_value(s) for s in v[1:]]
668                return self.add_stmt(Expression(self, op, *srcs))
669        else:
670            assert False, 'Invalid value: {}'.format(v[0])
671
672    def load_value(self, v):
673        v = self.parse_value(v)
674        if isinstance(v, Member) and v.zone == 'gpu':
675            v = self.add_stmt(Half(v.value, v.member))
676        return v
677
678    def parse_stmt(self, scope, s):
679        self.scope = scope
680        if isinstance(s, list):
681            subscope = Scope(self.m, self, scope)
682            for stmt in s:
683                self.parse_stmt(subscope, stmt)
684        elif s[0] == 'define':
685            scope.add_def(self.parse_value(s[2]), name=s[1])
686        elif s[0] == 'assign':
687            self.add_stmt(StoreReg(self, *s[1:]))
688        elif s[0] == 'dispatch':
689            self.add_stmt(Dispatch(self, *s[1:]))
690        elif s[0] == 'load-dword':
691            v = self.add_stmt(LoadMem(self, 32, s[2]))
692            self.add_stmt(StoreReg(self, s[1], v))
693        elif s[0] == 'load-qword':
694            v = self.add_stmt(LoadMem(self, 64, s[2]))
695            self.add_stmt(StoreReg(self, s[1], v))
696        elif s[0] == 'store-dword':
697            self.add_stmt(StoreMem(self, 32, *s[1:]))
698        elif s[0] == 'store-qword':
699            self.add_stmt(StoreMem(self, 64, *s[1:]))
700        elif s[0] == 'goto':
701            self.add_stmt(GoTo(self, s[1]))
702        elif s[0] == 'goto-if':
703            self.add_stmt(GoTo(self, s[1], s[2]))
704        elif s[0] == 'goto-if-not':
705            self.add_stmt(GoTo(self, s[1], s[2], invert=True))
706        elif s[0] == 'label':
707            self.add_stmt(GoToTarget(self, s[1]))
708        elif s[0] == 'control':
709            self.add_stmt(Control(self, s[1]))
710        elif s[0] == 'sem-wait-while':
711            self.add_stmt(Control(self, s[1]))
712        else:
713            assert False, 'Invalid statement: {}'.format(s[0])
714
715    def add_goto_target(self, t):
716        assert t.name not in self.goto_targets
717        self.goto_targets[t.name] = t
718
719    def get_goto_target(self, name):
720        return self.goto_targets[name]
721
722    def opt_copy_prop(self):
723        progress = False
724        copies = {}
725        for stmt in self.statements:
726            for i in range(len(stmt.srcs)):
727                src = stmt.srcs[i]
728                if isinstance(src, FixedGPR) and src.num in copies:
729                    stmt.srcs[i] = copies[src.num]
730                    progress = True
731
732            if isinstance(stmt, StoreReg):
733                reg = stmt.reg
734                if isinstance(reg, Member):
735                    reg = reg.value
736
737                if isinstance(reg, FixedGPR):
738                    copies.pop(reg.num, None)
739                    if not stmt.srcs[0].is_reg():
740                        copies[reg.num] = stmt.srcs[0]
741            elif isinstance(stmt, (GoTo, GoToTarget)):
742                copies = {}
743
744        return progress
745
746    def opt_dead_code1(self):
747        for stmt in self.statements:
748            # Mark every register which is read as live
749            for src in stmt.srcs:
750                if isinstance(src, Register):
751                    src.live = True
752
753            # Initialize every SSA statement to dead
754            if isinstance(stmt, SSAStatement):
755                stmt.live = False
756
757    def opt_dead_code2(self):
758        def yield_live(statements):
759            gprs_read = set(self.m.regs.keys())
760            for stmt in statements:
761                if isinstance(stmt, SSAStatement):
762                    if not stmt.live:
763                        continue
764                elif isinstance(stmt, StoreReg):
765                    reg = stmt.reg
766                    if isinstance(reg, Member):
767                        reg = reg.value
768
769                    if not stmt.reg.live:
770                        continue
771
772                    if isinstance(reg, FixedGPR):
773                        if reg.num in gprs_read:
774                            gprs_read.remove(reg.num)
775                        else:
776                            continue
777                elif isinstance(stmt, (GoTo, GoToTarget)):
778                    gprs_read = set(self.m.regs.keys())
779
780                for src in stmt.srcs:
781                    src.live = True
782                    if isinstance(src, FixedGPR):
783                        gprs_read.add(src.num)
784                yield stmt
785
786        old_stmt_list = self.statements
787        old_stmt_list.reverse()
788        self.statements = list(yield_live(old_stmt_list))
789        self.statements.reverse()
790        return len(self.statements) != len(old_stmt_list)
791
792    def count_ssa_value_uses(self):
793        for stmt in self.statements:
794            if isinstance(stmt, SSAStatement):
795                stmt.uses = 0
796
797            for src in stmt.srcs:
798                if isinstance(src, SSAStatement):
799                    src.uses += 1
800
801    def write_h(self, w):
802        w.write('void\n')
803        w.write('genX({})(\n', self.c_name)
804        w.push_indent()
805        w.write('struct anv_cmd_buffer *cmd_buffer')
806        for p in self.params:
807            w.write(',\n{} {}', p.type.c_name, p.name)
808        w.write(');\n')
809        w.pop_indent()
810
811    def write_c(self, w):
812        w.write('void\n')
813        w.write('genX({})(\n', self.c_name)
814        w.push_indent()
815        w.write('struct anv_cmd_buffer *cmd_buffer')
816        for p in self.params:
817            w.write(',\n{} {}', p.type.c_name, p.name)
818        w.write(')\n')
819        w.pop_indent()
820        w.write('{\n')
821        w.push_indent()
822
823        w.write('struct mi_builder b;\n')
824        w.write('mi_builder_init(&b, cmd_buffer->device->info, &cmd_buffer->batch);\n')
825        w.write('/* TODO: use anv_mocs? */\n');
826        w.write('const uint32_t mocs = isl_mocs(&cmd_buffer->device->isl_dev, 0, false);\n');
827        w.write('mi_builder_set_mocs(&b, mocs);\n');
828        w.write('\n')
829
830        for r in self.m.regs.values():
831            r.write_c(w)
832        w.write('\n')
833
834        for t in self.goto_targets.values():
835            t.write_decl(w)
836        w.write('\n')
837
838        self.count_ssa_value_uses()
839        for s in self.statements:
840            s.write_c(w)
841
842        w.pop_indent()
843
844        w.write('}\n')
845
846HEADER_PROLOGUE = COPYRIGHT + '''
847#include "anv_private.h"
848#include "grl/genX_grl.h"
849
850#ifndef {0}
851#define {0}
852
853#ifdef __cplusplus
854extern "C" {{
855#endif
856
857'''
858
859HEADER_EPILOGUE = '''
860#ifdef __cplusplus
861}}
862#endif
863
864#endif /* {0} */
865'''
866
867C_PROLOGUE = COPYRIGHT + '''
868#include "{0}"
869
870#include "genxml/gen_macros.h"
871#include "genxml/genX_pack.h"
872#include "genxml/genX_rt_pack.h"
873
874/* We reserve :
875 *    - GPR 14 for secondary command buffer returns
876 *    - GPR 15 for conditional rendering
877 */
878#define MI_BUILDER_NUM_ALLOC_GPRS 14
879#define __gen_get_batch_dwords anv_batch_emit_dwords
880#define __gen_address_offset anv_address_add
881#define __gen_get_batch_address(b, a) anv_batch_address(b, a)
882#include "common/mi_builder.h"
883
884#define MI_PREDICATE_RESULT mi_reg32(0x2418)
885#define DISPATCHDIM_X mi_reg32(0x2500)
886#define DISPATCHDIM_Y mi_reg32(0x2504)
887#define DISPATCHDIM_Z mi_reg32(0x2508)
888'''
889
890def parse_libraries(filenames):
891    libraries = {}
892    for fname in filenames:
893        lib_package = parse_grl_file(fname, [])
894        for lib in lib_package:
895            assert lib[0] == 'library'
896            # Add the directory of the library so that CL files can be found.
897            lib[2].append(('path', os.path.dirname(fname)))
898            libraries[lib[1]] = lib
899    return libraries
900
901def main():
902    argparser = argparse.ArgumentParser()
903    argparser.add_argument('--out-c', help='Output C file')
904    argparser.add_argument('--out-h', help='Output C file')
905    argparser.add_argument('--library', dest='libraries', action='append',
906                           default=[], help='Libraries to include')
907    argparser.add_argument('grl', help="Input  file")
908    args = argparser.parse_args()
909
910    grl_dir = os.path.dirname(args.grl)
911
912    libraries = parse_libraries(args.libraries)
913
914    ir = parse_grl_file(args.grl, libraries)
915
916    m = Module(grl_dir, ir)
917    m.optimize()
918    m.compact_regs()
919
920    with open(args.out_h, 'w') as f:
921        guard = os.path.splitext(os.path.basename(args.out_h))[0].upper()
922        w = Writer(f)
923        w.write(HEADER_PROLOGUE, guard)
924        m.write_h(w)
925        w.write(HEADER_EPILOGUE, guard)
926
927    with open(args.out_c, 'w') as f:
928        w = Writer(f)
929        w.write(C_PROLOGUE, os.path.basename(args.out_h))
930        m.write_c(w)
931
932if __name__ == '__main__':
933    main()
934