Bindings to the seekable variant of the ZSTD compression format
use super::{CStream, Error, ZSTDError, ZSTD_reset_session_only};
use crate::bindings::*;
use libc::*;
use std::hash::Hasher;
use std::{cmp, slice};
use xxhash_rust::xxh64::Xxh64;
// use twox_hash::xxh3::Hash64;

mod framelog;
use framelog::*;

const MAX_FRAME_DECOMPR_SIZE: usize = 0x80000000;
const FRAMELOG_STARTING_CAPACITY: usize = 16;

/// The type of seekable compressors.
pub struct SeekableCStream {
    cstream: CStream,
    framelog: FrameLog,
    frame_c_size: u32,
    frame_d_size: u32,
    xxh_state: Xxh64,
    max_frame_size: usize,
    writing_seek_table: bool,
}

unsafe impl Send for SeekableCStream {}

impl SeekableCStream {
    /// Create a compressor with the given level and frame size. When
    /// seeking in the file, frames are decompressed one by one, so
    /// this should be chosen of a size similar to the chunks that
    /// will be decompressed.
    pub fn new(level: usize, frame_size: usize) -> Result<Self, Error> {
        let cstream = unsafe { ZSTD_createCStream() };

        if cstream.is_null() {
            return Err(Error::Null);
        }

        if frame_size > MAX_FRAME_DECOMPR_SIZE {
            Err(Error::FParamUnsupported(frame_size, MAX_FRAME_DECOMPR_SIZE))
        } else {
            let max_frame_size = if frame_size > 0 {
                frame_size
            } else {
                MAX_FRAME_DECOMPR_SIZE
            };

            unsafe {
                let result = ZSTD_initCStream(cstream, level as c_int);
                if ZSTD_isError(result) != 0 {
                    return Err(Error::ZSTD(ZSTDError(result)));
                }
            };

            Ok(SeekableCStream {
                cstream: CStream { p: cstream },
                framelog: FrameLog::with_capacity(FRAMELOG_STARTING_CAPACITY),
                frame_c_size: 0,
                frame_d_size: 0,
                xxh_state: Xxh64::new(0),
                max_frame_size,
                writing_seek_table: false,
            })
        }
    }

    fn end_frame(&mut self, output: &mut ZSTD_outBuffer) -> Result<usize, Error> {
        let prev_out_pos = output.pos;
        let ret = unsafe { ZSTD_endStream(self.cstream.p, output) };

        self.frame_c_size += (output.pos - prev_out_pos) as u32;

        if ret != 0 {
            return Ok(ret);
        }

        let checksum = if self.framelog.checksum_flag != 0 {
            self.xxh_state.finish() as u32
        } else {
            0
        };

        self.framelog
            .log_frame(self.frame_c_size, self.frame_d_size, checksum)?;

        self.frame_c_size = 0;
        self.frame_d_size = 0;

        unsafe {
            let r = ZSTD_CCtx_reset(self.cstream.p, ZSTD_reset_session_only);
            if ZSTD_isError(r) != 0 {
                return Err(Error::ZSTD(ZSTDError(r)));
            }
        };

        if self.framelog.checksum_flag != 0 {
            self.xxh_state = Xxh64::new(0);
        }

        Ok(0)
    }

    /// Compress one chunk of input, and write it into the output. The
    /// `output` array must be large enough to hold the result. If
    /// successful, this function returns two integers `(out_pos,
    /// in_pos)`, where `out_pos` is the number of bytes written in
    /// `output`, and `in_pos` is the number of input bytes consumed.
    pub fn compress(&mut self, output: &mut [u8], input: &[u8]) -> Result<(usize, usize), Error> {
        let mut output = ZSTD_outBuffer {
            dst: output.as_mut_ptr() as *mut c_void,
            size: output.len() as size_t,
            pos: 0,
        };

        let len = cmp::min(
            input.len(),
            self.max_frame_size - self.frame_d_size as usize,
        );

        let mut in_tmp = ZSTD_inBuffer {
            src: input.as_ptr() as *const c_void,
            size: len,
            pos: 0,
        };

        if len > 0 {
            let ret = unsafe { ZSTD_compressStream(self.cstream.p, &mut output, &mut in_tmp) };

            if self.framelog.checksum_flag != 0 {
                self.xxh_state
                    .write(unsafe { slice::from_raw_parts(in_tmp.src as *const _, in_tmp.pos) });
            }

            self.frame_c_size += output.pos as u32;
            self.frame_d_size += in_tmp.pos as u32;

            if unsafe { ZSTD_isError(ret) } != 0 {
                return Err(Error::ZSTD(ZSTDError(ret)));
            }
        }

        if self.max_frame_size == self.frame_d_size as usize {
            let ret = self.end_frame(&mut output)?;

            if unsafe { ZSTD_isError(ret) } != 0 {
                return Err(Error::ZSTD(ZSTDError(ret)));
            }
        }

        Ok((output.pos as usize, in_tmp.pos as usize))
    }

