use crate::bindings::*;
use crate::{
Error, ZSTDError, ZSTD_DStream, ZSTD_reset_session_only, BLOCK_SIZE_MAX, MAGIC_SKIPPABLE_START,
SEEKABLE_MAGIC_NUMBER, SEEK_TABLE_FOOTER_SIZE, SKIPPABLE_HEADER_SIZE,
};
use libc::*;
use std::hash::Hasher;
use xxhash_rust::xxh64::Xxh64;
const SEEKABLE_BUFF_SIZE: usize = BLOCK_SIZE_MAX;
#[inline(always)]
fn slice_to_num(buff: &[u8]) -> Result<u32, Error> {
let mut b = [0; 4];
b.clone_from_slice(buff);
Ok(u32::from_le_bytes(b))
}
pub struct Seekable<R> {
dstream: *mut ZSTD_DStream,
seek_table: SeekTable,
src: R,
decompressed_offset: u64,
cur_frame: u32,
in_buff: Vec<u8>,
out_buff: Vec<u8>,
xxh_state: Xxh64,
}
unsafe impl<R> Send for Seekable<R> {}
impl<R> Drop for Seekable<R> {
fn drop(&mut self) {
if !self.dstream.is_null() {
unsafe {
ZSTD_freeDStream(self.dstream);
self.dstream = std::ptr::null_mut();
}
}
}
}
struct SeekEntry {
c_offset: u64,
d_offset: u64,
checksum: u32,
}
struct SeekTable {
entries: Vec<SeekEntry>,
checksum_flag: u32,
}
impl SeekTable {
fn new() -> Self {
SeekTable {
entries: Vec::new(),
checksum_flag: 1,
}
}
}
impl<R: std::io::Read + std::io::Seek> Seekable<R> {
pub fn init(source: R) -> Result<Self, Error> {
let mut seekable = Self::make_seekable(source)?;
seekable.init_advanced()?;
Ok(seekable)
}
fn make_seekable(source: R) -> Result<Self, Error> {
unsafe {
let dstream = ZSTD_createDStream();
if dstream.is_null() {
Err(Error::Null)
} else {
Ok(Seekable {
dstream,
seek_table: SeekTable::new(),
src: source,
decompressed_offset: 0,
cur_frame: 0,
in_buff: vec![0; SEEKABLE_BUFF_SIZE],
out_buff: vec![0; SEEKABLE_BUFF_SIZE],
xxh_state: Xxh64::new(0),
})
}
}
}
fn load_seek_table(&mut self) -> Result<(), Error> {
self.src
.seek(std::io::SeekFrom::End(-(SEEK_TABLE_FOOTER_SIZE as i64)))?;
self.src
.read_exact(&mut self.in_buff[..SEEK_TABLE_FOOTER_SIZE])?;
let prefix = slice_to_num(&self.in_buff[5..9])?;
if prefix != SEEKABLE_MAGIC_NUMBER {
return Err(Error::PrefixUnknown(prefix));
}
let sfd = self.in_buff[4];
let checksum_flag = (sfd >> 7) as usize;
if ((sfd >> 2) & 0x1f) != 0 {
return Err(Error::Corruption("when looking the checksum flag"));
}
let num_frames = slice_to_num(&self.in_buff[..4])? as usize;
let size_p_entry: usize = 8 + if checksum_flag != 0 { 4 } else { 0 };
let table_size = size_p_entry * num_frames;
let frame_size = table_size + SEEK_TABLE_FOOTER_SIZE + SKIPPABLE_HEADER_SIZE;
let mut remaining = frame_size as usize - SEEK_TABLE_FOOTER_SIZE;
let to_read = std::cmp::min(remaining, SEEKABLE_BUFF_SIZE);
self.src
.seek(std::io::SeekFrom::End(-(frame_size as i64)))?;
self.src.read_exact(&mut self.in_buff[..to_read])?;
remaining -= to_read;
let mut prefix = slice_to_num(&self.in_buff[..4])?;
if prefix != (MAGIC_SKIPPABLE_START | 0xE) {
return Err(Error::PrefixUnknown(prefix));
}
prefix = slice_to_num(&self.in_buff[4..8])?;
if prefix as usize + SKIPPABLE_HEADER_SIZE != frame_size {
return Err(Error::PrefixUnknown(prefix));
}
let mut entries: Vec<SeekEntry> = Vec::with_capacity((num_frames + 1) as usize);
let mut pos = 8;
let (mut c_offset, mut d_offset) = (0, 0);
for idx in 0..num_frames {
if pos + size_p_entry > SEEKABLE_BUFF_SIZE {
let offset = SEEKABLE_BUFF_SIZE - pos;
let to_read = std::cmp::min(remaining, SEEKABLE_BUFF_SIZE - offset);
self.in_buff.copy_within(pos..pos + offset, 0);
self.src
.read_exact(&mut self.in_buff[offset..offset + to_read])?;
remaining -= to_read;
pos = 0;
}
entries.push(SeekEntry {
c_offset,
d_offset,
checksum: 0,
});
c_offset += slice_to_num(&self.in_buff[pos..pos + 4])? as u64;
pos += 4;
d_offset += slice_to_num(&self.in_buff[pos..pos + 4])? as u64;
pos += 4;
if checksum_flag != 0 {
entries[idx].checksum = slice_to_num(&self.in_buff[pos..pos + 4])?;
pos += 4;
}
}
entries.push(SeekEntry {
c_offset,
d_offset,
checksum: 0,
});
self.seek_table.entries = entries;
self.seek_table.checksum_flag = checksum_flag as u32;
Ok(())
}
fn init_advanced(&mut self) -> Result<(), Error> {
self.load_seek_table()?;
self.decompressed_offset = u64::MAX;
self.cur_frame = u32::MAX;
unsafe {
let dstream_init = ZSTD_initDStream(self.dstream);
if ZSTD_isError(dstream_init) != 0 {
Err(Error::ZSTD(ZSTDError(dstream_init)))
} else {
Ok(())
}
}
}
pub fn decompress_frame(&mut self, dest: &mut [u8], index: usize) -> Result<usize, Error> {
let dec_size = self.get_frame_decompressed_size(index)?;
if (dest.len() as u64) < dec_size {
Err(Error::DSizeTooSmall(dest.len() as u64, dec_size))
} else {
self.decompress(dest, self.seek_table.entries[index].d_offset)
}
}
pub fn decompress(&mut self, out: &mut [u8], offset: u64) -> Result<usize, Error> {
let eos = self.seek_table.entries.last().unwrap().d_offset;
let len = out.len() as u64;
let len = if offset + len > eos {
eos - offset
} else {
len
};
let mut tgt_frame = self.seekable_offset_to_frame_index(offset);
let mut inn = ZSTD_inBuffer {
src: std::ptr::null() as *const c_void,
size: 0,
pos: 0,
};
loop {
if tgt_frame as usize != self.cur_frame as usize || offset != self.decompressed_offset {
self.decompressed_offset = self.seek_table.entries[tgt_frame].d_offset;
self.cur_frame = tgt_frame as u32;
self.src.seek(std::io::SeekFrom::Start(
self.seek_table.entries[tgt_frame].c_offset,
))?;
inn.src = self.in_buff.as_ptr() as *const _ as *const c_void;
self.xxh_state.reset(0);
unsafe {
let r = ZSTD_DCtx_reset(self.dstream, ZSTD_reset_session_only);
if ZSTD_isError(r) != 0 {
return Err(Error::ZSTD(ZSTDError(r)));
}
}
}
while self.decompressed_offset < offset + len as u64 {
let (mut out_tmp, slice_tmp) = if self.decompressed_offset < offset {
(
ZSTD_outBuffer {
dst: self.out_buff.as_mut_ptr() as *mut c_void,
size: std::cmp::min(
SEEKABLE_BUFF_SIZE,
(offset - self.decompressed_offset) as size_t,
),
pos: 0,
},
(&self.out_buff).as_ref(),
)
} else {
(
ZSTD_outBuffer {
dst: out.as_mut_ptr() as *mut c_void,
size: len as usize,
pos: (self.decompressed_offset - offset) as size_t,
},
&*out,
)
};
let prev_out_pos = out_tmp.pos;
let mut to_read;
unsafe {
to_read = ZSTD_decompressStream(self.dstream, &mut out_tmp, &mut inn);
if ZSTD_isError(to_read) != 0 {
return Err(Error::ZSTD(ZSTDError(to_read)));
}
}
if self.seek_table.checksum_flag != 0 {
self.xxh_state.write(&slice_tmp[prev_out_pos..out_tmp.pos]);
}
let forward_progress = (out_tmp.pos - prev_out_pos) as u64;
self.decompressed_offset += forward_progress;
if to_read == 0 {
let f = self.xxh_state.finish();
let f = f as u32;
if self.seek_table.checksum_flag != 0
&& f != self.seek_table.entries[tgt_frame].checksum
{
return Err(Error::Corruption("during decompression"));
}
if self.decompressed_offset < offset + len as u64 {
tgt_frame = self.seekable_offset_to_frame_index(self.decompressed_offset);
}
break;
}
if inn.pos == inn.size {
to_read = std::cmp::min(to_read, SEEKABLE_BUFF_SIZE);
self.in_buff.resize(to_read, 0u8);
self.src.read_exact(&mut self.in_buff)?;
inn.size = to_read;
inn.pos = 0;
}
}
if self.decompressed_offset == offset + len as u64 {
break;
}
}
Ok(len as usize)
}
}
impl<R> Seekable<R> {
#[inline(always)]
pub fn get_num_frames(&self) -> usize {
self.seek_table.entries.len() - 1
}
#[inline(always)]
fn get_frame(&self, frame_index: usize) -> Result<&'_ SeekEntry, Error> {
let max_frames = self.get_num_frames();
if frame_index >= max_frames {
Err(Error::FIndexTooLarge(frame_index, max_frames))
} else {
Ok(&self.seek_table.entries[frame_index])
}
}
pub fn get_frame_compressed_offset(&self, frame_index: usize) -> Result<u64, Error> {
Ok(self.get_frame(frame_index)?.c_offset)
}
pub fn get_frame_compressed_size(&self, frame_index: usize) -> Result<u64, Error> {
let entry = self.get_frame(frame_index)?;
Ok(self.seek_table.entries[frame_index + 1].c_offset - entry.c_offset)
}
pub fn get_frame_decompressed_offset(&self, frame_index: usize) -> Result<u64, Error> {
Ok(self.get_frame(frame_index)?.d_offset)
}
pub fn get_frame_decompressed_size(&self, frame_index: usize) -> Result<u64, Error> {
let entry = self.get_frame(frame_index)?;
Ok(self.seek_table.entries[frame_index + 1].d_offset - entry.d_offset)
}
pub fn seekable_offset_to_frame_index(&self, offset: u64) -> usize {
let n_frames = self.get_num_frames();
assert!(self.seek_table.entries.len() >= n_frames);
if offset >= self.seek_table.entries[n_frames].d_offset {
return n_frames;
}
let (mut lo, mut hi) = (0, n_frames);
while lo + 1 < hi {
let mid = lo + ((hi - lo) >> 1);
if self.seek_table.entries[mid].d_offset <= offset {
lo = mid
} else {
hi = mid;
}
}
lo
}
}
impl<'a> Seekable<std::io::Cursor<&'a [u8]>> {
pub fn init_buf(input: &'a [u8]) -> Result<Self, Error> {
let source = std::io::Cursor::new(input);
let mut seekable = Seekable::make_seekable(source)?;
seekable.init_advanced()?;
Ok(seekable)
}
}
impl Seekable<std::fs::File> {
pub fn init_file(name: &str) -> Result<Self, Error> {
let source = std::fs::File::create(name)?;
let mut seekable = Seekable::make_seekable(source)?;
seekable.init_advanced()?;
Ok(seekable)
}
}
#[test]
fn pijul_change() {
let mut h = xxhash_rust::xxh64::Xxh64::new(0);
h.update(&[4]);
let d = 0x64b9da3ed69d6732;
let f = h.finish();
assert_eq!(d, f);
let change = include_bytes!(
"../../.pijul/changes/IH/334Q5ACWE4TNQYYOOF6GWV6CRXOEM6542NVNPA6HRIZ3CBFKEAC.change"
);
use serde_derive::*;
#[derive(Deserialize)]
pub struct Offsets {
pub version: u64,
pub hashed_len: u64,
pub unhashed_off: u64,
pub unhashed_len: u64,
pub contents_off: u64,
pub contents_len: u64,
pub total: u64,
}
let off0 = std::mem::size_of::<Offsets>();
let offsets: Offsets = bincode::deserialize(&change[..off0]).unwrap();
let mut s = Seekable::init_buf(&change[off0..offsets.unhashed_off as usize]).unwrap();
let mut buf_ = Vec::new();
buf_.resize(offsets.hashed_len as usize, 0);
s.decompress(&mut buf_[..], 0).unwrap();
}