#![warn(unused_extern_crates)]
use axum::{
    debug_handler,
    extract::{MatchedPath, Path, Request, State},
    http::StatusCode,
    response::{IntoResponse, Redirect, Response},
    routing::{any, get, post},
    Json, Router,
};
use axum_extra::extract::cookie::SignedCookieJar;
use axum_response_cache::CacheLayer;
use diesel::{ExpressionMethods, NullableExpressionMethods, OptionalExtension, QueryDsl};
use diesel_async::*;
use serde::*;
use std::collections::BTreeMap;
use std::net::{Ipv6Addr, SocketAddr, SocketAddrV6};
use std::sync::Arc;
use tower_http::cors::CorsLayer;
use tower_http::{compression::CompressionLayer, trace::TraceLayer};
use tracing::*;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

mod auth;
mod change;
mod channel;
mod config;
mod config_file;
mod db;
mod discussions;
mod email;
mod hooks;
mod identicon;
mod markdown;
mod permissions;
mod proxy;
mod replication;
mod repository;
mod settings;
mod ssh;

#[derive(Clone)]
pub struct Config {
    config: Arc<config::Config>,
    repo_locks: repository::RepositoryLocks,
    replicator: ::replication::Worker<replication::H>,
}

impl axum::extract::FromRef<Config> for axum_csrf::CsrfConfig {
    fn from_ref(c: &Config) -> Self {
        c.config.csrf.clone()
    }
}

impl std::ops::Deref for Config {
    type Target = config::Config;
    fn deref(&self) -> &Self::Target {
        &self.config
    }
}

#[tokio::main]
async fn main() {
    tracing_subscriber::registry()
        .with(
            tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| {
                format!(
                    "{}=debug,tower_http=debug,axum::rejection=trace",
                    env!("CARGO_CRATE_NAME")
                )
                .into()
            }),
        )
        .with(tracing_subscriber::fmt::layer())
        .init();

    let (config_file, replication_config) = config::from_app();
    let config = Arc::new(config::from_file(&config_file).await);
    let repo_locks = repository::RepositoryLocks::new(config.clone());
    let cors = CorsLayer::new().allow_origin(tower_http::cors::Any);

    let (blrecv, replicator) = {
        let (blsend, blrecv) = tokio::sync::mpsc::unbounded_channel();
        let rconfig = replication_config.to_config(blsend).await;
        let handler = crate::replication::H::new(repo_locks.clone(), config.db.clone());
        (
            blrecv,
            ::replication::Worker::new(Arc::new(rconfig), handler),
        )
    };

    let cached = Router::new()
        .route("/static/{*wildcard}", any(proxy::node_proxy_resp))
        .route("/_app/immutable/{*wildcard}", any(proxy::node_proxy_resp))
        .nest("/identicon", identicon::identicon())
        .layer(TraceLayer::new_for_http())
        .layer(CompressionLayer::new());
    debug!("cached");
    let app: Router<Config> = Router::new();
    let app = app
        .route("/login", post(auth::login))
        .route(
            "/register",
            post(auth::register_post).get(auth::register_get),
        )
        .route(
            "/recover/init",
            post(auth::recover_init)
        )
        .route(
            "/recover/reset",
            post(auth::recover_reset)
        )
        .nest("/api", api_router())
        .route("/", any(root))
        .merge(if cfg!(debug_assertions) {
            cached
        } else {
            cached.layer(CacheLayer::with_lifespan(3600))
        })
        .fallback(proxy::node_proxy_resp)
        .layer(cors)
        .layer(
            TraceLayer::new_for_http()
                .make_span_with(|req: &Request| {
                    let method = req.method();
                    let uri = req.uri();
                    let matched_path = req
                        .extensions()
                        .get::<MatchedPath>()
                        .map(|matched_path| matched_path.as_str());
                    tracing::debug_span!("request", %method, %uri, matched_path)
                })
                .on_failure(()),
        )
        .with_state(Config {
            config: config.clone(),
            repo_locks: repo_locks.clone(),
            replicator: replicator.clone(),
        });

    let addr = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0);
    let ssh_addr = std::net::SocketAddr::V6(SocketAddrV6::new(addr, config_file.ssh.port, 0, 0));

    let ssh_worker = tokio::spawn(ssh::worker(
        Config {
            config: config.clone(),
            repo_locks: repo_locks.clone(),
            replicator: replicator.clone(),
        },
        ssh::socket(&ssh_addr).await,
    ));

    let tls_config = config::make_tls().await.unwrap();
    let http_addr = SocketAddr::V6(SocketAddrV6::new(
        Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0),
        config_file.http.http_port,
        0,
        0,
    ));
    let https_addr = SocketAddr::V6(SocketAddrV6::new(
        Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0),
        config_file.http.https_port,
        0,
        0,
    ));
    tracing::debug!("listening on {}", https_addr);
    let https_worker = tokio::spawn(
        axum_server::bind_rustls(https_addr, tls_config).serve(
            app.clone()
                .into_make_service_with_connect_info::<SocketAddr>(),
        ),
    );
    tracing::debug!("listening on {}", http_addr);
    let http_worker = tokio::spawn(redirect_http_to_https(
        config_file.http.http_port,
        config_file.http.https_port,
        app,
    ));

    config::drop_privileges(&config_file);

    tokio::select!(
        x = http_worker => {
            error!("http worker stopped: {:?}", x);
            std::process::exit(1)
        },
        x = https_worker => {
            error!("https worker stopped: {:?}", x);
            std::process::exit(2)
        },
        x = ssh_worker => {
            error!("ssh worker stopped: {:?}", x);
            std::process::exit(3)
        },
        x = replicator.worker(blrecv) => {
            error!("replicator worker stopped: {:?}", x);
            std::process::exit(6)
        }
    )
}

