Plugin to add Pijul support to the Nix package manager
#ifdef NIX_VERSION

#include "repo.h"
#include "debug.h"

#include <chrono>

#include <cache.hh>
#include <fetch-settings.hh>
#include <fetchers.hh>
#include <store-api.hh>

#if NIX_VERSION >= 0x022000
#include <posix-source-accessor.hh>
#endif

#if NIX_VERSION >= 0x022300
#include <store-path-accessor.hh>
#elif NIX_VERSION >= 0x022100
#include <fs-input-accessor.hh>
#endif

#include <date/date.h>
#include <nlohmann/json.hpp>

using nixpluginpijul::getRepoStatus;
using nixpluginpijul::getState;
using nixpluginpijul::getTrackedFiles;
using nixpluginpijul::isRepoDirty;
using nixpluginpijul::record;
using nixpluginpijul::RepoStatus;

using namespace std::string_literals;
using namespace std::string_view_literals;

namespace nix::fetchers
{

struct PijulInputScheme : InputScheme {
#if NIX_VERSION >= 0x022400
    [[nodiscard]] std::optional<Input> inputFromURL(const Settings &settings, const ParsedURL &url, bool requireTree) const override
#elif NIX_VERSION >= 0x021800
    [[nodiscard]] std::optional<Input> inputFromURL(const ParsedURL &url, bool requireTree) const override
#else
    [[nodiscard]] std::optional<Input> inputFromURL(const ParsedURL &url) const override
#endif
    {
        DBG_BEGIN

        if (url.scheme != "pijul+http" && url.scheme != "pijul+https" && url.scheme != "pijul+ssh" && url.scheme != "pijul+file") {
            return {};
        }

        auto url2(url);
        url2.scheme = std::string(url2.scheme, 6);
        url2.query.clear();

        Attrs attrs;
        attrs.emplace("type"s, "pijul"s);

        for (const auto &[name, value] : url.query) {
            if (name == "channel" || name == "state") {
                attrs.emplace(name, value);
            } else {
                url2.query.emplace(name, value);
            }
        }

        attrs.emplace("url"s, url2.to_string());

#if NIX_VERSION >= 0x022400
        return inputFromAttrs(settings, attrs);
#else
        return inputFromAttrs(attrs);
#endif

        DBG_END
    }

#if NIX_VERSION >= 0x022400
    [[nodiscard]] std::optional<Input> inputFromAttrs(const Settings &settings, const Attrs &attrs) const override
#else
    [[nodiscard]] std::optional<Input> inputFromAttrs(const Attrs &attrs) const override
#endif
    {
        DBG_BEGIN

        if (maybeGetStrAttr(attrs, "type") != "pijul") {
            return {};
        }

#if NIX_VERSION < 0x021900
        for (const auto &[name, _] : attrs) {
            if (name != "type"sv && name != "url"sv && name != "channel"sv && name != "state"sv && name != "narHash"sv && name != "lastModified"sv) {
                throw Error("unsupported Pijul input attribute '%s'"s, name);
            }
        }
#endif

        parseURL(getStrAttr(attrs, "url"));

#if NIX_VERSION >= 0x022400
        Input input(settings);
#else
        Input input;
#endif
        input.attrs = attrs;

#if NIX_VERSION < 0x022100
        if (maybeGetStrAttr(input.attrs, "channel") && maybeGetStrAttr(input.attrs, "state")) {
            input.locked = true;
        }
#endif

        return input;

        DBG_END
    }

#if NIX_VERSION < 0x021900
    [[nodiscard]] bool hasAllInfo(const Input &input) const override
    {
        return maybeGetIntAttr(input.attrs, "lastModified").has_value();
    }
#endif

