// Copyright © 2023 Kim Altintop <kim@eagain.io>
// SPDX-License-Identifier: GPL-2.0-only

use core::fmt;
use std::{
    collections::HashMap,
    iter,
    ops::Deref,
    str::FromStr,
};

use anyhow::{
    bail,
    Context as _,
};
use base64::{
    engine::general_purpose::URL_SAFE_NO_PAD,
    Engine as _,
};
use derive_more::From;
use ed25519_compact as ed25519;
use itertools::Itertools as _;
use pgp::Deserializable as _;
use serde_json::json;
use url::Url;

use yapma_common::http::{
    self,
    Response as _,
};

#[cfg(test)]
mod tests;

pub mod error {
    use thiserror::Error;

    #[derive(Debug, Error)]
    pub enum DID {
        #[error("not a DID")]
        NotADid,
        #[error("unsupported method")]
        UnsupportedMethod,
        #[error("unexpected eof")]
        UnexpectedEof,
        #[error("malformed method-specific identifier: `{0}`")]
        MalformedIdentifier(String),
        #[error("invalid method-specific identifier")]
        InvalidIdentifier(#[from] url::ParseError),
    }

    #[derive(Debug, Error)]
    pub enum DIDUrl {
        #[error("not a DID URL")]
        NotADidUrl,
        #[error("not a valid URL")]
        InvalidUrl(#[from] url::ParseError),
    }
}

/// A [DID Document], only capturing the fields we're interested in.
///
/// Specifically, we only care about the [`DID`], the subset of verification
/// methods we support (as per [`VerificationMethod`]), and any `alsoKnownAs`
/// references.
///
/// [DID Document]: https://www.w3.org/TR/did-core/#dfn-did-documents
#[derive(Debug, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PartialDocument {
    pub id: DID,
    #[serde(deserialize_with = "deserialize_verification_methods")]
    pub verification_method: Vec<VerificationMethod>,
    #[serde(default)]
    pub also_known_as: Vec<Url>,
}

impl PartialDocument {
    /// Validate the document according to the [`ValidDocument`] criteria.
    ///
    /// Returns `None` if validation doesn't pass. Consumes `self`, as an
    /// invalid document isn't of any further use.
    pub fn validate(self) -> Option<ValidDocument> {
        let verification_method = self
            .verification_method
            .into_iter()
            .filter(|method| method.controller == self.id && method.public_key.has_ed25519())
            .collect::<Vec<_>>();

        not(verification_method.is_empty()).then(|| {
            ValidDocument(Self {
                verification_method,
                ..self
            })
        })
    }
}

/// A [`PartialDocument`] passing certain validation criteria for use in
/// `yapma`.
///
/// Constructed by [`PartialDocument::validate`].
///
/// The validation criteria are essentially that at least one
/// [`VerificationMethod`] must be in the document for which:
///
/// * the `controller` is equal to the document `id`
/// * if the `public_key` is a PGP key, it must be an Ed25519 key or have at
///   least one Ed25519 subkey.
///
/// Any [`VerificationMethod`] which doesn't pass this is removed from the
/// `verification_method` list of the inner [`PartialDocument`].
pub struct ValidDocument(PartialDocument);

impl Deref for ValidDocument {
    type Target = PartialDocument;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl From<ValidDocument> for PartialDocument {
    fn from(ValidDocument(doc): ValidDocument) -> Self {
        doc
    }
}

/// Create a JSON [DID document] from a [DID] and a public key.
///
/// [DID document]: https://www.w3.org/TR/did-core/#dfn-did-documents
/// [did-web]: https://w3c-ccg.github.io/did-method-web/
pub fn create(did: DID, key: ed25519::PublicKey) -> serde_json::Value {
    let key = {
        let mut buf = [0u8; ed25519::PublicKey::BYTES + 2];
        buf[0] = 0xed;
        buf[1] = 0x01;
        buf[2..].copy_from_slice(&*key);
        multibase::encode(multibase::Base::Base58Btc, buf)
    };
    json!({
        "@context": [
            "https://www.w3.org/ns/did/v1",
            "https://w3id.org/security/suites/ed25519-2020/v1",
        ],
        "id": *did,
        "verificationMethod": [{
            "id": format!("{did}#key-0"),
            "type": "Ed25519VerificationKey2020",
            "controller": *did,
            "publicKeyMultiBase": key,
        }],
    })
}

#[derive(Debug)]
pub struct VerificationMethod {
    pub id: String,
    pub controller: DID,
    pub public_key: PublicKey,
}

#[derive(Debug, From)]
pub enum PublicKey {
    Ed25519(ed25519::PublicKey),
    Pgp(pgp::SignedPublicKey),
}

impl PublicKey {
    /// `true` if this [`PublicKey`] contains at least one Ed25519 signing key.
    pub fn has_ed25519(&self) -> bool {
        match self {
            Self::Ed25519(..) => true,
            Self::Pgp(key) => has_ed25519(key),
        }
    }
}

/// `true` if the given PGP key contains at least one Ed25519 signing key.
pub fn has_ed25519(key: &pgp::SignedPublicKey) -> bool {
    use pgp::{
        crypto::ecc_curve::ECCCurve,
        types::public::PublicParams,
    };

    fn is_ed25519(params: &PublicParams) -> bool {
        matches!(
            params,
            PublicParams::EdDSA {
                curve: ECCCurve::Ed25519,
                ..
            }
        )
    }

    is_ed25519(key.primary_key.public_params())
        || key
            .public_subkeys
            .iter()
            .map(|sub| sub.key.public_params())
            .any(is_ed25519)
}

fn deserialize_verification_methods<'de, D>(
    deserializer: D,
) -> Result<Vec<VerificationMethod>, D::Error>
where
    D: serde::Deserializer<'de>,
{
    struct Visitor;

    impl<'de> serde::de::Visitor<'de> for Visitor {
        type Value = Vec<VerificationMethod>;

        fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
            f.write_str("a list of verification methods")
        }

        fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
        where
            S: serde::de::SeqAccess<'de>,
        {
            let mut out = seq.size_hint().map(Vec::with_capacity).unwrap_or_default();
            while let Some(val) = seq.next_element::<VerificationMethodMap>()? {
                let key = if let Some(key) = val.public_key_jwk {
                    key.0.into()
                } else if let Some(key) = val.public_key_base58 {
                    key.0.into()
                } else if let Some(key) = val.public_key_multi_base {
                    key.0.into()
                } else if let Some(key) = val.public_key_pgp {
                    key.0.into()
                } else {
                    return Err(serde::de::Error::missing_field(
                        "VerificationMethodMap::public_key_*",
                    ));
                };

                out.push(VerificationMethod {
                    id: val.id,
                    controller: val.controller,
                    public_key: key,
                })
            }

            Ok(out)
        }
    }

