Bindings to the seekable variant of the ZSTD compression format
use crate::bindings::*;
use crate::{
    Error, ZSTDError, ZSTD_DStream, ZSTD_reset_session_only, BLOCK_SIZE_MAX, MAGIC_SKIPPABLE_START,
    SEEKABLE_MAGIC_NUMBER, SEEK_TABLE_FOOTER_SIZE, SKIPPABLE_HEADER_SIZE,
};
use libc::*;
use std::hash::Hasher;
use xxhash_rust::xxh64::Xxh64;

const SEEKABLE_BUFF_SIZE: usize = BLOCK_SIZE_MAX;

#[inline(always)]
fn slice_to_num(buff: &[u8]) -> Result<u32, Error> {
    let mut b = [0; 4];
    b.clone_from_slice(buff);
    Ok(u32::from_le_bytes(b))
}

/// The type of decompressors.
pub struct Seekable<R> {
    dstream: *mut ZSTD_DStream,
    seek_table: SeekTable,
    src: R,
    decompressed_offset: u64,
    cur_frame: u32,
    in_buff: Vec<u8>,
    out_buff: Vec<u8>,
    xxh_state: Xxh64,
}

unsafe impl<R> Send for Seekable<R> {}

impl<R> Drop for Seekable<R> {
    fn drop(&mut self) {
        if !self.dstream.is_null() {
            unsafe {
                ZSTD_freeDStream(self.dstream);
                self.dstream = std::ptr::null_mut();
            }
        }
    }
}

struct SeekEntry {
    c_offset: u64,
    d_offset: u64,
    checksum: u32,
}

struct SeekTable {
    entries: Vec<SeekEntry>,
    checksum_flag: u32,
}

impl SeekTable {
    fn new() -> Self {
        SeekTable {
            entries: Vec::new(),
            checksum_flag: 1,
        }
    }
}

impl<R: std::io::Read + std::io::Seek> Seekable<R> {
    pub fn init(source: R) -> Result<Self, Error> {
        let mut seekable = Self::make_seekable(source)?;
        seekable.init_advanced()?;
        Ok(seekable)
    }

    // The parameter size is the size of a buffer. So if the source is not one, the size is None.
    fn make_seekable(source: R) -> Result<Self, Error> {
        unsafe {
            let dstream = ZSTD_createDStream();
            if dstream.is_null() {
                Err(Error::Null)
            } else {
                Ok(Seekable {
                    dstream,
                    seek_table: SeekTable::new(),
                    src: source,
                    decompressed_offset: 0,
                    cur_frame: 0,
                    in_buff: vec![0; SEEKABLE_BUFF_SIZE],
                    out_buff: vec![0; SEEKABLE_BUFF_SIZE],
                    xxh_state: Xxh64::new(0),
                })
            }
        }
    }

    fn load_seek_table(&mut self) -> Result<(), Error> {
        // Help the compiler to see it's inbounds
        self.src
            .seek(std::io::SeekFrom::End(-(SEEK_TABLE_FOOTER_SIZE as i64)))?;

        self.src
            .read_exact(&mut self.in_buff[..SEEK_TABLE_FOOTER_SIZE])?;

        let prefix = slice_to_num(&self.in_buff[5..9])?;
        if prefix != SEEKABLE_MAGIC_NUMBER {
            return Err(Error::PrefixUnknown(prefix));
        }

        let sfd = self.in_buff[4];
        let checksum_flag = (sfd >> 7) as usize;

        if ((sfd >> 2) & 0x1f) != 0 {
            return Err(Error::Corruption("when looking the checksum flag"));
        }

        let num_frames = slice_to_num(&self.in_buff[..4])? as usize;
        let size_p_entry: usize = 8 + if checksum_flag != 0 { 4 } else { 0 };
        let table_size = size_p_entry * num_frames;
        let frame_size = table_size + SEEK_TABLE_FOOTER_SIZE + SKIPPABLE_HEADER_SIZE;

        let mut remaining = frame_size as usize - SEEK_TABLE_FOOTER_SIZE;
        let to_read = std::cmp::min(remaining, SEEKABLE_BUFF_SIZE);

        self.src
            .seek(std::io::SeekFrom::End(-(frame_size as i64)))?;

        self.src.read_exact(&mut self.in_buff[..to_read])?;
        remaining -= to_read;

        let mut prefix = slice_to_num(&self.in_buff[..4])?;
        if prefix != (MAGIC_SKIPPABLE_START | 0xE) {
            return Err(Error::PrefixUnknown(prefix));
        }

        prefix = slice_to_num(&self.in_buff[4..8])?;
        if prefix as usize + SKIPPABLE_HEADER_SIZE != frame_size {
            return Err(Error::PrefixUnknown(prefix));
        }

        let mut entries: Vec<SeekEntry> = Vec::with_capacity((num_frames + 1) as usize);
        let mut pos = 8;
        let (mut c_offset, mut d_offset) = (0, 0);

        for idx in 0..num_frames {
            if pos + size_p_entry > SEEKABLE_BUFF_SIZE {
                let offset = SEEKABLE_BUFF_SIZE - pos;
                let to_read = std::cmp::min(remaining, SEEKABLE_BUFF_SIZE - offset);
                self.in_buff.copy_within(pos..pos + offset, 0); // offset..offset + to_read, pos);

                self.src
                    .read_exact(&mut self.in_buff[offset..offset + to_read])?;
                remaining -= to_read;
                pos = 0;
            }

            entries.push(SeekEntry {
                c_offset,
                d_offset,
                checksum: 0,
            });
            c_offset += slice_to_num(&self.in_buff[pos..pos + 4])? as u64;
            pos += 4;
            d_offset += slice_to_num(&self.in_buff[pos..pos + 4])? as u64;
            pos += 4;

            if checksum_flag != 0 {
                entries[idx].checksum = slice_to_num(&self.in_buff[pos..pos + 4])?;
                pos += 4;
            }
        }
        entries.push(SeekEntry {
            c_offset,
            d_offset,
            checksum: 0,
        });

        self.seek_table.entries = entries;
        self.seek_table.checksum_flag = checksum_flag as u32;

        Ok(())
    }

