+ use ci::Message;
+ use serde_derive::*;
+ use std::io::Read;
+ use std::net::{Ipv6Addr, SocketAddr, ToSocketAddrs};
+ use std::path::{Path, PathBuf};
+ use std::sync::{Arc, Mutex};
+ use std::time::{Duration, SystemTime};
+ use tracing::*;
+ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
+
+ #[derive(Serialize, Deserialize)]
+ struct ConfigFile {
+ key_path: String,
+ port: u16,
+ timeout_secs: usize,
+ server_public_keys: Vec<String>,
+ log_path: String,
+ tarball_path: String,
+ }
+
+ #[derive(Serialize, Deserialize)]
+ struct BuildResult {
+ finished: chrono::DateTime<chrono::Utc>,
+ status: Option<i32>,
+ link: Option<PathBuf>,
+ job: ci::Job,
+ }
+
+ use clap::*;
+
+ #[derive(Debug, Parser)]
+ pub struct App {
+ #[arg(short, long)]
+ config: PathBuf,
+ }
+
+ #[tokio::main]
+ async fn main() {
+ tracing_subscriber::registry()
+ .with(
+ tracing_subscriber::EnvFilter::try_from_default_env()
+ .unwrap_or_else(|_| "ci=debug".into()),
+ )
+ .with(tracing_subscriber::fmt::layer())
+ .init();
+ let matches = App::parse();
+ let conf: ConfigFile =
+ toml::from_str(&std::fs::read_to_string(&matches.config).unwrap()).unwrap();
+ let config = Arc::new(thrussh::client::Config::default());
+ let addr = Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0);
+ let addr = (addr, conf.port).to_socket_addrs().unwrap().next().unwrap();
+ let key = Arc::new(thrussh_keys::load_secret_key(&conf.key_path, None).unwrap());
+ let (sender, mut receiver) = tokio::sync::mpsc::channel(1);
+ let client = CiClient {
+ process: Arc::new(Mutex::new(Process::default())),
+ log_path: Path::new(&conf.log_path).to_path_buf(),
+ tarball_path: Path::new(&conf.tarball_path).to_path_buf(),
+ last_window_adjustment: SystemTime::now(),
+ server_public_keys: Arc::new(
+ conf.server_public_keys
+ .iter()
+ .map(|p| thrussh_keys::parse_public_key_base64(p).unwrap())
+ .collect(),
+ ),
+ sender,
+ };
+ loop {
+ if let Err(e) = client
+ .protocol(&addr, config.clone(), key.clone(), &mut receiver)
+ .await
+ {
+ error!("restarting because of error: {:?}", e)
+ }
+ tokio::time::sleep(std::time::Duration::from_secs(1)).await;
+ }
+ }
+
+ #[derive(Clone, Debug)]
+ pub struct CiClient {
+ process: Arc<Mutex<Process>>,
+ tarball_path: PathBuf,
+ log_path: PathBuf,
+ last_window_adjustment: SystemTime,
+ server_public_keys: Arc<Vec<thrussh_keys::key::PublicKey>>,
+ sender: tokio::sync::mpsc::Sender<(ci::Job, Option<i32>, Option<PathBuf>)>,
+ }
+
+ #[derive(Debug, Default)]
+ struct Process {
+ child: Option<tokio::task::JoinHandle<Result<(), anyhow::Error>>>,
+ job: Option<ci::Job>,
+ tarball: Option<(std::fs::File, usize)>,
+ }
+
+ impl Process {
+ fn is_ready(&self) -> bool {
+ self.child.is_none() && self.job.is_none()
+ }
+ }
+
+ impl CiClient {
+ pub async fn protocol(
+ &self,
+ addr: &SocketAddr,
+ config: Arc<thrussh::client::Config>,
+ key: Arc<thrussh_keys::key::KeyPair>,
+ receiver: &mut tokio::sync::mpsc::Receiver<(ci::Job, Option<i32>, Option<PathBuf>)>,
+ ) -> Result<(), anyhow::Error> {
+ let mut h = thrussh::client::connect(config, &addr, self.clone()).await?;
+ debug!("Opening session");
+ if !h.authenticate_publickey("ci", key).await? {
+ return Ok(());
+ }
+ let mut channel = h.channel_open_session().await?;
+ channel
+ .data(
+ &bincode::serialize(&Message::Handshake {
+ version: ci::VERSION,
+ id: 0,
+ })
+ .unwrap()[..],
+ )
+ .await?;
+ debug!("handshake done");
+ 'outer: loop {
+ if self.process.lock().unwrap().is_ready() {
+ channel
+ .data(&bincode::serialize(&Message::Ready).unwrap()[..])
+ .await?;
+ debug!("ready");
+ }
+ loop {
+ tokio::select! {
+ msg = channel.wait() => {
+ debug!("msg = {:?}", msg);
+ if !self.handle_msg(&mut channel, &self.sender, msg).await? {
+ break 'outer
+ }
+ }
+ msg = receiver.recv() => {
+ debug!("message {:#?}", msg);
+ if let Some(p) = self.process.lock().unwrap().child.take() {
+ p.await??
+ }
+ if let Some((job, exit_status, path)) = msg {
+ self.send_log(&mut channel, job, exit_status, path).await?
+ }
+ channel.data(&bincode::serialize(&Message::Ready).unwrap()[..]).await?;
+ }
+ }
+ }
+ }
+ Ok(())
+ }
+
+ async fn handle_msg(
+ &self,
+ channel: &mut thrussh::client::Channel,
+ sender: &tokio::sync::mpsc::Sender<(ci::Job, Option<i32>, Option<PathBuf>)>,
+ msg: Option<thrussh::ChannelMsg>,
+ ) -> Result<bool, anyhow::Error> {
+ match msg {
+ Some(thrussh::ChannelMsg::Data { data }) => {
+ let mut proc = self.process.lock().unwrap();
+ if let Some((mut f, mut len)) = proc.tarball.take() {
+ debug!("len = {:?}", len);
+ use std::io::Write;
+ f.write_all(&data)?;
+ len -= data.len();
+ if len > 0 {
+ proc.tarball = Some((f, len));
+ }
+ return Ok(true);
+ }
+ let msg = bincode::deserialize::<Message>(&data);
+ debug!("msg = {:?}", msg);
+ match msg {
+ Ok(Message::Job(job)) => {
+ self.handle_job(channel, sender.clone(), &mut proc, job)
+ .await?
+ }
+ Ok(Message::Chunk { id, len, .. }) => {
+ let p = self.tarball_path.join(&format!("{}.tar.gz.tmp", id));
+ if len == 0 {
+ let p2 = self.tarball_path.join(&format!("{}.tar.gz", id));
+ std::fs::rename(&p, &p2)?;
+ proc.tarball = None;
+ let job = proc.job.take().unwrap();
+ self.handle_job(channel, sender.clone(), &mut proc, job)
+ .await?;
+ return Ok(true);
+ }
+ let file = std::fs::OpenOptions::new()
+ .write(true)
+ .create(true)
+ .append(true)
+ .open(&p)
+ .unwrap();
+ proc.tarball = Some((file, len as usize));
+ }
+ Ok(msg) => {
+ debug!("msg = {:?}", msg);
+ }
+ _ => return Ok(false),
+ }
+ }
+ None => return Ok(false),
+ msg => debug!("{:?}", msg),
+ }
+ Ok(true)
+ }
+
+ async fn handle_job(
+ &self,
+ channel: &mut thrussh::client::Channel,
+ sender: tokio::sync::mpsc::Sender<(ci::Job, Option<i32>, Option<PathBuf>)>,
+ process: &mut Process,
+ job: ci::Job,
+ ) -> Result<(), anyhow::Error> {
+ let p = self.tarball_path.join(&format!("{}.tar.gz", job.id));
+ debug!("p = {:?}", p);
+ if std::fs::metadata(&p).is_err() {
+ debug!("getting tarball");
+ channel
+ .data(&bincode::serialize(&Message::GetTarball { id: job.id }).unwrap()[..])
+ .await?;
+ process.job = Some(job);
+ return Ok(());
+ }
+ debug!("tar = {:?}", p);
+ let status = std::process::Command::new("tar")
+ .args(&["-xf", p.to_str().unwrap()])
+ .current_dir(&self.tarball_path)
+ .status()
+ .unwrap();
+ debug!("nix: {:?}", status);
+
+ let tarballp = self.tarball_path.join(job.id.to_string());
+ let logp = self.log_path.clone();
+
+ let result_path = logp.join(&format!("{}.result", job.id));
+ if let Ok(mut f) = std::fs::File::open(&result_path) {
+ if let Ok(build_result) = serde_json::from_reader::<_, BuildResult>(&mut f) {
+ sender
+ .send((build_result.job, build_result.status, build_result.link))
+ .await?;
+ return Ok(());
+ }
+ }
+
+ debug!("p = {:?}", tarballp);
+ process.child = Some(tokio::task::spawn(async move {
+ let mut process = tokio::process::Command::new("nix-build")
+ .arg("default.nix")
+ .current_dir(&tarballp)
+ .stdin(std::process::Stdio::null())
+ .stdout(std::process::Stdio::piped())
+ .stderr(std::process::Stdio::piped())
+ .spawn()
+ .unwrap();
+ let stdout = process.stdout.as_mut().unwrap();
+ let stderr = process.stderr.as_mut().unwrap();
+ let mut fstdout =
+ tokio::fs::File::create(logp.join(&format!("{}.stdout", job.id))).await?;
+ let mut fstderr =
+ tokio::fs::File::create(logp.join(&format!("{}.stderr", job.id))).await?;
+ let (a, b) = futures::future::join(
+ tokio::io::copy(stdout, &mut fstdout),
+ tokio::io::copy(stderr, &mut fstderr),
+ )
+ .await;
+ a?;
+ b?;
+ let status = process.wait().await?;
+ debug!("status = {:?}", status);
+
+ let mut result_file = std::fs::File::create(&result_path)?;
+ let link = std::fs::read_link(&tarballp.join("result")).ok();
+ serde_json::to_writer(
+ &mut result_file,
+ &BuildResult {
+ finished: chrono::Utc::now(),
+ status: status.code(),
+ job: job.clone(),
+ link: link.clone(),
+ },
+ )?;
+
+ sender.send((job, status.code(), link)).await?;
+
+ std::fs::remove_dir_all(&tarballp).unwrap_or(());
+ std::fs::remove_file(&p)?;
+
+ Ok(())
+ }));
+ Ok(())
+ }
+
+ async fn send_log(
+ &self,
+ channel: &mut thrussh::client::Channel,
+ job: ci::Job,
+ exit_status: Option<i32>,
+ path: Option<PathBuf>,
+ ) -> Result<(), anyhow::Error> {
+ let id = job.id;
+ let msg = Message::Log {
+ job,
+ exit_status,
+ path,
+ };
+ channel.data(&bincode::serialize(&msg).unwrap()[..]).await?;
+
+ let mut buf = Vec::with_capacity(4096);
+ debug!(
+ "stdout: {:?}",
+ self.log_path.join(&format!("{}.stdout", id))
+ );
+ if let Ok(ref mut stdout) =
+ std::fs::File::open(&self.log_path.join(&format!("{}.stdout", id)))
+ {
+ let len = channel.writable_packet_size().min(MAX_BUF_SIZE);
+ buf.resize(len, 0);
+ while let Ok(n) = stdout.read(&mut buf) {
+ if n == 0 {
+ channel
+ .data(
+ &bincode::serialize(&Message::Chunk {
+ id,
+ stderr: false,
+ len: 0,
+ })
+ .unwrap()[..],
+ )
+ .await?;
+ break;
+ }
+ channel
+ .data(
+ &bincode::serialize(&Message::Chunk {
+ id,
+ stderr: false,
+ len: n as u32,
+ })
+ .unwrap()[..],
+ )
+ .await?;
+ channel.data(&buf[..n]).await?
+ }
+ }
+ if let Ok(ref mut stdout) =
+ std::fs::File::open(&self.log_path.join(&format!("{}.stderr", id)))
+ {
+ let len = channel.writable_packet_size().min(MAX_BUF_SIZE);
+ buf.resize(len, 0);
+ while let Ok(n) = stdout.read(&mut buf) {
+ if n == 0 {
+ channel
+ .data(
+ &bincode::serialize(&Message::Chunk {
+ id,
+ stderr: true,
+ len: 0,
+ })
+ .unwrap()[..],
+ )
+ .await?;
+ break;
+ }
+ channel
+ .data(
+ &bincode::serialize(&Message::Chunk {
+ id,
+ stderr: true,
+ len: n as u32,
+ })
+ .unwrap()[..],
+ )
+ .await?;
+ channel.data(&buf[..n]).await?
+ }
+ }
+ Ok(())
+ }
+ }
+
+ const MAX_BUF_SIZE: usize = 1 << 16;
+
+ impl thrussh::client::Handler for CiClient {
+ type Error = anyhow::Error;
+
+ fn check_server_key(
+ self,
+ server_public_key: &thrussh_keys::key::PublicKey,
+ ) -> impl futures::Future<Output = Result<(Self, bool), Self::Error>> {
+ let valid = self
+ .server_public_keys
+ .iter()
+ .any(|p| p == server_public_key);
+ futures::future::ready(Ok((self, valid)))
+ }
+
+ fn adjust_window(&mut self, _channel: thrussh::ChannelId, target: u32) -> u32 {
+ let elapsed = self.last_window_adjustment.elapsed().unwrap();
+ self.last_window_adjustment = SystemTime::now();
+ if target >= 10_000_000 {
+ return target;
+ }
+ if elapsed < Duration::from_secs(2) {
+ target * 2
+ } else if elapsed > Duration::from_secs(8) {
+ target / 2
+ } else {
+ target
+ }
+ }
+ }