use crate::sshbuffer::SSHBuffer;
use crate::Error;
use byteorder::{BigEndian, ByteOrder};
use std::num::Wrapping;
use tokio::io::{AsyncRead, AsyncReadExt};
pub mod chacha20poly1305;
pub mod clear;
pub struct Cipher {
pub _name: Name,
pub key_len: usize,
pub make_opening_cipher: fn(key: &[u8]) -> OpeningCipher,
pub make_sealing_cipher: fn(key: &[u8]) -> SealingCipher,
}
pub enum OpeningCipher {
Clear(clear::Key),
Chacha20Poly1305(chacha20poly1305::OpeningKey),
}
impl<'a> OpeningCipher {
fn as_opening_key(&self) -> &dyn OpeningKey {
match *self {
OpeningCipher::Clear(ref key) => key,
OpeningCipher::Chacha20Poly1305(ref key) => key,
}
}
}
pub enum SealingCipher {
Clear(clear::Key),
Chacha20Poly1305(chacha20poly1305::SealingKey),
}
impl<'a> SealingCipher {
fn as_sealing_key(&'a self) -> &'a dyn SealingKey {
match *self {
SealingCipher::Clear(ref key) => key,
SealingCipher::Chacha20Poly1305(ref key) => key,
}
}
}
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub struct Name(&'static str);
impl AsRef<str> for Name {
fn as_ref(&self) -> &str {
self.0
}
}
pub struct CipherPair {
pub local_to_remote: SealingCipher,
pub remote_to_local: OpeningCipher,
}
impl std::fmt::Debug for CipherPair {
fn fmt(&self, _: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
Ok(())
}
}
pub const CLEAR_PAIR: CipherPair = CipherPair {
local_to_remote: SealingCipher::Clear(clear::Key),
remote_to_local: OpeningCipher::Clear(clear::Key),
};
pub trait OpeningKey {
fn decrypt_packet_length(&self, seqn: u32, encrypted_packet_length: [u8; 4]) -> [u8; 4];
fn tag_len(&self) -> usize;
fn open<'a>(
&self,
seqn: u32,
ciphertext_in_plaintext_out: &'a mut [u8],
tag: &[u8],
) -> Result<&'a [u8], Error>;
}
pub trait SealingKey {
fn padding_length(&self, plaintext: &[u8]) -> usize;
fn fill_padding(&self, padding_out: &mut [u8]);
fn tag_len(&self) -> usize;
fn seal(&self, seqn: u32, plaintext_in_ciphertext_out: &mut [u8], tag_out: &mut [u8]);
}
pub async fn read<'a, R: AsyncRead + Unpin>(
stream: &'a mut R,
buffer: &'a mut SSHBuffer,
pair: &'a CipherPair,
) -> Result<usize, Error> {
if buffer.len == 0 {
let mut len = [0; 4];
stream.read_exact(&mut len).await?;
debug!("reading, len = {:?}", len);
{
let key = pair.remote_to_local.as_opening_key();
let seqn = buffer.seqn.0;
buffer.buffer.clear();
buffer.buffer.extend(&len);
debug!("reading, seqn = {:?}", seqn);
let len = key.decrypt_packet_length(seqn, len);
buffer.len = BigEndian::read_u32(&len) as usize + key.tag_len();
debug!("reading, clear len = {:?}", buffer.len);
}
}
buffer.buffer.resize(buffer.len + 4);
debug!("read_exact {:?}", buffer.len + 4);
stream.read_exact(&mut buffer.buffer[4..]).await?;
debug!("read_exact done");
let key = pair.remote_to_local.as_opening_key();
let seqn = buffer.seqn.0;
let ciphertext_len = buffer.buffer.len() - key.tag_len();
let (ciphertext, tag) = buffer.buffer.split_at_mut(ciphertext_len);
let plaintext = key.open(seqn, ciphertext, tag)?;
let padding_length = plaintext[0] as usize;
debug!("reading, padding_length {:?}", padding_length);
let plaintext_end = plaintext
.len()
.checked_sub(padding_length)
.ok_or(Error::IndexOutOfBounds)?;
buffer.seqn += Wrapping(1);
buffer.len = 0;
buffer.buffer.resize(plaintext_end + 4);
Ok(plaintext_end + 4)
}
impl CipherPair {
pub fn write(&self, payload: &[u8], buffer: &mut SSHBuffer) {
debug!("writing, seqn = {:?}", buffer.seqn.0);
let key = self.local_to_remote.as_sealing_key();
let padding_length = key.padding_length(payload);
debug!("padding length {:?}", padding_length);
let packet_length = PADDING_LENGTH_LEN + payload.len() + padding_length;
debug!("packet_length {:?}", packet_length);
let offset = buffer.buffer.len();
assert!(packet_length <= std::u32::MAX as usize);
buffer.buffer.push_u32_be(packet_length as u32);
assert!(padding_length <= std::u8::MAX as usize);
buffer.buffer.push(padding_length as u8);
buffer.buffer.extend(payload);
key.fill_padding(buffer.buffer.resize_mut(padding_length));
buffer.buffer.resize_mut(key.tag_len());
let (plaintext, tag) =
buffer.buffer[offset..].split_at_mut(PACKET_LENGTH_LEN + packet_length);
key.seal(buffer.seqn.0, plaintext, tag);
buffer.bytes += payload.len();
buffer.seqn += Wrapping(1);
}
}
pub const PACKET_LENGTH_LEN: usize = 4;
const MINIMUM_PACKET_LEN: usize = 16;
const PADDING_LENGTH_LEN: usize = 1;