diff --git a/kernel/src/vspace/allocator/freelist.rs b/kernel/src/vspace/allocator/freelist.rs index d9d16fa..6ea15f0 100644 --- a/kernel/src/vspace/allocator/freelist.rs +++ b/kernel/src/vspace/allocator/freelist.rs @@ -1,5 +1,6 @@ // adapted from https://os.phil-opp.com/allocator-designs/#linked-list-allocator +use crate::utils::then::Then; use crate::vspace::addr::{AddressOps, PhysAddr}; use core::alloc::{GlobalAlloc, Layout}; use spin::Mutex; @@ -21,6 +22,24 @@ impl ListNode { fn end_addr(&self) -> PhysAddr { self.start_addr() + self.size } + + fn fit(&self, size: usize, align: usize) -> Result { + let alloc_start = self.start_addr().align_up(align); + let alloc_end = alloc_start + size; + + if alloc_end > self.end_addr() { + return Err(()); + } + + let excess_size = (self.end_addr() - alloc_end).as_usize(); + if excess_size > 0 && excess_size < core::mem::size_of::() { + // GlobalAlloc requires that dealloc be called with the original size and alignment, + // so we can't "waste" the remaining space and must return an error. + return Err(()); + } + + Ok(alloc_start) + } } struct FreeList { @@ -34,7 +53,47 @@ impl FreeList { } } - unsafe fn add_free_region(&mut self, start: PhysAddr, size: usize) { + fn alloc_node(&mut self, mut predicate: F) -> Option<(&'static mut ListNode, V)> + where F: FnMut(&mut ListNode) -> Result { + let mut current = &mut self.head; + while let Some(ref mut region) = current.next { + if let Ok(v) = predicate(region) { + let next = region.next.take(); + let ret = Some((current.next.take().unwrap(), v)); + current.next = next; + return ret; + } else { + current = current.next.as_mut().unwrap(); + } + } + None + } + + fn align_layout(layout: Layout) -> (usize, usize) { + let layout = layout + .align_to(core::mem::align_of::()) + .expect("adjusting alignment failed") + .pad_to_align(); + let size = layout.size().max(core::mem::size_of::()); + (size, layout.align()) + } + + unsafe fn alloc(&mut self, layout: Layout) -> *mut u8 { + let (size, align) = Self::align_layout(layout); + + if let Some((region, alloc_start)) = self.alloc_node(|region| region.fit(size, align)) { + let alloc_end = alloc_start + size; + let excess_size = (region.end_addr() - alloc_end).as_usize(); + if excess_size > 0 { + self.dealloc(alloc_end, excess_size); + } + alloc_start.as_mut_ptr() + } else { + core::ptr::null_mut() + } + } + + unsafe fn dealloc(&mut self, start: PhysAddr, size: usize) { assert_eq!(start.align_up(core::mem::align_of::()), start); assert!(size >= core::mem::size_of::()); @@ -45,39 +104,31 @@ impl FreeList { self.head.next = Some(&mut *node_ptr); } - fn find_region( - &mut self, - size: usize, - align: usize, - ) -> Option<(&'static mut ListNode, PhysAddr)> { - let mut current = &mut self.head; - while let Some(ref mut region) = current.next { - if let Ok(alloc_start) = Self::alloc_from_region(®ion, size, align) { - let next = region.next.take(); - let ret = Some((current.next.take().unwrap(), alloc_start)); - current.next = next; - return ret; - } else { - current = current.next.as_mut().unwrap(); + pub fn reserve(&mut self, start: PhysAddr, size: usize) { + if let Some((region, _)) = self + .alloc_node(|region| (region.start_addr() <= start).chain(|| region.fit(size, 1), ())) + { + /* layout + * region: | before | [start +: size] | after | + * ^ ^ ^ ^ region.end_addr() + * | | alloc_start | + * | | alloc_end + * | region.start_addr() + */ + + let region_start = region.start_addr(); + let region_end = region.end_addr(); + + let before_size = (start - region_start).as_usize(); + if before_size > 0 { + unsafe { self.dealloc(region_start, before_size) } + } + + let after_size = (region_end - (start + size)).as_usize(); + if after_size > 0 { + unsafe { self.dealloc(start + size, after_size) } } } - None - } - - fn alloc_from_region(region: &ListNode, size: usize, align: usize) -> Result { - let alloc_start = region.start_addr().align_up(align); - let alloc_end = alloc_start + size; - - if alloc_end > region.end_addr() { - return Err(()); - } - - let excess_size = (region.end_addr() - alloc_end).as_usize(); - if excess_size > 0 && excess_size < core::mem::size_of::() { - return Err(()); - } - - Ok(alloc_start) } } @@ -88,42 +139,21 @@ pub struct FreeListAllocator { impl FreeListAllocator { pub fn new(start: PhysAddr, size: usize) -> Self { let mut list = FreeList::new(); - unsafe { list.add_free_region(start, size) } + unsafe { list.dealloc(start, size) } Self { list: Mutex::new(list), } } - - fn size_align(layout: Layout) -> (usize, usize) { - let layout = layout - .align_to(core::mem::align_of::()) - .expect("adjusting alignment failed") - .pad_to_align(); - let size = layout.size().max(core::mem::size_of::()); - (size, layout.align()) - } } unsafe impl GlobalAlloc for FreeListAllocator { unsafe fn alloc(&self, layout: Layout) -> *mut u8 { - let (size, align) = Self::size_align(layout); - let mut allocator = self.list.lock(); - - if let Some((region, alloc_start)) = allocator.find_region(size, align) { - let alloc_end = alloc_start + size; - let excess_size = (region.end_addr() - alloc_end).as_usize(); - if excess_size > 0 { - allocator.add_free_region(alloc_end, excess_size); - } - alloc_start.as_mut_ptr() - } else { - core::ptr::null_mut() - } + self.list.lock().alloc(layout) } unsafe fn dealloc(&self, ptr: *mut u8, layout: Layout) { - let (size, _) = Self::size_align(layout); - self.list.lock().add_free_region(PhysAddr::from(ptr), size) + let (size, _) = FreeList::align_layout(layout); + self.list.lock().dealloc(PhysAddr::from(ptr), size) } } @@ -158,4 +188,36 @@ mod tests { unsafe { allocator.alloc(Layout::from_size_align(PAGE_SIZE, PAGE_SIZE).unwrap()) }; assert_eq!(ptr as usize, BASE.as_usize()); } + + #[test_case] + fn test_freelist_reserve() { + const BASE: PhysAddr = PhysAddr(0x80300000); + + let allocator = FreeListAllocator::new(BASE, 32 * PAGE_SIZE); + allocator + .list + .lock() + .reserve(BASE + 4 * PAGE_SIZE, 4 * PAGE_SIZE); + + let mut cnt = 32 - 4; + loop { + let ptr = + unsafe { allocator.alloc(Layout::from_size_align(PAGE_SIZE, PAGE_SIZE).unwrap()) }; + if ptr.is_null() { + assert_eq!(cnt, 0); + break; + } + + let ptr = PhysAddr::from(ptr); + assert!( + !(BASE + 4 * PAGE_SIZE <= ptr && ptr < BASE + (4 + 4) * PAGE_SIZE), + "Bad alloc: returned ptr: {:?}, reserved range: {:?}->{:?}", + ptr, + BASE + 4 * PAGE_SIZE, + BASE + (4 + 4) * PAGE_SIZE + ); + + cnt -= 1; + } + } }