use crate::auth;
use crate::negotiation;
use crate::pty::Pty;
use crate::session::*;
use crate::ssh_read::SshRead;
use crate::sshbuffer::*;
use crate::{ChannelId, ChannelMsg, ChannelOpenFailure, Disconnect, Limits, Preferred, Sig};
use cryptovec::CryptoVec;
use futures::task::{Context, Poll};
use futures::Future;
use std::cell::RefCell;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use thrussh_keys::encoding::{Encoding, Reader};
use thrussh_keys::key;
use thrussh_keys::key::parse_public_key;
use tokio;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::pin;
mod kex;
use crate::cipher;
use crate::{msg, Error};
mod encrypted;
mod session;
use tokio::sync::mpsc::*;
pub mod proxy;
pub struct Session {
common: CommonSession<Arc<Config>>,
receiver: Receiver<Msg>,
sender: UnboundedSender<Reply>,
channels: HashMap<ChannelId, UnboundedSender<OpenChannelMsg>>,
target_window_size: u32,
pending_reads: Vec<CryptoVec>,
pending_len: u32,
}
impl Drop for Session {
fn drop(&mut self) {
debug!("drop session")
}
}
#[derive(Debug)]
enum Reply {
AuthSuccess,
AuthFailure,
ChannelOpenFailure,
SignRequest {
key: thrussh_keys::key::PublicKey,
data: CryptoVec,
},
}
#[derive(Debug)]
enum Msg {
Authenticate {
user: String,
method: auth::Method,
},
Signed {
data: CryptoVec,
},
ChannelOpenSession {
sender: UnboundedSender<OpenChannelMsg>,
},
ChannelOpenX11 {
originator_address: String,
originator_port: u32,
sender: UnboundedSender<OpenChannelMsg>,
},
ChannelOpenDirectTcpIp {
host_to_connect: String,
port_to_connect: u32,
originator_address: String,
originator_port: u32,
sender: UnboundedSender<OpenChannelMsg>,
},
TcpIpForward {
want_reply: bool,
address: String,
port: u32,
},
CancelTcpIpForward {
want_reply: bool,
address: String,
port: u32,
},
Disconnect {
reason: Disconnect,
description: String,
language_tag: String,
},
Data {
id: ChannelId,
data: CryptoVec,
},
ExtendedData {
id: ChannelId,
data: CryptoVec,
ext: u32,
},
Eof {
id: ChannelId,
},
RequestPty {
id: ChannelId,
want_reply: bool,
term: String,
col_width: u32,
row_height: u32,
pix_width: u32,
pix_height: u32,
terminal_modes: Vec<(Pty, u32)>,
},
RequestShell {
id: ChannelId,
want_reply: bool,
},
Exec {
id: ChannelId,
want_reply: bool,
command: String,
},
Signal {
id: ChannelId,
signal: Sig,
},
RequestSubsystem {
id: ChannelId,
want_reply: bool,
name: String,
},
RequestX11 {
id: ChannelId,
want_reply: bool,
single_connection: bool,
x11_authentication_protocol: String,
x11_authentication_cookie: String,
x11_screen_number: u32,
},
SetEnv {
id: ChannelId,
want_reply: bool,
variable_name: String,
variable_value: String,
},
WindowChange {
id: ChannelId,
col_width: u32,
row_height: u32,
pix_width: u32,
pix_height: u32,
},
}
#[derive(Debug)]
enum OpenChannelMsg {
Open {
id: ChannelId,
max_packet_size: u32,
window_size: u32,
},
Msg(ChannelMsg),
}
pub struct Handle<H: Handler> {
sender: Sender<Msg>,
receiver: UnboundedReceiver<Reply>,
join: tokio::task::JoinHandle<Result<(), H::Error>>,
}
impl<H: Handler> Drop for Handle<H> {
fn drop(&mut self) {
debug!("drop handle")
}
}
#[derive(Clone)]
pub struct ChannelSender {
sender: Sender<Msg>,
id: ChannelId,
}
pub struct Channel {
sender: ChannelSender,
receiver: UnboundedReceiver<OpenChannelMsg>,
max_packet_size: u32,
window_size: u32,
}
impl<H: Handler> Handle<H> {
pub fn is_closed(&self) -> bool {
self.sender.is_closed()
}
pub async fn authenticate_password<U: Into<String>, P: Into<String>>(
&mut self,
user: U,
password: P,
) -> Result<bool, Error> {
let user = user.into();
self.sender
.send(Msg::Authenticate {
user,
method: auth::Method::Password {
password: password.into(),
},
})
.await
.map_err(|_| Error::SendError)?;
loop {
match self.receiver.recv().await {
Some(Reply::AuthSuccess) => return Ok(true),
Some(Reply::AuthFailure) => return Ok(false),
None => return Ok(false),
_ => {}
}
}
}
pub async fn authenticate_publickey<U: Into<String>>(
&mut self,
user: U,
key: Arc<key::KeyPair>,
) -> Result<bool, Error> {
let user = user.into();
self.sender
.send(Msg::Authenticate {
user,
method: auth::Method::PublicKey { key },
})
.await
.map_err(|_| Error::SendError)?;
loop {
match self.receiver.recv().await {
Some(Reply::AuthSuccess) => return Ok(true),
Some(Reply::AuthFailure) => return Ok(false),
None => return Ok(false),
_ => {}
}
}
}
pub async fn authenticate_future<U: Into<String>, S: auth::Signer>(
&mut self,
user: U,
key: key::PublicKey,
mut future: S,
) -> (S, Result<bool, S::Error>) {
let user = user.into();
if let Err(_) = self
.sender
.send(Msg::Authenticate {
user,
method: auth::Method::FuturePublicKey { key },
})
.await
{
return (future, Err((crate::SendError {}).into()));
}
loop {
let reply = self.receiver.recv().await;
match reply {
Some(Reply::AuthSuccess) => return (future, Ok(true)),
Some(Reply::AuthFailure) => return (future, Ok(false)),
Some(Reply::SignRequest { key, data }) => {
let (f, data) = future.auth_publickey_sign(&key, data).await;
future = f;
let data = match data {
Ok(data) => data,
Err(e) => return (future, Err(e.into())),
};
if let Err(_) = self.sender.send(Msg::Signed { data }).await {
return (future, Err((crate::SendError {}).into()));
}
}
None => return (future, Ok(false)),
_ => {}
}
}
}
async fn wait_channel_confirmation(
&self,
mut receiver: UnboundedReceiver<OpenChannelMsg>,
) -> Result<Channel, Error> {
loop {
match receiver.recv().await {
Some(OpenChannelMsg::Open {
id,
max_packet_size,
window_size,
}) => {
return Ok(Channel {
sender: ChannelSender {
sender: self.sender.clone(),
id,
},
receiver,
max_packet_size,
window_size,
});
}
None => {
return Err(Error::Disconnect.into());
}
msg => {
debug!("msg = {:?}", msg);
}
}
}
}
pub async fn channel_open_session(&mut self) -> Result<Channel, Error> {
let (sender, receiver) = unbounded_channel();
self.sender
.send(Msg::ChannelOpenSession { sender })
.await
.map_err(|_| Error::SendError)?;
self.wait_channel_confirmation(receiver).await
}
pub async fn channel_open_x11<A: Into<String>>(
&mut self,
originator_address: A,
originator_port: u32,
) -> Result<Channel, Error> {
let (sender, receiver) = unbounded_channel();
self.sender
.send(Msg::ChannelOpenX11 {
originator_address: originator_address.into(),
originator_port,
sender,
})
.await
.map_err(|_| Error::SendError)?;
self.wait_channel_confirmation(receiver).await
}
pub async fn channel_open_direct_tcpip<A: Into<String>, B: Into<String>>(
&mut self,
host_to_connect: A,
port_to_connect: u32,
originator_address: B,
originator_port: u32,
) -> Result<Channel, Error> {
let (sender, receiver) = unbounded_channel();
self.sender
.send(Msg::ChannelOpenDirectTcpIp {
host_to_connect: host_to_connect.into(),
port_to_connect,
originator_address: originator_address.into(),
originator_port,
sender,
})
.await
.map_err(|_| Error::SendError)?;
self.wait_channel_confirmation(receiver).await
}
pub async fn disconnect(
&mut self,
reason: Disconnect,
description: &str,
language_tag: &str,
) -> Result<(), Error> {
self.sender
.send(Msg::Disconnect {
reason,
description: description.into(),
language_tag: language_tag.into(),
})
.await
.map_err(|_| Error::SendError)?;
Ok(())
}
}
impl Channel {
pub fn id(&self) -> ChannelId {
self.sender.id
}
pub fn writable_packet_size(&self) -> usize {
self.max_packet_size.min(self.window_size) as usize
}
pub async fn request_pty(
&mut self,
want_reply: bool,
term: &str,
col_width: u32,
row_height: u32,
pix_width: u32,
pix_height: u32,
terminal_modes: &[(Pty, u32)],
) -> Result<(), Error> {
self.sender
.sender
.send(Msg::RequestPty {
id: self.sender.id,
want_reply,
term: term.to_string(),
col_width,
row_height,
pix_width,
pix_height,
terminal_modes: terminal_modes.to_vec(),
})
.await
.map_err(|_| Error::SendError)?;
Ok(())
}
pub async fn request_shell(&mut self, want_reply: bool) -> Result<(), Error> {
self.sender
.sender
.send(Msg::RequestShell {
id: self.sender.id,
want_reply,
})
.await
.map_err(|_| Error::SendError)?;
Ok(())
}
pub async fn exec<A: Into<String>>(
&mut self,
want_reply: bool,
command: A,
) -> Result<(), Error> {
self.sender
.sender
.send(Msg::Exec {
id: self.sender.id,
want_reply,
command: command.into(),
})
.await
.map_err(|e| {
debug!("e = {:?}", e);
Error::SendError
})?;
Ok(())
}
pub async fn signal(&mut self, signal: Sig) -> Result<(), Error> {
self.sender
.sender
.send(Msg::Signal {
id: self.sender.id,
signal,
})
.await
.map_err(|_| Error::SendError)?;
Ok(())
}
pub async fn request_subsystem<A: Into<String>>(
&mut self,
want_reply: bool,
name: A,
) -> Result<(), Error> {
self.sender
.sender
.send(Msg::RequestSubsystem {
id: self.sender.id,
want_reply,
name: name.into(),
})
.await
.map_err(|_| Error::SendError)?;
Ok(())
}
pub async fn tcpip_forward<A: Into<String>>(
&mut self,
want_reply: bool,
address: A,
port: u32,
) -> Result<(), Error> {
self.sender
.sender
.send(Msg::TcpIpForward {
want_reply,
address: address.into(),
port,
})
.await
.map_err(|_| Error::SendError)?;
Ok(())
}
pub async fn cancel_tcpip_forward<A: Into<String>>(
&mut self,
want_reply: bool,
address: A,
port: u32,
) -> Result<(), Error> {
self.sender
.sender
.send(Msg::CancelTcpIpForward {
want_reply,
address: address.into(),
port,
})
.await
.map_err(|_| Error::SendError)?;
Ok(())
}
pub async fn request_x11<A: Into<String>, B: Into<String>>(
&mut self,
want_reply: bool,
single_connection: bool,
x11_authentication_protocol: A,
x11_authentication_cookie: B,
x11_screen_number: u32,
) -> Result<(), Error> {
self.sender
.sender
.send(Msg::RequestX11 {
id: self.sender.id,
want_reply,
single_connection,
x11_authentication_protocol: x11_authentication_protocol.into(),
x11_authentication_cookie: x11_authentication_cookie.into(),
x11_screen_number,
})
.await
.map_err(|_| Error::SendError)?;
Ok(())
}
pub async fn set_env<A: Into<String>, B: Into<String>>(
&mut self,
want_reply: bool,
variable_name: A,
variable_value: B,
) -> Result<(), Error> {
self.sender
.sender
.send(Msg::SetEnv {
id: self.sender.id,
want_reply,
variable_name: variable_name.into(),
variable_value: variable_value.into(),
})
.await
.map_err(|_| Error::SendError)?;
Ok(())
}
pub async fn window_change(
&mut self,
col_width: u32,
row_height: u32,
pix_width: u32,
pix_height: u32,
) -> Result<(), Error> {
self.sender
.sender
.send(Msg::WindowChange {
id: self.sender.id,
col_width,
row_height,
pix_width,
pix_height,
})
.await
.map_err(|_| Error::SendError)?;
Ok(())
}
pub async fn data<R: tokio::io::AsyncReadExt + std::marker::Unpin>(
&mut self,
data: R,
) -> Result<(), Error> {
self.send_data(None, data).await
}
pub async fn extended_data<R: tokio::io::AsyncReadExt + std::marker::Unpin>(
&mut self,
ext: u32,
data: R,
) -> Result<(), Error> {
self.send_data(Some(ext), data).await
}
async fn send_data<R: tokio::io::AsyncReadExt + std::marker::Unpin>(
&mut self,
ext: Option<u32>,
mut data: R,
) -> Result<(), Error> {
let mut total = 0;
loop {
while self.window_size == 0 {
match self.receiver.recv().await {
Some(OpenChannelMsg::Msg(ChannelMsg::WindowAdjusted { new_size })) => {
debug!("window adjusted: {:?}", new_size);
self.window_size = new_size;
break;
}
Some(OpenChannelMsg::Msg(msg)) => {
debug!("unexpected channel msg: {:?}", msg);
}
Some(_) => debug!("unexpected channel msg"),
None => break,
}
}
debug!(
"sending data, self.window_size = {:?}, self.max_packet_size = {:?}, total = {:?}",
self.window_size, self.max_packet_size, total
);
let sendable = self.window_size.min(self.max_packet_size) as usize;
debug!("sendable {:?}", sendable);
let mut c = CryptoVec::new_zeroed(sendable);
let n = data.read(&mut c[..]).await?;
total += n;
c.resize(n);
self.window_size -= n as u32;
self.send_data_packet(ext, c).await?;
if n == 0 {
break;
} else if self.window_size > 0 {
continue;
}
}
Ok(())
}
async fn send_data_packet(&mut self, ext: Option<u32>, data: CryptoVec) -> Result<(), Error> {
self.sender
.sender
.send(if let Some(ext) = ext {
Msg::ExtendedData {
id: self.sender.id,
ext,
data,
}
} else {
Msg::Data {
id: self.sender.id,
data,
}
})
.await
.map_err(|e| {
error!("{:?}", e);
Error::SendError
})?;
Ok(())
}
pub async fn eof(&mut self) -> Result<(), Error> {
self.sender
.sender
.send(Msg::Eof { id: self.sender.id })
.await
.map_err(|_| Error::SendError)?;
Ok(())
}
pub async fn wait(&mut self) -> Option<ChannelMsg> {
loop {
match self.receiver.recv().await {
Some(OpenChannelMsg::Msg(ChannelMsg::WindowAdjusted { new_size })) => {
self.window_size += new_size;
return Some(ChannelMsg::WindowAdjusted { new_size });
}
Some(OpenChannelMsg::Msg(msg)) => return Some(msg),
None => return None,
_ => {}
}
}
}
}
impl<H: Handler> Future for Handle<H> {
type Output = Result<(), H::Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match Future::poll(Pin::new(&mut self.join), cx) {
Poll::Ready(r) => Poll::Ready(match r {
Ok(Ok(x)) => Ok(x),
Err(e) => Err(crate::Error::from(e).into()),
Ok(Err(e)) => Err(e),
}),
Poll::Pending => Poll::Pending,
}
}
}
use std::net::ToSocketAddrs;
pub async fn connect<H: Handler + Send + 'static, T: ToSocketAddrs>(
config: Arc<Config>,
addr: T,
handler: H,
) -> Result<Handle<H>, H::Error> {
let addr = addr
.to_socket_addrs()
.map_err(crate::Error::from)?
.next()
.unwrap();
let socket = TcpStream::connect(addr).await.map_err(crate::Error::from)?;
connect_stream(config, socket, handler).await
}
pub async fn connect_stream<H, R>(
config: Arc<Config>,
mut stream: R,
handler: H,
) -> Result<Handle<H>, H::Error>
where
H: Handler + Send + 'static,
R: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let mut write_buffer = SSHBuffer::new();
write_buffer.send_ssh_id(config.as_ref().client_id.as_bytes());
stream
.write_all(&write_buffer.buffer)
.await
.map_err(crate::Error::from)?;
let mut stream = SshRead::new(stream);
let sshid = stream.read_ssh_id().await?;
let (sender, receiver) = channel(10);
let (sender2, receiver2) = unbounded_channel();
if config.maximum_packet_size > 65535 {
error!(
"Maximum packet size ({:?}) should not larger than a TCP packet (65535)",
config.maximum_packet_size
);
}
let mut session = Session {
target_window_size: config.window_size,
common: CommonSession {
write_buffer,
kex: None,
auth_user: String::new(),
auth_method: None, cipher: Arc::new(cipher::CLEAR_PAIR),
encrypted: None,
config,
wants_reply: false,
disconnected: false,
buffer: CryptoVec::new(),
},
receiver,
sender: sender2,
channels: HashMap::new(),
pending_reads: Vec::new(),
pending_len: 0,
};
session.read_ssh_id(sshid)?;
let (encrypted_signal, encrypted_recv) = tokio::sync::oneshot::channel();
let join = tokio::spawn(session.run(stream, handler, Some(encrypted_signal)));
encrypted_recv.await.unwrap_or(());
Ok(Handle {
sender,
receiver: receiver2,
join,
})
}
async fn start_reading<R: AsyncRead + Unpin>(
mut stream_read: R,
mut buffer: SSHBuffer,
cipher: Arc<crate::cipher::CipherPair>,
) -> Result<(usize, R, SSHBuffer), Error> {
buffer.buffer.clear();
let n = cipher::read(&mut stream_read, &mut buffer, &cipher).await?;
Ok((n, stream_read, buffer))
}
impl Session {
async fn run<H: Handler + Send, R: AsyncRead + AsyncWrite + Unpin + Send>(
mut self,
mut stream: SshRead<R>,
handler: H,
mut encrypted_signal: Option<tokio::sync::oneshot::Sender<()>>,
) -> Result<(), H::Error> {
self.flush()?;
if !self.common.write_buffer.buffer.is_empty() {
debug!("writing {:?} bytes", self.common.write_buffer.buffer.len());
stream
.write_all(&self.common.write_buffer.buffer)
.await
.map_err(crate::Error::from)?;
stream.flush().await.map_err(crate::Error::from)?;
}
self.common.write_buffer.buffer.clear();
let mut decomp = CryptoVec::new();
let mut handler = Some(handler);
let (stream_read, mut stream_write) = stream.split();
let buffer = SSHBuffer::new();
let reading = start_reading(stream_read, buffer, self.common.cipher.clone());
pin!(reading);
while !self.common.disconnected {
tokio::select! {
r = &mut reading => {
let (stream_read, mut buffer) = match r {
Ok((_, stream_read, buffer)) => (stream_read, buffer),
Err(e) => return Err(e.into())
};
if buffer.buffer.len() < 5 {
break
}
let buf = if let Some(ref mut enc) = self.common.encrypted {
if let Ok(buf) = enc.decompress.decompress(
&buffer.buffer[5..],
&mut decomp,
) {
buf
} else {
break
}
} else {
&buffer.buffer[5..]
};
if !buf.is_empty() {
if buf[0] == crate::msg::DISCONNECT {
break;
} else if buf[0] > 4 {
buffer.strict = self.common.write_buffer.strict;
if buffer.strict && buf[0] == crate::msg::NEWKEYS {
buffer.seqn = std::num::Wrapping(0u32);
}
self = reply(self, &mut handler, &mut encrypted_signal, &buf[..]).await?;
}
}
reading.set(start_reading(stream_read, buffer, self.common.cipher.clone()));
}
msg = self.receiver.recv(), if !self.is_rekeying() => {
match msg {
Some(Msg::Authenticate { user, method }) => {
self.write_auth_request_if_needed(&user, method);
}
Some(Msg::Signed { .. }) => {},
Some(Msg::ChannelOpenSession { sender }) => {
let id = self.channel_open_session()?;
self.channels.insert(id, sender);
}
Some(Msg::ChannelOpenX11 { originator_address, originator_port, sender }) => {
let id = self.channel_open_x11(&originator_address, originator_port)?;
self.channels.insert(id, sender);
}
Some(Msg::ChannelOpenDirectTcpIp { host_to_connect, port_to_connect, originator_address, originator_port, sender }) => {
let id = self.channel_open_direct_tcpip(&host_to_connect, port_to_connect, &originator_address, originator_port)?;
self.channels.insert(id, sender);
}
Some(Msg::TcpIpForward { want_reply, address, port }) => {
self.tcpip_forward(want_reply, &address, port)
},
Some(Msg::CancelTcpIpForward { want_reply, address, port }) => {
self.cancel_tcpip_forward(want_reply, &address, port)
},
Some(Msg::Disconnect { reason, description, language_tag }) => {
self.disconnect(reason, &description, &language_tag)
},
Some(Msg::Data { data, id }) => { self.data(id, data) },
Some(Msg::Eof { id }) => { self.eof(id); },
Some(Msg::ExtendedData { data, ext, id }) => { self.extended_data(id, ext, data); },
Some(Msg::RequestPty { id, want_reply, term, col_width, row_height, pix_width, pix_height, terminal_modes }) => {
self.request_pty(id, want_reply, &term, col_width, row_height, pix_width, pix_height, &terminal_modes)
},
Some(Msg::WindowChange { id, col_width, row_height, pix_width, pix_height }) => {
self.window_change(id, col_width, row_height, pix_width, pix_height)
},
Some(Msg::RequestX11 { id, want_reply, single_connection, x11_authentication_protocol, x11_authentication_cookie, x11_screen_number }) => {
self.request_x11(id, want_reply, single_connection, &x11_authentication_protocol, &x11_authentication_cookie, x11_screen_number)
},
Some(Msg::SetEnv { id, want_reply, variable_name, variable_value }) => {
self.set_env(id, want_reply, &variable_name, &variable_value)
},
Some(Msg::RequestShell { id, want_reply }) => {
self.request_shell(want_reply, id)
},
Some(Msg::Exec { id, want_reply, command }) => {
self.exec(id, want_reply, &command)
},
Some(Msg::Signal { id, signal }) => {
self.signal(id, signal)
},
Some(Msg::RequestSubsystem { id, want_reply, name }) => {
self.request_subsystem(want_reply, id, &name)
},
None => {
self.common.disconnected = true;
break
}
}
}
}
self.flush()?;
if !self.common.write_buffer.buffer.is_empty() {
debug!(
"writing to stream: {:?} bytes",
self.common.write_buffer.buffer.len()
);
stream_write
.write_all(&self.common.write_buffer.buffer)
.await
.map_err(crate::Error::from)?;
stream_write.flush().await.map_err(crate::Error::from)?;
}
self.common.write_buffer.buffer.clear();
if let Some(ref mut enc) = self.common.encrypted {
if let EncryptedState::InitCompression = enc.state {
enc.client_compression.init_compress(&mut enc.compress);
enc.state = EncryptedState::Authenticated;
}
}
}
debug!("disconnected");
if self.common.disconnected {
stream_write.shutdown().await.map_err(crate::Error::from)?;
}
Ok(())
}
fn is_rekeying(&self) -> bool {
if let Some(ref enc) = self.common.encrypted {
enc.rekey.is_some()
} else {
true
}
}
fn read_ssh_id(&mut self, sshid: &[u8]) -> Result<(), Error> {
let mut exchange = Exchange::new();
exchange.server_id.extend(sshid);
exchange
.client_id
.extend(self.common.config.as_ref().client_id.as_bytes());
let mut kexinit = KexInit {
exchange: exchange,
algo: None,
sent: false,
session_id: None,
nonstrict_packets_received: false,
};
self.common.write_buffer.buffer.clear();
kexinit.client_write(
self.common.config.as_ref(),
&mut self.common.cipher,
&mut self.common.write_buffer,
)?;
self.common.kex = Some(Kex::KexInit(kexinit));
Ok(())
}
fn flush(&mut self) -> Result<(), Error> {
if let Some(ref mut enc) = self.common.encrypted {
if enc.flush(
&self.common.config.as_ref().limits,
&mut self.common.cipher,
&mut self.common.write_buffer,
)? {
info!("Re-exchanging keys");
if enc.rekey.is_none() {
if let Some(exchange) = std::mem::replace(&mut enc.exchange, None) {
let mut kexinit = KexInit::initiate_rekey(exchange, &enc.session_id);
kexinit.client_write(
&self.common.config.as_ref(),
&mut self.common.cipher,
&mut self.common.write_buffer,
)?;
enc.rekey = Some(Kex::KexInit(kexinit))
}
}
}
}
Ok(())
}
pub fn send_channel_msg(&self, channel: ChannelId, msg: ChannelMsg) -> bool {
if let Some(chan) = self.channels.get(&channel) {
chan.send(OpenChannelMsg::Msg(msg)).unwrap_or(());
true
} else {
false
}
}
}
thread_local! {
static HASH_BUFFER: RefCell<CryptoVec> = RefCell::new(CryptoVec::new());
}
impl KexDhDone {
async fn server_key_check<H: Handler>(
mut self,
rekey: bool,
handler: &mut Option<H>,
buf: &[u8],
) -> Result<Kex, H::Error> {
let mut reader = buf.reader(1);
let pubkey = reader.read_string().map_err(crate::Error::from)?; let pubkey = parse_public_key(pubkey).map_err(crate::Error::from)?;
debug!("server_public_Key: {:?}", pubkey);
if !rekey {
let h = handler.take().unwrap();
let (h, check) = h.check_server_key(&pubkey).await?;
*handler = Some(h);
if !check {
return Err(Error::UnknownKey.into());
}
}
HASH_BUFFER.with(|buffer| {
let mut buffer = buffer.borrow_mut();
buffer.clear();
let hash = {
let server_ephemeral = reader.read_string().map_err(crate::Error::from)?;
self.exchange.server_ephemeral.extend(server_ephemeral);
let signature = reader.read_string().map_err(crate::Error::from)?;
self.kex
.compute_shared_secret(&self.exchange.server_ephemeral)?;
debug!("kexdhdone.exchange = {:?}", self.exchange);
let hash = self
.kex
.compute_exchange_hash(&pubkey, &self.exchange, &mut buffer)?;
debug!("exchange hash: {:?}", hash);
let signature = {
let mut sig_reader = signature.reader(0);
let sig_type = sig_reader.read_string().map_err(crate::Error::from)?;
debug!("sig_type: {:?}", sig_type);
sig_reader.read_string().map_err(crate::Error::from)?
};
use thrussh_keys::key::Verify;
debug!("signature: {:?}", signature);
if !pubkey.verify_server_auth(hash.as_ref(), signature) {
debug!("wrong server sig");
return Err(Error::WrongServerSig.into());
}
hash
};
let mut newkeys = self.compute_keys(hash, false)?;
newkeys.sent = true;
Ok(Kex::NewKeys(newkeys))
})
}
}
async fn reply<H: Handler>(
mut session: Session,
handler: &mut Option<H>,
sender: &mut Option<tokio::sync::oneshot::Sender<()>>,
buf: &[u8],
) -> Result<Session, H::Error> {
match session.common.kex.take() {
Some(Kex::KexInit(mut kexinit)) => {
if kexinit.algo.is_some()
|| buf[0] == msg::KEXINIT
|| session.common.encrypted.is_none()
{
let nonstrict_received = kexinit.nonstrict_packets_received;
session.common.kex = Some(Kex::KexDhDone(kexinit.client_parse(
session.common.config.as_ref(),
&session.common.cipher,
buf,
&mut session.common.write_buffer,
)?));
if session.common.write_buffer.strict && nonstrict_received {
return Err(Error::KexInit.into());
}
session.flush()?;
} else {
kexinit.nonstrict_packets_received = true;
session.common.kex = Some(Kex::KexInit(kexinit))
}
Ok(session)
}
Some(Kex::KexDhDone(mut kexdhdone)) => {
if kexdhdone.names.ignore_guessed {
kexdhdone.names.ignore_guessed = false;
session.common.kex = Some(Kex::KexDhDone(kexdhdone));
Ok(session)
} else if buf[0] == msg::KEX_ECDH_REPLY {
session.common.kex = Some(kexdhdone.server_key_check(false, handler, buf).await?);
session
.common
.cipher
.write(&[msg::NEWKEYS], &mut session.common.write_buffer);
session.flush()?;
if session.common.write_buffer.strict {
session.common.write_buffer.seqn = std::num::Wrapping(0);
}
Ok(session)
} else {
error!("Wrong packet received");
Err(Error::Inconsistent.into())
}
}
Some(Kex::NewKeys(newkeys)) => {
debug!("newkeys received");
if buf[0] != msg::NEWKEYS {
return Err(Error::Kex.into());
}
if let Some(sender) = sender.take() {
sender.send(()).unwrap_or(());
}
session.common.encrypted(
EncryptedState::WaitingServiceRequest {
accepted: false,
sent: false,
},
newkeys,
);
Ok(session)
}
Some(kex) => {
session.common.kex = Some(kex);
Ok(session)
}
None => session.client_read_encrypted(handler, buf).await,
}
}
#[derive(Debug)]
pub struct Config {
pub client_id: String,
pub limits: Limits,
pub window_size: u32,
pub maximum_packet_size: u32,
pub preferred: negotiation::Preferred,
pub connection_timeout: Option<std::time::Duration>,
}
impl Default for Config {
fn default() -> Config {
Config {
client_id: format!(
"SSH-2.0-{}_{}",
env!("CARGO_PKG_NAME"),
env!("CARGO_PKG_VERSION")
),
limits: Limits::default(),
window_size: 2097152,
maximum_packet_size: 32768,
preferred: Preferred::DEFAULT_CLIENT,
connection_timeout: None,
}
}
}
pub trait Handler: Sized {
type Error: From<crate::Error> + Send;
type FutureBool: Future<Output = Result<(Self, bool), Self::Error>> + Send;
type FutureUnit: Future<Output = Result<(Self, Session), Self::Error>> + Send;
fn finished_bool(self, b: bool) -> Self::FutureBool;
fn finished(self, session: Session) -> Self::FutureUnit;
#[allow(unused_variables)]
fn auth_banner(self, banner: &str, session: Session) -> Self::FutureUnit {
self.finished(session)
}
#[allow(unused_variables)]
fn check_server_key(self, server_public_key: &key::PublicKey) -> Self::FutureBool {
self.finished_bool(false)
}
#[allow(unused_variables)]
fn channel_open_confirmation(
self,
id: ChannelId,
max_packet_size: u32,
window_size: u32,
session: Session,
) -> Self::FutureUnit {
if let Some(channel) = session.channels.get(&id) {
channel
.send(OpenChannelMsg::Open {
id,
max_packet_size,
window_size,
})
.unwrap_or(());
} else {
error!("no channel for id {:?}", id);
}
self.finished(session)
}
#[allow(unused_variables)]
fn channel_success(self, channel: ChannelId, session: Session) -> Self::FutureUnit {
if let Some(chan) = session.channels.get(&channel) {
chan.send(OpenChannelMsg::Msg(ChannelMsg::Success))
.unwrap_or(())
}
self.finished(session)
}
#[allow(unused_variables)]
fn channel_close(self, channel: ChannelId, mut session: Session) -> Self::FutureUnit {
session.channels.remove(&channel);
self.finished(session)
}
#[allow(unused_variables)]
fn channel_eof(self, channel: ChannelId, session: Session) -> Self::FutureUnit {
if let Some(chan) = session.channels.get(&channel) {
chan.send(OpenChannelMsg::Msg(ChannelMsg::Eof))
.unwrap_or(())
}
self.finished(session)
}
#[allow(unused_variables)]
fn channel_open_failure(
self,
channel: ChannelId,
reason: ChannelOpenFailure,
description: &str,
language: &str,
mut session: Session,
) -> Self::FutureUnit {
session.channels.remove(&channel);
session.sender.send(Reply::ChannelOpenFailure).unwrap_or(());
self.finished(session)
}
#[allow(unused_variables)]
fn channel_open_forwarded_tcpip(
self,
channel: ChannelId,
connected_address: &str,
connected_port: u32,
originator_address: &str,
originator_port: u32,
session: Session,
) -> Self::FutureUnit {
self.finished(session)
}
#[allow(unused_variables)]
fn data(self, channel: ChannelId, data: &[u8], session: Session) -> Self::FutureUnit {
if let Some(chan) = session.channels.get(&channel) {
chan.send(OpenChannelMsg::Msg(ChannelMsg::Data {
data: CryptoVec::from_slice(data),
}))
.unwrap_or(())
}
self.finished(session)
}
#[allow(unused_variables)]
fn extended_data(
self,
channel: ChannelId,
ext: u32,
data: &[u8],
session: Session,
) -> Self::FutureUnit {
if let Some(chan) = session.channels.get(&channel) {
chan.send(OpenChannelMsg::Msg(ChannelMsg::ExtendedData {
ext,
data: CryptoVec::from_slice(data),
}))
.unwrap_or(())
}
self.finished(session)
}
#[allow(unused_variables)]
fn xon_xoff(
self,
channel: ChannelId,
client_can_do: bool,
session: Session,
) -> Self::FutureUnit {
if let Some(chan) = session.channels.get(&channel) {
chan.send(OpenChannelMsg::Msg(ChannelMsg::XonXoff { client_can_do }))
.unwrap_or(())
}
self.finished(session)
}
#[allow(unused_variables)]
fn exit_status(
self,
channel: ChannelId,
exit_status: u32,
session: Session,
) -> Self::FutureUnit {
if let Some(chan) = session.channels.get(&channel) {
chan.send(OpenChannelMsg::Msg(ChannelMsg::ExitStatus { exit_status }))
.unwrap_or(())
}
self.finished(session)
}
#[allow(unused_variables)]
fn exit_signal(
self,
channel: ChannelId,
signal_name: Sig,
core_dumped: bool,
error_message: &str,
lang_tag: &str,
session: Session,
) -> Self::FutureUnit {
if let Some(chan) = session.channels.get(&channel) {
chan.send(OpenChannelMsg::Msg(ChannelMsg::ExitSignal {
signal_name,
core_dumped,
error_message: error_message.to_string(),
lang_tag: lang_tag.to_string(),
}))
.unwrap_or(())
}
self.finished(session)
}
#[allow(unused_variables)]
fn window_adjusted(
self,
channel: ChannelId,
mut new_size: u32,
mut session: Session,
) -> Self::FutureUnit {
if let Some(ref mut enc) = session.common.encrypted {
new_size -= enc.flush_pending(channel) as u32;
}
if let Some(chan) = session.channels.get(&channel) {
chan.send(OpenChannelMsg::Msg(ChannelMsg::WindowAdjusted { new_size }))
.unwrap_or(())
}
self.finished(session)
}
#[allow(unused_variables)]
fn adjust_window(&mut self, channel: ChannelId, window: u32) -> u32 {
window
}
}