// Avoid an infinite loop between Node and Rust.
#[debug_handler]
async fn root(
    State(config): State<Config>,
    jar: SignedCookieJar,
    req: axum::extract::Request,
) -> Result<Response, crate::Error> {
    if let Some((_, login)) = get_user_login(&jar, &config).await? {
        return Ok(Redirect::to(&format!("/{}", login)).into_response());
    }
    Ok(proxy::node_proxy(&config, req)
        .await?
        .into_response())
}

pub async fn redirect_http_to_https(http: u16, https: u16, mut app: Router<()>) {
    fn make_https(
        host: String,
        uri: http::Uri,
        http: u16,
        https: u16,
    ) -> Result<http::Uri, axum::BoxError> {
        let mut parts = uri.into_parts();

        parts.scheme = Some(axum::http::uri::Scheme::HTTPS);

        if parts.path_and_query.is_none() {
            parts.path_and_query = Some("/".parse().unwrap());
        }

        let https_host = host.replace(&http.to_string(), &https.to_string());
        parts.authority = Some(https_host.parse()?);

        Ok(http::Uri::from_parts(parts)?)
    }

    use axum::extract::ConnectInfo;
    use axum::response::IntoResponse;
    use axum_extra::extract::Host;
    let redirect = move |Host(host): Host,
                         uri: http::Uri,
                         ConnectInfo(addr): ConnectInfo<SocketAddr>,
                         r: http::Request<axum::body::Body>| async move {
        debug!("redirect {:?}", addr);
        if addr.ip().is_loopback() {
            debug!("is loopback");
            use tower_service::Service;
            Ok(app.call(r).await.into_response())
        } else {
            debug!("not loopback");
            match make_https(host, uri, http, https) {
                Ok(uri) => Ok(Redirect::permanent(&uri.to_string()).into_response()),
                Err(error) => {
                    tracing::warn!(%error, "failed to convert URI to HTTPS");
                    Err(StatusCode::BAD_REQUEST.into_response())
                }
            }
        }
    };

    let addr = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0);
    let http_addr = SocketAddr::V6(SocketAddrV6::new(addr, http, 0, 0));

    let listener = tokio::net::TcpListener::bind(http_addr).await.unwrap();
    tracing::debug!("listening on {}", listener.local_addr().unwrap());
    use axum::handler::HandlerWithoutStateExt;
    axum::serve(
        listener,
        redirect.into_make_service_with_connect_info::<SocketAddr>(),
    )
    .await
    .unwrap();
}

#[derive(Debug, Serialize)]
struct User {
    user: String,
    login: Option<String>,
    repos: BTreeMap<String, Repo>,
    is_owner: bool,
}

#[derive(Debug, Serialize)]
struct Repo {
    private: bool,
}

#[derive(Debug, Serialize)]
struct UserId {
    id: uuid::Uuid,
    login: String,
}

