use bootloader_api::info::{MemoryRegionKind, MemoryRegions};
use x86_64::addr::{PhysAddr, VirtAddr};
use x86_64::registers::control::Cr3;
use x86_64::structures::paging::{
    FrameAllocator, OffsetPageTable, PageSize, PageTable, PhysFrame, Size4KiB,
};

// Not sure how to get the configured page size, but pretty sure it's only 4kib by default
// TODO: check if this is true and enable 64-bit pages
static PAGE_SIZE: u64 = Size4KiB::SIZE;

pub struct SimpleFrameAllocator<I>
where
    I: Iterator<Item = PhysFrame>,
{
    pub usable_frames: I,
}

// TODO: increase page size here as well
unsafe impl<I> FrameAllocator<Size4KiB> for SimpleFrameAllocator<I>
where
    I: Iterator<Item = PhysFrame>,
{
    fn allocate_frame(&mut self) -> Option<PhysFrame<Size4KiB>> {
        self.usable_frames.next()
    }
}

/// Caller must ensure memory map is valid
pub unsafe fn init_frame_allocator(
    memory_regions: &'static MemoryRegions,
) -> SimpleFrameAllocator<impl Iterator<Item = PhysFrame>> {
    // Collect all physical frames marked as usable
    let usable_frames = memory_regions
        .iter()
        // Filter to only usable frames
        .filter(|region| region.kind == MemoryRegionKind::Usable)
        // Map to start+end range
        .map(|region| region.start..region.end)
        // Collect all pages in region
        .flat_map(|region| region.step_by(PAGE_SIZE as usize))
        // Map to start address
        .map(|start_address| PhysFrame::containing_address(PhysAddr::new(start_address)));

    SimpleFrameAllocator { usable_frames }
}

/// Get the current active level 4 page table
unsafe fn active_level_4_table(physical_memory_offset: VirtAddr) -> &'static mut PageTable {
    let (physical_frame, _flags) = Cr3::read();

    let frame_physical_address = physical_frame.start_address();
    let frame_virtual_address = physical_memory_offset + frame_physical_address.as_u64();
    let page_table_address = frame_virtual_address.as_mut_ptr();

    &mut *page_table_address
}

pub unsafe fn init(physical_memory_offset: VirtAddr) -> OffsetPageTable<'static> {
    let active_table = active_level_4_table(physical_memory_offset);

    OffsetPageTable::new(active_table, physical_memory_offset)
}