Matrix bot for kicking idle users
mod config;

use std::{
    convert::TryFrom,
    env, fs,
    process::exit,
    time::{Duration, SystemTime},
};

use anyhow::Context;
use ruma::{
    api::client::r0::{
        alias::get_alias,
        filter::FilterDefinition,
        membership::{join_room_by_id, kick_user},
        sync::sync_events::{self, JoinedRoom, Rooms},
    },
    assign,
    events::{room::member::MembershipState, AnySyncEphemeralRoomEvent, EventType},
    presence::PresenceState,
    RoomId, UserId,
};
use ruma_client::Client;
use serde::{Deserialize, Serialize};
use tokio::sync::Notify;
use tokio_stream::StreamExt as _;
use tracing::{debug, error, error_span, info, trace};
use tracing_subscriber::{layer::SubscriberExt as _, util::SubscriberInitExt as _};

use crate::config::Config;

#[tokio::main(flavor = "current_thread")]
async fn main() -> anyhow::Result<()> {
    tracing_subscriber::registry()
        .with(tracing_subscriber::fmt::layer().with_target(false))
        .with(
            tracing_subscriber::filter::EnvFilter::from_default_env()
                .add_directive(tracing_subscriber::filter::LevelFilter::INFO.into()),
        )
        .init();

    let config_path = match env::args_os().nth(1) {
        Some(x) => x,
        None => {
            eprintln!("Usage: {} <config path>", env::args().next().unwrap());
            exit(1)
        }
    };

    let config: Config = toml::from_slice(&fs::read(config_path).context("reading config")?)
        .context("parsing config")?;

    let client = Client::new(config.homeserver.clone(), None);
    let session = client
        .log_in(&config.username, &config.password, None, Some("idlebot"))
        .await?;
    let id = session.identification.as_ref().unwrap().user_id.clone();
    debug!(%id, "logged in");

    for room_alias in &config.rooms {
        let room_id = client
            .request(get_alias::Request::new(room_alias))
            .await?
            .room_id;
        client
            .request(join_room_by_id::Request::new(&room_id))
            .await?;
    }
    debug!("joined all rooms");

    let mut filter = FilterDefinition::ignore_all();
    let fields = [
        "type".into(),
        "sender".into(),
        "content".into(),
        "state_key".into(),
    ];
    filter.event_fields = Some(&fields);
    filter.room.rooms = None;
    filter.room.state.not_senders = std::slice::from_ref(&id);
    filter.room.timeline.not_senders = std::slice::from_ref(&id);
    filter.room.account_data.types = Some(&[]);
    let ephemeral_types = ["m.receipt".into()];
    filter.room.ephemeral.types = Some(&ephemeral_types);
    let mut initial_filter = filter.clone();
    initial_filter.room.timeline.limit = Some(0u32.into());
    let initial_tys = ["m.room.member".into()];
    initial_filter.room.state.types = Some(&initial_tys);
    let initial_filter = initial_filter.into();
    let filter = filter.into();

    let initial_sync = client
        .request(assign!(sync_events::Request::new(), {
            filter: Some(&initial_filter),
        }))
        .await?;

    // FIXME: Should run before connecting. https://github.com/Kerollmops/heed/issues/106
    let state = State::new(&config)?;

    state.update(&config, initial_sync.rooms, false);

    info!("synchronized");

    let mut sync_stream = Box::pin(client.sync(
        Some(&filter),
        initial_sync.next_batch,
        &PresenceState::Online,
        Some(Duration::from_secs(30)),
    ));

    let update_complete = Notify::new();

    let update = async {
        while let Some(res) = sync_stream.try_next().await? {
            state.update(&config, res.rooms, true);
            update_complete.notify_one();
        }

        Ok(())
    };

    let sleep = async {
        loop {
            let least_active = state.least_active();
            if let Some((time, room, user)) = least_active {
                let dt = (time + Duration::from_secs_f32(config.idle_days * 60.0 * 60.0 * 24.0))
                    .duration_since(SystemTime::now())
                    .unwrap_or(Duration::new(0, 0));
                debug!(%room, %user, time = ?dt, "waiting");
                tokio::select! {
                    _ = tokio::time::sleep(dt) => {
                        state.forget_user(&room, &user);
                        let mut kick = kick_user::Request::new(&room, &user);
                        kick.reason = config.reason.as_ref().map(|x| &x[..]);
                        if let Err(error) = client.request(kick).await {
                            error!(%room, %user, %error, "kick failed");
                            // Stick them back in the DB so we'll try again eventually
                            state.refresh_user(&room, &user);
                        } else {
                            info!(%room, %user, "kicked");
                        }
                    }
                    _ = update_complete.notified() => {}
                }
            } else {
                debug!("waiting for users");
                update_complete.notified().await;
            }
        }
    };

    loop {
        tokio::select! {
            result = update => { return result; }
            _ = sleep => unreachable!(),
        }
    }
}

