use axum::{extract::State, http};
use http_body_util::BodyExt;
use hyper::header::*;
use tracing::*;
pub const HSTS: &str = "max-age=31536000; includeSubDomains";
pub struct BodyBuilder {
pub http: http::response::Builder,
pub html: bool,
}
type Body = http_body_util::Full<bytes::Bytes>;
impl BodyBuilder {
fn body(self, body: bytes::Bytes) -> Result<hyper::Response<Body>, http::Error> {
let html =
self.html
|| (self.http.headers_ref().map(|m| {
m.get(CONTENT_TYPE).and_then(|x| x.to_str().ok()) == Some("text/html")
}) == Some(true));
let data = if html {
minify_html::minify(&body, &minify_html::Cfg::default()).into()
} else {
body
};
Ok(self.http.body(data.into())?)
}
}
#[axum::debug_handler]
pub async fn node_proxy_resp(
State(config): State<crate::Config>,
req: axum::extract::Request,
) -> hyper::Response<Body> {
node_proxy(&config, req).await.unwrap()
}
pub async fn node_proxy(
config: &crate::config::Config,
req: axum::extract::Request,
) -> Result<hyper::Response<Body>, axum::http::Error> {
match node_proxy_(config, req).await {
Ok(resp) => {
let (mut parts, body) = resp.into_parts();
let body = body.collect().await.unwrap().to_bytes();
parts.headers.insert(
hyper::header::STRICT_TRANSPORT_SECURITY,
HSTS.try_into().unwrap(),
);
let html = parts
.headers
.get(CONTENT_TYPE)
.and_then(|x| x.to_str().ok())
== Some("text/html");
let data = if html {
minify_html::minify(&body, &minify_html::Cfg::default()).into()
} else {
body
};
let len = data.len();
let mut resp = hyper::Response::new(data.into());
*resp.status_mut() = parts.status;
let h = resp.headers_mut();
for (k, v) in parts.headers {
h.insert(k.unwrap(), v);
}
h.insert(
hyper::header::CONTENT_LENGTH,
len.to_string().parse().unwrap(),
);
Ok(resp)
}
Err(e) => {
error!("proxy error {:?}", e);
let resp = http::response::Response::builder()
.status(http::StatusCode::INTERNAL_SERVER_ERROR)
.header(http::header::STRICT_TRANSPORT_SECURITY, HSTS);
Ok(BodyBuilder {
http: resp,
html: false,
}
.body("".into())?)
}
}
}
pub async fn node_proxy_(
config: &crate::config::Config,
req: axum::extract::Request,
) -> Result<hyper::Response<hyper::body::Incoming>, crate::Error> {
let mut sender = if let Some(ref svelte_socket) = config.svelte_socket {
use hyper_util::rt::TokioIo;
use tokio::net::UnixStream;
let stream = UnixStream::connect(svelte_socket).await?;
let io = TokioIo::new(stream);
let (sender, conn) = hyper::client::conn::http1::handshake(io).await?;
tokio::task::spawn(async move {
if let Err(err) = conn.await {
error!("Connection failed: {:?}", err);
}
});
sender
} else {
use hyper_util::rt::TokioIo;
use tokio::net::TcpStream;
let stream = TcpStream::connect("localhost:5173").await?;
let io = TokioIo::new(stream);
let (sender, conn) = hyper::client::conn::http1::handshake(io).await?;
tokio::task::spawn(async move {
if let Err(err) = conn.await {
error!("Connection failed: {:?}", err);
}
});
sender
};
let mut req_ = hyper::Request::builder().method(req.method().clone()).uri(
hyper::Uri::builder()
.path_and_query(req.uri().path_and_query().unwrap().clone())
.build()
.unwrap(),
);
for (k, v) in req.headers() {
req_ = req_.header(k, v)
}
if cfg!(debug_assertions) {
req_ = req_.header(HOST, "localhost:5173");
} else {
req_ = req_.header(HOST, "nest.pijul.org");
}
let body = req.into_body();
let req_ = req_.body(body)?;
Ok(sender.send_request(req_).await?)
}