use rayon::prelude::*;

#[inline(always)]
fn validate_bounds<T>(vs: &[T], indices: &[usize]) -> bool {
    let size = vs.len();
    indices.par_iter().any(|i| *i >= size)
}

/// Sequentially write `value` into `slice` with the given `indices`
pub fn writes<T: Copy>(slice: &mut [T], value: T, indices: &[usize]) -> bool {
    for i in indices {
        if *i >= slice.len() {
            return false;
        }
    }

    for i in indices {
        slice[*i] = value;
    }
    true
}

/// Sequentially write `value` into `slice` with the given `indices`, discarding
/// the writes for any out of bound indexes
pub fn writes_discard<T: Copy>(vs: &mut [T], val: T, indices: &[usize]) {
    for i in indices {
        if *i < vs.len() {
            vs[*i] = val;
        }
    }
}

#[derive(Copy, Clone)]
struct TmpRef<T: Send + Sync>(*mut T);
unsafe impl<T: Send + Sync> Send for TmpRef<T> {}
unsafe impl<T: Sync + Send> Sync for TmpRef<T> {}

/// Write `value` into `slice` with the given `indices` in parallel
pub fn writes_par<T: Copy + Send + Sync>(slice: &mut [T], value: T, indices: &[usize]) -> bool {
    let oob = validate_bounds(slice, indices);
    if oob {
        return false;
    }

    let tref = TmpRef(slice.as_mut_ptr());
    indices.par_iter().for_each(move |i| {
        let tmp = tref.clone();
        unsafe {
            let p = tmp.0.add(*i);
            p.write(value);
        }
    });
    true
}

fn partition(indices: &[usize]) -> [&[usize]; 4] {
    let size = indices.len();
    let shard = size / 4;
    let one = &indices[0..shard];
    let two = &indices[shard..shard * 2];
    let three = &indices[shard * 2..shard * 3];
    let four = &indices[shard * 3..];
    [one, two, three, four]
}

/// Write `value` into `slice` with the given `indices` in parallel, discarding
/// the writes for any out of bound indexes
pub fn writes_par_shard_discard<T: Copy + Send + Sync>(
    slice: &mut [T],
    value: T,
    indices: &[usize],
) {
    let size = slice.len();

    let tref = TmpRef(slice.as_mut_ptr());
    if size <= 4 {
        for v in slice.iter_mut() {
            *v = value;
        }
    } else {
        let shards = partition(indices);
        shards.par_iter().for_each(move |shard| {
            let tmp = tref.clone();
            let slice = unsafe { std::slice::from_raw_parts_mut(tmp.0, size) };
            writes_discard(slice, value, shard);
        });
    }
}

/// Write `value` into `slice` with the given `indices` in parallel, discarding
/// the writes for any out of bound indexes
pub fn writes_par_discard<T: Copy + Send + Sync>(slice: &mut [T], value: T, indices: &[usize]) {
    let size = slice.len();

    let tref = TmpRef(slice.as_mut_ptr());
    indices.par_iter().for_each(move |i| {
        let tmp = tref.clone();
        if *i < size {
            unsafe {
                let p = tmp.0.add(*i);
                p.write(value);
            }
        }
    });
}