use crate::{find_files::*, mount, Error};
use futures::{SinkExt, StreamExt};
use std::collections::{HashMap, HashSet};
use std::ffi::CString;
use std::os::unix::fs::PermissionsExt;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tracing::*;

/// A request to build a package.
#[derive(bincode::Encode, bincode::Decode, Debug)]
pub struct BuildRequest {
    /// Name of the package, to be appended to the hash to form a store path.
    pub name: String,
    /// Store paths to mount into the build environment.
    pub paths: Vec<String>,
    /// Build script to run inside the build environment.
    pub script: String,
    /// Target platform.
    pub target: String,
    /// Expected output hash.
    pub output_hash: Option<String>,
}

/// The channel used to communicate with the container process.
pub struct ContainerChannel {
    r: std::os::unix::net::UnixStream,
    w: std::os::unix::net::UnixStream,
}

/// A function that forwards messages from the receiver to the
/// container process and back.
pub async fn forward(
    mut receiver: tokio::sync::mpsc::UnboundedReceiver<(
        BuildRequest,
        tokio::sync::oneshot::Sender<Result<PathBuf, String>>,
    )>,
    c: ContainerChannel,
) -> Result<(), Error> {
    let r = tokio::net::UnixStream::from_std(c.r).unwrap();
    let w = tokio::net::UnixStream::from_std(c.w).unwrap();

    let encoder = tokio_util::codec::LengthDelimitedCodec::new();
    let mut writer = tokio_util::codec::FramedWrite::new(w, encoder);

    let decoder = tokio_util::codec::LengthDelimitedCodec::new();
    let mut reader = tokio_util::codec::FramedRead::new(r, decoder);

    let mut pending = HashMap::new();
    let mut id = 1u64;
    loop {
        tokio::select! {
            x = receiver.recv() => {
                if let Some((msg, resp)) = x {
                    pending.insert(id, resp);
                    let mut bytes = bytes::BytesMut::new();
                    bytes.extend_from_slice(&bincode::encode_to_vec(&(id, msg), bincode::config::standard()).unwrap());
                    debug!("sending to process {:?}", bytes.len());
                    writer.send(bytes.into()).await?;
                    id += 1;
                }
            }
            x = reader.next() => {
                debug!("received process response {:?}", x);
                if let Some(Ok(msg)) = x {
                    let ((id, resp), _) = bincode::decode_from_slice::<(u64, ProcessResult), _>(&msg, bincode::config::standard()).unwrap();
                    let chan = pending.remove(&id).unwrap();
                    chan.send(resp).unwrap_or(());
                } else {
                    panic!("received none");
                }
            }
        }
    }
}

type ProcessResult = Result<PathBuf, String>;

/// Start the container process. This function forks: the child blocks
/// indefinitely, while the parent returns. This needs to be a
/// separate process in order to run as root and create namespaces.
pub fn serve(user: &str, store_path: &Path) -> ContainerChannel {
    let (r0, w0) = std::os::unix::net::UnixStream::pair().unwrap();
    let (r1, w1) = std::os::unix::net::UnixStream::pair().unwrap();
    r0.set_nonblocking(true).unwrap();
    r1.set_nonblocking(true).unwrap();
    w0.set_nonblocking(true).unwrap();
    w1.set_nonblocking(true).unwrap();

    let pid = unsafe { libc::fork() };
    if pid == 0 {
        let rt = tokio::runtime::Runtime::new().unwrap();
        rt.block_on(async move {
            let r1 = tokio::net::UnixStream::from_std(r1).unwrap();
            let w0 = tokio::net::UnixStream::from_std(w0).unwrap();
            async_serve(user, store_path, r1, w0).await
        });
        unreachable!()
    } else {
        ContainerChannel { r: r0, w: w1 }
    }
}

/// Inner server callable from a Tokio runtime, used since Tokio
/// runtimes don't survive forks.
async fn async_serve(
    user: &str,
    store: &Path,
    r: tokio::net::UnixStream,
    w: tokio::net::UnixStream,
) {
    let encoder = tokio_util::codec::LengthDelimitedCodec::new();
    let writer = Arc::new(tokio::sync::Mutex::new(
        tokio_util::codec::FramedWrite::new(w, encoder),
    ));
    let decoder = tokio_util::codec::LengthDelimitedCodec::new();
    let mut reader = tokio_util::codec::FramedRead::new(r, decoder);
    debug!("drv process waiting");
    while let Some(received) = reader.next().await {
        if let Ok(received) = received {
            debug!("drv_process received {:?}", received.len());
            let ((id, rec_msg), _) = bincode::decode_from_slice::<(u64, BuildRequest), _>(
                &received,
                bincode::config::standard(),
            )
            .unwrap();
            let writer = writer.clone();
            let store = store.to_path_buf();
            let user = user.to_string();
            tokio::spawn(async move {
                let result: ProcessResult =
                    run_in_container(&user, &store, rec_msg).map_err(|e| format!("{:?}", e));
                debug!("result {:?}", result);
                let v = bincode::encode_to_vec(&(id, result), bincode::config::standard()).unwrap();
                let mut bytes = bytes::BytesMut::new();
                debug!("drv_process replying {:?}", v.len());
                bytes.extend_from_slice(&v);
                writer.lock().await.send(bytes.into()).await.unwrap()
            });
        }
    }
    info!("drv_process exited");
}

