1 // Copyright (c) 2023 Huawei Device Co., Ltd. 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 use std::sync::atomic::AtomicUsize; 15 use std::sync::atomic::Ordering::SeqCst; 16 use std::sync::Mutex; 17 18 pub(super) struct Sleeper { 19 record: Record, 20 idle_list: Mutex<Vec<usize>>, 21 num_workers: usize, 22 } 23 24 impl Sleeper { new(num_workers: usize) -> Self25 pub fn new(num_workers: usize) -> Self { 26 Sleeper { 27 record: Record::new(num_workers), 28 idle_list: Mutex::new(Vec::with_capacity(num_workers)), 29 num_workers, 30 } 31 } 32 is_parked(&self, worker_index: &usize) -> bool33 pub fn is_parked(&self, worker_index: &usize) -> bool { 34 let idle_list = self.idle_list.lock().unwrap(); 35 idle_list.contains(worker_index) 36 } 37 pop_worker(&self) -> Option<usize>38 pub fn pop_worker(&self) -> Option<usize> { 39 let (active_num, searching_num) = self.record.load_state(); 40 if active_num >= self.num_workers || searching_num > 0 { 41 return None; 42 } 43 44 let mut idle_list = self.idle_list.lock().unwrap(); 45 46 let res = idle_list.pop(); 47 if res.is_some() { 48 self.record.inc_active_num(); 49 } 50 res 51 } 52 53 // return true if it's the last thread going to sleep. push_worker(&self, worker_index: usize) -> bool54 pub fn push_worker(&self, worker_index: usize) -> bool { 55 let mut idle_list = self.idle_list.lock().unwrap(); 56 idle_list.push(worker_index); 57 58 self.record.dec_active_num() 59 } 60 try_inc_searching_num(&self) -> bool61 pub fn try_inc_searching_num(&self) -> bool { 62 let (active_num, searching_num) = self.record.load_state(); 63 if searching_num * 2 < active_num { 64 // increment searching worker number 65 self.record.inc_searching_num(); 66 return true; 67 } 68 false 69 } 70 71 // reutrn true if it's the last searching thread dec_searching_num(&self) -> bool72 pub fn dec_searching_num(&self) -> bool { 73 self.record.dec_searching_num() 74 } 75 76 #[cfg(feature = "metrics")] load_state(&self) -> (usize, usize)77 pub(crate) fn load_state(&self) -> (usize, usize) { 78 self.record.load_state() 79 } 80 } 81 82 const ACTIVE_WORKER_SHIFT: usize = 16; 83 const SEARCHING_MASK: usize = (1 << ACTIVE_WORKER_SHIFT) - 1; 84 const ACTIVE_MASK: usize = !SEARCHING_MASK; 85 // 32 bits 16 bits 16 bits 86 // |-------------------| working num | searching num| 87 struct Record(AtomicUsize); 88 89 impl Record { new(active_num: usize) -> Self90 fn new(active_num: usize) -> Self { 91 Self(AtomicUsize::new(active_num << ACTIVE_WORKER_SHIFT)) 92 } 93 94 // Return true if it is the last searching thread dec_searching_num(&self) -> bool95 fn dec_searching_num(&self) -> bool { 96 let ret = self.0.fetch_sub(1, SeqCst); 97 (ret & SEARCHING_MASK) == 1 98 } 99 inc_searching_num(&self)100 fn inc_searching_num(&self) { 101 self.0.fetch_add(1, SeqCst); 102 } 103 inc_active_num(&self)104 fn inc_active_num(&self) { 105 let inc = 1 << ACTIVE_WORKER_SHIFT; 106 107 self.0.fetch_add(inc, SeqCst); 108 } 109 dec_active_num(&self) -> bool110 fn dec_active_num(&self) -> bool { 111 let dec = 1 << ACTIVE_WORKER_SHIFT; 112 113 let ret = self.0.fetch_sub(dec, SeqCst); 114 let active_num = ((ret & ACTIVE_MASK) >> ACTIVE_WORKER_SHIFT) - 1; 115 active_num == 0 116 } 117 load_state(&self) -> (usize, usize)118 fn load_state(&self) -> (usize, usize) { 119 let union_num = self.0.load(SeqCst); 120 121 let searching_num = union_num & SEARCHING_MASK; 122 let active_num = (union_num & ACTIVE_MASK) >> ACTIVE_WORKER_SHIFT; 123 124 (active_num, searching_num) 125 } 126 } 127