pub struct State {
    env: heed::Env,
    user_time: heed::Database<heed::types::Unit, heed::types::OwnedType<u64>>,
    time_user: heed::Database<heed::types::Unit, heed::types::Unit>,
}

impl State {
    fn new(config: &Config) -> anyhow::Result<Self> {
        std::fs::create_dir_all(&config.state_path).context("accessing state")?;
        let env = heed::EnvOpenOptions::new()
            .max_dbs(20) // FIXME: https://github.com/Kerollmops/heed/issues/106
            .open(&config.state_path)
            .map_err(|e| anyhow::anyhow!("opening state: {}", e))?;
        let user_time = env
            .create_database(Some("user-time"))
            .map_err(|e| anyhow::anyhow!("opening state: {}", e))?;
        let time_user = env
            .create_database(Some("time-user"))
            .map_err(|e| anyhow::anyhow!("opening state: {}", e))?;
        Ok(Self {
            env,
            user_time,
            time_user,
        })
    }

    fn update(&self, config: &Config, rooms: Rooms, state_fresh: bool) {
        let mut txn = self.env.write_txn().unwrap();
        for (id, room) in rooms.join {
            self.update_room(&mut txn, &config, id, room, state_fresh);
        }
        txn.commit().unwrap();
    }

    fn least_active(&self) -> Option<(SystemTime, RoomId, UserId)> {
        let txn = self.env.read_txn().unwrap();
        let time_user = self.time_user.remap_key_type::<TimeUserEntry>();
        let ((time, key), ()) = time_user.first(&txn).unwrap()?;
        let time = SystemTime::UNIX_EPOCH + Duration::from_secs(u64::from_be(time));
        Some((
            time,
            RoomId::try_from(key.room).unwrap(),
            UserId::try_from(key.user).unwrap(),
        ))
    }

    fn update_room(
        &self,
        txn: &mut heed::RwTxn<'_, '_>,
        config: &Config,
        id: RoomId,
        room: JoinedRoom,
        state_fresh: bool,
    ) {
        let span = error_span!("synchronizing", room = %id);
        let _guard = span.enter();
        for event in room.state.events {
            self.handle_event(txn, config, &id, event.json(), state_fresh);
        }
        for event in room.timeline.events {
            self.handle_event(txn, config, &id, event.json(), true);
        }
        for event in room.ephemeral.events {
            match event.deserialize() {
                Ok(event) => self.handle_ephemeral(txn, &id, event),
                Err(error) => error!(%error, "malformed ephemeral event"),
            }
        }
    }

