A library for working with Pijul repositories in Go
package pijul

import (
	"bytes"
	"encoding/binary"
	"errors"
	"fmt"
	"io"
	"strconv"

	"golang.org/x/exp/constraints"
)

type parser[T any] func(input []byte) (rest []byte, value T, err error)

func takeUntil(tag string) parser[[]byte] {
	b := []byte(tag)
	return func(input []byte) (rest []byte, value []byte, err error) {
		i := bytes.Index(input, b)
		if i == -1 {
			return input, nil, fmt.Errorf("not found: %q", tag)
		}
		return input[i:], input[:i], nil
	}
}

func alt[T any](options ...parser[T]) parser[T] {
	return func(input []byte) (rest []byte, value T, err error) {
		for _, opt := range options {
			rest, value, err = opt(input)
			if err == nil {
				return
			}
		}
		return input, value, err
	}
}

func mapValue[T, U any](p parser[T], f func(T) U) parser[U] {
	return func(input []byte) (rest []byte, value U, err error) {
		rest, v, err := p(input)
		if err != nil {
			return input, value, err
		}
		return rest, f(v), nil
	}
}

func mapWithError[T, U any](p parser[T], f func(T) (U, error)) parser[U] {
	return func(input []byte) (rest []byte, value U, err error) {
		rest, v, err := p(input)
		if err != nil {
			return input, value, err
		}
		value, err = f(v)
		return rest, value, err
	}
}

func tag(t string) parser[string] {
	b := []byte(t)
	return func(input []byte) (rest []byte, value string, err error) {
		if bytes.HasPrefix(input, b) {
			return input[len(b):], t, nil
		} else {
			return input, "", fmt.Errorf("not found: %q", t)
		}
	}
}

func value[T, U any](val T, p parser[U]) parser[T] {
	return func(input []byte) (rest []byte, value T, err error) {
		rest, _, err = p(input)
		if err != nil {
			return input, value, err
		}
		return rest, val, nil
	}
}

// takeAny returns the longest input slice (if any) that contains only
// characters found in set.
func takeAny(set string) parser[[]byte] {
	return func(input []byte) (rest []byte, value []byte, err error) {
		rest = bytes.TrimLeft(input, set)
		return rest, input[:len(input)-len(rest)], nil
	}
}

// takeAny1 returns the longest input slice that contains only
// characters found in set. If it doesn't find at least one byte that matches,
// it returns an error.
func takeAny1(set string) parser[[]byte] {
	return func(input []byte) (rest []byte, value []byte, err error) {
		rest = bytes.TrimLeft(input, set)
		if len(rest) == len(input) {
			return input, nil, fmt.Errorf("nothing matching %q was found", set)
		}
		return rest, input[:len(input)-len(rest)], nil
	}
}

func positiveInt(input []byte) (rest []byte, value int, err error) {
	return mapWithError(
		takeAny1("0123456789"),
		func(b []byte) (int, error) {
			return strconv.Atoi(string(b))
		},
	)(input)
}

func delimited[T, U, V any](left parser[T], inner parser[U], right parser[V]) parser[U] {
	return func(input []byte) (rest []byte, value U, err error) {
		rest, _, err = left(input)
		if err != nil {
			return
		}
		rest, value, err = inner(rest)
		if err != nil {
			return
		}
		rest, _, err = right(rest)
		return
	}
}

func terminated[T, U any](first parser[T], second parser[U]) parser[T] {
	return func(input []byte) (rest []byte, value T, err error) {
		rest, value, err = first(input)
		if err != nil {
			return
		}
		rest, _, err = second(rest)
		return
	}
}

func preceded[T, U any](first parser[T], second parser[U]) parser[U] {
	return func(input []byte) (rest []byte, value U, err error) {
		rest, _, err = first(input)
		if err != nil {
			return
		}
		rest, value, err = second(rest)
		return
	}
}

func space0(input []byte) (rest []byte, value []byte, err error) {
	return takeAny(" \t")(input)
}

func multispace0(input []byte) (rest []byte, value []byte, err error) {
	return takeAny(" \t\r\n")(input)
}

func lineEnding(input []byte) (rest []byte, value string, err error) {
	return alt(tag("\n"), tag("\r\n"))(input)
}

func takeWhile(f func(byte) bool) parser[[]byte] {
	return func(input []byte) ([]byte, []byte, error) {
		i := 0
		for i < len(input) && f(input[i]) {
			i++
		}
		return input[i:], input[:i], nil
	}
}

func recognize[T any](p parser[T]) parser[[]byte] {
	return func(input []byte) (rest []byte, value []byte, err error) {
		rest, _, err = p(input)
		if err != nil {
			return
		}
		return rest, input[:len(input)-len(rest)], nil
	}
}

func recognize2[T, U any](p1 parser[T], p2 parser[U]) parser[[]byte] {
	return func(input []byte) (rest []byte, value []byte, err error) {
		rest, _, err = p1(input)
		if err != nil {
			return
		}
		rest, _, err = p2(rest)
		if err != nil {
			return
		}
		return rest, input[:len(input)-len(rest)], nil
	}
}

func recognize3[T, U, V any](p1 parser[T], p2 parser[U], p3 parser[V]) parser[[]byte] {
	return func(input []byte) (rest []byte, value []byte, err error) {
		rest, _, err = p1(input)
		if err != nil {
			return
		}
		rest, _, err = p2(rest)
		if err != nil {
			return
		}
		rest, _, err = p3(rest)
		if err != nil {
			return
		}
		return rest, input[:len(input)-len(rest)], nil
	}
}

