#ifndef SHISHUA_NEON_H
#define SHISHUA_NEON_H
#include <stdint.h>
#include <stddef.h>
#include <assert.h>
#include <arm_neon.h>
typedef struct prng_state {
uint64x2_t state[8];
uint64x2_t output[8];
uint64x2_t counter[2];
} prng_state;
#if defined(__GNUC__) && (defined(__BYTE_ORDER__) && __BYTE_ORDER__==__ORDER_LITTLE_ENDIAN__)
# define SHISHUA_VSETQ_N_U64(a, b) (__extension__(uint64x2_t) { a, b })
#else
# define SHISHUA_VSETQ_N_U64(a, b) vcombine_u64(vdup_n_u64(a), vdup_n_u64(b))
#endif
#define SHISHUA_VEXTQ_U8(Rn, Rm, Imm) \
vreinterpretq_u64_u8( \
vextq_u8( \
vreinterpretq_u8_u64(Rn), \
vreinterpretq_u8_u64(Rm), \
(Imm) \
) \
)
#if defined(__GNUC__) || defined(_MSC_VER)
# define SHISHUA_RESTRICT __restrict
#else
# define SHISHUA_RESTRICT
#endif
static inline void prng_gen(prng_state *SHISHUA_RESTRICT s, uint8_t *SHISHUA_RESTRICT buf, size_t size) {
uint8_t *b = buf;
uint64x2_t counter_lo = s->counter[0], counter_hi = s->counter[1];
uint64x2_t increment_lo = SHISHUA_VSETQ_N_U64(7, 5);
uint64x2_t increment_hi = SHISHUA_VSETQ_N_U64(3, 1);
assert((size % 128 == 0) && "buf's size must be a multiple of 128 bytes.");
for (size_t i = 0; i < size; i += 128) {
if (buf != NULL) {
for (size_t j = 0; j < 8; j++) {
vst1q_u8(b, vreinterpretq_u8_u64(s->output[j]));
b += 16;
}
}
for (size_t j = 0; j < 2; j++) {
uint64x2_t s0_lo = s->state[j * 4 + 0],
s0_hi = s->state[j * 4 + 1],
s1_lo = s->state[j * 4 + 2],
s1_hi = s->state[j * 4 + 3],
t0_lo, t0_hi, t1_lo, t1_hi,
u_lo, u_hi;
s1_lo = vaddq_u64(s1_lo, counter_lo); s1_hi = vaddq_u64(s1_hi, counter_hi);
t0_lo = SHISHUA_VEXTQ_U8(s0_hi, s0_lo, 4); t0_hi = SHISHUA_VEXTQ_U8(s0_lo, s0_hi, 4);
t1_lo = SHISHUA_VEXTQ_U8(s1_lo, s1_hi, 12); t1_hi = SHISHUA_VEXTQ_U8(s1_hi, s1_lo, 12);
u_lo = vshrq_n_u64(s0_lo, 1); u_hi = vshrq_n_u64(s0_hi, 1);
#if defined(__clang__)
__asm__("" : "+w" (u_lo), "+w" (u_hi));
#endif
s->state[4 * j + 0] = vaddq_u64(t0_lo, u_lo);
s->state[4 * j + 1] = vaddq_u64(t0_hi, u_hi);
s->state[4 * j + 2] = vsraq_n_u64(t1_lo, s1_lo, 3);
s->state[4 * j + 3] = vsraq_n_u64(t1_hi, s1_hi, 3);
s->output[2 * j + 0] = veorq_u64(u_lo, t1_lo);
s->output[2 * j + 1] = veorq_u64(u_hi, t1_hi);
}
s->output[4] = veorq_u64(s->state[0], s->state[6]);
s->output[5] = veorq_u64(s->state[1], s->state[7]);
s->output[6] = veorq_u64(s->state[2], s->state[4]);
s->output[7] = veorq_u64(s->state[3], s->state[5]);
counter_lo = vaddq_u64(counter_lo, increment_lo);
counter_hi = vaddq_u64(counter_hi, increment_hi);
}
s->counter[0] = counter_lo;
s->counter[1] = counter_hi;
}
static uint64_t phi[16] = {
0x9E3779B97F4A7C15, 0xF39CC0605CEDC834, 0x1082276BF3A27251, 0xF86C6A11D0C18E95,
0x2767F0B153D27B7F, 0x0347045B5BF1827F, 0x01886F0928403002, 0xC1D64BA40F335E36,
0xF06AD7AE9717877E, 0x85839D6EFFBD7DC6, 0x64D325D1C5371682, 0xCADD0CCCFDFFBBE1,
0x626E33B8D04B4331, 0xBBF73C790D94F79D, 0x471C4AB3ED3D82A5, 0xFEC507705E4AE6E5,
};
void prng_init(prng_state *s, uint64_t seed[4]) {
s->counter[0] = vdupq_n_u64(0);
s->counter[1] = vdupq_n_u64(0);
# define ROUNDS 13
# define STEPS 1
uint64x2_t seed_0 = SHISHUA_VSETQ_N_U64(seed[0], 0);
uint64x2_t seed_1 = SHISHUA_VSETQ_N_U64(seed[1], 0);
uint64x2_t seed_2 = SHISHUA_VSETQ_N_U64(seed[2], 0);
uint64x2_t seed_3 = SHISHUA_VSETQ_N_U64(seed[3], 0);
s->state[0] = veorq_u64(seed_0, vld1q_u64(&phi[ 0]));
s->state[1] = veorq_u64(seed_1, vld1q_u64(&phi[ 2]));
s->state[2] = veorq_u64(seed_2, vld1q_u64(&phi[ 4]));
s->state[3] = veorq_u64(seed_3, vld1q_u64(&phi[ 6]));
s->state[4] = veorq_u64(seed_2, vld1q_u64(&phi[ 8]));
s->state[5] = veorq_u64(seed_3, vld1q_u64(&phi[10]));
s->state[6] = veorq_u64(seed_0, vld1q_u64(&phi[12]));
s->state[7] = veorq_u64(seed_1, vld1q_u64(&phi[14]));
for (int i = 0; i < ROUNDS; i++) {
prng_gen(s, NULL, 128 * STEPS);
s->state[0] = s->output[6]; s->state[1] = s->output[7];
s->state[2] = s->output[4]; s->state[3] = s->output[5];
s->state[4] = s->output[2]; s->state[5] = s->output[3];
s->state[6] = s->output[0]; s->state[7] = s->output[1];
}
# undef STEPS
# undef ROUNDS
}
#undef SHISHUA_VSETQ_N_U64
#undef SHISHUA_VEXTQ_U8
#undef SHISHUA_RESTRICT
#endif