#ifndef SHISHUA_HALF_H
#define SHISHUA_HALF_H
#define SHISHUA_TARGET_SCALAR 0
#define SHISHUA_TARGET_AVX2 1
#define SHISHUA_TARGET_SSE2 2
#define SHISHUA_TARGET_NEON 3
#ifndef SHISHUA_TARGET
# if defined(__AVX2__) && (defined(__x86_64__) || defined(_M_X64))
# define SHISHUA_TARGET SHISHUA_TARGET_AVX2
# elif defined(__x86_64__) || defined(_M_X64) || defined(__SSE2__) \
|| (defined(_M_IX86_FP) && _M_IX86_FP >= 2)
# define SHISHUA_TARGET SHISHUA_TARGET_SSE2
# elif (defined(__ARM_NEON) || defined(__ARM_NEON__)) && defined(__clang__)
# define SHISHUA_TARGET SHISHUA_TARGET_NEON
# else
# define SHISHUA_TARGET SHISHUA_TARGET_SCALAR
# endif
#endif
#if SHISHUA_TARGET == SHISHUA_TARGET_AVX2
# include "shishua-half-avx2.h"
#elif SHISHUA_TARGET == SHISHUA_TARGET_SSE2
# include "shishua-half-sse2.h"
#elif SHISHUA_TARGET == SHISHUA_TARGET_NEON
# include "shishua-half-neon.h"
#else
#include <stdint.h>
#include <stddef.h>
#include <string.h>
#include <assert.h>
typedef struct prng_state {
uint64_t state[8]; uint64_t output[4]; uint64_t counter[4]; } prng_state;
#if defined(__GNUC__) || defined(_MSC_VER)
# define SHISHUA_RESTRICT __restrict
#else
# define SHISHUA_RESTRICT
#endif
static inline void shishua_write_le64(void *dst, uint64_t val) {
# if defined(SHISHUA_NATIVE_ENDIAN) \
|| defined(_WIN32) \
|| (defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__) \
|| defined(__LITTLE_ENDIAN__)
memcpy(dst, &val, sizeof(uint64_t));
#else
uint8_t *d = (uint8_t *)dst;
for (size_t i = 0; i < 8; i++) {
d[i] = (uint8_t)(val & 0xff);
val >>= 8;
}
#endif
}
static inline void prng_gen(prng_state *SHISHUA_RESTRICT state, uint8_t *SHISHUA_RESTRICT buf, size_t size) {
uint8_t *b = buf;
assert((size % 32 == 0) && "buf's size must be a multiple of 32 bytes.");
for (size_t i = 0; i < size; i += 32) {
uint64_t t[8];
if (buf != NULL) {
for (size_t j = 0; j < 4; j++) {
shishua_write_le64(b, state->output[j]);
b += 8;
}
}
for (size_t j = 0; j < 4; j++) {
state->state[j + 4] += state->counter[j];
state->counter[j] += 7 - (j * 2); }
const uint8_t shuf_offsets[16] = { 2,3,0,1, 5,6,7,4, 3,0,1,2, 6,7,4,5 }; for (size_t j = 0; j < 8; j++) {
t[j] = (state->state[shuf_offsets[j]] >> 32) | (state->state[shuf_offsets[j + 8]] << 32);
}
for (size_t j = 0; j < 4; j++) {
uint64_t u_lo = state->state[j + 0] >> 1;
uint64_t u_hi = state->state[j + 4] >> 3;
state->state[j + 0] = u_lo + t[j + 0];
state->state[j + 4] = u_hi + t[j + 4];
state->output[j] = u_lo ^ t[j + 4];
}
}
}
#undef SHISHUA_RESTRICT
static uint64_t phi[8] = {
0x9E3779B97F4A7C15, 0xF39CC0605CEDC834, 0x1082276BF3A27251, 0xF86C6A11D0C18E95,
0x2767F0B153D27B7F, 0x0347045B5BF1827F, 0x01886F0928403002, 0xC1D64BA40F335E36,
};
void prng_init(prng_state *s, uint64_t seed[4]) {
memset(s, 0, sizeof(prng_state));
# define STEPS 5
# define ROUNDS 4
memcpy(s->state, phi, sizeof(phi));
for (size_t i = 0; i < 4; i++) {
s->state[i * 2] ^= seed[i];
}
for (size_t i = 0; i < ROUNDS; i++) {
prng_gen(s, NULL, 32 * STEPS);
for (size_t j = 0; j < 4; j++) {
s->state[j + 0] = s->state[j + 4];
s->state[j + 4] = s->output[j];
}
}
# undef STEPS
# undef ROUNDS
}
#endif #endif