use anyhow::bail;
use async_trait::async_trait;
use serde::de::DeserializeOwned;
use tracing::debug;
use url::Url;
pub use http::{
HeaderMap,
StatusCode,
};
pub type Result<T> = anyhow::Result<T>;
pub trait IntoUrl: Send {
fn into_url(self) -> Result<Url>;
}
impl IntoUrl for Url {
fn into_url(self) -> Result<Url> {
Ok(self)
}
}
impl<'a> IntoUrl for &'a str {
fn into_url(self) -> Result<Url> {
Ok(Url::parse(self)?)
}
}
impl IntoUrl for String {
fn into_url(self) -> Result<Url> {
Ok(Url::parse(self.as_str())?)
}
}
#[async_trait]
pub trait Client: Send + Sync {
type Response: Response;
async fn get<U: IntoUrl>(&self, url: U) -> Result<Self::Response>;
}
#[async_trait]
pub trait Response: Send + Sync {
async fn text(self) -> Result<String>;
async fn json<T: DeserializeOwned>(self) -> Result<T>;
fn status(&self) -> StatusCode;
fn content_length(&self) -> Option<u64>;
fn headers(&self) -> &HeaderMap;
}
#[async_trait]
impl Client for reqwest::Client {
type Response = reqwest::Response;
async fn get<U: IntoUrl>(&self, url: U) -> Result<Self::Response> {
let resp = self.get(url.into_url()?).send().await?;
let status = resp.status();
debug!("{status}");
if !status.is_success() {
debug!("{:#?}", resp.headers());
let body = resp.text().await?;
bail!(body);
}
Ok(resp)
}
}
#[async_trait]
impl Response for reqwest::Response {
async fn text(self) -> Result<String> {
Ok(self.text().await?)
}
async fn json<T: DeserializeOwned>(self) -> Result<T> {
Ok(self.json().await?)
}
fn status(&self) -> StatusCode {
self.status()
}
fn content_length(&self) -> Option<u64> {
self.content_length()
}
fn headers(&self) -> &HeaderMap {
self.headers()
}
}
pub fn client() -> impl Client<Response = reqwest::Response> {
static UA: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
reqwest::Client::builder().user_agent(UA).build().unwrap()
}
#[cfg(any(test, feature = "testing"))]
pub mod client {
use std::path::{
Path,
PathBuf,
};
use anyhow::Context;
use http::{
HeaderName,
HeaderValue,
};
use tokio::io::{
AsyncRead,
AsyncReadExt as _,
AsyncWrite,
AsyncWriteExt as _,
};
use tracing::debug_span;
use super::*;
pub fn mock<F>(f: F) -> impl Client
where
F: Fn(Url) -> ::http::Response<reqwest::Body> + Send + Sync,
{
struct C<F> {
respond: F,
}
#[async_trait]
impl<F> Client for C<F>
where
F: Fn(Url) -> ::http::Response<reqwest::Body> + Send + Sync,
{
type Response = reqwest::Response;
async fn get<U: IntoUrl>(&self, url: U) -> Result<Self::Response> {
let url = url.into_url()?;
Ok(Self::Response::from((self.respond)(url)))
}
}
C { respond: f }
}
pub fn caching<C, P>(inner: C, path: P) -> impl Client<Response = C::Response>
where
C: Client<Response = reqwest::Response>,
P: Into<PathBuf>,
{
struct Caching<C> {
inner: C,
path: PathBuf,
}
#[async_trait]
impl<C> Client for Caching<C>
where
C: Client<Response = reqwest::Response>,
{
type Response = C::Response;
async fn get<U: IntoUrl>(&self, url: U) -> Result<Self::Response> {
let url = url.into_url()?;
let mut cache_path = self.path.clone();
cache_path.push(blake3::hash(url.as_str().as_bytes()).to_hex().as_str());
cache_path.set_extension("http");
let span = debug_span!(
"cached::get",
"url" = url.as_str(),
"cache-path" = cache_path.display().to_string()
);
fn context<'a>(
what: &'static str,
cache_path: &'a Path,
) -> impl FnOnce() -> String + 'a {
move || {
format!(
"failed to open cache file for {what}: {}",
cache_path.display()
)
}
}
if !cache_path.try_exists()? {
span.in_scope(|| debug!("cache miss"));
let resp = self.inner.get(url).await?;
let tmp = cache_path.with_extension("tmp");
let mut file = tokio::fs::File::create(&tmp)
.await
.with_context(context("writing", &tmp))?;
write(resp, &mut file).await?;
file.sync_all().await?;
drop(file);
tokio::fs::rename(tmp, &cache_path).await?;
}
span.in_scope(|| debug!("cache load"));
let file = tokio::fs::File::open(&cache_path)
.await
.with_context(context("reading", &cache_path))?;
read(tokio::io::BufReader::new(file)).await
}
}
Caching {
inner,
path: path.into(),
}
}
async fn write<W: AsyncWrite + Unpin>(resp: reqwest::Response, mut out: W) -> Result<()> {
out.write_all(format!("{:?} {}", resp.version(), resp.status().as_u16(),).as_bytes())
.await?;
if let Some(reason) = resp.status().canonical_reason() {
out.write_all(format!(" {}", reason).as_bytes()).await?;
}
out.write_all(b"\r\n").await?;
for (name, value) in resp.headers() {
out.write_all(name.as_ref()).await?;
out.write_all(b": ").await?;
out.write_all(value.as_bytes()).await?;
out.write_all(b"\r\n").await?;
}
out.write_all(b"\r\n").await?;
let body = resp.bytes().await?;
out.write_all(&body).await?;
Ok(())
}
async fn read<R: AsyncRead + Unpin>(mut io: R) -> Result<reqwest::Response> {
let mut buf = Vec::with_capacity(1024);
io.read_to_end(&mut buf).await?;
let mut headers = [httparse::EMPTY_HEADER; 64];
let mut response = httparse::Response::new(&mut headers);
let httparse::Status::Complete(len) = response.parse(&buf)? else {
bail!("Incomplete response")
};
let body: &[u8] = &buf[len..];
let mut resp = http::Response::builder().status(response.code.unwrap_or(200));
let headers = resp.headers_mut().unwrap();
for header in response.headers {
headers.insert(
HeaderName::from_bytes(header.name.as_bytes())?,
HeaderValue::from_bytes(header.value)?,
);
}
let resp = resp.body(body.to_vec())?;
Ok(resp.into())
}
}