    /// Finish writing the message, i.e. write the remaining pending block.
    pub fn end_stream(&mut self, output: &mut [u8]) -> Result<usize, Error> {
        let mut output_ = ZSTD_outBuffer {
            dst: output.as_mut_ptr() as *mut c_void,
            size: output.len() as size_t,
            pos: 0,
        };

        if !self.writing_seek_table && self.frame_d_size != 0 {
            let end_frame = self.end_frame(&mut output_)?;

            if unsafe { ZSTD_isError(end_frame) } != 0 {
                return Err(Error::ZSTD(ZSTDError(end_frame)));
            }

            if end_frame != 0 {
                return Ok(end_frame + self.framelog.seek_table_size());
            }
        }

        self.writing_seek_table = true;
        let result = self.framelog.write_seek_table(output, &mut output_.pos)?;

        if unsafe { ZSTD_isError(result) } != 0 {
            return Err(Error::ZSTD(ZSTDError(result)));
        }

        Ok(output_.pos as usize)
    }
}

#[cfg(feature = "threadpool")]
mod parallel_compress {
    use super::*;
    use xxhash_rust::xxh64::xxh64;
    pub trait Dst: Send {
        fn as_mut_ptr(&mut self) -> *mut u8;
        fn as_slice(&self) -> &[u8];
        fn len(&self) -> usize;
        fn new() -> Self;
    }

    impl<const N: usize> Dst for [u8; N] {
        fn as_mut_ptr(&mut self) -> *mut u8 {
            self.as_mut().as_mut_ptr()
        }
        fn as_slice(&self) -> &[u8] {
            self.as_ref()
        }
        fn len(&self) -> usize {
            N
        }
        fn new() -> Self {
            unsafe { std::mem::MaybeUninit::uninit().assume_init() }
        }
    }

    struct CompressedFrame<D: Dst> {
        src_size: u32,
        dst_size: u32,
        checksum: u32,
        dst: D,
    }

    impl<D: Dst> CompressedFrame<D> {
        fn as_slice(&self) -> &[u8] {
            &self.dst.as_slice()[..self.dst_size as usize]
        }
    }

    unsafe impl<D: Dst> Send for CompressedFrame<D> {}

    fn compress_frame<D: Dst>(src: &[u8], level: usize) -> Result<CompressedFrame<D>, Error> {
        let mut dst = D::new();

        let ret = unsafe {
            let ret = ZSTD_compress(
                dst.as_mut_ptr() as *mut c_void,
                dst.len() as size_t,
                src.as_ptr() as *const c_void,
                src.len() as size_t,
                level as c_int,
            );

            if ZSTD_isError(ret) != 0 {
                return Err(Error::ZSTD(ZSTDError(ret)));
            }

            ret
        };

        let checksum = xxh64(src, 0) as u32;

        Ok(CompressedFrame {
            src_size: src.len() as u32,
            dst_size: ret as u32,
            checksum,
            dst,
        })
    }

    pub fn parallel_compress<W: std::io::Write, D: Dst + 'static>(
        src: &'static [u8],
        mut output: W,
        level: usize,
        jobs: usize,
        chunk_size: usize,
    ) -> Result<(), Error> {
        use std::sync::mpsc::channel;
        use threadpool::ThreadPool;

        let n = src.len() / chunk_size + if src.len() % chunk_size == 0 { 0 } else { 1 };
        let pool = ThreadPool::new(jobs);

        let (tx, rx) = channel();
        for (i, chunk) in src.chunks(chunk_size).enumerate() {
            let tx = tx.clone();
            pool.execute(move || {
                let frame = compress_frame(chunk, level);
                tx.send((i, frame))
                    .expect("channel will be there waiting for the pool");
            });
        }

        let mut frames: Vec<CompressedFrame<D>> = Vec::with_capacity(n);
        unsafe { frames.set_len(n) };
        for (i, frame) in rx.iter().take(n) {
            frames[i] = frame?;
        }

        let mut log = FrameLog::new();
        for frame in frames.iter() {
            output.write_all(frame.as_slice())?;
            log.log_frame(frame.dst_size, frame.src_size, frame.checksum)?;
        }
        log.write_all(&mut output)?;
        Ok(())
    }
}

#[cfg(feature = "threadpool")]
pub use parallel_compress::*;