use super::*;
use crate::msg;
use std::sync::Arc;
use thrussh_keys::encoding::Encoding;
use tokio::sync::mpsc::{Receiver, Sender};
pub struct Session {
pub(crate) common: CommonSession<Arc<Config>>,
pub(crate) sender: Handle,
pub(crate) receiver: Receiver<(ChannelId, ChannelMsg)>,
pub(crate) target_window_size: u32,
}
#[derive(Clone)]
pub struct Handle {
pub(crate) sender: Sender<(ChannelId, ChannelMsg)>,
}
impl Handle {
pub async fn data(&mut self, id: ChannelId, data: CryptoVec) -> Result<(), CryptoVec> {
self.sender
.send((id, ChannelMsg::Data { data }))
.await
.map_err(|e| match e.0 {
(_, ChannelMsg::Data { data }) => data,
_ => unreachable!(),
})
}
pub async fn extended_data(
&mut self,
id: ChannelId,
ext: u32,
data: CryptoVec,
) -> Result<(), CryptoVec> {
self.sender
.send((id, ChannelMsg::ExtendedData { ext, data }))
.await
.map_err(|e| match e.0 {
(_, ChannelMsg::ExtendedData { data, .. }) => data,
_ => unreachable!(),
})
}
pub async fn eof(&mut self, id: ChannelId) -> Result<(), ()> {
self.sender
.send((id, ChannelMsg::Eof))
.await
.map_err(|_| ())
}
pub async fn close(&mut self, id: ChannelId) -> Result<(), ()> {
self.sender
.send((id, ChannelMsg::Close))
.await
.map_err(|_| ())
}
pub async fn xon_xoff_request(&mut self, id: ChannelId, client_can_do: bool) -> Result<(), ()> {
self.sender
.send((id, ChannelMsg::XonXoff { client_can_do }))
.await
.map_err(|_| ())
}
pub async fn exit_status_request(&mut self, id: ChannelId, exit_status: u32) -> Result<(), ()> {
self.sender
.send((id, ChannelMsg::ExitStatus { exit_status }))
.await
.map_err(|_| ())
}
pub async fn exit_signal_request(
&mut self,
id: ChannelId,
signal_name: Sig,
core_dumped: bool,
error_message: String,
lang_tag: String,
) -> Result<(), ()> {
self.sender
.send((
id,
ChannelMsg::ExitSignal {
signal_name,
core_dumped,
error_message,
lang_tag,
},
))
.await
.map_err(|_| ())
}
}
impl Session {
pub(crate) fn is_rekeying(&self) -> bool {
if let Some(ref enc) = self.common.encrypted {
enc.rekey.is_some()
} else {
true
}
}
pub fn handle(&self) -> Handle {
self.sender.clone()
}
pub fn writable_packet_size(&self, channel: &ChannelId) -> u32 {
if let Some(ref enc) = self.common.encrypted {
if let Some(channel) = enc.channels.get(&channel) {
return channel
.sender_window_size
.min(channel.sender_maximum_packet_size);
}
}
0
}
pub fn window_size(&self, channel: &ChannelId) -> u32 {
if let Some(ref enc) = self.common.encrypted {
if let Some(channel) = enc.channels.get(&channel) {
return channel.sender_window_size;
}
}
0
}
pub fn max_packet_size(&self, channel: &ChannelId) -> u32 {
if let Some(ref enc) = self.common.encrypted {
if let Some(channel) = enc.channels.get(&channel) {
return channel.sender_maximum_packet_size;
}
}
0
}
pub fn flush(&mut self) -> Result<(), Error> {
if let Some(ref mut enc) = self.common.encrypted {
if enc.flush(
&self.common.config.as_ref().limits,
&self.common.cipher,
&mut self.common.write_buffer,
) {
if enc.rekey.is_none() {
debug!("starting rekeying");
if let Some(exchange) = enc.exchange.take() {
let mut kexinit = KexInit::initiate_rekey(exchange, &enc.session_id);
kexinit.server_write(
&self.common.config.as_ref(),
&mut self.common.cipher,
&mut self.common.write_buffer,
)?;
enc.rekey = Some(Kex::KexInit(kexinit))
}
}
}
}
Ok(())
}
pub fn flush_pending(&mut self, channel: ChannelId) -> usize {
if let Some(ref mut enc) = self.common.encrypted {
enc.flush_pending(channel)
} else {
0
}
}
pub fn sender_window_size(&self, channel: ChannelId) -> usize {
if let Some(ref enc) = self.common.encrypted {
enc.sender_window_size(channel)
} else {
0
}
}
pub fn has_pending_data(&self, channel: ChannelId) -> bool {
if let Some(ref enc) = self.common.encrypted {
enc.has_pending_data(channel)
} else {
false
}
}
pub fn config(&self) -> &Config {
&self.common.config
}
pub fn disconnect(&mut self, reason: Disconnect, description: &str, language_tag: &str) {
self.common.disconnect(reason, description, language_tag);
}
pub fn request_success(&mut self) {
if self.common.wants_reply {
if let Some(ref mut enc) = self.common.encrypted {
self.common.wants_reply = false;
push_packet!(enc.write, enc.write.push(msg::REQUEST_SUCCESS))
}
}
}
pub fn request_failure(&mut self) {
if let Some(ref mut enc) = self.common.encrypted {
self.common.wants_reply = false;
push_packet!(enc.write, enc.write.push(msg::REQUEST_FAILURE))
}
}
pub fn channel_success(&mut self, channel: ChannelId) {
if let Some(ref mut enc) = self.common.encrypted {
if let Some(channel) = enc.channels.get_mut(&channel) {
assert!(channel.confirmed);
if channel.wants_reply {
channel.wants_reply = false;
debug!("channel_success {:?}", channel);
push_packet!(enc.write, {
enc.write.push(msg::CHANNEL_SUCCESS);
enc.write.push_u32_be(channel.recipient_channel);
})
}
}
}
}
pub fn channel_failure(&mut self, channel: ChannelId) {
if let Some(ref mut enc) = self.common.encrypted {
if let Some(channel) = enc.channels.get_mut(&channel) {
assert!(channel.confirmed);
if channel.wants_reply {
channel.wants_reply = false;
push_packet!(enc.write, {
enc.write.push(msg::CHANNEL_FAILURE);
enc.write.push_u32_be(channel.recipient_channel);
})
}
}
}
}
pub fn channel_open_failure(
&mut self,
channel: ChannelId,
reason: ChannelOpenFailure,
description: &str,
language: &str,
) {
if let Some(ref mut enc) = self.common.encrypted {
push_packet!(enc.write, {
enc.write.push(msg::CHANNEL_OPEN_FAILURE);
enc.write.push_u32_be(channel.0);
enc.write.push_u32_be(reason as u32);
enc.write.extend_ssh_string(description.as_bytes());
enc.write.extend_ssh_string(language.as_bytes());
})
}
}
pub fn close(&mut self, channel: ChannelId) {
self.common.byte(channel, msg::CHANNEL_CLOSE);
}
pub fn eof(&mut self, channel: ChannelId) {
self.common.byte(channel, msg::CHANNEL_EOF);
}
pub fn data(&mut self, channel: ChannelId, data: CryptoVec) {
if let Some(ref mut enc) = self.common.encrypted {
enc.data(channel, data)
} else {
unreachable!()
}
}
pub fn extended_data(&mut self, channel: ChannelId, extended: u32, data: CryptoVec) {
if let Some(ref mut enc) = self.common.encrypted {
enc.extended_data(channel, extended, data)
} else {
unreachable!()
}
}
pub fn xon_xoff_request(&mut self, channel: ChannelId, client_can_do: bool) {
if let Some(ref mut enc) = self.common.encrypted {
if let Some(channel) = enc.channels.get(&channel) {
assert!(channel.confirmed);
push_packet!(enc.write, {
enc.write.push(msg::CHANNEL_REQUEST);
enc.write.push_u32_be(channel.recipient_channel);
enc.write.extend_ssh_string(b"xon-xoff");
enc.write.push(0);
enc.write.push(if client_can_do { 1 } else { 0 });
})
}
}
}
pub fn exit_status_request(&mut self, channel: ChannelId, exit_status: u32) {
if let Some(ref mut enc) = self.common.encrypted {
if let Some(channel) = enc.channels.get(&channel) {
assert!(channel.confirmed);
push_packet!(enc.write, {
enc.write.push(msg::CHANNEL_REQUEST);
enc.write.push_u32_be(channel.recipient_channel);
enc.write.extend_ssh_string(b"exit-status");
enc.write.push(0);
enc.write.push_u32_be(exit_status)
})
}
}
}
pub fn exit_signal_request(
&mut self,
channel: ChannelId,
signal: Sig,
core_dumped: bool,
error_message: &str,
language_tag: &str,
) {
if let Some(ref mut enc) = self.common.encrypted {
if let Some(channel) = enc.channels.get(&channel) {
assert!(channel.confirmed);
push_packet!(enc.write, {
enc.write.push(msg::CHANNEL_REQUEST);
enc.write.push_u32_be(channel.recipient_channel);
enc.write.extend_ssh_string(b"exit-signal");
enc.write.push(0);
enc.write.extend_ssh_string(signal.name().as_bytes());
enc.write.push(if core_dumped { 1 } else { 0 });
enc.write.extend_ssh_string(error_message.as_bytes());
enc.write.extend_ssh_string(language_tag.as_bytes());
})
}
}
}
pub fn channel_open_forwarded_tcpip(
&mut self,
connected_address: &str,
connected_port: u32,
originator_address: &str,
originator_port: u32,
) -> Result<ChannelId, Error> {
let result = if let Some(ref mut enc) = self.common.encrypted {
match enc.state {
EncryptedState::Authenticated => {
debug!("sending open request");
let sender_channel = enc.new_channel(
self.common.config.window_size,
self.common.config.maximum_packet_size,
);
push_packet!(enc.write, {
enc.write.push(msg::CHANNEL_OPEN);
enc.write.extend_ssh_string(b"forwarded-tcpip");
enc.write.push_u32_be(sender_channel.0);
enc.write
.push_u32_be(self.common.config.as_ref().window_size);
enc.write
.push_u32_be(self.common.config.as_ref().maximum_packet_size);
enc.write.extend_ssh_string(connected_address.as_bytes());
enc.write.push_u32_be(connected_port); enc.write.extend_ssh_string(originator_address.as_bytes());
enc.write.push_u32_be(originator_port); });
sender_channel
}
_ => return Err(Error::Inconsistent),
}
} else {
return Err(Error::Inconsistent);
};
Ok(result)
}
}