func recognize4[T, U, V, W any](p1 parser[T], p2 parser[U], p3 parser[V], p4 parser[W]) parser[[]byte] {
	return func(input []byte) (rest []byte, value []byte, err error) {
		rest, _, err = p1(input)
		if err != nil {
			return
		}
		rest, _, err = p2(rest)
		if err != nil {
			return
		}
		rest, _, err = p3(rest)
		if err != nil {
			return
		}
		rest, _, err = p4(rest)
		if err != nil {
			return
		}
		return rest, input[:len(input)-len(rest)], nil
	}
}

func many0[T any](p parser[T]) parser[[]T] {
	return func(input []byte) (rest []byte, value []T, err error) {
		for {
			var v T
			rest, v, err = p(input)
			if err != nil {
				return input, value, nil
			}
			if len(rest) == len(input) {
				return rest, value, errors.New("infinite loop in many0")
			}
			value = append(value, v)
			input = rest
		}
	}
}

func opt[T any](p parser[T]) parser[*T] {
	return func(input []byte) ([]byte, *T, error) {
		rest, value, err := p(input)
		if err != nil {
			return input, nil, nil
		}
		return rest, &value, nil
	}
}

// uint64LE parses a little-endian 64-bit unsigned integer.
func uint64LE(input []byte) ([]byte, uint64, error) {
	if len(input) < 8 {
		return input, 0, fmt.Errorf("need 8 bytes to parse a 64-bit integer; only got %d", len(input))
	}
	return input[8:], binary.LittleEndian.Uint64(input), nil
}

func uint32LE(input []byte) ([]byte, uint32, error) {
	if len(input) < 4 {
		return input, 0, fmt.Errorf("need 4 bytes to parse a 32-bit integer; only got %d", len(input))
	}
	return input[4:], binary.LittleEndian.Uint32(input), nil
}

func uint16LE(input []byte) ([]byte, uint16, error) {
	if len(input) < 2 {
		return input, 0, fmt.Errorf("need 2 bytes to parse a 16-bit integer; only got %d", len(input))
	}
	return input[2:], binary.LittleEndian.Uint16(input), nil
}

func lengthData[T constraints.Integer](p parser[T]) parser[[]byte] {
	return func(data []byte) ([]byte, []byte, error) {
		data, length, err := p(data)
		if err != nil {
			return data, nil, err
		}
		if int(length) > len(data) {
			return data, nil, fmt.Errorf("need %d bytes, only have %d", length, len(data))
		}
		return data[length:], data[:length], nil
	}
}

func toString(p parser[[]byte]) parser[string] {
	return mapValue(p, func(b []byte) string {
		return string(b)
	})
}

func rustString(data []byte) ([]byte, string, error) {
	return toString(lengthData(uint64LE))(data)
}

func optionalString(data []byte) ([]byte, string, error) {
	return mapValue(option(rustString), func(p *string) string {
		if p == nil {
			return ""
		}
		return *p
	})(data)
}

// option parses a serialized Rust Option<T>, with the first byte being 0 for
// None and 1 for Some.
func option[T any](p parser[T]) parser[*T] {
	return func(data []byte) ([]byte, *T, error) {
		if len(data) == 0 {
			return data, nil, io.ErrUnexpectedEOF
		}
		switch data[0] {
		case 0:
			return data[1:], nil, nil
		case 1:
			data, v, err := p(data[1:])
			if err != nil {
				return data, nil, err
			}
			return data, &v, nil
		default:
			return data, nil, fmt.Errorf("want 0 or 1, got 0x%02x", data[0])
		}
	}
}

// vec parses a serialized Rust Vec<T>, starting with a 64-bit length.
func vec[T any](p parser[T]) parser[[]T] {
	return func(data []byte) ([]byte, []T, error) {
		data, length, err := uint64LE(data)
		if err != nil {
			return data, nil, err
		}
		vector := make([]T, length)
		for i := range vector {
			var v T
			data, v, err = p(data)
			if err != nil {
				return data, nil, err
			}
			vector[i] = v
		}
		return data, vector, nil
	}
}

// hashMap parses a serialized Rust HashMap<K,V>, starting with a 64-bit
// length.
func hashMap[K comparable, V any](pk parser[K], pv parser[V]) parser[map[K]V] {
	return func(data []byte) ([]byte, map[K]V, error) {
		data, length, err := uint64LE(data)
		if err != nil {
			return data, nil, err
		}
		m := make(map[K]V, length)
		for i := 0; i < int(length); i++ {
			var key K
			var val V
			data, key, err = pk(data)
			if err != nil {
				return data, nil, err
			}
			data, val, err = pv(data)
			if err != nil {
				return data, nil, err
			}
			m[key] = val
		}
		return data, m, nil
	}
}

func take(n int) parser[[]byte] {
	return func(data []byte) ([]byte, []byte, error) {
		if n > len(data) {
			return data, nil, io.ErrUnexpectedEOF
		}
		return data[n:], data[:n], nil
	}
}

// assign assigns the result of p to dest.
func assign[T any](dest *T, p parser[T]) parser[[]byte] {
	return func(input []byte) ([]byte, []byte, error) {
		rest, v, err := p(input)
		if err != nil {
			return input, nil, err
		}
		*dest = v
		return rest, input[:len(input)-len(rest)], nil
	}
}

func tuple[T any](parsers ...parser[T]) parser[[]T] {
	return func(data []byte) ([]byte, []T, error) {
		results := make([]T, len(parsers))
		var err error
		for i, p := range parsers {
			data, results[i], err = p(data)
			if err != nil {
				return data, nil, err
			}
		}
		return data, results, nil
	}
}