package pijul
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"strconv"
"golang.org/x/exp/constraints"
)
type parser[T any] func(input []byte) (rest []byte, value T, err error)
func takeUntil(tag string) parser[[]byte] {
b := []byte(tag)
return func(input []byte) (rest []byte, value []byte, err error) {
i := bytes.Index(input, b)
if i == -1 {
return input, nil, fmt.Errorf("not found: %q", tag)
}
return input[i:], input[:i], nil
}
}
func alt[T any](options ...parser[T]) parser[T] {
return func(input []byte) (rest []byte, value T, err error) {
for _, opt := range options {
rest, value, err = opt(input)
if err == nil {
return
}
}
return input, value, err
}
}
func mapValue[T, U any](p parser[T], f func(T) U) parser[U] {
return func(input []byte) (rest []byte, value U, err error) {
rest, v, err := p(input)
if err != nil {
return input, value, err
}
return rest, f(v), nil
}
}
func mapWithError[T, U any](p parser[T], f func(T) (U, error)) parser[U] {
return func(input []byte) (rest []byte, value U, err error) {
rest, v, err := p(input)
if err != nil {
return input, value, err
}
value, err = f(v)
return rest, value, err
}
}
func tag(t string) parser[string] {
b := []byte(t)
return func(input []byte) (rest []byte, value string, err error) {
if bytes.HasPrefix(input, b) {
return input[len(b):], t, nil
} else {
return input, "", fmt.Errorf("not found: %q", t)
}
}
}
func value[T, U any](val T, p parser[U]) parser[T] {
return func(input []byte) (rest []byte, value T, err error) {
rest, _, err = p(input)
if err != nil {
return input, value, err
}
return rest, val, nil
}
}
func takeAny(set string) parser[[]byte] {
return func(input []byte) (rest []byte, value []byte, err error) {
rest = bytes.TrimLeft(input, set)
return rest, input[:len(input)-len(rest)], nil
}
}
func takeAny1(set string) parser[[]byte] {
return func(input []byte) (rest []byte, value []byte, err error) {
rest = bytes.TrimLeft(input, set)
if len(rest) == len(input) {
return input, nil, fmt.Errorf("nothing matching %q was found", set)
}
return rest, input[:len(input)-len(rest)], nil
}
}
func positiveInt(input []byte) (rest []byte, value int, err error) {
return mapWithError(
takeAny1("0123456789"),
func(b []byte) (int, error) {
return strconv.Atoi(string(b))
},
)(input)
}
func delimited[T, U, V any](left parser[T], inner parser[U], right parser[V]) parser[U] {
return func(input []byte) (rest []byte, value U, err error) {
rest, _, err = left(input)
if err != nil {
return
}
rest, value, err = inner(rest)
if err != nil {
return
}
rest, _, err = right(rest)
return
}
}
func terminated[T, U any](first parser[T], second parser[U]) parser[T] {
return func(input []byte) (rest []byte, value T, err error) {
rest, value, err = first(input)
if err != nil {
return
}
rest, _, err = second(rest)
return
}
}
func preceded[T, U any](first parser[T], second parser[U]) parser[U] {
return func(input []byte) (rest []byte, value U, err error) {
rest, _, err = first(input)
if err != nil {
return
}
rest, value, err = second(rest)
return
}
}
func space0(input []byte) (rest []byte, value []byte, err error) {
return takeAny(" \t")(input)
}
func multispace0(input []byte) (rest []byte, value []byte, err error) {
return takeAny(" \t\r\n")(input)
}
func lineEnding(input []byte) (rest []byte, value string, err error) {
return alt(tag("\n"), tag("\r\n"))(input)
}
func takeWhile(f func(byte) bool) parser[[]byte] {
return func(input []byte) ([]byte, []byte, error) {
i := 0
for i < len(input) && f(input[i]) {
i++
}
return input[i:], input[:i], nil
}
}
func recognize[T any](p parser[T]) parser[[]byte] {
return func(input []byte) (rest []byte, value []byte, err error) {
rest, _, err = p(input)
if err != nil {
return
}
return rest, input[:len(input)-len(rest)], nil
}
}
func recognize2[T, U any](p1 parser[T], p2 parser[U]) parser[[]byte] {
return func(input []byte) (rest []byte, value []byte, err error) {
rest, _, err = p1(input)
if err != nil {
return
}
rest, _, err = p2(rest)
if err != nil {
return
}
return rest, input[:len(input)-len(rest)], nil
}
}
func recognize3[T, U, V any](p1 parser[T], p2 parser[U], p3 parser[V]) parser[[]byte] {
return func(input []byte) (rest []byte, value []byte, err error) {
rest, _, err = p1(input)
if err != nil {
return
}
rest, _, err = p2(rest)
if err != nil {
return
}
rest, _, err = p3(rest)
if err != nil {
return
}
return rest, input[:len(input)-len(rest)], nil
}
}
func recognize4[T, U, V, W any](p1 parser[T], p2 parser[U], p3 parser[V], p4 parser[W]) parser[[]byte] {
return func(input []byte) (rest []byte, value []byte, err error) {
rest, _, err = p1(input)
if err != nil {
return
}
rest, _, err = p2(rest)
if err != nil {
return
}
rest, _, err = p3(rest)
if err != nil {
return
}
rest, _, err = p4(rest)
if err != nil {
return
}
return rest, input[:len(input)-len(rest)], nil
}
}
func many0[T any](p parser[T]) parser[[]T] {
return func(input []byte) (rest []byte, value []T, err error) {
for {
var v T
rest, v, err = p(input)
if err != nil {
return input, value, nil
}
if len(rest) == len(input) {
return rest, value, errors.New("infinite loop in many0")
}
value = append(value, v)
input = rest
}
}
}
func opt[T any](p parser[T]) parser[*T] {
return func(input []byte) ([]byte, *T, error) {
rest, value, err := p(input)
if err != nil {
return input, nil, nil
}
return rest, &value, nil
}
}
func uint64LE(input []byte) ([]byte, uint64, error) {
if len(input) < 8 {
return input, 0, fmt.Errorf("need 8 bytes to parse a 64-bit integer; only got %d", len(input))
}
return input[8:], binary.LittleEndian.Uint64(input), nil
}
func uint32LE(input []byte) ([]byte, uint32, error) {
if len(input) < 4 {
return input, 0, fmt.Errorf("need 4 bytes to parse a 32-bit integer; only got %d", len(input))
}
return input[4:], binary.LittleEndian.Uint32(input), nil
}
func uint16LE(input []byte) ([]byte, uint16, error) {
if len(input) < 2 {
return input, 0, fmt.Errorf("need 2 bytes to parse a 16-bit integer; only got %d", len(input))
}
return input[2:], binary.LittleEndian.Uint16(input), nil
}
func lengthData[T constraints.Integer](p parser[T]) parser[[]byte] {
return func(data []byte) ([]byte, []byte, error) {
data, length, err := p(data)
if err != nil {
return data, nil, err
}
if int(length) > len(data) {
return data, nil, fmt.Errorf("need %d bytes, only have %d", length, len(data))
}
return data[length:], data[:length], nil
}
}
func toString(p parser[[]byte]) parser[string] {
return mapValue(p, func(b []byte) string {
return string(b)
})
}
func rustString(data []byte) ([]byte, string, error) {
return toString(lengthData(uint64LE))(data)
}
func optionalString(data []byte) ([]byte, string, error) {
return mapValue(option(rustString), func(p *string) string {
if p == nil {
return ""
}
return *p
})(data)
}
func option[T any](p parser[T]) parser[*T] {
return func(data []byte) ([]byte, *T, error) {
if len(data) == 0 {
return data, nil, io.ErrUnexpectedEOF
}
switch data[0] {
case 0:
return data[1:], nil, nil
case 1:
data, v, err := p(data[1:])
if err != nil {
return data, nil, err
}
return data, &v, nil
default:
return data, nil, fmt.Errorf("want 0 or 1, got 0x%02x", data[0])
}
}
}
func vec[T any](p parser[T]) parser[[]T] {
return func(data []byte) ([]byte, []T, error) {
data, length, err := uint64LE(data)
if err != nil {
return data, nil, err
}
vector := make([]T, length)
for i := range vector {
var v T
data, v, err = p(data)
if err != nil {
return data, nil, err
}
vector[i] = v
}
return data, vector, nil
}
}
func hashMap[K comparable, V any](pk parser[K], pv parser[V]) parser[map[K]V] {
return func(data []byte) ([]byte, map[K]V, error) {
data, length, err := uint64LE(data)
if err != nil {
return data, nil, err
}
m := make(map[K]V, length)
for i := 0; i < int(length); i++ {
var key K
var val V
data, key, err = pk(data)
if err != nil {
return data, nil, err
}
data, val, err = pv(data)
if err != nil {
return data, nil, err
}
m[key] = val
}
return data, m, nil
}
}
func take(n int) parser[[]byte] {
return func(data []byte) ([]byte, []byte, error) {
if n > len(data) {
return data, nil, io.ErrUnexpectedEOF
}
return data[n:], data[:n], nil
}
}
func assign[T any](dest *T, p parser[T]) parser[[]byte] {
return func(input []byte) ([]byte, []byte, error) {
rest, v, err := p(input)
if err != nil {
return input, nil, err
}
*dest = v
return rest, input[:len(input)-len(rest)], nil
}
}
func tuple[T any](parsers ...parser[T]) parser[[]T] {
return func(data []byte) ([]byte, []T, error) {
results := make([]T, len(parsers))
var err error
for i, p := range parsers {
data, results[i], err = p(data)
if err != nil {
return data, nil, err
}
}
return data, results, nil
}
}