use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
State,
},
response::IntoResponse,
routing::get,
Router,
};
use axum_extra::{headers::UserAgent, TypedHeader};
use remote_test::{ServerMessage, WorkPackage};
use tracing::level_filters::LevelFilter;
use std::{collections::HashMap, net::SocketAddr};
use std::{ops::ControlFlow, process::Stdio};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, Layer};
use axum::extract::connect_info::ConnectInfo;
use futures::{sink::SinkExt, stream::StreamExt};
use tokio::sync::mpsc::{unbounded_channel, UnboundedSender};
use actix::prelude::*;
#[derive(Default)]
struct Worker {
clients: HashMap<SocketAddr, UnboundedSender<ServerMessage>>,
}
impl Worker {
fn send_all(&self, msg: ServerMessage) {
for (_, c) in &self.clients {
c.send(msg.clone()).unwrap();
}
}
}
impl Actor for Worker {
type Context = Context<Self>;
}
#[derive(Message)]
#[rtype(result = "()")]
struct StartWork(WorkPackage);
impl Handler<StartWork> for Worker {
type Result = ();
fn handle(&mut self, StartWork(msg): StartWork, _ctx: &mut Context<Self>) -> Self::Result {
self.send_all(ServerMessage::Ok { data: None });
std::fs::write(&msg.path, &msg.binary).unwrap();
let dir = msg.path.parent().unwrap().to_owned();
let child = std::process::Command::new(msg.path)
.args(msg.args)
.envs(msg.env)
.current_dir(msg.cwd.unwrap_or(dir))
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.unwrap();
let output = child.wait_with_output().unwrap();
let stdout = String::from_utf8(output.stdout).unwrap();
let stderr = String::from_utf8(output.stderr).unwrap();
self.send_all(ServerMessage::ExecutionFinished { stdout, stderr });
}
}
#[derive(Message)]
#[rtype(result = "Option<UnboundedSender<ServerMessage>>")]
struct Register((SocketAddr, UnboundedSender<ServerMessage>));
impl Handler<Register> for Worker {
type Result = Option<UnboundedSender<ServerMessage>>;
fn handle(&mut self, Register((who, tx)): Register, _ctx: &mut Context<Self>) -> Self::Result {
self.clients.insert(who, tx)
}
}
#[actix_rt::main]
async fn main() {
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer().with_filter(LevelFilter::INFO))
.init();
let addr = Worker::default().start();
let app = Router::new().route("/ws", get(ws_handler)).with_state(addr);
let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
.await
.unwrap();
tracing::debug!("listening on {}", listener.local_addr().unwrap());
axum::serve(
listener,
app.into_make_service_with_connect_info::<SocketAddr>(),
)
.await
.unwrap();
}
async fn ws_handler(
State(worker): State<Addr<Worker>>,
ws: WebSocketUpgrade,
user_agent: Option<TypedHeader<UserAgent>>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
) -> impl IntoResponse {
let user_agent = if let Some(TypedHeader(user_agent)) = user_agent {
user_agent.to_string()
} else {
String::from("Unknown browser")
};
println!("`{user_agent}` at {addr} connected.");
ws.on_upgrade(move |socket| handle_socket(socket, addr, worker))
}
async fn handle_socket(mut socket: WebSocket, who: SocketAddr, worker: Addr<Worker>) {
if socket.send(Message::Ping(vec![1, 2, 3])).await.is_ok() {
println!("Pinged {who}...");
} else {
println!("Could not send ping {who}!");
return;
}
let (tx, mut rx) = unbounded_channel::<ServerMessage>();
if let Some(msg) = socket.recv().await {
if let Ok(msg) = msg {
if process_message(msg, who, worker.clone()).await.is_break() {
return;
}
} else {
println!("client {who} abruptly disconnected");
return;
}
}
worker.send(Register((who, tx.clone()))).await.unwrap();
let (mut sender, mut receiver) = socket.split();
let mut recv_task = tokio::spawn(async move {
let mut cnt = 0usize;
while let Some(Ok(msg)) = receiver.next().await {
cnt += 1;
match process_message(msg, who, worker.clone()).await {
ControlFlow::Continue(msg) => tx.send(msg.into()).unwrap(),
ControlFlow::Break(_) => break,
}
}
cnt
});
let mut send_task = tokio::spawn(async move {
let mut cnt = 0usize;
while let Some(msg) = rx.recv().await {
cnt += 1;
sender
.send(Message::Text(serde_json::to_string(&msg).unwrap()))
.await
.unwrap();
}
cnt
});
tokio::select! {
rv_a = (&mut send_task) => {
match rv_a {
Ok(a) => println!("{a} messages sent to {who}"),
Err(a) => println!("Error sending messages {a:?}")
}
recv_task.abort();
},
rv_b = (&mut recv_task) => {
match rv_b {
Ok(b) => println!("Received {b} messages"),
Err(b) => println!("Error receiving messages {b:?}")
}
send_task.abort();
}
}
println!("Websocket context {who} destroyed");
}
async fn process_message(
msg: Message,
who: SocketAddr,
worker: Addr<Worker>,
) -> ControlFlow<(), ServerMessage> {
match msg {
Message::Text(t) => match serde_json::from_str::<WorkPackage>(&t) {
Ok(pkg) => worker.send(StartWork(pkg)).await.unwrap(),
Err(e) => {
return ControlFlow::Continue(ServerMessage::Error {
msg: format!("{e}"),
});
}
},
Message::Binary(d) => {
println!(">>> {} sent {} bytes: {:?}", who, d.len(), d);
}
Message::Close(c) => {
if let Some(cf) = c {
println!(
">>> {} sent close with code {} and reason `{}`",
who, cf.code, cf.reason
);
} else {
println!(">>> {who} somehow sent close message without CloseFrame");
}
return ControlFlow::Break(());
}
_ => (),
}
ControlFlow::Continue(ServerMessage::Ok { data: None })
}