use crate::Error;
use crate::mac;
use crate::sshbuffer::SSHBuffer;
use std::num::Wrapping;
use tokio::io::{AsyncRead, AsyncReadExt};
#[cfg(feature = "openssl")]
pub mod aes256_ctr;
#[cfg(feature = "openssl")]
pub mod aes256_gcm;
pub mod chacha20poly1305;
pub mod clear;
pub struct Cipher {
pub _name: Name,
pub key_len: usize,
pub iv_len: usize,
pub make_opening_cipher: fn(key: &[u8], iv: &[u8], mac: mac::MacKey) -> OpeningCipher,
pub make_sealing_cipher: fn(key: &[u8], iv: &[u8], mac: mac::MacKey) -> SealingCipher,
}
pub enum OpeningCipher {
Clear(clear::Key),
Chacha20Poly1305(chacha20poly1305::OpeningKey),
#[cfg(feature = "openssl")]
Aes256Gcm(aes256_gcm::Key),
#[cfg(feature = "openssl")]
Aes256Ctr(aes256_ctr::Key),
}
impl<'a> OpeningCipher {
fn as_opening_key(&self) -> &dyn OpeningKey {
match *self {
OpeningCipher::Clear(ref key) => key,
OpeningCipher::Chacha20Poly1305(ref key) => key,
#[cfg(feature = "openssl")]
OpeningCipher::Aes256Gcm(ref key) => key,
#[cfg(feature = "openssl")]
OpeningCipher::Aes256Ctr(ref key) => key,
}
}
}
pub enum SealingCipher {
Clear(clear::Key),
Chacha20Poly1305(chacha20poly1305::SealingKey),
#[cfg(feature = "openssl")]
Aes256Gcm(aes256_gcm::Key),
#[cfg(feature = "openssl")]
Aes256Ctr(aes256_ctr::Key),
}
impl<'a> SealingCipher {
fn as_sealing_key(&'a self) -> &'a dyn SealingKey {
match *self {
SealingCipher::Clear(ref key) => key,
SealingCipher::Chacha20Poly1305(ref key) => key,
#[cfg(feature = "openssl")]
SealingCipher::Aes256Gcm(ref key) => key,
#[cfg(feature = "openssl")]
SealingCipher::Aes256Ctr(ref key) => key,
}
}
}
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub struct Name(&'static str);
impl AsRef<str> for Name {
fn as_ref(&self) -> &str {
self.0
}
}
pub struct CipherPair {
pub local_to_remote: SealingCipher,
pub remote_to_local: OpeningCipher,
}
impl std::fmt::Debug for CipherPair {
fn fmt(&self, _: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
Ok(())
}
}
pub const CLEAR_PAIR: CipherPair = CipherPair {
local_to_remote: SealingCipher::Clear(clear::Key),
remote_to_local: OpeningCipher::Clear(clear::Key),
};
pub trait OpeningKey {
fn length_block_size(&self) -> usize;
fn decrypt_packet_length(&self, seqn: u32, encrypted_packet_length: &[u8]) -> u32;
fn tag_len(&self) -> usize;
fn open<'a>(
&self,
seqn: u32,
ciphertext_in_plaintext_out: &'a mut [u8],
tag: usize,
) -> Result<&'a [u8], Error>;
}
pub trait SealingKey {
fn padding_length(&self, plaintext: &[u8]) -> usize;
fn fill_padding(&self, padding_out: &mut [u8]);
fn tag_len(&self) -> usize;
fn seal(&self, seqn: u32, plaintext_in_ciphertext_out: &mut [u8], plaintext_len: usize);
}
const MAX_BLOCK_SIZE: usize = 16;
pub async fn read<'a, R: AsyncRead + Unpin>(
stream: &'a mut R,
buffer: &'a mut SSHBuffer,
pair: &'a CipherPair,
) -> Result<usize, Error> {
let len_block_size = if buffer.len == 0 {
let mut len = [0; MAX_BLOCK_SIZE];
let len = {
let key = pair.remote_to_local.as_opening_key();
&mut len[..key.length_block_size()]
};
stream.read_exact(len).await?;
debug!("reading, len = {:?}", len);
{
let key = pair.remote_to_local.as_opening_key();
let seqn = buffer.seqn.0;
buffer.buffer.clear();
buffer.buffer.extend(&len);
debug!("reading, seqn = {:?}", seqn);
let len = key.decrypt_packet_length(seqn, len);
buffer.len = len as usize + key.tag_len();
debug!("reading, clear len = {:?}", buffer.len);
key.length_block_size()
}
} else {
0
};
buffer.buffer.resize(buffer.len + 4);
debug!("read_exact {:?}", buffer.len + 4);
stream.read_exact(&mut buffer.buffer[len_block_size..]).await?;
debug!("read_exact done");
let key = pair.remote_to_local.as_opening_key();
let seqn = buffer.seqn.0;
let ciphertext_len = buffer.buffer.len() - key.tag_len();
let plaintext = key.open(seqn, &mut buffer.buffer, ciphertext_len)?;
let padding_length = plaintext[0] as usize;
debug!("reading, padding_length {:?}", padding_length);
let plaintext_end = plaintext
.len()
.checked_sub(padding_length)
.ok_or(Error::IndexOutOfBounds)?;
buffer.seqn += Wrapping(1);
buffer.len = 0;
buffer.buffer.resize(plaintext_end + 4);
Ok(plaintext_end + 4)
}
impl CipherPair {
pub fn write(&self, payload: &[u8], buffer: &mut SSHBuffer) {
debug!("writing, seqn = {:?}", buffer.seqn.0);
let key = self.local_to_remote.as_sealing_key();
let padding_length = key.padding_length(payload);
debug!("padding length {:?}", padding_length);
let packet_length = PADDING_LENGTH_LEN + payload.len() + padding_length;
debug!("packet_length {:?}", packet_length);
let offset = buffer.buffer.len();
assert!(packet_length <= std::u32::MAX as usize);
buffer.buffer.push_u32_be(packet_length as u32);
assert!(padding_length <= std::u8::MAX as usize);
buffer.buffer.push(padding_length as u8);
buffer.buffer.extend(payload);
key.fill_padding(buffer.buffer.resize_mut(padding_length));
buffer.buffer.resize_mut(key.tag_len());
let plaintext = &mut buffer.buffer[offset..];
key.seal(buffer.seqn.0, plaintext, PACKET_LENGTH_LEN + packet_length);
buffer.bytes += payload.len();
buffer.seqn += Wrapping(1);
}
}
pub const PACKET_LENGTH_LEN: usize = 4;
const MINIMUM_PACKET_LEN: usize = 16;
const PADDING_LENGTH_LEN: usize = 1;