    deserializer.deserialize_seq(Visitor)
}

/// A [decentralized identifier] (DID).
///
/// The only currently supported [method] is [did-web].
///
/// # Examples
///
/// ```text
/// did:web:example.com
/// did:web:example.com:u:bob
/// ```
///
/// [decentralized identifier]: https://www.w3.org/TR/did-core/#dfn-decentralized-identifiers
/// [method]: https://www.w3.org/TR/did-core/#dfn-did-methods
/// [did-web]: https://w3c-ccg.github.io/did-method-web/
#[derive(Clone, Debug, PartialEq, serde::Deserialize)]
#[serde(try_from = "String")]
pub struct DID(String);

impl DID {
    pub fn parse<S>(s: S) -> Result<Self, error::DID>
    where
        S: AsRef<str> + Into<String>,
    {
        let uri = Url::parse(s.as_ref())?;
        if uri.scheme() != "did" {
            return Err(error::DID::NotADid);
        }

        let mut iter = uri.path().split(':');
        let method = iter.next().ok_or(error::DID::UnexpectedEof)?;
        if method != "web" {
            return Err(error::DID::UnsupportedMethod);
        }
        if iter.any(|segment| segment.contains('/')) {
            return Err(error::DID::MalformedIdentifier(
                uri.path().strip_prefix("web:").unwrap().to_owned(),
            ));
        }
        let iter = uri.path().split(':').skip(1);

        let suf = iter::once("https://").chain(iter).join("/");
        let _ = Url::parse(&suf)?;

        Ok(Self(s.into()))
    }

    pub fn id(&self) -> &str {
        self
    }

    pub async fn resolve<C: http::Client>(&self, client: &C) -> anyhow::Result<PartialDocument> {
        let mut url = Url::parse(
            &iter::once("https://")
                .chain(self.split(':').skip(2))
                .join("/"),
        )?;
        let is_empty_path = url.path() == "/";
        {
            let mut path = url.path_segments_mut().unwrap();
            path.pop_if_empty();
            if is_empty_path {
                path.push(".well-known");
            }
            path.push("did.json");
        }

        client.get(url).await?.json().await
    }
}

impl fmt::Display for DID {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.write_str(self)
    }
}

impl FromStr for DID {
    type Err = error::DID;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        DID::parse(s)
    }
}

impl TryFrom<String> for DID {
    type Error = error::DID;

