#ifndef SHISHUA_SSE2_H
#define SHISHUA_SSE2_H
#include <stdint.h>
#include <stddef.h>
#include <assert.h>
#if defined(__SSSE3__) || defined(__AVX__)
# include <tmmintrin.h>
# define SHISHUA_ALIGNR_EPI8(hi, lo, amt) \
_mm_alignr_epi8(hi, lo, amt)
#else
# include <emmintrin.h>
# define SHISHUA_ALIGNR_EPI8(hi, lo, amt) \
_mm_or_si128( \
_mm_slli_si128(hi, 16 - (amt)), \
_mm_srli_si128(lo, amt) \
)
#endif
typedef struct prng_state {
__m128i state[8];
__m128i output[8];
__m128i counter[2];
} prng_state;
#if defined(__x86_64__) || defined(_M_X64)
# define SHISHUA_SET_EPI64X(b, a) _mm_set_epi64x(b, a)
# define SHISHUA_CVTSI64_SI128(x) _mm_cvtsi64_si128(x)
#else
# define SHISHUA_SET_EPI64X(b, a) \
_mm_set_epi32( \
(int)(((uint64_t)(b)) >> 32), \
(int)(b), \
(int)(((uint64_t)(a)) >> 32), \
(int)(a) \
)
# define SHISHUA_CVTSI64_SI128(x) SHISHUA_SET_EPI64X(0, x)
#endif
#if defined(__GNUC__) || defined(_MSC_VER)
# define SHISHUA_RESTRICT __restrict
#else
# define SHISHUA_RESTRICT
#endif
static inline void prng_gen(prng_state *SHISHUA_RESTRICT s, uint8_t *SHISHUA_RESTRICT buf, size_t size) {
__m128i counter_lo = s->counter[0], counter_hi = s->counter[1];
__m128i increment_lo = SHISHUA_SET_EPI64X(5, 7);
__m128i increment_hi = SHISHUA_SET_EPI64X(1, 3);
assert((size % 128 == 0) && "buf's size must be a multiple of 128 bytes.");
for (size_t i = 0; i < size; i += 128) {
if (buf != NULL) {
for (size_t j = 0; j < 8; j++) {
_mm_storeu_si128((__m128i *)&buf[i + (16 * j)], s->output[j]);
}
}
for (size_t j = 0; j < 2; j++) {
__m128i s_lo, s_hi, u0_lo, u0_hi, u1_lo, u1_hi, t_lo, t_hi;
s_lo = s->state[4 * j + 0];
s_hi = s->state[4 * j + 1];
u0_lo = _mm_srli_epi64(s_lo, 1);
u0_hi = _mm_srli_epi64(s_hi, 1);
t_lo = SHISHUA_ALIGNR_EPI8(s_lo, s_hi, 4);
t_hi = SHISHUA_ALIGNR_EPI8(s_hi, s_lo, 4);
s->state[4 * j + 0] = _mm_add_epi64(t_lo, u0_lo);
s->state[4 * j + 1] = _mm_add_epi64(t_hi, u0_hi);
s_lo = s->state[4 * j + 2];
s_hi = s->state[4 * j + 3];
s_lo = _mm_add_epi64(s_lo, counter_lo);
s_hi = _mm_add_epi64(s_hi, counter_hi);
u1_lo = _mm_srli_epi64(s_lo, 3);
u1_hi = _mm_srli_epi64(s_hi, 3);
t_lo = SHISHUA_ALIGNR_EPI8(s_hi, s_lo, 12);
t_hi = SHISHUA_ALIGNR_EPI8(s_lo, s_hi, 12);
s->state[4 * j + 2] = _mm_add_epi64(t_lo, u1_lo);
s->state[4 * j + 3] = _mm_add_epi64(t_hi, u1_hi);
s->output[2 * j + 0] = _mm_xor_si128(u0_lo, t_lo);
s->output[2 * j + 1] = _mm_xor_si128(u0_hi, t_hi);
}
s->output[4] = _mm_xor_si128(s->state[0], s->state[6]);
s->output[5] = _mm_xor_si128(s->state[1], s->state[7]);
s->output[6] = _mm_xor_si128(s->state[4], s->state[2]);
s->output[7] = _mm_xor_si128(s->state[5], s->state[3]);
counter_lo = _mm_add_epi64(counter_lo, increment_lo);
counter_hi = _mm_add_epi64(counter_hi, increment_hi);
}
s->counter[0] = counter_lo;
s->counter[1] = counter_hi;
}
static uint64_t phi[16] = {
0x9E3779B97F4A7C15, 0xF39CC0605CEDC834, 0x1082276BF3A27251, 0xF86C6A11D0C18E95,
0x2767F0B153D27B7F, 0x0347045B5BF1827F, 0x01886F0928403002, 0xC1D64BA40F335E36,
0xF06AD7AE9717877E, 0x85839D6EFFBD7DC6, 0x64D325D1C5371682, 0xCADD0CCCFDFFBBE1,
0x626E33B8D04B4331, 0xBBF73C790D94F79D, 0x471C4AB3ED3D82A5, 0xFEC507705E4AE6E5,
};
void prng_init(prng_state *s, uint64_t seed[4]) {
s->counter[0] = _mm_setzero_si128();
s->counter[1] = _mm_setzero_si128();
# define ROUNDS 13
# define STEPS 1
__m128i seed_0 = SHISHUA_CVTSI64_SI128(seed[0]);
__m128i seed_1 = SHISHUA_CVTSI64_SI128(seed[1]);
__m128i seed_2 = SHISHUA_CVTSI64_SI128(seed[2]);
__m128i seed_3 = SHISHUA_CVTSI64_SI128(seed[3]);
s->state[0] = _mm_xor_si128(seed_0, _mm_loadu_si128((__m128i *)&phi[ 0]));
s->state[1] = _mm_xor_si128(seed_1, _mm_loadu_si128((__m128i *)&phi[ 2]));
s->state[2] = _mm_xor_si128(seed_2, _mm_loadu_si128((__m128i *)&phi[ 4]));
s->state[3] = _mm_xor_si128(seed_3, _mm_loadu_si128((__m128i *)&phi[ 6]));
s->state[4] = _mm_xor_si128(seed_2, _mm_loadu_si128((__m128i *)&phi[ 8]));
s->state[5] = _mm_xor_si128(seed_3, _mm_loadu_si128((__m128i *)&phi[10]));
s->state[6] = _mm_xor_si128(seed_0, _mm_loadu_si128((__m128i *)&phi[12]));
s->state[7] = _mm_xor_si128(seed_1, _mm_loadu_si128((__m128i *)&phi[14]));
for (int i = 0; i < ROUNDS; i++) {
prng_gen(s, NULL, 128 * STEPS);
s->state[0] = s->output[6]; s->state[1] = s->output[7];
s->state[2] = s->output[4]; s->state[3] = s->output[5];
s->state[4] = s->output[2]; s->state[5] = s->output[3];
s->state[6] = s->output[0]; s->state[7] = s->output[1];
}
# undef STEPS
# undef ROUNDS
}
#undef SHISHUA_CVTSI64_SI128
#undef SHISHUA_ALIGNR_EPI8
#undef SHISHUA_SET_EPI64X
#undef SHISHUA_RESTRICT
#endif