    fn init_advanced(&mut self) -> Result<(), Error> {
        self.load_seek_table()?;

        self.decompressed_offset = u64::MAX;
        self.cur_frame = u32::MAX;

        unsafe {
            let dstream_init = ZSTD_initDStream(self.dstream);
            if ZSTD_isError(dstream_init) != 0 {
                Err(Error::ZSTD(ZSTDError(dstream_init)))
            } else {
                Ok(())
            }
        }
    }

    /// Decompress a single frame. This method internally calls
    /// `decompress`, and `dest` must be exactly the size of the
    /// uncompressed frame.
    pub fn decompress_frame(&mut self, dest: &mut [u8], index: usize) -> Result<usize, Error> {
        let dec_size = self.get_frame_decompressed_size(index)?;

        if (dest.len() as u64) < dec_size {
            Err(Error::DSizeTooSmall(dest.len() as u64, dec_size))
        } else {
            self.decompress(dest, self.seek_table.entries[index].d_offset)
        }
    }

    /// Decompress starting from an offset. The length of `out` must
    /// be at least the size of the decompressed output.
    ///
    /// This function finds the correct frame to start with, and takes
    /// care of decompressing multiple frames in a row.
    pub fn decompress(&mut self, out: &mut [u8], offset: u64) -> Result<usize, Error> {
        let eos = self.seek_table.entries.last().unwrap().d_offset;
        let len = out.len() as u64;
        let len = if offset + len > eos {
            eos - offset
        } else {
            len
        };

        let mut tgt_frame = self.seekable_offset_to_frame_index(offset);

        let mut inn = ZSTD_inBuffer {
            src: std::ptr::null() as *const c_void,
            size: 0,
            pos: 0,
        };

        loop {
            if tgt_frame as usize != self.cur_frame as usize || offset != self.decompressed_offset {
                self.decompressed_offset = self.seek_table.entries[tgt_frame].d_offset;
                self.cur_frame = tgt_frame as u32;

                self.src.seek(std::io::SeekFrom::Start(
                    self.seek_table.entries[tgt_frame].c_offset,
                ))?;

                inn.src = self.in_buff.as_ptr() as *const _ as *const c_void;
                self.xxh_state.reset(0);

                unsafe {
                    let r = ZSTD_DCtx_reset(self.dstream, ZSTD_reset_session_only);
                    if ZSTD_isError(r) != 0 {
                        return Err(Error::ZSTD(ZSTDError(r)));
                    }
                }
            }

            while self.decompressed_offset < offset + len as u64 {
                // here slice_tmp is a [u8] helper so we avoid using unsafe pointer arithmetic later on
                let (mut out_tmp, slice_tmp) = if self.decompressed_offset < offset {
                    (
                        ZSTD_outBuffer {
                            dst: self.out_buff.as_mut_ptr() as *mut c_void,
                            size: std::cmp::min(
                                SEEKABLE_BUFF_SIZE,
                                (offset - self.decompressed_offset) as size_t,
                            ),
                            pos: 0,
                        },
                        (&self.out_buff).as_ref(),
                    )
                } else {
                    (
                        ZSTD_outBuffer {
                            dst: out.as_mut_ptr() as *mut c_void,
                            size: len as usize,
                            pos: (self.decompressed_offset - offset) as size_t,
                        },
                        &*out,
                    )
                };

                let prev_out_pos = out_tmp.pos;
                let mut to_read;

                unsafe {
                    to_read = ZSTD_decompressStream(self.dstream, &mut out_tmp, &mut inn);

                    if ZSTD_isError(to_read) != 0 {
                        return Err(Error::ZSTD(ZSTDError(to_read)));
                    }
                }

                if self.seek_table.checksum_flag != 0 {
                    self.xxh_state.write(&slice_tmp[prev_out_pos..out_tmp.pos]);
                }

                let forward_progress = (out_tmp.pos - prev_out_pos) as u64;
                self.decompressed_offset += forward_progress;

                if to_read == 0 {
                    // frame complete

                    // verify checksum
                    let f = self.xxh_state.finish();
                    let f = f as u32;
                    if self.seek_table.checksum_flag != 0
                        && f != self.seek_table.entries[tgt_frame].checksum
                    {
                        return Err(Error::Corruption("during decompression"));
                    }

                    if self.decompressed_offset < offset + len as u64 {
                        tgt_frame = self.seekable_offset_to_frame_index(self.decompressed_offset);
                        // assert!(tgt_frame != self.seek_table.entries.len());
                    }

                    break;
                }

                if inn.pos == inn.size {
                    to_read = std::cmp::min(to_read, SEEKABLE_BUFF_SIZE);
                    self.in_buff.resize(to_read, 0u8);
                    self.src.read_exact(&mut self.in_buff)?;
                    inn.size = to_read;
                    inn.pos = 0;
                }
            }

            if self.decompressed_offset == offset + len as u64 {
                break;
            }
        }

        Ok(len as usize)
    }
}

