use super::{CStream, Error, ZSTDError, ZSTD_reset_session_only};
use crate::bindings::*;
use libc::*;
use std::hash::Hasher;
use std::{cmp, slice};
use xxhash_rust::xxh64::Xxh64;
mod framelog;
use framelog::*;
const MAX_FRAME_DECOMPR_SIZE: usize = 0x80000000;
const FRAMELOG_STARTING_CAPACITY: usize = 16;
pub struct SeekableCStream {
cstream: CStream,
framelog: FrameLog,
frame_c_size: u32,
frame_d_size: u32,
xxh_state: Xxh64,
max_frame_size: usize,
writing_seek_table: bool,
}
unsafe impl Send for SeekableCStream {}
impl SeekableCStream {
pub fn new(level: usize, frame_size: usize) -> Result<Self, Error> {
let cstream = unsafe { ZSTD_createCStream() };
if cstream.is_null() {
return Err(Error::Null);
}
if frame_size > MAX_FRAME_DECOMPR_SIZE {
Err(Error::FParamUnsupported(frame_size, MAX_FRAME_DECOMPR_SIZE))
} else {
let max_frame_size = if frame_size > 0 {
frame_size
} else {
MAX_FRAME_DECOMPR_SIZE
};
unsafe {
let result = ZSTD_initCStream(cstream, level as c_int);
if ZSTD_isError(result) != 0 {
return Err(Error::ZSTD(ZSTDError(result)));
}
};
Ok(SeekableCStream {
cstream: CStream { p: cstream },
framelog: FrameLog::with_capacity(FRAMELOG_STARTING_CAPACITY),
frame_c_size: 0,
frame_d_size: 0,
xxh_state: Xxh64::new(0),
max_frame_size,
writing_seek_table: false,
})
}
}
fn end_frame(&mut self, output: &mut ZSTD_outBuffer) -> Result<usize, Error> {
let prev_out_pos = output.pos;
let ret = unsafe { ZSTD_endStream(self.cstream.p, output) };
self.frame_c_size += (output.pos - prev_out_pos) as u32;
if ret != 0 {
return Ok(ret);
}
let checksum = if self.framelog.checksum_flag != 0 {
self.xxh_state.finish() as u32
} else {
0
};
self.framelog
.log_frame(self.frame_c_size, self.frame_d_size, checksum)?;
self.frame_c_size = 0;
self.frame_d_size = 0;
unsafe {
let r = ZSTD_CCtx_reset(self.cstream.p, ZSTD_reset_session_only);
if ZSTD_isError(r) != 0 {
return Err(Error::ZSTD(ZSTDError(r)));
}
};
if self.framelog.checksum_flag != 0 {
self.xxh_state = Xxh64::new(0);
}
Ok(0)
}
pub fn compress(&mut self, output: &mut [u8], input: &[u8]) -> Result<(usize, usize), Error> {
let mut output = ZSTD_outBuffer {
dst: output.as_mut_ptr() as *mut c_void,
size: output.len() as size_t,
pos: 0,
};
let len = cmp::min(
input.len(),
self.max_frame_size - self.frame_d_size as usize,
);
let mut in_tmp = ZSTD_inBuffer {
src: input.as_ptr() as *const c_void,
size: len,
pos: 0,
};
if len > 0 {
let ret = unsafe { ZSTD_compressStream(self.cstream.p, &mut output, &mut in_tmp) };
if self.framelog.checksum_flag != 0 {
self.xxh_state
.write(unsafe { slice::from_raw_parts(in_tmp.src as *const _, in_tmp.pos) });
}
self.frame_c_size += output.pos as u32;
self.frame_d_size += in_tmp.pos as u32;
if unsafe { ZSTD_isError(ret) } != 0 {
return Err(Error::ZSTD(ZSTDError(ret)));
}
}
if self.max_frame_size == self.frame_d_size as usize {
let ret = self.end_frame(&mut output)?;
if unsafe { ZSTD_isError(ret) } != 0 {
return Err(Error::ZSTD(ZSTDError(ret)));
}
}
Ok((output.pos as usize, in_tmp.pos as usize))
}
pub fn end_stream(&mut self, output: &mut [u8]) -> Result<usize, Error> {
let mut output_ = ZSTD_outBuffer {
dst: output.as_mut_ptr() as *mut c_void,
size: output.len() as size_t,
pos: 0,
};
if !self.writing_seek_table && self.frame_d_size != 0 {
let end_frame = self.end_frame(&mut output_)?;
if unsafe { ZSTD_isError(end_frame) } != 0 {
return Err(Error::ZSTD(ZSTDError(end_frame)));
}
if end_frame != 0 {
return Ok(end_frame + self.framelog.seek_table_size());
}
}
self.writing_seek_table = true;
let result = self.framelog.write_seek_table(output, &mut output_.pos)?;
if unsafe { ZSTD_isError(result) } != 0 {
return Err(Error::ZSTD(ZSTDError(result)));
}
Ok(output_.pos as usize)
}
}
#[cfg(feature = "threadpool")]
mod parallel_compress {
use super::*;
use xxhash_rust::xxh64::xxh64;
pub trait Dst: Send {
fn as_mut_ptr(&mut self) -> *mut u8;
fn as_slice(&self) -> &[u8];
fn len(&self) -> usize;
fn new() -> Self;
}
impl<const N: usize> Dst for [u8; N] {
fn as_mut_ptr(&mut self) -> *mut u8 {
self.as_mut().as_mut_ptr()
}
fn as_slice(&self) -> &[u8] {
self.as_ref()
}
fn len(&self) -> usize {
N
}
fn new() -> Self {
unsafe { std::mem::MaybeUninit::uninit().assume_init() }
}
}
struct CompressedFrame<D: Dst> {
src_size: u32,
dst_size: u32,
checksum: u32,
dst: D,
}
impl<D: Dst> CompressedFrame<D> {
fn as_slice(&self) -> &[u8] {
&self.dst.as_slice()[..self.dst_size as usize]
}
}
unsafe impl<D: Dst> Send for CompressedFrame<D> {}
fn compress_frame<D: Dst>(src: &[u8], level: usize) -> Result<CompressedFrame<D>, Error> {
let mut dst = D::new();
let ret = unsafe {
let ret = ZSTD_compress(
dst.as_mut_ptr() as *mut c_void,
dst.len() as size_t,
src.as_ptr() as *const c_void,
src.len() as size_t,
level as c_int,
);
if ZSTD_isError(ret) != 0 {
return Err(Error::ZSTD(ZSTDError(ret)));
}
ret
};
let checksum = xxh64(src, 0) as u32;
Ok(CompressedFrame {
src_size: src.len() as u32,
dst_size: ret as u32,
checksum,
dst,
})
}
pub fn parallel_compress<W: std::io::Write, D: Dst + 'static>(
src: &'static [u8],
mut output: W,
level: usize,
jobs: usize,
chunk_size: usize,
) -> Result<(), Error> {
use std::sync::mpsc::channel;
use threadpool::ThreadPool;
let n = src.len() / chunk_size + if src.len() % chunk_size == 0 { 0 } else { 1 };
let pool = ThreadPool::new(jobs);
let (tx, rx) = channel();
for (i, chunk) in src.chunks(chunk_size).enumerate() {
let tx = tx.clone();
pool.execute(move || {
let frame = compress_frame(chunk, level);
tx.send((i, frame))
.expect("channel will be there waiting for the pool");
});
}
let mut frames: Vec<CompressedFrame<D>> = Vec::with_capacity(n);
unsafe { frames.set_len(n) };
for (i, frame) in rx.iter().take(n) {
frames[i] = frame?;
}
let mut log = FrameLog::new();
for frame in frames.iter() {
output.write_all(frame.as_slice())?;
log.log_frame(frame.dst_size, frame.src_size, frame.checksum)?;
}
log.write_all(&mut output)?;
Ok(())
}
}
#[cfg(feature = "threadpool")]
pub use parallel_compress::*;