fn run_in_container(user: &str, store: &Path, r: BuildRequest) -> Result<PathBuf, Error> {
    let mut hasher = blake3::Hasher::new();
    hasher.update(r.name.as_bytes());
    hasher.update(b"\n");
    hasher.update(r.target.as_bytes());
    hasher.update(b"\n");
    debug!("run in container, path = {:#?}", r.paths);
    for p in r.paths.iter() {
        hasher.update(p.as_bytes());
        hasher.update(b"\n");
    }
    hasher.update(r.script.as_bytes());
    hasher.update(b"\n");
    let name = data_encoding::HEXLOWER.encode(hasher.finalize().as_bytes());

    // Guest dest.
    let dest = store.join(&name);

    if std::fs::metadata(&dest).is_ok() {
        let mut output_hasher = blake3::Hasher::new();
        let blakesums = dest.join("blake3sums");
        let file = match std::fs::File::open(&blakesums) {
            Ok(file) => file,
            Err(e) => {
                error!("Error {:?} {:?}: {:?}", blakesums, dest, e);
                return Err(e.into());
            }
        };
        output_hasher.update_reader(file)?;
        return Ok(store.join(&format!(
            "{}-{}",
            data_encoding::HEXLOWER.encode(output_hasher.finalize().as_bytes()),
            r.name,
        )));
    }

    // Tmp host path where things will be mounted.
    let tmp_dir = store.join(format!("{}.drv", name));

    // Full path of the store in the host.
    let tmp_store = Path::new(&tmp_dir).join(store.strip_prefix("/").unwrap());

    std::fs::create_dir_all(&tmp_store)?;

    // Host dest.
    let tmp_dest = tmp_store.join(&name);

    let newnet = if r.output_hash.is_some() {
        0
    } else {
        libc::CLONE_NEWNET
    };
    let pid = unsafe {
        libc::syscall(
            libc::SYS_clone3,
            &libc::clone_args {
                flags: (libc::CLONE_NEWPID | newnet | libc::CLONE_NEWNS) as u64,
                pidfd: 0,
                parent_tid: 0,
                child_tid: 0,
                stack: 0,
                stack_size: 0,
                tls: 0,
                set_tid: 0,
                set_tid_size: 0,
                cgroup: 0,
                exit_signal: libc::SIGCHLD as u64,
            },
            size_of::<libc::clone_args>(),
        )
    };

    if pid == 0 {
        match std::panic::catch_unwind(|| {
            inner_process(user, &r, &tmp_dir, &dest, &store, &tmp_store, &name)
        }) {
            Ok(Ok(())) => std::process::exit(0),
            Ok(Err(Error::BuildReturn { status })) => std::process::exit(status),
            _ => std::process::exit(1),
        }
    } else {
        debug!("waitpid");
        let mut status = 0;
        unsafe { libc::waitpid(pid as i32, &mut status, 0) };
        info!("return status {:?}", status);
        if status != 0 {
            debug!("returning error");
            return Err(Error::BuildReturn { status });
        }
        // Now that the paths have been unmounted, delete.
        if let Ok(dir) = std::fs::read_dir(&tmp_store) {
            for entry in dir {
                std::fs::remove_dir(&entry?.path()).unwrap_or(());
            }
            std::fs::remove_dir(&tmp_store).unwrap_or(());
        }
    }

    // And hash the output.
    debug!("tmp_dest {:?}", tmp_dest);
    let Ok(hashed) = hash_all(&tmp_dest) else {
        return Err(Error::NoDestDir);
    };
    let out = store.join(&format!(
        "{}-{}",
        data_encoding::HEXLOWER.encode(hashed.as_bytes()),
        r.name,
    ));

    std::fs::remove_dir_all(&dest).unwrap_or(());

    // Should be a symlink
    std::fs::rename(&tmp_dest, &dest).unwrap();
    match std::os::unix::fs::symlink(&dest, &out) {
        Ok(()) => (),
        Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
            let got = std::fs::read_link(&out).unwrap();
            if dest != got {
                Err::<(), _>(Error::WrongResultSymlink {
                    expected: dest,
                    got,
                })
                .unwrap();
            }
        }
        Err(e) => return Err(e.into()),
    }
    Ok(out)
}