#[debug_handler]
async fn user_id(
    State(config): State<Config>,
    jar: SignedCookieJar,
) -> (SignedCookieJar, Json<Option<UserId>>) {
    if let Ok(Some((id, login))) = get_user_login(&jar, &config).await {
        (jar, Json(Some(UserId { login, id })))
    } else {
        use axum_extra::extract::cookie::Cookie;
        (jar.remove(Cookie::from("session_id")), Json(None))
    }
}

#[debug_handler]
async fn user(
    State(config): State<Config>,
    jar: SignedCookieJar,
    Path(user): Path<String>,
) -> Result<Response, Error> {
    debug!("user {:?}", user);
    use db::repositories::dsl as r;
    use db::users::dsl as u;
    let (user_id, login) = if let Some((id, login)) = get_user_login(&jar, &config).await? {
        (id, Some(login))
    } else {
        (uuid::Uuid::nil(), None)
    };
    let repos = u::users
        .left_join(r::repositories)
        .select((u::id, r::name.nullable(), permissions!(user_id, r::id), permissions!(uuid::Uuid::nil(), r::id)))
        .filter(u::login.eq(&user))
        .filter(u::email_is_invalid.is_null())
        .get_results::<(uuid::Uuid, Option<String>, i64, i64)>(&mut config.db.get().await?)
        .await?;
    debug!("repos = {:?}", repos);
    if repos.is_empty() {
        Ok((StatusCode::NOT_FOUND, format!("{{}}")).into_response())
    } else {
        let mut uid = uuid::Uuid::nil();
        let repos = repos
            .into_iter()
            .filter_map(|(id, name, perm, perm_all)| {
                uid = id;
                use crate::permissions::Perm;
                let perm = Perm::from_bits(perm).unwrap();
                let perm_all = Perm::from_bits(perm_all).unwrap();
                if perm.contains(Perm::READ) {
                    name.map(|name| (name, Repo { private: !perm_all.contains(Perm::READ) }))
                } else {
                    None
                }
            })
            .collect();

        let is_owner = Some(uid) == get_user_id(&jar)?;
        debug!("uid = {:?}", uid);
        let resp = User {
            user,
            login,
            repos,
            is_owner,
        };
        debug!("{:?}", resp);
        Ok(Json(resp).into_response())
    }
}

fn api_router() -> Router<Config> {
    Router::new()
        .route("/user/{login}", get(user))
        .route("/user", get(user_id))
        .nest("/settings", settings::router())
        .route("/tree/{owner}/{repo}", get(repository::router::tree))
        .nest("/admin/{owner}/{repo}", repository::admin::router())
        .nest("/change", change::router())
        .nest("/discussion", discussions::router())
        .nest("/channel", channel::router())
        .fallback(fallback)
}

async fn fallback(uri: http::Uri) -> (StatusCode, String) {
    (StatusCode::NOT_FOUND, format!("No route for {uri}"))
}

