mod config;
use std::{
convert::TryFrom,
env, fs,
process::exit,
time::{Duration, SystemTime},
};
use anyhow::Context;
use ruma::{
api::client::r0::{
alias::get_alias,
filter::FilterDefinition,
membership::{join_room_by_id, kick_user},
sync::sync_events::{self, JoinedRoom, Rooms},
},
assign,
events::{room::member::MembershipState, AnySyncEphemeralRoomEvent, EventType},
presence::PresenceState,
RoomId, UserId,
};
use ruma_client::Client;
use serde::{Deserialize, Serialize};
use tokio::sync::Notify;
use tokio_stream::StreamExt as _;
use tracing::{debug, error, error_span, info, trace};
use tracing_subscriber::{layer::SubscriberExt as _, util::SubscriberInitExt as _};
use crate::config::Config;
#[tokio::main(flavor = "current_thread")]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer().with_target(false))
.with(
tracing_subscriber::filter::EnvFilter::from_default_env()
.add_directive(tracing_subscriber::filter::LevelFilter::INFO.into()),
)
.init();
let config_path = match env::args_os().nth(1) {
Some(x) => x,
None => {
eprintln!("Usage: {} <config path>", env::args().next().unwrap());
exit(1)
}
};
let config: Config = toml::from_slice(&fs::read(config_path).context("reading config")?)
.context("parsing config")?;
let client = Client::new(config.homeserver.clone(), None);
let session = client
.log_in(&config.username, &config.password, None, Some("idlebot"))
.await?;
let id = session.identification.as_ref().unwrap().user_id.clone();
debug!(%id, "logged in");
for room_alias in &config.rooms {
let room_id = client
.request(get_alias::Request::new(room_alias))
.await?
.room_id;
client
.request(join_room_by_id::Request::new(&room_id))
.await?;
}
debug!("joined all rooms");
let mut filter = FilterDefinition::ignore_all();
let fields = [
"type".into(),
"sender".into(),
"content".into(),
"state_key".into(),
];
filter.event_fields = Some(&fields);
filter.room.rooms = None;
filter.room.state.not_senders = std::slice::from_ref(&id);
filter.room.timeline.not_senders = std::slice::from_ref(&id);
filter.room.account_data.types = Some(&[]);
let ephemeral_types = ["m.receipt".into()];
filter.room.ephemeral.types = Some(&ephemeral_types);
let mut initial_filter = filter.clone();
initial_filter.room.timeline.limit = Some(0u32.into());
let initial_tys = ["m.room.member".into()];
initial_filter.room.state.types = Some(&initial_tys);
let initial_filter = initial_filter.into();
let filter = filter.into();
let initial_sync = client
.request(assign!(sync_events::Request::new(), {
filter: Some(&initial_filter),
}))
.await?;
let state = State::new(&config)?;
state.update(&config, initial_sync.rooms, false);
info!("synchronized");
let mut sync_stream = Box::pin(client.sync(
Some(&filter),
initial_sync.next_batch,
&PresenceState::Online,
Some(Duration::from_secs(30)),
));
let update_complete = Notify::new();
let update = async {
while let Some(res) = sync_stream.try_next().await? {
state.update(&config, res.rooms, true);
update_complete.notify_one();
}
Ok(())
};
let sleep = async {
loop {
let least_active = state.least_active();
if let Some((time, room, user)) = least_active {
let dt = (time + Duration::from_secs_f32(config.idle_days * 60.0 * 60.0 * 24.0))
.duration_since(SystemTime::now())
.unwrap_or(Duration::new(0, 0));
debug!(%room, %user, time = ?dt, "waiting");
tokio::select! {
_ = tokio::time::sleep(dt) => {
state.forget_user(&room, &user);
let mut kick = kick_user::Request::new(&room, &user);
kick.reason = config.reason.as_ref().map(|x| &x[..]);
if let Err(error) = client.request(kick).await {
error!(%room, %user, %error, "kick failed");
state.refresh_user(&room, &user);
} else {
info!(%room, %user, "kicked");
}
}
_ = update_complete.notified() => {}
}
} else {
debug!("waiting for users");
update_complete.notified().await;
}
}
};
loop {
tokio::select! {
result = update => { return result; }
_ = sleep => unreachable!(),
}
}
}
pub struct State {
env: heed::Env,
user_time: heed::Database<heed::types::Unit, heed::types::OwnedType<u64>>,
time_user: heed::Database<heed::types::Unit, heed::types::Unit>,
}
impl State {
fn new(config: &Config) -> anyhow::Result<Self> {
std::fs::create_dir_all(&config.state_path).context("accessing state")?;
let env = heed::EnvOpenOptions::new()
.max_dbs(20) .open(&config.state_path)
.map_err(|e| anyhow::anyhow!("opening state: {}", e))?;
let user_time = env
.create_database(Some("user-time"))
.map_err(|e| anyhow::anyhow!("opening state: {}", e))?;
let time_user = env
.create_database(Some("time-user"))
.map_err(|e| anyhow::anyhow!("opening state: {}", e))?;
Ok(Self {
env,
user_time,
time_user,
})
}
fn update(&self, config: &Config, rooms: Rooms, state_fresh: bool) {
let mut txn = self.env.write_txn().unwrap();
for (id, room) in rooms.join {
self.update_room(&mut txn, &config, id, room, state_fresh);
}
txn.commit().unwrap();
}
fn least_active(&self) -> Option<(SystemTime, RoomId, UserId)> {
let txn = self.env.read_txn().unwrap();
let time_user = self.time_user.remap_key_type::<TimeUserEntry>();
let ((time, key), ()) = time_user.first(&txn).unwrap()?;
let time = SystemTime::UNIX_EPOCH + Duration::from_secs(u64::from_be(time));
Some((
time,
RoomId::try_from(key.room).unwrap(),
UserId::try_from(key.user).unwrap(),
))
}
fn update_room(
&self,
txn: &mut heed::RwTxn<'_, '_>,
config: &Config,
id: RoomId,
room: JoinedRoom,
state_fresh: bool,
) {
let span = error_span!("synchronizing", room = %id);
let _guard = span.enter();
for event in room.state.events {
self.handle_event(txn, config, &id, event.json(), state_fresh);
}
for event in room.timeline.events {
self.handle_event(txn, config, &id, event.json(), true);
}
for event in room.ephemeral.events {
match event.deserialize() {
Ok(event) => self.handle_ephemeral(txn, &id, event),
Err(error) => error!(%error, "malformed ephemeral event"),
}
}
}
fn handle_event(
&self,
txn: &mut heed::RwTxn<'_, '_>,
config: &Config,
room: &RoomId,
event: &serde_json::value::RawValue,
freshen: bool,
) {
let event: AnyEvent = match serde_json::from_str(event.get()) {
Err(e) => {
error!(raw = event.get(), error = %e, "error parsing event");
return;
}
Ok(e) => e,
};
if !config.blacklist.contains(&event.sender) {
self.refresh_user_inner(txn, room, &event.sender, freshen);
}
if event.ty == EventType::RoomMember
&& event.content.map_or(false, |content| {
content
.membership
.map_or(false, |membership| membership != MembershipState::Join)
})
{
if let Some(key) = event.state_key {
match UserId::try_from(key) {
Ok(user) => {
if !config.blacklist.contains(&user) {
self.forget_user_inner(txn, room, &user);
}
}
Err(e) => {
error!(key, error = %e, "malformed state key for membership event");
}
}
}
}
}
fn handle_ephemeral(
&self,
txn: &mut heed::RwTxn<'_, '_>,
room: &RoomId,
event: AnySyncEphemeralRoomEvent,
) {
let content = match event {
AnySyncEphemeralRoomEvent::Receipt(x) => x.content,
_ => return,
};
for receipts in content.values() {
let read = match receipts.read {
Some(ref x) => x,
None => continue,
};
for user in read.keys() {
self.refresh_user_inner(txn, room, user, true);
}
}
}
fn refresh_user(&self, room: &RoomId, user: &UserId) {
let mut txn = self.env.write_txn().unwrap();
self.refresh_user_inner(&mut txn, room, user, true);
txn.commit().unwrap();
}
fn refresh_user_inner(
&self,
txn: &mut heed::RwTxn<'_, '_>,
room: &RoomId,
user: &UserId,
freshen: bool,
) {
let key = Key {
room: room.as_str(),
user: user.as_str(),
};
let time = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap()
.as_secs();
let user_time = self
.user_time
.remap_key_type::<heed::types::SerdeBincode<Key>>();
let time_user = self.time_user.remap_key_type::<TimeUserEntry>();
let prev_time = user_time.get(txn, &key).unwrap();
if prev_time.is_some() && !freshen {
return;
}
user_time.put(txn, &key, &time).unwrap();
if let Some(prev_time) = prev_time {
time_user.delete(txn, &(prev_time.to_be(), key)).unwrap();
}
time_user.put(txn, &(time.to_be(), key), &()).unwrap();
trace!(id = %user, age = time - prev_time.unwrap_or(time), "refreshed user");
}
fn forget_user(&self, room: &RoomId, user: &UserId) {
let mut txn = self.env.write_txn().unwrap();
self.forget_user_inner(&mut txn, room, user);
txn.commit().unwrap();
}
fn forget_user_inner(&self, txn: &mut heed::RwTxn<'_, '_>, room: &RoomId, user: &UserId) {
let key = Key {
room: room.as_str(),
user: user.as_str(),
};
let user_time = self
.user_time
.remap_key_type::<heed::types::SerdeBincode<Key>>();
let time_user = self.time_user.remap_key_type::<TimeUserEntry>();
let time = user_time.get(txn, &key).unwrap();
let known = user_time.delete(txn, &key).unwrap();
if let Some(time) = time {
time_user.delete(txn, &(time.to_be(), key)).unwrap();
}
if known {
trace!(id = %user, "forgot user");
}
}
}
type TimeUserEntry<'a> = heed::types::SerdeBincode<(u64, Key<'a>)>;
#[derive(Serialize, Deserialize, Copy, Clone, Debug)]
struct Key<'a> {
room: &'a str,
user: &'a str,
}
#[derive(Deserialize)]
struct AnyEvent<'a> {
#[serde(rename = "type")]
ty: EventType,
sender: UserId,
content: Option<Content>,
state_key: Option<&'a str>,
}
#[derive(Deserialize, PartialEq)]
struct Content {
membership: Option<MembershipState>,
}