1 // Copyright © 2022 Collabora, Ltd.
2 // SPDX-License-Identifier: MIT
3
4 use crate::api::{GetDebugFlags, DEBUG};
5 use crate::bitset::BitSet;
6 use crate::ir::*;
7 use crate::liveness::{BlockLiveness, Liveness, SimpleLiveness};
8
9 use std::cmp::{max, Ordering};
10 use std::collections::{HashMap, HashSet};
11
12 struct KillSet {
13 set: HashSet<SSAValue>,
14 vec: Vec<SSAValue>,
15 }
16
17 impl KillSet {
new() -> KillSet18 pub fn new() -> KillSet {
19 KillSet {
20 set: HashSet::new(),
21 vec: Vec::new(),
22 }
23 }
24
clear(&mut self)25 pub fn clear(&mut self) {
26 self.set.clear();
27 self.vec.clear();
28 }
29
insert(&mut self, ssa: SSAValue)30 pub fn insert(&mut self, ssa: SSAValue) {
31 if self.set.insert(ssa) {
32 self.vec.push(ssa);
33 }
34 }
35
iter(&self) -> std::slice::Iter<'_, SSAValue>36 pub fn iter(&self) -> std::slice::Iter<'_, SSAValue> {
37 self.vec.iter()
38 }
39
is_empty(&self) -> bool40 pub fn is_empty(&self) -> bool {
41 self.vec.is_empty()
42 }
43 }
44
45 enum SSAUse {
46 FixedReg(u32),
47 Vec(SSARef),
48 }
49
50 struct SSAUseMap {
51 ssa_map: HashMap<SSAValue, Vec<(usize, SSAUse)>>,
52 }
53
54 impl SSAUseMap {
add_fixed_reg_use(&mut self, ip: usize, ssa: SSAValue, reg: u32)55 fn add_fixed_reg_use(&mut self, ip: usize, ssa: SSAValue, reg: u32) {
56 let v = self.ssa_map.entry(ssa).or_default();
57 v.push((ip, SSAUse::FixedReg(reg)));
58 }
59
add_vec_use(&mut self, ip: usize, vec: SSARef)60 fn add_vec_use(&mut self, ip: usize, vec: SSARef) {
61 if vec.comps() == 1 {
62 return;
63 }
64
65 for ssa in vec.iter() {
66 let v = self.ssa_map.entry(*ssa).or_default();
67 v.push((ip, SSAUse::Vec(vec)));
68 }
69 }
70
find_vec_use_after(&self, ssa: SSAValue, ip: usize) -> Option<&SSAUse>71 fn find_vec_use_after(&self, ssa: SSAValue, ip: usize) -> Option<&SSAUse> {
72 if let Some(v) = self.ssa_map.get(&ssa) {
73 let p = v.partition_point(|(uip, _)| *uip <= ip);
74 if p == v.len() {
75 None
76 } else {
77 let (_, u) = &v[p];
78 Some(u)
79 }
80 } else {
81 None
82 }
83 }
84
add_block(&mut self, b: &BasicBlock)85 pub fn add_block(&mut self, b: &BasicBlock) {
86 for (ip, instr) in b.instrs.iter().enumerate() {
87 match &instr.op {
88 Op::FSOut(op) => {
89 for (i, src) in op.srcs.iter().enumerate() {
90 let out_reg = u32::try_from(i).unwrap();
91 if let SrcRef::SSA(ssa) = src.src_ref {
92 assert!(ssa.comps() == 1);
93 self.add_fixed_reg_use(ip, ssa[0], out_reg);
94 }
95 }
96 }
97 _ => {
98 // We don't care about predicates because they're scalar
99 for src in instr.srcs() {
100 if let SrcRef::SSA(ssa) = src.src_ref {
101 self.add_vec_use(ip, ssa);
102 }
103 }
104 }
105 }
106 }
107 }
108
for_block(b: &BasicBlock) -> SSAUseMap109 pub fn for_block(b: &BasicBlock) -> SSAUseMap {
110 let mut am = SSAUseMap {
111 ssa_map: HashMap::new(),
112 };
113 am.add_block(b);
114 am
115 }
116 }
117
118 #[derive(Clone, Copy, Eq, Hash, PartialEq)]
119 enum LiveRef {
120 SSA(SSAValue),
121 Phi(u32),
122 }
123
124 #[derive(Clone, Copy, Eq, Hash, PartialEq)]
125 struct LiveValue {
126 pub live_ref: LiveRef,
127 pub reg_ref: RegRef,
128 }
129
130 // We need a stable ordering of live values so that RA is deterministic
131 impl Ord for LiveValue {
cmp(&self, other: &Self) -> Ordering132 fn cmp(&self, other: &Self) -> Ordering {
133 let s_file = u8::from(self.reg_ref.file());
134 let o_file = u8::from(other.reg_ref.file());
135 match s_file.cmp(&o_file) {
136 Ordering::Equal => {
137 let s_idx = self.reg_ref.base_idx();
138 let o_idx = other.reg_ref.base_idx();
139 s_idx.cmp(&o_idx)
140 }
141 ord => ord,
142 }
143 }
144 }
145
146 impl PartialOrd for LiveValue {
partial_cmp(&self, other: &Self) -> Option<Ordering>147 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
148 Some(self.cmp(other))
149 }
150 }
151
152 #[derive(Clone)]
153 struct RegAllocator {
154 file: RegFile,
155 num_regs: u32,
156 used: BitSet,
157 reg_ssa: Vec<SSAValue>,
158 ssa_reg: HashMap<SSAValue, u32>,
159 }
160
161 impl RegAllocator {
new(file: RegFile, num_regs: u32) -> Self162 pub fn new(file: RegFile, num_regs: u32) -> Self {
163 Self {
164 file: file,
165 num_regs: num_regs,
166 used: BitSet::new(),
167 reg_ssa: Vec::new(),
168 ssa_reg: HashMap::new(),
169 }
170 }
171
file(&self) -> RegFile172 fn file(&self) -> RegFile {
173 self.file
174 }
175
num_regs_used(&self) -> u32176 pub fn num_regs_used(&self) -> u32 {
177 self.ssa_reg.len().try_into().unwrap()
178 }
179
reg_is_used(&self, reg: u32) -> bool180 pub fn reg_is_used(&self, reg: u32) -> bool {
181 self.used.get(reg.try_into().unwrap())
182 }
183
reg_range_is_unused(&self, reg: u32, comps: u8) -> bool184 fn reg_range_is_unused(&self, reg: u32, comps: u8) -> bool {
185 for c in 0..u32::from(comps) {
186 if self.reg_is_used(reg + c) {
187 return false;
188 }
189 }
190 true
191 }
192
try_get_reg(&self, ssa: SSAValue) -> Option<u32>193 pub fn try_get_reg(&self, ssa: SSAValue) -> Option<u32> {
194 self.ssa_reg.get(&ssa).cloned()
195 }
196
try_get_ssa(&self, reg: u32) -> Option<SSAValue>197 pub fn try_get_ssa(&self, reg: u32) -> Option<SSAValue> {
198 if self.reg_is_used(reg) {
199 Some(self.reg_ssa[usize::try_from(reg).unwrap()])
200 } else {
201 None
202 }
203 }
204
try_get_vec_reg(&self, vec: &SSARef) -> Option<u32>205 pub fn try_get_vec_reg(&self, vec: &SSARef) -> Option<u32> {
206 let Some(reg) = self.try_get_reg(vec[0]) else {
207 return None;
208 };
209
210 let align = u32::from(vec.comps()).next_power_of_two();
211 if reg % align != 0 {
212 return None;
213 }
214
215 for c in 1..vec.comps() {
216 let ssa = vec[usize::from(c)];
217 if self.try_get_reg(ssa) != Some(reg + u32::from(c)) {
218 return None;
219 }
220 }
221 Some(reg)
222 }
223
free_ssa(&mut self, ssa: SSAValue) -> u32224 pub fn free_ssa(&mut self, ssa: SSAValue) -> u32 {
225 assert!(ssa.file() == self.file);
226 let reg = self.ssa_reg.remove(&ssa).unwrap();
227 assert!(self.reg_is_used(reg));
228 let reg_usize = usize::try_from(reg).unwrap();
229 assert!(self.reg_ssa[reg_usize] == ssa);
230 self.used.remove(reg_usize);
231 reg
232 }
233
assign_reg(&mut self, ssa: SSAValue, reg: u32)234 pub fn assign_reg(&mut self, ssa: SSAValue, reg: u32) {
235 assert!(ssa.file() == self.file);
236 assert!(reg < self.num_regs);
237 assert!(!self.reg_is_used(reg));
238
239 let reg_usize = usize::try_from(reg).unwrap();
240 if reg_usize >= self.reg_ssa.len() {
241 self.reg_ssa.resize(reg_usize + 1, SSAValue::NONE);
242 }
243 self.reg_ssa[reg_usize] = ssa;
244 let old = self.ssa_reg.insert(ssa, reg);
245 assert!(old.is_none());
246 self.used.insert(reg_usize);
247 }
248
try_find_unused_reg_range( &self, start_reg: u32, align: u32, comps: u8, ) -> Option<u32>249 pub fn try_find_unused_reg_range(
250 &self,
251 start_reg: u32,
252 align: u32,
253 comps: u8,
254 ) -> Option<u32> {
255 assert!(comps > 0 && u32::from(comps) <= self.num_regs);
256
257 let mut next_reg = start_reg;
258 loop {
259 let reg: u32 = self
260 .used
261 .next_unset(usize::try_from(next_reg).unwrap())
262 .try_into()
263 .unwrap();
264
265 // Ensure we're properly aligned
266 let reg = reg.next_multiple_of(align);
267
268 // Ensure we're in-bounds. This also serves as a check to ensure
269 // that u8::try_from(reg + i) will succeed.
270 if reg > self.num_regs - u32::from(comps) {
271 return None;
272 }
273
274 if self.reg_range_is_unused(reg, comps) {
275 return Some(reg);
276 }
277
278 next_reg = reg + align;
279 }
280 }
281
alloc_scalar( &mut self, ip: usize, sum: &SSAUseMap, ssa: SSAValue, ) -> u32282 pub fn alloc_scalar(
283 &mut self,
284 ip: usize,
285 sum: &SSAUseMap,
286 ssa: SSAValue,
287 ) -> u32 {
288 if let Some(u) = sum.find_vec_use_after(ssa, ip) {
289 match u {
290 SSAUse::FixedReg(reg) => {
291 if !self.reg_is_used(*reg) {
292 self.assign_reg(ssa, *reg);
293 return *reg;
294 }
295 }
296 SSAUse::Vec(vec) => {
297 let mut comp = u8::MAX;
298 for c in 0..vec.comps() {
299 if vec[usize::from(c)] == ssa {
300 comp = c;
301 break;
302 }
303 }
304 assert!(comp < vec.comps());
305
306 let align = u32::from(vec.comps()).next_power_of_two();
307 for c in 0..vec.comps() {
308 if c == comp {
309 continue;
310 }
311
312 let other = vec[usize::from(c)];
313 let Some(other_reg) = self.try_get_reg(other) else {
314 continue;
315 };
316
317 let vec_reg = other_reg & !(align - 1);
318 if other_reg != vec_reg + u32::from(c) {
319 continue;
320 }
321
322 let reg = vec_reg + u32::from(comp);
323 if reg < self.num_regs && !self.reg_is_used(reg) {
324 self.assign_reg(ssa, reg);
325 return reg;
326 }
327 }
328
329 // We weren't able to pair it with an already allocated
330 // register but maybe we can at least find an aligned one.
331 if let Some(reg) =
332 self.try_find_unused_reg_range(0, align, 1)
333 {
334 self.assign_reg(ssa, reg);
335 return reg;
336 }
337 }
338 }
339 }
340
341 let reg = self
342 .try_find_unused_reg_range(0, 1, 1)
343 .expect("Failed to find free register");
344 self.assign_reg(ssa, reg);
345 reg
346 }
347 }
348
349 struct PinnedRegAllocator<'a> {
350 ra: &'a mut RegAllocator,
351 pcopy: OpParCopy,
352 pinned: BitSet,
353 evicted: HashMap<SSAValue, u32>,
354 }
355
356 impl<'a> PinnedRegAllocator<'a> {
new(ra: &'a mut RegAllocator) -> Self357 fn new(ra: &'a mut RegAllocator) -> Self {
358 PinnedRegAllocator {
359 ra: ra,
360 pcopy: OpParCopy::new(),
361 pinned: Default::default(),
362 evicted: HashMap::new(),
363 }
364 }
365
file(&self) -> RegFile366 fn file(&self) -> RegFile {
367 self.ra.file()
368 }
369
pin_reg(&mut self, reg: u32)370 fn pin_reg(&mut self, reg: u32) {
371 self.pinned.insert(reg.try_into().unwrap());
372 }
373
pin_reg_range(&mut self, reg: u32, comps: u8)374 fn pin_reg_range(&mut self, reg: u32, comps: u8) {
375 for c in 0..u32::from(comps) {
376 self.pin_reg(reg + c);
377 }
378 }
379
reg_is_pinned(&self, reg: u32) -> bool380 fn reg_is_pinned(&self, reg: u32) -> bool {
381 self.pinned.get(reg.try_into().unwrap())
382 }
383
reg_range_is_unpinned(&self, reg: u32, comps: u8) -> bool384 fn reg_range_is_unpinned(&self, reg: u32, comps: u8) -> bool {
385 for c in 0..u32::from(comps) {
386 if self.reg_is_pinned(reg + c) {
387 return false;
388 }
389 }
390 true
391 }
392
assign_pin_reg(&mut self, ssa: SSAValue, reg: u32) -> RegRef393 fn assign_pin_reg(&mut self, ssa: SSAValue, reg: u32) -> RegRef {
394 self.pin_reg(reg);
395 self.ra.assign_reg(ssa, reg);
396 RegRef::new(self.file(), reg, 1)
397 }
398
assign_pin_vec_reg(&mut self, vec: SSARef, reg: u32) -> RegRef399 pub fn assign_pin_vec_reg(&mut self, vec: SSARef, reg: u32) -> RegRef {
400 for c in 0..vec.comps() {
401 let ssa = vec[usize::from(c)];
402 self.assign_pin_reg(ssa, reg + u32::from(c));
403 }
404 RegRef::new(self.file(), reg, vec.comps())
405 }
406
try_find_unpinned_reg_range( &self, start_reg: u32, align: u32, comps: u8, ) -> Option<u32>407 fn try_find_unpinned_reg_range(
408 &self,
409 start_reg: u32,
410 align: u32,
411 comps: u8,
412 ) -> Option<u32> {
413 let mut next_reg = start_reg;
414 loop {
415 let reg: u32 = self
416 .pinned
417 .next_unset(usize::try_from(next_reg).unwrap())
418 .try_into()
419 .unwrap();
420
421 // Ensure we're properly aligned
422 let reg = reg.next_multiple_of(align);
423
424 // Ensure we're in-bounds. This also serves as a check to ensure
425 // that u8::try_from(reg + i) will succeed.
426 if reg > self.ra.num_regs - u32::from(comps) {
427 return None;
428 }
429
430 if self.reg_range_is_unpinned(reg, comps) {
431 return Some(reg);
432 }
433
434 next_reg = reg + align;
435 }
436 }
437
evict_ssa(&mut self, ssa: SSAValue, old_reg: u32)438 pub fn evict_ssa(&mut self, ssa: SSAValue, old_reg: u32) {
439 assert!(ssa.file() == self.file());
440 assert!(!self.reg_is_pinned(old_reg));
441 self.evicted.insert(ssa, old_reg);
442 }
443
evict_reg_if_used(&mut self, reg: u32)444 pub fn evict_reg_if_used(&mut self, reg: u32) {
445 assert!(!self.reg_is_pinned(reg));
446
447 if let Some(ssa) = self.ra.try_get_ssa(reg) {
448 self.ra.free_ssa(ssa);
449 self.evict_ssa(ssa, reg);
450 }
451 }
452
move_ssa_to_reg(&mut self, ssa: SSAValue, new_reg: u32)453 fn move_ssa_to_reg(&mut self, ssa: SSAValue, new_reg: u32) {
454 if let Some(old_reg) = self.ra.try_get_reg(ssa) {
455 assert!(self.evicted.get(&ssa).is_none());
456 assert!(!self.reg_is_pinned(old_reg));
457
458 if new_reg == old_reg {
459 self.pin_reg(new_reg);
460 } else {
461 self.ra.free_ssa(ssa);
462 self.evict_reg_if_used(new_reg);
463
464 self.pcopy.push(
465 RegRef::new(self.file(), new_reg, 1).into(),
466 RegRef::new(self.file(), old_reg, 1).into(),
467 );
468
469 self.assign_pin_reg(ssa, new_reg);
470 }
471 } else if let Some(old_reg) = self.evicted.remove(&ssa) {
472 self.evict_reg_if_used(new_reg);
473
474 self.pcopy.push(
475 RegRef::new(self.file(), new_reg, 1).into(),
476 RegRef::new(self.file(), old_reg, 1).into(),
477 );
478
479 self.assign_pin_reg(ssa, new_reg);
480 } else {
481 panic!("Unknown SSA value");
482 }
483 }
484
finish(mut self, pcopy: &mut OpParCopy)485 fn finish(mut self, pcopy: &mut OpParCopy) {
486 pcopy.dsts_srcs.append(&mut self.pcopy.dsts_srcs);
487
488 if !self.evicted.is_empty() {
489 // Sort so we get determinism, even if the hash map order changes
490 // from one run to another or due to rust compiler updates.
491 let mut evicted: Vec<_> = self.evicted.drain().collect();
492 evicted.sort_by_key(|(_, reg)| *reg);
493
494 for (ssa, old_reg) in evicted {
495 let mut next_reg = 0;
496 let new_reg = loop {
497 let reg = self
498 .ra
499 .try_find_unused_reg_range(next_reg, 1, 1)
500 .expect("Failed to find free register");
501 if !self.reg_is_pinned(reg) {
502 break reg;
503 }
504 next_reg = reg + 1;
505 };
506
507 pcopy.push(
508 RegRef::new(self.file(), new_reg, 1).into(),
509 RegRef::new(self.file(), old_reg, 1).into(),
510 );
511 self.assign_pin_reg(ssa, new_reg);
512 }
513 }
514 }
515
try_get_vec_reg(&self, vec: &SSARef) -> Option<u32>516 pub fn try_get_vec_reg(&self, vec: &SSARef) -> Option<u32> {
517 self.ra.try_get_vec_reg(vec)
518 }
519
collect_vector(&mut self, vec: &SSARef) -> RegRef520 pub fn collect_vector(&mut self, vec: &SSARef) -> RegRef {
521 if let Some(reg) = self.try_get_vec_reg(vec) {
522 self.pin_reg_range(reg, vec.comps());
523 return RegRef::new(self.file(), reg, vec.comps());
524 }
525
526 let comps = vec.comps();
527 let align = u32::from(comps).next_power_of_two();
528
529 let reg = self
530 .ra
531 .try_find_unused_reg_range(0, align, comps)
532 .or_else(|| {
533 for c in 0..comps {
534 let ssa = vec[usize::from(c)];
535 let Some(comp_reg) = self.ra.try_get_reg(ssa) else {
536 continue;
537 };
538
539 let vec_reg = comp_reg & !(align - 1);
540 if comp_reg != vec_reg + u32::from(c) {
541 continue;
542 }
543
544 if vec_reg + u32::from(comps) > self.ra.num_regs {
545 continue;
546 }
547
548 if self.reg_range_is_unpinned(vec_reg, comps) {
549 return Some(vec_reg);
550 }
551 }
552 None
553 })
554 .or_else(|| self.try_find_unpinned_reg_range(0, align, comps))
555 .expect("Failed to find an unpinned register range");
556
557 for c in 0..comps {
558 let ssa = vec[usize::from(c)];
559 self.move_ssa_to_reg(ssa, reg + u32::from(c));
560 }
561
562 RegRef::new(self.file(), reg, comps)
563 }
564
alloc_vector(&mut self, vec: SSARef) -> RegRef565 pub fn alloc_vector(&mut self, vec: SSARef) -> RegRef {
566 let comps = vec.comps();
567 let align = u32::from(comps).next_power_of_two();
568
569 if let Some(reg) = self.ra.try_find_unused_reg_range(0, align, comps) {
570 return self.assign_pin_vec_reg(vec, reg);
571 }
572
573 let reg = self
574 .try_find_unpinned_reg_range(0, align, comps)
575 .expect("Failed to find an unpinned register range");
576
577 for c in 0..comps {
578 self.evict_reg_if_used(reg + u32::from(c));
579 }
580 self.assign_pin_vec_reg(vec, reg)
581 }
582
free_killed(&mut self, killed: &KillSet)583 pub fn free_killed(&mut self, killed: &KillSet) {
584 for ssa in killed.iter() {
585 if ssa.file() == self.file() {
586 self.ra.free_ssa(*ssa);
587 }
588 }
589 }
590 }
591
592 impl Drop for PinnedRegAllocator<'_> {
drop(&mut self)593 fn drop(&mut self) {
594 assert!(self.evicted.is_empty());
595 }
596 }
597
instr_remap_srcs_file(instr: &mut Instr, ra: &mut PinnedRegAllocator)598 fn instr_remap_srcs_file(instr: &mut Instr, ra: &mut PinnedRegAllocator) {
599 // Collect vector sources first since those may silently pin some of our
600 // scalar sources.
601 for src in instr.srcs_mut() {
602 if let SrcRef::SSA(ssa) = &src.src_ref {
603 if ssa.file() == ra.file() && ssa.comps() > 1 {
604 src.src_ref = ra.collect_vector(ssa).into();
605 }
606 }
607 }
608
609 if let PredRef::SSA(pred) = instr.pred.pred_ref {
610 if pred.file() == ra.file() {
611 instr.pred.pred_ref = ra.collect_vector(&pred.into()).into();
612 }
613 }
614
615 for src in instr.srcs_mut() {
616 if let SrcRef::SSA(ssa) = &src.src_ref {
617 if ssa.file() == ra.file() && ssa.comps() == 1 {
618 src.src_ref = ra.collect_vector(ssa).into();
619 }
620 }
621 }
622 }
623
instr_alloc_scalar_dsts_file( instr: &mut Instr, ip: usize, sum: &SSAUseMap, ra: &mut RegAllocator, )624 fn instr_alloc_scalar_dsts_file(
625 instr: &mut Instr,
626 ip: usize,
627 sum: &SSAUseMap,
628 ra: &mut RegAllocator,
629 ) {
630 for dst in instr.dsts_mut() {
631 if let Dst::SSA(ssa) = dst {
632 assert!(ssa.comps() == 1);
633 if ssa.file() == ra.file() {
634 let reg = ra.alloc_scalar(ip, sum, ssa[0]);
635 *dst = RegRef::new(ra.file(), reg, 1).into();
636 }
637 }
638 }
639 }
640
instr_assign_regs_file( instr: &mut Instr, ip: usize, sum: &SSAUseMap, killed: &KillSet, pcopy: &mut OpParCopy, ra: &mut RegAllocator, )641 fn instr_assign_regs_file(
642 instr: &mut Instr,
643 ip: usize,
644 sum: &SSAUseMap,
645 killed: &KillSet,
646 pcopy: &mut OpParCopy,
647 ra: &mut RegAllocator,
648 ) {
649 struct VecDst {
650 dst_idx: usize,
651 comps: u8,
652 killed: Option<SSARef>,
653 reg: u32,
654 }
655
656 let mut vec_dsts = Vec::new();
657 let mut vec_dst_comps = 0;
658 for (i, dst) in instr.dsts().iter().enumerate() {
659 if let Dst::SSA(ssa) = dst {
660 if ssa.file() == ra.file() && ssa.comps() > 1 {
661 vec_dsts.push(VecDst {
662 dst_idx: i,
663 comps: ssa.comps(),
664 killed: None,
665 reg: u32::MAX,
666 });
667 vec_dst_comps += ssa.comps();
668 }
669 }
670 }
671
672 // No vector destinations is the easy case
673 if vec_dst_comps == 0 {
674 let mut pra = PinnedRegAllocator::new(ra);
675 instr_remap_srcs_file(instr, &mut pra);
676 pra.free_killed(killed);
677 pra.finish(pcopy);
678 instr_alloc_scalar_dsts_file(instr, ip, sum, ra);
679 return;
680 }
681
682 // Predicates can't be vectors. This lets us ignore instr.pred in our
683 // analysis for the cases below. Only the easy case above needs to care
684 // about them.
685 assert!(!ra.file().is_predicate());
686
687 let mut avail = killed.set.clone();
688 let mut killed_vecs = Vec::new();
689 for src in instr.srcs() {
690 if let SrcRef::SSA(vec) = src.src_ref {
691 if vec.comps() > 1 {
692 let mut vec_killed = true;
693 for ssa in vec.iter() {
694 if ssa.file() != ra.file() || !avail.contains(ssa) {
695 vec_killed = false;
696 break;
697 }
698 }
699 if vec_killed {
700 for ssa in vec.iter() {
701 avail.remove(ssa);
702 }
703 killed_vecs.push(vec);
704 }
705 }
706 }
707 }
708
709 vec_dsts.sort_by_key(|v| v.comps);
710 killed_vecs.sort_by_key(|v| v.comps());
711
712 let mut next_dst_reg = 0;
713 let mut vec_dsts_map_to_killed_srcs = true;
714 let mut could_trivially_allocate = true;
715 for vec_dst in vec_dsts.iter_mut().rev() {
716 while let Some(src) = killed_vecs.pop() {
717 if src.comps() >= vec_dst.comps {
718 vec_dst.killed = Some(src);
719 break;
720 }
721 }
722 if vec_dst.killed.is_none() {
723 vec_dsts_map_to_killed_srcs = false;
724 }
725
726 let align = u32::from(vec_dst.comps).next_power_of_two();
727 if let Some(reg) =
728 ra.try_find_unused_reg_range(next_dst_reg, align, vec_dst.comps)
729 {
730 vec_dst.reg = reg;
731 next_dst_reg = reg + u32::from(vec_dst.comps);
732 } else {
733 could_trivially_allocate = false;
734 }
735 }
736
737 if vec_dsts_map_to_killed_srcs {
738 let mut pra = PinnedRegAllocator::new(ra);
739 instr_remap_srcs_file(instr, &mut pra);
740
741 for vec_dst in &mut vec_dsts {
742 let src_vec = vec_dst.killed.as_ref().unwrap();
743 vec_dst.reg = pra.try_get_vec_reg(src_vec).unwrap();
744 }
745
746 pra.free_killed(killed);
747
748 for vec_dst in vec_dsts {
749 let dst = &mut instr.dsts_mut()[vec_dst.dst_idx];
750 *dst = pra
751 .assign_pin_vec_reg(*dst.as_ssa().unwrap(), vec_dst.reg)
752 .into();
753 }
754
755 pra.finish(pcopy);
756
757 instr_alloc_scalar_dsts_file(instr, ip, sum, ra);
758 } else if could_trivially_allocate {
759 let mut pra = PinnedRegAllocator::new(ra);
760 for vec_dst in vec_dsts {
761 let dst = &mut instr.dsts_mut()[vec_dst.dst_idx];
762 *dst = pra
763 .assign_pin_vec_reg(*dst.as_ssa().unwrap(), vec_dst.reg)
764 .into();
765 }
766
767 instr_remap_srcs_file(instr, &mut pra);
768 pra.free_killed(killed);
769 pra.finish(pcopy);
770 instr_alloc_scalar_dsts_file(instr, ip, sum, ra);
771 } else {
772 let mut pra = PinnedRegAllocator::new(ra);
773 instr_remap_srcs_file(instr, &mut pra);
774
775 // Allocate vector destinations first so we have the most freedom.
776 // Scalar destinations can fill in holes.
777 for dst in instr.dsts_mut() {
778 if let Dst::SSA(ssa) = dst {
779 if ssa.file() == pra.file() && ssa.comps() > 1 {
780 *dst = pra.alloc_vector(*ssa).into();
781 }
782 }
783 }
784
785 pra.free_killed(killed);
786 pra.finish(pcopy);
787
788 instr_alloc_scalar_dsts_file(instr, ip, sum, ra);
789 }
790 }
791
792 impl PerRegFile<RegAllocator> {
assign_reg(&mut self, ssa: SSAValue, reg: RegRef)793 pub fn assign_reg(&mut self, ssa: SSAValue, reg: RegRef) {
794 assert!(reg.file() == ssa.file());
795 assert!(reg.comps() == 1);
796 self[ssa.file()].assign_reg(ssa, reg.base_idx());
797 }
798
free_killed(&mut self, killed: &KillSet)799 pub fn free_killed(&mut self, killed: &KillSet) {
800 for ssa in killed.iter() {
801 self[ssa.file()].free_ssa(*ssa);
802 }
803 }
804 }
805
806 struct AssignRegsBlock {
807 ra: PerRegFile<RegAllocator>,
808 pcopy_tmp_gprs: u8,
809 live_in: Vec<LiveValue>,
810 phi_out: HashMap<u32, SrcRef>,
811 }
812
813 impl AssignRegsBlock {
new(num_regs: &PerRegFile<u32>, pcopy_tmp_gprs: u8) -> AssignRegsBlock814 fn new(num_regs: &PerRegFile<u32>, pcopy_tmp_gprs: u8) -> AssignRegsBlock {
815 AssignRegsBlock {
816 ra: PerRegFile::new_with(|file| {
817 RegAllocator::new(file, num_regs[file])
818 }),
819 pcopy_tmp_gprs: pcopy_tmp_gprs,
820 live_in: Vec::new(),
821 phi_out: HashMap::new(),
822 }
823 }
824
get_scalar(&self, ssa: SSAValue) -> RegRef825 fn get_scalar(&self, ssa: SSAValue) -> RegRef {
826 let ra = &self.ra[ssa.file()];
827 let reg = ra.try_get_reg(ssa).expect("Unknown SSA value");
828 RegRef::new(ssa.file(), reg, 1)
829 }
830
alloc_scalar( &mut self, ip: usize, sum: &SSAUseMap, ssa: SSAValue, ) -> RegRef831 fn alloc_scalar(
832 &mut self,
833 ip: usize,
834 sum: &SSAUseMap,
835 ssa: SSAValue,
836 ) -> RegRef {
837 let ra = &mut self.ra[ssa.file()];
838 let reg = ra.alloc_scalar(ip, sum, ssa);
839 RegRef::new(ssa.file(), reg, 1)
840 }
841
try_coalesce(&mut self, ssa: SSAValue, src: &Src) -> bool842 fn try_coalesce(&mut self, ssa: SSAValue, src: &Src) -> bool {
843 debug_assert!(src.src_mod.is_none());
844 let SrcRef::Reg(src_reg) = src.src_ref else {
845 return false;
846 };
847 debug_assert!(src_reg.comps() == 1);
848
849 if src_reg.file() != ssa.file() {
850 return false;
851 }
852
853 let ra = &mut self.ra[src_reg.file()];
854 if ra.reg_is_used(src_reg.base_idx()) {
855 return false;
856 }
857
858 ra.assign_reg(ssa, src_reg.base_idx());
859 true
860 }
861
pcopy_tmp(&self) -> Option<RegRef>862 fn pcopy_tmp(&self) -> Option<RegRef> {
863 if self.pcopy_tmp_gprs > 0 {
864 Some(RegRef::new(
865 RegFile::GPR,
866 self.ra[RegFile::GPR].num_regs,
867 self.pcopy_tmp_gprs,
868 ))
869 } else {
870 None
871 }
872 }
873
assign_regs_instr( &mut self, mut instr: Box<Instr>, ip: usize, sum: &SSAUseMap, srcs_killed: &KillSet, dsts_killed: &KillSet, pcopy: &mut OpParCopy, ) -> Option<Box<Instr>>874 fn assign_regs_instr(
875 &mut self,
876 mut instr: Box<Instr>,
877 ip: usize,
878 sum: &SSAUseMap,
879 srcs_killed: &KillSet,
880 dsts_killed: &KillSet,
881 pcopy: &mut OpParCopy,
882 ) -> Option<Box<Instr>> {
883 match &mut instr.op {
884 Op::Undef(undef) => {
885 if let Dst::SSA(ssa) = undef.dst {
886 assert!(ssa.comps() == 1);
887 self.alloc_scalar(ip, sum, ssa[0]);
888 }
889 assert!(srcs_killed.is_empty());
890 self.ra.free_killed(dsts_killed);
891 None
892 }
893 Op::PhiSrcs(phi) => {
894 for (id, src) in phi.srcs.iter() {
895 assert!(src.src_mod.is_none());
896 if let SrcRef::SSA(ssa) = src.src_ref {
897 assert!(ssa.comps() == 1);
898 let reg = self.get_scalar(ssa[0]);
899 self.phi_out.insert(*id, reg.into());
900 } else {
901 self.phi_out.insert(*id, src.src_ref);
902 }
903 }
904 assert!(dsts_killed.is_empty());
905 None
906 }
907 Op::PhiDsts(phi) => {
908 assert!(instr.pred.is_true());
909
910 for (id, dst) in phi.dsts.iter() {
911 if let Dst::SSA(ssa) = dst {
912 assert!(ssa.comps() == 1);
913 let reg = self.alloc_scalar(ip, sum, ssa[0]);
914 self.live_in.push(LiveValue {
915 live_ref: LiveRef::Phi(*id),
916 reg_ref: reg,
917 });
918 }
919 }
920 assert!(srcs_killed.is_empty());
921 self.ra.free_killed(dsts_killed);
922
923 None
924 }
925 Op::Break(op) => {
926 for src in op.srcs_as_mut_slice() {
927 if let SrcRef::SSA(ssa) = src.src_ref {
928 assert!(ssa.comps() == 1);
929 let reg = self.get_scalar(ssa[0]);
930 src.src_ref = reg.into();
931 }
932 }
933
934 self.ra.free_killed(srcs_killed);
935
936 if let Dst::SSA(ssa) = &op.bar_out {
937 let reg = *op.bar_in.src_ref.as_reg().unwrap();
938 self.ra.assign_reg(ssa[0], reg);
939 op.bar_out = reg.into();
940 }
941
942 self.ra.free_killed(dsts_killed);
943
944 Some(instr)
945 }
946 Op::BSSy(op) => {
947 for src in op.srcs_as_mut_slice() {
948 if let SrcRef::SSA(ssa) = src.src_ref {
949 assert!(ssa.comps() == 1);
950 let reg = self.get_scalar(ssa[0]);
951 src.src_ref = reg.into();
952 }
953 }
954
955 self.ra.free_killed(srcs_killed);
956
957 if let Dst::SSA(ssa) = &op.bar_out {
958 let reg = *op.bar_in.src_ref.as_reg().unwrap();
959 self.ra.assign_reg(ssa[0], reg);
960 op.bar_out = reg.into();
961 }
962
963 self.ra.free_killed(dsts_killed);
964
965 Some(instr)
966 }
967 Op::Copy(copy) => {
968 if let SrcRef::SSA(src_vec) = ©.src.src_ref {
969 debug_assert!(src_vec.comps() == 1);
970 let src_ssa = &src_vec[0];
971 copy.src.src_ref = self.get_scalar(*src_ssa).into();
972 }
973
974 self.ra.free_killed(srcs_killed);
975
976 let mut del_copy = false;
977 if let Dst::SSA(dst_vec) = &mut copy.dst {
978 debug_assert!(dst_vec.comps() == 1);
979 let dst_ssa = &dst_vec[0];
980
981 if self.try_coalesce(*dst_ssa, ©.src) {
982 del_copy = true;
983 } else {
984 copy.dst = self.alloc_scalar(ip, sum, *dst_ssa).into();
985 }
986 }
987
988 self.ra.free_killed(dsts_killed);
989
990 if del_copy {
991 None
992 } else {
993 Some(instr)
994 }
995 }
996 Op::ParCopy(pcopy) => {
997 for (_, src) in pcopy.dsts_srcs.iter_mut() {
998 if let SrcRef::SSA(src_vec) = src.src_ref {
999 debug_assert!(src_vec.comps() == 1);
1000 let src_ssa = &src_vec[0];
1001 src.src_ref = self.get_scalar(*src_ssa).into();
1002 }
1003 }
1004
1005 self.ra.free_killed(srcs_killed);
1006
1007 // Try to coalesce destinations into sources, if possible
1008 pcopy.dsts_srcs.retain(|dst, src| match dst {
1009 Dst::None => false,
1010 Dst::SSA(dst_vec) => {
1011 debug_assert!(dst_vec.comps() == 1);
1012 !self.try_coalesce(dst_vec[0], src)
1013 }
1014 Dst::Reg(_) => true,
1015 });
1016
1017 for (dst, _) in pcopy.dsts_srcs.iter_mut() {
1018 if let Dst::SSA(dst_vec) = dst {
1019 debug_assert!(dst_vec.comps() == 1);
1020 *dst = self.alloc_scalar(ip, sum, dst_vec[0]).into();
1021 }
1022 }
1023
1024 self.ra.free_killed(dsts_killed);
1025
1026 pcopy.tmp = self.pcopy_tmp();
1027 if pcopy.is_empty() {
1028 None
1029 } else {
1030 Some(instr)
1031 }
1032 }
1033 Op::FSOut(out) => {
1034 for src in out.srcs.iter_mut() {
1035 if let SrcRef::SSA(src_vec) = src.src_ref {
1036 debug_assert!(src_vec.comps() == 1);
1037 let src_ssa = &src_vec[0];
1038 src.src_ref = self.get_scalar(*src_ssa).into();
1039 }
1040 }
1041
1042 self.ra.free_killed(srcs_killed);
1043 assert!(dsts_killed.is_empty());
1044
1045 // This should be the last instruction and its sources should
1046 // be the last free GPRs.
1047 debug_assert!(self.ra[RegFile::GPR].num_regs_used() == 0);
1048
1049 for (i, src) in out.srcs.iter().enumerate() {
1050 let reg = u32::try_from(i).unwrap();
1051 let dst = RegRef::new(RegFile::GPR, reg, 1);
1052 pcopy.push(dst.into(), *src);
1053 }
1054
1055 None
1056 }
1057 _ => {
1058 for file in self.ra.values_mut() {
1059 instr_assign_regs_file(
1060 &mut instr,
1061 ip,
1062 sum,
1063 srcs_killed,
1064 pcopy,
1065 file,
1066 );
1067 }
1068 self.ra.free_killed(dsts_killed);
1069 Some(instr)
1070 }
1071 }
1072 }
1073
first_pass<BL: BlockLiveness>( &mut self, b: &mut BasicBlock, bl: &BL, pred_ra: Option<&PerRegFile<RegAllocator>>, )1074 fn first_pass<BL: BlockLiveness>(
1075 &mut self,
1076 b: &mut BasicBlock,
1077 bl: &BL,
1078 pred_ra: Option<&PerRegFile<RegAllocator>>,
1079 ) {
1080 // Populate live in from the register file we're handed. We'll add more
1081 // live in when we process the OpPhiDst, if any.
1082 if let Some(pred_ra) = pred_ra {
1083 for (raf, pred_raf) in self.ra.values_mut().zip(pred_ra.values()) {
1084 for (ssa, reg) in &pred_raf.ssa_reg {
1085 if bl.is_live_in(ssa) {
1086 raf.assign_reg(*ssa, *reg);
1087 self.live_in.push(LiveValue {
1088 live_ref: LiveRef::SSA(*ssa),
1089 reg_ref: RegRef::new(raf.file(), *reg, 1),
1090 });
1091 }
1092 }
1093 }
1094 }
1095
1096 let sum = SSAUseMap::for_block(b);
1097
1098 let mut instrs = Vec::new();
1099 let mut srcs_killed = KillSet::new();
1100 let mut dsts_killed = KillSet::new();
1101
1102 for (ip, instr) in b.instrs.drain(..).enumerate() {
1103 // Build up the kill set
1104 srcs_killed.clear();
1105 if let PredRef::SSA(ssa) = &instr.pred.pred_ref {
1106 if !bl.is_live_after_ip(ssa, ip) {
1107 srcs_killed.insert(*ssa);
1108 }
1109 }
1110 for src in instr.srcs() {
1111 for ssa in src.iter_ssa() {
1112 if !bl.is_live_after_ip(ssa, ip) {
1113 srcs_killed.insert(*ssa);
1114 }
1115 }
1116 }
1117
1118 dsts_killed.clear();
1119 for dst in instr.dsts() {
1120 if let Dst::SSA(vec) = dst {
1121 for ssa in vec.iter() {
1122 if !bl.is_live_after_ip(ssa, ip) {
1123 dsts_killed.insert(*ssa);
1124 }
1125 }
1126 }
1127 }
1128
1129 let mut pcopy = OpParCopy::new();
1130 pcopy.tmp = self.pcopy_tmp();
1131
1132 let instr = self.assign_regs_instr(
1133 instr,
1134 ip,
1135 &sum,
1136 &srcs_killed,
1137 &dsts_killed,
1138 &mut pcopy,
1139 );
1140
1141 if !pcopy.is_empty() {
1142 if DEBUG.annotate() {
1143 instrs.push(Instr::new_boxed(OpAnnotate {
1144 annotation: "generated by assign_regs".into(),
1145 }));
1146 }
1147 instrs.push(Instr::new_boxed(pcopy));
1148 }
1149
1150 if let Some(instr) = instr {
1151 instrs.push(instr);
1152 }
1153 }
1154
1155 // Sort live-in to maintain determinism
1156 self.live_in.sort();
1157
1158 b.instrs = instrs;
1159 }
1160
second_pass(&self, target: &AssignRegsBlock, b: &mut BasicBlock)1161 fn second_pass(&self, target: &AssignRegsBlock, b: &mut BasicBlock) {
1162 let mut pcopy = OpParCopy::new();
1163 pcopy.tmp = self.pcopy_tmp();
1164
1165 for lv in &target.live_in {
1166 let src = match lv.live_ref {
1167 LiveRef::SSA(ssa) => SrcRef::from(self.get_scalar(ssa)),
1168 LiveRef::Phi(phi) => *self.phi_out.get(&phi).unwrap(),
1169 };
1170 let dst = lv.reg_ref;
1171 if let SrcRef::Reg(src_reg) = src {
1172 if dst == src_reg {
1173 continue;
1174 }
1175 }
1176 pcopy.push(dst.into(), src.into());
1177 }
1178
1179 if DEBUG.annotate() {
1180 b.instrs.push(Instr::new_boxed(OpAnnotate {
1181 annotation: "generated by assign_regs".into(),
1182 }));
1183 }
1184 if b.branch().is_some() {
1185 b.instrs.insert(b.instrs.len() - 1, Instr::new_boxed(pcopy));
1186 } else {
1187 b.instrs.push(Instr::new_boxed(pcopy));
1188 }
1189 }
1190 }
1191
1192 impl Shader {
assign_regs(&mut self)1193 pub fn assign_regs(&mut self) {
1194 assert!(self.functions.len() == 1);
1195 let f = &mut self.functions[0];
1196
1197 // Convert to CSSA before we spill or assign registers
1198 f.to_cssa();
1199
1200 let mut live = SimpleLiveness::for_function(f);
1201 let mut max_live = live.calc_max_live(f);
1202
1203 // We want at least one temporary GPR reserved for parallel copies.
1204 let mut tmp_gprs = 1_u8;
1205
1206 let spill_files = [RegFile::Pred, RegFile::Bar];
1207 for file in spill_files {
1208 let num_regs = file.num_regs(self.info.sm);
1209 if max_live[file] > num_regs {
1210 f.spill_values(file, num_regs);
1211
1212 // Re-calculate liveness after we spill
1213 live = SimpleLiveness::for_function(f);
1214 max_live = live.calc_max_live(f);
1215
1216 if file == RegFile::Bar {
1217 tmp_gprs = max(tmp_gprs, 2);
1218 }
1219 }
1220 }
1221
1222 // An instruction can have at most 4 vector sources/destinations. In
1223 // order to ensure we always succeed at allocation, regardless of
1224 // arbitrary choices, we need at least 16 GPRs.
1225 let mut gpr_limit = max(max_live[RegFile::GPR], 16);
1226 let mut total_gprs = gpr_limit + u32::from(tmp_gprs);
1227
1228 let max_gprs = RegFile::GPR.num_regs(self.info.sm);
1229 if total_gprs > max_gprs {
1230 // If we're spilling GPRs, we need to reserve 2 GPRs for OpParCopy
1231 // lowering because it needs to be able lower Mem copies which
1232 // require a temporary
1233 tmp_gprs = max(tmp_gprs, 2);
1234 total_gprs = max_gprs;
1235 gpr_limit = total_gprs - u32::from(tmp_gprs);
1236
1237 f.spill_values(RegFile::GPR, gpr_limit);
1238
1239 // Re-calculate liveness one last time
1240 live = SimpleLiveness::for_function(f);
1241 }
1242
1243 self.info.num_gprs = total_gprs.try_into().unwrap();
1244
1245 // We do a maximum here because nak_from_nir may set num_barriers to 1
1246 // in the case where there is an OpBar.
1247 self.info.num_barriers = max(
1248 self.info.num_barriers,
1249 max_live[RegFile::Bar].try_into().unwrap(),
1250 );
1251
1252 let limit = PerRegFile::new_with(|file| {
1253 if file == RegFile::GPR {
1254 gpr_limit
1255 } else {
1256 file.num_regs(self.info.sm)
1257 }
1258 });
1259
1260 let mut blocks: Vec<AssignRegsBlock> = Vec::new();
1261 for b_idx in 0..f.blocks.len() {
1262 let pred = f.blocks.pred_indices(b_idx);
1263 let pred_ra = if pred.is_empty() {
1264 None
1265 } else {
1266 // Start with the previous block's.
1267 Some(&blocks[pred[0]].ra)
1268 };
1269
1270 let bl = live.block_live(b_idx);
1271
1272 let mut arb = AssignRegsBlock::new(&limit, tmp_gprs);
1273 arb.first_pass(&mut f.blocks[b_idx], bl, pred_ra);
1274
1275 assert!(blocks.len() == b_idx);
1276 blocks.push(arb);
1277 }
1278
1279 for b_idx in 0..f.blocks.len() {
1280 let arb = &blocks[b_idx];
1281 for sb_idx in f.blocks.succ_indices(b_idx).to_vec() {
1282 arb.second_pass(&blocks[sb_idx], &mut f.blocks[b_idx]);
1283 }
1284 }
1285 }
1286 }
1287