    [[nodiscard]] ParsedURL toURL(const Input &input) const override
    {
        DBG_BEGIN

        auto url = parseURL(getStrAttr(input.attrs, "url"));

        if (url.scheme != "pijul") {
            url.scheme = "pijul+"s + url.scheme;
        }

        if (auto channel = maybeGetStrAttr(input.attrs, "channel"s)) {
            url.query.insert_or_assign("channel"s, std::move(*channel));
        }

        if (auto state = maybeGetStrAttr(input.attrs, "state"s)) {
            url.query.insert_or_assign("state"s, std::move(*state));
        }

        return url;

        DBG_END
    }

#if NIX_VERSION >= 0x022100
    std::pair<StorePath, Input> fetchToStore(ref<Store> store, const Input &_input) const
#else
    std::pair<StorePath, Input> fetch(ref<Store> store, const Input &_input) override
#endif
    {
        DBG_BEGIN

        if (auto localPath = getSourcePath(_input)) {
            return fetchLocal(store, _input, *localPath);
        } else {
            auto [storePath, infoAttrs] = doFetch(store, _input);

            Input input(_input);
            mergeAttrs(input.attrs, std::move(infoAttrs));
            return {std::move(storePath), input};
        }

        DBG_END
    }

#if NIX_VERSION >= 0x021900
    std::optional<Path> getSourcePath(const Input &input) const override
#else
    std::optional<Path> getSourcePath(const Input &input) override
#endif
    {
        DBG_BEGIN

        auto url = parseURL(getStrAttr(input.attrs, "url"));

        if (url.scheme == "file" && !input.getRef() && !input.getRev()) {
            return url.path;
        }

        return {};

        DBG_END
    }

#if NIX_VERSION >= 0x021900
    void putFile(const Input &input, const CanonPath &path, std::string_view contents, std::optional<std::string> commitMsg) const override
    {
        DBG_BEGIN

        auto root = getSourcePath(input);
        assert(root);

#if NIX_VERSION >= 0x022100
        writeFile((CanonPath(*root) / path).abs(), contents);
#else
        writeFile((CanonPath(*root) + path).abs(), contents);
#endif

        record(*commitMsg, *root, {Path(path.rel())});

        DBG_END
    }
#else
    void markChangedFile(const Input &input, std::string_view file, std::optional<std::string> commitMsg) override
    {
        auto root = getSourcePath(input);
        assert(root);

        record(*commitMsg, *root, {std::string(file)});
    }
#endif

#if NIX_VERSION >= 0x021900
    std::string_view schemeName() const override
    {
        return "pijul"sv;
    }