fn patch_result_elf(root: &Path, f: &Path, target: &str) -> Result<bool, Error> {
    use elfedit::*;
    info!("patch_elf {:?}", f);
    let file = std::fs::OpenOptions::new()
        .read(true)
        .write(true)
        .open(&f)?;

    let mut elf = match Elf::open(&file) {
        Ok(elf) => elf,
        Err(e) => {
            info!("error opening {:?}: {:?}", file, e);
            return Ok(false);
        }
    };
    info!("patching {:?}", f);
    let Some(parsed) = elf.parse().unwrap() else {
        info!("No dynamic section");
        return Ok(false);
    };
    let needed: Vec<_> = parsed
        .needed()
        .map(|x| x.unwrap().to_str().unwrap().to_string())
        .collect();
    let interp = parsed.interpreter();
    if let Some(interp) = interp.unwrap() {
        let mut interp_ = interp.to_str().unwrap();
        let mut interp = Path::new(interp_).to_path_buf();
        debug!("existing interp: {interp:?}");

        if interp_.starts_with("/usr") || interp_.starts_with("/lib") {
            while interp_.starts_with("/usr") || interp_.starts_with("/lib") {
                if let Ok(target) = std::fs::read_link(&interp) {
                    let target = if target.is_relative() {
                        interp.parent().unwrap().join(target)
                    } else {
                        target
                    };
                    interp = target;
                    interp_ = interp.to_str().unwrap();
                } else {
                    break;
                }
            }
            debug!("target: {target:?}");
            let subst = CString::new(interp_).unwrap();
            info!("set interpreter {:?}", subst);
            elf.set_interpreter(subst.to_bytes_with_nul());
        } else {
            // TODO: not sure what else to do here, we
            // might want to set the interpreter to a
            // different value (equivalent to "recompiling
            // everything" on Nix).
            info!("Interpreter is {interp_}. Already patched?");
            return Ok(false);
        }
    } else if needed.is_empty() {
        return Ok(false);
    }

    let mut deps_h = HashSet::new();
    let mut path = String::new();

    for n in needed.iter() {
        for p in &["/usr/lib", "/usr/lib64", &format!("/usr/lib/{target}")] {
            let Ok(dep) = std::fs::read_link(Path::new(p).join(n)) else {
                continue;
            };
            if !deps_h.insert(dep.clone()) {
                continue;
            }
            debug!("patch_elf needed: {:?}", dep);
            debug!("root: {:?}", root);
            if !path.is_empty() {
                path.push(':')
            }
            path.push_str(dep.parent().unwrap().to_str().unwrap());
        }
    }
    path.push('\0');
    info!("Setting path {:?}", path);
    if path.len() > 1 {
        elf.set_runpath(&path.as_bytes());
    }

    Ok(elf.update(None).unwrap()) // map_err(From::from)
}

