use std::{borrow::Cow, fmt::Display, marker::PhantomData};
use thiserror::Error;
pub use url::Url;
pub use request_type::{Any, Gemini};
mod request_type;
#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd)]
pub struct Request<T: request_type::RequestType> {
url: Url,
_phantom: PhantomData<T>,
}
pub type AnyRequest = Request<request_type::Any>;
pub type GeminiRequest = Request<request_type::Gemini>;
#[derive(Debug, Copy, Clone, Error)]
pub struct InvalidRequest {
_priv: (),
}
impl Display for InvalidRequest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "invalid request")
}
}
impl InvalidRequest {
fn new() -> Self {
InvalidRequest { _priv: () }
}
}
impl<T: request_type::RequestType> Request<T> {
pub const MAX_URL_LEN: usize = 1024;
pub const DEFAULT_PORT: u16 = 1965;
pub const GEMINI_SCHEME: &'static str = "gemini";
pub fn from_uri(uri: &str) -> Result<AnyRequest, InvalidRequest> {
let uri = if uri.contains("://") {
Cow::from(uri)
} else {
format!("{}://{}", Self::GEMINI_SCHEME, uri).into()
};
let url = Url::parse(&uri).map_err(|_| InvalidRequest::new())?;
Self::from_url(url)
}
pub fn from_url(url: Url) -> Result<AnyRequest, InvalidRequest> {
if url.as_str().len() > Self::MAX_URL_LEN {
Err(InvalidRequest::new())
} else {
Ok(Request {
url,
_phantom: PhantomData,
})
}
}
pub fn gemini_request(
host: &str,
port: Option<u16>,
path: &str,
) -> Result<GeminiRequest, InvalidRequest> {
let url = format!(
"{}://{}:{}/{}",
Self::GEMINI_SCHEME,
host,
port.unwrap_or(1965),
path
);
let url = Url::parse(url.as_str()).map_err(|_| InvalidRequest::new())?;
if url.as_str().len() > Self::MAX_URL_LEN {
Err(InvalidRequest::new())
} else {
Ok(Request {
url,
_phantom: PhantomData,
})
}
}
pub fn is_gemini_request(&self) -> bool {
self.scheme() == Self::GEMINI_SCHEME
&& self.url.has_authority()
&& self.url.username().is_empty()
&& self.url.password().is_none()
&& self.url.host().is_some()
&& !self.url.cannot_be_a_base()
}
pub fn scheme(&self) -> &str {
self.url.scheme()
}
pub fn url(&self) -> &Url {
&self.url
}
}
impl AnyRequest {
pub fn into_gemini_request(mut self) -> Result<GeminiRequest, InvalidRequest> {
if self.is_gemini_request() {
if self.url.port().is_none() {
self.url.set_port(Some(Self::DEFAULT_PORT)).unwrap()
}
Ok(Request {
url: self.url,
_phantom: PhantomData,
})
} else {
Err(InvalidRequest::new())
}
}
}
impl GeminiRequest {
pub fn host(&self) -> &str {
self.url.host_str().unwrap()
}
pub fn path(&self) -> &str {
self.url.path()
}
pub fn port(&self) -> u16 {
self.url.port().unwrap()
}
}
#[cfg(feature = "parsers")]
pub mod parse {
use std::str;
use nom::{
bytes::streaming::{tag, take_until},
combinator::map_res,
error::context,
sequence::terminated,
IResult,
};
use super::*;
pub fn request(input: &[u8]) -> IResult<&[u8], AnyRequest> {
context(
"request",
map_res(terminated(take_until("\r\n"), tag("\r\n")), |bs| {
let s = str::from_utf8(bs).map_err(|_| InvalidRequest::new())?;
AnyRequest::from_uri(s)
}),
)(input)
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_gemini_request() {
let bytes = b"gemini://foo.bar.baz:1966/path\r\n";
assert!(request(bytes).unwrap().1.is_gemini_request())
}
#[test]
fn test_gemini_request_no_port() {
let bytes = b"gemini://foo.bar.baz/path\r\n";
assert!(request(bytes).unwrap().1.is_gemini_request())
}
#[test]
fn test_generic_request() {
let bytes = b"http://goggle.com/snoop\r\n";
assert!(request(bytes).is_ok())
}
#[test]
fn test_no_scheme() {
let bytes = b"foo.bar.baz/path\r\n";
assert!(request(bytes).unwrap().1.is_gemini_request())
}
}
}