use std::fmt::{self, Display, Formatter};
use std::future::Future;
use std::io::Write;
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use bytes::Bytes;
use futures_core::ready;
use http::header::{HeaderValue, AUTHORIZATION, CONTENT_TYPE};
use http::Uri;
use http_body::Body;
use oauth_credentials::Credentials;
use pin_project_lite::pin_project;
use serde::{de, Deserialize};
use crate::error::Error;
use crate::response::RawResponseFuture;
use crate::traits::{HttpService, HttpTryFuture};
use self::private::AuthDeserialize;
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct AccessToken {
pub credentials: Credentials<Box<str>>,
pub user_id: i64,
pub screen_name: Box<str>,
}
#[derive(Clone, Debug, Deserialize)]
#[non_exhaustive]
pub struct AccessToken2 {
pub access_token: Box<str>,
}
pin_project! {
pub struct AuthFuture<T, F: HttpTryFuture> {
#[pin]
inner: RawResponseFuture<F>,
marker: PhantomData<fn() -> T>,
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum AuthAccessType {
Read,
Write,
}
impl<T: AuthDeserialize, F: HttpTryFuture> AuthFuture<T, F> {
pub(crate) fn new(response: F) -> Self {
AuthFuture {
inner: RawResponseFuture::new(response),
marker: PhantomData,
}
}
}
impl<T: AuthDeserialize, F: HttpTryFuture> Future for AuthFuture<T, F> {
#[allow(clippy::type_complexity)]
type Output = Result<T, Error<F::Error, <F::Body as Body>::Error>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let res = ready!(self.project().inner.poll(cx)?);
let data = T::deserialize(res.data).ok_or(Error::Unexpected)?;
Poll::Ready(Ok(data))
}
}
impl AuthAccessType {
pub fn as_str(&self) -> &'static str {
match *self {
AuthAccessType::Read => "read",
AuthAccessType::Write => "write",
}
}
}
impl AsRef<str> for AuthAccessType {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl Display for AuthAccessType {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str(self.as_ref())
}
}
impl AuthDeserialize for Credentials<Box<str>> {
fn deserialize(input: Bytes) -> Option<Self> {
serde_urlencoded::from_bytes(&input).ok()
}
}
impl AuthDeserialize for AccessToken {
fn deserialize(input: Bytes) -> Option<Self> {
serde_urlencoded::from_bytes(&input).ok()
}
}
impl AuthDeserialize for AccessToken2 {
fn deserialize(input: Bytes) -> Option<Self> {
serde_json::from_slice(&input).ok()
}
}
pub fn request_token<C, S, B>(
client_credentials: &Credentials<C>,
callback: &str,
x_auth_access_type: Option<AuthAccessType>,
client: &mut S,
) -> AuthFuture<Credentials<Box<str>>, S::Future>
where
C: AsRef<str>,
S: HttpService<B>,
B: Default,
{
let req = request_token_request(client_credentials.as_ref(), callback, x_auth_access_type);
AuthFuture::new(client.call(req.map(|()| Default::default())))
}
fn request_token_request(
client_credentials: Credentials<&str>,
callback: &str,
x_auth_access_type: Option<AuthAccessType>,
) -> http::Request<()> {
const URI: &str = "https://api.twitter.com/oauth/request_token";
#[derive(oauth::Request)]
struct RequestToken {
x_auth_access_type: Option<AuthAccessType>,
}
let authorization = oauth::Builder::<_, _>::new(client_credentials, oauth::HmacSha1)
.callback(callback)
.post(URI, &RequestToken { x_auth_access_type });
http::Request::post(Uri::from_static(URI))
.header(AUTHORIZATION, authorization)
.body(())
.unwrap()
}
pub fn access_token<C, T, S, B>(
client_credentials: &Credentials<C>,
temporary_credentials: &Credentials<T>,
oauth_verifier: &str,
client: &mut S,
) -> AuthFuture<AccessToken, S::Future>
where
C: AsRef<str>,
T: AsRef<str>,
S: HttpService<B>,
B: Default,
{
let req = access_token_request(
client_credentials.as_ref(),
temporary_credentials.as_ref(),
oauth_verifier,
);
AuthFuture::new(client.call(req.map(|()| Default::default())))
}
fn access_token_request(
client_credentials: Credentials<&str>,
temporary_credentials: Credentials<&str>,
oauth_verifier: &str,
) -> http::Request<()> {
const URI: &str = "https://api.twitter.com/oauth/access_token";
let authorization = oauth::Builder::new(client_credentials, oauth::HmacSha1)
.token(temporary_credentials)
.verifier(oauth_verifier)
.post(URI, &());
http::Request::post(Uri::from_static(URI))
.header(AUTHORIZATION, authorization)
.body(())
.unwrap()
}
impl<'de> Deserialize<'de> for AccessToken {
fn deserialize<D: de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
#[derive(Deserialize)]
struct AccessToken {
oauth_token: Box<str>,
oauth_token_secret: Box<str>,
user_id: i64,
screen_name: Box<str>,
}
AccessToken::deserialize(d).map(|t| Self {
credentials: Credentials::new(t.oauth_token, t.oauth_token_secret),
user_id: t.user_id,
screen_name: t.screen_name,
})
}
}
pub fn token<C, S, B>(
client_credentials: &Credentials<C>,
client: &mut S,
) -> AuthFuture<AccessToken2, S::Future>
where
C: AsRef<str>,
S: HttpService<B>,
B: From<&'static [u8]>,
{
let req = token_request(client_credentials.as_ref());
AuthFuture::new(client.call(req.map(Into::into)))
}
fn token_request(client_credentials: Credentials<&str>) -> http::Request<&'static [u8]> {
const URI: &str = "https://api.twitter.com/oauth2/token";
let authorization = basic_auth(client_credentials);
let application_www_form_urlencoded =
HeaderValue::from_static("application/x-www-form-urlencoded");
http::Request::post(Uri::from_static(URI))
.header(AUTHORIZATION, authorization)
.header(CONTENT_TYPE, application_www_form_urlencoded)
.body(&b"grant_type=client_credentials"[..])
.unwrap()
}
fn basic_auth(credentials: Credentials<&str>) -> String {
let b64len = (credentials.identifier.len() + credentials.secret.len()) / 3 * 4 + 4;
let mut authorization = String::with_capacity("Basic ".len() + b64len);
authorization.push_str("Basic ");
let mut enc = base64::write::EncoderStringWriter::from(authorization, base64::STANDARD);
enc.write_all(credentials.identifier.as_bytes()).unwrap();
enc.write_all(b":").unwrap();
enc.write_all(credentials.secret.as_bytes()).unwrap();
enc.into_inner()
}
mod private {
use bytes::Bytes;
pub trait AuthDeserialize: Sized {
fn deserialize(input: Bytes) -> Option<Self>;
}
}