    fn try_from(s: String) -> Result<Self, Self::Error> {
        DID::parse(s)
    }
}

impl TryFrom<Url> for DID {
    type Error = anyhow::Error;

    fn try_from(url: Url) -> Result<Self, Self::Error> {
        let host = match url.host().context("missing host")? {
            url::Host::Domain(host) => host,
            url::Host::Ipv4(_) | url::Host::Ipv6(_) => {
                bail!("Host must be a domain name, not IP address")
            },
        };
        let port = url.port().map(|p| format!("%3A{p}")).unwrap_or_default();
        let mut path: Vec<&str> = url.path().split('/').filter(|s| !s.is_empty()).collect();
        if let Some(&"did.json") = path.last() {
            path.pop();
        }
        if let Some(&".well-known") = path.last() {
            path.pop();
        }
        let path = path.join(":");

        Ok(Self(format!("did:web:{host}{port}:{path}")))
    }
}

impl Deref for DID {
    type Target = str;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl AsRef<str> for DID {
    fn as_ref(&self) -> &str {
        self
    }
}

#[derive(Debug, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
struct VerificationMethodMap {
    id: String,
    controller: DID,
    public_key_jwk: Option<PublicKeyJwk>,
    public_key_base58: Option<PublicKeyBase58>,
    public_key_multi_base: Option<PublicKeyMultibase>,
    public_key_pgp: Option<PublicKeyPgp>,
}

#[derive(Debug)]
struct PublicKeyJwk(ed25519::PublicKey);

impl<'de> serde::Deserialize<'de> for PublicKeyJwk {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        let map: HashMap<String, String> = serde::Deserialize::deserialize(deserializer)?;
        match map.get("kty").map(|s| s.as_str()) {
            Some("OKP") => Ok(()),
            Some(x) => Err(serde::de::Error::custom(format!("unsupported kty: {x}"))),
            None => Err(serde::de::Error::custom("missing kty")),
        }?;
        match map.get("crv").map(|s| s.as_str()) {
            Some("Ed25519") => Ok(()),
            Some(x) => Err(serde::de::Error::custom(format!("unsupported crv: {x}"))),
            None => Err(serde::de::Error::custom("missing crv")),
        }?;

        let b64: &str = map.get("x").ok_or(serde::de::Error::custom("missing x"))?;
        let bytes = URL_SAFE_NO_PAD
            .decode(b64)
            .map_err(serde::de::Error::custom)?;
        let x = ed25519::PublicKey::from_slice(&bytes).map_err(serde::de::Error::custom)?;

        Ok(Self(x))
    }
}

#[derive(Debug)]
struct PublicKeyBase58(ed25519::PublicKey);

impl<'de> serde::Deserialize<'de> for PublicKeyBase58 {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        let b58: &str = serde::Deserialize::deserialize(deserializer)?;
        let bytes = bs58::decode(b58)
            .into_vec()
            .map_err(serde::de::Error::custom)?;
        let x = ed25519::PublicKey::from_slice(&bytes).map_err(serde::de::Error::custom)?;

        Ok(Self(x))
    }
}

#[derive(Debug)]
struct PublicKeyMultibase(ed25519::PublicKey);

impl<'de> serde::Deserialize<'de> for PublicKeyMultibase {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        let mb: &str = serde::Deserialize::deserialize(deserializer)?;
        let (_base, bytes) = multibase::decode(mb).map_err(serde::de::Error::custom)?;
        if bytes.len() != 34 {
            fn exp() -> impl serde::de::Expected {
                "34"
            }
            return Err(serde::de::Error::invalid_length(bytes.len(), &exp()));
        }
        if bytes[0] != 0xed {
            return Err(serde::de::Error::custom(format!(
                "Expected first byte of multikey to be 0xed, got {}",
                bytes[0]
            )));
        }
        if bytes[1] != 0x01 {
            return Err(serde::de::Error::custom(format!(
                "Expected second byte of multikey to 0x01, got {}",
                bytes[1]
            )));
        }
        let x = ed25519::PublicKey::from_slice(&bytes[2..]).map_err(serde::de::Error::custom)?;

        Ok(Self(x))
    }
}

#[derive(Debug)]
struct PublicKeyPgp(pgp::SignedPublicKey);

impl<'de> serde::Deserialize<'de> for PublicKeyPgp {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        let ascii: String = serde::Deserialize::deserialize(deserializer)?;
        let (key, _) =
            pgp::SignedPublicKey::from_string(&ascii).map_err(serde::de::Error::custom)?;

        Ok(Self(key))
    }
}

fn not(x: bool) -> bool {
    !x
}