    StringSet allowedAttrs() const override
    {
        return {"url"s, "channel"s, "state"s, "narHash"s, "lastModified"s};
    }
#endif

#if NIX_VERSION >= 0x022100
    bool isLocked(const Input &input) const override
    {
        DBG_BEGIN
        return maybeGetStrAttr(input.attrs, "channel") && maybeGetStrAttr(input.attrs, "state");
        DBG_END
    }

#if NIX_VERSION >= 0x022300
    std::pair<ref<SourceAccessor>, Input> getAccessor(ref<Store> store, const Input &_input) const override
#else
    std::pair<ref<InputAccessor>, Input> getAccessor(ref<Store> store, const Input &_input) const override
#endif
    {
        DBG_BEGIN
        Input input(_input);

        auto [storePath, _] = fetchToStore(store, input);

        return {makeStorePathAccessor(store, storePath), input};
        DBG_END
    }
#endif

private:
    static std::pair<StorePath, Attrs> doFetch(const ref<Store> &_store, const Input &input)
    {
        DBG_BEGIN
#if NIX_VERSION >= 0x022000
        Store &store = *_store;
#else
        const auto &store = _store;
#endif

        const auto &name = input.getName();

        const auto url = parseURL(getStrAttr(input.attrs, "url"));
        const auto &repoUrl = url.base;
        const auto channel = maybeGetStrAttr(input.attrs, "channel");
        const auto state = maybeGetStrAttr(input.attrs, "state");

#if NIX_VERSION >= 0x022300
        std::optional<Cache::Key> key;
#else
        std::optional<Attrs> key;
#endif
        bool isLocked = false;

        if (channel && state) {
            isLocked = true;

#if NIX_VERSION >= 0x022300
            key = {"pijul", {
                {"name", name},
                {"channel", *channel},
                {"state", *state},
            }};

            if (auto res = getCache()->lookupStorePath(*key, store)) {
                return {std::move(res->storePath), std::move(res->value)};
            }
#else
            key = {
                {"type", "pijul"},
                {"name", name},
                {"channel", *channel},
                {"state", *state},
            };

            if (auto res = getCache()->lookup(store, *key)) {
                auto &[infoAttrs, storePath] = *res;
                return {std::move(storePath), std::move(infoAttrs)};
            }
#endif
        }

#if NIX_VERSION >= 0x022300
        const Cache::Key impureKey{"pijul", {
            {"name", name},
            {"url", repoUrl},
        }};

        if (auto res = getCache()->lookupStorePath(impureKey, store)) {
            auto &infoAttrs = res->value;

            if ((!channel || *channel == getStrAttr(infoAttrs, "channel")) && (!state || *state == getStrAttr(infoAttrs, "state"))) {
                return {std::move(res->storePath), std::move(infoAttrs)};
            }
        }
#else
        const Attrs impureKey{
            {"type", "pijul"},
            {"name", name},
            {"url", repoUrl},
        };

        if (auto res = getCache()->lookup(store, impureKey)) {
            auto &[infoAttrs, storePath] = *res;

            if ((!channel || *channel == getStrAttr(infoAttrs, "channel")) && (!state || *state == getStrAttr(infoAttrs, "state"))) {
                return {std::move(storePath), std::move(infoAttrs)};
            }
        }
#endif

        auto [storePath, rs] = doFetch(_store, name, repoUrl, channel, state);

        if (!key) {
#if NIX_VERSION >= 0x022300
            key = {"pijul", {
                {"name", name},
            }};
#else
            key = {
                {"type", "pijul"},
                {"name", name},
            };
#endif
        }

#if NIX_VERSION >= 0x022300
        mergeAttrs(key->second,
#else
        mergeAttrs(*key,
#endif
                   {
                       {"channel", rs.channel},
                       {"state", rs.state},
                   });

        Attrs infoAttrs = {
            {"channel", std::move(rs.channel)},
            {"state", std::move(rs.state)},
            {"lastModified", rs.lastModified},
        };

        if (!isLocked) {
#if NIX_VERSION >= 0x022300
            getCache()->upsert(impureKey, store, infoAttrs, storePath);
#else
            getCache()->add(store, impureKey, infoAttrs, storePath, false);
#endif
        }

#if NIX_VERSION >= 0x022300
        getCache()->upsert(*key, store, infoAttrs, storePath);
#else
        getCache()->add(store, *key, infoAttrs, storePath, true);
#endif

        return {std::move(storePath), std::move(infoAttrs)};
        DBG_END
    }

    static std::pair<StorePath, RepoStatus> doFetch(const ref<Store> &store,
                                                    const std::string_view &inputName,
                                                    const std::string_view &repoUrl,
                                                    const std::optional<std::string_view> &channel,
                                                    const std::optional<std::string_view> &state)
    {
        DBG_BEGIN
        const Path tmpDir = createTempDir();
        const AutoDelete delTmpDir(tmpDir, true);
        const auto repoDir = tmpDir + "/source"sv;

        nixpluginpijul::clone(repoUrl, repoDir, channel, state);

        RepoStatus rs = getRepoStatus(repoDir);

        if (channel && *channel != rs.channel) {
            throw Error("channel mismatch: requested %s, got %s"s, *channel, rs.channel);
        }

        if (state && *state != rs.state) {
            throw Error("state mismatch: requested %s, got %s"s, *state, rs.state);
        }

        deletePath(repoDir + "/.pijul"sv);

#if NIX_VERSION >= 0x022300
        auto path = PosixSourceAccessor::createAtRoot(repoDir);
        auto storePath = store->addToStore(inputName, path);
#elif NIX_VERSION >= 0x022100
        auto [accessor, canonPath] = PosixSourceAccessor::createAtRoot(repoDir);
        auto storePath = store->addToStore(inputName, accessor, canonPath);
#elif NIX_VERSION >= 0x022000
        PosixSourceAccessor accessor;
        auto storePath = store->addToStore(inputName, accessor, CanonPath::fromCwd(repoDir));
#else
        auto storePath = store->addToStore(inputName, repoDir);
#endif

        return {std::move(storePath), std::move(rs)};
        DBG_END
    }

    static std::pair<StorePath, Input> fetchLocal(const ref<Store> &store, const Input &_input, const Path &path)
    {
        DBG_BEGIN
        if (_input.attrs.contains("channel"s) || _input.attrs.contains("state"s)) {
            throw Error("no channel/state support for local Pijul repository yet"s);
        }

        Input input(_input);

        bool dirty = isRepoDirty(path);

        if (dirty) {
#if NIX_VERSION >= 0x022400
            auto settings = *input.settings;
#else
            auto settings = fetchSettings;
#endif

            if (!settings.allowDirty) {
                throw Error("Pijul tree '%s' is dirty", path);
            }

            if (settings.warnDirty) {
                warn("Pijul tree '%s' is dirty", path);
            }
        }

        auto files = getTrackedFiles(path);

        Path actualPath(absPath(Path(path)));

        PathFilter filter = [&](const Path &p) -> bool {
            assert(hasPrefix(p, actualPath));
            std::string file(p, actualPath.size() + 1);

            auto st = lstat(p);

            // TODO is this necessary? pijul tracks directories
            if (S_ISDIR(st.st_mode)) {
                auto prefix = file + "/";
                auto i = files.lower_bound(prefix);
                return i != files.end() && hasPrefix(*i, prefix);
            }

            return files.count(file);
        };

#if NIX_VERSION >= 0x022400
        auto storePath = store->addToStore(input.getName(), {getFSSourceAccessor(), CanonPath(actualPath)}, ContentAddressMethod::Raw::NixArchive, HashAlgorithm::SHA256, {}, filter);
#elif NIX_VERSION >= 0x022300
        auto storePath = store->addToStore(input.getName(), {getFSSourceAccessor(), CanonPath(actualPath)}, FileIngestionMethod::Recursive, HashAlgorithm::SHA256, {}, filter);
#elif NIX_VERSION >= 0x022000
        PosixSourceAccessor accessor;
        auto storePath = store->addToStore(input.getName(), accessor, CanonPath{actualPath}, FileIngestionMethod::Recursive, HashAlgorithm::SHA256, {}, filter);
#else
        auto storePath = store->addToStore(input.getName(), actualPath, FileIngestionMethod::Recursive, htSHA256, filter);
#endif

        try {
            const auto [state, timestamp] = getState(path);

            input.attrs.insert_or_assign("lastModified", timestamp);
        } catch (...) {
            input.attrs.insert_or_assign("lastModified", uint64_t(0));
        }

        return {std::move(storePath), input};
        DBG_END
    }

    static void mergeAttrs(Attrs &dest, Attrs &&source)
    {
        while (true) {
            auto next = source.begin();

            if (next == source.end()) {
                break;
            }

            auto handle = source.extract(next);

            mergeOne(dest, std::move(handle.key()), std::move(handle.mapped()));
        }
    }

    static void mergeOne(Attrs &dest, std::string key, Attr attr)
    {
        const auto &d = dest.find(key);

        if (d != dest.end()) {
            if (d->second != attr) {
                throw Error("while merging attrs: value mismatch for %s", d->first);
            }
        } else {
            dest.emplace(std::move(key), std::move(attr));
        }
    }
};

[[maybe_unused]] static auto rPijulInputScheme = OnStartup([] {
    registerInputScheme(std::make_unique<PijulInputScheme>());
});

} // namespace nix::fetchers

#endif