    fn handle_event(
        &self,
        txn: &mut heed::RwTxn<'_, '_>,
        config: &Config,
        room: &RoomId,
        event: &serde_json::value::RawValue,
        freshen: bool,
    ) {
        let event: AnyEvent = match serde_json::from_str(event.get()) {
            Err(e) => {
                error!(raw = event.get(), error = %e, "error parsing event");
                return;
            }
            Ok(e) => e,
        };
        if !config.blacklist.contains(&event.sender) {
            self.refresh_user_inner(txn, room, &event.sender, freshen);
        }
        if event.ty == EventType::RoomMember
            && event.content.map_or(false, |content| {
                content
                    .membership
                    .map_or(false, |membership| membership != MembershipState::Join)
            })
        {
            if let Some(key) = event.state_key {
                match UserId::try_from(key) {
                    Ok(user) => {
                        if !config.blacklist.contains(&user) {
                            self.forget_user_inner(txn, room, &user);
                        }
                    }
                    Err(e) => {
                        error!(key, error = %e, "malformed state key for membership event");
                    }
                }
            }
        }
    }

    fn handle_ephemeral(
        &self,
        txn: &mut heed::RwTxn<'_, '_>,
        room: &RoomId,
        event: AnySyncEphemeralRoomEvent,
    ) {
        let content = match event {
            AnySyncEphemeralRoomEvent::Receipt(x) => x.content,
            _ => return,
        };
        for receipts in content.values() {
            let read = match receipts.read {
                Some(ref x) => x,
                None => continue,
            };
            for user in read.keys() {
                self.refresh_user_inner(txn, room, user, true);
            }
        }
    }

    fn refresh_user(&self, room: &RoomId, user: &UserId) {
        let mut txn = self.env.write_txn().unwrap();
        self.refresh_user_inner(&mut txn, room, user, true);
        txn.commit().unwrap();
    }

    fn refresh_user_inner(
        &self,
        txn: &mut heed::RwTxn<'_, '_>,
        room: &RoomId,
        user: &UserId,
        freshen: bool,
    ) {
        let key = Key {
            room: room.as_str(),
            user: user.as_str(),
        };
        let time = SystemTime::now()
            .duration_since(SystemTime::UNIX_EPOCH)
            .unwrap()
            .as_secs();
        let user_time = self
            .user_time
            .remap_key_type::<heed::types::SerdeBincode<Key>>();
        let time_user = self.time_user.remap_key_type::<TimeUserEntry>();

        let prev_time = user_time.get(txn, &key).unwrap();
        if prev_time.is_some() && !freshen {
            return;
        }
        user_time.put(txn, &key, &time).unwrap();
        if let Some(prev_time) = prev_time {
            time_user.delete(txn, &(prev_time.to_be(), key)).unwrap();
        }
        time_user.put(txn, &(time.to_be(), key), &()).unwrap();
        trace!(id = %user, age = time - prev_time.unwrap_or(time), "refreshed user");
    }

    fn forget_user(&self, room: &RoomId, user: &UserId) {
        let mut txn = self.env.write_txn().unwrap();
        self.forget_user_inner(&mut txn, room, user);
        txn.commit().unwrap();
    }

    fn forget_user_inner(&self, txn: &mut heed::RwTxn<'_, '_>, room: &RoomId, user: &UserId) {
        let key = Key {
            room: room.as_str(),
            user: user.as_str(),
        };
        let user_time = self
            .user_time
            .remap_key_type::<heed::types::SerdeBincode<Key>>();
        let time_user = self.time_user.remap_key_type::<TimeUserEntry>();
        let time = user_time.get(txn, &key).unwrap();
        let known = user_time.delete(txn, &key).unwrap();
        if let Some(time) = time {
            time_user.delete(txn, &(time.to_be(), key)).unwrap();
        }
        if known {
            trace!(id = %user, "forgot user");
        }
    }
}

type TimeUserEntry<'a> = heed::types::SerdeBincode<(u64, Key<'a>)>;

#[derive(Serialize, Deserialize, Copy, Clone, Debug)]
struct Key<'a> {
    room: &'a str,
    user: &'a str,
}

#[derive(Deserialize)]
struct AnyEvent<'a> {
    #[serde(rename = "type")]
    ty: EventType,
    sender: UserId,
    content: Option<Content>,
    state_key: Option<&'a str>,
}

#[derive(Deserialize, PartialEq)]
struct Content {
    membership: Option<MembershipState>,
}