async fn fallback_json(uri: http::Uri) -> (StatusCode, String) {
    (
        StatusCode::NOT_FOUND,
        format!(r#"{{ "error": "No route for {uri}" }}"#),
    )
}

fn get_user_id(jar: &SignedCookieJar) -> Result<Option<uuid::Uuid>, Error> {
    if let Some(sid) = jar.get("session_id") {
        let (cookie_uid, _time): (uuid::Uuid, chrono::DateTime<chrono::Utc>) =
            bincode::deserialize(&data_encoding::BASE64URL.decode(sid.value().as_bytes())?)?;
        Ok(Some(cookie_uid))
    } else {
        Ok(None)
    }
}

fn get_user_id_strict(jar: &SignedCookieJar) -> Result<uuid::Uuid, Error> {
    if let Some(id) = get_user_id(jar)? {
        Ok(id)
    } else {
        Err(Error::NeedsAuth)
    }
}

async fn get_user_login(
    jar: &SignedCookieJar,
    config: &Config,
) -> Result<Option<(uuid::Uuid, String)>, Error> {
    use db::users::dsl as u;
    if let Some(id) = get_user_id(jar)? {
        if let Some(login) = u::users
            .find(id)
            .select(u::login)
            .get_result::<String>(&mut config.db.get().await?)
            .await
            .optional()?
        {
            return Ok(Some((id, login)));
        }
    }
    Ok(None)
}

async fn get_user_login_(
    jar: &SignedCookieJar,
    config: &Config,
) -> Result<(Option<uuid::Uuid>, Option<String>), Error> {
    if let Some((a, b)) = get_user_login(jar, config).await? {
        Ok((Some(a), Some(b)))
    } else {
        Ok((None, None))
    }
}

async fn get_user_login_email(
    jar: &SignedCookieJar,
    config: &Config,
) -> Result<Option<(uuid::Uuid, String, String)>, Error> {
    use db::users::dsl as u;
    if let Some(id) = get_user_id(jar)? {
        if let Some((login, email)) = u::users
            .find(id)
            .select((u::login, u::email))
            .get_result::<(String, String)>(&mut config.db.get().await?)
            .await
            .optional()?
        {
            return Ok(Some((id, login, email)));
        }
    }
    Ok(None)
}

async fn get_user_login_email_strict(
    jar: &SignedCookieJar,
    config: &Config,
) -> Result<(uuid::Uuid, String, String), Error> {
    if let Some(id) = get_user_login_email(jar, config).await? {
        Ok(id)
    } else {
        Err(Error::NeedsAuth)
    }
}

async fn get_user_login_strict(
    jar: &SignedCookieJar,
    config: &Config,
) -> Result<(uuid::Uuid, String), Error> {
    if let Some(id) = get_user_login(jar, config).await? {
        Ok(id)
    } else {
        Err(Error::NeedsAuth)
    }
}

use thiserror::Error;

#[derive(Error, Debug)]
pub enum Error {
    #[error("Lock")]
    Lock,
    #[error("Depended upon")]
    DependedUpon,
    #[error(transparent)]
    SSH(#[from] thrussh::Error),
    #[error(transparent)]
    Hyper(#[from] hyper::Error),
    #[error(transparent)]
    Http(#[from] http::Error),
    #[error(transparent)]
    Join(#[from] tokio::task::JoinError),
    #[error("Repository not found")]
    RepositoryNotFound,
    #[error("Unknown Pijul error")]
    Txn,
    #[error(transparent)]
    Archive(
        #[from]
        libpijul::output::ArchiveError<
            repository::changestore::Error,
            libpijul::pristine::sanakirja::GenericTxn<sanakirja::MutTxn<Arc<sanakirja::Env>, ()>>,
            std::io::Error,
        >,
    ),
    #[error(transparent)]
    Apply(
        #[from]
        libpijul::ApplyError<
            repository::changestore::Error,
            libpijul::pristine::sanakirja::GenericTxn<sanakirja::MutTxn<Arc<sanakirja::Env>, ()>>,
        >,
    ),
    #[error(transparent)]
    PijulChange(#[from] libpijul::change::ChangeError),
    #[error("Hash parse error: {hash:?}")]
    HashParse { hash: String },
    #[error("Insufficient permissions, required {required:?}, got {got:?}")]
    Permissions {
        required: permissions::Perm,
        got: permissions::Perm,
    },
    #[error("Auth required")]
    NeedsAuth,
    #[error(transparent)]
    Csrf(#[from] axum_csrf::CsrfError),
    #[error(transparent)]
    Io(#[from] std::io::Error),
    #[error(transparent)]
    Decode(#[from] data_encoding::DecodeError),
    #[error(transparent)]
    Bincode(#[from] bincode::Error),
    #[error(transparent)]
    ChangeStore(#[from] repository::changestore::Error),
    #[error(transparent)]
    Sanakirja(#[from] libpijul::pristine::sanakirja::SanakirjaError),
    #[error("The request body was too large, please retry with something smaller.")]
    BodyTooLarge,
    #[error("Wrong token")]
    WrongToken,
    #[error("Change not found")]
    ChangeNotFound,
    #[error("Channel not found: {:?}", channel)]
    ChannelNotFound { channel: String },
    #[error("Inactive user")]
    InactiveUser,
    #[error("OAuth parse")]
    OAuthParse,
    #[error("Wrong signature")]
    WrongSignature,
    #[error("The last channel in a repository cannot be deleted")]
    LastChannelCannotBeDeleted,
    #[error("Protocol error")]
    ProtocolError,
    #[error("Not enough space (max {:?})", quota)]
    Quota { quota: u64 },
    #[error("Forbidden")]
    Forbidden,
    #[error("Ambiguous path")]
    AmbiguousInode,
    #[error("Path not found")]
    InodeNotFound,
    #[error("This channel has changed independently")]
    ConcurrentChannelOps,
    #[error(transparent)]
    Diesel(#[from] diesel::result::Error),
    #[error(transparent)]
    Deadpool(#[from] deadpool::managed::PoolError<diesel_async::pooled_connection::PoolError>),
    #[error(transparent)]
    Replication(#[from] ::replication::Error<replication::Error>),
    #[error(transparent)]
    Utf8(#[from] std::str::Utf8Error),
    #[error(transparent)]
    ParseInt(#[from] std::num::ParseIntError),
    #[error(transparent)]
    Time(#[from] std::time::SystemTimeError),
    #[error(transparent)]
    Bs58(#[from] bs58::decode::Error),
    #[error(transparent)]
    Key(#[from] libpijul::key::KeyError),
    #[error(transparent)]
    Reqwest(#[from] reqwest::Error),
    #[error(transparent)]
    Rusoto(#[from] rusoto_core::RusotoError<rusoto_ses::SendEmailError>),
    #[error(transparent)]
    Url(#[from] serde_urlencoded::de::Error),
}

impl<E: Into<Error> + std::error::Error> From<libpijul::pristine::TxnErr<E>> for Error {
    fn from(e: libpijul::pristine::TxnErr<E>) -> Self {
        e.0.into()
    }
}

impl IntoResponse for Error {
    fn into_response(self) -> Response {
        debug!("response {:?}", self);
        match self {
            Error::RepositoryNotFound => (StatusCode::NOT_FOUND, "{}").into_response(),
            Error::ChannelNotFound { .. } => (StatusCode::NOT_FOUND, "{}").into_response(),
            Error::InodeNotFound => (StatusCode::NOT_FOUND, "{}").into_response(),
            Error::Csrf(_) => (StatusCode::FORBIDDEN, "{}").into_response(),
            Error::NeedsAuth => (StatusCode::FORBIDDEN, "{}").into_response(),
            Error::Permissions { .. } => (StatusCode::FORBIDDEN, "{}").into_response(),
            _ => (StatusCode::INTERNAL_SERVER_ERROR, "{}").into_response(),
        }
    }
}

#[derive(diesel::query_builder::QueryId, diesel::sql_types::SqlType)]
#[diesel(postgres_type(name = "keyalgorithm"))]
pub struct Keyalgorithm;

#[derive(
    Debug,
    Clone,
    Copy,
    PartialEq,
    Eq,
    PartialOrd,
    Ord,
    Hash,
    Serialize,
    Deserialize,
    diesel_derive_enum::DbEnum,
)]
#[ExistingTypePath = "Keyalgorithm"]
#[DbValueStyle = "PascalCase"]
pub enum Keyalgorithm_ {
    Ed25519,
}

#[macro_export]
macro_rules! permissions {
    ( $u:expr, $r: expr ) => {{
        use diesel::sql_types::{BigInt, Uuid};
        diesel::dsl::sql::<BigInt>("permissions(")
            .bind::<Uuid, _>($u)
            .sql(", ")
            .bind::<Uuid, _>($r)
            .sql(")")
    }};
}

#[macro_export]
macro_rules! has_permissions {
    ( $u:expr, $r: expr, $p: expr ) => {{
        use diesel::sql_types::{BigInt, Bool, Uuid};
        diesel::dsl::sql::<Bool>("permissions(")
            .bind::<Uuid, _>($u)
            .sql(", ")
            .bind::<Uuid, _>($r)
            .sql(") & ")
            .bind::<BigInt, _>($p)
            .sql(" != 0")
    }};
}

pub fn is_valid_name(u: &str) -> bool {
    u.chars().all(|x| {
        (x >= 'A' && x <= 'Z')
            || (x >= 'a' && x <= 'z')
            || (x >= '0' && x <= '9')
            || (x == '-' || x == '_' || x == '.')
    })
}

pub fn is_valid_repo_channel_name(u: &str) -> bool {
    u.chars().all(|x| {
        (x >= 'A' && x <= 'Z')
            || (x >= 'a' && x <= 'z')
            || (x >= '0' && x <= '9')
            || (x == '-' || x == '_' || x == '.' || x == ':' || x == ' ')
    })
}

pub fn split_repo_channel<'a>(
    repo: &'a str,
) -> (std::borrow::Cow<'a, str>, Option<std::borrow::Cow<'a, str>>) {
    use percent_encoding::*;
    let mut s = repo.split(':');
    match (s.next(), s.next()) {
        (Some(s), Some(b)) => (
            percent_decode_str(s).decode_utf8_lossy(),
            Some(percent_decode_str(b).decode_utf8_lossy()),
        ),
        _ => (percent_decode_str(repo).decode_utf8_lossy(), None),
    }
}