use std;
use std::net::ToSocketAddrs;
use std::sync::Arc;
use futures::future::Future;
use thrussh_keys::key;
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpListener;
use tokio::pin;
use crate::session::*;
use crate::ssh_read::*;
use crate::sshbuffer::*;
use crate::*;
mod kex;
mod session;
pub use self::session::*;
mod encrypted;
#[derive(Debug)]
pub struct Config {
pub server_id: String,
pub methods: auth::MethodSet,
pub auth_banner: Option<&'static str>,
pub auth_rejection_time: std::time::Duration,
pub keys: Vec<key::KeyPair>,
pub limits: Limits,
pub window_size: u32,
pub maximum_packet_size: u32,
pub preferred: Preferred,
pub max_auth_attempts: usize,
pub connection_timeout: Option<std::time::Duration>,
}
impl Default for Config {
fn default() -> Config {
Config {
server_id: format!(
"SSH-2.0-{}_{}",
env!("CARGO_PKG_NAME"),
env!("CARGO_PKG_VERSION")
),
methods: auth::MethodSet::all(),
auth_banner: None,
auth_rejection_time: std::time::Duration::from_secs(1),
keys: Vec::new(),
window_size: 2097152,
maximum_packet_size: 32768,
limits: Limits::default(),
preferred: Preferred::DEFAULT_SERVER,
max_auth_attempts: 10,
connection_timeout: Some(std::time::Duration::from_secs(600)),
}
}
}
#[derive(Debug)]
pub struct Response<'a> {
pos: thrussh_keys::encoding::Position<'a>,
n: u32,
}
impl<'a> Iterator for Response<'a> {
type Item = &'a [u8];
fn next(&mut self) -> Option<Self::Item> {
if self.n == 0 {
None
} else {
self.n -= 1;
self.pos.read_string().ok()
}
}
}
use std::borrow::Cow;
#[derive(Debug, PartialEq, Eq)]
pub enum Auth {
Reject,
Accept,
UnsupportedMethod,
Partial {
name: Cow<'static, str>,
instructions: Cow<'static, str>,
prompts: Cow<'static, [(Cow<'static, str>, bool)]>,
},
}
pub trait Handler: Sized {
type Error: From<crate::Error> + Send;
type FutureAuth: Future<Output = Result<(Self, Auth), Self::Error>> + Send;
type FutureUnit: Future<Output = Result<(Self, Session), Self::Error>> + Send;
type FutureBool: Future<Output = Result<(Self, Session, bool), Self::Error>> + Send;
fn finished_auth(self, auth: Auth) -> Self::FutureAuth;
fn finished_bool(self, b: bool, session: Session) -> Self::FutureBool;
fn finished(self, session: Session) -> Self::FutureUnit;
#[allow(unused_variables)]
fn auth_none(self, user: &str) -> Self::FutureAuth {
self.finished_auth(Auth::Reject)
}
#[allow(unused_variables)]
fn auth_password(self, user: &str, password: &str) -> Self::FutureAuth {
self.finished_auth(Auth::Reject)
}
#[allow(unused_variables)]
fn auth_publickey(self, user: &str, public_key: &key::PublicKey) -> Self::FutureAuth {
self.finished_auth(Auth::Reject)
}
#[allow(unused_variables)]
fn auth_keyboard_interactive(
self,
user: &str,
submethods: &str,
response: Option<Response>,
) -> Self::FutureAuth {
self.finished_auth(Auth::Reject)
}
#[allow(unused_variables)]
fn channel_close(self, channel: ChannelId, session: Session) -> Self::FutureUnit {
self.finished(session)
}
#[allow(unused_variables)]
fn channel_eof(self, channel: ChannelId, session: Session) -> Self::FutureUnit {
self.finished(session)
}
#[allow(unused_variables)]
fn channel_open_session(self, channel: ChannelId, session: Session) -> Self::FutureUnit {
self.finished(session)
}
#[allow(unused_variables)]
fn channel_open_x11(
self,
channel: ChannelId,
originator_address: &str,
originator_port: u32,
session: Session,
) -> Self::FutureUnit {
self.finished(session)
}
#[allow(unused_variables)]
fn channel_open_direct_tcpip(
self,
channel: ChannelId,
host_to_connect: &str,
port_to_connect: u32,
originator_address: &str,
originator_port: u32,
session: Session,
) -> Self::FutureUnit {
self.finished(session)
}
#[allow(unused_variables)]
fn data(self, channel: ChannelId, data: &[u8], session: Session) -> Self::FutureUnit {
self.finished(session)
}
#[allow(unused_variables)]
fn extended_data(
self,
channel: ChannelId,
code: u32,
data: &[u8],
session: Session,
) -> Self::FutureUnit {
self.finished(session)
}
#[allow(unused_variables)]
fn window_adjusted(
self,
channel: ChannelId,
new_window_size: usize,
mut session: Session,
) -> Self::FutureUnit {
if let Some(ref mut enc) = session.common.encrypted {
enc.flush_pending(channel);
}
self.finished(session)
}
#[allow(unused_variables)]
fn adjust_window(&mut self, channel: ChannelId, current: u32) -> u32 {
current
}
#[allow(unused_variables)]
fn pty_request(
self,
channel: ChannelId,
term: &str,
col_width: u32,
row_height: u32,
pix_width: u32,
pix_height: u32,
modes: &[(Pty, u32)],
session: Session,
) -> Self::FutureUnit {
self.finished(session)
}
#[allow(unused_variables)]
fn x11_request(
self,
channel: ChannelId,
single_connection: bool,
x11_auth_protocol: &str,
x11_auth_cookie: &str,
x11_screen_number: u32,
session: Session,
) -> Self::FutureUnit {
self.finished(session)
}
#[allow(unused_variables)]
fn env_request(
self,
channel: ChannelId,
variable_name: &str,
variable_value: &str,
session: Session,
) -> Self::FutureUnit {
self.finished(session)
}
#[allow(unused_variables)]
fn shell_request(self, channel: ChannelId, session: Session) -> Self::FutureUnit {
self.finished(session)
}
#[allow(unused_variables)]
fn exec_request(self, channel: ChannelId, data: &[u8], session: Session) -> Self::FutureUnit {
self.finished(session)
}
#[allow(unused_variables)]
fn subsystem_request(
self,
channel: ChannelId,
name: &str,
session: Session,
) -> Self::FutureUnit {
self.finished(session)
}
#[allow(unused_variables)]
fn window_change_request(
self,
channel: ChannelId,
col_width: u32,
row_height: u32,
pix_width: u32,
pix_height: u32,
session: Session,
) -> Self::FutureUnit {
self.finished(session)
}
#[allow(unused_variables)]
fn signal(self, channel: ChannelId, signal_name: Sig, session: Session) -> Self::FutureUnit {
self.finished(session)
}
#[allow(unused_variables)]
fn tcpip_forward(self, address: &str, port: u32, session: Session) -> Self::FutureBool {
self.finished_bool(false, session)
}
#[allow(unused_variables)]
fn cancel_tcpip_forward(self, address: &str, port: u32, session: Session) -> Self::FutureBool {
self.finished_bool(false, session)
}
}
pub trait Server {
type Handler: Handler + Send;
fn new(&mut self, peer_addr: Option<std::net::SocketAddr>) -> Self::Handler;
}
pub async fn run<H: Server + Send + 'static>(
config: Arc<Config>,
addr: &str,
mut server: H,
) -> Result<(), std::io::Error> {
let addr = addr.to_socket_addrs().unwrap().next().unwrap();
let socket = TcpListener::bind(&addr).await?;
if config.maximum_packet_size > 65535 {
error!(
"Maximum packet size ({:?}) should not larger than a TCP packet (65535)",
config.maximum_packet_size
);
}
while let Ok((socket, _)) = socket.accept().await {
let config = config.clone();
let server = server.new(socket.peer_addr().ok());
tokio::spawn(run_stream(config, socket, server));
}
Ok(())
}
use std::cell::RefCell;
thread_local! {
static B1: RefCell<CryptoVec> = RefCell::new(CryptoVec::new());
static B2: RefCell<CryptoVec> = RefCell::new(CryptoVec::new());
}
pub async fn timeout(delay: Option<std::time::Duration>) {
if let Some(delay) = delay {
tokio::time::sleep(delay).await
} else {
futures::future::pending().await
};
}
async fn start_reading<R: AsyncRead + Unpin>(
mut stream_read: R,
mut buffer: SSHBuffer,
cipher: Arc<crate::cipher::CipherPair>,
) -> Result<(usize, R, SSHBuffer), Error> {
buffer.buffer.clear();
let n = cipher::read(&mut stream_read, &mut buffer, &cipher).await?;
Ok((n, stream_read, buffer))
}
pub async fn run_stream<H: Handler, R>(
config: Arc<Config>,
mut stream: R,
handler: H,
) -> Result<H, H::Error>
where
R: AsyncRead + AsyncWrite + Unpin,
{
let mut handler = Some(handler);
let delay = config.connection_timeout;
let mut decomp = CryptoVec::new();
let mut write_buffer = SSHBuffer::new();
write_buffer.send_ssh_id(config.as_ref().server_id.as_bytes());
stream
.write_all(&write_buffer.buffer[..])
.await
.map_err(crate::Error::from)?;
let mut stream = SshRead::new(&mut stream);
let common = read_ssh_id(config, &mut stream).await?;
let (sender, receiver) = tokio::sync::mpsc::channel(10);
let mut session = Session {
target_window_size: common.config.window_size,
common,
receiver,
sender: server::session::Handle { sender },
pending_reads: Vec::new(),
pending_len: 0,
};
session.flush()?;
stream
.write_all(&session.common.write_buffer.buffer)
.await
.map_err(crate::Error::from)?;
session.common.write_buffer.buffer.clear();
let (stream_read, mut stream_write) = stream.split();
let buffer = SSHBuffer::new();
let reading = start_reading(stream_read, buffer, session.common.cipher.clone());
pin!(reading);
let mut is_reading = None;
while !session.common.disconnected {
tokio::select! {
r = &mut reading => {
let (stream_read, mut buffer) = match r {
Ok((_, stream_read, buffer)) => (stream_read, buffer),
Err(e) => return Err(e.into())
};
if buffer.buffer.len() < 5 {
is_reading = Some((stream_read, buffer));
break
}
let buf = if let Some(ref mut enc) = session.common.encrypted {
let d = enc.decompress.decompress(
&buffer.buffer[5..],
&mut decomp,
);
if let Ok(buf) = d {
buf
} else {
debug!("err = {:?}", d);
is_reading = Some((stream_read, buffer));
break
}
} else {
&buffer.buffer[5..]
};
if !buf.is_empty() {
if buf[0] == crate::msg::DISCONNECT {
debug!("break");
is_reading = Some((stream_read, buffer));
break;
} else if buf[0] > 4 {
buffer.strict = session.common.write_buffer.strict;
debug!("buffer strict {:?} {:?}", buffer.strict, buf[0]);
if buffer.strict && buf[0] == crate::msg::NEWKEYS {
buffer.seqn = std::num::Wrapping(0u32);
}
match reply(session, &mut handler, &buf[..]).await {
Ok(s) => session = s,
Err(e) => return Err(e),
}
}
}
reading.set(start_reading(stream_read, buffer, session.common.cipher.clone()));
}
_ = timeout(delay) => {
debug!("timeout");
break
},
msg = session.receiver.recv(), if !session.is_rekeying() => {
match msg {
Some((id, ChannelMsg::Data { data })) => {
session.data(id, data);
}
Some((id, ChannelMsg::ExtendedData { ext, data })) => {
session.extended_data(id, ext, data);
}
Some((id, ChannelMsg::Eof)) => {
session.eof(id);
}
Some((id, ChannelMsg::Close)) => {
session.close(id);
}
Some((id, ChannelMsg::XonXoff { client_can_do })) => {
session.xon_xoff_request(id, client_can_do);
}
Some((id, ChannelMsg::ExitStatus { exit_status })) => {
session.exit_status_request(id, exit_status);
}
Some((id, ChannelMsg::ExitSignal { signal_name, core_dumped, error_message, lang_tag })) => {
session.exit_signal_request(id, signal_name, core_dumped, &error_message, &lang_tag);
}
Some((id, ChannelMsg::WindowAdjusted { new_size })) => {
debug!("window adjusted to {:?} for channel {:?}", new_size, id);
}
Some((id, ChannelMsg::Success)) => {
debug!("channel success {:?}", id);
}
None => {
debug!("session.receiver: received None");
}
}
}
}
session.flush()?;
stream_write
.write_all(&session.common.write_buffer.buffer)
.await
.map_err(crate::Error::from)?;
session.common.write_buffer.buffer.clear();
}
debug!("disconnected");
stream_write.shutdown().await.map_err(crate::Error::from)?;
loop {
if let Some((stream_read, buffer)) = is_reading.take() {
reading.set(start_reading(
stream_read,
buffer,
session.common.cipher.clone(),
));
}
let (n, r, b) = (&mut reading).await?;
is_reading = Some((r, b));
if n == 0 {
break;
}
}
Ok(handler.unwrap())
}
async fn read_ssh_id<R: AsyncRead + Unpin>(
config: Arc<Config>,
read: &mut SshRead<R>,
) -> Result<CommonSession<Arc<Config>>, Error> {
let sshid = if let Some(t) = config.connection_timeout {
tokio::time::timeout(t, read.read_ssh_id()).await??
} else {
read.read_ssh_id().await?
};
let mut exchange = Exchange::new();
exchange.client_id.extend(sshid);
exchange
.server_id
.extend(config.as_ref().server_id.as_bytes());
let mut kexinit = KexInit {
exchange: exchange,
algo: None,
sent: false,
session_id: None,
nonstrict_packets_received: false,
};
let cipher = Arc::new(cipher::CLEAR_PAIR);
let mut write_buffer = SSHBuffer::new();
kexinit.server_write(config.as_ref(), cipher.as_ref(), &mut write_buffer)?;
Ok(CommonSession {
write_buffer,
kex: Some(Kex::KexInit(kexinit)),
auth_user: String::new(),
auth_method: None, cipher,
encrypted: None,
config: config,
wants_reply: false,
disconnected: false,
buffer: CryptoVec::new(),
})
}
async fn reply<H: Handler>(
mut session: Session,
handler: &mut Option<H>,
buf: &[u8],
) -> Result<Session, H::Error> {
debug!("buf {:?}", buf);
if session.common.encrypted.is_none() {
match session.common.kex.take() {
Some(Kex::KexInit(mut kexinit)) => {
if kexinit.algo.is_some() || buf[0] == msg::KEXINIT {
let nonstrict_received = kexinit.nonstrict_packets_received;
session.common.kex = Some(kexinit.server_parse(
session.common.config.as_ref(),
&session.common.cipher,
&buf,
&mut session.common.write_buffer,
)?);
if session.common.write_buffer.strict && nonstrict_received {
return Err(Error::KexInit.into())
}
} else {
kexinit.nonstrict_packets_received = true;
session.common.kex = Some(Kex::KexInit(kexinit))
}
}
Some(Kex::KexDh(kexdh)) => {
let client_supports_ext = kexdh.names.client_supports_ext;
session.common.kex = Some(kexdh.parse(
session.common.config.as_ref(),
&session.common.cipher,
buf,
&mut session.common.write_buffer,
)?);
if client_supports_ext {
session.send_server_sig_algs();
}
}
Some(Kex::NewKeys(newkeys)) => {
if buf[0] != msg::NEWKEYS {
return Err(Error::Kex.into());
}
session.common.encrypted(
EncryptedState::WaitingServiceRequest {
sent: false,
accepted: false,
},
newkeys,
);
}
Some(kex) => {
session.common.kex = Some(kex);
}
None => {}
}
Ok(session)
} else {
Ok(session.server_read_encrypted(handler, buf).await?)
}
}