/// Can't be async since this needs to fork and Tokio doesn't work
/// across forks.
///
/// The fork is used to manage the mounts.
fn inner_process(
    user: &str,
    r: &BuildRequest,
    tmp_dir: &Path,
    dest: &Path,
    store: &Path,
    tmp_store: &Path,
    name: &str,
) -> Result<(), Error> {
    let tmp_usr = tmp_dir.join("usr");
    std::fs::create_dir_all(&tmp_usr)?;

    mount::make_root_private().unwrap();

    std::mem::forget(mount::Mount::ramfs(&tmp_usr).unwrap());

    std::os::unix::fs::symlink("/usr/lib", &tmp_dir.join("lib")).unwrap_or(());
    std::os::unix::fs::symlink("/usr/lib64", &tmp_dir.join("lib64")).unwrap_or(());

    for host in r.paths.iter() {
        let guest = tmp_dir.join(host.strip_prefix("/").unwrap());
        debug!("mounting {:?} to {:?}", host, guest);
        std::fs::create_dir_all(&guest).unwrap();
        debug!("created {:?} to {:?}", host, guest);
        match mount::Mount::bind(host, guest.to_str().unwrap()) {
            Ok(m) => std::mem::forget(m),
            Err(e) => error!("{e:?}"),
        }
        debug!("mounted {:?}", guest);
        if let Ok(target) = std::fs::read_link(&host) {
            info!("mounting link {:?}", target);
            if target.is_absolute() {
                let guest = tmp_dir.join(target.strip_prefix("/").unwrap());
                std::fs::create_dir_all(&guest).unwrap();
                match mount::Mount::bind(&target, guest.to_str().unwrap()) {
                    Ok(m) => std::mem::forget(m),
                    Err(e) => error!("{e:?}"),
                }
            }
        }

        let guest_usr = guest.join("usr");
        if let Ok(find) = find_files(guest_usr.clone()) {
            std::fs::create_dir(&guest_usr).unwrap_or(());
            for (f, m) in find {
                // Stripped: guest path of the link.
                let stripped = f.strip_prefix(&guest).unwrap();

                // Target: host path of the link.
                let target = tmp_dir.join(stripped);
                debug!("mount link {:?} {:?} {:?}", f, target, m);
                if m.is_dir() {
                    std::fs::create_dir(&target).unwrap_or(());
                } else if let Ok(out) = std::fs::read_link(&f) {
                    if std::fs::remove_file(&target).is_err() {
                        std::fs::remove_dir_all(&target).unwrap_or(());
                    }
                    std::os::unix::fs::symlink(&out, &target).unwrap();
                } else {
                    if std::fs::remove_file(&target).is_err() {
                        std::fs::remove_dir_all(&target).unwrap_or(());
                    }
                    std::os::unix::fs::symlink(Path::new(&host).join(&stripped), &target).unwrap();
                }
            }
        }
    }

    let builder_base = format!("{name}-builder.sh");
    let tmp_builder = tmp_store.join(&builder_base);
    let builder = store.join(&builder_base);
    std::fs::write(&tmp_builder, &r.script).unwrap();
    let mut perm = std::fs::metadata(&tmp_builder).unwrap().permissions();
    perm.set_mode(0o555);
    std::fs::set_permissions(&tmp_builder, perm).unwrap();
    let (uid, gid) = {
        let user_ffi = CString::new(user).unwrap();
        let pw = unsafe { libc::getpwnam(user_ffi.as_ptr()) };
        assert!(!pw.is_null());
        let pw = unsafe { &*pw };
        (pw.pw_uid, pw.pw_gid)
    };
    std::os::unix::fs::chown(tmp_dir, Some(uid), Some(gid)).unwrap();
    std::os::unix::fs::chown(tmp_store, Some(uid), Some(gid)).unwrap();

    let out_env = format!("DESTDIR={}", dest.to_str().unwrap()).to_string();
    let out_env = CString::new(out_env.as_str()).unwrap();
    let c = CString::new(builder.to_str().unwrap()).unwrap();

    let pid = unsafe { libc::fork() };
    if pid == 0 {
        info!("chrooting to {:?}", tmp_dir);
        privdrop::PrivDrop::default()
            .chroot(tmp_dir)
            .user(user)
            .apply()
            .unwrap();
        std::env::set_current_dir("/").unwrap();
        debug!("execve {:?}", c);
        unsafe {
            libc::execve(
                c.as_ptr(),
                [c.as_ptr(), std::ptr::null()].as_ptr(),
                [out_env.as_ptr(), std::ptr::null()].as_ptr(),
            );
        }
        panic!("execve failed: {:?}", std::io::Error::last_os_error())
    } else {
        let mut status = 0;
        unsafe { libc::waitpid(pid, &mut status, 0) };
        debug!("fork returned {status}");
        if status != 0 {
            debug!("returning error");
            return Err(Error::BuildReturn { status });
        }
    }

    info!("chrooting to {:?}", tmp_dir);
    unsafe {
        let c = CString::new(tmp_dir.to_str().unwrap()).unwrap();
        libc::chroot(c.as_ptr());
    }

    if let Ok(f) = find_files(dest.to_path_buf()) {
        for (f, _meta) in f {
            // Potentially patch the ELF.
            debug!("patching {f:?}");
            if let Err(e) = patch_result_elf(&tmp_dir, &f, &r.target) {
                error!("{:?}", e);
            }
        }
    }
    Ok(())
}

/// Create the `blake3sums` file of `p` and hash that file.
fn hash_all(p: &Path) -> Result<blake3::Hash, Error> {
    let mut hashes = Vec::new();
    for (f, _meta) in find_files(p.to_path_buf())? {
        // hash + write
        info!("hashing {:?}", f);
        if let Ok(link) = std::fs::read_link(&f) {
            hashes.push((f, link.to_str().unwrap().to_string()))
        } else if f.is_file() {
            let file = std::fs::File::open(&f)?;
            let mut hasher = blake3::Hasher::new();
            hasher.update_reader(file).unwrap();
            let hex = data_encoding::HEXLOWER.encode(hasher.finalize().as_bytes());
            hashes.push((f, hex))
        }
    }
    hashes.sort_by(|a, b| a.0.cmp(&b.0));
    info!("hashed all");

    let mut output_hasher = blake3::Hasher::new();
    let blakesums = p.join("blake3sums");
    let mut file = std::fs::File::create(&blakesums).unwrap();
    use std::io::Write;
    for (path, hash) in hashes {
        let path = path.to_str().unwrap();
        writeln!(file, "{} {}", hash, path)?;
        writeln!(output_hasher, "{} {}", hash, path)?;
    }
    Ok(output_hasher.finalize())
}