use super::{Msg, Reply};
use crate::auth;
use crate::key::PubKey;
use crate::msg;
use crate::negotiation;
use crate::negotiation::Named;
use crate::negotiation::Select;
use crate::session::*;
use crate::{ChannelId, ChannelOpenFailure, Error, Sig};
use cryptovec::CryptoVec;
use std::cell::RefCell;
use thrussh_keys::encoding::{Encoding, Reader};
thread_local! {
static SIGNATURE_BUFFER: RefCell<CryptoVec> = RefCell::new(CryptoVec::new());
}
impl super::Session {
pub(crate) async fn client_read_encrypted<C: super::Handler>(
mut self,
client: &mut Option<C>,
buf: &[u8],
) -> Result<Self, C::Error> {
debug!(
"client_read_encrypted, buf = {:?}",
&buf[..buf.len().min(20)]
);
if buf[0] == msg::KEXINIT {
if let Some(ref mut enc) = self.common.encrypted {
if let Some(Kex::KexInit(kexinit)) = enc.rekey.take() {
enc.rekey = Some(Kex::KexDhDone(kexinit.client_parse(
self.common.config.as_ref(),
&self.common.cipher,
buf,
&mut self.common.write_buffer,
)?));
} else if let Some(exchange) = std::mem::replace(&mut enc.exchange, None) {
let kexinit = KexInit::received_rekey(
exchange,
negotiation::Client::read_kex(buf, &self.common.config.as_ref().preferred)?,
&enc.session_id,
);
enc.rekey = Some(Kex::KexDhDone(kexinit.client_parse(
self.common.config.as_ref(),
&mut self.common.cipher,
buf,
&mut self.common.write_buffer,
)?));
}
} else {
unreachable!()
}
self.flush()?;
return Ok(self);
}
if let Some(ref mut enc) = self.common.encrypted {
match enc.rekey.take() {
Some(Kex::KexDhDone(mut kexdhdone)) => {
if kexdhdone.names.ignore_guessed {
kexdhdone.names.ignore_guessed = false;
enc.rekey = Some(Kex::KexDhDone(kexdhdone));
return Ok(self);
} else if buf[0] == msg::KEX_ECDH_REPLY {
enc.rekey = Some(kexdhdone.server_key_check(true, client, buf).await?);
self.common
.cipher
.write(&[msg::NEWKEYS], &mut self.common.write_buffer);
if self.common.write_buffer.strict {
self.common.write_buffer.seqn = std::num::Wrapping(0);
}
self.flush()?;
return Ok(self);
} else {
error!("Wrong packet received");
return Err(Error::Inconsistent.into());
}
}
Some(Kex::NewKeys(newkeys)) => {
if buf[0] != msg::NEWKEYS {
return Err(Error::Kex.into());
}
self.common.write_buffer.bytes = 0;
enc.last_rekey = std::time::Instant::now();
enc.flush_all_pending();
let mut pending = std::mem::replace(&mut self.pending_reads, Vec::new());
for p in pending.drain(..) {
self = self.process_packet(client, &p).await?
}
self.pending_reads = pending;
self.pending_len = 0;
self.common.newkeys(newkeys);
self.flush()?;
return Ok(self);
}
Some(Kex::KexInit(k)) => {
enc.rekey = Some(Kex::KexInit(k));
self.pending_len += buf.len() as u32;
if self.pending_len > 2 * self.target_window_size {
return Err(Error::Pending.into());
}
self.pending_reads.push(CryptoVec::from_slice(buf));
return Ok(self);
}
rek => enc.rekey = rek,
}
}
self.process_packet(client, buf).await
}
async fn process_packet<H: super::Handler>(
mut self,
client: &mut Option<H>,
buf: &[u8],
) -> Result<Self, H::Error> {
debug!("buf = {:?} bytes", buf.len());
trace!("buf = {:?}", buf);
let mut is_authenticated = false;
if let Some(ref mut enc) = self.common.encrypted {
match enc.state {
EncryptedState::WaitingServiceRequest {
ref mut accepted, ..
} => {
debug!(
"waiting service request, {:?} {:?}",
buf[0],
msg::SERVICE_ACCEPT
);
if buf[0] == msg::SERVICE_ACCEPT {
let mut r = buf.reader(1);
if r.read_string().map_err(crate::Error::from)? == b"ssh-userauth" {
*accepted = true;
if let Some(ref meth) = self.common.auth_method {
let auth_request = auth::AuthRequest {
methods: auth::MethodSet::all(),
partial_success: false,
current: None,
rejection_count: 0,
};
let len = enc.write.len();
if enc.write_auth_request(&self.common.auth_user, meth) {
debug!("enc: {:?}", &enc.write[len..]);
enc.state = EncryptedState::WaitingAuthRequest(auth_request)
}
} else {
debug!("no auth method")
}
}
} else {
debug!("unknown message: {:?}", buf);
return Err(Error::Inconsistent.into());
}
}
EncryptedState::WaitingAuthRequest(ref mut auth_request) => {
if buf[0] == msg::USERAUTH_SUCCESS {
debug!("userauth_success");
self.sender
.send(Reply::AuthSuccess)
.map_err(|_| Error::SendError)?;
enc.state = EncryptedState::InitCompression;
enc.server_compression.init_decompress(&mut enc.decompress);
return Ok(self);
} else if buf[0] == msg::USERAUTH_BANNER {
let mut r = buf.reader(1);
let banner = r.read_string().map_err(crate::Error::from)?;
if let Ok(banner) = std::str::from_utf8(banner) {
let c = client.take().unwrap();
let (c, s) = c.auth_banner(banner, self).await?;
*client = Some(c);
return Ok(s);
} else {
return Ok(self);
}
} else if buf[0] == msg::USERAUTH_FAILURE {
debug!("userauth_failure");
let mut r = buf.reader(1);
let remaining_methods = r.read_string().map_err(crate::Error::from)?;
debug!(
"remaining methods {:?}",
std::str::from_utf8(remaining_methods)
);
auth_request.methods = auth::MethodSet::empty();
for method in remaining_methods.split(|&c| c == b',') {
if let Some(m) = auth::MethodSet::from_bytes(method) {
auth_request.methods |= m
}
}
let no_more_methods = auth_request.methods.is_empty();
self.common.auth_method = None;
self.sender
.send(Reply::AuthFailure)
.map_err(|_| Error::SendError)?;
if no_more_methods {
return Err(Error::NoAuthMethod.into());
}
} else if buf[0] == msg::USERAUTH_PK_OK {
debug!("userauth_pk_ok");
if let Some(auth::CurrentRequest::PublicKey {
ref mut sent_pk_ok, ..
}) = auth_request.current
{
*sent_pk_ok = true;
}
match self.common.auth_method.take() {
Some(auth_method @ auth::Method::PublicKey { .. }) => {
self.common.buffer.clear();
enc.client_send_signature(
&self.common.auth_user,
&auth_method,
&mut self.common.buffer,
)?
}
Some(auth::Method::FuturePublicKey { key }) => {
debug!("public key");
self.common.buffer.clear();
let i = enc.client_make_to_sign(
&self.common.auth_user,
&key,
&mut self.common.buffer,
);
let len = self.common.buffer.len();
let buf =
std::mem::replace(&mut self.common.buffer, CryptoVec::new());
self.sender
.send(Reply::SignRequest { key, data: buf })
.map_err(|_| Error::SendError)?;
self.common.buffer = loop {
match self.receiver.recv().await {
Some(Msg::Signed { data }) => break data,
_ => {}
}
};
if self.common.buffer.len() != len {
push_packet!(enc.write, {
enc.write.extend(&self.common.buffer[i..]);
})
}
}
_ => {}
}
} else {
debug!("unknown message: {:?}", buf);
return Err(Error::Inconsistent.into());
}
}
EncryptedState::InitCompression => unreachable!(),
EncryptedState::Authenticated => is_authenticated = true,
}
}
if is_authenticated {
self.client_read_authenticated(client, buf).await
} else {
Ok(self)
}
}
async fn client_read_authenticated<C: super::Handler>(
mut self,
client: &mut Option<C>,
buf: &[u8],
) -> Result<Self, C::Error> {
match buf[0] {
msg::CHANNEL_OPEN_CONFIRMATION => {
debug!("channel_open_confirmation");
let mut reader = buf.reader(1);
let id_send = ChannelId(reader.read_u32().map_err(crate::Error::from)?);
let id_recv = reader.read_u32().map_err(crate::Error::from)?;
let window = reader.read_u32().map_err(crate::Error::from)?;
let max_packet = reader.read_u32().map_err(crate::Error::from)?;
if let Some(ref mut enc) = self.common.encrypted {
if let Some(parameters) = enc.channels.get_mut(&id_send) {
parameters.recipient_channel = id_recv;
parameters.recipient_window_size = window;
parameters.recipient_maximum_packet_size = max_packet;
parameters.confirmed = true;
} else {
return Err(Error::Inconsistent.into());
}
} else {
return Err(Error::Inconsistent.into());
};
let c = client.take().unwrap();
let (c, s) = c
.channel_open_confirmation(id_send, max_packet, window, self)
.await?;
*client = Some(c);
Ok(s)
}
msg::CHANNEL_CLOSE => {
debug!("channel_close");
let mut r = buf.reader(1);
let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?);
if let Some(ref mut enc) = self.common.encrypted {
enc.channels.remove(&channel_num);
}
let c = client.take().unwrap();
let (c, s) = c.channel_close(channel_num, self).await?;
*client = Some(c);
Ok(s)
}
msg::CHANNEL_EOF => {
debug!("channel_eof");
let mut r = buf.reader(1);
let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?);
let c = client.take().unwrap();
let (c, s) = c.channel_eof(channel_num, self).await?;
*client = Some(c);
Ok(s)
}
msg::CHANNEL_OPEN_FAILURE => {
debug!("channel_open_failure");
let mut r = buf.reader(1);
let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?);
let reason_code =
ChannelOpenFailure::from_u32(r.read_u32().map_err(crate::Error::from)?)?;
let descr = std::str::from_utf8(r.read_string().map_err(crate::Error::from)?)
.map_err(crate::Error::from)?;
let language = std::str::from_utf8(r.read_string().map_err(crate::Error::from)?)
.map_err(crate::Error::from)?;
if let Some(ref mut enc) = self.common.encrypted {
enc.channels.remove(&channel_num);
}
let c = client.take().unwrap();
let (c, s) = c
.channel_open_failure(channel_num, reason_code, descr, language, self)
.await?;
*client = Some(c);
Ok(s)
}
msg::CHANNEL_DATA => {
debug!("channel_data");
let mut r = buf.reader(1);
let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?);
let data = r.read_string().map_err(crate::Error::from)?;
let target = self.common.config.window_size;
let mut c = client.take().unwrap();
if let Some(ref mut enc) = self.common.encrypted {
if enc.adjust_window_size(channel_num, data, target) {
let next_window = c.adjust_window(channel_num, self.target_window_size);
if next_window > 0 {
self.target_window_size = next_window
}
}
}
let (c, s) = c.data(channel_num, &data, self).await?;
*client = Some(c);
Ok(s)
}
msg::CHANNEL_EXTENDED_DATA => {
debug!("channel_extended_data");
let mut r = buf.reader(1);
let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?);
let extended_code = r.read_u32().map_err(crate::Error::from)?;
let data = r.read_string().map_err(crate::Error::from)?;
let target = self.common.config.window_size;
let mut c = client.take().unwrap();
if let Some(ref mut enc) = self.common.encrypted {
if enc.adjust_window_size(channel_num, data, target) {
let next_window = c.adjust_window(channel_num, self.target_window_size);
if next_window > 0 {
self.target_window_size = next_window
}
}
}
let (c, s) = c
.extended_data(channel_num, extended_code, &data, self)
.await?;
*client = Some(c);
Ok(s)
}
msg::CHANNEL_REQUEST => {
let mut r = buf.reader(1);
let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?);
let req = r.read_string().map_err(crate::Error::from)?;
debug!(
"channel_request: {:?} {:?}",
channel_num,
std::str::from_utf8(req)
);
let cl = client.take().unwrap();
let (c, s) = match req {
b"forwarded_tcpip" => {
let a = std::str::from_utf8(r.read_string().map_err(crate::Error::from)?)
.map_err(crate::Error::from)?;
let b = r.read_u32().map_err(crate::Error::from)?;
let c = std::str::from_utf8(r.read_string().map_err(crate::Error::from)?)
.map_err(crate::Error::from)?;
let d = r.read_u32().map_err(crate::Error::from)?;
cl.channel_open_forwarded_tcpip(channel_num, a, b, c, d, self)
.await?
}
b"xon-xoff" => {
r.read_byte().map_err(crate::Error::from)?; let client_can_do = r.read_byte().map_err(crate::Error::from)?;
cl.xon_xoff(channel_num, client_can_do != 0, self).await?
}
b"exit-status" => {
r.read_byte().map_err(crate::Error::from)?; let exit_status = r.read_u32().map_err(crate::Error::from)?;
cl.exit_status(channel_num, exit_status, self).await?
}
b"exit-signal" => {
r.read_byte().map_err(crate::Error::from)?; let signal_name =
Sig::from_name(r.read_string().map_err(crate::Error::from)?)?;
let core_dumped = r.read_byte().map_err(crate::Error::from)?;
let error_message =
std::str::from_utf8(r.read_string().map_err(crate::Error::from)?)
.map_err(crate::Error::from)?;
let lang_tag =
std::str::from_utf8(r.read_string().map_err(crate::Error::from)?)
.map_err(crate::Error::from)?;
cl.exit_signal(
channel_num,
signal_name,
core_dumped != 0,
error_message,
lang_tag,
self,
)
.await?
}
_ => {
let wants_reply = r.read_byte().map_err(crate::Error::from)?;
if wants_reply == 1 {
if let Some(ref mut enc) = self.common.encrypted {
self.common.wants_reply = false;
push_packet!(enc.write, {
enc.write.push(msg::CHANNEL_FAILURE);
enc.write.push_u32_be(channel_num.0)
})
}
}
info!(
"Unknown channel request {:?} {:?}",
std::str::from_utf8(req),
wants_reply
);
(cl, self)
}
};
*client = Some(c);
Ok(s)
}
msg::CHANNEL_WINDOW_ADJUST => {
debug!("channel_window_adjust");
let mut r = buf.reader(1);
let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?);
let amount = r.read_u32().map_err(crate::Error::from)?;
let mut new_value = 0;
debug!("amount: {:?}", amount);
if let Some(ref mut enc) = self.common.encrypted {
if let Some(ref mut channel) = enc.channels.get_mut(&channel_num) {
channel.recipient_window_size += amount;
new_value = channel.recipient_window_size;
} else {
return Err(Error::WrongChannel.into());
}
}
let c = client.take().unwrap();
let (c, s) = c.window_adjusted(channel_num, new_value, self).await?;
*client = Some(c);
Ok(s)
}
msg::GLOBAL_REQUEST => {
let mut r = buf.reader(1);
let req = r.read_string().map_err(crate::Error::from)?;
let wants_reply = r.read_byte().map_err(crate::Error::from)?;
if let Some(ref mut enc) = self.common.encrypted {
self.common.wants_reply = false;
push_packet!(enc.write, enc.write.push(msg::REQUEST_FAILURE))
}
info!(
"Unhandled global request: {:?} {:?}",
std::str::from_utf8(req),
wants_reply
);
Ok(self)
}
msg::CHANNEL_SUCCESS => {
let mut r = buf.reader(1);
let channel_num = ChannelId(r.read_u32().map_err(crate::Error::from)?);
let c = client.take().unwrap();
let (c, s) = c.channel_success(channel_num, self).await?;
*client = Some(c);
Ok(s)
}
_ => {
info!("Unhandled packet: {:?}", buf);
Ok(self)
}
}
}
pub(crate) fn write_auth_request_if_needed(&mut self, user: &str, meth: auth::Method) -> bool {
let mut is_waiting = false;
if let Some(ref mut enc) = self.common.encrypted {
is_waiting = match enc.state {
EncryptedState::WaitingAuthRequest(_) => true,
EncryptedState::WaitingServiceRequest {
accepted,
ref mut sent,
} => {
debug!("sending ssh-userauth service requset");
if !*sent {
let p = b"\x05\0\0\0\x0Cssh-userauth";
self.common.cipher.write(p, &mut self.common.write_buffer);
*sent = true
}
accepted
}
EncryptedState::InitCompression | EncryptedState::Authenticated => false,
};
debug!(
"write_auth_request_if_needed: is_waiting = {:?}",
is_waiting
);
if is_waiting {
enc.write_auth_request(user, &meth);
}
}
self.common.auth_user.clear();
self.common.auth_user.push_str(user);
self.common.auth_method = Some(meth);
is_waiting
}
}
impl Encrypted {
fn write_auth_request(&mut self, user: &str, auth_method: &auth::Method) -> bool {
push_packet!(self.write, {
self.write.push(msg::USERAUTH_REQUEST);
match *auth_method {
auth::Method::Password { ref password } => {
self.write.extend_ssh_string(user.as_bytes());
self.write.extend_ssh_string(b"ssh-connection");
self.write.extend_ssh_string(b"password");
self.write.push(0);
self.write.extend_ssh_string(password.as_bytes());
true
}
auth::Method::PublicKey { ref key } => {
self.write.extend_ssh_string(user.as_bytes());
self.write.extend_ssh_string(b"ssh-connection");
self.write.extend_ssh_string(b"publickey");
self.write.push(0);
debug!("write_auth_request: {:?}", key.name());
self.write.extend_ssh_string(key.name().as_bytes());
key.push_to(&mut self.write);
true
}
auth::Method::FuturePublicKey { ref key, .. } => {
self.write.extend_ssh_string(user.as_bytes());
self.write.extend_ssh_string(b"ssh-connection");
self.write.extend_ssh_string(b"publickey");
self.write.push(0);
self.write.extend_ssh_string(key.name().as_bytes());
key.push_to(&mut self.write);
true
}
}
})
}
fn client_make_to_sign<Key: Named + PubKey>(
&mut self,
user: &str,
key: &Key,
buffer: &mut CryptoVec,
) -> usize {
buffer.clear();
buffer.extend_ssh_string(self.session_id.as_ref());
let i0 = buffer.len();
buffer.push(msg::USERAUTH_REQUEST);
buffer.extend_ssh_string(user.as_bytes());
buffer.extend_ssh_string(b"ssh-connection");
buffer.extend_ssh_string(b"publickey");
buffer.push(1);
buffer.extend_ssh_string(key.name().as_bytes());
key.push_to(buffer);
i0
}
fn client_send_signature(
&mut self,
user: &str,
method: &auth::Method,
buffer: &mut CryptoVec,
) -> Result<(), Error> {
match method {
&auth::Method::PublicKey { ref key } => {
let i0 = self.client_make_to_sign(user, key.as_ref(), buffer);
key.add_self_signature(buffer)?;
push_packet!(self.write, {
self.write.extend(&buffer[i0..]);
})
}
_ => {}
}
Ok(())
}
}