• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright © 2022 Collabora, Ltd.
2 // SPDX-License-Identifier: MIT
3 
4 use crate::api::{GetDebugFlags, DEBUG};
5 use crate::bitset::BitSet;
6 use crate::ir::*;
7 use crate::liveness::{BlockLiveness, Liveness, SimpleLiveness};
8 
9 use std::cmp::{max, Ordering};
10 use std::collections::{HashMap, HashSet};
11 
12 struct KillSet {
13     set: HashSet<SSAValue>,
14     vec: Vec<SSAValue>,
15 }
16 
17 impl KillSet {
new() -> KillSet18     pub fn new() -> KillSet {
19         KillSet {
20             set: HashSet::new(),
21             vec: Vec::new(),
22         }
23     }
24 
clear(&mut self)25     pub fn clear(&mut self) {
26         self.set.clear();
27         self.vec.clear();
28     }
29 
insert(&mut self, ssa: SSAValue)30     pub fn insert(&mut self, ssa: SSAValue) {
31         if self.set.insert(ssa) {
32             self.vec.push(ssa);
33         }
34     }
35 
iter(&self) -> std::slice::Iter<'_, SSAValue>36     pub fn iter(&self) -> std::slice::Iter<'_, SSAValue> {
37         self.vec.iter()
38     }
39 
is_empty(&self) -> bool40     pub fn is_empty(&self) -> bool {
41         self.vec.is_empty()
42     }
43 }
44 
45 enum SSAUse {
46     FixedReg(u32),
47     Vec(SSARef),
48 }
49 
50 struct SSAUseMap {
51     ssa_map: HashMap<SSAValue, Vec<(usize, SSAUse)>>,
52 }
53 
54 impl SSAUseMap {
add_fixed_reg_use(&mut self, ip: usize, ssa: SSAValue, reg: u32)55     fn add_fixed_reg_use(&mut self, ip: usize, ssa: SSAValue, reg: u32) {
56         let v = self.ssa_map.entry(ssa).or_default();
57         v.push((ip, SSAUse::FixedReg(reg)));
58     }
59 
add_vec_use(&mut self, ip: usize, vec: SSARef)60     fn add_vec_use(&mut self, ip: usize, vec: SSARef) {
61         if vec.comps() == 1 {
62             return;
63         }
64 
65         for ssa in vec.iter() {
66             let v = self.ssa_map.entry(*ssa).or_default();
67             v.push((ip, SSAUse::Vec(vec)));
68         }
69     }
70 
find_vec_use_after(&self, ssa: SSAValue, ip: usize) -> Option<&SSAUse>71     fn find_vec_use_after(&self, ssa: SSAValue, ip: usize) -> Option<&SSAUse> {
72         if let Some(v) = self.ssa_map.get(&ssa) {
73             let p = v.partition_point(|(uip, _)| *uip <= ip);
74             if p == v.len() {
75                 None
76             } else {
77                 let (_, u) = &v[p];
78                 Some(u)
79             }
80         } else {
81             None
82         }
83     }
84 
add_block(&mut self, b: &BasicBlock)85     pub fn add_block(&mut self, b: &BasicBlock) {
86         for (ip, instr) in b.instrs.iter().enumerate() {
87             match &instr.op {
88                 Op::FSOut(op) => {
89                     for (i, src) in op.srcs.iter().enumerate() {
90                         let out_reg = u32::try_from(i).unwrap();
91                         if let SrcRef::SSA(ssa) = src.src_ref {
92                             assert!(ssa.comps() == 1);
93                             self.add_fixed_reg_use(ip, ssa[0], out_reg);
94                         }
95                     }
96                 }
97                 _ => {
98                     // We don't care about predicates because they're scalar
99                     for src in instr.srcs() {
100                         if let SrcRef::SSA(ssa) = src.src_ref {
101                             self.add_vec_use(ip, ssa);
102                         }
103                     }
104                 }
105             }
106         }
107     }
108 
for_block(b: &BasicBlock) -> SSAUseMap109     pub fn for_block(b: &BasicBlock) -> SSAUseMap {
110         let mut am = SSAUseMap {
111             ssa_map: HashMap::new(),
112         };
113         am.add_block(b);
114         am
115     }
116 }
117 
118 #[derive(Clone, Copy, Eq, Hash, PartialEq)]
119 enum LiveRef {
120     SSA(SSAValue),
121     Phi(u32),
122 }
123 
124 #[derive(Clone, Copy, Eq, Hash, PartialEq)]
125 struct LiveValue {
126     pub live_ref: LiveRef,
127     pub reg_ref: RegRef,
128 }
129 
130 // We need a stable ordering of live values so that RA is deterministic
131 impl Ord for LiveValue {
cmp(&self, other: &Self) -> Ordering132     fn cmp(&self, other: &Self) -> Ordering {
133         let s_file = u8::from(self.reg_ref.file());
134         let o_file = u8::from(other.reg_ref.file());
135         match s_file.cmp(&o_file) {
136             Ordering::Equal => {
137                 let s_idx = self.reg_ref.base_idx();
138                 let o_idx = other.reg_ref.base_idx();
139                 s_idx.cmp(&o_idx)
140             }
141             ord => ord,
142         }
143     }
144 }
145 
146 impl PartialOrd for LiveValue {
partial_cmp(&self, other: &Self) -> Option<Ordering>147     fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
148         Some(self.cmp(other))
149     }
150 }
151 
152 #[derive(Clone)]
153 struct RegAllocator {
154     file: RegFile,
155     num_regs: u32,
156     used: BitSet,
157     reg_ssa: Vec<SSAValue>,
158     ssa_reg: HashMap<SSAValue, u32>,
159 }
160 
161 impl RegAllocator {
new(file: RegFile, num_regs: u32) -> Self162     pub fn new(file: RegFile, num_regs: u32) -> Self {
163         Self {
164             file: file,
165             num_regs: num_regs,
166             used: BitSet::new(),
167             reg_ssa: Vec::new(),
168             ssa_reg: HashMap::new(),
169         }
170     }
171 
file(&self) -> RegFile172     fn file(&self) -> RegFile {
173         self.file
174     }
175 
num_regs_used(&self) -> u32176     pub fn num_regs_used(&self) -> u32 {
177         self.ssa_reg.len().try_into().unwrap()
178     }
179 
reg_is_used(&self, reg: u32) -> bool180     pub fn reg_is_used(&self, reg: u32) -> bool {
181         self.used.get(reg.try_into().unwrap())
182     }
183 
reg_range_is_unused(&self, reg: u32, comps: u8) -> bool184     fn reg_range_is_unused(&self, reg: u32, comps: u8) -> bool {
185         for c in 0..u32::from(comps) {
186             if self.reg_is_used(reg + c) {
187                 return false;
188             }
189         }
190         true
191     }
192 
try_get_reg(&self, ssa: SSAValue) -> Option<u32>193     pub fn try_get_reg(&self, ssa: SSAValue) -> Option<u32> {
194         self.ssa_reg.get(&ssa).cloned()
195     }
196 
try_get_ssa(&self, reg: u32) -> Option<SSAValue>197     pub fn try_get_ssa(&self, reg: u32) -> Option<SSAValue> {
198         if self.reg_is_used(reg) {
199             Some(self.reg_ssa[usize::try_from(reg).unwrap()])
200         } else {
201             None
202         }
203     }
204 
try_get_vec_reg(&self, vec: &SSARef) -> Option<u32>205     pub fn try_get_vec_reg(&self, vec: &SSARef) -> Option<u32> {
206         let Some(reg) = self.try_get_reg(vec[0]) else {
207             return None;
208         };
209 
210         let align = u32::from(vec.comps()).next_power_of_two();
211         if reg % align != 0 {
212             return None;
213         }
214 
215         for c in 1..vec.comps() {
216             let ssa = vec[usize::from(c)];
217             if self.try_get_reg(ssa) != Some(reg + u32::from(c)) {
218                 return None;
219             }
220         }
221         Some(reg)
222     }
223 
free_ssa(&mut self, ssa: SSAValue) -> u32224     pub fn free_ssa(&mut self, ssa: SSAValue) -> u32 {
225         assert!(ssa.file() == self.file);
226         let reg = self.ssa_reg.remove(&ssa).unwrap();
227         assert!(self.reg_is_used(reg));
228         let reg_usize = usize::try_from(reg).unwrap();
229         assert!(self.reg_ssa[reg_usize] == ssa);
230         self.used.remove(reg_usize);
231         reg
232     }
233 
assign_reg(&mut self, ssa: SSAValue, reg: u32)234     pub fn assign_reg(&mut self, ssa: SSAValue, reg: u32) {
235         assert!(ssa.file() == self.file);
236         assert!(reg < self.num_regs);
237         assert!(!self.reg_is_used(reg));
238 
239         let reg_usize = usize::try_from(reg).unwrap();
240         if reg_usize >= self.reg_ssa.len() {
241             self.reg_ssa.resize(reg_usize + 1, SSAValue::NONE);
242         }
243         self.reg_ssa[reg_usize] = ssa;
244         let old = self.ssa_reg.insert(ssa, reg);
245         assert!(old.is_none());
246         self.used.insert(reg_usize);
247     }
248 
try_find_unused_reg_range( &self, start_reg: u32, align: u32, comps: u8, ) -> Option<u32>249     pub fn try_find_unused_reg_range(
250         &self,
251         start_reg: u32,
252         align: u32,
253         comps: u8,
254     ) -> Option<u32> {
255         assert!(comps > 0 && u32::from(comps) <= self.num_regs);
256 
257         let mut next_reg = start_reg;
258         loop {
259             let reg: u32 = self
260                 .used
261                 .next_unset(usize::try_from(next_reg).unwrap())
262                 .try_into()
263                 .unwrap();
264 
265             // Ensure we're properly aligned
266             let reg = reg.next_multiple_of(align);
267 
268             // Ensure we're in-bounds. This also serves as a check to ensure
269             // that u8::try_from(reg + i) will succeed.
270             if reg > self.num_regs - u32::from(comps) {
271                 return None;
272             }
273 
274             if self.reg_range_is_unused(reg, comps) {
275                 return Some(reg);
276             }
277 
278             next_reg = reg + align;
279         }
280     }
281 
alloc_scalar( &mut self, ip: usize, sum: &SSAUseMap, ssa: SSAValue, ) -> u32282     pub fn alloc_scalar(
283         &mut self,
284         ip: usize,
285         sum: &SSAUseMap,
286         ssa: SSAValue,
287     ) -> u32 {
288         if let Some(u) = sum.find_vec_use_after(ssa, ip) {
289             match u {
290                 SSAUse::FixedReg(reg) => {
291                     if !self.reg_is_used(*reg) {
292                         self.assign_reg(ssa, *reg);
293                         return *reg;
294                     }
295                 }
296                 SSAUse::Vec(vec) => {
297                     let mut comp = u8::MAX;
298                     for c in 0..vec.comps() {
299                         if vec[usize::from(c)] == ssa {
300                             comp = c;
301                             break;
302                         }
303                     }
304                     assert!(comp < vec.comps());
305 
306                     let align = u32::from(vec.comps()).next_power_of_two();
307                     for c in 0..vec.comps() {
308                         if c == comp {
309                             continue;
310                         }
311 
312                         let other = vec[usize::from(c)];
313                         let Some(other_reg) = self.try_get_reg(other) else {
314                             continue;
315                         };
316 
317                         let vec_reg = other_reg & !(align - 1);
318                         if other_reg != vec_reg + u32::from(c) {
319                             continue;
320                         }
321 
322                         let reg = vec_reg + u32::from(comp);
323                         if reg < self.num_regs && !self.reg_is_used(reg) {
324                             self.assign_reg(ssa, reg);
325                             return reg;
326                         }
327                     }
328 
329                     // We weren't able to pair it with an already allocated
330                     // register but maybe we can at least find an aligned one.
331                     if let Some(reg) =
332                         self.try_find_unused_reg_range(0, align, 1)
333                     {
334                         self.assign_reg(ssa, reg);
335                         return reg;
336                     }
337                 }
338             }
339         }
340 
341         let reg = self
342             .try_find_unused_reg_range(0, 1, 1)
343             .expect("Failed to find free register");
344         self.assign_reg(ssa, reg);
345         reg
346     }
347 }
348 
349 struct PinnedRegAllocator<'a> {
350     ra: &'a mut RegAllocator,
351     pcopy: OpParCopy,
352     pinned: BitSet,
353     evicted: HashMap<SSAValue, u32>,
354 }
355 
356 impl<'a> PinnedRegAllocator<'a> {
new(ra: &'a mut RegAllocator) -> Self357     fn new(ra: &'a mut RegAllocator) -> Self {
358         PinnedRegAllocator {
359             ra: ra,
360             pcopy: OpParCopy::new(),
361             pinned: Default::default(),
362             evicted: HashMap::new(),
363         }
364     }
365 
file(&self) -> RegFile366     fn file(&self) -> RegFile {
367         self.ra.file()
368     }
369 
pin_reg(&mut self, reg: u32)370     fn pin_reg(&mut self, reg: u32) {
371         self.pinned.insert(reg.try_into().unwrap());
372     }
373 
pin_reg_range(&mut self, reg: u32, comps: u8)374     fn pin_reg_range(&mut self, reg: u32, comps: u8) {
375         for c in 0..u32::from(comps) {
376             self.pin_reg(reg + c);
377         }
378     }
379 
reg_is_pinned(&self, reg: u32) -> bool380     fn reg_is_pinned(&self, reg: u32) -> bool {
381         self.pinned.get(reg.try_into().unwrap())
382     }
383 
reg_range_is_unpinned(&self, reg: u32, comps: u8) -> bool384     fn reg_range_is_unpinned(&self, reg: u32, comps: u8) -> bool {
385         for c in 0..u32::from(comps) {
386             if self.reg_is_pinned(reg + c) {
387                 return false;
388             }
389         }
390         true
391     }
392 
assign_pin_reg(&mut self, ssa: SSAValue, reg: u32) -> RegRef393     fn assign_pin_reg(&mut self, ssa: SSAValue, reg: u32) -> RegRef {
394         self.pin_reg(reg);
395         self.ra.assign_reg(ssa, reg);
396         RegRef::new(self.file(), reg, 1)
397     }
398 
assign_pin_vec_reg(&mut self, vec: SSARef, reg: u32) -> RegRef399     pub fn assign_pin_vec_reg(&mut self, vec: SSARef, reg: u32) -> RegRef {
400         for c in 0..vec.comps() {
401             let ssa = vec[usize::from(c)];
402             self.assign_pin_reg(ssa, reg + u32::from(c));
403         }
404         RegRef::new(self.file(), reg, vec.comps())
405     }
406 
try_find_unpinned_reg_range( &self, start_reg: u32, align: u32, comps: u8, ) -> Option<u32>407     fn try_find_unpinned_reg_range(
408         &self,
409         start_reg: u32,
410         align: u32,
411         comps: u8,
412     ) -> Option<u32> {
413         let mut next_reg = start_reg;
414         loop {
415             let reg: u32 = self
416                 .pinned
417                 .next_unset(usize::try_from(next_reg).unwrap())
418                 .try_into()
419                 .unwrap();
420 
421             // Ensure we're properly aligned
422             let reg = reg.next_multiple_of(align);
423 
424             // Ensure we're in-bounds. This also serves as a check to ensure
425             // that u8::try_from(reg + i) will succeed.
426             if reg > self.ra.num_regs - u32::from(comps) {
427                 return None;
428             }
429 
430             if self.reg_range_is_unpinned(reg, comps) {
431                 return Some(reg);
432             }
433 
434             next_reg = reg + align;
435         }
436     }
437 
evict_ssa(&mut self, ssa: SSAValue, old_reg: u32)438     pub fn evict_ssa(&mut self, ssa: SSAValue, old_reg: u32) {
439         assert!(ssa.file() == self.file());
440         assert!(!self.reg_is_pinned(old_reg));
441         self.evicted.insert(ssa, old_reg);
442     }
443 
evict_reg_if_used(&mut self, reg: u32)444     pub fn evict_reg_if_used(&mut self, reg: u32) {
445         assert!(!self.reg_is_pinned(reg));
446 
447         if let Some(ssa) = self.ra.try_get_ssa(reg) {
448             self.ra.free_ssa(ssa);
449             self.evict_ssa(ssa, reg);
450         }
451     }
452 
move_ssa_to_reg(&mut self, ssa: SSAValue, new_reg: u32)453     fn move_ssa_to_reg(&mut self, ssa: SSAValue, new_reg: u32) {
454         if let Some(old_reg) = self.ra.try_get_reg(ssa) {
455             assert!(self.evicted.get(&ssa).is_none());
456             assert!(!self.reg_is_pinned(old_reg));
457 
458             if new_reg == old_reg {
459                 self.pin_reg(new_reg);
460             } else {
461                 self.ra.free_ssa(ssa);
462                 self.evict_reg_if_used(new_reg);
463 
464                 self.pcopy.push(
465                     RegRef::new(self.file(), new_reg, 1).into(),
466                     RegRef::new(self.file(), old_reg, 1).into(),
467                 );
468 
469                 self.assign_pin_reg(ssa, new_reg);
470             }
471         } else if let Some(old_reg) = self.evicted.remove(&ssa) {
472             self.evict_reg_if_used(new_reg);
473 
474             self.pcopy.push(
475                 RegRef::new(self.file(), new_reg, 1).into(),
476                 RegRef::new(self.file(), old_reg, 1).into(),
477             );
478 
479             self.assign_pin_reg(ssa, new_reg);
480         } else {
481             panic!("Unknown SSA value");
482         }
483     }
484 
finish(mut self, pcopy: &mut OpParCopy)485     fn finish(mut self, pcopy: &mut OpParCopy) {
486         pcopy.dsts_srcs.append(&mut self.pcopy.dsts_srcs);
487 
488         if !self.evicted.is_empty() {
489             // Sort so we get determinism, even if the hash map order changes
490             // from one run to another or due to rust compiler updates.
491             let mut evicted: Vec<_> = self.evicted.drain().collect();
492             evicted.sort_by_key(|(_, reg)| *reg);
493 
494             for (ssa, old_reg) in evicted {
495                 let mut next_reg = 0;
496                 let new_reg = loop {
497                     let reg = self
498                         .ra
499                         .try_find_unused_reg_range(next_reg, 1, 1)
500                         .expect("Failed to find free register");
501                     if !self.reg_is_pinned(reg) {
502                         break reg;
503                     }
504                     next_reg = reg + 1;
505                 };
506 
507                 pcopy.push(
508                     RegRef::new(self.file(), new_reg, 1).into(),
509                     RegRef::new(self.file(), old_reg, 1).into(),
510                 );
511                 self.assign_pin_reg(ssa, new_reg);
512             }
513         }
514     }
515 
try_get_vec_reg(&self, vec: &SSARef) -> Option<u32>516     pub fn try_get_vec_reg(&self, vec: &SSARef) -> Option<u32> {
517         self.ra.try_get_vec_reg(vec)
518     }
519 
collect_vector(&mut self, vec: &SSARef) -> RegRef520     pub fn collect_vector(&mut self, vec: &SSARef) -> RegRef {
521         if let Some(reg) = self.try_get_vec_reg(vec) {
522             self.pin_reg_range(reg, vec.comps());
523             return RegRef::new(self.file(), reg, vec.comps());
524         }
525 
526         let comps = vec.comps();
527         let align = u32::from(comps).next_power_of_two();
528 
529         let reg = self
530             .ra
531             .try_find_unused_reg_range(0, align, comps)
532             .or_else(|| {
533                 for c in 0..comps {
534                     let ssa = vec[usize::from(c)];
535                     let Some(comp_reg) = self.ra.try_get_reg(ssa) else {
536                         continue;
537                     };
538 
539                     let vec_reg = comp_reg & !(align - 1);
540                     if comp_reg != vec_reg + u32::from(c) {
541                         continue;
542                     }
543 
544                     if vec_reg + u32::from(comps) > self.ra.num_regs {
545                         continue;
546                     }
547 
548                     if self.reg_range_is_unpinned(vec_reg, comps) {
549                         return Some(vec_reg);
550                     }
551                 }
552                 None
553             })
554             .or_else(|| self.try_find_unpinned_reg_range(0, align, comps))
555             .expect("Failed to find an unpinned register range");
556 
557         for c in 0..comps {
558             let ssa = vec[usize::from(c)];
559             self.move_ssa_to_reg(ssa, reg + u32::from(c));
560         }
561 
562         RegRef::new(self.file(), reg, comps)
563     }
564 
alloc_vector(&mut self, vec: SSARef) -> RegRef565     pub fn alloc_vector(&mut self, vec: SSARef) -> RegRef {
566         let comps = vec.comps();
567         let align = u32::from(comps).next_power_of_two();
568 
569         if let Some(reg) = self.ra.try_find_unused_reg_range(0, align, comps) {
570             return self.assign_pin_vec_reg(vec, reg);
571         }
572 
573         let reg = self
574             .try_find_unpinned_reg_range(0, align, comps)
575             .expect("Failed to find an unpinned register range");
576 
577         for c in 0..comps {
578             self.evict_reg_if_used(reg + u32::from(c));
579         }
580         self.assign_pin_vec_reg(vec, reg)
581     }
582 
free_killed(&mut self, killed: &KillSet)583     pub fn free_killed(&mut self, killed: &KillSet) {
584         for ssa in killed.iter() {
585             if ssa.file() == self.file() {
586                 self.ra.free_ssa(*ssa);
587             }
588         }
589     }
590 }
591 
592 impl Drop for PinnedRegAllocator<'_> {
drop(&mut self)593     fn drop(&mut self) {
594         assert!(self.evicted.is_empty());
595     }
596 }
597 
instr_remap_srcs_file(instr: &mut Instr, ra: &mut PinnedRegAllocator)598 fn instr_remap_srcs_file(instr: &mut Instr, ra: &mut PinnedRegAllocator) {
599     // Collect vector sources first since those may silently pin some of our
600     // scalar sources.
601     for src in instr.srcs_mut() {
602         if let SrcRef::SSA(ssa) = &src.src_ref {
603             if ssa.file() == ra.file() && ssa.comps() > 1 {
604                 src.src_ref = ra.collect_vector(ssa).into();
605             }
606         }
607     }
608 
609     if let PredRef::SSA(pred) = instr.pred.pred_ref {
610         if pred.file() == ra.file() {
611             instr.pred.pred_ref = ra.collect_vector(&pred.into()).into();
612         }
613     }
614 
615     for src in instr.srcs_mut() {
616         if let SrcRef::SSA(ssa) = &src.src_ref {
617             if ssa.file() == ra.file() && ssa.comps() == 1 {
618                 src.src_ref = ra.collect_vector(ssa).into();
619             }
620         }
621     }
622 }
623 
instr_alloc_scalar_dsts_file( instr: &mut Instr, ip: usize, sum: &SSAUseMap, ra: &mut RegAllocator, )624 fn instr_alloc_scalar_dsts_file(
625     instr: &mut Instr,
626     ip: usize,
627     sum: &SSAUseMap,
628     ra: &mut RegAllocator,
629 ) {
630     for dst in instr.dsts_mut() {
631         if let Dst::SSA(ssa) = dst {
632             assert!(ssa.comps() == 1);
633             if ssa.file() == ra.file() {
634                 let reg = ra.alloc_scalar(ip, sum, ssa[0]);
635                 *dst = RegRef::new(ra.file(), reg, 1).into();
636             }
637         }
638     }
639 }
640 
instr_assign_regs_file( instr: &mut Instr, ip: usize, sum: &SSAUseMap, killed: &KillSet, pcopy: &mut OpParCopy, ra: &mut RegAllocator, )641 fn instr_assign_regs_file(
642     instr: &mut Instr,
643     ip: usize,
644     sum: &SSAUseMap,
645     killed: &KillSet,
646     pcopy: &mut OpParCopy,
647     ra: &mut RegAllocator,
648 ) {
649     struct VecDst {
650         dst_idx: usize,
651         comps: u8,
652         killed: Option<SSARef>,
653         reg: u32,
654     }
655 
656     let mut vec_dsts = Vec::new();
657     let mut vec_dst_comps = 0;
658     for (i, dst) in instr.dsts().iter().enumerate() {
659         if let Dst::SSA(ssa) = dst {
660             if ssa.file() == ra.file() && ssa.comps() > 1 {
661                 vec_dsts.push(VecDst {
662                     dst_idx: i,
663                     comps: ssa.comps(),
664                     killed: None,
665                     reg: u32::MAX,
666                 });
667                 vec_dst_comps += ssa.comps();
668             }
669         }
670     }
671 
672     // No vector destinations is the easy case
673     if vec_dst_comps == 0 {
674         let mut pra = PinnedRegAllocator::new(ra);
675         instr_remap_srcs_file(instr, &mut pra);
676         pra.free_killed(killed);
677         pra.finish(pcopy);
678         instr_alloc_scalar_dsts_file(instr, ip, sum, ra);
679         return;
680     }
681 
682     // Predicates can't be vectors.  This lets us ignore instr.pred in our
683     // analysis for the cases below. Only the easy case above needs to care
684     // about them.
685     assert!(!ra.file().is_predicate());
686 
687     let mut avail = killed.set.clone();
688     let mut killed_vecs = Vec::new();
689     for src in instr.srcs() {
690         if let SrcRef::SSA(vec) = src.src_ref {
691             if vec.comps() > 1 {
692                 let mut vec_killed = true;
693                 for ssa in vec.iter() {
694                     if ssa.file() != ra.file() || !avail.contains(ssa) {
695                         vec_killed = false;
696                         break;
697                     }
698                 }
699                 if vec_killed {
700                     for ssa in vec.iter() {
701                         avail.remove(ssa);
702                     }
703                     killed_vecs.push(vec);
704                 }
705             }
706         }
707     }
708 
709     vec_dsts.sort_by_key(|v| v.comps);
710     killed_vecs.sort_by_key(|v| v.comps());
711 
712     let mut next_dst_reg = 0;
713     let mut vec_dsts_map_to_killed_srcs = true;
714     let mut could_trivially_allocate = true;
715     for vec_dst in vec_dsts.iter_mut().rev() {
716         while let Some(src) = killed_vecs.pop() {
717             if src.comps() >= vec_dst.comps {
718                 vec_dst.killed = Some(src);
719                 break;
720             }
721         }
722         if vec_dst.killed.is_none() {
723             vec_dsts_map_to_killed_srcs = false;
724         }
725 
726         let align = u32::from(vec_dst.comps).next_power_of_two();
727         if let Some(reg) =
728             ra.try_find_unused_reg_range(next_dst_reg, align, vec_dst.comps)
729         {
730             vec_dst.reg = reg;
731             next_dst_reg = reg + u32::from(vec_dst.comps);
732         } else {
733             could_trivially_allocate = false;
734         }
735     }
736 
737     if vec_dsts_map_to_killed_srcs {
738         let mut pra = PinnedRegAllocator::new(ra);
739         instr_remap_srcs_file(instr, &mut pra);
740 
741         for vec_dst in &mut vec_dsts {
742             let src_vec = vec_dst.killed.as_ref().unwrap();
743             vec_dst.reg = pra.try_get_vec_reg(src_vec).unwrap();
744         }
745 
746         pra.free_killed(killed);
747 
748         for vec_dst in vec_dsts {
749             let dst = &mut instr.dsts_mut()[vec_dst.dst_idx];
750             *dst = pra
751                 .assign_pin_vec_reg(*dst.as_ssa().unwrap(), vec_dst.reg)
752                 .into();
753         }
754 
755         pra.finish(pcopy);
756 
757         instr_alloc_scalar_dsts_file(instr, ip, sum, ra);
758     } else if could_trivially_allocate {
759         let mut pra = PinnedRegAllocator::new(ra);
760         for vec_dst in vec_dsts {
761             let dst = &mut instr.dsts_mut()[vec_dst.dst_idx];
762             *dst = pra
763                 .assign_pin_vec_reg(*dst.as_ssa().unwrap(), vec_dst.reg)
764                 .into();
765         }
766 
767         instr_remap_srcs_file(instr, &mut pra);
768         pra.free_killed(killed);
769         pra.finish(pcopy);
770         instr_alloc_scalar_dsts_file(instr, ip, sum, ra);
771     } else {
772         let mut pra = PinnedRegAllocator::new(ra);
773         instr_remap_srcs_file(instr, &mut pra);
774 
775         // Allocate vector destinations first so we have the most freedom.
776         // Scalar destinations can fill in holes.
777         for dst in instr.dsts_mut() {
778             if let Dst::SSA(ssa) = dst {
779                 if ssa.file() == pra.file() && ssa.comps() > 1 {
780                     *dst = pra.alloc_vector(*ssa).into();
781                 }
782             }
783         }
784 
785         pra.free_killed(killed);
786         pra.finish(pcopy);
787 
788         instr_alloc_scalar_dsts_file(instr, ip, sum, ra);
789     }
790 }
791 
792 impl PerRegFile<RegAllocator> {
assign_reg(&mut self, ssa: SSAValue, reg: RegRef)793     pub fn assign_reg(&mut self, ssa: SSAValue, reg: RegRef) {
794         assert!(reg.file() == ssa.file());
795         assert!(reg.comps() == 1);
796         self[ssa.file()].assign_reg(ssa, reg.base_idx());
797     }
798 
free_killed(&mut self, killed: &KillSet)799     pub fn free_killed(&mut self, killed: &KillSet) {
800         for ssa in killed.iter() {
801             self[ssa.file()].free_ssa(*ssa);
802         }
803     }
804 }
805 
806 struct AssignRegsBlock {
807     ra: PerRegFile<RegAllocator>,
808     pcopy_tmp_gprs: u8,
809     live_in: Vec<LiveValue>,
810     phi_out: HashMap<u32, SrcRef>,
811 }
812 
813 impl AssignRegsBlock {
new(num_regs: &PerRegFile<u32>, pcopy_tmp_gprs: u8) -> AssignRegsBlock814     fn new(num_regs: &PerRegFile<u32>, pcopy_tmp_gprs: u8) -> AssignRegsBlock {
815         AssignRegsBlock {
816             ra: PerRegFile::new_with(|file| {
817                 RegAllocator::new(file, num_regs[file])
818             }),
819             pcopy_tmp_gprs: pcopy_tmp_gprs,
820             live_in: Vec::new(),
821             phi_out: HashMap::new(),
822         }
823     }
824 
get_scalar(&self, ssa: SSAValue) -> RegRef825     fn get_scalar(&self, ssa: SSAValue) -> RegRef {
826         let ra = &self.ra[ssa.file()];
827         let reg = ra.try_get_reg(ssa).expect("Unknown SSA value");
828         RegRef::new(ssa.file(), reg, 1)
829     }
830 
alloc_scalar( &mut self, ip: usize, sum: &SSAUseMap, ssa: SSAValue, ) -> RegRef831     fn alloc_scalar(
832         &mut self,
833         ip: usize,
834         sum: &SSAUseMap,
835         ssa: SSAValue,
836     ) -> RegRef {
837         let ra = &mut self.ra[ssa.file()];
838         let reg = ra.alloc_scalar(ip, sum, ssa);
839         RegRef::new(ssa.file(), reg, 1)
840     }
841 
try_coalesce(&mut self, ssa: SSAValue, src: &Src) -> bool842     fn try_coalesce(&mut self, ssa: SSAValue, src: &Src) -> bool {
843         debug_assert!(src.src_mod.is_none());
844         let SrcRef::Reg(src_reg) = src.src_ref else {
845             return false;
846         };
847         debug_assert!(src_reg.comps() == 1);
848 
849         if src_reg.file() != ssa.file() {
850             return false;
851         }
852 
853         let ra = &mut self.ra[src_reg.file()];
854         if ra.reg_is_used(src_reg.base_idx()) {
855             return false;
856         }
857 
858         ra.assign_reg(ssa, src_reg.base_idx());
859         true
860     }
861 
pcopy_tmp(&self) -> Option<RegRef>862     fn pcopy_tmp(&self) -> Option<RegRef> {
863         if self.pcopy_tmp_gprs > 0 {
864             Some(RegRef::new(
865                 RegFile::GPR,
866                 self.ra[RegFile::GPR].num_regs,
867                 self.pcopy_tmp_gprs,
868             ))
869         } else {
870             None
871         }
872     }
873 
assign_regs_instr( &mut self, mut instr: Box<Instr>, ip: usize, sum: &SSAUseMap, srcs_killed: &KillSet, dsts_killed: &KillSet, pcopy: &mut OpParCopy, ) -> Option<Box<Instr>>874     fn assign_regs_instr(
875         &mut self,
876         mut instr: Box<Instr>,
877         ip: usize,
878         sum: &SSAUseMap,
879         srcs_killed: &KillSet,
880         dsts_killed: &KillSet,
881         pcopy: &mut OpParCopy,
882     ) -> Option<Box<Instr>> {
883         match &mut instr.op {
884             Op::Undef(undef) => {
885                 if let Dst::SSA(ssa) = undef.dst {
886                     assert!(ssa.comps() == 1);
887                     self.alloc_scalar(ip, sum, ssa[0]);
888                 }
889                 assert!(srcs_killed.is_empty());
890                 self.ra.free_killed(dsts_killed);
891                 None
892             }
893             Op::PhiSrcs(phi) => {
894                 for (id, src) in phi.srcs.iter() {
895                     assert!(src.src_mod.is_none());
896                     if let SrcRef::SSA(ssa) = src.src_ref {
897                         assert!(ssa.comps() == 1);
898                         let reg = self.get_scalar(ssa[0]);
899                         self.phi_out.insert(*id, reg.into());
900                     } else {
901                         self.phi_out.insert(*id, src.src_ref);
902                     }
903                 }
904                 assert!(dsts_killed.is_empty());
905                 None
906             }
907             Op::PhiDsts(phi) => {
908                 assert!(instr.pred.is_true());
909 
910                 for (id, dst) in phi.dsts.iter() {
911                     if let Dst::SSA(ssa) = dst {
912                         assert!(ssa.comps() == 1);
913                         let reg = self.alloc_scalar(ip, sum, ssa[0]);
914                         self.live_in.push(LiveValue {
915                             live_ref: LiveRef::Phi(*id),
916                             reg_ref: reg,
917                         });
918                     }
919                 }
920                 assert!(srcs_killed.is_empty());
921                 self.ra.free_killed(dsts_killed);
922 
923                 None
924             }
925             Op::Break(op) => {
926                 for src in op.srcs_as_mut_slice() {
927                     if let SrcRef::SSA(ssa) = src.src_ref {
928                         assert!(ssa.comps() == 1);
929                         let reg = self.get_scalar(ssa[0]);
930                         src.src_ref = reg.into();
931                     }
932                 }
933 
934                 self.ra.free_killed(srcs_killed);
935 
936                 if let Dst::SSA(ssa) = &op.bar_out {
937                     let reg = *op.bar_in.src_ref.as_reg().unwrap();
938                     self.ra.assign_reg(ssa[0], reg);
939                     op.bar_out = reg.into();
940                 }
941 
942                 self.ra.free_killed(dsts_killed);
943 
944                 Some(instr)
945             }
946             Op::BSSy(op) => {
947                 for src in op.srcs_as_mut_slice() {
948                     if let SrcRef::SSA(ssa) = src.src_ref {
949                         assert!(ssa.comps() == 1);
950                         let reg = self.get_scalar(ssa[0]);
951                         src.src_ref = reg.into();
952                     }
953                 }
954 
955                 self.ra.free_killed(srcs_killed);
956 
957                 if let Dst::SSA(ssa) = &op.bar_out {
958                     let reg = *op.bar_in.src_ref.as_reg().unwrap();
959                     self.ra.assign_reg(ssa[0], reg);
960                     op.bar_out = reg.into();
961                 }
962 
963                 self.ra.free_killed(dsts_killed);
964 
965                 Some(instr)
966             }
967             Op::Copy(copy) => {
968                 if let SrcRef::SSA(src_vec) = &copy.src.src_ref {
969                     debug_assert!(src_vec.comps() == 1);
970                     let src_ssa = &src_vec[0];
971                     copy.src.src_ref = self.get_scalar(*src_ssa).into();
972                 }
973 
974                 self.ra.free_killed(srcs_killed);
975 
976                 let mut del_copy = false;
977                 if let Dst::SSA(dst_vec) = &mut copy.dst {
978                     debug_assert!(dst_vec.comps() == 1);
979                     let dst_ssa = &dst_vec[0];
980 
981                     if self.try_coalesce(*dst_ssa, &copy.src) {
982                         del_copy = true;
983                     } else {
984                         copy.dst = self.alloc_scalar(ip, sum, *dst_ssa).into();
985                     }
986                 }
987 
988                 self.ra.free_killed(dsts_killed);
989 
990                 if del_copy {
991                     None
992                 } else {
993                     Some(instr)
994                 }
995             }
996             Op::ParCopy(pcopy) => {
997                 for (_, src) in pcopy.dsts_srcs.iter_mut() {
998                     if let SrcRef::SSA(src_vec) = src.src_ref {
999                         debug_assert!(src_vec.comps() == 1);
1000                         let src_ssa = &src_vec[0];
1001                         src.src_ref = self.get_scalar(*src_ssa).into();
1002                     }
1003                 }
1004 
1005                 self.ra.free_killed(srcs_killed);
1006 
1007                 // Try to coalesce destinations into sources, if possible
1008                 pcopy.dsts_srcs.retain(|dst, src| match dst {
1009                     Dst::None => false,
1010                     Dst::SSA(dst_vec) => {
1011                         debug_assert!(dst_vec.comps() == 1);
1012                         !self.try_coalesce(dst_vec[0], src)
1013                     }
1014                     Dst::Reg(_) => true,
1015                 });
1016 
1017                 for (dst, _) in pcopy.dsts_srcs.iter_mut() {
1018                     if let Dst::SSA(dst_vec) = dst {
1019                         debug_assert!(dst_vec.comps() == 1);
1020                         *dst = self.alloc_scalar(ip, sum, dst_vec[0]).into();
1021                     }
1022                 }
1023 
1024                 self.ra.free_killed(dsts_killed);
1025 
1026                 pcopy.tmp = self.pcopy_tmp();
1027                 if pcopy.is_empty() {
1028                     None
1029                 } else {
1030                     Some(instr)
1031                 }
1032             }
1033             Op::FSOut(out) => {
1034                 for src in out.srcs.iter_mut() {
1035                     if let SrcRef::SSA(src_vec) = src.src_ref {
1036                         debug_assert!(src_vec.comps() == 1);
1037                         let src_ssa = &src_vec[0];
1038                         src.src_ref = self.get_scalar(*src_ssa).into();
1039                     }
1040                 }
1041 
1042                 self.ra.free_killed(srcs_killed);
1043                 assert!(dsts_killed.is_empty());
1044 
1045                 // This should be the last instruction and its sources should
1046                 // be the last free GPRs.
1047                 debug_assert!(self.ra[RegFile::GPR].num_regs_used() == 0);
1048 
1049                 for (i, src) in out.srcs.iter().enumerate() {
1050                     let reg = u32::try_from(i).unwrap();
1051                     let dst = RegRef::new(RegFile::GPR, reg, 1);
1052                     pcopy.push(dst.into(), *src);
1053                 }
1054 
1055                 None
1056             }
1057             _ => {
1058                 for file in self.ra.values_mut() {
1059                     instr_assign_regs_file(
1060                         &mut instr,
1061                         ip,
1062                         sum,
1063                         srcs_killed,
1064                         pcopy,
1065                         file,
1066                     );
1067                 }
1068                 self.ra.free_killed(dsts_killed);
1069                 Some(instr)
1070             }
1071         }
1072     }
1073 
first_pass<BL: BlockLiveness>( &mut self, b: &mut BasicBlock, bl: &BL, pred_ra: Option<&PerRegFile<RegAllocator>>, )1074     fn first_pass<BL: BlockLiveness>(
1075         &mut self,
1076         b: &mut BasicBlock,
1077         bl: &BL,
1078         pred_ra: Option<&PerRegFile<RegAllocator>>,
1079     ) {
1080         // Populate live in from the register file we're handed.  We'll add more
1081         // live in when we process the OpPhiDst, if any.
1082         if let Some(pred_ra) = pred_ra {
1083             for (raf, pred_raf) in self.ra.values_mut().zip(pred_ra.values()) {
1084                 for (ssa, reg) in &pred_raf.ssa_reg {
1085                     if bl.is_live_in(ssa) {
1086                         raf.assign_reg(*ssa, *reg);
1087                         self.live_in.push(LiveValue {
1088                             live_ref: LiveRef::SSA(*ssa),
1089                             reg_ref: RegRef::new(raf.file(), *reg, 1),
1090                         });
1091                     }
1092                 }
1093             }
1094         }
1095 
1096         let sum = SSAUseMap::for_block(b);
1097 
1098         let mut instrs = Vec::new();
1099         let mut srcs_killed = KillSet::new();
1100         let mut dsts_killed = KillSet::new();
1101 
1102         for (ip, instr) in b.instrs.drain(..).enumerate() {
1103             // Build up the kill set
1104             srcs_killed.clear();
1105             if let PredRef::SSA(ssa) = &instr.pred.pred_ref {
1106                 if !bl.is_live_after_ip(ssa, ip) {
1107                     srcs_killed.insert(*ssa);
1108                 }
1109             }
1110             for src in instr.srcs() {
1111                 for ssa in src.iter_ssa() {
1112                     if !bl.is_live_after_ip(ssa, ip) {
1113                         srcs_killed.insert(*ssa);
1114                     }
1115                 }
1116             }
1117 
1118             dsts_killed.clear();
1119             for dst in instr.dsts() {
1120                 if let Dst::SSA(vec) = dst {
1121                     for ssa in vec.iter() {
1122                         if !bl.is_live_after_ip(ssa, ip) {
1123                             dsts_killed.insert(*ssa);
1124                         }
1125                     }
1126                 }
1127             }
1128 
1129             let mut pcopy = OpParCopy::new();
1130             pcopy.tmp = self.pcopy_tmp();
1131 
1132             let instr = self.assign_regs_instr(
1133                 instr,
1134                 ip,
1135                 &sum,
1136                 &srcs_killed,
1137                 &dsts_killed,
1138                 &mut pcopy,
1139             );
1140 
1141             if !pcopy.is_empty() {
1142                 if DEBUG.annotate() {
1143                     instrs.push(Instr::new_boxed(OpAnnotate {
1144                         annotation: "generated by assign_regs".into(),
1145                     }));
1146                 }
1147                 instrs.push(Instr::new_boxed(pcopy));
1148             }
1149 
1150             if let Some(instr) = instr {
1151                 instrs.push(instr);
1152             }
1153         }
1154 
1155         // Sort live-in to maintain determinism
1156         self.live_in.sort();
1157 
1158         b.instrs = instrs;
1159     }
1160 
second_pass(&self, target: &AssignRegsBlock, b: &mut BasicBlock)1161     fn second_pass(&self, target: &AssignRegsBlock, b: &mut BasicBlock) {
1162         let mut pcopy = OpParCopy::new();
1163         pcopy.tmp = self.pcopy_tmp();
1164 
1165         for lv in &target.live_in {
1166             let src = match lv.live_ref {
1167                 LiveRef::SSA(ssa) => SrcRef::from(self.get_scalar(ssa)),
1168                 LiveRef::Phi(phi) => *self.phi_out.get(&phi).unwrap(),
1169             };
1170             let dst = lv.reg_ref;
1171             if let SrcRef::Reg(src_reg) = src {
1172                 if dst == src_reg {
1173                     continue;
1174                 }
1175             }
1176             pcopy.push(dst.into(), src.into());
1177         }
1178 
1179         if DEBUG.annotate() {
1180             b.instrs.push(Instr::new_boxed(OpAnnotate {
1181                 annotation: "generated by assign_regs".into(),
1182             }));
1183         }
1184         if b.branch().is_some() {
1185             b.instrs.insert(b.instrs.len() - 1, Instr::new_boxed(pcopy));
1186         } else {
1187             b.instrs.push(Instr::new_boxed(pcopy));
1188         }
1189     }
1190 }
1191 
1192 impl Shader {
assign_regs(&mut self)1193     pub fn assign_regs(&mut self) {
1194         assert!(self.functions.len() == 1);
1195         let f = &mut self.functions[0];
1196 
1197         // Convert to CSSA before we spill or assign registers
1198         f.to_cssa();
1199 
1200         let mut live = SimpleLiveness::for_function(f);
1201         let mut max_live = live.calc_max_live(f);
1202 
1203         // We want at least one temporary GPR reserved for parallel copies.
1204         let mut tmp_gprs = 1_u8;
1205 
1206         let spill_files = [RegFile::Pred, RegFile::Bar];
1207         for file in spill_files {
1208             let num_regs = file.num_regs(self.info.sm);
1209             if max_live[file] > num_regs {
1210                 f.spill_values(file, num_regs);
1211 
1212                 // Re-calculate liveness after we spill
1213                 live = SimpleLiveness::for_function(f);
1214                 max_live = live.calc_max_live(f);
1215 
1216                 if file == RegFile::Bar {
1217                     tmp_gprs = max(tmp_gprs, 2);
1218                 }
1219             }
1220         }
1221 
1222         // An instruction can have at most 4 vector sources/destinations.  In
1223         // order to ensure we always succeed at allocation, regardless of
1224         // arbitrary choices, we need at least 16 GPRs.
1225         let mut gpr_limit = max(max_live[RegFile::GPR], 16);
1226         let mut total_gprs = gpr_limit + u32::from(tmp_gprs);
1227 
1228         let max_gprs = RegFile::GPR.num_regs(self.info.sm);
1229         if total_gprs > max_gprs {
1230             // If we're spilling GPRs, we need to reserve 2 GPRs for OpParCopy
1231             // lowering because it needs to be able lower Mem copies which
1232             // require a temporary
1233             tmp_gprs = max(tmp_gprs, 2);
1234             total_gprs = max_gprs;
1235             gpr_limit = total_gprs - u32::from(tmp_gprs);
1236 
1237             f.spill_values(RegFile::GPR, gpr_limit);
1238 
1239             // Re-calculate liveness one last time
1240             live = SimpleLiveness::for_function(f);
1241         }
1242 
1243         self.info.num_gprs = total_gprs.try_into().unwrap();
1244 
1245         // We do a maximum here because nak_from_nir may set num_barriers to 1
1246         // in the case where there is an OpBar.
1247         self.info.num_barriers = max(
1248             self.info.num_barriers,
1249             max_live[RegFile::Bar].try_into().unwrap(),
1250         );
1251 
1252         let limit = PerRegFile::new_with(|file| {
1253             if file == RegFile::GPR {
1254                 gpr_limit
1255             } else {
1256                 file.num_regs(self.info.sm)
1257             }
1258         });
1259 
1260         let mut blocks: Vec<AssignRegsBlock> = Vec::new();
1261         for b_idx in 0..f.blocks.len() {
1262             let pred = f.blocks.pred_indices(b_idx);
1263             let pred_ra = if pred.is_empty() {
1264                 None
1265             } else {
1266                 // Start with the previous block's.
1267                 Some(&blocks[pred[0]].ra)
1268             };
1269 
1270             let bl = live.block_live(b_idx);
1271 
1272             let mut arb = AssignRegsBlock::new(&limit, tmp_gprs);
1273             arb.first_pass(&mut f.blocks[b_idx], bl, pred_ra);
1274 
1275             assert!(blocks.len() == b_idx);
1276             blocks.push(arb);
1277         }
1278 
1279         for b_idx in 0..f.blocks.len() {
1280             let arb = &blocks[b_idx];
1281             for sb_idx in f.blocks.succ_indices(b_idx).to_vec() {
1282                 arb.second_pass(&blocks[sb_idx], &mut f.blocks[b_idx]);
1283             }
1284         }
1285     }
1286 }
1287