impl<R> Seekable<R> {
    /// Number of frames in the message.
    #[inline(always)]
    pub fn get_num_frames(&self) -> usize {
        self.seek_table.entries.len() - 1
    }

    #[inline(always)]
    fn get_frame(&self, frame_index: usize) -> Result<&'_ SeekEntry, Error> {
        let max_frames = self.get_num_frames();
        if frame_index >= max_frames {
            Err(Error::FIndexTooLarge(frame_index, max_frames))
        } else {
            Ok(&self.seek_table.entries[frame_index])
        }
    }

    /// Offset of the frame in the compressed data.
    pub fn get_frame_compressed_offset(&self, frame_index: usize) -> Result<u64, Error> {
        Ok(self.get_frame(frame_index)?.c_offset)
    }
    /// Size of the frame in the compressed data.
    pub fn get_frame_compressed_size(&self, frame_index: usize) -> Result<u64, Error> {
        let entry = self.get_frame(frame_index)?;
        Ok(self.seek_table.entries[frame_index + 1].c_offset - entry.c_offset)
    }
    /// Offset of the frame in the decompressed data.
    pub fn get_frame_decompressed_offset(&self, frame_index: usize) -> Result<u64, Error> {
        Ok(self.get_frame(frame_index)?.d_offset)
    }
    /// Size of the frame in the decompressed data.
    pub fn get_frame_decompressed_size(&self, frame_index: usize) -> Result<u64, Error> {
        let entry = self.get_frame(frame_index)?;
        Ok(self.seek_table.entries[frame_index + 1].d_offset - entry.d_offset)
    }
    /// Perform a binary search to find the frame containing the offset.
    pub fn seekable_offset_to_frame_index(&self, offset: u64) -> usize {
        let n_frames = self.get_num_frames();

        assert!(self.seek_table.entries.len() >= n_frames);

        if offset >= self.seek_table.entries[n_frames].d_offset {
            return n_frames;
        }

        let (mut lo, mut hi) = (0, n_frames);

        while lo + 1 < hi {
            let mid = lo + ((hi - lo) >> 1);
            if self.seek_table.entries[mid].d_offset <= offset {
                lo = mid
            } else {
                hi = mid;
            }
        }

        lo
    }
}

impl<'a> Seekable<std::io::Cursor<&'a [u8]>> {
    /// Initialise a decompressor with an input buffer.
    pub fn init_buf(input: &'a [u8]) -> Result<Self, Error> {
        let source = std::io::Cursor::new(input);
        let mut seekable = Seekable::make_seekable(source)?;
        seekable.init_advanced()?;
        Ok(seekable)
    }
}

impl Seekable<std::fs::File> {
    /// Initialise a decompressor with a file. This method opens the file, and dropping the resulting `Seekable` closes the file.
    pub fn init_file(name: &str) -> Result<Self, Error> {
        let source = std::fs::File::create(name)?;
        let mut seekable = Seekable::make_seekable(source)?;
        seekable.init_advanced()?;
        Ok(seekable)
    }
}

#[test]
fn pijul_change() {
    let mut h = xxhash_rust::xxh64::Xxh64::new(0);
    h.update(&[4]);
    let d = 0x64b9da3ed69d6732;
    let f = h.finish();
    assert_eq!(d, f);

    let change = include_bytes!(
        "../../.pijul/changes/IH/334Q5ACWE4TNQYYOOF6GWV6CRXOEM6542NVNPA6HRIZ3CBFKEAC.change"
    );
    use serde_derive::*;
    #[derive(Deserialize)]
    pub struct Offsets {
        pub version: u64,
        pub hashed_len: u64,
        pub unhashed_off: u64,
        pub unhashed_len: u64,
        pub contents_off: u64,
        pub contents_len: u64,
        pub total: u64,
    }

    let off0 = std::mem::size_of::<Offsets>();
    let offsets: Offsets = bincode::deserialize(&change[..off0]).unwrap();

    let mut s = Seekable::init_buf(&change[off0..offsets.unhashed_off as usize]).unwrap();

    let mut buf_ = Vec::new();
    buf_.resize(offsets.hashed_len as usize, 0);
    s.decompress(&mut buf_[..], 0).unwrap();
}